feat(auth): implement OAuth 2.0 authorization code flow and dynamic client registration

- Add OAuth 2.0 support with authorization code flow and dynamic client registration.
- Introduce new handlers for OAuth metadata, client registration, authorization, and token issuance.
- Enhance authentication middleware to support OAuth client credentials.
- Create in-memory stores for authorization codes and tokens.
- Update configuration to include OAuth client details.
- Ensure validation checks for OAuth clients in the configuration.
This commit is contained in:
2026-03-26 21:17:55 +02:00
parent ed05d390b7
commit 56c84df342
19 changed files with 970 additions and 40 deletions

View File

@@ -50,10 +50,24 @@ func Run(ctx context.Context, configPath string) error {
return err
}
keyring, err := auth.NewKeyring(cfg.Auth.Keys)
if err != nil {
return err
var keyring *auth.Keyring
var oauthRegistry *auth.OAuthRegistry
var tokenStore *auth.TokenStore
if len(cfg.Auth.Keys) > 0 {
keyring, err = auth.NewKeyring(cfg.Auth.Keys)
if err != nil {
return err
}
}
if len(cfg.Auth.OAuth.Clients) > 0 {
oauthRegistry, err = auth.NewOAuthRegistry(cfg.Auth.OAuth.Clients)
if err != nil {
return err
}
tokenStore = auth.NewTokenStore(0)
}
authCodes := auth.NewAuthCodeStore()
dynClients := auth.NewDynamicClientStore()
activeProjects := session.NewActiveProjects()
logger.Info("database connection verified",
@@ -62,7 +76,7 @@ func Run(ctx context.Context, configPath string) error {
server := &http.Server{
Addr: fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port),
Handler: routes(logger, cfg, db, provider, keyring, activeProjects),
Handler: routes(logger, cfg, db, provider, keyring, oauthRegistry, tokenStore, authCodes, dynClients, activeProjects),
ReadTimeout: cfg.Server.ReadTimeout,
WriteTimeout: cfg.Server.WriteTimeout,
IdleTimeout: cfg.Server.IdleTimeout,
@@ -92,7 +106,7 @@ func Run(ctx context.Context, configPath string) error {
}
}
func routes(logger *slog.Logger, cfg *config.Config, db *store.DB, provider ai.Provider, keyring *auth.Keyring, activeProjects *session.ActiveProjects) http.Handler {
func routes(logger *slog.Logger, cfg *config.Config, db *store.DB, provider ai.Provider, keyring *auth.Keyring, oauthRegistry *auth.OAuthRegistry, tokenStore *auth.TokenStore, authCodes *auth.AuthCodeStore, dynClients *auth.DynamicClientStore, activeProjects *session.ActiveProjects) http.Handler {
mux := http.NewServeMux()
toolSet := mcpserver.ToolSet{
@@ -112,7 +126,13 @@ func routes(logger *slog.Logger, cfg *config.Config, db *store.DB, provider ai.P
}
mcpHandler := mcpserver.New(cfg.MCP, toolSet)
mux.Handle(cfg.MCP.Path, auth.Middleware(cfg.Auth, keyring, logger)(mcpHandler))
mux.Handle(cfg.MCP.Path, auth.Middleware(cfg.Auth, keyring, oauthRegistry, tokenStore, logger)(mcpHandler))
if oauthRegistry != nil && tokenStore != nil {
mux.HandleFunc("/.well-known/oauth-authorization-server", oauthMetadataHandler())
mux.HandleFunc("/oauth/register", oauthRegisterHandler(dynClients, logger))
mux.HandleFunc("/oauth/authorize", oauthAuthorizeHandler(dynClients, authCodes, logger))
mux.HandleFunc("/oauth/token", oauthTokenHandler(oauthRegistry, tokenStore, authCodes, logger))
}
mux.HandleFunc("/favicon.ico", serveFavicon)
mux.HandleFunc("/llm", serveLLMInstructions)

390
internal/app/oauth.go Normal file
View File

@@ -0,0 +1,390 @@
package app
import (
"crypto/sha256"
"crypto/subtle"
"encoding/base64"
"encoding/json"
"fmt"
"html"
"log/slog"
"net/http"
"net/url"
"strings"
"time"
"git.warky.dev/wdevs/amcs/internal/auth"
)
// --- JSON types ---
type tokenResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
}
type tokenErrorResponse struct {
Error string `json:"error"`
}
type oauthServerMetadata struct {
Issuer string `json:"issuer"`
AuthorizationEndpoint string `json:"authorization_endpoint"`
TokenEndpoint string `json:"token_endpoint"`
RegistrationEndpoint string `json:"registration_endpoint"`
ScopesSupported []string `json:"scopes_supported"`
ResponseTypesSupported []string `json:"response_types_supported"`
GrantTypesSupported []string `json:"grant_types_supported"`
TokenEndpointAuthMethodsSupported []string `json:"token_endpoint_auth_methods_supported"`
CodeChallengeMethodsSupported []string `json:"code_challenge_methods_supported"`
}
type registerRequest struct {
ClientName string `json:"client_name"`
RedirectURIs []string `json:"redirect_uris"`
GrantTypes []string `json:"grant_types"`
ResponseTypes []string `json:"response_types"`
TokenEndpointAuthMethod string `json:"token_endpoint_auth_method"`
}
type registerResponse struct {
ClientID string `json:"client_id"`
ClientName string `json:"client_name"`
RedirectURIs []string `json:"redirect_uris"`
GrantTypes []string `json:"grant_types"`
ResponseTypes []string `json:"response_types"`
TokenEndpointAuthMethod string `json:"token_endpoint_auth_method"`
}
// --- Handlers ---
// oauthMetadataHandler serves GET /.well-known/oauth-authorization-server
// per RFC 8414 for OAuth 2.0 server metadata discovery.
func oauthMetadataHandler() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
base := serverBaseURL(r)
meta := oauthServerMetadata{
Issuer: base,
AuthorizationEndpoint: base + "/oauth/authorize",
TokenEndpoint: base + "/oauth/token",
RegistrationEndpoint: base + "/oauth/register",
ScopesSupported: []string{"mcp"},
ResponseTypesSupported: []string{"code"},
GrantTypesSupported: []string{"authorization_code", "client_credentials"},
TokenEndpointAuthMethodsSupported: []string{"client_secret_basic", "none"},
CodeChallengeMethodsSupported: []string{"S256"},
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(meta)
}
}
// oauthRegisterHandler serves POST /oauth/register per RFC 7591
// (OAuth 2.0 Dynamic Client Registration).
func oauthRegisterHandler(dynClients *auth.DynamicClientStore, log *slog.Logger) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
w.Header().Set("Allow", http.MethodPost)
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
var req registerRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "invalid request body", http.StatusBadRequest)
return
}
if len(req.RedirectURIs) == 0 {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
_, _ = w.Write([]byte(`{"error":"invalid_client_metadata","error_description":"redirect_uris is required"}`))
return
}
client, err := dynClients.Register(req.ClientName, req.RedirectURIs)
if err != nil {
log.Error("oauth register: failed", slog.String("error", err.Error()))
http.Error(w, "internal server error", http.StatusInternalServerError)
return
}
log.Info("oauth register: new client",
slog.String("client_id", client.ClientID),
slog.String("client_name", client.ClientName),
)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusCreated)
_ = json.NewEncoder(w).Encode(registerResponse{
ClientID: client.ClientID,
ClientName: client.ClientName,
RedirectURIs: client.RedirectURIs,
GrantTypes: []string{"authorization_code"},
ResponseTypes: []string{"code"},
TokenEndpointAuthMethod: "none",
})
}
}
// oauthAuthorizeHandler serves GET and POST /oauth/authorize.
// GET shows an approval page; POST processes the user's approve/deny action.
func oauthAuthorizeHandler(dynClients *auth.DynamicClientStore, authCodes *auth.AuthCodeStore, log *slog.Logger) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodGet:
handleAuthorizeGET(w, r, dynClients)
case http.MethodPost:
handleAuthorizePOST(w, r, dynClients, authCodes, log)
default:
w.Header().Set("Allow", "GET, POST")
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
}
}
}
func handleAuthorizeGET(w http.ResponseWriter, r *http.Request, dynClients *auth.DynamicClientStore) {
q := r.URL.Query()
clientID := q.Get("client_id")
redirectURI := q.Get("redirect_uri")
responseType := q.Get("response_type")
state := q.Get("state")
codeChallenge := q.Get("code_challenge")
codeChallengeMethod := q.Get("code_challenge_method")
scope := q.Get("scope")
// Validate client and redirect_uri before any redirect (prevents open redirect).
client, ok := dynClients.Lookup(clientID)
if !ok {
http.Error(w, "unknown client_id", http.StatusBadRequest)
return
}
if !client.HasRedirectURI(redirectURI) {
http.Error(w, "redirect_uri not registered for this client", http.StatusBadRequest)
return
}
// Errors from here can safely redirect back to the client.
if responseType != "code" {
oauthRedirectError(w, r, redirectURI, "unsupported_response_type", state)
return
}
if codeChallenge == "" || codeChallengeMethod != "S256" {
oauthRedirectError(w, r, redirectURI, "invalid_request", state)
return
}
serveAuthorizePage(w, client.ClientName, clientID, redirectURI, state, codeChallenge, codeChallengeMethod, scope)
}
func handleAuthorizePOST(w http.ResponseWriter, r *http.Request, dynClients *auth.DynamicClientStore, authCodes *auth.AuthCodeStore, log *slog.Logger) {
if err := r.ParseForm(); err != nil {
http.Error(w, "invalid form", http.StatusBadRequest)
return
}
clientID := r.FormValue("client_id")
redirectURI := r.FormValue("redirect_uri")
state := r.FormValue("state")
codeChallenge := r.FormValue("code_challenge")
codeChallengeMethod := r.FormValue("code_challenge_method")
scope := r.FormValue("scope")
action := r.FormValue("action")
client, ok := dynClients.Lookup(clientID)
if !ok {
http.Error(w, "unknown client_id", http.StatusBadRequest)
return
}
if !client.HasRedirectURI(redirectURI) {
http.Error(w, "redirect_uri not registered for this client", http.StatusBadRequest)
return
}
if action == "deny" {
oauthRedirectError(w, r, redirectURI, "access_denied", state)
return
}
code, err := authCodes.Issue(auth.AuthCode{
ClientID: clientID,
RedirectURI: redirectURI,
Scope: scope,
CodeChallenge: codeChallenge,
CodeChallengeMethod: codeChallengeMethod,
KeyID: clientID,
})
if err != nil {
log.Error("oauth authorize: failed to issue code", slog.String("error", err.Error()))
oauthRedirectError(w, r, redirectURI, "server_error", state)
return
}
target := redirectURI + "?code=" + url.QueryEscape(code)
if state != "" {
target += "&state=" + url.QueryEscape(state)
}
http.Redirect(w, r, target, http.StatusFound)
}
// oauthTokenHandler serves POST /oauth/token.
// Supports grant_type=client_credentials and grant_type=authorization_code.
func oauthTokenHandler(oauthRegistry *auth.OAuthRegistry, tokenStore *auth.TokenStore, authCodes *auth.AuthCodeStore, log *slog.Logger) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
w.Header().Set("Allow", http.MethodPost)
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
if err := r.ParseForm(); err != nil {
writeTokenError(w, "invalid_request", http.StatusBadRequest)
return
}
switch r.FormValue("grant_type") {
case "client_credentials":
handleClientCredentials(w, r, oauthRegistry, tokenStore, log)
case "authorization_code":
handleAuthorizationCode(w, r, authCodes, tokenStore, log)
default:
writeTokenError(w, "unsupported_grant_type", http.StatusBadRequest)
}
}
}
func handleClientCredentials(w http.ResponseWriter, r *http.Request, oauthRegistry *auth.OAuthRegistry, tokenStore *auth.TokenStore, log *slog.Logger) {
clientID, clientSecret := extractOAuthBasicOrBody(r)
if clientID == "" || clientSecret == "" {
w.Header().Set("WWW-Authenticate", `Basic realm="oauth"`)
writeTokenError(w, "invalid_client", http.StatusUnauthorized)
return
}
keyID, ok := oauthRegistry.Lookup(clientID, clientSecret)
if !ok {
log.Warn("oauth token: invalid client credentials", slog.String("remote_addr", r.RemoteAddr))
w.Header().Set("WWW-Authenticate", `Basic realm="oauth"`)
writeTokenError(w, "invalid_client", http.StatusUnauthorized)
return
}
issueToken(w, keyID, tokenStore, log)
}
func handleAuthorizationCode(w http.ResponseWriter, r *http.Request, authCodes *auth.AuthCodeStore, tokenStore *auth.TokenStore, log *slog.Logger) {
code := r.FormValue("code")
redirectURI := r.FormValue("redirect_uri")
clientID := r.FormValue("client_id")
codeVerifier := r.FormValue("code_verifier")
if code == "" || redirectURI == "" || codeVerifier == "" {
writeTokenError(w, "invalid_request", http.StatusBadRequest)
return
}
entry, ok := authCodes.Consume(code)
if !ok {
writeTokenError(w, "invalid_grant", http.StatusBadRequest)
return
}
if entry.ClientID != clientID || entry.RedirectURI != redirectURI {
writeTokenError(w, "invalid_grant", http.StatusBadRequest)
return
}
if !verifyPKCE(codeVerifier, entry.CodeChallenge, entry.CodeChallengeMethod) {
log.Warn("oauth token: PKCE verification failed", slog.String("remote_addr", r.RemoteAddr))
writeTokenError(w, "invalid_grant", http.StatusBadRequest)
return
}
issueToken(w, entry.KeyID, tokenStore, log)
}
func issueToken(w http.ResponseWriter, keyID string, tokenStore *auth.TokenStore, log *slog.Logger) {
token, ttl, err := tokenStore.Issue(keyID)
if err != nil {
log.Error("oauth token: failed to issue token", slog.String("error", err.Error()))
http.Error(w, "internal server error", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Cache-Control", "no-store")
w.Header().Set("Pragma", "no-cache")
_ = json.NewEncoder(w).Encode(tokenResponse{
AccessToken: token,
TokenType: "bearer",
ExpiresIn: int(ttl / time.Second),
})
}
// --- Helpers ---
func serveAuthorizePage(w http.ResponseWriter, clientName, clientID, redirectURI, state, codeChallenge, codeChallengeMethod, scope string) {
e := html.EscapeString
w.Header().Set("Content-Type", "text/html; charset=utf-8")
fmt.Fprintf(w, `<!DOCTYPE html>
<html>
<head>
<meta charset=utf-8>
<title>Authorize — AMCS</title>
<style>
body{font-family:system-ui,sans-serif;max-width:480px;margin:80px auto;padding:0 1rem}
button{padding:.5rem 1.2rem;margin-right:.5rem;cursor:pointer;font-size:1rem}
</style>
</head>
<body>
<h2>Authorize Access</h2>
<p><strong>%s</strong> is requesting access to this AMCS server.</p>
<form method=POST action=/oauth/authorize>
<input type=hidden name=client_id value="%s">
<input type=hidden name=redirect_uri value="%s">
<input type=hidden name=state value="%s">
<input type=hidden name=code_challenge value="%s">
<input type=hidden name=code_challenge_method value="%s">
<input type=hidden name=scope value="%s">
<button type=submit name=action value=approve>Approve</button>
<button type=submit name=action value=deny>Deny</button>
</form>
</body>
</html>`,
e(clientName), e(clientID), e(redirectURI), e(state),
e(codeChallenge), e(codeChallengeMethod), e(scope))
}
func oauthRedirectError(w http.ResponseWriter, r *http.Request, redirectURI, errCode, state string) {
target := redirectURI + "?error=" + url.QueryEscape(errCode)
if state != "" {
target += "&state=" + url.QueryEscape(state)
}
http.Redirect(w, r, target, http.StatusFound)
}
func verifyPKCE(verifier, challenge, method string) bool {
if method != "S256" {
return false
}
h := sha256.Sum256([]byte(verifier))
got := base64.RawURLEncoding.EncodeToString(h[:])
return subtle.ConstantTimeCompare([]byte(got), []byte(challenge)) == 1
}
func serverBaseURL(r *http.Request) string {
scheme := "https"
if proto := r.Header.Get("X-Forwarded-Proto"); proto != "" {
scheme = strings.ToLower(proto)
} else if r.TLS == nil {
scheme = "http"
}
return scheme + "://" + r.Host
}
func extractOAuthBasicOrBody(r *http.Request) (string, string) {
if id, secret, ok := r.BasicAuth(); ok {
return id, secret
}
return strings.TrimSpace(r.FormValue("client_id")), strings.TrimSpace(r.FormValue("client_secret"))
}
func writeTokenError(w http.ResponseWriter, errCode string, status int) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
_ = json.NewEncoder(w).Encode(tokenErrorResponse{Error: errCode})
}

View File

@@ -0,0 +1,76 @@
package auth
import (
"crypto/rand"
"encoding/hex"
"sync"
"time"
)
const authCodeTTL = 10 * time.Minute
// AuthCode holds a pending authorization code and its associated PKCE data.
type AuthCode struct {
ClientID string
RedirectURI string
Scope string
CodeChallenge string
CodeChallengeMethod string
KeyID string
ExpiresAt time.Time
}
// AuthCodeStore issues single-use authorization codes for the OAuth 2.0
// Authorization Code flow.
type AuthCodeStore struct {
mu sync.Mutex
codes map[string]AuthCode
}
func NewAuthCodeStore() *AuthCodeStore {
s := &AuthCodeStore{codes: make(map[string]AuthCode)}
go s.sweepLoop()
return s
}
// Issue stores the entry and returns the raw authorization code.
func (s *AuthCodeStore) Issue(entry AuthCode) (string, error) {
b := make([]byte, 16)
if _, err := rand.Read(b); err != nil {
return "", err
}
raw := hex.EncodeToString(b)
entry.ExpiresAt = time.Now().Add(authCodeTTL)
s.mu.Lock()
s.codes[raw] = entry
s.mu.Unlock()
return raw, nil
}
// Consume validates and removes the code, returning the associated entry.
func (s *AuthCodeStore) Consume(code string) (AuthCode, bool) {
s.mu.Lock()
defer s.mu.Unlock()
entry, ok := s.codes[code]
if !ok || time.Now().After(entry.ExpiresAt) {
delete(s.codes, code)
return AuthCode{}, false
}
delete(s.codes, code)
return entry, true
}
func (s *AuthCodeStore) sweepLoop() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for range ticker.C {
now := time.Now()
s.mu.Lock()
for code, entry := range s.codes {
if now.After(entry.ExpiresAt) {
delete(s.codes, code)
}
}
s.mu.Unlock()
}
}

