package auth import ( "bytes" "encoding/json" "io" "log/slog" "net/http" "net/http/httptest" "testing" "git.warky.dev/wdevs/amcs/internal/config" "git.warky.dev/wdevs/amcs/internal/observability" ) func testLogger() *slog.Logger { return slog.New(slog.NewTextHandler(io.Discard, nil)) } func TestNewKeyringAndLookup(t *testing.T) { _, err := NewKeyring(nil) if err == nil { t.Fatal("NewKeyring(nil) error = nil, want error") } keyring, err := NewKeyring([]config.APIKey{{ID: "client-a", Value: "secret"}}) if err != nil { t.Fatalf("NewKeyring() error = %v", err) } if got, ok := keyring.Lookup("secret"); !ok || got != "client-a" { t.Fatalf("Lookup(secret) = (%q, %v), want (client-a, true)", got, ok) } if _, ok := keyring.Lookup("wrong"); ok { t.Fatal("Lookup(wrong) = true, want false") } } func TestMiddlewareAllowsHeaderAuthAndSetsContext(t *testing.T) { keyring, err := NewKeyring([]config.APIKey{{ID: "client-a", Value: "secret"}}) if err != nil { t.Fatalf("NewKeyring() error = %v", err) } handler := Middleware(config.AuthConfig{HeaderName: "x-brain-key"}, keyring, nil, nil, nil, testLogger())(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { keyID, ok := KeyIDFromContext(r.Context()) if !ok || keyID != "client-a" { t.Fatalf("KeyIDFromContext() = (%q, %v), want (client-a, true)", keyID, ok) } w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest(http.MethodGet, "/mcp", nil) req.Header.Set("x-brain-key", "secret") rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK) } } func TestMiddlewareAllowsBearerAuthAndSetsContext(t *testing.T) { keyring, err := NewKeyring([]config.APIKey{{ID: "client-a", Value: "secret"}}) if err != nil { t.Fatalf("NewKeyring() error = %v", err) } handler := Middleware(config.AuthConfig{HeaderName: "x-brain-key"}, keyring, nil, nil, nil, testLogger())(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { keyID, ok := KeyIDFromContext(r.Context()) if !ok || keyID != "client-a" { t.Fatalf("KeyIDFromContext() = (%q, %v), want (client-a, true)", keyID, ok) } w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest(http.MethodGet, "/mcp", nil) req.Header.Set("Authorization", "Bearer secret") rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK) } } func TestMiddlewarePrefersExplicitHeaderOverBearerAuth(t *testing.T) { keyring, err := NewKeyring([]config.APIKey{ {ID: "client-a", Value: "secret"}, {ID: "client-b", Value: "other-secret"}, }) if err != nil { t.Fatalf("NewKeyring() error = %v", err) } handler := Middleware(config.AuthConfig{HeaderName: "x-brain-key"}, keyring, nil, nil, nil, testLogger())(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { keyID, ok := KeyIDFromContext(r.Context()) if !ok || keyID != "client-a" { t.Fatalf("KeyIDFromContext() = (%q, %v), want (client-a, true)", keyID, ok) } w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest(http.MethodGet, "/mcp", nil) req.Header.Set("x-brain-key", "secret") req.Header.Set("Authorization", "Bearer other-secret") rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK) } } func TestMiddlewareAllowsQueryParamWhenEnabled(t *testing.T) { keyring, err := NewKeyring([]config.APIKey{{ID: "client-a", Value: "secret"}}) if err != nil { t.Fatalf("NewKeyring() error = %v", err) } handler := Middleware(config.AuthConfig{ HeaderName: "x-brain-key", QueryParam: "key", AllowQueryParam: true, }, keyring, nil, nil, nil, testLogger())(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNoContent) })) req := httptest.NewRequest(http.MethodGet, "/mcp?key=secret", nil) rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) if rec.Code != http.StatusNoContent { t.Fatalf("status = %d, want %d", rec.Code, http.StatusNoContent) } } func TestMiddlewareRejectsMissingOrInvalidKey(t *testing.T) { keyring, err := NewKeyring([]config.APIKey{{ID: "client-a", Value: "secret"}}) if err != nil { t.Fatalf("NewKeyring() error = %v", err) } handler := Middleware(config.AuthConfig{HeaderName: "x-brain-key"}, keyring, nil, nil, nil, testLogger())(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { t.Fatal("next handler should not be called") })) req := httptest.NewRequest(http.MethodGet, "/mcp", nil) rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) if rec.Code != http.StatusUnauthorized { t.Fatalf("missing key status = %d, want %d", rec.Code, http.StatusUnauthorized) } req = httptest.NewRequest(http.MethodGet, "/mcp", nil) req.Header.Set("x-brain-key", "wrong") rec = httptest.NewRecorder() handler.ServeHTTP(rec, req) if rec.Code != http.StatusUnauthorized { t.Fatalf("invalid key status = %d, want %d", rec.Code, http.StatusUnauthorized) } } func TestMiddlewareRecordsForwardedRemoteAddr(t *testing.T) { keyring, err := NewKeyring([]config.APIKey{{ID: "client-a", Value: "secret"}}) if err != nil { t.Fatalf("NewKeyring() error = %v", err) } tracker := NewAccessTracker() handler := Middleware(config.AuthConfig{HeaderName: "x-brain-key"}, keyring, nil, nil, tracker, testLogger())(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNoContent) })) req := httptest.NewRequest(http.MethodGet, "/mcp", nil) req.RemoteAddr = "10.0.0.5:2222" req.Header.Set("x-brain-key", "secret") req.Header.Set("X-Real-IP", "203.0.113.99") rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) if rec.Code != http.StatusNoContent { t.Fatalf("status = %d, want %d", rec.Code, http.StatusNoContent) } snap := tracker.Snapshot() if len(snap) != 1 { t.Fatalf("len(snapshot) = %d, want 1", len(snap)) } if snap[0].RemoteAddr != "203.0.113.99" { t.Fatalf("snapshot remote_addr = %q, want %q", snap[0].RemoteAddr, "203.0.113.99") } } func TestMiddlewareRecordsMCPToolUsage(t *testing.T) { keyring, err := NewKeyring([]config.APIKey{{ID: "client-a", Value: "secret"}}) if err != nil { t.Fatalf("NewKeyring() error = %v", err) } tracker := NewAccessTracker() logger := testLogger() authenticated := Middleware(config.AuthConfig{HeaderName: "x-brain-key"}, keyring, nil, nil, tracker, logger)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNoContent) })) handler := observability.AccessLog(logger)(authenticated) payload := map[string]any{ "jsonrpc": "2.0", "id": "1", "method": "tools/call", "params": map[string]any{ "name": "list_projects", }, } body, err := json.Marshal(payload) if err != nil { t.Fatalf("json.Marshal() error = %v", err) } req := httptest.NewRequest(http.MethodPost, "/mcp", bytes.NewReader(body)) req.Header.Set("x-brain-key", "secret") rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) if rec.Code != http.StatusNoContent { t.Fatalf("status = %d, want %d", rec.Code, http.StatusNoContent) } metrics := tracker.Metrics(10) if metrics.UniqueTools != 1 { t.Fatalf("UniqueTools = %d, want 1", metrics.UniqueTools) } if len(metrics.TopTools) != 1 { t.Fatalf("len(TopTools) = %d, want 1", len(metrics.TopTools)) } if metrics.TopTools[0].Key != "list_projects" || metrics.TopTools[0].RequestCount != 1 { t.Fatalf("TopTools[0] = %+v, want list_projects with count 1", metrics.TopTools[0]) } }