feat(auth): implement OAuth 2.0 authorization code flow and dynamic client registration
- Add OAuth 2.0 support with authorization code flow and dynamic client registration. - Introduce new handlers for OAuth metadata, client registration, authorization, and token issuance. - Enhance authentication middleware to support OAuth client credentials. - Create in-memory stores for authorization codes and tokens. - Update configuration to include OAuth client details. - Ensure validation checks for OAuth clients in the configuration.
This commit is contained in:
23
README.md
23
README.md
@@ -47,7 +47,28 @@ A Go MCP server for capturing and retrieving thoughts, memory, and project conte
|
|||||||
Config is YAML-driven. Copy `configs/config.example.yaml` and set:
|
Config is YAML-driven. Copy `configs/config.example.yaml` and set:
|
||||||
|
|
||||||
- `database.url` — Postgres connection string
|
- `database.url` — Postgres connection string
|
||||||
- `auth.keys` — API keys for MCP endpoint access via `x-brain-key` or `Authorization: Bearer <key>`
|
- `auth.mode` — `api_keys` or `oauth_client_credentials`
|
||||||
|
- `auth.keys` — API keys for MCP access via `x-brain-key` or `Authorization: Bearer <key>` when `auth.mode=api_keys`
|
||||||
|
- `auth.oauth.clients` — client registry when `auth.mode=oauth_client_credentials`
|
||||||
|
|
||||||
|
**OAuth Client Credentials flow** (`auth.mode=oauth_client_credentials`):
|
||||||
|
|
||||||
|
1. Obtain a token — `POST /oauth/token` (public, no auth required):
|
||||||
|
```
|
||||||
|
POST /oauth/token
|
||||||
|
Content-Type: application/x-www-form-urlencoded
|
||||||
|
Authorization: Basic base64(client_id:client_secret)
|
||||||
|
|
||||||
|
grant_type=client_credentials
|
||||||
|
```
|
||||||
|
Returns: `{"access_token": "...", "token_type": "bearer", "expires_in": 3600}`
|
||||||
|
|
||||||
|
2. Use the token on the MCP endpoint:
|
||||||
|
```
|
||||||
|
Authorization: Bearer <access_token>
|
||||||
|
```
|
||||||
|
|
||||||
|
Alternatively, pass `client_id` and `client_secret` as body parameters instead of `Authorization: Basic`. Direct `Authorization: Basic` credential validation on the MCP endpoint is also supported as a fallback (no token required).
|
||||||
- `ai.litellm.base_url` and `ai.litellm.api_key` — LiteLLM proxy
|
- `ai.litellm.base_url` and `ai.litellm.api_key` — LiteLLM proxy
|
||||||
- `ai.ollama.base_url` and `ai.ollama.api_key` — Ollama local or remote server
|
- `ai.ollama.base_url` and `ai.ollama.api_key` — Ollama local or remote server
|
||||||
|
|
||||||
|
|||||||
@@ -14,7 +14,6 @@ mcp:
|
|||||||
transport: "streamable_http"
|
transport: "streamable_http"
|
||||||
|
|
||||||
auth:
|
auth:
|
||||||
mode: "api_keys"
|
|
||||||
header_name: "x-brain-key"
|
header_name: "x-brain-key"
|
||||||
query_param: "key"
|
query_param: "key"
|
||||||
allow_query_param: false
|
allow_query_param: false
|
||||||
@@ -22,6 +21,12 @@ auth:
|
|||||||
- id: "local-client"
|
- id: "local-client"
|
||||||
value: "replace-me"
|
value: "replace-me"
|
||||||
description: "main local client key"
|
description: "main local client key"
|
||||||
|
oauth:
|
||||||
|
clients:
|
||||||
|
- id: "oauth-client"
|
||||||
|
client_id: ""
|
||||||
|
client_secret: ""
|
||||||
|
description: "used when auth.mode=oauth_client_credentials"
|
||||||
|
|
||||||
database:
|
database:
|
||||||
url: "postgres://postgres:postgres@localhost:5432/amcs?sslmode=disable"
|
url: "postgres://postgres:postgres@localhost:5432/amcs?sslmode=disable"
|
||||||
|
|||||||
@@ -14,7 +14,6 @@ mcp:
|
|||||||
transport: "streamable_http"
|
transport: "streamable_http"
|
||||||
|
|
||||||
auth:
|
auth:
|
||||||
mode: "api_keys"
|
|
||||||
header_name: "x-brain-key"
|
header_name: "x-brain-key"
|
||||||
query_param: "key"
|
query_param: "key"
|
||||||
allow_query_param: false
|
allow_query_param: false
|
||||||
@@ -22,6 +21,12 @@ auth:
|
|||||||
- id: "local-client"
|
- id: "local-client"
|
||||||
value: "replace-me"
|
value: "replace-me"
|
||||||
description: "main local client key"
|
description: "main local client key"
|
||||||
|
oauth:
|
||||||
|
clients:
|
||||||
|
- id: "oauth-client"
|
||||||
|
client_id: ""
|
||||||
|
client_secret: ""
|
||||||
|
description: "used when auth.mode=oauth_client_credentials"
|
||||||
|
|
||||||
database:
|
database:
|
||||||
url: "postgres://postgres:postgres@localhost:5432/amcs?sslmode=disable"
|
url: "postgres://postgres:postgres@localhost:5432/amcs?sslmode=disable"
|
||||||
|
|||||||
@@ -14,7 +14,6 @@ mcp:
|
|||||||
transport: "streamable_http"
|
transport: "streamable_http"
|
||||||
|
|
||||||
auth:
|
auth:
|
||||||
mode: "api_keys"
|
|
||||||
header_name: "x-brain-key"
|
header_name: "x-brain-key"
|
||||||
query_param: "key"
|
query_param: "key"
|
||||||
allow_query_param: false
|
allow_query_param: false
|
||||||
@@ -22,6 +21,12 @@ auth:
|
|||||||
- id: "local-client"
|
- id: "local-client"
|
||||||
value: "replace-me"
|
value: "replace-me"
|
||||||
description: "main local client key"
|
description: "main local client key"
|
||||||
|
oauth:
|
||||||
|
clients:
|
||||||
|
- id: "oauth-client"
|
||||||
|
client_id: ""
|
||||||
|
client_secret: ""
|
||||||
|
description: "used when auth.mode=oauth_client_credentials"
|
||||||
|
|
||||||
database:
|
database:
|
||||||
url: "postgres://postgres:postgres@db:5432/amcs?sslmode=disable"
|
url: "postgres://postgres:postgres@db:5432/amcs?sslmode=disable"
|
||||||
|
|||||||
@@ -50,10 +50,24 @@ func Run(ctx context.Context, configPath string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
keyring, err := auth.NewKeyring(cfg.Auth.Keys)
|
var keyring *auth.Keyring
|
||||||
if err != nil {
|
var oauthRegistry *auth.OAuthRegistry
|
||||||
return err
|
var tokenStore *auth.TokenStore
|
||||||
|
if len(cfg.Auth.Keys) > 0 {
|
||||||
|
keyring, err = auth.NewKeyring(cfg.Auth.Keys)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
if len(cfg.Auth.OAuth.Clients) > 0 {
|
||||||
|
oauthRegistry, err = auth.NewOAuthRegistry(cfg.Auth.OAuth.Clients)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
tokenStore = auth.NewTokenStore(0)
|
||||||
|
}
|
||||||
|
authCodes := auth.NewAuthCodeStore()
|
||||||
|
dynClients := auth.NewDynamicClientStore()
|
||||||
activeProjects := session.NewActiveProjects()
|
activeProjects := session.NewActiveProjects()
|
||||||
|
|
||||||
logger.Info("database connection verified",
|
logger.Info("database connection verified",
|
||||||
@@ -62,7 +76,7 @@ func Run(ctx context.Context, configPath string) error {
|
|||||||
|
|
||||||
server := &http.Server{
|
server := &http.Server{
|
||||||
Addr: fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port),
|
Addr: fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port),
|
||||||
Handler: routes(logger, cfg, db, provider, keyring, activeProjects),
|
Handler: routes(logger, cfg, db, provider, keyring, oauthRegistry, tokenStore, authCodes, dynClients, activeProjects),
|
||||||
ReadTimeout: cfg.Server.ReadTimeout,
|
ReadTimeout: cfg.Server.ReadTimeout,
|
||||||
WriteTimeout: cfg.Server.WriteTimeout,
|
WriteTimeout: cfg.Server.WriteTimeout,
|
||||||
IdleTimeout: cfg.Server.IdleTimeout,
|
IdleTimeout: cfg.Server.IdleTimeout,
|
||||||
@@ -92,7 +106,7 @@ func Run(ctx context.Context, configPath string) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func routes(logger *slog.Logger, cfg *config.Config, db *store.DB, provider ai.Provider, keyring *auth.Keyring, activeProjects *session.ActiveProjects) http.Handler {
|
func routes(logger *slog.Logger, cfg *config.Config, db *store.DB, provider ai.Provider, keyring *auth.Keyring, oauthRegistry *auth.OAuthRegistry, tokenStore *auth.TokenStore, authCodes *auth.AuthCodeStore, dynClients *auth.DynamicClientStore, activeProjects *session.ActiveProjects) http.Handler {
|
||||||
mux := http.NewServeMux()
|
mux := http.NewServeMux()
|
||||||
|
|
||||||
toolSet := mcpserver.ToolSet{
|
toolSet := mcpserver.ToolSet{
|
||||||
@@ -112,7 +126,13 @@ func routes(logger *slog.Logger, cfg *config.Config, db *store.DB, provider ai.P
|
|||||||
}
|
}
|
||||||
|
|
||||||
mcpHandler := mcpserver.New(cfg.MCP, toolSet)
|
mcpHandler := mcpserver.New(cfg.MCP, toolSet)
|
||||||
mux.Handle(cfg.MCP.Path, auth.Middleware(cfg.Auth, keyring, logger)(mcpHandler))
|
mux.Handle(cfg.MCP.Path, auth.Middleware(cfg.Auth, keyring, oauthRegistry, tokenStore, logger)(mcpHandler))
|
||||||
|
if oauthRegistry != nil && tokenStore != nil {
|
||||||
|
mux.HandleFunc("/.well-known/oauth-authorization-server", oauthMetadataHandler())
|
||||||
|
mux.HandleFunc("/oauth/register", oauthRegisterHandler(dynClients, logger))
|
||||||
|
mux.HandleFunc("/oauth/authorize", oauthAuthorizeHandler(dynClients, authCodes, logger))
|
||||||
|
mux.HandleFunc("/oauth/token", oauthTokenHandler(oauthRegistry, tokenStore, authCodes, logger))
|
||||||
|
}
|
||||||
mux.HandleFunc("/favicon.ico", serveFavicon)
|
mux.HandleFunc("/favicon.ico", serveFavicon)
|
||||||
mux.HandleFunc("/llm", serveLLMInstructions)
|
mux.HandleFunc("/llm", serveLLMInstructions)
|
||||||
|
|
||||||
|
|||||||
390
internal/app/oauth.go
Normal file
390
internal/app/oauth.go
Normal file
@@ -0,0 +1,390 @@
|
|||||||
|
package app
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/sha256"
|
||||||
|
"crypto/subtle"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"html"
|
||||||
|
"log/slog"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.warky.dev/wdevs/amcs/internal/auth"
|
||||||
|
)
|
||||||
|
|
||||||
|
// --- JSON types ---
|
||||||
|
|
||||||
|
type tokenResponse struct {
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
TokenType string `json:"token_type"`
|
||||||
|
ExpiresIn int `json:"expires_in"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type tokenErrorResponse struct {
|
||||||
|
Error string `json:"error"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type oauthServerMetadata struct {
|
||||||
|
Issuer string `json:"issuer"`
|
||||||
|
AuthorizationEndpoint string `json:"authorization_endpoint"`
|
||||||
|
TokenEndpoint string `json:"token_endpoint"`
|
||||||
|
RegistrationEndpoint string `json:"registration_endpoint"`
|
||||||
|
ScopesSupported []string `json:"scopes_supported"`
|
||||||
|
ResponseTypesSupported []string `json:"response_types_supported"`
|
||||||
|
GrantTypesSupported []string `json:"grant_types_supported"`
|
||||||
|
TokenEndpointAuthMethodsSupported []string `json:"token_endpoint_auth_methods_supported"`
|
||||||
|
CodeChallengeMethodsSupported []string `json:"code_challenge_methods_supported"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type registerRequest struct {
|
||||||
|
ClientName string `json:"client_name"`
|
||||||
|
RedirectURIs []string `json:"redirect_uris"`
|
||||||
|
GrantTypes []string `json:"grant_types"`
|
||||||
|
ResponseTypes []string `json:"response_types"`
|
||||||
|
TokenEndpointAuthMethod string `json:"token_endpoint_auth_method"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type registerResponse struct {
|
||||||
|
ClientID string `json:"client_id"`
|
||||||
|
ClientName string `json:"client_name"`
|
||||||
|
RedirectURIs []string `json:"redirect_uris"`
|
||||||
|
GrantTypes []string `json:"grant_types"`
|
||||||
|
ResponseTypes []string `json:"response_types"`
|
||||||
|
TokenEndpointAuthMethod string `json:"token_endpoint_auth_method"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Handlers ---
|
||||||
|
|
||||||
|
// oauthMetadataHandler serves GET /.well-known/oauth-authorization-server
|
||||||
|
// per RFC 8414 for OAuth 2.0 server metadata discovery.
|
||||||
|
func oauthMetadataHandler() http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
base := serverBaseURL(r)
|
||||||
|
meta := oauthServerMetadata{
|
||||||
|
Issuer: base,
|
||||||
|
AuthorizationEndpoint: base + "/oauth/authorize",
|
||||||
|
TokenEndpoint: base + "/oauth/token",
|
||||||
|
RegistrationEndpoint: base + "/oauth/register",
|
||||||
|
ScopesSupported: []string{"mcp"},
|
||||||
|
ResponseTypesSupported: []string{"code"},
|
||||||
|
GrantTypesSupported: []string{"authorization_code", "client_credentials"},
|
||||||
|
TokenEndpointAuthMethodsSupported: []string{"client_secret_basic", "none"},
|
||||||
|
CodeChallengeMethodsSupported: []string{"S256"},
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_ = json.NewEncoder(w).Encode(meta)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// oauthRegisterHandler serves POST /oauth/register per RFC 7591
|
||||||
|
// (OAuth 2.0 Dynamic Client Registration).
|
||||||
|
func oauthRegisterHandler(dynClients *auth.DynamicClientStore, log *slog.Logger) http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodPost {
|
||||||
|
w.Header().Set("Allow", http.MethodPost)
|
||||||
|
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req registerRequest
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
|
http.Error(w, "invalid request body", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(req.RedirectURIs) == 0 {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
_, _ = w.Write([]byte(`{"error":"invalid_client_metadata","error_description":"redirect_uris is required"}`))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
client, err := dynClients.Register(req.ClientName, req.RedirectURIs)
|
||||||
|
if err != nil {
|
||||||
|
log.Error("oauth register: failed", slog.String("error", err.Error()))
|
||||||
|
http.Error(w, "internal server error", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info("oauth register: new client",
|
||||||
|
slog.String("client_id", client.ClientID),
|
||||||
|
slog.String("client_name", client.ClientName),
|
||||||
|
)
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusCreated)
|
||||||
|
_ = json.NewEncoder(w).Encode(registerResponse{
|
||||||
|
ClientID: client.ClientID,
|
||||||
|
ClientName: client.ClientName,
|
||||||
|
RedirectURIs: client.RedirectURIs,
|
||||||
|
GrantTypes: []string{"authorization_code"},
|
||||||
|
ResponseTypes: []string{"code"},
|
||||||
|
TokenEndpointAuthMethod: "none",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// oauthAuthorizeHandler serves GET and POST /oauth/authorize.
|
||||||
|
// GET shows an approval page; POST processes the user's approve/deny action.
|
||||||
|
func oauthAuthorizeHandler(dynClients *auth.DynamicClientStore, authCodes *auth.AuthCodeStore, log *slog.Logger) http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.Method {
|
||||||
|
case http.MethodGet:
|
||||||
|
handleAuthorizeGET(w, r, dynClients)
|
||||||
|
case http.MethodPost:
|
||||||
|
handleAuthorizePOST(w, r, dynClients, authCodes, log)
|
||||||
|
default:
|
||||||
|
w.Header().Set("Allow", "GET, POST")
|
||||||
|
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleAuthorizeGET(w http.ResponseWriter, r *http.Request, dynClients *auth.DynamicClientStore) {
|
||||||
|
q := r.URL.Query()
|
||||||
|
clientID := q.Get("client_id")
|
||||||
|
redirectURI := q.Get("redirect_uri")
|
||||||
|
responseType := q.Get("response_type")
|
||||||
|
state := q.Get("state")
|
||||||
|
codeChallenge := q.Get("code_challenge")
|
||||||
|
codeChallengeMethod := q.Get("code_challenge_method")
|
||||||
|
scope := q.Get("scope")
|
||||||
|
|
||||||
|
// Validate client and redirect_uri before any redirect (prevents open redirect).
|
||||||
|
client, ok := dynClients.Lookup(clientID)
|
||||||
|
if !ok {
|
||||||
|
http.Error(w, "unknown client_id", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !client.HasRedirectURI(redirectURI) {
|
||||||
|
http.Error(w, "redirect_uri not registered for this client", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Errors from here can safely redirect back to the client.
|
||||||
|
if responseType != "code" {
|
||||||
|
oauthRedirectError(w, r, redirectURI, "unsupported_response_type", state)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if codeChallenge == "" || codeChallengeMethod != "S256" {
|
||||||
|
oauthRedirectError(w, r, redirectURI, "invalid_request", state)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
serveAuthorizePage(w, client.ClientName, clientID, redirectURI, state, codeChallenge, codeChallengeMethod, scope)
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleAuthorizePOST(w http.ResponseWriter, r *http.Request, dynClients *auth.DynamicClientStore, authCodes *auth.AuthCodeStore, log *slog.Logger) {
|
||||||
|
if err := r.ParseForm(); err != nil {
|
||||||
|
http.Error(w, "invalid form", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
clientID := r.FormValue("client_id")
|
||||||
|
redirectURI := r.FormValue("redirect_uri")
|
||||||
|
state := r.FormValue("state")
|
||||||
|
codeChallenge := r.FormValue("code_challenge")
|
||||||
|
codeChallengeMethod := r.FormValue("code_challenge_method")
|
||||||
|
scope := r.FormValue("scope")
|
||||||
|
action := r.FormValue("action")
|
||||||
|
|
||||||
|
client, ok := dynClients.Lookup(clientID)
|
||||||
|
if !ok {
|
||||||
|
http.Error(w, "unknown client_id", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !client.HasRedirectURI(redirectURI) {
|
||||||
|
http.Error(w, "redirect_uri not registered for this client", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if action == "deny" {
|
||||||
|
oauthRedirectError(w, r, redirectURI, "access_denied", state)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
code, err := authCodes.Issue(auth.AuthCode{
|
||||||
|
ClientID: clientID,
|
||||||
|
RedirectURI: redirectURI,
|
||||||
|
Scope: scope,
|
||||||
|
CodeChallenge: codeChallenge,
|
||||||
|
CodeChallengeMethod: codeChallengeMethod,
|
||||||
|
KeyID: clientID,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
log.Error("oauth authorize: failed to issue code", slog.String("error", err.Error()))
|
||||||
|
oauthRedirectError(w, r, redirectURI, "server_error", state)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
target := redirectURI + "?code=" + url.QueryEscape(code)
|
||||||
|
if state != "" {
|
||||||
|
target += "&state=" + url.QueryEscape(state)
|
||||||
|
}
|
||||||
|
http.Redirect(w, r, target, http.StatusFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
// oauthTokenHandler serves POST /oauth/token.
|
||||||
|
// Supports grant_type=client_credentials and grant_type=authorization_code.
|
||||||
|
func oauthTokenHandler(oauthRegistry *auth.OAuthRegistry, tokenStore *auth.TokenStore, authCodes *auth.AuthCodeStore, log *slog.Logger) http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodPost {
|
||||||
|
w.Header().Set("Allow", http.MethodPost)
|
||||||
|
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := r.ParseForm(); err != nil {
|
||||||
|
writeTokenError(w, "invalid_request", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
switch r.FormValue("grant_type") {
|
||||||
|
case "client_credentials":
|
||||||
|
handleClientCredentials(w, r, oauthRegistry, tokenStore, log)
|
||||||
|
case "authorization_code":
|
||||||
|
handleAuthorizationCode(w, r, authCodes, tokenStore, log)
|
||||||
|
default:
|
||||||
|
writeTokenError(w, "unsupported_grant_type", http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleClientCredentials(w http.ResponseWriter, r *http.Request, oauthRegistry *auth.OAuthRegistry, tokenStore *auth.TokenStore, log *slog.Logger) {
|
||||||
|
clientID, clientSecret := extractOAuthBasicOrBody(r)
|
||||||
|
if clientID == "" || clientSecret == "" {
|
||||||
|
w.Header().Set("WWW-Authenticate", `Basic realm="oauth"`)
|
||||||
|
writeTokenError(w, "invalid_client", http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
keyID, ok := oauthRegistry.Lookup(clientID, clientSecret)
|
||||||
|
if !ok {
|
||||||
|
log.Warn("oauth token: invalid client credentials", slog.String("remote_addr", r.RemoteAddr))
|
||||||
|
w.Header().Set("WWW-Authenticate", `Basic realm="oauth"`)
|
||||||
|
writeTokenError(w, "invalid_client", http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
issueToken(w, keyID, tokenStore, log)
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleAuthorizationCode(w http.ResponseWriter, r *http.Request, authCodes *auth.AuthCodeStore, tokenStore *auth.TokenStore, log *slog.Logger) {
|
||||||
|
code := r.FormValue("code")
|
||||||
|
redirectURI := r.FormValue("redirect_uri")
|
||||||
|
clientID := r.FormValue("client_id")
|
||||||
|
codeVerifier := r.FormValue("code_verifier")
|
||||||
|
|
||||||
|
if code == "" || redirectURI == "" || codeVerifier == "" {
|
||||||
|
writeTokenError(w, "invalid_request", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
entry, ok := authCodes.Consume(code)
|
||||||
|
if !ok {
|
||||||
|
writeTokenError(w, "invalid_grant", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if entry.ClientID != clientID || entry.RedirectURI != redirectURI {
|
||||||
|
writeTokenError(w, "invalid_grant", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !verifyPKCE(codeVerifier, entry.CodeChallenge, entry.CodeChallengeMethod) {
|
||||||
|
log.Warn("oauth token: PKCE verification failed", slog.String("remote_addr", r.RemoteAddr))
|
||||||
|
writeTokenError(w, "invalid_grant", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
issueToken(w, entry.KeyID, tokenStore, log)
|
||||||
|
}
|
||||||
|
|
||||||
|
func issueToken(w http.ResponseWriter, keyID string, tokenStore *auth.TokenStore, log *slog.Logger) {
|
||||||
|
token, ttl, err := tokenStore.Issue(keyID)
|
||||||
|
if err != nil {
|
||||||
|
log.Error("oauth token: failed to issue token", slog.String("error", err.Error()))
|
||||||
|
http.Error(w, "internal server error", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.Header().Set("Cache-Control", "no-store")
|
||||||
|
w.Header().Set("Pragma", "no-cache")
|
||||||
|
_ = json.NewEncoder(w).Encode(tokenResponse{
|
||||||
|
AccessToken: token,
|
||||||
|
TokenType: "bearer",
|
||||||
|
ExpiresIn: int(ttl / time.Second),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Helpers ---
|
||||||
|
|
||||||
|
func serveAuthorizePage(w http.ResponseWriter, clientName, clientID, redirectURI, state, codeChallenge, codeChallengeMethod, scope string) {
|
||||||
|
e := html.EscapeString
|
||||||
|
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||||
|
fmt.Fprintf(w, `<!DOCTYPE html>
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<meta charset=utf-8>
|
||||||
|
<title>Authorize — AMCS</title>
|
||||||
|
<style>
|
||||||
|
body{font-family:system-ui,sans-serif;max-width:480px;margin:80px auto;padding:0 1rem}
|
||||||
|
button{padding:.5rem 1.2rem;margin-right:.5rem;cursor:pointer;font-size:1rem}
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<h2>Authorize Access</h2>
|
||||||
|
<p><strong>%s</strong> is requesting access to this AMCS server.</p>
|
||||||
|
<form method=POST action=/oauth/authorize>
|
||||||
|
<input type=hidden name=client_id value="%s">
|
||||||
|
<input type=hidden name=redirect_uri value="%s">
|
||||||
|
<input type=hidden name=state value="%s">
|
||||||
|
<input type=hidden name=code_challenge value="%s">
|
||||||
|
<input type=hidden name=code_challenge_method value="%s">
|
||||||
|
<input type=hidden name=scope value="%s">
|
||||||
|
<button type=submit name=action value=approve>Approve</button>
|
||||||
|
<button type=submit name=action value=deny>Deny</button>
|
||||||
|
</form>
|
||||||
|
</body>
|
||||||
|
</html>`,
|
||||||
|
e(clientName), e(clientID), e(redirectURI), e(state),
|
||||||
|
e(codeChallenge), e(codeChallengeMethod), e(scope))
|
||||||
|
}
|
||||||
|
|
||||||
|
func oauthRedirectError(w http.ResponseWriter, r *http.Request, redirectURI, errCode, state string) {
|
||||||
|
target := redirectURI + "?error=" + url.QueryEscape(errCode)
|
||||||
|
if state != "" {
|
||||||
|
target += "&state=" + url.QueryEscape(state)
|
||||||
|
}
|
||||||
|
http.Redirect(w, r, target, http.StatusFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
func verifyPKCE(verifier, challenge, method string) bool {
|
||||||
|
if method != "S256" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
h := sha256.Sum256([]byte(verifier))
|
||||||
|
got := base64.RawURLEncoding.EncodeToString(h[:])
|
||||||
|
return subtle.ConstantTimeCompare([]byte(got), []byte(challenge)) == 1
|
||||||
|
}
|
||||||
|
|
||||||
|
func serverBaseURL(r *http.Request) string {
|
||||||
|
scheme := "https"
|
||||||
|
if proto := r.Header.Get("X-Forwarded-Proto"); proto != "" {
|
||||||
|
scheme = strings.ToLower(proto)
|
||||||
|
} else if r.TLS == nil {
|
||||||
|
scheme = "http"
|
||||||
|
}
|
||||||
|
return scheme + "://" + r.Host
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractOAuthBasicOrBody(r *http.Request) (string, string) {
|
||||||
|
if id, secret, ok := r.BasicAuth(); ok {
|
||||||
|
return id, secret
|
||||||
|
}
|
||||||
|
return strings.TrimSpace(r.FormValue("client_id")), strings.TrimSpace(r.FormValue("client_secret"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeTokenError(w http.ResponseWriter, errCode string, status int) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(status)
|
||||||
|
_ = json.NewEncoder(w).Encode(tokenErrorResponse{Error: errCode})
|
||||||
|
}
|
||||||
76
internal/auth/auth_code_store.go
Normal file
76
internal/auth/auth_code_store.go
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/hex"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const authCodeTTL = 10 * time.Minute
|
||||||
|
|
||||||
|
// AuthCode holds a pending authorization code and its associated PKCE data.
|
||||||
|
type AuthCode struct {
|
||||||
|
ClientID string
|
||||||
|
RedirectURI string
|
||||||
|
Scope string
|
||||||
|
CodeChallenge string
|
||||||
|
CodeChallengeMethod string
|
||||||
|
KeyID string
|
||||||
|
ExpiresAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuthCodeStore issues single-use authorization codes for the OAuth 2.0
|
||||||
|
// Authorization Code flow.
|
||||||
|
type AuthCodeStore struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
codes map[string]AuthCode
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewAuthCodeStore() *AuthCodeStore {
|
||||||
|
s := &AuthCodeStore{codes: make(map[string]AuthCode)}
|
||||||
|
go s.sweepLoop()
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// Issue stores the entry and returns the raw authorization code.
|
||||||
|
func (s *AuthCodeStore) Issue(entry AuthCode) (string, error) {
|
||||||
|
b := make([]byte, 16)
|
||||||
|
if _, err := rand.Read(b); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
raw := hex.EncodeToString(b)
|
||||||
|
entry.ExpiresAt = time.Now().Add(authCodeTTL)
|
||||||
|
s.mu.Lock()
|
||||||
|
s.codes[raw] = entry
|
||||||
|
s.mu.Unlock()
|
||||||
|
return raw, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Consume validates and removes the code, returning the associated entry.
|
||||||
|
func (s *AuthCodeStore) Consume(code string) (AuthCode, bool) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
entry, ok := s.codes[code]
|
||||||
|
if !ok || time.Now().After(entry.ExpiresAt) {
|
||||||
|
delete(s.codes, code)
|
||||||
|
return AuthCode{}, false
|
||||||
|
}
|
||||||
|
delete(s.codes, code)
|
||||||
|
return entry, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *AuthCodeStore) sweepLoop() {
|
||||||
|
ticker := time.NewTicker(5 * time.Minute)
|
||||||
|
defer ticker.Stop()
|
||||||
|
for range ticker.C {
|
||||||
|
now := time.Now()
|
||||||
|
s.mu.Lock()
|
||||||
|
for code, entry := range s.codes {
|
||||||
|
if now.After(entry.ExpiresAt) {
|
||||||
|
delete(s.codes, code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s.mu.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
62
internal/auth/dynamic_client_store.go
Normal file
62
internal/auth/dynamic_client_store.go
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/hex"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DynamicClient holds a dynamically registered OAuth client (RFC 7591).
|
||||||
|
type DynamicClient struct {
|
||||||
|
ClientID string
|
||||||
|
ClientName string
|
||||||
|
RedirectURIs []string
|
||||||
|
CreatedAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasRedirectURI reports whether uri is registered for this client.
|
||||||
|
func (c *DynamicClient) HasRedirectURI(uri string) bool {
|
||||||
|
for _, u := range c.RedirectURIs {
|
||||||
|
if u == uri {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// DynamicClientStore holds dynamically registered OAuth clients in memory.
|
||||||
|
type DynamicClientStore struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
clients map[string]DynamicClient
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewDynamicClientStore() *DynamicClientStore {
|
||||||
|
return &DynamicClientStore{clients: make(map[string]DynamicClient)}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register creates a new client and returns it.
|
||||||
|
func (s *DynamicClientStore) Register(name string, redirectURIs []string) (DynamicClient, error) {
|
||||||
|
b := make([]byte, 16)
|
||||||
|
if _, err := rand.Read(b); err != nil {
|
||||||
|
return DynamicClient{}, err
|
||||||
|
}
|
||||||
|
client := DynamicClient{
|
||||||
|
ClientID: hex.EncodeToString(b),
|
||||||
|
ClientName: name,
|
||||||
|
RedirectURIs: append([]string(nil), redirectURIs...),
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
}
|
||||||
|
s.mu.Lock()
|
||||||
|
s.clients[client.ClientID] = client
|
||||||
|
s.mu.Unlock()
|
||||||
|
return client, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Lookup returns the client for the given client_id.
|
||||||
|
func (s *DynamicClientStore) Lookup(clientID string) (DynamicClient, bool) {
|
||||||
|
s.mu.RLock()
|
||||||
|
client, ok := s.clients[clientID]
|
||||||
|
s.mu.RUnlock()
|
||||||
|
return client, ok
|
||||||
|
}
|
||||||
@@ -39,7 +39,7 @@ func TestMiddlewareAllowsHeaderAuthAndSetsContext(t *testing.T) {
|
|||||||
t.Fatalf("NewKeyring() error = %v", err)
|
t.Fatalf("NewKeyring() error = %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
handler := Middleware(config.AuthConfig{HeaderName: "x-brain-key"}, keyring, testLogger())(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
handler := Middleware(config.AuthConfig{HeaderName: "x-brain-key"}, keyring, nil, nil, testLogger())(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
keyID, ok := KeyIDFromContext(r.Context())
|
keyID, ok := KeyIDFromContext(r.Context())
|
||||||
if !ok || keyID != "client-a" {
|
if !ok || keyID != "client-a" {
|
||||||
t.Fatalf("KeyIDFromContext() = (%q, %v), want (client-a, true)", keyID, ok)
|
t.Fatalf("KeyIDFromContext() = (%q, %v), want (client-a, true)", keyID, ok)
|
||||||
@@ -63,7 +63,7 @@ func TestMiddlewareAllowsBearerAuthAndSetsContext(t *testing.T) {
|
|||||||
t.Fatalf("NewKeyring() error = %v", err)
|
t.Fatalf("NewKeyring() error = %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
handler := Middleware(config.AuthConfig{HeaderName: "x-brain-key"}, keyring, testLogger())(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
handler := Middleware(config.AuthConfig{HeaderName: "x-brain-key"}, keyring, nil, nil, testLogger())(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
keyID, ok := KeyIDFromContext(r.Context())
|
keyID, ok := KeyIDFromContext(r.Context())
|
||||||
if !ok || keyID != "client-a" {
|
if !ok || keyID != "client-a" {
|
||||||
t.Fatalf("KeyIDFromContext() = (%q, %v), want (client-a, true)", keyID, ok)
|
t.Fatalf("KeyIDFromContext() = (%q, %v), want (client-a, true)", keyID, ok)
|
||||||
@@ -90,7 +90,7 @@ func TestMiddlewarePrefersExplicitHeaderOverBearerAuth(t *testing.T) {
|
|||||||
t.Fatalf("NewKeyring() error = %v", err)
|
t.Fatalf("NewKeyring() error = %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
handler := Middleware(config.AuthConfig{HeaderName: "x-brain-key"}, keyring, testLogger())(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
handler := Middleware(config.AuthConfig{HeaderName: "x-brain-key"}, keyring, nil, nil, testLogger())(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
keyID, ok := KeyIDFromContext(r.Context())
|
keyID, ok := KeyIDFromContext(r.Context())
|
||||||
if !ok || keyID != "client-a" {
|
if !ok || keyID != "client-a" {
|
||||||
t.Fatalf("KeyIDFromContext() = (%q, %v), want (client-a, true)", keyID, ok)
|
t.Fatalf("KeyIDFromContext() = (%q, %v), want (client-a, true)", keyID, ok)
|
||||||
@@ -119,7 +119,7 @@ func TestMiddlewareAllowsQueryParamWhenEnabled(t *testing.T) {
|
|||||||
HeaderName: "x-brain-key",
|
HeaderName: "x-brain-key",
|
||||||
QueryParam: "key",
|
QueryParam: "key",
|
||||||
AllowQueryParam: true,
|
AllowQueryParam: true,
|
||||||
}, keyring, testLogger())(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
}, keyring, nil, nil, testLogger())(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.WriteHeader(http.StatusNoContent)
|
w.WriteHeader(http.StatusNoContent)
|
||||||
}))
|
}))
|
||||||
|
|
||||||
@@ -138,7 +138,7 @@ func TestMiddlewareRejectsMissingOrInvalidKey(t *testing.T) {
|
|||||||
t.Fatalf("NewKeyring() error = %v", err)
|
t.Fatalf("NewKeyring() error = %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
handler := Middleware(config.AuthConfig{HeaderName: "x-brain-key"}, keyring, testLogger())(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
handler := Middleware(config.AuthConfig{HeaderName: "x-brain-key"}, keyring, nil, nil, testLogger())(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
t.Fatal("next handler should not be called")
|
t.Fatal("next handler should not be called")
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package auth
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -13,32 +14,77 @@ type contextKey string
|
|||||||
|
|
||||||
const keyIDContextKey contextKey = "auth.key_id"
|
const keyIDContextKey contextKey = "auth.key_id"
|
||||||
|
|
||||||
func Middleware(cfg config.AuthConfig, keyring *Keyring, log *slog.Logger) func(http.Handler) http.Handler {
|
func Middleware(cfg config.AuthConfig, keyring *Keyring, oauthRegistry *OAuthRegistry, tokenStore *TokenStore, log *slog.Logger) func(http.Handler) http.Handler {
|
||||||
headerName := cfg.HeaderName
|
headerName := cfg.HeaderName
|
||||||
if headerName == "" {
|
if headerName == "" {
|
||||||
headerName = "x-brain-key"
|
headerName = "x-brain-key"
|
||||||
}
|
}
|
||||||
|
|
||||||
return func(next http.Handler) http.Handler {
|
return func(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
token := extractToken(r, headerName)
|
// 1. Custom header → keyring only.
|
||||||
if token == "" && cfg.AllowQueryParam {
|
if keyring != nil {
|
||||||
token = strings.TrimSpace(r.URL.Query().Get(cfg.QueryParam))
|
if token := strings.TrimSpace(r.Header.Get(headerName)); token != "" {
|
||||||
|
keyID, ok := keyring.Lookup(token)
|
||||||
|
if !ok {
|
||||||
|
log.Warn("authentication failed", slog.String("remote_addr", r.RemoteAddr))
|
||||||
|
http.Error(w, "invalid API key", http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), keyIDContextKey, keyID)))
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if token == "" {
|
// 2. Bearer token → tokenStore (OAuth), then keyring (API key).
|
||||||
http.Error(w, "missing API key", http.StatusUnauthorized)
|
if bearer := extractBearer(r); bearer != "" {
|
||||||
|
if tokenStore != nil {
|
||||||
|
if keyID, ok := tokenStore.Lookup(bearer); ok {
|
||||||
|
next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), keyIDContextKey, keyID)))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if keyring != nil {
|
||||||
|
if keyID, ok := keyring.Lookup(bearer); ok {
|
||||||
|
next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), keyIDContextKey, keyID)))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
log.Warn("bearer token rejected", slog.String("remote_addr", r.RemoteAddr))
|
||||||
|
http.Error(w, "invalid token or API key", http.StatusUnauthorized)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
keyID, ok := keyring.Lookup(token)
|
// 3. HTTP Basic → oauthRegistry (direct client credentials).
|
||||||
if !ok {
|
if clientID, clientSecret := extractOAuthClientCredentials(r); clientID != "" {
|
||||||
log.Warn("authentication failed", slog.String("remote_addr", r.RemoteAddr))
|
if oauthRegistry == nil {
|
||||||
http.Error(w, "invalid API key", http.StatusUnauthorized)
|
http.Error(w, "authentication is not configured", http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
keyID, ok := oauthRegistry.Lookup(clientID, clientSecret)
|
||||||
|
if !ok {
|
||||||
|
log.Warn("oauth client authentication failed", slog.String("remote_addr", r.RemoteAddr))
|
||||||
|
http.Error(w, "invalid OAuth client credentials", http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), keyIDContextKey, keyID)))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), keyIDContextKey, keyID)))
|
// 4. Query param → keyring.
|
||||||
|
if keyring != nil && cfg.AllowQueryParam {
|
||||||
|
if token := strings.TrimSpace(r.URL.Query().Get(cfg.QueryParam)); token != "" {
|
||||||
|
keyID, ok := keyring.Lookup(token)
|
||||||
|
if !ok {
|
||||||
|
log.Warn("authentication failed", slog.String("remote_addr", r.RemoteAddr))
|
||||||
|
http.Error(w, "invalid API key", http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), keyIDContextKey, keyID)))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
http.Error(w, "authentication required", http.StatusUnauthorized)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -58,6 +104,30 @@ func extractToken(r *http.Request, headerName string) string {
|
|||||||
return strings.TrimSpace(credentials)
|
return strings.TrimSpace(credentials)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func extractBearer(r *http.Request) string {
|
||||||
|
authHeader := strings.TrimSpace(r.Header.Get("Authorization"))
|
||||||
|
scheme, credentials, ok := strings.Cut(authHeader, " ")
|
||||||
|
if !ok || !strings.EqualFold(scheme, "Bearer") {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return strings.TrimSpace(credentials)
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractOAuthClientCredentials(r *http.Request) (string, string) {
|
||||||
|
authHeader := strings.TrimSpace(r.Header.Get("Authorization"))
|
||||||
|
scheme, credentials, ok := strings.Cut(authHeader, " ")
|
||||||
|
if ok && strings.EqualFold(scheme, "Basic") {
|
||||||
|
decoded, err := base64.StdEncoding.DecodeString(strings.TrimSpace(credentials))
|
||||||
|
if err == nil {
|
||||||
|
clientID, clientSecret, found := strings.Cut(string(decoded), ":")
|
||||||
|
if found {
|
||||||
|
return strings.TrimSpace(clientID), strings.TrimSpace(clientSecret)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "", ""
|
||||||
|
}
|
||||||
|
|
||||||
func KeyIDFromContext(ctx context.Context) (string, bool) {
|
func KeyIDFromContext(ctx context.Context) (string, bool) {
|
||||||
value, ok := ctx.Value(keyIDContextKey).(string)
|
value, ok := ctx.Value(keyIDContextKey).(string)
|
||||||
return value, ok
|
return value, ok
|
||||||
|
|||||||
33
internal/auth/oauth_registry.go
Normal file
33
internal/auth/oauth_registry.go
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/subtle"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"git.warky.dev/wdevs/amcs/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
type OAuthRegistry struct {
|
||||||
|
clients []config.OAuthClient
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewOAuthRegistry(clients []config.OAuthClient) (*OAuthRegistry, error) {
|
||||||
|
if len(clients) == 0 {
|
||||||
|
return nil, fmt.Errorf("oauth registry requires at least one client")
|
||||||
|
}
|
||||||
|
|
||||||
|
return &OAuthRegistry{clients: append([]config.OAuthClient(nil), clients...)}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *OAuthRegistry) Lookup(clientID string, clientSecret string) (string, bool) {
|
||||||
|
for _, client := range o.clients {
|
||||||
|
if subtle.ConstantTimeCompare([]byte(client.ClientID), []byte(clientID)) == 1 &&
|
||||||
|
subtle.ConstantTimeCompare([]byte(client.ClientSecret), []byte(clientSecret)) == 1 {
|
||||||
|
if client.ID != "" {
|
||||||
|
return client.ID, true
|
||||||
|
}
|
||||||
|
return client.ClientID, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
92
internal/auth/oauth_registry_test.go
Normal file
92
internal/auth/oauth_registry_test.go
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.warky.dev/wdevs/amcs/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewOAuthRegistryAndLookup(t *testing.T) {
|
||||||
|
_, err := NewOAuthRegistry(nil)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("NewOAuthRegistry(nil) error = nil, want error")
|
||||||
|
}
|
||||||
|
|
||||||
|
registry, err := NewOAuthRegistry([]config.OAuthClient{{
|
||||||
|
ID: "oauth-client",
|
||||||
|
ClientID: "client-id",
|
||||||
|
ClientSecret: "client-secret",
|
||||||
|
}})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewOAuthRegistry() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got, ok := registry.Lookup("client-id", "client-secret"); !ok || got != "oauth-client" {
|
||||||
|
t.Fatalf("Lookup(client-id, client-secret) = (%q, %v), want (oauth-client, true)", got, ok)
|
||||||
|
}
|
||||||
|
if _, ok := registry.Lookup("client-id", "wrong"); ok {
|
||||||
|
t.Fatal("Lookup(client-id, wrong) = true, want false")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMiddlewareAllowsOAuthBasicAuthAndSetsContext(t *testing.T) {
|
||||||
|
oauthRegistry, err := NewOAuthRegistry([]config.OAuthClient{{
|
||||||
|
ID: "oauth-client",
|
||||||
|
ClientID: "client-id",
|
||||||
|
ClientSecret: "client-secret",
|
||||||
|
}})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewOAuthRegistry() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := Middleware(config.AuthConfig{}, nil, oauthRegistry, nil, testLogger())(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
keyID, ok := KeyIDFromContext(r.Context())
|
||||||
|
if !ok || keyID != "oauth-client" {
|
||||||
|
t.Fatalf("KeyIDFromContext() = (%q, %v), want (oauth-client, true)", keyID, ok)
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/mcp", nil)
|
||||||
|
req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte("client-id:client-secret")))
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
func TestMiddlewareRejectsOAuthMissingOrInvalidCredentials(t *testing.T) {
|
||||||
|
oauthRegistry, err := NewOAuthRegistry([]config.OAuthClient{{
|
||||||
|
ID: "oauth-client",
|
||||||
|
ClientID: "client-id",
|
||||||
|
ClientSecret: "client-secret",
|
||||||
|
}})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewOAuthRegistry() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := Middleware(config.AuthConfig{}, nil, oauthRegistry, nil, testLogger())(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
t.Fatal("next handler should not be called")
|
||||||
|
}))
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/mcp", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(rec, req)
|
||||||
|
if rec.Code != http.StatusUnauthorized {
|
||||||
|
t.Fatalf("missing credentials status = %d, want %d", rec.Code, http.StatusUnauthorized)
|
||||||
|
}
|
||||||
|
|
||||||
|
req = httptest.NewRequest(http.MethodGet, "/mcp", nil)
|
||||||
|
req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte("client-id:wrong")))
|
||||||
|
rec = httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(rec, req)
|
||||||
|
if rec.Code != http.StatusUnauthorized {
|
||||||
|
t.Fatalf("invalid credentials status = %d, want %d", rec.Code, http.StatusUnauthorized)
|
||||||
|
}
|
||||||
|
}
|
||||||
74
internal/auth/token_store.go
Normal file
74
internal/auth/token_store.go
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/hex"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const defaultTokenTTL = time.Hour
|
||||||
|
|
||||||
|
type tokenEntry struct {
|
||||||
|
keyID string
|
||||||
|
expiresAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// TokenStore issues and validates short-lived opaque access tokens for OAuth
|
||||||
|
// client credentials flow.
|
||||||
|
type TokenStore struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
tokens map[string]tokenEntry
|
||||||
|
ttl time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewTokenStore(ttl time.Duration) *TokenStore {
|
||||||
|
if ttl <= 0 {
|
||||||
|
ttl = defaultTokenTTL
|
||||||
|
}
|
||||||
|
s := &TokenStore{
|
||||||
|
tokens: make(map[string]tokenEntry),
|
||||||
|
ttl: ttl,
|
||||||
|
}
|
||||||
|
go s.sweepLoop()
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// Issue generates a new token for the given keyID and returns the token and its TTL.
|
||||||
|
func (s *TokenStore) Issue(keyID string) (string, time.Duration, error) {
|
||||||
|
b := make([]byte, 32)
|
||||||
|
if _, err := rand.Read(b); err != nil {
|
||||||
|
return "", 0, err
|
||||||
|
}
|
||||||
|
token := hex.EncodeToString(b)
|
||||||
|
s.mu.Lock()
|
||||||
|
s.tokens[token] = tokenEntry{keyID: keyID, expiresAt: time.Now().Add(s.ttl)}
|
||||||
|
s.mu.Unlock()
|
||||||
|
return token, s.ttl, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Lookup validates a token and returns the associated keyID.
|
||||||
|
func (s *TokenStore) Lookup(token string) (string, bool) {
|
||||||
|
s.mu.RLock()
|
||||||
|
entry, ok := s.tokens[token]
|
||||||
|
s.mu.RUnlock()
|
||||||
|
if !ok || time.Now().After(entry.expiresAt) {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
return entry.keyID, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *TokenStore) sweepLoop() {
|
||||||
|
ticker := time.NewTicker(5 * time.Minute)
|
||||||
|
defer ticker.Stop()
|
||||||
|
for range ticker.C {
|
||||||
|
now := time.Now()
|
||||||
|
s.mu.Lock()
|
||||||
|
for token, entry := range s.tokens {
|
||||||
|
if now.After(entry.expiresAt) {
|
||||||
|
delete(s.tokens, token)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s.mu.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -36,11 +36,11 @@ type MCPConfig struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type AuthConfig struct {
|
type AuthConfig struct {
|
||||||
Mode string `yaml:"mode"`
|
HeaderName string `yaml:"header_name"`
|
||||||
HeaderName string `yaml:"header_name"`
|
QueryParam string `yaml:"query_param"`
|
||||||
QueryParam string `yaml:"query_param"`
|
AllowQueryParam bool `yaml:"allow_query_param"`
|
||||||
AllowQueryParam bool `yaml:"allow_query_param"`
|
Keys []APIKey `yaml:"keys"`
|
||||||
Keys []APIKey `yaml:"keys"`
|
OAuth OAuthConfig `yaml:"oauth"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type APIKey struct {
|
type APIKey struct {
|
||||||
@@ -49,6 +49,17 @@ type APIKey struct {
|
|||||||
Description string `yaml:"description"`
|
Description string `yaml:"description"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type OAuthConfig struct {
|
||||||
|
Clients []OAuthClient `yaml:"clients"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type OAuthClient struct {
|
||||||
|
ID string `yaml:"id"`
|
||||||
|
ClientID string `yaml:"client_id"`
|
||||||
|
ClientSecret string `yaml:"client_secret"`
|
||||||
|
Description string `yaml:"description"`
|
||||||
|
}
|
||||||
|
|
||||||
type DatabaseConfig struct {
|
type DatabaseConfig struct {
|
||||||
URL string `yaml:"url"`
|
URL string `yaml:"url"`
|
||||||
MaxConns int32 `yaml:"max_conns"`
|
MaxConns int32 `yaml:"max_conns"`
|
||||||
|
|||||||
@@ -32,8 +32,10 @@ func Load(explicitPath string) (*Config, string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func ResolvePath(explicitPath string) string {
|
func ResolvePath(explicitPath string) string {
|
||||||
if strings.TrimSpace(explicitPath) != "" {
|
if path := strings.TrimSpace(explicitPath); path != "" {
|
||||||
return explicitPath
|
if path != ".yaml" && path != ".yml" {
|
||||||
|
return path
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if envPath := strings.TrimSpace(os.Getenv("AMCS_CONFIG")); envPath != "" {
|
if envPath := strings.TrimSpace(os.Getenv("AMCS_CONFIG")); envPath != "" {
|
||||||
@@ -59,7 +61,6 @@ func defaultConfig() Config {
|
|||||||
Transport: "streamable_http",
|
Transport: "streamable_http",
|
||||||
},
|
},
|
||||||
Auth: AuthConfig{
|
Auth: AuthConfig{
|
||||||
Mode: "api_keys",
|
|
||||||
HeaderName: "x-brain-key",
|
HeaderName: "x-brain-key",
|
||||||
QueryParam: "key",
|
QueryParam: "key",
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -18,6 +18,18 @@ func TestResolvePathPrecedence(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestResolvePathIgnoresBareYAMLExtension(t *testing.T) {
|
||||||
|
t.Setenv("AMCS_CONFIG", "/tmp/from-env.yaml")
|
||||||
|
|
||||||
|
if got := ResolvePath(".yaml"); got != "/tmp/from-env.yaml" {
|
||||||
|
t.Fatalf("ResolvePath(.yaml) = %q, want %q", got, "/tmp/from-env.yaml")
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := ResolvePath(".yml"); got != "/tmp/from-env.yaml" {
|
||||||
|
t.Fatalf("ResolvePath(.yml) = %q, want %q", got, "/tmp/from-env.yaml")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestLoadAppliesEnvOverrides(t *testing.T) {
|
func TestLoadAppliesEnvOverrides(t *testing.T) {
|
||||||
configPath := filepath.Join(t.TempDir(), "test.yaml")
|
configPath := filepath.Join(t.TempDir(), "test.yaml")
|
||||||
if err := os.WriteFile(configPath, []byte(`
|
if err := os.WriteFile(configPath, []byte(`
|
||||||
|
|||||||
@@ -10,10 +10,9 @@ func (c Config) Validate() error {
|
|||||||
return fmt.Errorf("invalid config: database.url is required")
|
return fmt.Errorf("invalid config: database.url is required")
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(c.Auth.Keys) == 0 {
|
if len(c.Auth.Keys) == 0 && len(c.Auth.OAuth.Clients) == 0 {
|
||||||
return fmt.Errorf("invalid config: auth.keys must not be empty")
|
return fmt.Errorf("invalid config: at least one of auth.keys or auth.oauth.clients must be configured")
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, key := range c.Auth.Keys {
|
for i, key := range c.Auth.Keys {
|
||||||
if strings.TrimSpace(key.ID) == "" {
|
if strings.TrimSpace(key.ID) == "" {
|
||||||
return fmt.Errorf("invalid config: auth.keys[%d].id is required", i)
|
return fmt.Errorf("invalid config: auth.keys[%d].id is required", i)
|
||||||
@@ -22,6 +21,14 @@ func (c Config) Validate() error {
|
|||||||
return fmt.Errorf("invalid config: auth.keys[%d].value is required", i)
|
return fmt.Errorf("invalid config: auth.keys[%d].value is required", i)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
for i, client := range c.Auth.OAuth.Clients {
|
||||||
|
if strings.TrimSpace(client.ClientID) == "" {
|
||||||
|
return fmt.Errorf("invalid config: auth.oauth.clients[%d].client_id is required", i)
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(client.ClientSecret) == "" {
|
||||||
|
return fmt.Errorf("invalid config: auth.oauth.clients[%d].client_secret is required", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if strings.TrimSpace(c.MCP.Path) == "" {
|
if strings.TrimSpace(c.MCP.Path) == "" {
|
||||||
return fmt.Errorf("invalid config: mcp.path is required")
|
return fmt.Errorf("invalid config: mcp.path is required")
|
||||||
|
|||||||
@@ -67,3 +67,47 @@ func TestValidateRejectsEmptyAuthKeyValue(t *testing.T) {
|
|||||||
t.Fatal("Validate() error = nil, want error for empty auth key value")
|
t.Fatal("Validate() error = nil, want error for empty auth key value")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestValidateAcceptsOAuthClients(t *testing.T) {
|
||||||
|
cfg := validConfig()
|
||||||
|
cfg.Auth = AuthConfig{
|
||||||
|
OAuth: OAuthConfig{
|
||||||
|
Clients: []OAuthClient{{
|
||||||
|
ID: "oauth-client",
|
||||||
|
ClientID: "client-id",
|
||||||
|
ClientSecret: "client-secret",
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := cfg.Validate(); err != nil {
|
||||||
|
t.Fatalf("Validate() error = %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateAcceptsBothAuthMethods(t *testing.T) {
|
||||||
|
cfg := validConfig()
|
||||||
|
cfg.Auth = AuthConfig{
|
||||||
|
Keys: []APIKey{{ID: "key1", Value: "secret"}},
|
||||||
|
OAuth: OAuthConfig{
|
||||||
|
Clients: []OAuthClient{{
|
||||||
|
ID: "oauth-client",
|
||||||
|
ClientID: "client-id",
|
||||||
|
ClientSecret: "client-secret",
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := cfg.Validate(); err != nil {
|
||||||
|
t.Fatalf("Validate() error = %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateRejectsEmptyAuth(t *testing.T) {
|
||||||
|
cfg := validConfig()
|
||||||
|
cfg.Auth = AuthConfig{}
|
||||||
|
|
||||||
|
if err := cfg.Validate(); err == nil {
|
||||||
|
t.Fatal("Validate() error = nil, want error when neither auth.keys nor auth.oauth.clients is configured")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
# AMCS Memory Instructions
|
# AMCS Memory Instructions
|
||||||
|
|
||||||
|
AMCS (Avalon Memory Crystal Server) is an MCP server for capturing and retrieving thoughts, memory, and project context. It is backed by Postgres with pgvector for semantic search.
|
||||||
|
|
||||||
You have access to an MCP memory server named AMCS.
|
You have access to an MCP memory server named AMCS.
|
||||||
|
|
||||||
Use AMCS as memory with two scopes:
|
Use AMCS as memory with two scopes:
|
||||||
@@ -19,8 +21,8 @@ Use AMCS as memory with two scopes:
|
|||||||
## Project Memory Rules
|
## Project Memory Rules
|
||||||
|
|
||||||
- Use project memory for code decisions, architecture, TODOs, debugging findings, and context specific to the current repo or workstream.
|
- Use project memory for code decisions, architecture, TODOs, debugging findings, and context specific to the current repo or workstream.
|
||||||
- Before substantial work, retrieve context with `get_project_context` or `recall_context`.
|
- Before substantial work, always retrieve context with `get_project_context` or `recall_context` so prior decisions inform your approach.
|
||||||
- Save durable project facts with `capture_thought`.
|
- Save durable project facts with `capture_thought` after completing meaningful work.
|
||||||
- Do not attach memory to the wrong project.
|
- Do not attach memory to the wrong project.
|
||||||
|
|
||||||
## Global Notebook Rules
|
## Global Notebook Rules
|
||||||
|
|||||||
Reference in New Issue
Block a user