View File

@@ -0,0 +1,62 @@
package auth
import (
"crypto/rand"
"encoding/hex"
"sync"
"time"
)
// DynamicClient holds a dynamically registered OAuth client (RFC 7591).
type DynamicClient struct {
ClientID string
ClientName string
RedirectURIs []string
CreatedAt time.Time
}
// HasRedirectURI reports whether uri is registered for this client.
func (c *DynamicClient) HasRedirectURI(uri string) bool {
for _, u := range c.RedirectURIs {
if u == uri {
return true
}
}
return false
}
// DynamicClientStore holds dynamically registered OAuth clients in memory.
type DynamicClientStore struct {
mu sync.RWMutex
clients map[string]DynamicClient
}
func NewDynamicClientStore() *DynamicClientStore {
return &DynamicClientStore{clients: make(map[string]DynamicClient)}
}
// Register creates a new client and returns it.
func (s *DynamicClientStore) Register(name string, redirectURIs []string) (DynamicClient, error) {
b := make([]byte, 16)
if _, err := rand.Read(b); err != nil {
return DynamicClient{}, err
}
client := DynamicClient{
ClientID: hex.EncodeToString(b),
ClientName: name,
RedirectURIs: append([]string(nil), redirectURIs...),
CreatedAt: time.Now(),
}
s.mu.Lock()
s.clients[client.ClientID] = client
s.mu.Unlock()
return client, nil
}
// Lookup returns the client for the given client_id.
func (s *DynamicClientStore) Lookup(clientID string) (DynamicClient, bool) {
s.mu.RLock()
client, ok := s.clients[clientID]
s.mu.RUnlock()
return client, ok
}

