package auth import ( "context" "log/slog" "net/http" "strings" "git.warky.dev/wdevs/amcs/internal/config" ) type contextKey string const keyIDContextKey contextKey = "auth.key_id" func Middleware(cfg config.AuthConfig, keyring *Keyring, log *slog.Logger) func(http.Handler) http.Handler { headerName := cfg.HeaderName if headerName == "" { headerName = "x-brain-key" } return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { token := extractToken(r, headerName) if token == "" && cfg.AllowQueryParam { token = strings.TrimSpace(r.URL.Query().Get(cfg.QueryParam)) } if token == "" { http.Error(w, "missing API key", http.StatusUnauthorized) return } keyID, ok := keyring.Lookup(token) if !ok { log.Warn("authentication failed", slog.String("remote_addr", r.RemoteAddr)) http.Error(w, "invalid API key", http.StatusUnauthorized) return } next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), keyIDContextKey, keyID))) }) } } func extractToken(r *http.Request, headerName string) string { token := strings.TrimSpace(r.Header.Get(headerName)) if token != "" { return token } authHeader := strings.TrimSpace(r.Header.Get("Authorization")) scheme, credentials, ok := strings.Cut(authHeader, " ") if !ok || !strings.EqualFold(scheme, "Bearer") { return "" } return strings.TrimSpace(credentials) } func KeyIDFromContext(ctx context.Context) (string, bool) { value, ok := ctx.Value(keyIDContextKey).(string) return value, ok }