feat(auth): enhance middleware to support Bearer token auth
* Added support for extracting Bearer tokens from Authorization header. * Updated middleware to prefer explicit header over Bearer token. * Improved test coverage for authentication scenarios.
This commit is contained in:
@@ -47,7 +47,7 @@ A Go MCP server for capturing and retrieving thoughts, memory, and project conte
|
|||||||
Config is YAML-driven. Copy `configs/config.example.yaml` and set:
|
Config is YAML-driven. Copy `configs/config.example.yaml` and set:
|
||||||
|
|
||||||
- `database.url` — Postgres connection string
|
- `database.url` — Postgres connection string
|
||||||
- `auth.keys` — API keys for MCP endpoint access
|
- `auth.keys` — API keys for MCP endpoint access via `x-brain-key` or `Authorization: Bearer <key>`
|
||||||
- `ai.litellm.base_url` and `ai.litellm.api_key` — LiteLLM proxy
|
- `ai.litellm.base_url` and `ai.litellm.api_key` — LiteLLM proxy
|
||||||
- `ai.ollama.base_url` and `ai.ollama.api_key` — Ollama local or remote server
|
- `ai.ollama.base_url` and `ai.ollama.api_key` — Ollama local or remote server
|
||||||
|
|
||||||
|
|||||||
@@ -57,6 +57,58 @@ func TestMiddlewareAllowsHeaderAuthAndSetsContext(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMiddlewareAllowsBearerAuthAndSetsContext(t *testing.T) {
|
||||||
|
keyring, err := NewKeyring([]config.APIKey{{ID: "client-a", Value: "secret"}})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewKeyring() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := Middleware(config.AuthConfig{HeaderName: "x-brain-key"}, keyring, testLogger())(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
keyID, ok := KeyIDFromContext(r.Context())
|
||||||
|
if !ok || keyID != "client-a" {
|
||||||
|
t.Fatalf("KeyIDFromContext() = (%q, %v), want (client-a, true)", keyID, ok)
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/mcp", nil)
|
||||||
|
req.Header.Set("Authorization", "Bearer secret")
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMiddlewarePrefersExplicitHeaderOverBearerAuth(t *testing.T) {
|
||||||
|
keyring, err := NewKeyring([]config.APIKey{
|
||||||
|
{ID: "client-a", Value: "secret"},
|
||||||
|
{ID: "client-b", Value: "other-secret"},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewKeyring() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := Middleware(config.AuthConfig{HeaderName: "x-brain-key"}, keyring, testLogger())(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
keyID, ok := KeyIDFromContext(r.Context())
|
||||||
|
if !ok || keyID != "client-a" {
|
||||||
|
t.Fatalf("KeyIDFromContext() = (%q, %v), want (client-a, true)", keyID, ok)
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/mcp", nil)
|
||||||
|
req.Header.Set("x-brain-key", "secret")
|
||||||
|
req.Header.Set("Authorization", "Bearer other-secret")
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestMiddlewareAllowsQueryParamWhenEnabled(t *testing.T) {
|
func TestMiddlewareAllowsQueryParamWhenEnabled(t *testing.T) {
|
||||||
keyring, err := NewKeyring([]config.APIKey{{ID: "client-a", Value: "secret"}})
|
keyring, err := NewKeyring([]config.APIKey{{ID: "client-a", Value: "secret"}})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ func Middleware(cfg config.AuthConfig, keyring *Keyring, log *slog.Logger) func(
|
|||||||
|
|
||||||
return func(next http.Handler) http.Handler {
|
return func(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
token := strings.TrimSpace(r.Header.Get(headerName))
|
token := extractToken(r, headerName)
|
||||||
if token == "" && cfg.AllowQueryParam {
|
if token == "" && cfg.AllowQueryParam {
|
||||||
token = strings.TrimSpace(r.URL.Query().Get(cfg.QueryParam))
|
token = strings.TrimSpace(r.URL.Query().Get(cfg.QueryParam))
|
||||||
}
|
}
|
||||||
@@ -43,6 +43,21 @@ func Middleware(cfg config.AuthConfig, keyring *Keyring, log *slog.Logger) func(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func extractToken(r *http.Request, headerName string) string {
|
||||||
|
token := strings.TrimSpace(r.Header.Get(headerName))
|
||||||
|
if token != "" {
|
||||||
|
return token
|
||||||
|
}
|
||||||
|
|
||||||
|
authHeader := strings.TrimSpace(r.Header.Get("Authorization"))
|
||||||
|
scheme, credentials, ok := strings.Cut(authHeader, " ")
|
||||||
|
if !ok || !strings.EqualFold(scheme, "Bearer") {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
return strings.TrimSpace(credentials)
|
||||||
|
}
|
||||||
|
|
||||||
func KeyIDFromContext(ctx context.Context) (string, bool) {
|
func KeyIDFromContext(ctx context.Context) (string, bool) {
|
||||||
value, ok := ctx.Value(keyIDContextKey).(string)
|
value, ok := ctx.Value(keyIDContextKey).(string)
|
||||||
return value, ok
|
return value, ok
|
||||||
|
|||||||
Reference in New Issue
Block a user