View File

@@ -39,7 +39,7 @@ func TestMiddlewareAllowsHeaderAuthAndSetsContext(t *testing.T) {
t.Fatalf("NewKeyring() error = %v", err)
}
handler := Middleware(config.AuthConfig{HeaderName: "x-brain-key"}, keyring, testLogger())(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handler := Middleware(config.AuthConfig{HeaderName: "x-brain-key"}, keyring, nil, nil, testLogger())(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
keyID, ok := KeyIDFromContext(r.Context())
if !ok || keyID != "client-a" {
t.Fatalf("KeyIDFromContext() = (%q, %v), want (client-a, true)", keyID, ok)
@@ -63,7 +63,7 @@ func TestMiddlewareAllowsBearerAuthAndSetsContext(t *testing.T) {
t.Fatalf("NewKeyring() error = %v", err)
}
handler := Middleware(config.AuthConfig{HeaderName: "x-brain-key"}, keyring, testLogger())(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handler := Middleware(config.AuthConfig{HeaderName: "x-brain-key"}, keyring, nil, nil, testLogger())(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
keyID, ok := KeyIDFromContext(r.Context())
if !ok || keyID != "client-a" {
t.Fatalf("KeyIDFromContext() = (%q, %v), want (client-a, true)", keyID, ok)
@@ -90,7 +90,7 @@ func TestMiddlewarePrefersExplicitHeaderOverBearerAuth(t *testing.T) {
t.Fatalf("NewKeyring() error = %v", err)
}
handler := Middleware(config.AuthConfig{HeaderName: "x-brain-key"}, keyring, testLogger())(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handler := Middleware(config.AuthConfig{HeaderName: "x-brain-key"}, keyring, nil, nil, testLogger())(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
keyID, ok := KeyIDFromContext(r.Context())
if !ok || keyID != "client-a" {
t.Fatalf("KeyIDFromContext() = (%q, %v), want (client-a, true)", keyID, ok)
@@ -119,7 +119,7 @@ func TestMiddlewareAllowsQueryParamWhenEnabled(t *testing.T) {
HeaderName: "x-brain-key",
QueryParam: "key",
AllowQueryParam: true,
}, keyring, testLogger())(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
}, keyring, nil, nil, testLogger())(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNoContent)
}))
@@ -138,7 +138,7 @@ func TestMiddlewareRejectsMissingOrInvalidKey(t *testing.T) {
t.Fatalf("NewKeyring() error = %v", err)
}
handler := Middleware(config.AuthConfig{HeaderName: "x-brain-key"}, keyring, testLogger())(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handler := Middleware(config.AuthConfig{HeaderName: "x-brain-key"}, keyring, nil, nil, testLogger())(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Fatal("next handler should not be called")
}))

View File

@@ -2,6 +2,7 @@ package auth
import (
"context"
"encoding/base64"
"log/slog"
"net/http"
"strings"
@@ -13,32 +14,77 @@ type contextKey string
const keyIDContextKey contextKey = "auth.key_id"
func Middleware(cfg config.AuthConfig, keyring *Keyring, log *slog.Logger) func(http.Handler) http.Handler {
func Middleware(cfg config.AuthConfig, keyring *Keyring, oauthRegistry *OAuthRegistry, tokenStore *TokenStore, log *slog.Logger) func(http.Handler) http.Handler {
headerName := cfg.HeaderName
if headerName == "" {
headerName = "x-brain-key"
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token := extractToken(r, headerName)
if token == "" && cfg.AllowQueryParam {
token = strings.TrimSpace(r.URL.Query().Get(cfg.QueryParam))
// 1. Custom header → keyring only.
if keyring != nil {
if token := strings.TrimSpace(r.Header.Get(headerName)); token != "" {
keyID, ok := keyring.Lookup(token)
if !ok {
log.Warn("authentication failed", slog.String("remote_addr", r.RemoteAddr))
http.Error(w, "invalid API key", http.StatusUnauthorized)
return
}
next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), keyIDContextKey, keyID)))
return
}
}
if token == "" {
http.Error(w, "missing API key", http.StatusUnauthorized)
// 2. Bearer token → tokenStore (OAuth), then keyring (API key).
if bearer := extractBearer(r); bearer != "" {
if tokenStore != nil {
if keyID, ok := tokenStore.Lookup(bearer); ok {
next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), keyIDContextKey, keyID)))
return
}
}
if keyring != nil {
if keyID, ok := keyring.Lookup(bearer); ok {
next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), keyIDContextKey, keyID)))
return
}
}
log.Warn("bearer token rejected", slog.String("remote_addr", r.RemoteAddr))
http.Error(w, "invalid token or API key", http.StatusUnauthorized)
return
}
keyID, ok := keyring.Lookup(token)
if !ok {
log.Warn("authentication failed", slog.String("remote_addr", r.RemoteAddr))
http.Error(w, "invalid API key", http.StatusUnauthorized)
// 3. HTTP Basic → oauthRegistry (direct client credentials).
if clientID, clientSecret := extractOAuthClientCredentials(r); clientID != "" {
if oauthRegistry == nil {
http.Error(w, "authentication is not configured", http.StatusUnauthorized)
return
}
keyID, ok := oauthRegistry.Lookup(clientID, clientSecret)
if !ok {
log.Warn("oauth client authentication failed", slog.String("remote_addr", r.RemoteAddr))
http.Error(w, "invalid OAuth client credentials", http.StatusUnauthorized)
return
}
next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), keyIDContextKey, keyID)))
return
}
next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), keyIDContextKey, keyID)))
// 4. Query param → keyring.
if keyring != nil && cfg.AllowQueryParam {
if token := strings.TrimSpace(r.URL.Query().Get(cfg.QueryParam)); token != "" {
keyID, ok := keyring.Lookup(token)
if !ok {
log.Warn("authentication failed", slog.String("remote_addr", r.RemoteAddr))
http.Error(w, "invalid API key", http.StatusUnauthorized)
return
}
next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), keyIDContextKey, keyID)))
return
}
}
http.Error(w, "authentication required", http.StatusUnauthorized)
})
}
}
@@ -58,6 +104,30 @@ func extractToken(r *http.Request, headerName string) string {
return strings.TrimSpace(credentials)
}
func extractBearer(r *http.Request) string {
authHeader := strings.TrimSpace(r.Header.Get("Authorization"))
scheme, credentials, ok := strings.Cut(authHeader, " ")
if !ok || !strings.EqualFold(scheme, "Bearer") {
return ""
}
return strings.TrimSpace(credentials)
}
func extractOAuthClientCredentials(r *http.Request) (string, string) {
authHeader := strings.TrimSpace(r.Header.Get("Authorization"))
scheme, credentials, ok := strings.Cut(authHeader, " ")
if ok && strings.EqualFold(scheme, "Basic") {
decoded, err := base64.StdEncoding.DecodeString(strings.TrimSpace(credentials))
if err == nil {
clientID, clientSecret, found := strings.Cut(string(decoded), ":")
if found {
return strings.TrimSpace(clientID), strings.TrimSpace(clientSecret)
}
}
}
return "", ""
}
func KeyIDFromContext(ctx context.Context) (string, bool) {
value, ok := ctx.Value(keyIDContextKey).(string)
return value, ok

View File

@@ -0,0 +1,33 @@
package auth
import (
"crypto/subtle"
"fmt"
"git.warky.dev/wdevs/amcs/internal/config"
)
type OAuthRegistry struct {
clients []config.OAuthClient
}
func NewOAuthRegistry(clients []config.OAuthClient) (*OAuthRegistry, error) {
if len(clients) == 0 {
return nil, fmt.Errorf("oauth registry requires at least one client")
}
return &OAuthRegistry{clients: append([]config.OAuthClient(nil), clients...)}, nil
}
func (o *OAuthRegistry) Lookup(clientID string, clientSecret string) (string, bool) {
for _, client := range o.clients {
if subtle.ConstantTimeCompare([]byte(client.ClientID), []byte(clientID)) == 1 &&
subtle.ConstantTimeCompare([]byte(client.ClientSecret), []byte(clientSecret)) == 1 {
if client.ID != "" {
return client.ID, true
}
return client.ClientID, true
}
}
return "", false
}

View File

@@ -0,0 +1,92 @@
package auth
import (
"encoding/base64"
"net/http"
"net/http/httptest"
"testing"
"git.warky.dev/wdevs/amcs/internal/config"
)
func TestNewOAuthRegistryAndLookup(t *testing.T) {
_, err := NewOAuthRegistry(nil)
if err == nil {
t.Fatal("NewOAuthRegistry(nil) error = nil, want error")
}
registry, err := NewOAuthRegistry([]config.OAuthClient{{
ID: "oauth-client",
ClientID: "client-id",
ClientSecret: "client-secret",
}})
if err != nil {
t.Fatalf("NewOAuthRegistry() error = %v", err)
}
if got, ok := registry.Lookup("client-id", "client-secret"); !ok || got != "oauth-client" {
t.Fatalf("Lookup(client-id, client-secret) = (%q, %v), want (oauth-client, true)", got, ok)
}
if _, ok := registry.Lookup("client-id", "wrong"); ok {
t.Fatal("Lookup(client-id, wrong) = true, want false")
}
}
func TestMiddlewareAllowsOAuthBasicAuthAndSetsContext(t *testing.T) {
oauthRegistry, err := NewOAuthRegistry([]config.OAuthClient{{
ID: "oauth-client",
ClientID: "client-id",
ClientSecret: "client-secret",
}})
if err != nil {
t.Fatalf("NewOAuthRegistry() error = %v", err)
}
handler := Middleware(config.AuthConfig{}, nil, oauthRegistry, nil, testLogger())(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
keyID, ok := KeyIDFromContext(r.Context())
if !ok || keyID != "oauth-client" {
t.Fatalf("KeyIDFromContext() = (%q, %v), want (oauth-client, true)", keyID, ok)
}
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "/mcp", nil)
req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte("client-id:client-secret")))
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
}
}
func TestMiddlewareRejectsOAuthMissingOrInvalidCredentials(t *testing.T) {
oauthRegistry, err := NewOAuthRegistry([]config.OAuthClient{{
ID: "oauth-client",
ClientID: "client-id",
ClientSecret: "client-secret",
}})
if err != nil {
t.Fatalf("NewOAuthRegistry() error = %v", err)
}
handler := Middleware(config.AuthConfig{}, nil, oauthRegistry, nil, testLogger())(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Fatal("next handler should not be called")
}))
req := httptest.NewRequest(http.MethodGet, "/mcp", nil)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusUnauthorized {
t.Fatalf("missing credentials status = %d, want %d", rec.Code, http.StatusUnauthorized)
}
req = httptest.NewRequest(http.MethodGet, "/mcp", nil)
req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte("client-id:wrong")))
rec = httptest.NewRecorder()
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusUnauthorized {
t.Fatalf("invalid credentials status = %d, want %d", rec.Code, http.StatusUnauthorized)
}
}

