package auth import ( "context" "encoding/base64" "log/slog" "net/http" "strings" "time" "git.warky.dev/wdevs/amcs/internal/config" ) type contextKey string const keyIDContextKey contextKey = "auth.key_id" func Middleware(cfg config.AuthConfig, keyring *Keyring, oauthRegistry *OAuthRegistry, tokenStore *TokenStore, tracker *AccessTracker, log *slog.Logger) func(http.Handler) http.Handler { headerName := cfg.HeaderName if headerName == "" { headerName = "x-brain-key" } recordAccess := func(r *http.Request, keyID string) { if tracker != nil { tracker.Record(keyID, r.URL.Path, r.RemoteAddr, r.UserAgent(), time.Now()) } } return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // 1. Custom header → keyring only. if keyring != nil { if token := strings.TrimSpace(r.Header.Get(headerName)); token != "" { keyID, ok := keyring.Lookup(token) if !ok { log.Warn("authentication failed", slog.String("remote_addr", r.RemoteAddr)) http.Error(w, "invalid API key", http.StatusUnauthorized) return } recordAccess(r, keyID) next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), keyIDContextKey, keyID))) return } } // 2. Bearer token → tokenStore (OAuth), then keyring (API key). if bearer := extractBearer(r); bearer != "" { if tokenStore != nil { if keyID, ok := tokenStore.Lookup(bearer); ok { recordAccess(r, keyID) next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), keyIDContextKey, keyID))) return } } if keyring != nil { if keyID, ok := keyring.Lookup(bearer); ok { recordAccess(r, keyID) next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), keyIDContextKey, keyID))) return } } log.Warn("bearer token rejected", slog.String("remote_addr", r.RemoteAddr)) http.Error(w, "invalid token or API key", http.StatusUnauthorized) return } // 3. HTTP Basic → oauthRegistry (direct client credentials). if clientID, clientSecret := extractOAuthClientCredentials(r); clientID != "" { if oauthRegistry == nil { http.Error(w, "authentication is not configured", http.StatusUnauthorized) return } keyID, ok := oauthRegistry.Lookup(clientID, clientSecret) if !ok { log.Warn("oauth client authentication failed", slog.String("remote_addr", r.RemoteAddr)) http.Error(w, "invalid OAuth client credentials", http.StatusUnauthorized) return } recordAccess(r, keyID) next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), keyIDContextKey, keyID))) return } // 4. Query param → keyring. if keyring != nil && cfg.AllowQueryParam { if token := strings.TrimSpace(r.URL.Query().Get(cfg.QueryParam)); token != "" { keyID, ok := keyring.Lookup(token) if !ok { log.Warn("authentication failed", slog.String("remote_addr", r.RemoteAddr)) http.Error(w, "invalid API key", http.StatusUnauthorized) return } recordAccess(r, keyID) next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), keyIDContextKey, keyID))) return } } http.Error(w, "authentication required", http.StatusUnauthorized) }) } } func extractBearer(r *http.Request) string { authHeader := strings.TrimSpace(r.Header.Get("Authorization")) scheme, credentials, ok := strings.Cut(authHeader, " ") if !ok || !strings.EqualFold(scheme, "Bearer") { return "" } return strings.TrimSpace(credentials) } func extractOAuthClientCredentials(r *http.Request) (string, string) { authHeader := strings.TrimSpace(r.Header.Get("Authorization")) scheme, credentials, ok := strings.Cut(authHeader, " ") if ok && strings.EqualFold(scheme, "Basic") { decoded, err := base64.StdEncoding.DecodeString(strings.TrimSpace(credentials)) if err == nil { clientID, clientSecret, found := strings.Cut(string(decoded), ":") if found { return strings.TrimSpace(clientID), strings.TrimSpace(clientSecret) } } } return "", "" } func KeyIDFromContext(ctx context.Context) (string, bool) { value, ok := ctx.Value(keyIDContextKey).(string) return value, ok }