diff --git a/README.md b/README.md index 9a86443..5298361 100644 --- a/README.md +++ b/README.md @@ -47,7 +47,7 @@ A Go MCP server for capturing and retrieving thoughts, memory, and project conte Config is YAML-driven. Copy `configs/config.example.yaml` and set: - `database.url` — Postgres connection string -- `auth.keys` — API keys for MCP endpoint access +- `auth.keys` — API keys for MCP endpoint access via `x-brain-key` or `Authorization: Bearer ` - `ai.litellm.base_url` and `ai.litellm.api_key` — LiteLLM proxy - `ai.ollama.base_url` and `ai.ollama.api_key` — Ollama local or remote server diff --git a/internal/auth/keyring_test.go b/internal/auth/keyring_test.go index 95fc074..6fda6b6 100644 --- a/internal/auth/keyring_test.go +++ b/internal/auth/keyring_test.go @@ -57,6 +57,58 @@ func TestMiddlewareAllowsHeaderAuthAndSetsContext(t *testing.T) { } } +func TestMiddlewareAllowsBearerAuthAndSetsContext(t *testing.T) { + keyring, err := NewKeyring([]config.APIKey{{ID: "client-a", Value: "secret"}}) + if err != nil { + t.Fatalf("NewKeyring() error = %v", err) + } + + handler := Middleware(config.AuthConfig{HeaderName: "x-brain-key"}, keyring, testLogger())(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + keyID, ok := KeyIDFromContext(r.Context()) + if !ok || keyID != "client-a" { + t.Fatalf("KeyIDFromContext() = (%q, %v), want (client-a, true)", keyID, ok) + } + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodGet, "/mcp", nil) + req.Header.Set("Authorization", "Bearer secret") + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK) + } +} + +func TestMiddlewarePrefersExplicitHeaderOverBearerAuth(t *testing.T) { + keyring, err := NewKeyring([]config.APIKey{ + {ID: "client-a", Value: "secret"}, + {ID: "client-b", Value: "other-secret"}, + }) + if err != nil { + t.Fatalf("NewKeyring() error = %v", err) + } + + handler := Middleware(config.AuthConfig{HeaderName: "x-brain-key"}, keyring, testLogger())(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + keyID, ok := KeyIDFromContext(r.Context()) + if !ok || keyID != "client-a" { + t.Fatalf("KeyIDFromContext() = (%q, %v), want (client-a, true)", keyID, ok) + } + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodGet, "/mcp", nil) + req.Header.Set("x-brain-key", "secret") + req.Header.Set("Authorization", "Bearer other-secret") + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK) + } +} + func TestMiddlewareAllowsQueryParamWhenEnabled(t *testing.T) { keyring, err := NewKeyring([]config.APIKey{{ID: "client-a", Value: "secret"}}) if err != nil { diff --git a/internal/auth/middleware.go b/internal/auth/middleware.go index a99da09..d12abed 100644 --- a/internal/auth/middleware.go +++ b/internal/auth/middleware.go @@ -21,7 +21,7 @@ func Middleware(cfg config.AuthConfig, keyring *Keyring, log *slog.Logger) func( return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - token := strings.TrimSpace(r.Header.Get(headerName)) + token := extractToken(r, headerName) if token == "" && cfg.AllowQueryParam { token = strings.TrimSpace(r.URL.Query().Get(cfg.QueryParam)) } @@ -43,6 +43,21 @@ func Middleware(cfg config.AuthConfig, keyring *Keyring, log *slog.Logger) func( } } +func extractToken(r *http.Request, headerName string) string { + token := strings.TrimSpace(r.Header.Get(headerName)) + if token != "" { + return token + } + + authHeader := strings.TrimSpace(r.Header.Get("Authorization")) + scheme, credentials, ok := strings.Cut(authHeader, " ") + if !ok || !strings.EqualFold(scheme, "Bearer") { + return "" + } + + return strings.TrimSpace(credentials) +} + func KeyIDFromContext(ctx context.Context) (string, bool) { value, ok := ctx.Value(keyIDContextKey).(string) return value, ok