From 512b16f8fe1b01cb62e91ffb11c8bd2644f77591 Mon Sep 17 00:00:00 2001 From: Hein Date: Tue, 21 Apr 2026 22:35:42 +0200 Subject: [PATCH] feat(observability): add MCP tool name logging in access log * Include tool name from request in access log entries * Update user agent header in HTTP requests * Add tests for MCP tool name logging --- cmd/amcs-cli/cmd/root.go | 5 +++ cmd/amcs-cli/cmd/root_test.go | 4 ++ internal/observability/http.go | 66 +++++++++++++++++++++++++++- internal/observability/http_test.go | 34 ++++++++++++++ internal/requestip/requestip.go | 10 ++--- internal/requestip/requestip_test.go | 12 +---- 6 files changed, 111 insertions(+), 20 deletions(-) diff --git a/cmd/amcs-cli/cmd/root.go b/cmd/amcs-cli/cmd/root.go index e1a69c6..b658b6f 100644 --- a/cmd/amcs-cli/cmd/root.go +++ b/cmd/amcs-cli/cmd/root.go @@ -21,6 +21,8 @@ var ( cfg Config ) +const cliUserAgent = "amcs-cli/0.0.1" + var rootCmd = &cobra.Command{ Use: "amcs-cli", Short: "CLI for connecting to a remote AMCS MCP server", @@ -114,6 +116,9 @@ func (t *bearerTransport) RoundTrip(req *http.Request) (*http.Response, error) { base = http.DefaultTransport } clone := req.Clone(req.Context()) + if strings.TrimSpace(clone.Header.Get("User-Agent")) == "" { + clone.Header.Set("User-Agent", cliUserAgent) + } if strings.TrimSpace(t.token) != "" { clone.Header.Set("Authorization", "Bearer "+t.token) } diff --git a/cmd/amcs-cli/cmd/root_test.go b/cmd/amcs-cli/cmd/root_test.go index 2bb737e..16ad0d4 100644 --- a/cmd/amcs-cli/cmd/root_test.go +++ b/cmd/amcs-cli/cmd/root_test.go @@ -8,11 +8,15 @@ import ( func TestBearerTransportFormatsBearerToken(t *testing.T) { const want = "Bearer X" + const wantUA = "amcs-cli/0.0.1" ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if got := r.Header.Get("Authorization"); got != want { t.Fatalf("Authorization header = %q, want %q", got, want) } + if got := r.Header.Get("User-Agent"); got != wantUA { + t.Fatalf("User-Agent header = %q, want %q", got, wantUA) + } w.WriteHeader(http.StatusNoContent) })) defer ts.Close() diff --git a/internal/observability/http.go b/internal/observability/http.go index 09aaef0..e442e9a 100644 --- a/internal/observability/http.go +++ b/internal/observability/http.go @@ -1,10 +1,14 @@ package observability import ( + "bytes" "context" + "encoding/json" + "io" "log/slog" "net/http" "runtime/debug" + "strings" "time" "github.com/google/uuid" @@ -15,6 +19,7 @@ import ( type contextKey string const requestIDContextKey contextKey = "request_id" +const mcpToolContextKey contextKey = "mcp_tool" func Chain(h http.Handler, middlewares ...func(http.Handler) http.Handler) http.Handler { for i := len(middlewares) - 1; i >= 0; i-- { @@ -58,18 +63,26 @@ func Recover(log *slog.Logger) func(http.Handler) http.Handler { func AccessLog(log *slog.Logger) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if tool := mcpToolFromRequest(r); tool != "" { + r = r.WithContext(context.WithValue(r.Context(), mcpToolContextKey, tool)) + } + recorder := &statusRecorder{ResponseWriter: w, status: http.StatusOK} started := time.Now() next.ServeHTTP(recorder, r) - log.Info("http request", + attrs := []any{ slog.String("request_id", RequestIDFromContext(r.Context())), slog.String("method", r.Method), slog.String("path", r.URL.Path), slog.Int("status", recorder.status), slog.Duration("duration", time.Since(started)), slog.String("remote_addr", requestip.FromRequest(r)), - ) + } + if tool, _ := r.Context().Value(mcpToolContextKey).(string); strings.TrimSpace(tool) != "" { + attrs = append(attrs, slog.String("tool", tool)) + } + log.Info("http request", attrs...) }) } } @@ -101,3 +114,52 @@ func (s *statusRecorder) WriteHeader(statusCode int) { s.status = statusCode s.ResponseWriter.WriteHeader(statusCode) } + +func mcpToolFromRequest(r *http.Request) string { + if r == nil || r.Method != http.MethodPost || !strings.HasPrefix(r.URL.Path, "/mcp") || r.Body == nil { + return "" + } + + raw, err := io.ReadAll(r.Body) + if err != nil { + return "" + } + r.Body = io.NopCloser(bytes.NewReader(raw)) + if len(raw) == 0 { + return "" + } + + // Support both single and batch JSON-RPC payloads. + if strings.HasPrefix(strings.TrimSpace(string(raw)), "[") { + var batch []rpcEnvelope + if err := json.Unmarshal(raw, &batch); err != nil { + return "" + } + for _, msg := range batch { + if tool := msg.toolName(); tool != "" { + return tool + } + } + return "" + } + + var msg rpcEnvelope + if err := json.Unmarshal(raw, &msg); err != nil { + return "" + } + return msg.toolName() +} + +type rpcEnvelope struct { + Method string `json:"method"` + Params struct { + Name string `json:"name"` + } `json:"params"` +} + +func (m rpcEnvelope) toolName() string { + if m.Method != "tools/call" { + return "" + } + return strings.TrimSpace(m.Params.Name) +} diff --git a/internal/observability/http_test.go b/internal/observability/http_test.go index fe69358..cf97b4f 100644 --- a/internal/observability/http_test.go +++ b/internal/observability/http_test.go @@ -2,6 +2,7 @@ package observability import ( "bytes" + "encoding/json" "io" "log/slog" "net/http" @@ -80,3 +81,36 @@ func TestAccessLogUsesForwardedClientIP(t *testing.T) { t.Fatalf("log output = %q, want remote_addr=203.0.113.7", buf.String()) } } + +func TestAccessLogIncludesMCPToolName(t *testing.T) { + var buf bytes.Buffer + logger := slog.New(slog.NewTextHandler(&buf, nil)) + handler := AccessLog(logger)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNoContent) + })) + + payload := map[string]any{ + "jsonrpc": "2.0", + "id": "1", + "method": "tools/call", + "params": map[string]any{ + "name": "list_projects", + "arguments": map[string]any{}, + }, + } + body, err := json.Marshal(payload) + if err != nil { + t.Fatalf("json.Marshal() error = %v", err) + } + + req := httptest.NewRequest(http.MethodPost, "/mcp", bytes.NewReader(body)) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusNoContent { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusNoContent) + } + if !strings.Contains(buf.String(), "tool=list_projects") { + t.Fatalf("log output = %q, want tool=list_projects", buf.String()) + } +} diff --git a/internal/requestip/requestip.go b/internal/requestip/requestip.go index ea5ffee..e639cf1 100644 --- a/internal/requestip/requestip.go +++ b/internal/requestip/requestip.go @@ -11,10 +11,9 @@ import ( // // Header precedence: // 1) X-Real-IP -// 2) X-Forwarded-Host -// 3) X-Forwarded-For (first value) -// 4) Forwarded (for=...) -// 5) RemoteAddr (host part) +// 2) X-Forwarded-For (first value) +// 3) Forwarded (for=...) +// 4) RemoteAddr (host part) func FromRequest(r *http.Request) string { if r == nil { return "" @@ -23,9 +22,6 @@ func FromRequest(r *http.Request) string { if v := firstAddressToken(r.Header.Get("X-Real-IP")); v != "" { return stripPort(v) } - if v := firstAddressToken(r.Header.Get("X-Forwarded-Host")); v != "" { - return stripPort(v) - } if v := firstAddressToken(r.Header.Get("X-Forwarded-For")); v != "" { return stripPort(v) } diff --git a/internal/requestip/requestip_test.go b/internal/requestip/requestip_test.go index 09a15a7..0756ed2 100644 --- a/internal/requestip/requestip_test.go +++ b/internal/requestip/requestip_test.go @@ -9,7 +9,7 @@ import ( func TestFromRequestPrefersXRealIP(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/", nil) req.RemoteAddr = "10.0.0.10:5555" - req.Header.Set("X-Forwarded-Host", "proxy.example.com") + req.Header.Set("X-Forwarded-For", "198.51.100.1") req.Header.Set("X-Real-IP", "203.0.113.10") if got := FromRequest(req); got != "203.0.113.10" { @@ -17,16 +17,6 @@ func TestFromRequestPrefersXRealIP(t *testing.T) { } } -func TestFromRequestUsesXForwardedHostWhenRealIPMissing(t *testing.T) { - req := httptest.NewRequest(http.MethodGet, "/", nil) - req.RemoteAddr = "10.0.0.10:5555" - req.Header.Set("X-Forwarded-Host", "203.0.113.22") - - if got := FromRequest(req); got != "203.0.113.22" { - t.Fatalf("FromRequest() = %q, want %q", got, "203.0.113.22") - } -} - func TestFromRequestUsesXForwardedForFirstValue(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/", nil) req.RemoteAddr = "10.0.0.10:5555"