From 3dfed9c986a4354be3da35892036bdff6b38e8d3 Mon Sep 17 00:00:00 2001 From: Hein Date: Tue, 21 Apr 2026 23:04:46 +0200 Subject: [PATCH] fix(observability): include MCP session ID in access logs * Add function to extract MCP session ID from request headers and query parameters * Update access log to include MCP session ID fix(cli): simplify project lookup logic * Refactor project retrieval to prefer GUID lookup when input is a valid UUID * Introduce separate functions for fetching projects by GUID and name --- cmd/amcs-cli/cmd/root.go | 8 ++--- cmd/amcs-cli/cmd/sse.go | 9 +++--- internal/observability/http.go | 19 +++++++++++- internal/observability/http_test.go | 42 ++++++++++++++++++++++++++ internal/store/projects.go | 47 +++++++++++++++++++++-------- 5 files changed, 101 insertions(+), 24 deletions(-) diff --git a/cmd/amcs-cli/cmd/root.go b/cmd/amcs-cli/cmd/root.go index b658b6f..d0c9336 100644 --- a/cmd/amcs-cli/cmd/root.go +++ b/cmd/amcs-cli/cmd/root.go @@ -6,7 +6,6 @@ import ( "net/http" "os" "strings" - "time" "github.com/modelcontextprotocol/go-sdk/mcp" "github.com/spf13/cobra" @@ -132,11 +131,10 @@ func connectRemote(ctx context.Context) (*mcp.ClientSession, error) { verboseLogf("connecting to %s", endpointURL()) client := mcp.NewClient(&mcp.Implementation{Name: "amcs-cli", Version: "0.0.1"}, nil) transport := &mcp.StreamableClientTransport{ - Endpoint: endpointURL(), - HTTPClient: newHTTPClient(), + Endpoint: endpointURL(), + HTTPClient: newHTTPClient(), + DisableStandaloneSSE: true, } - ctx, cancel := context.WithTimeout(ctx, 30*time.Second) - defer cancel() session, err := client.Connect(ctx, transport, nil) if err != nil { return nil, fmt.Errorf("connect to AMCS server: %w", err) diff --git a/cmd/amcs-cli/cmd/sse.go b/cmd/amcs-cli/cmd/sse.go index a5a7b47..2d08341 100644 --- a/cmd/amcs-cli/cmd/sse.go +++ b/cmd/amcs-cli/cmd/sse.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "strings" - "time" "github.com/modelcontextprotocol/go-sdk/mcp" "github.com/spf13/cobra" @@ -26,11 +25,8 @@ var sseCmd = &cobra.Command{ HTTPClient: newHTTPClient(), } - connectCtx, cancel := context.WithTimeout(ctx, 30*time.Second) - defer cancel() - verboseLogf("connecting to SSE endpoint %s", sseEndpointURL()) - remote, err := client.Connect(connectCtx, transport, nil) + remote, err := client.Connect(ctx, transport, nil) if err != nil { return fmt.Errorf("connect to AMCS SSE endpoint: %w", err) } @@ -79,6 +75,9 @@ var sseCmd = &cobra.Command{ func sseEndpointURL() string { base := strings.TrimRight(strings.TrimSpace(cfg.Server), "/") + if strings.HasSuffix(base, "/mcp") { + base = strings.TrimSuffix(base, "/mcp") + } if strings.HasSuffix(base, "/sse") { return base } diff --git a/internal/observability/http.go b/internal/observability/http.go index e442e9a..064392f 100644 --- a/internal/observability/http.go +++ b/internal/observability/http.go @@ -78,9 +78,10 @@ func AccessLog(log *slog.Logger) func(http.Handler) http.Handler { slog.Int("status", recorder.status), slog.Duration("duration", time.Since(started)), slog.String("remote_addr", requestip.FromRequest(r)), + slog.String("mcp_session_id", mcpSessionIDFromRequest(r)), } if tool, _ := r.Context().Value(mcpToolContextKey).(string); strings.TrimSpace(tool) != "" { - attrs = append(attrs, slog.String("tool", tool)) + attrs = append(attrs, slog.String("tool", tool), slog.String("tool_call", tool)) } log.Info("http request", attrs...) }) @@ -150,6 +151,22 @@ func mcpToolFromRequest(r *http.Request) string { return msg.toolName() } +func mcpSessionIDFromRequest(r *http.Request) string { + if r == nil { + return "" + } + if v := strings.TrimSpace(r.Header.Get("MCP-Session-Id")); v != "" { + return v + } + // Some clients/proxies may propagate the session in query params. + for _, key := range []string{"session_id", "sessionId", "mcp_session_id"} { + if v := strings.TrimSpace(r.URL.Query().Get(key)); v != "" { + return v + } + } + return "" +} + type rpcEnvelope struct { Method string `json:"method"` Params struct { diff --git a/internal/observability/http_test.go b/internal/observability/http_test.go index cf97b4f..6e4d0ea 100644 --- a/internal/observability/http_test.go +++ b/internal/observability/http_test.go @@ -113,4 +113,46 @@ func TestAccessLogIncludesMCPToolName(t *testing.T) { if !strings.Contains(buf.String(), "tool=list_projects") { t.Fatalf("log output = %q, want tool=list_projects", buf.String()) } + if !strings.Contains(buf.String(), "tool_call=list_projects") { + t.Fatalf("log output = %q, want tool_call=list_projects", buf.String()) + } +} + +func TestAccessLogIncludesMCPSessionIDHeader(t *testing.T) { + var buf bytes.Buffer + logger := slog.New(slog.NewTextHandler(&buf, nil)) + handler := AccessLog(logger)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNoContent) + })) + + req := httptest.NewRequest(http.MethodGet, "/sse", nil) + req.Header.Set("MCP-Session-Id", "sess-123") + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusNoContent { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusNoContent) + } + if !strings.Contains(buf.String(), "mcp_session_id=sess-123") { + t.Fatalf("log output = %q, want mcp_session_id=sess-123", buf.String()) + } +} + +func TestAccessLogIncludesMCPSessionIDQueryParam(t *testing.T) { + var buf bytes.Buffer + logger := slog.New(slog.NewTextHandler(&buf, nil)) + handler := AccessLog(logger)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNoContent) + })) + + req := httptest.NewRequest(http.MethodGet, "/sse?session_id=sess-q-1", nil) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusNoContent { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusNoContent) + } + if !strings.Contains(buf.String(), "mcp_session_id=sess-q-1") { + t.Fatalf("log output = %q, want mcp_session_id=sess-q-1", buf.String()) + } } diff --git a/internal/store/projects.go b/internal/store/projects.go index d3a707f..aa1834a 100644 --- a/internal/store/projects.go +++ b/internal/store/projects.go @@ -26,21 +26,42 @@ func (db *DB) CreateProject(ctx context.Context, name, description string) (thou } func (db *DB) GetProject(ctx context.Context, nameOrID string) (thoughttypes.Project, error) { - var row pgx.Row - if parsedID, err := uuid.Parse(strings.TrimSpace(nameOrID)); err == nil { - row = db.pool.QueryRow(ctx, ` - select guid, name, description, created_at, last_active_at - from projects - where guid = $1 - `, parsedID) - } else { - row = db.pool.QueryRow(ctx, ` - select guid, name, description, created_at, last_active_at - from projects - where name = $1 - `, strings.TrimSpace(nameOrID)) + lookup := strings.TrimSpace(nameOrID) + + // Prefer guid lookup when input parses as UUID, but fall back to name lookup + // so UUID-shaped project names can still be resolved by name. + if parsedID, err := uuid.Parse(lookup); err == nil { + project, queryErr := db.getProjectByGUID(ctx, parsedID) + if queryErr == nil { + return project, nil + } + if queryErr != pgx.ErrNoRows { + return thoughttypes.Project{}, queryErr + } } + return db.getProjectByName(ctx, lookup) +} + +func (db *DB) getProjectByGUID(ctx context.Context, id uuid.UUID) (thoughttypes.Project, error) { + row := db.pool.QueryRow(ctx, ` + select guid, name, description, created_at, last_active_at + from projects + where guid = $1 + `, id) + return scanProject(row) +} + +func (db *DB) getProjectByName(ctx context.Context, name string) (thoughttypes.Project, error) { + row := db.pool.QueryRow(ctx, ` + select guid, name, description, created_at, last_active_at + from projects + where name = $1 + `, name) + return scanProject(row) +} + +func scanProject(row pgx.Row) (thoughttypes.Project, error) { var project thoughttypes.Project if err := row.Scan(&project.ID, &project.Name, &project.Description, &project.CreatedAt, &project.LastActiveAt); err != nil { if err == pgx.ErrNoRows {