View File

@@ -0,0 +1,74 @@
package auth
import (
"crypto/rand"
"encoding/hex"
"sync"
"time"
)
const defaultTokenTTL = time.Hour
type tokenEntry struct {
keyID string
expiresAt time.Time
}
// TokenStore issues and validates short-lived opaque access tokens for OAuth
// client credentials flow.
type TokenStore struct {
mu sync.RWMutex
tokens map[string]tokenEntry
ttl time.Duration
}
func NewTokenStore(ttl time.Duration) *TokenStore {
if ttl <= 0 {
ttl = defaultTokenTTL
}
s := &TokenStore{
tokens: make(map[string]tokenEntry),
ttl: ttl,
}
go s.sweepLoop()
return s
}
// Issue generates a new token for the given keyID and returns the token and its TTL.
func (s *TokenStore) Issue(keyID string) (string, time.Duration, error) {
b := make([]byte, 32)
if _, err := rand.Read(b); err != nil {
return "", 0, err
}
token := hex.EncodeToString(b)
s.mu.Lock()
s.tokens[token] = tokenEntry{keyID: keyID, expiresAt: time.Now().Add(s.ttl)}
s.mu.Unlock()
return token, s.ttl, nil
}
// Lookup validates a token and returns the associated keyID.
func (s *TokenStore) Lookup(token string) (string, bool) {
s.mu.RLock()
entry, ok := s.tokens[token]
s.mu.RUnlock()
if !ok || time.Now().After(entry.expiresAt) {
return "", false
}
return entry.keyID, true
}
func (s *TokenStore) sweepLoop() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for range ticker.C {
now := time.Now()
s.mu.Lock()
for token, entry := range s.tokens {
if now.After(entry.expiresAt) {
delete(s.tokens, token)
}
}
s.mu.Unlock()
}
}

