package observability import ( "bytes" "encoding/json" "io" "log/slog" "net/http" "net/http/httptest" "strings" "testing" "time" ) func TestRequestIDSetsHeaderAndContext(t *testing.T) { handler := RequestID()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if got := RequestIDFromContext(r.Context()); got == "" { t.Fatal("RequestIDFromContext() = empty, want non-empty") } w.WriteHeader(http.StatusNoContent) })) req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) if rec.Header().Get("X-Request-Id") == "" { t.Fatal("X-Request-Id header = empty, want non-empty") } } func TestTimeoutAddsContextDeadline(t *testing.T) { handler := Timeout(50 * time.Millisecond)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if _, ok := r.Context().Deadline(); !ok { t.Fatal("context deadline missing") } w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK) } } func TestRecoverHandlesPanic(t *testing.T) { logger := slog.New(slog.NewTextHandler(io.Discard, nil)) handler := Recover(logger)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { panic("boom") })) req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) if rec.Code != http.StatusInternalServerError { t.Fatalf("status = %d, want %d", rec.Code, http.StatusInternalServerError) } } func TestAccessLogUsesForwardedClientIP(t *testing.T) { var buf bytes.Buffer logger := slog.New(slog.NewTextHandler(&buf, nil)) handler := AccessLog(logger)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNoContent) })) req := httptest.NewRequest(http.MethodGet, "/mcp", nil) req.RemoteAddr = "10.0.0.10:1234" req.Header.Set("X-Real-IP", "203.0.113.7") rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) if rec.Code != http.StatusNoContent { t.Fatalf("status = %d, want %d", rec.Code, http.StatusNoContent) } if !strings.Contains(buf.String(), "remote_addr=203.0.113.7") { t.Fatalf("log output = %q, want remote_addr=203.0.113.7", buf.String()) } } func TestAccessLogIncludesMCPToolName(t *testing.T) { var buf bytes.Buffer logger := slog.New(slog.NewTextHandler(&buf, nil)) handler := AccessLog(logger)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNoContent) })) payload := map[string]any{ "jsonrpc": "2.0", "id": "1", "method": "tools/call", "params": map[string]any{ "name": "list_projects", "arguments": map[string]any{}, }, } body, err := json.Marshal(payload) if err != nil { t.Fatalf("json.Marshal() error = %v", err) } req := httptest.NewRequest(http.MethodPost, "/mcp", bytes.NewReader(body)) rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) if rec.Code != http.StatusNoContent { t.Fatalf("status = %d, want %d", rec.Code, http.StatusNoContent) } if !strings.Contains(buf.String(), "tool=list_projects") { t.Fatalf("log output = %q, want tool=list_projects", buf.String()) } if !strings.Contains(buf.String(), "tool_call=list_projects") { t.Fatalf("log output = %q, want tool_call=list_projects", buf.String()) } } func TestAccessLogIncludesMCPSessionIDHeader(t *testing.T) { var buf bytes.Buffer logger := slog.New(slog.NewTextHandler(&buf, nil)) handler := AccessLog(logger)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNoContent) })) req := httptest.NewRequest(http.MethodGet, "/sse", nil) req.Header.Set("MCP-Session-Id", "sess-123") rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) if rec.Code != http.StatusNoContent { t.Fatalf("status = %d, want %d", rec.Code, http.StatusNoContent) } if !strings.Contains(buf.String(), "mcp_session_id=sess-123") { t.Fatalf("log output = %q, want mcp_session_id=sess-123", buf.String()) } } func TestAccessLogIncludesMCPSessionIDQueryParam(t *testing.T) { var buf bytes.Buffer logger := slog.New(slog.NewTextHandler(&buf, nil)) handler := AccessLog(logger)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNoContent) })) req := httptest.NewRequest(http.MethodGet, "/sse?session_id=sess-q-1", nil) rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) if rec.Code != http.StatusNoContent { t.Fatalf("status = %d, want %d", rec.Code, http.StatusNoContent) } if !strings.Contains(buf.String(), "mcp_session_id=sess-q-1") { t.Fatalf("log output = %q, want mcp_session_id=sess-q-1", buf.String()) } }