Files
amcs/internal/auth/middleware.go
T
Hein a993859c62
CI / build-and-test (push) Has been cancelled
feat(db): add oauth_clients table for dynamic client registration
* Introduced oauth_clients table with fields for client_id, client_name, redirect_uris, and created_at.
* Updated agent_persona_parts, agent_persona_skills, agent_persona_guardrails, agent_persona_traits, and arc_stage_parts tables to use unique constraints instead of primary keys for composite indexes.
2026-05-07 13:30:30 +02:00

159 lines
5.0 KiB
Go

package auth
import (
"context"
"encoding/base64"
"log/slog"
"net/http"
"strings"
"time"
"git.warky.dev/wdevs/amcs/internal/config"
"git.warky.dev/wdevs/amcs/internal/observability"
"git.warky.dev/wdevs/amcs/internal/requestip"
)
type contextKey string
const keyIDContextKey contextKey = "auth.key_id"
// wwwAuthenticate returns the value for a WWW-Authenticate header.
// It advertises Bearer and, when a public URL is known, the OAuth metadata URL per RFC 9728.
func wwwAuthenticate(r *http.Request, publicURL string) string {
base := publicURL
if base == "" {
scheme := "https"
if proto := r.Header.Get("X-Forwarded-Proto"); proto != "" {
scheme = strings.ToLower(proto)
} else if r.TLS == nil {
scheme = "http"
}
base = scheme + "://" + r.Host
}
return `Bearer resource_metadata="` + base + `/.well-known/oauth-authorization-server"`
}
func Middleware(cfg config.AuthConfig, keyring *Keyring, oauthRegistry *OAuthRegistry, tokenStore *TokenStore, tracker *AccessTracker, log *slog.Logger) func(http.Handler) http.Handler {
headerName := cfg.HeaderName
if headerName == "" {
headerName = "x-brain-key"
}
recordAccess := func(r *http.Request, keyID string) {
if tracker != nil {
tracker.Record(
keyID,
r.URL.Path,
requestip.FromRequest(r),
r.UserAgent(),
observability.MCPToolFromContext(r.Context()),
time.Now(),
)
}
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
remoteAddr := requestip.FromRequest(r)
// 1. Custom header → keyring only.
if keyring != nil {
if token := strings.TrimSpace(r.Header.Get(headerName)); token != "" {
keyID, ok := keyring.Lookup(token)
if !ok {
log.Warn("authentication failed", slog.String("remote_addr", remoteAddr))
http.Error(w, "invalid API key", http.StatusUnauthorized)
return
}
recordAccess(r, keyID)
next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), keyIDContextKey, keyID)))
return
}
}
// 2. Bearer token → tokenStore (OAuth), then keyring (API key).
if bearer := extractBearer(r); bearer != "" {
if tokenStore != nil {
if keyID, ok := tokenStore.Lookup(bearer); ok {
recordAccess(r, keyID)
next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), keyIDContextKey, keyID)))
return
}
}
if keyring != nil {
if keyID, ok := keyring.Lookup(bearer); ok {
recordAccess(r, keyID)
next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), keyIDContextKey, keyID)))
return
}
}
log.Warn("bearer token rejected", slog.String("remote_addr", remoteAddr))
w.Header().Set("WWW-Authenticate", wwwAuthenticate(r, "")+`, error="invalid_token"`)
http.Error(w, "invalid token or API key", http.StatusUnauthorized)
return
}
// 3. HTTP Basic → oauthRegistry (direct client credentials).
if clientID, clientSecret := extractOAuthClientCredentials(r); clientID != "" {
if oauthRegistry == nil {
http.Error(w, "authentication is not configured", http.StatusUnauthorized)
return
}
keyID, ok := oauthRegistry.Lookup(clientID, clientSecret)
if !ok {
log.Warn("oauth client authentication failed", slog.String("remote_addr", remoteAddr))
http.Error(w, "invalid OAuth client credentials", http.StatusUnauthorized)
return
}
recordAccess(r, keyID)
next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), keyIDContextKey, keyID)))
return
}
// 4. Query param → keyring.
if keyring != nil && cfg.AllowQueryParam {
if token := strings.TrimSpace(r.URL.Query().Get(cfg.QueryParam)); token != "" {
keyID, ok := keyring.Lookup(token)
if !ok {
log.Warn("authentication failed", slog.String("remote_addr", remoteAddr))
http.Error(w, "invalid API key", http.StatusUnauthorized)
return
}
recordAccess(r, keyID)
next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), keyIDContextKey, keyID)))
return
}
}
w.Header().Set("WWW-Authenticate", wwwAuthenticate(r, ""))
http.Error(w, "authentication required", http.StatusUnauthorized)
})
}
}
func extractBearer(r *http.Request) string {
authHeader := strings.TrimSpace(r.Header.Get("Authorization"))
scheme, credentials, ok := strings.Cut(authHeader, " ")
if !ok || !strings.EqualFold(scheme, "Bearer") {
return ""
}
return strings.TrimSpace(credentials)
}
func extractOAuthClientCredentials(r *http.Request) (string, string) {
authHeader := strings.TrimSpace(r.Header.Get("Authorization"))
scheme, credentials, ok := strings.Cut(authHeader, " ")
if ok && strings.EqualFold(scheme, "Basic") {
decoded, err := base64.StdEncoding.DecodeString(strings.TrimSpace(credentials))
if err == nil {
clientID, clientSecret, found := strings.Cut(string(decoded), ":")
if found {
return strings.TrimSpace(clientID), strings.TrimSpace(clientSecret)
}
}
}
return "", ""
}
func KeyIDFromContext(ctx context.Context) (string, bool) {
value, ok := ctx.Value(keyIDContextKey).(string)
return value, ok
}