fix(observability): include MCP session ID in access logs
Some checks failed
CI / build-and-test (push) Failing after -32m50s
Some checks failed
CI / build-and-test (push) Failing after -32m50s
* Add function to extract MCP session ID from request headers and query parameters * Update access log to include MCP session ID fix(cli): simplify project lookup logic * Refactor project retrieval to prefer GUID lookup when input is a valid UUID * Introduce separate functions for fetching projects by GUID and name
This commit is contained in:
@@ -6,7 +6,6 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
@@ -134,9 +133,8 @@ func connectRemote(ctx context.Context) (*mcp.ClientSession, error) {
|
|||||||
transport := &mcp.StreamableClientTransport{
|
transport := &mcp.StreamableClientTransport{
|
||||||
Endpoint: endpointURL(),
|
Endpoint: endpointURL(),
|
||||||
HTTPClient: newHTTPClient(),
|
HTTPClient: newHTTPClient(),
|
||||||
|
DisableStandaloneSSE: true,
|
||||||
}
|
}
|
||||||
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
session, err := client.Connect(ctx, transport, nil)
|
session, err := client.Connect(ctx, transport, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("connect to AMCS server: %w", err)
|
return nil, fmt.Errorf("connect to AMCS server: %w", err)
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
@@ -26,11 +25,8 @@ var sseCmd = &cobra.Command{
|
|||||||
HTTPClient: newHTTPClient(),
|
HTTPClient: newHTTPClient(),
|
||||||
}
|
}
|
||||||
|
|
||||||
connectCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
verboseLogf("connecting to SSE endpoint %s", sseEndpointURL())
|
verboseLogf("connecting to SSE endpoint %s", sseEndpointURL())
|
||||||
remote, err := client.Connect(connectCtx, transport, nil)
|
remote, err := client.Connect(ctx, transport, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("connect to AMCS SSE endpoint: %w", err)
|
return fmt.Errorf("connect to AMCS SSE endpoint: %w", err)
|
||||||
}
|
}
|
||||||
@@ -79,6 +75,9 @@ var sseCmd = &cobra.Command{
|
|||||||
|
|
||||||
func sseEndpointURL() string {
|
func sseEndpointURL() string {
|
||||||
base := strings.TrimRight(strings.TrimSpace(cfg.Server), "/")
|
base := strings.TrimRight(strings.TrimSpace(cfg.Server), "/")
|
||||||
|
if strings.HasSuffix(base, "/mcp") {
|
||||||
|
base = strings.TrimSuffix(base, "/mcp")
|
||||||
|
}
|
||||||
if strings.HasSuffix(base, "/sse") {
|
if strings.HasSuffix(base, "/sse") {
|
||||||
return base
|
return base
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -78,9 +78,10 @@ func AccessLog(log *slog.Logger) func(http.Handler) http.Handler {
|
|||||||
slog.Int("status", recorder.status),
|
slog.Int("status", recorder.status),
|
||||||
slog.Duration("duration", time.Since(started)),
|
slog.Duration("duration", time.Since(started)),
|
||||||
slog.String("remote_addr", requestip.FromRequest(r)),
|
slog.String("remote_addr", requestip.FromRequest(r)),
|
||||||
|
slog.String("mcp_session_id", mcpSessionIDFromRequest(r)),
|
||||||
}
|
}
|
||||||
if tool, _ := r.Context().Value(mcpToolContextKey).(string); strings.TrimSpace(tool) != "" {
|
if tool, _ := r.Context().Value(mcpToolContextKey).(string); strings.TrimSpace(tool) != "" {
|
||||||
attrs = append(attrs, slog.String("tool", tool))
|
attrs = append(attrs, slog.String("tool", tool), slog.String("tool_call", tool))
|
||||||
}
|
}
|
||||||
log.Info("http request", attrs...)
|
log.Info("http request", attrs...)
|
||||||
})
|
})
|
||||||
@@ -150,6 +151,22 @@ func mcpToolFromRequest(r *http.Request) string {
|
|||||||
return msg.toolName()
|
return msg.toolName()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func mcpSessionIDFromRequest(r *http.Request) string {
|
||||||
|
if r == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if v := strings.TrimSpace(r.Header.Get("MCP-Session-Id")); v != "" {
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
// Some clients/proxies may propagate the session in query params.
|
||||||
|
for _, key := range []string{"session_id", "sessionId", "mcp_session_id"} {
|
||||||
|
if v := strings.TrimSpace(r.URL.Query().Get(key)); v != "" {
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
type rpcEnvelope struct {
|
type rpcEnvelope struct {
|
||||||
Method string `json:"method"`
|
Method string `json:"method"`
|
||||||
Params struct {
|
Params struct {
|
||||||
|
|||||||
@@ -113,4 +113,46 @@ func TestAccessLogIncludesMCPToolName(t *testing.T) {
|
|||||||
if !strings.Contains(buf.String(), "tool=list_projects") {
|
if !strings.Contains(buf.String(), "tool=list_projects") {
|
||||||
t.Fatalf("log output = %q, want tool=list_projects", buf.String())
|
t.Fatalf("log output = %q, want tool=list_projects", buf.String())
|
||||||
}
|
}
|
||||||
|
if !strings.Contains(buf.String(), "tool_call=list_projects") {
|
||||||
|
t.Fatalf("log output = %q, want tool_call=list_projects", buf.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccessLogIncludesMCPSessionIDHeader(t *testing.T) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
logger := slog.New(slog.NewTextHandler(&buf, nil))
|
||||||
|
handler := AccessLog(logger)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusNoContent)
|
||||||
|
}))
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/sse", nil)
|
||||||
|
req.Header.Set("MCP-Session-Id", "sess-123")
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusNoContent {
|
||||||
|
t.Fatalf("status = %d, want %d", rec.Code, http.StatusNoContent)
|
||||||
|
}
|
||||||
|
if !strings.Contains(buf.String(), "mcp_session_id=sess-123") {
|
||||||
|
t.Fatalf("log output = %q, want mcp_session_id=sess-123", buf.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccessLogIncludesMCPSessionIDQueryParam(t *testing.T) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
logger := slog.New(slog.NewTextHandler(&buf, nil))
|
||||||
|
handler := AccessLog(logger)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusNoContent)
|
||||||
|
}))
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/sse?session_id=sess-q-1", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusNoContent {
|
||||||
|
t.Fatalf("status = %d, want %d", rec.Code, http.StatusNoContent)
|
||||||
|
}
|
||||||
|
if !strings.Contains(buf.String(), "mcp_session_id=sess-q-1") {
|
||||||
|
t.Fatalf("log output = %q, want mcp_session_id=sess-q-1", buf.String())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -26,21 +26,42 @@ func (db *DB) CreateProject(ctx context.Context, name, description string) (thou
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (db *DB) GetProject(ctx context.Context, nameOrID string) (thoughttypes.Project, error) {
|
func (db *DB) GetProject(ctx context.Context, nameOrID string) (thoughttypes.Project, error) {
|
||||||
var row pgx.Row
|
lookup := strings.TrimSpace(nameOrID)
|
||||||
if parsedID, err := uuid.Parse(strings.TrimSpace(nameOrID)); err == nil {
|
|
||||||
row = db.pool.QueryRow(ctx, `
|
// Prefer guid lookup when input parses as UUID, but fall back to name lookup
|
||||||
|
// so UUID-shaped project names can still be resolved by name.
|
||||||
|
if parsedID, err := uuid.Parse(lookup); err == nil {
|
||||||
|
project, queryErr := db.getProjectByGUID(ctx, parsedID)
|
||||||
|
if queryErr == nil {
|
||||||
|
return project, nil
|
||||||
|
}
|
||||||
|
if queryErr != pgx.ErrNoRows {
|
||||||
|
return thoughttypes.Project{}, queryErr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return db.getProjectByName(ctx, lookup)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *DB) getProjectByGUID(ctx context.Context, id uuid.UUID) (thoughttypes.Project, error) {
|
||||||
|
row := db.pool.QueryRow(ctx, `
|
||||||
select guid, name, description, created_at, last_active_at
|
select guid, name, description, created_at, last_active_at
|
||||||
from projects
|
from projects
|
||||||
where guid = $1
|
where guid = $1
|
||||||
`, parsedID)
|
`, id)
|
||||||
} else {
|
return scanProject(row)
|
||||||
row = db.pool.QueryRow(ctx, `
|
}
|
||||||
|
|
||||||
|
func (db *DB) getProjectByName(ctx context.Context, name string) (thoughttypes.Project, error) {
|
||||||
|
row := db.pool.QueryRow(ctx, `
|
||||||
select guid, name, description, created_at, last_active_at
|
select guid, name, description, created_at, last_active_at
|
||||||
from projects
|
from projects
|
||||||
where name = $1
|
where name = $1
|
||||||
`, strings.TrimSpace(nameOrID))
|
`, name)
|
||||||
|
return scanProject(row)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func scanProject(row pgx.Row) (thoughttypes.Project, error) {
|
||||||
var project thoughttypes.Project
|
var project thoughttypes.Project
|
||||||
if err := row.Scan(&project.ID, &project.Name, &project.Description, &project.CreatedAt, &project.LastActiveAt); err != nil {
|
if err := row.Scan(&project.ID, &project.Name, &project.Description, &project.CreatedAt, &project.LastActiveAt); err != nil {
|
||||||
if err == pgx.ErrNoRows {
|
if err == pgx.ErrNoRows {
|
||||||
|
|||||||
Reference in New Issue
Block a user