* Implement tests for error functions like errRequiredField, errInvalidField, and errEntityNotFound. * Ensure proper metadata is returned for various error scenarios. * Validate error handling in CRM, Files, and other tools. * Introduce tests for parsing stored file IDs and UUIDs. * Enhance coverage for helper functions related to project resolution and session management.
120 lines
3.8 KiB
Go
120 lines
3.8 KiB
Go
package auth
|
|
|
|
import (
|
|
"context"
|
|
"encoding/base64"
|
|
"log/slog"
|
|
"net/http"
|
|
"strings"
|
|
|
|
"git.warky.dev/wdevs/amcs/internal/config"
|
|
)
|
|
|
|
type contextKey string
|
|
|
|
const keyIDContextKey contextKey = "auth.key_id"
|
|
|
|
func Middleware(cfg config.AuthConfig, keyring *Keyring, oauthRegistry *OAuthRegistry, tokenStore *TokenStore, log *slog.Logger) func(http.Handler) http.Handler {
|
|
headerName := cfg.HeaderName
|
|
if headerName == "" {
|
|
headerName = "x-brain-key"
|
|
}
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
// 1. Custom header → keyring only.
|
|
if keyring != nil {
|
|
if token := strings.TrimSpace(r.Header.Get(headerName)); token != "" {
|
|
keyID, ok := keyring.Lookup(token)
|
|
if !ok {
|
|
log.Warn("authentication failed", slog.String("remote_addr", r.RemoteAddr))
|
|
http.Error(w, "invalid API key", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), keyIDContextKey, keyID)))
|
|
return
|
|
}
|
|
}
|
|
|
|
// 2. Bearer token → tokenStore (OAuth), then keyring (API key).
|
|
if bearer := extractBearer(r); bearer != "" {
|
|
if tokenStore != nil {
|
|
if keyID, ok := tokenStore.Lookup(bearer); ok {
|
|
next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), keyIDContextKey, keyID)))
|
|
return
|
|
}
|
|
}
|
|
if keyring != nil {
|
|
if keyID, ok := keyring.Lookup(bearer); ok {
|
|
next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), keyIDContextKey, keyID)))
|
|
return
|
|
}
|
|
}
|
|
log.Warn("bearer token rejected", slog.String("remote_addr", r.RemoteAddr))
|
|
http.Error(w, "invalid token or API key", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
// 3. HTTP Basic → oauthRegistry (direct client credentials).
|
|
if clientID, clientSecret := extractOAuthClientCredentials(r); clientID != "" {
|
|
if oauthRegistry == nil {
|
|
http.Error(w, "authentication is not configured", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
keyID, ok := oauthRegistry.Lookup(clientID, clientSecret)
|
|
if !ok {
|
|
log.Warn("oauth client authentication failed", slog.String("remote_addr", r.RemoteAddr))
|
|
http.Error(w, "invalid OAuth client credentials", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), keyIDContextKey, keyID)))
|
|
return
|
|
}
|
|
|
|
// 4. Query param → keyring.
|
|
if keyring != nil && cfg.AllowQueryParam {
|
|
if token := strings.TrimSpace(r.URL.Query().Get(cfg.QueryParam)); token != "" {
|
|
keyID, ok := keyring.Lookup(token)
|
|
if !ok {
|
|
log.Warn("authentication failed", slog.String("remote_addr", r.RemoteAddr))
|
|
http.Error(w, "invalid API key", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), keyIDContextKey, keyID)))
|
|
return
|
|
}
|
|
}
|
|
|
|
http.Error(w, "authentication required", http.StatusUnauthorized)
|
|
})
|
|
}
|
|
}
|
|
|
|
func extractBearer(r *http.Request) string {
|
|
authHeader := strings.TrimSpace(r.Header.Get("Authorization"))
|
|
scheme, credentials, ok := strings.Cut(authHeader, " ")
|
|
if !ok || !strings.EqualFold(scheme, "Bearer") {
|
|
return ""
|
|
}
|
|
return strings.TrimSpace(credentials)
|
|
}
|
|
|
|
func extractOAuthClientCredentials(r *http.Request) (string, string) {
|
|
authHeader := strings.TrimSpace(r.Header.Get("Authorization"))
|
|
scheme, credentials, ok := strings.Cut(authHeader, " ")
|
|
if ok && strings.EqualFold(scheme, "Basic") {
|
|
decoded, err := base64.StdEncoding.DecodeString(strings.TrimSpace(credentials))
|
|
if err == nil {
|
|
clientID, clientSecret, found := strings.Cut(string(decoded), ":")
|
|
if found {
|
|
return strings.TrimSpace(clientID), strings.TrimSpace(clientSecret)
|
|
}
|
|
}
|
|
}
|
|
return "", ""
|
|
}
|
|
|
|
func KeyIDFromContext(ctx context.Context) (string, bool) {
|
|
value, ok := ctx.Value(keyIDContextKey).(string)
|
|
return value, ok
|
|
}
|