View File

@@ -36,11 +36,11 @@ type MCPConfig struct {
}
type AuthConfig struct {
Mode string `yaml:"mode"`
HeaderName string `yaml:"header_name"`
QueryParam string `yaml:"query_param"`
AllowQueryParam bool `yaml:"allow_query_param"`
Keys []APIKey `yaml:"keys"`
HeaderName string `yaml:"header_name"`
QueryParam string `yaml:"query_param"`
AllowQueryParam bool `yaml:"allow_query_param"`
Keys []APIKey `yaml:"keys"`
OAuth OAuthConfig `yaml:"oauth"`
}
type APIKey struct {
@@ -49,6 +49,17 @@ type APIKey struct {
Description string `yaml:"description"`
}
type OAuthConfig struct {
Clients []OAuthClient `yaml:"clients"`
}
type OAuthClient struct {
ID string `yaml:"id"`
ClientID string `yaml:"client_id"`
ClientSecret string `yaml:"client_secret"`
Description string `yaml:"description"`
}
type DatabaseConfig struct {
URL string `yaml:"url"`
MaxConns int32 `yaml:"max_conns"`

View File

@@ -32,8 +32,10 @@ func Load(explicitPath string) (*Config, string, error) {
}
func ResolvePath(explicitPath string) string {
if strings.TrimSpace(explicitPath) != "" {
return explicitPath
if path := strings.TrimSpace(explicitPath); path != "" {
if path != ".yaml" && path != ".yml" {
return path
}
}
if envPath := strings.TrimSpace(os.Getenv("AMCS_CONFIG")); envPath != "" {
@@ -59,7 +61,6 @@ func defaultConfig() Config {
Transport: "streamable_http",
},
Auth: AuthConfig{
Mode: "api_keys",
HeaderName: "x-brain-key",
QueryParam: "key",
},

View File

@@ -18,6 +18,18 @@ func TestResolvePathPrecedence(t *testing.T) {
}
}
func TestResolvePathIgnoresBareYAMLExtension(t *testing.T) {
t.Setenv("AMCS_CONFIG", "/tmp/from-env.yaml")
if got := ResolvePath(".yaml"); got != "/tmp/from-env.yaml" {
t.Fatalf("ResolvePath(.yaml) = %q, want %q", got, "/tmp/from-env.yaml")
}
if got := ResolvePath(".yml"); got != "/tmp/from-env.yaml" {
t.Fatalf("ResolvePath(.yml) = %q, want %q", got, "/tmp/from-env.yaml")
}
}
func TestLoadAppliesEnvOverrides(t *testing.T) {
configPath := filepath.Join(t.TempDir(), "test.yaml")
if err := os.WriteFile(configPath, []byte(`

View File

@@ -10,10 +10,9 @@ func (c Config) Validate() error {
return fmt.Errorf("invalid config: database.url is required")
}
if len(c.Auth.Keys) == 0 {
return fmt.Errorf("invalid config: auth.keys must not be empty")
if len(c.Auth.Keys) == 0 && len(c.Auth.OAuth.Clients) == 0 {
return fmt.Errorf("invalid config: at least one of auth.keys or auth.oauth.clients must be configured")
}
for i, key := range c.Auth.Keys {
if strings.TrimSpace(key.ID) == "" {
return fmt.Errorf("invalid config: auth.keys[%d].id is required", i)
@@ -22,6 +21,14 @@ func (c Config) Validate() error {
return fmt.Errorf("invalid config: auth.keys[%d].value is required", i)
}
}
for i, client := range c.Auth.OAuth.Clients {
if strings.TrimSpace(client.ClientID) == "" {
return fmt.Errorf("invalid config: auth.oauth.clients[%d].client_id is required", i)
}
if strings.TrimSpace(client.ClientSecret) == "" {
return fmt.Errorf("invalid config: auth.oauth.clients[%d].client_secret is required", i)
}
}
if strings.TrimSpace(c.MCP.Path) == "" {
return fmt.Errorf("invalid config: mcp.path is required")

View File

@@ -67,3 +67,47 @@ func TestValidateRejectsEmptyAuthKeyValue(t *testing.T) {
t.Fatal("Validate() error = nil, want error for empty auth key value")
}
}
func TestValidateAcceptsOAuthClients(t *testing.T) {
cfg := validConfig()
cfg.Auth = AuthConfig{
OAuth: OAuthConfig{
Clients: []OAuthClient{{
ID: "oauth-client",
ClientID: "client-id",
ClientSecret: "client-secret",
}},
},
}
if err := cfg.Validate(); err != nil {
t.Fatalf("Validate() error = %v", err)
}
}
func TestValidateAcceptsBothAuthMethods(t *testing.T) {
cfg := validConfig()
cfg.Auth = AuthConfig{
Keys: []APIKey{{ID: "key1", Value: "secret"}},
OAuth: OAuthConfig{
Clients: []OAuthClient{{
ID: "oauth-client",
ClientID: "client-id",
ClientSecret: "client-secret",
}},
},
}
if err := cfg.Validate(); err != nil {
t.Fatalf("Validate() error = %v", err)
}
}
func TestValidateRejectsEmptyAuth(t *testing.T) {
cfg := validConfig()
cfg.Auth = AuthConfig{}
if err := cfg.Validate(); err == nil {
t.Fatal("Validate() error = nil, want error when neither auth.keys nor auth.oauth.clients is configured")
}
}