feat(auth): implement OAuth 2.0 authorization code flow and dynamic client registration
- Add OAuth 2.0 support with authorization code flow and dynamic client registration. - Introduce new handlers for OAuth metadata, client registration, authorization, and token issuance. - Enhance authentication middleware to support OAuth client credentials. - Create in-memory stores for authorization codes and tokens. - Update configuration to include OAuth client details. - Ensure validation checks for OAuth clients in the configuration.
This commit is contained in:
@@ -2,6 +2,7 @@ package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
@@ -13,32 +14,77 @@ type contextKey string
|
||||
|
||||
const keyIDContextKey contextKey = "auth.key_id"
|
||||
|
||||
func Middleware(cfg config.AuthConfig, keyring *Keyring, log *slog.Logger) func(http.Handler) http.Handler {
|
||||
func Middleware(cfg config.AuthConfig, keyring *Keyring, oauthRegistry *OAuthRegistry, tokenStore *TokenStore, log *slog.Logger) func(http.Handler) http.Handler {
|
||||
headerName := cfg.HeaderName
|
||||
if headerName == "" {
|
||||
headerName = "x-brain-key"
|
||||
}
|
||||
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
token := extractToken(r, headerName)
|
||||
if token == "" && cfg.AllowQueryParam {
|
||||
token = strings.TrimSpace(r.URL.Query().Get(cfg.QueryParam))
|
||||
// 1. Custom header → keyring only.
|
||||
if keyring != nil {
|
||||
if token := strings.TrimSpace(r.Header.Get(headerName)); token != "" {
|
||||
keyID, ok := keyring.Lookup(token)
|
||||
if !ok {
|
||||
log.Warn("authentication failed", slog.String("remote_addr", r.RemoteAddr))
|
||||
http.Error(w, "invalid API key", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), keyIDContextKey, keyID)))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if token == "" {
|
||||
http.Error(w, "missing API key", http.StatusUnauthorized)
|
||||
// 2. Bearer token → tokenStore (OAuth), then keyring (API key).
|
||||
if bearer := extractBearer(r); bearer != "" {
|
||||
if tokenStore != nil {
|
||||
if keyID, ok := tokenStore.Lookup(bearer); ok {
|
||||
next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), keyIDContextKey, keyID)))
|
||||
return
|
||||
}
|
||||
}
|
||||
if keyring != nil {
|
||||
if keyID, ok := keyring.Lookup(bearer); ok {
|
||||
next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), keyIDContextKey, keyID)))
|
||||
return
|
||||
}
|
||||
}
|
||||
log.Warn("bearer token rejected", slog.String("remote_addr", r.RemoteAddr))
|
||||
http.Error(w, "invalid token or API key", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
keyID, ok := keyring.Lookup(token)
|
||||
if !ok {
|
||||
log.Warn("authentication failed", slog.String("remote_addr", r.RemoteAddr))
|
||||
http.Error(w, "invalid API key", http.StatusUnauthorized)
|
||||
// 3. HTTP Basic → oauthRegistry (direct client credentials).
|
||||
if clientID, clientSecret := extractOAuthClientCredentials(r); clientID != "" {
|
||||
if oauthRegistry == nil {
|
||||
http.Error(w, "authentication is not configured", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
keyID, ok := oauthRegistry.Lookup(clientID, clientSecret)
|
||||
if !ok {
|
||||
log.Warn("oauth client authentication failed", slog.String("remote_addr", r.RemoteAddr))
|
||||
http.Error(w, "invalid OAuth client credentials", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), keyIDContextKey, keyID)))
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), keyIDContextKey, keyID)))
|
||||
// 4. Query param → keyring.
|
||||
if keyring != nil && cfg.AllowQueryParam {
|
||||
if token := strings.TrimSpace(r.URL.Query().Get(cfg.QueryParam)); token != "" {
|
||||
keyID, ok := keyring.Lookup(token)
|
||||
if !ok {
|
||||
log.Warn("authentication failed", slog.String("remote_addr", r.RemoteAddr))
|
||||
http.Error(w, "invalid API key", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), keyIDContextKey, keyID)))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
http.Error(w, "authentication required", http.StatusUnauthorized)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -58,6 +104,30 @@ func extractToken(r *http.Request, headerName string) string {
|
||||
return strings.TrimSpace(credentials)
|
||||
}
|
||||
|
||||
func extractBearer(r *http.Request) string {
|
||||
authHeader := strings.TrimSpace(r.Header.Get("Authorization"))
|
||||
scheme, credentials, ok := strings.Cut(authHeader, " ")
|
||||
if !ok || !strings.EqualFold(scheme, "Bearer") {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(credentials)
|
||||
}
|
||||
|
||||
func extractOAuthClientCredentials(r *http.Request) (string, string) {
|
||||
authHeader := strings.TrimSpace(r.Header.Get("Authorization"))
|
||||
scheme, credentials, ok := strings.Cut(authHeader, " ")
|
||||
if ok && strings.EqualFold(scheme, "Basic") {
|
||||
decoded, err := base64.StdEncoding.DecodeString(strings.TrimSpace(credentials))
|
||||
if err == nil {
|
||||
clientID, clientSecret, found := strings.Cut(string(decoded), ":")
|
||||
if found {
|
||||
return strings.TrimSpace(clientID), strings.TrimSpace(clientSecret)
|
||||
}
|
||||
}
|
||||
}
|
||||
return "", ""
|
||||
}
|
||||
|
||||
func KeyIDFromContext(ctx context.Context) (string, bool) {
|
||||
value, ok := ctx.Value(keyIDContextKey).(string)
|
||||
return value, ok
|
||||
|
||||
Reference in New Issue
Block a user