diff --git a/README.md b/README.md index 5298361..6fb31a5 100644 --- a/README.md +++ b/README.md @@ -47,7 +47,28 @@ A Go MCP server for capturing and retrieving thoughts, memory, and project conte Config is YAML-driven. Copy `configs/config.example.yaml` and set: - `database.url` — Postgres connection string -- `auth.keys` — API keys for MCP endpoint access via `x-brain-key` or `Authorization: Bearer ` +- `auth.mode` — `api_keys` or `oauth_client_credentials` +- `auth.keys` — API keys for MCP access via `x-brain-key` or `Authorization: Bearer ` when `auth.mode=api_keys` +- `auth.oauth.clients` — client registry when `auth.mode=oauth_client_credentials` + +**OAuth Client Credentials flow** (`auth.mode=oauth_client_credentials`): + +1. Obtain a token — `POST /oauth/token` (public, no auth required): + ``` + POST /oauth/token + Content-Type: application/x-www-form-urlencoded + Authorization: Basic base64(client_id:client_secret) + + grant_type=client_credentials + ``` + Returns: `{"access_token": "...", "token_type": "bearer", "expires_in": 3600}` + +2. Use the token on the MCP endpoint: + ``` + Authorization: Bearer + ``` + +Alternatively, pass `client_id` and `client_secret` as body parameters instead of `Authorization: Basic`. Direct `Authorization: Basic` credential validation on the MCP endpoint is also supported as a fallback (no token required). - `ai.litellm.base_url` and `ai.litellm.api_key` — LiteLLM proxy - `ai.ollama.base_url` and `ai.ollama.api_key` — Ollama local or remote server diff --git a/configs/config.example.yaml b/configs/config.example.yaml index 2a763d9..78f3879 100644 --- a/configs/config.example.yaml +++ b/configs/config.example.yaml @@ -14,7 +14,6 @@ mcp: transport: "streamable_http" auth: - mode: "api_keys" header_name: "x-brain-key" query_param: "key" allow_query_param: false @@ -22,6 +21,12 @@ auth: - id: "local-client" value: "replace-me" description: "main local client key" + oauth: + clients: + - id: "oauth-client" + client_id: "" + client_secret: "" + description: "used when auth.mode=oauth_client_credentials" database: url: "postgres://postgres:postgres@localhost:5432/amcs?sslmode=disable" diff --git a/configs/dev.yaml b/configs/dev.yaml index 2a763d9..78f3879 100644 --- a/configs/dev.yaml +++ b/configs/dev.yaml @@ -14,7 +14,6 @@ mcp: transport: "streamable_http" auth: - mode: "api_keys" header_name: "x-brain-key" query_param: "key" allow_query_param: false @@ -22,6 +21,12 @@ auth: - id: "local-client" value: "replace-me" description: "main local client key" + oauth: + clients: + - id: "oauth-client" + client_id: "" + client_secret: "" + description: "used when auth.mode=oauth_client_credentials" database: url: "postgres://postgres:postgres@localhost:5432/amcs?sslmode=disable" diff --git a/configs/docker.yaml b/configs/docker.yaml index f7a734f..990c6aa 100644 --- a/configs/docker.yaml +++ b/configs/docker.yaml @@ -14,7 +14,6 @@ mcp: transport: "streamable_http" auth: - mode: "api_keys" header_name: "x-brain-key" query_param: "key" allow_query_param: false @@ -22,6 +21,12 @@ auth: - id: "local-client" value: "replace-me" description: "main local client key" + oauth: + clients: + - id: "oauth-client" + client_id: "" + client_secret: "" + description: "used when auth.mode=oauth_client_credentials" database: url: "postgres://postgres:postgres@db:5432/amcs?sslmode=disable" diff --git a/internal/app/app.go b/internal/app/app.go index bc4a641..59e6bd7 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -50,10 +50,24 @@ func Run(ctx context.Context, configPath string) error { return err } - keyring, err := auth.NewKeyring(cfg.Auth.Keys) - if err != nil { - return err + var keyring *auth.Keyring + var oauthRegistry *auth.OAuthRegistry + var tokenStore *auth.TokenStore + if len(cfg.Auth.Keys) > 0 { + keyring, err = auth.NewKeyring(cfg.Auth.Keys) + if err != nil { + return err + } } + if len(cfg.Auth.OAuth.Clients) > 0 { + oauthRegistry, err = auth.NewOAuthRegistry(cfg.Auth.OAuth.Clients) + if err != nil { + return err + } + tokenStore = auth.NewTokenStore(0) + } + authCodes := auth.NewAuthCodeStore() + dynClients := auth.NewDynamicClientStore() activeProjects := session.NewActiveProjects() logger.Info("database connection verified", @@ -62,7 +76,7 @@ func Run(ctx context.Context, configPath string) error { server := &http.Server{ Addr: fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port), - Handler: routes(logger, cfg, db, provider, keyring, activeProjects), + Handler: routes(logger, cfg, db, provider, keyring, oauthRegistry, tokenStore, authCodes, dynClients, activeProjects), ReadTimeout: cfg.Server.ReadTimeout, WriteTimeout: cfg.Server.WriteTimeout, IdleTimeout: cfg.Server.IdleTimeout, @@ -92,7 +106,7 @@ func Run(ctx context.Context, configPath string) error { } } -func routes(logger *slog.Logger, cfg *config.Config, db *store.DB, provider ai.Provider, keyring *auth.Keyring, activeProjects *session.ActiveProjects) http.Handler { +func routes(logger *slog.Logger, cfg *config.Config, db *store.DB, provider ai.Provider, keyring *auth.Keyring, oauthRegistry *auth.OAuthRegistry, tokenStore *auth.TokenStore, authCodes *auth.AuthCodeStore, dynClients *auth.DynamicClientStore, activeProjects *session.ActiveProjects) http.Handler { mux := http.NewServeMux() toolSet := mcpserver.ToolSet{ @@ -112,7 +126,13 @@ func routes(logger *slog.Logger, cfg *config.Config, db *store.DB, provider ai.P } mcpHandler := mcpserver.New(cfg.MCP, toolSet) - mux.Handle(cfg.MCP.Path, auth.Middleware(cfg.Auth, keyring, logger)(mcpHandler)) + mux.Handle(cfg.MCP.Path, auth.Middleware(cfg.Auth, keyring, oauthRegistry, tokenStore, logger)(mcpHandler)) + if oauthRegistry != nil && tokenStore != nil { + mux.HandleFunc("/.well-known/oauth-authorization-server", oauthMetadataHandler()) + mux.HandleFunc("/oauth/register", oauthRegisterHandler(dynClients, logger)) + mux.HandleFunc("/oauth/authorize", oauthAuthorizeHandler(dynClients, authCodes, logger)) + mux.HandleFunc("/oauth/token", oauthTokenHandler(oauthRegistry, tokenStore, authCodes, logger)) + } mux.HandleFunc("/favicon.ico", serveFavicon) mux.HandleFunc("/llm", serveLLMInstructions) diff --git a/internal/app/oauth.go b/internal/app/oauth.go new file mode 100644 index 0000000..878fe75 --- /dev/null +++ b/internal/app/oauth.go @@ -0,0 +1,390 @@ +package app + +import ( + "crypto/sha256" + "crypto/subtle" + "encoding/base64" + "encoding/json" + "fmt" + "html" + "log/slog" + "net/http" + "net/url" + "strings" + "time" + + "git.warky.dev/wdevs/amcs/internal/auth" +) + +// --- JSON types --- + +type tokenResponse struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` +} + +type tokenErrorResponse struct { + Error string `json:"error"` +} + +type oauthServerMetadata struct { + Issuer string `json:"issuer"` + AuthorizationEndpoint string `json:"authorization_endpoint"` + TokenEndpoint string `json:"token_endpoint"` + RegistrationEndpoint string `json:"registration_endpoint"` + ScopesSupported []string `json:"scopes_supported"` + ResponseTypesSupported []string `json:"response_types_supported"` + GrantTypesSupported []string `json:"grant_types_supported"` + TokenEndpointAuthMethodsSupported []string `json:"token_endpoint_auth_methods_supported"` + CodeChallengeMethodsSupported []string `json:"code_challenge_methods_supported"` +} + +type registerRequest struct { + ClientName string `json:"client_name"` + RedirectURIs []string `json:"redirect_uris"` + GrantTypes []string `json:"grant_types"` + ResponseTypes []string `json:"response_types"` + TokenEndpointAuthMethod string `json:"token_endpoint_auth_method"` +} + +type registerResponse struct { + ClientID string `json:"client_id"` + ClientName string `json:"client_name"` + RedirectURIs []string `json:"redirect_uris"` + GrantTypes []string `json:"grant_types"` + ResponseTypes []string `json:"response_types"` + TokenEndpointAuthMethod string `json:"token_endpoint_auth_method"` +} + +// --- Handlers --- + +// oauthMetadataHandler serves GET /.well-known/oauth-authorization-server +// per RFC 8414 for OAuth 2.0 server metadata discovery. +func oauthMetadataHandler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + base := serverBaseURL(r) + meta := oauthServerMetadata{ + Issuer: base, + AuthorizationEndpoint: base + "/oauth/authorize", + TokenEndpoint: base + "/oauth/token", + RegistrationEndpoint: base + "/oauth/register", + ScopesSupported: []string{"mcp"}, + ResponseTypesSupported: []string{"code"}, + GrantTypesSupported: []string{"authorization_code", "client_credentials"}, + TokenEndpointAuthMethodsSupported: []string{"client_secret_basic", "none"}, + CodeChallengeMethodsSupported: []string{"S256"}, + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(meta) + } +} + +// oauthRegisterHandler serves POST /oauth/register per RFC 7591 +// (OAuth 2.0 Dynamic Client Registration). +func oauthRegisterHandler(dynClients *auth.DynamicClientStore, log *slog.Logger) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + w.Header().Set("Allow", http.MethodPost) + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + + var req registerRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, "invalid request body", http.StatusBadRequest) + return + } + if len(req.RedirectURIs) == 0 { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte(`{"error":"invalid_client_metadata","error_description":"redirect_uris is required"}`)) + return + } + + client, err := dynClients.Register(req.ClientName, req.RedirectURIs) + if err != nil { + log.Error("oauth register: failed", slog.String("error", err.Error())) + http.Error(w, "internal server error", http.StatusInternalServerError) + return + } + + log.Info("oauth register: new client", + slog.String("client_id", client.ClientID), + slog.String("client_name", client.ClientName), + ) + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusCreated) + _ = json.NewEncoder(w).Encode(registerResponse{ + ClientID: client.ClientID, + ClientName: client.ClientName, + RedirectURIs: client.RedirectURIs, + GrantTypes: []string{"authorization_code"}, + ResponseTypes: []string{"code"}, + TokenEndpointAuthMethod: "none", + }) + } +} + +// oauthAuthorizeHandler serves GET and POST /oauth/authorize. +// GET shows an approval page; POST processes the user's approve/deny action. +func oauthAuthorizeHandler(dynClients *auth.DynamicClientStore, authCodes *auth.AuthCodeStore, log *slog.Logger) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + handleAuthorizeGET(w, r, dynClients) + case http.MethodPost: + handleAuthorizePOST(w, r, dynClients, authCodes, log) + default: + w.Header().Set("Allow", "GET, POST") + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + } + } +} + +func handleAuthorizeGET(w http.ResponseWriter, r *http.Request, dynClients *auth.DynamicClientStore) { + q := r.URL.Query() + clientID := q.Get("client_id") + redirectURI := q.Get("redirect_uri") + responseType := q.Get("response_type") + state := q.Get("state") + codeChallenge := q.Get("code_challenge") + codeChallengeMethod := q.Get("code_challenge_method") + scope := q.Get("scope") + + // Validate client and redirect_uri before any redirect (prevents open redirect). + client, ok := dynClients.Lookup(clientID) + if !ok { + http.Error(w, "unknown client_id", http.StatusBadRequest) + return + } + if !client.HasRedirectURI(redirectURI) { + http.Error(w, "redirect_uri not registered for this client", http.StatusBadRequest) + return + } + + // Errors from here can safely redirect back to the client. + if responseType != "code" { + oauthRedirectError(w, r, redirectURI, "unsupported_response_type", state) + return + } + if codeChallenge == "" || codeChallengeMethod != "S256" { + oauthRedirectError(w, r, redirectURI, "invalid_request", state) + return + } + + serveAuthorizePage(w, client.ClientName, clientID, redirectURI, state, codeChallenge, codeChallengeMethod, scope) +} + +func handleAuthorizePOST(w http.ResponseWriter, r *http.Request, dynClients *auth.DynamicClientStore, authCodes *auth.AuthCodeStore, log *slog.Logger) { + if err := r.ParseForm(); err != nil { + http.Error(w, "invalid form", http.StatusBadRequest) + return + } + + clientID := r.FormValue("client_id") + redirectURI := r.FormValue("redirect_uri") + state := r.FormValue("state") + codeChallenge := r.FormValue("code_challenge") + codeChallengeMethod := r.FormValue("code_challenge_method") + scope := r.FormValue("scope") + action := r.FormValue("action") + + client, ok := dynClients.Lookup(clientID) + if !ok { + http.Error(w, "unknown client_id", http.StatusBadRequest) + return + } + if !client.HasRedirectURI(redirectURI) { + http.Error(w, "redirect_uri not registered for this client", http.StatusBadRequest) + return + } + + if action == "deny" { + oauthRedirectError(w, r, redirectURI, "access_denied", state) + return + } + + code, err := authCodes.Issue(auth.AuthCode{ + ClientID: clientID, + RedirectURI: redirectURI, + Scope: scope, + CodeChallenge: codeChallenge, + CodeChallengeMethod: codeChallengeMethod, + KeyID: clientID, + }) + if err != nil { + log.Error("oauth authorize: failed to issue code", slog.String("error", err.Error())) + oauthRedirectError(w, r, redirectURI, "server_error", state) + return + } + + target := redirectURI + "?code=" + url.QueryEscape(code) + if state != "" { + target += "&state=" + url.QueryEscape(state) + } + http.Redirect(w, r, target, http.StatusFound) +} + +// oauthTokenHandler serves POST /oauth/token. +// Supports grant_type=client_credentials and grant_type=authorization_code. +func oauthTokenHandler(oauthRegistry *auth.OAuthRegistry, tokenStore *auth.TokenStore, authCodes *auth.AuthCodeStore, log *slog.Logger) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + w.Header().Set("Allow", http.MethodPost) + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + if err := r.ParseForm(); err != nil { + writeTokenError(w, "invalid_request", http.StatusBadRequest) + return + } + + switch r.FormValue("grant_type") { + case "client_credentials": + handleClientCredentials(w, r, oauthRegistry, tokenStore, log) + case "authorization_code": + handleAuthorizationCode(w, r, authCodes, tokenStore, log) + default: + writeTokenError(w, "unsupported_grant_type", http.StatusBadRequest) + } + } +} + +func handleClientCredentials(w http.ResponseWriter, r *http.Request, oauthRegistry *auth.OAuthRegistry, tokenStore *auth.TokenStore, log *slog.Logger) { + clientID, clientSecret := extractOAuthBasicOrBody(r) + if clientID == "" || clientSecret == "" { + w.Header().Set("WWW-Authenticate", `Basic realm="oauth"`) + writeTokenError(w, "invalid_client", http.StatusUnauthorized) + return + } + keyID, ok := oauthRegistry.Lookup(clientID, clientSecret) + if !ok { + log.Warn("oauth token: invalid client credentials", slog.String("remote_addr", r.RemoteAddr)) + w.Header().Set("WWW-Authenticate", `Basic realm="oauth"`) + writeTokenError(w, "invalid_client", http.StatusUnauthorized) + return + } + issueToken(w, keyID, tokenStore, log) +} + +func handleAuthorizationCode(w http.ResponseWriter, r *http.Request, authCodes *auth.AuthCodeStore, tokenStore *auth.TokenStore, log *slog.Logger) { + code := r.FormValue("code") + redirectURI := r.FormValue("redirect_uri") + clientID := r.FormValue("client_id") + codeVerifier := r.FormValue("code_verifier") + + if code == "" || redirectURI == "" || codeVerifier == "" { + writeTokenError(w, "invalid_request", http.StatusBadRequest) + return + } + + entry, ok := authCodes.Consume(code) + if !ok { + writeTokenError(w, "invalid_grant", http.StatusBadRequest) + return + } + if entry.ClientID != clientID || entry.RedirectURI != redirectURI { + writeTokenError(w, "invalid_grant", http.StatusBadRequest) + return + } + if !verifyPKCE(codeVerifier, entry.CodeChallenge, entry.CodeChallengeMethod) { + log.Warn("oauth token: PKCE verification failed", slog.String("remote_addr", r.RemoteAddr)) + writeTokenError(w, "invalid_grant", http.StatusBadRequest) + return + } + + issueToken(w, entry.KeyID, tokenStore, log) +} + +func issueToken(w http.ResponseWriter, keyID string, tokenStore *auth.TokenStore, log *slog.Logger) { + token, ttl, err := tokenStore.Issue(keyID) + if err != nil { + log.Error("oauth token: failed to issue token", slog.String("error", err.Error())) + http.Error(w, "internal server error", http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Cache-Control", "no-store") + w.Header().Set("Pragma", "no-cache") + _ = json.NewEncoder(w).Encode(tokenResponse{ + AccessToken: token, + TokenType: "bearer", + ExpiresIn: int(ttl / time.Second), + }) +} + +// --- Helpers --- + +func serveAuthorizePage(w http.ResponseWriter, clientName, clientID, redirectURI, state, codeChallenge, codeChallengeMethod, scope string) { + e := html.EscapeString + w.Header().Set("Content-Type", "text/html; charset=utf-8") + fmt.Fprintf(w, ` + + + +Authorize — AMCS + + + +

Authorize Access

+

%s is requesting access to this AMCS server.

+
+ + + + + + + + +
+ +`, + e(clientName), e(clientID), e(redirectURI), e(state), + e(codeChallenge), e(codeChallengeMethod), e(scope)) +} + +func oauthRedirectError(w http.ResponseWriter, r *http.Request, redirectURI, errCode, state string) { + target := redirectURI + "?error=" + url.QueryEscape(errCode) + if state != "" { + target += "&state=" + url.QueryEscape(state) + } + http.Redirect(w, r, target, http.StatusFound) +} + +func verifyPKCE(verifier, challenge, method string) bool { + if method != "S256" { + return false + } + h := sha256.Sum256([]byte(verifier)) + got := base64.RawURLEncoding.EncodeToString(h[:]) + return subtle.ConstantTimeCompare([]byte(got), []byte(challenge)) == 1 +} + +func serverBaseURL(r *http.Request) string { + scheme := "https" + if proto := r.Header.Get("X-Forwarded-Proto"); proto != "" { + scheme = strings.ToLower(proto) + } else if r.TLS == nil { + scheme = "http" + } + return scheme + "://" + r.Host +} + +func extractOAuthBasicOrBody(r *http.Request) (string, string) { + if id, secret, ok := r.BasicAuth(); ok { + return id, secret + } + return strings.TrimSpace(r.FormValue("client_id")), strings.TrimSpace(r.FormValue("client_secret")) +} + +func writeTokenError(w http.ResponseWriter, errCode string, status int) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + _ = json.NewEncoder(w).Encode(tokenErrorResponse{Error: errCode}) +} diff --git a/internal/auth/auth_code_store.go b/internal/auth/auth_code_store.go new file mode 100644 index 0000000..e44a023 --- /dev/null +++ b/internal/auth/auth_code_store.go @@ -0,0 +1,76 @@ +package auth + +import ( + "crypto/rand" + "encoding/hex" + "sync" + "time" +) + +const authCodeTTL = 10 * time.Minute + +// AuthCode holds a pending authorization code and its associated PKCE data. +type AuthCode struct { + ClientID string + RedirectURI string + Scope string + CodeChallenge string + CodeChallengeMethod string + KeyID string + ExpiresAt time.Time +} + +// AuthCodeStore issues single-use authorization codes for the OAuth 2.0 +// Authorization Code flow. +type AuthCodeStore struct { + mu sync.Mutex + codes map[string]AuthCode +} + +func NewAuthCodeStore() *AuthCodeStore { + s := &AuthCodeStore{codes: make(map[string]AuthCode)} + go s.sweepLoop() + return s +} + +// Issue stores the entry and returns the raw authorization code. +func (s *AuthCodeStore) Issue(entry AuthCode) (string, error) { + b := make([]byte, 16) + if _, err := rand.Read(b); err != nil { + return "", err + } + raw := hex.EncodeToString(b) + entry.ExpiresAt = time.Now().Add(authCodeTTL) + s.mu.Lock() + s.codes[raw] = entry + s.mu.Unlock() + return raw, nil +} + +// Consume validates and removes the code, returning the associated entry. +func (s *AuthCodeStore) Consume(code string) (AuthCode, bool) { + s.mu.Lock() + defer s.mu.Unlock() + entry, ok := s.codes[code] + if !ok || time.Now().After(entry.ExpiresAt) { + delete(s.codes, code) + return AuthCode{}, false + } + delete(s.codes, code) + return entry, true +} + +func (s *AuthCodeStore) sweepLoop() { + ticker := time.NewTicker(5 * time.Minute) + defer ticker.Stop() + for range ticker.C { + now := time.Now() + s.mu.Lock() + for code, entry := range s.codes { + if now.After(entry.ExpiresAt) { + delete(s.codes, code) + } + } + s.mu.Unlock() + } +} diff --git a/internal/auth/dynamic_client_store.go b/internal/auth/dynamic_client_store.go new file mode 100644 index 0000000..c137915 --- /dev/null +++ b/internal/auth/dynamic_client_store.go @@ -0,0 +1,62 @@ +package auth + +import ( + "crypto/rand" + "encoding/hex" + "sync" + "time" +) + +// DynamicClient holds a dynamically registered OAuth client (RFC 7591). +type DynamicClient struct { + ClientID string + ClientName string + RedirectURIs []string + CreatedAt time.Time +} + +// HasRedirectURI reports whether uri is registered for this client. +func (c *DynamicClient) HasRedirectURI(uri string) bool { + for _, u := range c.RedirectURIs { + if u == uri { + return true + } + } + return false +} + +// DynamicClientStore holds dynamically registered OAuth clients in memory. +type DynamicClientStore struct { + mu sync.RWMutex + clients map[string]DynamicClient +} + +func NewDynamicClientStore() *DynamicClientStore { + return &DynamicClientStore{clients: make(map[string]DynamicClient)} +} + +// Register creates a new client and returns it. +func (s *DynamicClientStore) Register(name string, redirectURIs []string) (DynamicClient, error) { + b := make([]byte, 16) + if _, err := rand.Read(b); err != nil { + return DynamicClient{}, err + } + client := DynamicClient{ + ClientID: hex.EncodeToString(b), + ClientName: name, + RedirectURIs: append([]string(nil), redirectURIs...), + CreatedAt: time.Now(), + } + s.mu.Lock() + s.clients[client.ClientID] = client + s.mu.Unlock() + return client, nil +} + +// Lookup returns the client for the given client_id. +func (s *DynamicClientStore) Lookup(clientID string) (DynamicClient, bool) { + s.mu.RLock() + client, ok := s.clients[clientID] + s.mu.RUnlock() + return client, ok +} diff --git a/internal/auth/keyring_test.go b/internal/auth/keyring_test.go index 6fda6b6..708623e 100644 --- a/internal/auth/keyring_test.go +++ b/internal/auth/keyring_test.go @@ -39,7 +39,7 @@ func TestMiddlewareAllowsHeaderAuthAndSetsContext(t *testing.T) { t.Fatalf("NewKeyring() error = %v", err) } - handler := Middleware(config.AuthConfig{HeaderName: "x-brain-key"}, keyring, testLogger())(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + handler := Middleware(config.AuthConfig{HeaderName: "x-brain-key"}, keyring, nil, nil, testLogger())(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { keyID, ok := KeyIDFromContext(r.Context()) if !ok || keyID != "client-a" { t.Fatalf("KeyIDFromContext() = (%q, %v), want (client-a, true)", keyID, ok) @@ -63,7 +63,7 @@ func TestMiddlewareAllowsBearerAuthAndSetsContext(t *testing.T) { t.Fatalf("NewKeyring() error = %v", err) } - handler := Middleware(config.AuthConfig{HeaderName: "x-brain-key"}, keyring, testLogger())(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + handler := Middleware(config.AuthConfig{HeaderName: "x-brain-key"}, keyring, nil, nil, testLogger())(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { keyID, ok := KeyIDFromContext(r.Context()) if !ok || keyID != "client-a" { t.Fatalf("KeyIDFromContext() = (%q, %v), want (client-a, true)", keyID, ok) @@ -90,7 +90,7 @@ func TestMiddlewarePrefersExplicitHeaderOverBearerAuth(t *testing.T) { t.Fatalf("NewKeyring() error = %v", err) } - handler := Middleware(config.AuthConfig{HeaderName: "x-brain-key"}, keyring, testLogger())(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + handler := Middleware(config.AuthConfig{HeaderName: "x-brain-key"}, keyring, nil, nil, testLogger())(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { keyID, ok := KeyIDFromContext(r.Context()) if !ok || keyID != "client-a" { t.Fatalf("KeyIDFromContext() = (%q, %v), want (client-a, true)", keyID, ok) @@ -119,7 +119,7 @@ func TestMiddlewareAllowsQueryParamWhenEnabled(t *testing.T) { HeaderName: "x-brain-key", QueryParam: "key", AllowQueryParam: true, - }, keyring, testLogger())(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + }, keyring, nil, nil, testLogger())(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNoContent) })) @@ -138,7 +138,7 @@ func TestMiddlewareRejectsMissingOrInvalidKey(t *testing.T) { t.Fatalf("NewKeyring() error = %v", err) } - handler := Middleware(config.AuthConfig{HeaderName: "x-brain-key"}, keyring, testLogger())(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + handler := Middleware(config.AuthConfig{HeaderName: "x-brain-key"}, keyring, nil, nil, testLogger())(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { t.Fatal("next handler should not be called") })) diff --git a/internal/auth/middleware.go b/internal/auth/middleware.go index d12abed..4395e00 100644 --- a/internal/auth/middleware.go +++ b/internal/auth/middleware.go @@ -2,6 +2,7 @@ package auth import ( "context" + "encoding/base64" "log/slog" "net/http" "strings" @@ -13,32 +14,77 @@ type contextKey string const keyIDContextKey contextKey = "auth.key_id" -func Middleware(cfg config.AuthConfig, keyring *Keyring, log *slog.Logger) func(http.Handler) http.Handler { +func Middleware(cfg config.AuthConfig, keyring *Keyring, oauthRegistry *OAuthRegistry, tokenStore *TokenStore, log *slog.Logger) func(http.Handler) http.Handler { headerName := cfg.HeaderName if headerName == "" { headerName = "x-brain-key" } - return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - token := extractToken(r, headerName) - if token == "" && cfg.AllowQueryParam { - token = strings.TrimSpace(r.URL.Query().Get(cfg.QueryParam)) + // 1. Custom header → keyring only. + if keyring != nil { + if token := strings.TrimSpace(r.Header.Get(headerName)); token != "" { + keyID, ok := keyring.Lookup(token) + if !ok { + log.Warn("authentication failed", slog.String("remote_addr", r.RemoteAddr)) + http.Error(w, "invalid API key", http.StatusUnauthorized) + return + } + next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), keyIDContextKey, keyID))) + return + } } - if token == "" { - http.Error(w, "missing API key", http.StatusUnauthorized) + // 2. Bearer token → tokenStore (OAuth), then keyring (API key). + if bearer := extractBearer(r); bearer != "" { + if tokenStore != nil { + if keyID, ok := tokenStore.Lookup(bearer); ok { + next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), keyIDContextKey, keyID))) + return + } + } + if keyring != nil { + if keyID, ok := keyring.Lookup(bearer); ok { + next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), keyIDContextKey, keyID))) + return + } + } + log.Warn("bearer token rejected", slog.String("remote_addr", r.RemoteAddr)) + http.Error(w, "invalid token or API key", http.StatusUnauthorized) return } - keyID, ok := keyring.Lookup(token) - if !ok { - log.Warn("authentication failed", slog.String("remote_addr", r.RemoteAddr)) - http.Error(w, "invalid API key", http.StatusUnauthorized) + // 3. HTTP Basic → oauthRegistry (direct client credentials). + if clientID, clientSecret := extractOAuthClientCredentials(r); clientID != "" { + if oauthRegistry == nil { + http.Error(w, "authentication is not configured", http.StatusUnauthorized) + return + } + keyID, ok := oauthRegistry.Lookup(clientID, clientSecret) + if !ok { + log.Warn("oauth client authentication failed", slog.String("remote_addr", r.RemoteAddr)) + http.Error(w, "invalid OAuth client credentials", http.StatusUnauthorized) + return + } + next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), keyIDContextKey, keyID))) return } - next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), keyIDContextKey, keyID))) + // 4. Query param → keyring. + if keyring != nil && cfg.AllowQueryParam { + if token := strings.TrimSpace(r.URL.Query().Get(cfg.QueryParam)); token != "" { + keyID, ok := keyring.Lookup(token) + if !ok { + log.Warn("authentication failed", slog.String("remote_addr", r.RemoteAddr)) + http.Error(w, "invalid API key", http.StatusUnauthorized) + return + } + next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), keyIDContextKey, keyID))) + return + } + } + + http.Error(w, "authentication required", http.StatusUnauthorized) }) } } @@ -58,6 +104,30 @@ func extractToken(r *http.Request, headerName string) string { return strings.TrimSpace(credentials) } +func extractBearer(r *http.Request) string { + authHeader := strings.TrimSpace(r.Header.Get("Authorization")) + scheme, credentials, ok := strings.Cut(authHeader, " ") + if !ok || !strings.EqualFold(scheme, "Bearer") { + return "" + } + return strings.TrimSpace(credentials) +} + +func extractOAuthClientCredentials(r *http.Request) (string, string) { + authHeader := strings.TrimSpace(r.Header.Get("Authorization")) + scheme, credentials, ok := strings.Cut(authHeader, " ") + if ok && strings.EqualFold(scheme, "Basic") { + decoded, err := base64.StdEncoding.DecodeString(strings.TrimSpace(credentials)) + if err == nil { + clientID, clientSecret, found := strings.Cut(string(decoded), ":") + if found { + return strings.TrimSpace(clientID), strings.TrimSpace(clientSecret) + } + } + } + return "", "" +} + func KeyIDFromContext(ctx context.Context) (string, bool) { value, ok := ctx.Value(keyIDContextKey).(string) return value, ok diff --git a/internal/auth/oauth_registry.go b/internal/auth/oauth_registry.go new file mode 100644 index 0000000..e512bd4 --- /dev/null +++ b/internal/auth/oauth_registry.go @@ -0,0 +1,33 @@ +package auth + +import ( + "crypto/subtle" + "fmt" + + "git.warky.dev/wdevs/amcs/internal/config" +) + +type OAuthRegistry struct { + clients []config.OAuthClient +} + +func NewOAuthRegistry(clients []config.OAuthClient) (*OAuthRegistry, error) { + if len(clients) == 0 { + return nil, fmt.Errorf("oauth registry requires at least one client") + } + + return &OAuthRegistry{clients: append([]config.OAuthClient(nil), clients...)}, nil +} + +func (o *OAuthRegistry) Lookup(clientID string, clientSecret string) (string, bool) { + for _, client := range o.clients { + if subtle.ConstantTimeCompare([]byte(client.ClientID), []byte(clientID)) == 1 && + subtle.ConstantTimeCompare([]byte(client.ClientSecret), []byte(clientSecret)) == 1 { + if client.ID != "" { + return client.ID, true + } + return client.ClientID, true + } + } + return "", false +} diff --git a/internal/auth/oauth_registry_test.go b/internal/auth/oauth_registry_test.go new file mode 100644 index 0000000..752be3f --- /dev/null +++ b/internal/auth/oauth_registry_test.go @@ -0,0 +1,92 @@ +package auth + +import ( + "encoding/base64" + "net/http" + "net/http/httptest" + "testing" + + "git.warky.dev/wdevs/amcs/internal/config" +) + +func TestNewOAuthRegistryAndLookup(t *testing.T) { + _, err := NewOAuthRegistry(nil) + if err == nil { + t.Fatal("NewOAuthRegistry(nil) error = nil, want error") + } + + registry, err := NewOAuthRegistry([]config.OAuthClient{{ + ID: "oauth-client", + ClientID: "client-id", + ClientSecret: "client-secret", + }}) + if err != nil { + t.Fatalf("NewOAuthRegistry() error = %v", err) + } + + if got, ok := registry.Lookup("client-id", "client-secret"); !ok || got != "oauth-client" { + t.Fatalf("Lookup(client-id, client-secret) = (%q, %v), want (oauth-client, true)", got, ok) + } + if _, ok := registry.Lookup("client-id", "wrong"); ok { + t.Fatal("Lookup(client-id, wrong) = true, want false") + } +} + +func TestMiddlewareAllowsOAuthBasicAuthAndSetsContext(t *testing.T) { + oauthRegistry, err := NewOAuthRegistry([]config.OAuthClient{{ + ID: "oauth-client", + ClientID: "client-id", + ClientSecret: "client-secret", + }}) + if err != nil { + t.Fatalf("NewOAuthRegistry() error = %v", err) + } + + handler := Middleware(config.AuthConfig{}, nil, oauthRegistry, nil, testLogger())(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + keyID, ok := KeyIDFromContext(r.Context()) + if !ok || keyID != "oauth-client" { + t.Fatalf("KeyIDFromContext() = (%q, %v), want (oauth-client, true)", keyID, ok) + } + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodGet, "/mcp", nil) + req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte("client-id:client-secret"))) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK) + } +} + + +func TestMiddlewareRejectsOAuthMissingOrInvalidCredentials(t *testing.T) { + oauthRegistry, err := NewOAuthRegistry([]config.OAuthClient{{ + ID: "oauth-client", + ClientID: "client-id", + ClientSecret: "client-secret", + }}) + if err != nil { + t.Fatalf("NewOAuthRegistry() error = %v", err) + } + + handler := Middleware(config.AuthConfig{}, nil, oauthRegistry, nil, testLogger())(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Fatal("next handler should not be called") + })) + + req := httptest.NewRequest(http.MethodGet, "/mcp", nil) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + if rec.Code != http.StatusUnauthorized { + t.Fatalf("missing credentials status = %d, want %d", rec.Code, http.StatusUnauthorized) + } + + req = httptest.NewRequest(http.MethodGet, "/mcp", nil) + req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte("client-id:wrong"))) + rec = httptest.NewRecorder() + handler.ServeHTTP(rec, req) + if rec.Code != http.StatusUnauthorized { + t.Fatalf("invalid credentials status = %d, want %d", rec.Code, http.StatusUnauthorized) + } +} diff --git a/internal/auth/token_store.go b/internal/auth/token_store.go new file mode 100644 index 0000000..f404973 --- /dev/null +++ b/internal/auth/token_store.go @@ -0,0 +1,74 @@ +package auth + +import ( + "crypto/rand" + "encoding/hex" + "sync" + "time" +) + +const defaultTokenTTL = time.Hour + +type tokenEntry struct { + keyID string + expiresAt time.Time +} + +// TokenStore issues and validates short-lived opaque access tokens for OAuth +// client credentials flow. +type TokenStore struct { + mu sync.RWMutex + tokens map[string]tokenEntry + ttl time.Duration +} + +func NewTokenStore(ttl time.Duration) *TokenStore { + if ttl <= 0 { + ttl = defaultTokenTTL + } + s := &TokenStore{ + tokens: make(map[string]tokenEntry), + ttl: ttl, + } + go s.sweepLoop() + return s +} + +// Issue generates a new token for the given keyID and returns the token and its TTL. +func (s *TokenStore) Issue(keyID string) (string, time.Duration, error) { + b := make([]byte, 32) + if _, err := rand.Read(b); err != nil { + return "", 0, err + } + token := hex.EncodeToString(b) + s.mu.Lock() + s.tokens[token] = tokenEntry{keyID: keyID, expiresAt: time.Now().Add(s.ttl)} + s.mu.Unlock() + return token, s.ttl, nil +} + +// Lookup validates a token and returns the associated keyID. +func (s *TokenStore) Lookup(token string) (string, bool) { + s.mu.RLock() + entry, ok := s.tokens[token] + s.mu.RUnlock() + if !ok || time.Now().After(entry.expiresAt) { + return "", false + } + return entry.keyID, true +} + +func (s *TokenStore) sweepLoop() { + ticker := time.NewTicker(5 * time.Minute) + defer ticker.Stop() + for range ticker.C { + now := time.Now() + s.mu.Lock() + for token, entry := range s.tokens { + if now.After(entry.expiresAt) { + delete(s.tokens, token) + } + } + s.mu.Unlock() + } +} diff --git a/internal/config/config.go b/internal/config/config.go index 3d8bd7a..62a1796 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -36,11 +36,11 @@ type MCPConfig struct { } type AuthConfig struct { - Mode string `yaml:"mode"` - HeaderName string `yaml:"header_name"` - QueryParam string `yaml:"query_param"` - AllowQueryParam bool `yaml:"allow_query_param"` - Keys []APIKey `yaml:"keys"` + HeaderName string `yaml:"header_name"` + QueryParam string `yaml:"query_param"` + AllowQueryParam bool `yaml:"allow_query_param"` + Keys []APIKey `yaml:"keys"` + OAuth OAuthConfig `yaml:"oauth"` } type APIKey struct { @@ -49,6 +49,17 @@ type APIKey struct { Description string `yaml:"description"` } +type OAuthConfig struct { + Clients []OAuthClient `yaml:"clients"` +} + +type OAuthClient struct { + ID string `yaml:"id"` + ClientID string `yaml:"client_id"` + ClientSecret string `yaml:"client_secret"` + Description string `yaml:"description"` +} + type DatabaseConfig struct { URL string `yaml:"url"` MaxConns int32 `yaml:"max_conns"` diff --git a/internal/config/loader.go b/internal/config/loader.go index 000e23d..47d1589 100644 --- a/internal/config/loader.go +++ b/internal/config/loader.go @@ -32,8 +32,10 @@ func Load(explicitPath string) (*Config, string, error) { } func ResolvePath(explicitPath string) string { - if strings.TrimSpace(explicitPath) != "" { - return explicitPath + if path := strings.TrimSpace(explicitPath); path != "" { + if path != ".yaml" && path != ".yml" { + return path + } } if envPath := strings.TrimSpace(os.Getenv("AMCS_CONFIG")); envPath != "" { @@ -59,7 +61,6 @@ func defaultConfig() Config { Transport: "streamable_http", }, Auth: AuthConfig{ - Mode: "api_keys", HeaderName: "x-brain-key", QueryParam: "key", }, diff --git a/internal/config/loader_test.go b/internal/config/loader_test.go index e10188e..f288801 100644 --- a/internal/config/loader_test.go +++ b/internal/config/loader_test.go @@ -18,6 +18,18 @@ func TestResolvePathPrecedence(t *testing.T) { } } +func TestResolvePathIgnoresBareYAMLExtension(t *testing.T) { + t.Setenv("AMCS_CONFIG", "/tmp/from-env.yaml") + + if got := ResolvePath(".yaml"); got != "/tmp/from-env.yaml" { + t.Fatalf("ResolvePath(.yaml) = %q, want %q", got, "/tmp/from-env.yaml") + } + + if got := ResolvePath(".yml"); got != "/tmp/from-env.yaml" { + t.Fatalf("ResolvePath(.yml) = %q, want %q", got, "/tmp/from-env.yaml") + } +} + func TestLoadAppliesEnvOverrides(t *testing.T) { configPath := filepath.Join(t.TempDir(), "test.yaml") if err := os.WriteFile(configPath, []byte(` diff --git a/internal/config/validate.go b/internal/config/validate.go index a9559bf..d0a7996 100644 --- a/internal/config/validate.go +++ b/internal/config/validate.go @@ -10,10 +10,9 @@ func (c Config) Validate() error { return fmt.Errorf("invalid config: database.url is required") } - if len(c.Auth.Keys) == 0 { - return fmt.Errorf("invalid config: auth.keys must not be empty") + if len(c.Auth.Keys) == 0 && len(c.Auth.OAuth.Clients) == 0 { + return fmt.Errorf("invalid config: at least one of auth.keys or auth.oauth.clients must be configured") } - for i, key := range c.Auth.Keys { if strings.TrimSpace(key.ID) == "" { return fmt.Errorf("invalid config: auth.keys[%d].id is required", i) @@ -22,6 +21,14 @@ func (c Config) Validate() error { return fmt.Errorf("invalid config: auth.keys[%d].value is required", i) } } + for i, client := range c.Auth.OAuth.Clients { + if strings.TrimSpace(client.ClientID) == "" { + return fmt.Errorf("invalid config: auth.oauth.clients[%d].client_id is required", i) + } + if strings.TrimSpace(client.ClientSecret) == "" { + return fmt.Errorf("invalid config: auth.oauth.clients[%d].client_secret is required", i) + } + } if strings.TrimSpace(c.MCP.Path) == "" { return fmt.Errorf("invalid config: mcp.path is required") diff --git a/internal/config/validate_test.go b/internal/config/validate_test.go index 735bf86..a701c0d 100644 --- a/internal/config/validate_test.go +++ b/internal/config/validate_test.go @@ -67,3 +67,47 @@ func TestValidateRejectsEmptyAuthKeyValue(t *testing.T) { t.Fatal("Validate() error = nil, want error for empty auth key value") } } + +func TestValidateAcceptsOAuthClients(t *testing.T) { + cfg := validConfig() + cfg.Auth = AuthConfig{ + OAuth: OAuthConfig{ + Clients: []OAuthClient{{ + ID: "oauth-client", + ClientID: "client-id", + ClientSecret: "client-secret", + }}, + }, + } + + if err := cfg.Validate(); err != nil { + t.Fatalf("Validate() error = %v", err) + } +} + +func TestValidateAcceptsBothAuthMethods(t *testing.T) { + cfg := validConfig() + cfg.Auth = AuthConfig{ + Keys: []APIKey{{ID: "key1", Value: "secret"}}, + OAuth: OAuthConfig{ + Clients: []OAuthClient{{ + ID: "oauth-client", + ClientID: "client-id", + ClientSecret: "client-secret", + }}, + }, + } + + if err := cfg.Validate(); err != nil { + t.Fatalf("Validate() error = %v", err) + } +} + +func TestValidateRejectsEmptyAuth(t *testing.T) { + cfg := validConfig() + cfg.Auth = AuthConfig{} + + if err := cfg.Validate(); err == nil { + t.Fatal("Validate() error = nil, want error when neither auth.keys nor auth.oauth.clients is configured") + } +} diff --git a/llm/memory.md b/llm/memory.md index 9520e68..3bfce45 100644 --- a/llm/memory.md +++ b/llm/memory.md @@ -1,5 +1,7 @@ # AMCS Memory Instructions +AMCS (Avalon Memory Crystal Server) is an MCP server for capturing and retrieving thoughts, memory, and project context. It is backed by Postgres with pgvector for semantic search. + You have access to an MCP memory server named AMCS. Use AMCS as memory with two scopes: @@ -19,8 +21,8 @@ Use AMCS as memory with two scopes: ## Project Memory Rules - Use project memory for code decisions, architecture, TODOs, debugging findings, and context specific to the current repo or workstream. -- Before substantial work, retrieve context with `get_project_context` or `recall_context`. -- Save durable project facts with `capture_thought`. +- Before substantial work, always retrieve context with `get_project_context` or `recall_context` so prior decisions inform your approach. +- Save durable project facts with `capture_thought` after completing meaningful work. - Do not attach memory to the wrong project. ## Global Notebook Rules