From 6502b55797d45f0152e5278835d3f1676cf09264 Mon Sep 17 00:00:00 2001 From: Hein Date: Tue, 7 Apr 2026 22:56:05 +0200 Subject: [PATCH] feat(security): implement OAuth2 authorization server with database support - Add OAuthServer for handling OAuth2 flows including authorization, token exchange, and client registration. - Introduce DatabaseAuthenticator for persisting clients and authorization codes. - Implement SQL procedures for client registration, code saving, and token introspection. - Support for external OAuth2 providers and PKCE (Proof Key for Code Exchange). --- pkg/resolvemcp/README.md | 132 ++++- pkg/resolvemcp/handler.go | 17 +- pkg/resolvemcp/oauth2.go | 264 ++++++++++ pkg/resolvemcp/oauth2_server.go | 51 ++ pkg/security/README.md | 156 ++++++ pkg/security/database_schema.sql | 170 ++++++ pkg/security/oauth_server.go | 859 +++++++++++++++++++++++++++++++ pkg/security/oauth_server_db.go | 202 ++++++++ pkg/security/sql_names.go | 32 ++ 9 files changed, 1874 insertions(+), 9 deletions(-) create mode 100644 pkg/resolvemcp/oauth2.go create mode 100644 pkg/resolvemcp/oauth2_server.go create mode 100644 pkg/security/oauth_server.go create mode 100644 pkg/security/oauth_server_db.go diff --git a/pkg/resolvemcp/README.md b/pkg/resolvemcp/README.md index 68399c8..660599f 100644 --- a/pkg/resolvemcp/README.md +++ b/pkg/resolvemcp/README.md @@ -142,9 +142,137 @@ e.Any("/mcp", echo.WrapHandler(h)) // Echo --- -### Authentication +## OAuth2 Authentication -Add middleware before the MCP routes. The handler itself has no auth layer. +`resolvemcp` ships a full **MCP-standard OAuth2 authorization server** (`pkg/security.OAuthServer`) that MCP clients (Claude Desktop, Cursor, etc.) can discover and use automatically. + +It can operate as: +- **Its own identity provider** — shows a login form, validates via `DatabaseAuthenticator.Login()` +- **An OAuth2 federation layer** — delegates to external providers (Google, GitHub, Microsoft, etc.) +- **Both simultaneously** + +### Standard endpoints served + +| Path | Spec | Purpose | +|---|---|---| +| `GET /.well-known/oauth-authorization-server` | RFC 8414 | MCP client auto-discovery | +| `POST /oauth/register` | RFC 7591 | Dynamic client registration | +| `GET /oauth/authorize` | OAuth 2.1 + PKCE | Start login (form or provider redirect) | +| `POST /oauth/authorize` | — | Login form submission | +| `POST /oauth/token` | OAuth 2.1 | Auth code → Bearer token exchange | +| `POST /oauth/token` (refresh) | OAuth 2.1 | Refresh token rotation | +| `GET /oauth/provider/callback` | Internal | External provider redirect target | + +MCP clients send `Authorization: Bearer ` on all subsequent requests. + +--- + +### Mode 1 — Direct login (server as identity provider) + +```go +import "github.com/bitechdev/ResolveSpec/pkg/security" + +db, _ := sql.Open("postgres", dsn) +auth := security.NewDatabaseAuthenticator(db) + +handler := resolvemcp.NewHandlerWithGORM(gormDB, resolvemcp.Config{ + BaseURL: "https://api.example.com", + BasePath: "/mcp", +}) + +// Enable the OAuth2 server — auth enables the login form +handler.EnableOAuthServer(security.OAuthServerConfig{ + Issuer: "https://api.example.com", +}, auth) + +provider, _ := security.NewCompositeSecurityProvider(auth, colSec, rowSec) +securityList, _ := security.NewSecurityList(provider) +security.RegisterSecurityHooks(handler, securityList) + +http.ListenAndServe(":8080", handler.HTTPHandler(securityList)) +``` + +MCP client flow: +1. Discovers server at `/.well-known/oauth-authorization-server` +2. Registers itself at `/oauth/register` +3. Redirects user to `/oauth/authorize` → login form appears +4. On submit, exchanges code at `/oauth/token` → receives `Authorization: Bearer` token +5. Uses token on all MCP tool calls + +--- + +### Mode 2 — External provider (Google, GitHub, etc.) + +The `RedirectURL` in the provider config must point to `/oauth/provider/callback` on this server. + +```go +auth := security.NewDatabaseAuthenticator(db).WithOAuth2(security.OAuth2Config{ + ClientID: os.Getenv("GOOGLE_CLIENT_ID"), + ClientSecret: os.Getenv("GOOGLE_CLIENT_SECRET"), + RedirectURL: "https://api.example.com/oauth/provider/callback", + Scopes: []string{"openid", "profile", "email"}, + AuthURL: "https://accounts.google.com/o/oauth2/auth", + TokenURL: "https://oauth2.googleapis.com/token", + UserInfoURL: "https://www.googleapis.com/oauth2/v2/userinfo", + ProviderName: "google", +}) + +// nil = no password login; Google handles auth +handler.EnableOAuthServer(security.OAuthServerConfig{ + Issuer: "https://api.example.com", +}, nil) +handler.RegisterOAuth2Provider(auth, "google") +``` + +--- + +### Mode 3 — Both (login form + external providers) + +```go +handler.EnableOAuthServer(security.OAuthServerConfig{ + Issuer: "https://api.example.com", + LoginTitle: "My App Login", +}, auth) // auth enables the username/password form + +handler.RegisterOAuth2Provider(googleAuth, "google") +handler.RegisterOAuth2Provider(githubAuth, "github") +``` + +When external providers are registered they take priority; the login form is used as fallback when no providers are configured. + +--- + +### Using `security.OAuthServer` standalone + +The authorization server lives in `pkg/security` and can be used with any HTTP framework independently of `resolvemcp`: + +```go +oauthSrv := security.NewOAuthServer(security.OAuthServerConfig{ + Issuer: "https://api.example.com", +}, auth) +oauthSrv.RegisterExternalProvider(googleAuth, "google") + +mux := http.NewServeMux() +mux.Handle("/", oauthSrv.HTTPHandler()) // mounts all OAuth2 routes +mux.Handle("/mcp/", myMCPHandler) +http.ListenAndServe(":8080", mux) +``` + +--- + +### Cookie-based flow (legacy) + +For simple setups without full MCP OAuth2 compliance, use the legacy helpers that set a session cookie after external provider login: + +```go +resolvemcp.SetupMuxOAuth2Routes(r, auth, resolvemcp.OAuth2RouteConfig{ + ProviderName: "google", + LoginPath: "/auth/google/login", + CallbackPath: "/auth/google/callback", + AfterLoginRedirect: "/", +}) +resolvemcp.SetupMuxRoutesWithAuth(r, handler, securityList) +``` --- diff --git a/pkg/resolvemcp/handler.go b/pkg/resolvemcp/handler.go index 4ed7c9c..df29499 100644 --- a/pkg/resolvemcp/handler.go +++ b/pkg/resolvemcp/handler.go @@ -16,17 +16,20 @@ import ( "github.com/bitechdev/ResolveSpec/pkg/logger" "github.com/bitechdev/ResolveSpec/pkg/modelregistry" "github.com/bitechdev/ResolveSpec/pkg/reflection" + "github.com/bitechdev/ResolveSpec/pkg/security" ) // Handler exposes registered database models as MCP tools and resources. type Handler struct { - db common.Database - registry common.ModelRegistry - hooks *HookRegistry - mcpServer *server.MCPServer - config Config - name string - version string + db common.Database + registry common.ModelRegistry + hooks *HookRegistry + mcpServer *server.MCPServer + config Config + name string + version string + oauth2Regs []oauth2Registration + oauthSrv *security.OAuthServer } // NewHandler creates a Handler with the given database, model registry, and config. diff --git a/pkg/resolvemcp/oauth2.go b/pkg/resolvemcp/oauth2.go new file mode 100644 index 0000000..5ffa07e --- /dev/null +++ b/pkg/resolvemcp/oauth2.go @@ -0,0 +1,264 @@ +package resolvemcp + +import ( + "encoding/json" + "net/http" + + "github.com/gorilla/mux" + + "github.com/bitechdev/ResolveSpec/pkg/security" +) + +// -------------------------------------------------------------------------- +// OAuth2 registration on the Handler +// -------------------------------------------------------------------------- + +// oauth2Registration stores a configured auth provider and its route config. +type oauth2Registration struct { + auth *security.DatabaseAuthenticator + cfg OAuth2RouteConfig +} + +// RegisterOAuth2 attaches an OAuth2 provider to the Handler. +// The login and callback HTTP routes are served by HTTPHandler / StreamableHTTPMux. +// Call this once per provider before serving requests. +// +// Example: +// +// auth := security.NewGoogleAuthenticator(clientID, secret, redirectURL, db) +// handler.RegisterOAuth2(auth, resolvemcp.OAuth2RouteConfig{ +// ProviderName: "google", +// LoginPath: "/auth/google/login", +// CallbackPath: "/auth/google/callback", +// AfterLoginRedirect: "/", +// }) +func (h *Handler) RegisterOAuth2(auth *security.DatabaseAuthenticator, cfg OAuth2RouteConfig) { + h.oauth2Regs = append(h.oauth2Regs, oauth2Registration{auth: auth, cfg: cfg}) +} + +// HTTPHandler returns a single http.Handler that serves: +// - MCP OAuth2 authorization server endpoints (when EnableOAuthServer has been called) +// - OAuth2 login + callback routes for every registered provider (legacy cookie flow) +// - The MCP SSE transport wrapped with required authentication middleware +// +// Example: +// +// auth := security.NewGoogleAuthenticator(...) +// handler.RegisterOAuth2(auth, cfg) +// handler.EnableOAuthServer(resolvemcp.OAuthServerConfig{Issuer: "https://api.example.com"}) +// security.RegisterSecurityHooks(handler, securityList) +// http.ListenAndServe(":8080", handler.HTTPHandler(securityList)) +func (h *Handler) HTTPHandler(securityList *security.SecurityList) http.Handler { + mux := http.NewServeMux() + if h.oauthSrv != nil { + h.mountOAuthServerRoutes(mux) + } + h.mountOAuth2Routes(mux) + + mcpHandler := h.AuthedSSEServer(securityList) + basePath := h.config.BasePath + if basePath == "" { + basePath = "/mcp" + } + mux.Handle(basePath+"/sse", mcpHandler) + mux.Handle(basePath+"/message", mcpHandler) + mux.Handle(basePath+"/", http.StripPrefix(basePath, mcpHandler)) + + return mux +} + +// StreamableHTTPMux returns a single http.Handler that serves: +// - MCP OAuth2 authorization server endpoints (when EnableOAuthServer has been called) +// - OAuth2 login + callback routes for every registered provider (legacy cookie flow) +// - The MCP streamable HTTP transport wrapped with required authentication middleware +// +// Example: +// +// http.ListenAndServe(":8080", handler.StreamableHTTPMux(securityList)) +func (h *Handler) StreamableHTTPMux(securityList *security.SecurityList) http.Handler { + mux := http.NewServeMux() + if h.oauthSrv != nil { + h.mountOAuthServerRoutes(mux) + } + h.mountOAuth2Routes(mux) + + mcpHandler := h.AuthedStreamableHTTPServer(securityList) + basePath := h.config.BasePath + if basePath == "" { + basePath = "/mcp" + } + mux.Handle(basePath+"/", http.StripPrefix(basePath, mcpHandler)) + mux.Handle(basePath, mcpHandler) + + return mux +} + +// mountOAuth2Routes registers all stored OAuth2 login+callback routes onto mux. +func (h *Handler) mountOAuth2Routes(mux *http.ServeMux) { + for _, reg := range h.oauth2Regs { + var cookieOpts []security.SessionCookieOptions + if reg.cfg.CookieOptions != nil { + cookieOpts = append(cookieOpts, *reg.cfg.CookieOptions) + } + mux.Handle(reg.cfg.LoginPath, OAuth2LoginHandler(reg.auth, reg.cfg.ProviderName)) + mux.Handle(reg.cfg.CallbackPath, OAuth2CallbackHandler(reg.auth, reg.cfg.ProviderName, reg.cfg.AfterLoginRedirect, cookieOpts...)) + } +} + +// -------------------------------------------------------------------------- +// Auth-wrapped transports +// -------------------------------------------------------------------------- + +// AuthedSSEServer wraps SSEServer with required authentication middleware from pkg/security. +// The middleware reads the session cookie / Authorization header and populates the user +// context into the request context, making it available to BeforeHandle security hooks. +// Unauthenticated requests receive 401 before reaching any MCP tool. +func (h *Handler) AuthedSSEServer(securityList *security.SecurityList) http.Handler { + return security.NewAuthMiddleware(securityList)(h.SSEServer()) +} + +// OptionalAuthSSEServer wraps SSEServer with optional authentication middleware. +// Unauthenticated requests continue as guest rather than returning 401. +// Use together with RegisterSecurityHooks and per-model CanPublicRead/Write rules +// to allow mixed public/private access. +func (h *Handler) OptionalAuthSSEServer(securityList *security.SecurityList) http.Handler { + return security.NewOptionalAuthMiddleware(securityList)(h.SSEServer()) +} + +// AuthedStreamableHTTPServer wraps StreamableHTTPServer with required authentication middleware. +func (h *Handler) AuthedStreamableHTTPServer(securityList *security.SecurityList) http.Handler { + return security.NewAuthMiddleware(securityList)(h.StreamableHTTPServer()) +} + +// OptionalAuthStreamableHTTPServer wraps StreamableHTTPServer with optional authentication middleware. +func (h *Handler) OptionalAuthStreamableHTTPServer(securityList *security.SecurityList) http.Handler { + return security.NewOptionalAuthMiddleware(securityList)(h.StreamableHTTPServer()) +} + +// -------------------------------------------------------------------------- +// OAuth2 route config and standalone handlers +// -------------------------------------------------------------------------- + +// OAuth2RouteConfig configures the OAuth2 HTTP endpoints for a single provider. +type OAuth2RouteConfig struct { + // ProviderName is the OAuth2 provider name as registered with WithOAuth2() + // (e.g. "google", "github", "microsoft"). + ProviderName string + + // LoginPath is the HTTP path that redirects the browser to the OAuth2 provider + // (e.g. "/auth/google/login"). + LoginPath string + + // CallbackPath is the HTTP path that the OAuth2 provider redirects back to + // (e.g. "/auth/google/callback"). Must match the RedirectURL in OAuth2Config. + CallbackPath string + + // AfterLoginRedirect is the URL to redirect the browser to after a successful + // login. When empty the LoginResponse JSON is written directly to the response. + AfterLoginRedirect string + + // CookieOptions customises the session cookie written on successful login. + // Defaults to HttpOnly, Secure, SameSite=Lax when nil. + CookieOptions *security.SessionCookieOptions +} + +// OAuth2LoginHandler returns an http.HandlerFunc that redirects the browser to +// the OAuth2 provider's authorization URL. +// +// Register it on any router: +// +// mux.Handle("/auth/google/login", resolvemcp.OAuth2LoginHandler(auth, "google")) +func OAuth2LoginHandler(auth *security.DatabaseAuthenticator, providerName string) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + state, err := auth.OAuth2GenerateState() + if err != nil { + http.Error(w, "failed to generate state", http.StatusInternalServerError) + return + } + authURL, err := auth.OAuth2GetAuthURL(providerName, state) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + http.Redirect(w, r, authURL, http.StatusTemporaryRedirect) + } +} + +// OAuth2CallbackHandler returns an http.HandlerFunc that handles the OAuth2 provider +// callback: exchanges the authorization code for a session token, writes the session +// cookie, then either redirects to afterLoginRedirect or writes the LoginResponse as JSON. +// +// Register it on any router: +// +// mux.Handle("/auth/google/callback", resolvemcp.OAuth2CallbackHandler(auth, "google", "/dashboard")) +func OAuth2CallbackHandler(auth *security.DatabaseAuthenticator, providerName, afterLoginRedirect string, cookieOpts ...security.SessionCookieOptions) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + code := r.URL.Query().Get("code") + state := r.URL.Query().Get("state") + if code == "" { + http.Error(w, "missing code parameter", http.StatusBadRequest) + return + } + + loginResp, err := auth.OAuth2HandleCallback(r.Context(), providerName, code, state) + if err != nil { + http.Error(w, err.Error(), http.StatusUnauthorized) + return + } + + security.SetSessionCookie(w, loginResp, cookieOpts...) + + if afterLoginRedirect != "" { + http.Redirect(w, r, afterLoginRedirect, http.StatusTemporaryRedirect) + return + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(loginResp) //nolint:errcheck + } +} + +// -------------------------------------------------------------------------- +// Gorilla Mux convenience helpers +// -------------------------------------------------------------------------- + +// SetupMuxOAuth2Routes registers the OAuth2 login and callback routes on a Gorilla Mux router. +// +// Example: +// +// resolvemcp.SetupMuxOAuth2Routes(r, auth, resolvemcp.OAuth2RouteConfig{ +// ProviderName: "google", LoginPath: "/auth/google/login", +// CallbackPath: "/auth/google/callback", AfterLoginRedirect: "/", +// }) +func SetupMuxOAuth2Routes(muxRouter *mux.Router, auth *security.DatabaseAuthenticator, cfg OAuth2RouteConfig) { + var cookieOpts []security.SessionCookieOptions + if cfg.CookieOptions != nil { + cookieOpts = append(cookieOpts, *cfg.CookieOptions) + } + + muxRouter.Handle(cfg.LoginPath, + OAuth2LoginHandler(auth, cfg.ProviderName), + ).Methods(http.MethodGet) + + muxRouter.Handle(cfg.CallbackPath, + OAuth2CallbackHandler(auth, cfg.ProviderName, cfg.AfterLoginRedirect, cookieOpts...), + ).Methods(http.MethodGet) +} + +// SetupMuxRoutesWithAuth mounts the MCP SSE endpoints on a Gorilla Mux router +// with required authentication middleware applied. +func SetupMuxRoutesWithAuth(muxRouter *mux.Router, handler *Handler, securityList *security.SecurityList) { + basePath := handler.config.BasePath + h := handler.AuthedSSEServer(securityList) + + muxRouter.Handle(basePath+"/sse", h).Methods(http.MethodGet, http.MethodOptions) + muxRouter.Handle(basePath+"/message", h).Methods(http.MethodPost, http.MethodOptions) + muxRouter.PathPrefix(basePath).Handler(http.StripPrefix(basePath, h)) +} + +// SetupMuxStreamableHTTPRoutesWithAuth mounts the MCP streamable HTTP endpoint on a +// Gorilla Mux router with required authentication middleware applied. +func SetupMuxStreamableHTTPRoutesWithAuth(muxRouter *mux.Router, handler *Handler, securityList *security.SecurityList) { + basePath := handler.config.BasePath + h := handler.AuthedStreamableHTTPServer(securityList) + muxRouter.PathPrefix(basePath).Handler(http.StripPrefix(basePath, h)) +} diff --git a/pkg/resolvemcp/oauth2_server.go b/pkg/resolvemcp/oauth2_server.go new file mode 100644 index 0000000..30d0073 --- /dev/null +++ b/pkg/resolvemcp/oauth2_server.go @@ -0,0 +1,51 @@ +package resolvemcp + +import ( + "net/http" + + "github.com/bitechdev/ResolveSpec/pkg/security" +) + +// EnableOAuthServer activates the MCP-standard OAuth2 authorization server on this Handler. +// +// Pass a DatabaseAuthenticator to enable direct username/password login — the server acts as +// its own identity provider and renders a login form at /oauth/authorize. Pass nil to use +// only external providers registered via RegisterOAuth2Provider. +// +// After calling this, HTTPHandler and StreamableHTTPMux serve the full set of RFC-compliant +// endpoints required by MCP clients alongside the MCP transport: +// +// GET /.well-known/oauth-authorization-server RFC 8414 — auto-discovery +// POST /oauth/register RFC 7591 — dynamic client registration +// GET /oauth/authorize OAuth 2.1 + PKCE — start login +// POST /oauth/authorize Login form submission (password flow) +// POST /oauth/token Bearer token exchange + refresh +// GET /oauth/provider/callback External provider redirect target +func (h *Handler) EnableOAuthServer(cfg security.OAuthServerConfig, auth *security.DatabaseAuthenticator) { + h.oauthSrv = security.NewOAuthServer(cfg, auth) + // Wire any external providers already registered via RegisterOAuth2 + for _, reg := range h.oauth2Regs { + h.oauthSrv.RegisterExternalProvider(reg.auth, reg.cfg.ProviderName) + } +} + +// RegisterOAuth2Provider adds an external OAuth2 provider to the MCP OAuth2 authorization server. +// EnableOAuthServer must be called before this. The auth must have been configured with +// WithOAuth2(providerName, ...) for the given provider name. +func (h *Handler) RegisterOAuth2Provider(auth *security.DatabaseAuthenticator, providerName string) { + if h.oauthSrv != nil { + h.oauthSrv.RegisterExternalProvider(auth, providerName) + } +} + +// mountOAuthServerRoutes mounts the security.OAuthServer's HTTP handler onto mux. +func (h *Handler) mountOAuthServerRoutes(mux *http.ServeMux) { + oauthHandler := h.oauthSrv.HTTPHandler() + // Delegate all /oauth/ and /.well-known/ paths to the OAuth server + mux.Handle("/.well-known/", oauthHandler) + mux.Handle("/oauth/", oauthHandler) + if h.oauthSrv != nil { + // Also mount the external provider callback path if it differs from /oauth/ + mux.Handle(h.oauthSrv.ProviderCallbackPath(), oauthHandler) + } +} diff --git a/pkg/security/README.md b/pkg/security/README.md index d39cd84..4891fef 100644 --- a/pkg/security/README.md +++ b/pkg/security/README.md @@ -12,6 +12,7 @@ Type-safe, composable security system for ResolveSpec with support for authentic - ✅ **Testable** - Easy to mock and test - ✅ **Extensible** - Implement custom providers for your needs - ✅ **Stored Procedures** - All database operations use PostgreSQL stored procedures for security and maintainability +- ✅ **OAuth2 Authorization Server** - Built-in OAuth 2.1 + PKCE server (RFC 8414, 7591, 7009, 7662) with login form and external provider federation ## Stored Procedure Architecture @@ -38,6 +39,12 @@ Type-safe, composable security system for ResolveSpec with support for authentic | `resolvespec_jwt_logout` | JWT token blacklist | JWTAuthenticator | | `resolvespec_column_security` | Load column rules | DatabaseColumnSecurityProvider | | `resolvespec_row_security` | Load row templates | DatabaseRowSecurityProvider | +| `resolvespec_oauth_register_client` | Persist OAuth2 client (RFC 7591) | OAuthServer / DatabaseAuthenticator | +| `resolvespec_oauth_get_client` | Retrieve OAuth2 client by ID | OAuthServer / DatabaseAuthenticator | +| `resolvespec_oauth_save_code` | Persist authorization code | OAuthServer / DatabaseAuthenticator | +| `resolvespec_oauth_exchange_code` | Consume authorization code (single-use) | OAuthServer / DatabaseAuthenticator | +| `resolvespec_oauth_introspect` | Token introspection (RFC 7662) | OAuthServer / DatabaseAuthenticator | +| `resolvespec_oauth_revoke` | Token revocation (RFC 7009) | OAuthServer / DatabaseAuthenticator | See `database_schema.sql` for complete stored procedure definitions and examples. @@ -897,6 +904,155 @@ securityList := security.NewSecurityList(provider) restheadspec.RegisterSecurityHooks(handler, securityList) // or funcspec/resolvespec ``` +## OAuth2 Authorization Server + +`OAuthServer` is a generic OAuth 2.1 + PKCE authorization server. It is not tied to any spec — `pkg/resolvemcp` uses it, but it can be used standalone with any `http.ServeMux`. + +### Endpoints + +| Method | Path | RFC | +|--------|------|-----| +| `GET` | `/.well-known/oauth-authorization-server` | RFC 8414 — server metadata | +| `POST` | `/oauth/register` | RFC 7591 — dynamic client registration | +| `GET` | `/oauth/authorize` | OAuth 2.1 — start authorization / provider selection | +| `POST` | `/oauth/authorize` | OAuth 2.1 — login form submission | +| `POST` | `/oauth/token` | OAuth 2.1 — code exchange + refresh | +| `POST` | `/oauth/revoke` | RFC 7009 — token revocation | +| `POST` | `/oauth/introspect` | RFC 7662 — token introspection | +| `GET` | `{ProviderCallbackPath}` | External provider redirect target | + +### Config + +```go +cfg := security.OAuthServerConfig{ + Issuer: "https://example.com", // Required — token issuer URL + ProviderCallbackPath: "/oauth/provider/callback", // External provider redirect target + LoginTitle: "My App Login", // HTML login page title + PersistClients: true, // Store clients in DB (multi-instance safe) + PersistCodes: true, // Store codes in DB (multi-instance safe) + DefaultScopes: []string{"openid", "profile"}, // Returned when no scope requested + AccessTokenTTL: time.Hour, + AuthCodeTTL: 5 * time.Minute, +} +``` + +| Field | Default | Notes | +|-------|---------|-------| +| `Issuer` | — | Required | +| `ProviderCallbackPath` | `/oauth/provider/callback` | | +| `LoginTitle` | `"Login"` | | +| `PersistClients` | `false` | Set `true` for multi-instance | +| `PersistCodes` | `false` | Set `true` for multi-instance | +| `DefaultScopes` | `nil` | | +| `AccessTokenTTL` | `1h` | | +| `AuthCodeTTL` | `5m` | | + +### Operating Modes + +**Mode 1 — Direct login (username/password form)** + +Pass a `*DatabaseAuthenticator` to `NewOAuthServer`. The server renders a login form at `GET /oauth/authorize` and issues tokens via the stored session after login. + +```go +auth := security.NewDatabaseAuthenticator(db) +srv := security.NewOAuthServer(cfg, auth) +``` + +**Mode 2 — External provider federation** + +Pass `nil` as auth and register external providers. The authorize page shows a provider selection UI. + +```go +srv := security.NewOAuthServer(cfg, nil) +srv.RegisterExternalProvider(googleAuth, "google") +srv.RegisterExternalProvider(githubAuth, "github") +``` + +**Mode 3 — Both** + +Pass auth for the login form and also register external providers. The authorize page shows both a login form and provider buttons. + +```go +srv := security.NewOAuthServer(cfg, auth) +srv.RegisterExternalProvider(googleAuth, "google") +``` + +### Standalone Usage + +```go +mux := http.NewServeMux() +mux.Handle("/.well-known/", srv.HTTPHandler()) +mux.Handle("/oauth/", srv.HTTPHandler()) +mux.Handle(cfg.ProviderCallbackPath, srv.HTTPHandler()) + +http.ListenAndServe(":8080", mux) +``` + +### DB Persistence + +When `PersistClients: true` or `PersistCodes: true`, the server calls the corresponding `DatabaseAuthenticator` methods. Both flags default to `false` (in-memory maps). Enable both for multi-instance deployments. + +Requires `oauth_clients` and `oauth_codes` tables + 6 stored procedures from `database_schema.sql`. + +#### New DB Types + +```go +type OAuthServerClient struct { + ClientID string `json:"client_id"` + RedirectURIs []string `json:"redirect_uris"` + ClientName string `json:"client_name,omitempty"` + GrantTypes []string `json:"grant_types"` + AllowedScopes []string `json:"allowed_scopes,omitempty"` +} + +type OAuthCode struct { + Code string `json:"code"` + ClientID string `json:"client_id"` + RedirectURI string `json:"redirect_uri"` + ClientState string `json:"client_state,omitempty"` + CodeChallenge string `json:"code_challenge"` + CodeChallengeMethod string `json:"code_challenge_method"` + SessionToken string `json:"session_token"` + Scopes []string `json:"scopes,omitempty"` + ExpiresAt time.Time `json:"expires_at"` +} + +type OAuthTokenInfo struct { + Active bool `json:"active"` + Sub string `json:"sub,omitempty"` + Username string `json:"username,omitempty"` + Email string `json:"email,omitempty"` + Roles []string `json:"roles,omitempty"` + Exp int64 `json:"exp,omitempty"` + Iat int64 `json:"iat,omitempty"` +} +``` + +#### DatabaseAuthenticator OAuth Methods + +```go +auth.OAuthRegisterClient(ctx, client) // RFC 7591 — persist client +auth.OAuthGetClient(ctx, clientID) // retrieve client +auth.OAuthSaveCode(ctx, code) // persist authorization code +auth.OAuthExchangeCode(ctx, code) // consume code (single-use, deletes on read) +auth.OAuthIntrospectToken(ctx, token) // RFC 7662 — returns OAuthTokenInfo +auth.OAuthRevokeToken(ctx, token) // RFC 7009 — revoke session +``` + +#### SQLNames Fields + +```go +type SQLNames struct { + // ... existing fields ... + OAuthRegisterClient string // default: "resolvespec_oauth_register_client" + OAuthGetClient string // default: "resolvespec_oauth_get_client" + OAuthSaveCode string // default: "resolvespec_oauth_save_code" + OAuthExchangeCode string // default: "resolvespec_oauth_exchange_code" + OAuthIntrospect string // default: "resolvespec_oauth_introspect" + OAuthRevoke string // default: "resolvespec_oauth_revoke" +} +``` + The main changes: 1. Security package no longer knows about specific spec types 2. Each spec registers its own security hooks diff --git a/pkg/security/database_schema.sql b/pkg/security/database_schema.sql index 12f31f1..bbddb9b 100644 --- a/pkg/security/database_schema.sql +++ b/pkg/security/database_schema.sql @@ -1397,3 +1397,173 @@ $$ LANGUAGE plpgsql; -- Get credentials by username -- SELECT * FROM resolvespec_passkey_get_credentials_by_username('admin'); + +-- ============================================ +-- OAuth2 Server Tables (OAuthServer persistence) +-- ============================================ + +-- oauth_clients: persistent RFC 7591 registered clients +CREATE TABLE IF NOT EXISTS oauth_clients ( + id SERIAL PRIMARY KEY, + client_id VARCHAR(255) NOT NULL UNIQUE, + redirect_uris TEXT[] NOT NULL, + client_name VARCHAR(255), + grant_types TEXT[] DEFAULT ARRAY['authorization_code'], + allowed_scopes TEXT[] DEFAULT ARRAY['openid','profile','email'], + is_active BOOLEAN DEFAULT true, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP +); + +-- oauth_codes: short-lived authorization codes (for multi-instance deployments) +CREATE TABLE IF NOT EXISTS oauth_codes ( + id SERIAL PRIMARY KEY, + code VARCHAR(255) NOT NULL UNIQUE, + client_id VARCHAR(255) NOT NULL REFERENCES oauth_clients(client_id) ON DELETE CASCADE, + redirect_uri TEXT NOT NULL, + client_state TEXT, + code_challenge VARCHAR(255) NOT NULL, + code_challenge_method VARCHAR(10) DEFAULT 'S256', + session_token TEXT NOT NULL, + scopes TEXT[], + expires_at TIMESTAMP NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP +); +CREATE INDEX IF NOT EXISTS idx_oauth_codes_code ON oauth_codes(code); +CREATE INDEX IF NOT EXISTS idx_oauth_codes_expires ON oauth_codes(expires_at); + +-- ============================================ +-- OAuth2 Server Stored Procedures +-- ============================================ + +CREATE OR REPLACE FUNCTION resolvespec_oauth_register_client(p_data jsonb) +RETURNS TABLE(p_success bool, p_error text, p_data jsonb) +LANGUAGE plpgsql AS $$ +DECLARE + v_client_id text; + v_row jsonb; +BEGIN + v_client_id := p_data->>'client_id'; + + INSERT INTO oauth_clients (client_id, redirect_uris, client_name, grant_types, allowed_scopes) + VALUES ( + v_client_id, + ARRAY(SELECT jsonb_array_elements_text(p_data->'redirect_uris')), + p_data->>'client_name', + COALESCE(ARRAY(SELECT jsonb_array_elements_text(p_data->'grant_types')), ARRAY['authorization_code']), + COALESCE(ARRAY(SELECT jsonb_array_elements_text(p_data->'allowed_scopes')), ARRAY['openid','profile','email']) + ) + RETURNING to_jsonb(oauth_clients.*) INTO v_row; + + RETURN QUERY SELECT true, null::text, v_row; +EXCEPTION WHEN OTHERS THEN + RETURN QUERY SELECT false, SQLERRM, null::jsonb; +END; +$$; + +CREATE OR REPLACE FUNCTION resolvespec_oauth_get_client(p_client_id text) +RETURNS TABLE(p_success bool, p_error text, p_data jsonb) +LANGUAGE plpgsql AS $$ +DECLARE + v_row jsonb; +BEGIN + SELECT to_jsonb(oauth_clients.*) + INTO v_row + FROM oauth_clients + WHERE client_id = p_client_id AND is_active = true; + + IF v_row IS NULL THEN + RETURN QUERY SELECT false, 'client not found'::text, null::jsonb; + ELSE + RETURN QUERY SELECT true, null::text, v_row; + END IF; +END; +$$; + +CREATE OR REPLACE FUNCTION resolvespec_oauth_save_code(p_data jsonb) +RETURNS TABLE(p_success bool, p_error text) +LANGUAGE plpgsql AS $$ +BEGIN + INSERT INTO oauth_codes (code, client_id, redirect_uri, client_state, code_challenge, code_challenge_method, session_token, scopes, expires_at) + VALUES ( + p_data->>'code', + p_data->>'client_id', + p_data->>'redirect_uri', + p_data->>'client_state', + p_data->>'code_challenge', + COALESCE(p_data->>'code_challenge_method', 'S256'), + p_data->>'session_token', + ARRAY(SELECT jsonb_array_elements_text(p_data->'scopes')), + (p_data->>'expires_at')::timestamp + ); + + RETURN QUERY SELECT true, null::text; +EXCEPTION WHEN OTHERS THEN + RETURN QUERY SELECT false, SQLERRM; +END; +$$; + +CREATE OR REPLACE FUNCTION resolvespec_oauth_exchange_code(p_code text) +RETURNS TABLE(p_success bool, p_error text, p_data jsonb) +LANGUAGE plpgsql AS $$ +DECLARE + v_row jsonb; +BEGIN + DELETE FROM oauth_codes + WHERE code = p_code AND expires_at > now() + RETURNING jsonb_build_object( + 'client_id', client_id, + 'redirect_uri', redirect_uri, + 'client_state', client_state, + 'code_challenge', code_challenge, + 'code_challenge_method', code_challenge_method, + 'session_token', session_token, + 'scopes', to_jsonb(scopes) + ) INTO v_row; + + IF v_row IS NULL THEN + RETURN QUERY SELECT false, 'invalid or expired code'::text, null::jsonb; + ELSE + RETURN QUERY SELECT true, null::text, v_row; + END IF; +END; +$$; + +CREATE OR REPLACE FUNCTION resolvespec_oauth_introspect(p_token text) +RETURNS TABLE(p_success bool, p_error text, p_data jsonb) +LANGUAGE plpgsql AS $$ +DECLARE + v_row jsonb; +BEGIN + SELECT jsonb_build_object( + 'active', true, + 'sub', u.id::text, + 'username', u.username, + 'email', u.email, + 'user_level', u.user_level, + 'roles', to_jsonb(string_to_array(COALESCE(u.roles, ''), ',')), + 'exp', EXTRACT(EPOCH FROM s.expires_at)::bigint, + 'iat', EXTRACT(EPOCH FROM s.created_at)::bigint + ) + INTO v_row + FROM user_sessions s + JOIN users u ON u.id = s.user_id + WHERE s.session_token = p_token + AND s.expires_at > now() + AND u.is_active = true; + + IF v_row IS NULL THEN + RETURN QUERY SELECT true, null::text, '{"active":false}'::jsonb; + ELSE + RETURN QUERY SELECT true, null::text, v_row; + END IF; +END; +$$; + +CREATE OR REPLACE FUNCTION resolvespec_oauth_revoke(p_token text) +RETURNS TABLE(p_success bool, p_error text) +LANGUAGE plpgsql AS $$ +BEGIN + DELETE FROM user_sessions WHERE session_token = p_token; + RETURN QUERY SELECT true, null::text; +END; +$$; diff --git a/pkg/security/oauth_server.go b/pkg/security/oauth_server.go new file mode 100644 index 0000000..7fff22f --- /dev/null +++ b/pkg/security/oauth_server.go @@ -0,0 +1,859 @@ +package security + +import ( + "context" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "fmt" + "net/http" + "net/url" + "strings" + "sync" + "time" +) + +// OAuthServerConfig configures the MCP-standard OAuth2 authorization server. +type OAuthServerConfig struct { + // Issuer is the public base URL of this server (e.g. "https://api.example.com"). + // Used in /.well-known/oauth-authorization-server and to build endpoint URLs. + Issuer string + + // ProviderCallbackPath is the path on this server that external OAuth2 providers + // redirect back to. Defaults to "/oauth/provider/callback". + ProviderCallbackPath string + + // LoginTitle is shown on the built-in login form when the server acts as its own + // identity provider. Defaults to "MCP Login". + LoginTitle string + + // PersistClients stores registered clients in the database when a DatabaseAuthenticator is provided. + // Clients registered during a session survive server restarts. + PersistClients bool + + // PersistCodes stores authorization codes in the database. + // Useful for multi-instance deployments. Defaults to in-memory. + PersistCodes bool + + // DefaultScopes lists scopes advertised in server metadata. Defaults to ["openid","profile","email"]. + DefaultScopes []string + + // AccessTokenTTL is the issued token lifetime. Defaults to 24h. + AccessTokenTTL time.Duration + + // AuthCodeTTL is the auth code lifetime. Defaults to 2 minutes. + AuthCodeTTL time.Duration +} + +// oauthClient is a dynamically registered OAuth2 client (RFC 7591). +type oauthClient struct { + ClientID string `json:"client_id"` + RedirectURIs []string `json:"redirect_uris"` + ClientName string `json:"client_name,omitempty"` + GrantTypes []string `json:"grant_types"` + AllowedScopes []string `json:"allowed_scopes,omitempty"` +} + +// pendingAuth tracks an in-progress authorization code exchange. +type pendingAuth struct { + ClientID string + RedirectURI string + ClientState string + CodeChallenge string + CodeChallengeMethod string + ProviderName string // empty = password login + ExpiresAt time.Time + SessionToken string // set after authentication completes + Scopes []string // requested scopes +} + +// externalProvider pairs a DatabaseAuthenticator with its provider name. +type externalProvider struct { + auth *DatabaseAuthenticator + providerName string +} + +// OAuthServer implements the MCP-standard OAuth2 authorization server (OAuth 2.1 + PKCE). +// +// It can act as both: +// - A direct identity provider using DatabaseAuthenticator username/password login +// - A federation layer that delegates authentication to external OAuth2 providers +// (Google, GitHub, Microsoft, etc.) registered via RegisterExternalProvider +// +// The server exposes these RFC-compliant endpoints: +// +// GET /.well-known/oauth-authorization-server RFC 8414 — server metadata discovery +// POST /oauth/register RFC 7591 — dynamic client registration +// GET /oauth/authorize OAuth 2.1 + PKCE — start authorization +// POST /oauth/authorize Direct login form submission +// POST /oauth/token Token exchange and refresh +// POST /oauth/revoke RFC 7009 — token revocation +// POST /oauth/introspect RFC 7662 — token introspection +// GET {ProviderCallbackPath} Internal — external provider callback +type OAuthServer struct { + cfg OAuthServerConfig + auth *DatabaseAuthenticator // nil = only external providers + providers []externalProvider + + mu sync.RWMutex + clients map[string]*oauthClient + pending map[string]*pendingAuth // provider_state → pending (external flow) + codes map[string]*pendingAuth // auth_code → pending (post-auth) +} + +// NewOAuthServer creates a new MCP OAuth2 authorization server. +// +// Pass a DatabaseAuthenticator to enable direct username/password login (the server +// acts as its own identity provider). Pass nil to use only external providers. +// External providers are added separately via RegisterExternalProvider. +func NewOAuthServer(cfg OAuthServerConfig, auth *DatabaseAuthenticator) *OAuthServer { + if cfg.ProviderCallbackPath == "" { + cfg.ProviderCallbackPath = "/oauth/provider/callback" + } + if cfg.LoginTitle == "" { + cfg.LoginTitle = "Sign in" + } + if len(cfg.DefaultScopes) == 0 { + cfg.DefaultScopes = []string{"openid", "profile", "email"} + } + if cfg.AccessTokenTTL == 0 { + cfg.AccessTokenTTL = 24 * time.Hour + } + if cfg.AuthCodeTTL == 0 { + cfg.AuthCodeTTL = 2 * time.Minute + } + s := &OAuthServer{ + cfg: cfg, + auth: auth, + clients: make(map[string]*oauthClient), + pending: make(map[string]*pendingAuth), + codes: make(map[string]*pendingAuth), + } + go s.cleanupExpired() + return s +} + +// RegisterExternalProvider adds an external OAuth2 provider (Google, GitHub, Microsoft, etc.) +// that handles user authentication via redirect. The DatabaseAuthenticator must have been +// configured with WithOAuth2(providerName, ...) before calling this. +// Multiple providers can be registered; the first is used as the default. +func (s *OAuthServer) RegisterExternalProvider(auth *DatabaseAuthenticator, providerName string) { + s.providers = append(s.providers, externalProvider{auth: auth, providerName: providerName}) +} + +// ProviderCallbackPath returns the configured path for external provider callbacks. +func (s *OAuthServer) ProviderCallbackPath() string { + return s.cfg.ProviderCallbackPath +} + +// HTTPHandler returns an http.Handler that serves all RFC-required OAuth2 endpoints. +// Mount it at the root of your HTTP server alongside the MCP transport. +// +// mux := http.NewServeMux() +// mux.Handle("/", oauthServer.HTTPHandler()) +// mux.Handle("/mcp/", mcpTransport) +func (s *OAuthServer) HTTPHandler() http.Handler { + mux := http.NewServeMux() + mux.HandleFunc("/.well-known/oauth-authorization-server", s.metadataHandler) + mux.HandleFunc("/oauth/register", s.registerHandler) + mux.HandleFunc("/oauth/authorize", s.authorizeHandler) + mux.HandleFunc("/oauth/token", s.tokenHandler) + mux.HandleFunc("/oauth/revoke", s.revokeHandler) + mux.HandleFunc("/oauth/introspect", s.introspectHandler) + mux.HandleFunc(s.cfg.ProviderCallbackPath, s.providerCallbackHandler) + return mux +} + +// cleanupExpired removes stale pending auths and codes every 5 minutes. +func (s *OAuthServer) cleanupExpired() { + ticker := time.NewTicker(5 * time.Minute) + defer ticker.Stop() + for range ticker.C { + now := time.Now() + s.mu.Lock() + for k, p := range s.pending { + if now.After(p.ExpiresAt) { + delete(s.pending, k) + } + } + for k, p := range s.codes { + if now.After(p.ExpiresAt) { + delete(s.codes, k) + } + } + s.mu.Unlock() + } +} + +// -------------------------------------------------------------------------- +// RFC 8414 — Server metadata +// -------------------------------------------------------------------------- + +func (s *OAuthServer) metadataHandler(w http.ResponseWriter, r *http.Request) { + issuer := s.cfg.Issuer + meta := map[string]interface{}{ + "issuer": issuer, + "authorization_endpoint": issuer + "/oauth/authorize", + "token_endpoint": issuer + "/oauth/token", + "registration_endpoint": issuer + "/oauth/register", + "revocation_endpoint": issuer + "/oauth/revoke", + "introspection_endpoint": issuer + "/oauth/introspect", + "scopes_supported": s.cfg.DefaultScopes, + "response_types_supported": []string{"code"}, + "grant_types_supported": []string{"authorization_code", "refresh_token"}, + "code_challenge_methods_supported": []string{"S256"}, + "token_endpoint_auth_methods_supported": []string{"none"}, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(meta) //nolint:errcheck +} + +// -------------------------------------------------------------------------- +// RFC 7591 — Dynamic client registration +// -------------------------------------------------------------------------- + +func (s *OAuthServer) registerHandler(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + var req struct { + RedirectURIs []string `json:"redirect_uris"` + ClientName string `json:"client_name"` + GrantTypes []string `json:"grant_types"` + AllowedScopes []string `json:"allowed_scopes"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeOAuthError(w, "invalid_request", "malformed JSON", http.StatusBadRequest) + return + } + if len(req.RedirectURIs) == 0 { + writeOAuthError(w, "invalid_request", "redirect_uris required", http.StatusBadRequest) + return + } + grantTypes := req.GrantTypes + if len(grantTypes) == 0 { + grantTypes = []string{"authorization_code"} + } + allowedScopes := req.AllowedScopes + if len(allowedScopes) == 0 { + allowedScopes = s.cfg.DefaultScopes + } + clientID, err := randomOAuthToken() + if err != nil { + http.Error(w, "server error", http.StatusInternalServerError) + return + } + client := &oauthClient{ + ClientID: clientID, + RedirectURIs: req.RedirectURIs, + ClientName: req.ClientName, + GrantTypes: grantTypes, + AllowedScopes: allowedScopes, + } + + if s.cfg.PersistClients && s.auth != nil { + dbClient := &OAuthServerClient{ + ClientID: client.ClientID, + RedirectURIs: client.RedirectURIs, + ClientName: client.ClientName, + GrantTypes: client.GrantTypes, + AllowedScopes: client.AllowedScopes, + } + if _, err := s.auth.OAuthRegisterClient(r.Context(), dbClient); err != nil { + http.Error(w, "server error", http.StatusInternalServerError) + return + } + } + + s.mu.Lock() + s.clients[clientID] = client + s.mu.Unlock() + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusCreated) + json.NewEncoder(w).Encode(client) //nolint:errcheck +} + +// -------------------------------------------------------------------------- +// Authorization endpoint — GET + POST /oauth/authorize +// -------------------------------------------------------------------------- + +func (s *OAuthServer) authorizeHandler(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + s.authorizeGet(w, r) + case http.MethodPost: + s.authorizePost(w, r) + default: + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + } +} + +// authorizeGet validates the request and either: +// - Redirects to an external provider (if providers are registered) +// - Renders a login form (if the server is its own identity provider) +func (s *OAuthServer) authorizeGet(w http.ResponseWriter, r *http.Request) { + q := r.URL.Query() + clientID := q.Get("client_id") + redirectURI := q.Get("redirect_uri") + clientState := q.Get("state") + codeChallenge := q.Get("code_challenge") + codeChallengeMethod := q.Get("code_challenge_method") + providerName := q.Get("provider") + scopeStr := q.Get("scope") + var scopes []string + if scopeStr != "" { + scopes = strings.Fields(scopeStr) + } + + if q.Get("response_type") != "code" { + writeOAuthError(w, "unsupported_response_type", "only 'code' is supported", http.StatusBadRequest) + return + } + if codeChallenge == "" { + writeOAuthError(w, "invalid_request", "code_challenge required (PKCE S256)", http.StatusBadRequest) + return + } + if codeChallengeMethod != "" && codeChallengeMethod != "S256" { + writeOAuthError(w, "invalid_request", "only S256 code_challenge_method is supported", http.StatusBadRequest) + return + } + client, ok := s.lookupOrFetchClient(r.Context(), clientID) + if !ok { + writeOAuthError(w, "invalid_client", "unknown client_id", http.StatusBadRequest) + return + } + if !oauthSliceContains(client.RedirectURIs, redirectURI) { + writeOAuthError(w, "invalid_request", "redirect_uri not registered", http.StatusBadRequest) + return + } + + // External provider path + if len(s.providers) > 0 { + s.redirectToExternalProvider(w, r, clientID, redirectURI, clientState, codeChallenge, codeChallengeMethod, providerName, scopes) + return + } + + // Direct login form path (server is its own identity provider) + if s.auth == nil { + http.Error(w, "no authentication provider configured", http.StatusInternalServerError) + return + } + s.renderLoginForm(w, r, clientID, redirectURI, clientState, codeChallenge, codeChallengeMethod, scopeStr, "") +} + +// authorizePost handles login form submission for the direct login flow. +func (s *OAuthServer) authorizePost(w http.ResponseWriter, r *http.Request) { + if err := r.ParseForm(); err != nil { + http.Error(w, "invalid form", http.StatusBadRequest) + return + } + + clientID := r.FormValue("client_id") + redirectURI := r.FormValue("redirect_uri") + clientState := r.FormValue("client_state") + codeChallenge := r.FormValue("code_challenge") + codeChallengeMethod := r.FormValue("code_challenge_method") + username := r.FormValue("username") + password := r.FormValue("password") + scopeStr := r.FormValue("scope") + var scopes []string + if scopeStr != "" { + scopes = strings.Fields(scopeStr) + } + + client, ok := s.lookupOrFetchClient(r.Context(), clientID) + if !ok || !oauthSliceContains(client.RedirectURIs, redirectURI) { + http.Error(w, "invalid client or redirect_uri", http.StatusBadRequest) + return + } + if s.auth == nil { + http.Error(w, "no authentication provider configured", http.StatusInternalServerError) + return + } + + loginResp, err := s.auth.Login(r.Context(), LoginRequest{ + Username: username, + Password: password, + }) + if err != nil { + s.renderLoginForm(w, r, clientID, redirectURI, clientState, codeChallenge, codeChallengeMethod, scopeStr, "Invalid username or password") + return + } + + s.issueCodeAndRedirect(w, r, loginResp.Token, clientID, redirectURI, clientState, codeChallenge, codeChallengeMethod, "", scopes) +} + +// redirectToExternalProvider stores the pending auth and redirects to the configured provider. +func (s *OAuthServer) redirectToExternalProvider(w http.ResponseWriter, r *http.Request, clientID, redirectURI, clientState, codeChallenge, codeChallengeMethod, providerName string, scopes []string) { + var provider *externalProvider + if providerName != "" { + for i := range s.providers { + if s.providers[i].providerName == providerName { + provider = &s.providers[i] + break + } + } + if provider == nil { + http.Error(w, fmt.Sprintf("provider %q not found", providerName), http.StatusBadRequest) + return + } + } else { + provider = &s.providers[0] + } + + providerState, err := randomOAuthToken() + if err != nil { + http.Error(w, "server error", http.StatusInternalServerError) + return + } + + pending := &pendingAuth{ + ClientID: clientID, + RedirectURI: redirectURI, + ClientState: clientState, + CodeChallenge: codeChallenge, + CodeChallengeMethod: codeChallengeMethod, + ProviderName: provider.providerName, + ExpiresAt: time.Now().Add(10 * time.Minute), + Scopes: scopes, + } + s.mu.Lock() + s.pending[providerState] = pending + s.mu.Unlock() + + authURL, err := provider.auth.OAuth2GetAuthURL(provider.providerName, providerState) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + http.Redirect(w, r, authURL, http.StatusFound) +} + +// -------------------------------------------------------------------------- +// External provider callback — GET {ProviderCallbackPath} +// -------------------------------------------------------------------------- + +func (s *OAuthServer) providerCallbackHandler(w http.ResponseWriter, r *http.Request) { + code := r.URL.Query().Get("code") + providerState := r.URL.Query().Get("state") + + if code == "" { + http.Error(w, "missing code", http.StatusBadRequest) + return + } + + s.mu.Lock() + pending, ok := s.pending[providerState] + if ok { + delete(s.pending, providerState) + } + s.mu.Unlock() + + if !ok || time.Now().After(pending.ExpiresAt) { + http.Error(w, "invalid or expired state", http.StatusBadRequest) + return + } + + provider := s.providerByName(pending.ProviderName) + if provider == nil { + http.Error(w, fmt.Sprintf("provider %q not found", pending.ProviderName), http.StatusInternalServerError) + return + } + + loginResp, err := provider.auth.OAuth2HandleCallback(r.Context(), pending.ProviderName, code, providerState) + if err != nil { + http.Error(w, err.Error(), http.StatusUnauthorized) + return + } + + s.issueCodeAndRedirect(w, r, loginResp.Token, + pending.ClientID, pending.RedirectURI, pending.ClientState, + pending.CodeChallenge, pending.CodeChallengeMethod, pending.ProviderName, pending.Scopes) +} + +// issueCodeAndRedirect generates a short-lived auth code and redirects to the MCP client. +func (s *OAuthServer) issueCodeAndRedirect(w http.ResponseWriter, r *http.Request, sessionToken, clientID, redirectURI, clientState, codeChallenge, codeChallengeMethod, providerName string, scopes []string) { + authCode, err := randomOAuthToken() + if err != nil { + http.Error(w, "server error", http.StatusInternalServerError) + return + } + + pending := &pendingAuth{ + ClientID: clientID, + RedirectURI: redirectURI, + ClientState: clientState, + CodeChallenge: codeChallenge, + CodeChallengeMethod: codeChallengeMethod, + ProviderName: providerName, + SessionToken: sessionToken, + ExpiresAt: time.Now().Add(s.cfg.AuthCodeTTL), + Scopes: scopes, + } + + if s.cfg.PersistCodes && s.auth != nil { + oauthCode := &OAuthCode{ + Code: authCode, + ClientID: clientID, + RedirectURI: redirectURI, + ClientState: clientState, + CodeChallenge: codeChallenge, + CodeChallengeMethod: codeChallengeMethod, + SessionToken: sessionToken, + Scopes: scopes, + ExpiresAt: pending.ExpiresAt, + } + if err := s.auth.OAuthSaveCode(r.Context(), oauthCode); err != nil { + http.Error(w, "server error", http.StatusInternalServerError) + return + } + } else { + s.mu.Lock() + s.codes[authCode] = pending + s.mu.Unlock() + } + + redirectURL, err := url.Parse(redirectURI) + if err != nil { + http.Error(w, "invalid redirect_uri", http.StatusInternalServerError) + return + } + qp := redirectURL.Query() + qp.Set("code", authCode) + if clientState != "" { + qp.Set("state", clientState) + } + redirectURL.RawQuery = qp.Encode() + http.Redirect(w, r, redirectURL.String(), http.StatusFound) +} + +// -------------------------------------------------------------------------- +// Token endpoint — POST /oauth/token +// -------------------------------------------------------------------------- + +func (s *OAuthServer) tokenHandler(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + if err := r.ParseForm(); err != nil { + writeOAuthError(w, "invalid_request", "cannot parse form", http.StatusBadRequest) + return + } + switch r.FormValue("grant_type") { + case "authorization_code": + s.handleAuthCodeGrant(w, r) + case "refresh_token": + s.handleRefreshGrant(w, r) + default: + writeOAuthError(w, "unsupported_grant_type", "", http.StatusBadRequest) + } +} + +func (s *OAuthServer) handleAuthCodeGrant(w http.ResponseWriter, r *http.Request) { + code := r.FormValue("code") + redirectURI := r.FormValue("redirect_uri") + clientID := r.FormValue("client_id") + codeVerifier := r.FormValue("code_verifier") + + if code == "" || codeVerifier == "" { + writeOAuthError(w, "invalid_request", "code and code_verifier required", http.StatusBadRequest) + return + } + + var sessionToken string + var scopes []string + + if s.cfg.PersistCodes && s.auth != nil { + oauthCode, err := s.auth.OAuthExchangeCode(r.Context(), code) + if err != nil { + writeOAuthError(w, "invalid_grant", "code expired or invalid", http.StatusBadRequest) + return + } + if oauthCode.ClientID != clientID { + writeOAuthError(w, "invalid_client", "", http.StatusBadRequest) + return + } + if oauthCode.RedirectURI != redirectURI { + writeOAuthError(w, "invalid_grant", "redirect_uri mismatch", http.StatusBadRequest) + return + } + if !validatePKCESHA256(oauthCode.CodeChallenge, codeVerifier) { + writeOAuthError(w, "invalid_grant", "code_verifier invalid", http.StatusBadRequest) + return + } + sessionToken = oauthCode.SessionToken + scopes = oauthCode.Scopes + } else { + s.mu.Lock() + pending, ok := s.codes[code] + if ok { + delete(s.codes, code) + } + s.mu.Unlock() + + if !ok || time.Now().After(pending.ExpiresAt) { + writeOAuthError(w, "invalid_grant", "code expired or invalid", http.StatusBadRequest) + return + } + if pending.ClientID != clientID { + writeOAuthError(w, "invalid_client", "", http.StatusBadRequest) + return + } + if pending.RedirectURI != redirectURI { + writeOAuthError(w, "invalid_grant", "redirect_uri mismatch", http.StatusBadRequest) + return + } + if !validatePKCESHA256(pending.CodeChallenge, codeVerifier) { + writeOAuthError(w, "invalid_grant", "code_verifier invalid", http.StatusBadRequest) + return + } + sessionToken = pending.SessionToken + scopes = pending.Scopes + } + + writeOAuthToken(w, sessionToken, "", scopes) +} + +func (s *OAuthServer) handleRefreshGrant(w http.ResponseWriter, r *http.Request) { + refreshToken := r.FormValue("refresh_token") + providerName := r.FormValue("provider") + if refreshToken == "" { + writeOAuthError(w, "invalid_request", "refresh_token required", http.StatusBadRequest) + return + } + + // Try external providers first, then fall back to DatabaseAuthenticator + provider := s.providerByName(providerName) + if provider != nil { + loginResp, err := provider.auth.OAuth2RefreshToken(r.Context(), refreshToken, providerName) + if err != nil { + writeOAuthError(w, "invalid_grant", err.Error(), http.StatusBadRequest) + return + } + writeOAuthToken(w, loginResp.Token, loginResp.RefreshToken, nil) + return + } + + if s.auth != nil { + loginResp, err := s.auth.RefreshToken(r.Context(), refreshToken) + if err != nil { + writeOAuthError(w, "invalid_grant", err.Error(), http.StatusBadRequest) + return + } + writeOAuthToken(w, loginResp.Token, loginResp.RefreshToken, nil) + return + } + + writeOAuthError(w, "invalid_grant", "no provider available for refresh", http.StatusBadRequest) +} + +// -------------------------------------------------------------------------- +// RFC 7009 — Token revocation +// -------------------------------------------------------------------------- + +func (s *OAuthServer) revokeHandler(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + if err := r.ParseForm(); err != nil { + w.WriteHeader(http.StatusOK) + return + } + token := r.FormValue("token") + if token == "" { + w.WriteHeader(http.StatusOK) + return + } + + if s.auth != nil { + s.auth.OAuthRevokeToken(r.Context(), token) //nolint:errcheck + } + w.WriteHeader(http.StatusOK) +} + +// -------------------------------------------------------------------------- +// RFC 7662 — Token introspection +// -------------------------------------------------------------------------- + +func (s *OAuthServer) introspectHandler(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + if err := r.ParseForm(); err != nil { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"active":false}`)) //nolint:errcheck + return + } + token := r.FormValue("token") + w.Header().Set("Content-Type", "application/json") + + if token == "" || s.auth == nil { + w.Write([]byte(`{"active":false}`)) //nolint:errcheck + return + } + + info, err := s.auth.OAuthIntrospectToken(r.Context(), token) + if err != nil { + w.Write([]byte(`{"active":false}`)) //nolint:errcheck + return + } + json.NewEncoder(w).Encode(info) //nolint:errcheck +} + +// -------------------------------------------------------------------------- +// Login form (direct identity provider mode) +// -------------------------------------------------------------------------- + +func (s *OAuthServer) renderLoginForm(w http.ResponseWriter, r *http.Request, clientID, redirectURI, clientState, codeChallenge, codeChallengeMethod, scope, errMsg string) { + w.Header().Set("Content-Type", "text/html; charset=utf-8") + errHTML := "" + if errMsg != "" { + errHTML = `

` + errMsg + `

` + } + fmt.Fprintf(w, loginFormHTML, + s.cfg.LoginTitle, + s.cfg.LoginTitle, + errHTML, + clientID, + htmlEscape(redirectURI), + htmlEscape(clientState), + htmlEscape(codeChallenge), + htmlEscape(codeChallengeMethod), + htmlEscape(scope), + ) +} + +const loginFormHTML = ` +%s + +
+

%s

%s +
+ + + + + + + + + +
` + +// -------------------------------------------------------------------------- +// Helpers +// -------------------------------------------------------------------------- + +// lookupOrFetchClient checks in-memory first, then DB if PersistClients is enabled. +func (s *OAuthServer) lookupOrFetchClient(ctx context.Context, clientID string) (*oauthClient, bool) { + s.mu.RLock() + c, ok := s.clients[clientID] + s.mu.RUnlock() + if ok { + return c, true + } + + if !s.cfg.PersistClients || s.auth == nil { + return nil, false + } + + dbClient, err := s.auth.OAuthGetClient(ctx, clientID) + if err != nil { + return nil, false + } + + c = &oauthClient{ + ClientID: dbClient.ClientID, + RedirectURIs: dbClient.RedirectURIs, + ClientName: dbClient.ClientName, + GrantTypes: dbClient.GrantTypes, + AllowedScopes: dbClient.AllowedScopes, + } + s.mu.Lock() + s.clients[clientID] = c + s.mu.Unlock() + return c, true +} + +func (s *OAuthServer) providerByName(name string) *externalProvider { + for i := range s.providers { + if s.providers[i].providerName == name { + return &s.providers[i] + } + } + // If name is empty and only one provider exists, return it + if name == "" && len(s.providers) == 1 { + return &s.providers[0] + } + return nil +} + +func validatePKCESHA256(challenge, verifier string) bool { + h := sha256.Sum256([]byte(verifier)) + return base64.RawURLEncoding.EncodeToString(h[:]) == challenge +} + +func randomOAuthToken() (string, error) { + b := make([]byte, 32) + if _, err := rand.Read(b); err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(b), nil +} + +func oauthSliceContains(slice []string, s string) bool { + for _, v := range slice { + if strings.EqualFold(v, s) { + return true + } + } + return false +} + +func writeOAuthToken(w http.ResponseWriter, accessToken, refreshToken string, scopes []string) { + resp := map[string]interface{}{ + "access_token": accessToken, + "token_type": "Bearer", + "expires_in": 86400, + } + if refreshToken != "" { + resp["refresh_token"] = refreshToken + } + if len(scopes) > 0 { + resp["scope"] = strings.Join(scopes, " ") + } + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Cache-Control", "no-store") + w.Header().Set("Pragma", "no-cache") + json.NewEncoder(w).Encode(resp) //nolint:errcheck +} + +func writeOAuthError(w http.ResponseWriter, errCode, description string, status int) { + resp := map[string]string{"error": errCode} + if description != "" { + resp["error_description"] = description + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + json.NewEncoder(w).Encode(resp) //nolint:errcheck +} + +func htmlEscape(s string) string { + s = strings.ReplaceAll(s, "&", "&") + s = strings.ReplaceAll(s, `"`, """) + s = strings.ReplaceAll(s, "<", "<") + s = strings.ReplaceAll(s, ">", ">") + return s +} diff --git a/pkg/security/oauth_server_db.go b/pkg/security/oauth_server_db.go new file mode 100644 index 0000000..1206871 --- /dev/null +++ b/pkg/security/oauth_server_db.go @@ -0,0 +1,202 @@ +package security + +import ( + "context" + "encoding/json" + "fmt" + "time" +) + +// OAuthServerClient is a persisted RFC 7591 registered OAuth2 client. +type OAuthServerClient struct { + ClientID string `json:"client_id"` + RedirectURIs []string `json:"redirect_uris"` + ClientName string `json:"client_name,omitempty"` + GrantTypes []string `json:"grant_types"` + AllowedScopes []string `json:"allowed_scopes,omitempty"` +} + +// OAuthCode is a short-lived authorization code. +type OAuthCode struct { + Code string `json:"code"` + ClientID string `json:"client_id"` + RedirectURI string `json:"redirect_uri"` + ClientState string `json:"client_state,omitempty"` + CodeChallenge string `json:"code_challenge"` + CodeChallengeMethod string `json:"code_challenge_method"` + SessionToken string `json:"session_token"` + Scopes []string `json:"scopes,omitempty"` + ExpiresAt time.Time `json:"expires_at"` +} + +// OAuthTokenInfo is the RFC 7662 token introspection response. +type OAuthTokenInfo struct { + Active bool `json:"active"` + Sub string `json:"sub,omitempty"` + Username string `json:"username,omitempty"` + Email string `json:"email,omitempty"` + Roles []string `json:"roles,omitempty"` + Exp int64 `json:"exp,omitempty"` + Iat int64 `json:"iat,omitempty"` +} + +// OAuthRegisterClient persists an OAuth2 client registration. +func (a *DatabaseAuthenticator) OAuthRegisterClient(ctx context.Context, client *OAuthServerClient) (*OAuthServerClient, error) { + input, err := json.Marshal(client) + if err != nil { + return nil, fmt.Errorf("failed to marshal client: %w", err) + } + + var success bool + var errMsg *string + var data []byte + + err = a.db.QueryRowContext(ctx, fmt.Sprintf(` + SELECT p_success, p_error, p_data::text + FROM %s($1::jsonb) + `, a.sqlNames.OAuthRegisterClient), input).Scan(&success, &errMsg, &data) + if err != nil { + return nil, fmt.Errorf("failed to register client: %w", err) + } + if !success { + if errMsg != nil { + return nil, fmt.Errorf("%s", *errMsg) + } + return nil, fmt.Errorf("failed to register client") + } + + var result OAuthServerClient + if err := json.Unmarshal(data, &result); err != nil { + return nil, fmt.Errorf("failed to parse registered client: %w", err) + } + return &result, nil +} + +// OAuthGetClient retrieves a registered client by ID. +func (a *DatabaseAuthenticator) OAuthGetClient(ctx context.Context, clientID string) (*OAuthServerClient, error) { + var success bool + var errMsg *string + var data []byte + + err := a.db.QueryRowContext(ctx, fmt.Sprintf(` + SELECT p_success, p_error, p_data::text + FROM %s($1) + `, a.sqlNames.OAuthGetClient), clientID).Scan(&success, &errMsg, &data) + if err != nil { + return nil, fmt.Errorf("failed to get client: %w", err) + } + if !success { + if errMsg != nil { + return nil, fmt.Errorf("%s", *errMsg) + } + return nil, fmt.Errorf("client not found") + } + + var result OAuthServerClient + if err := json.Unmarshal(data, &result); err != nil { + return nil, fmt.Errorf("failed to parse client: %w", err) + } + return &result, nil +} + +// OAuthSaveCode persists an authorization code. +func (a *DatabaseAuthenticator) OAuthSaveCode(ctx context.Context, code *OAuthCode) error { + input, err := json.Marshal(code) + if err != nil { + return fmt.Errorf("failed to marshal code: %w", err) + } + + var success bool + var errMsg *string + + err = a.db.QueryRowContext(ctx, fmt.Sprintf(` + SELECT p_success, p_error + FROM %s($1::jsonb) + `, a.sqlNames.OAuthSaveCode), input).Scan(&success, &errMsg) + if err != nil { + return fmt.Errorf("failed to save code: %w", err) + } + if !success { + if errMsg != nil { + return fmt.Errorf("%s", *errMsg) + } + return fmt.Errorf("failed to save code") + } + return nil +} + +// OAuthExchangeCode retrieves and deletes an authorization code (single use). +func (a *DatabaseAuthenticator) OAuthExchangeCode(ctx context.Context, code string) (*OAuthCode, error) { + var success bool + var errMsg *string + var data []byte + + err := a.db.QueryRowContext(ctx, fmt.Sprintf(` + SELECT p_success, p_error, p_data::text + FROM %s($1) + `, a.sqlNames.OAuthExchangeCode), code).Scan(&success, &errMsg, &data) + if err != nil { + return nil, fmt.Errorf("failed to exchange code: %w", err) + } + if !success { + if errMsg != nil { + return nil, fmt.Errorf("%s", *errMsg) + } + return nil, fmt.Errorf("invalid or expired code") + } + + var result OAuthCode + if err := json.Unmarshal(data, &result); err != nil { + return nil, fmt.Errorf("failed to parse code data: %w", err) + } + result.Code = code + return &result, nil +} + +// OAuthIntrospectToken validates a token and returns its metadata (RFC 7662). +func (a *DatabaseAuthenticator) OAuthIntrospectToken(ctx context.Context, token string) (*OAuthTokenInfo, error) { + var success bool + var errMsg *string + var data []byte + + err := a.db.QueryRowContext(ctx, fmt.Sprintf(` + SELECT p_success, p_error, p_data::text + FROM %s($1) + `, a.sqlNames.OAuthIntrospect), token).Scan(&success, &errMsg, &data) + if err != nil { + return nil, fmt.Errorf("failed to introspect token: %w", err) + } + if !success { + if errMsg != nil { + return nil, fmt.Errorf("%s", *errMsg) + } + return nil, fmt.Errorf("introspection failed") + } + + var result OAuthTokenInfo + if err := json.Unmarshal(data, &result); err != nil { + return nil, fmt.Errorf("failed to parse token info: %w", err) + } + return &result, nil +} + +// OAuthRevokeToken revokes a token by deleting the session (RFC 7009). +func (a *DatabaseAuthenticator) OAuthRevokeToken(ctx context.Context, token string) error { + var success bool + var errMsg *string + + err := a.db.QueryRowContext(ctx, fmt.Sprintf(` + SELECT p_success, p_error + FROM %s($1) + `, a.sqlNames.OAuthRevoke), token).Scan(&success, &errMsg) + if err != nil { + return fmt.Errorf("failed to revoke token: %w", err) + } + if !success { + if errMsg != nil { + return fmt.Errorf("%s", *errMsg) + } + return fmt.Errorf("failed to revoke token") + } + return nil +} diff --git a/pkg/security/sql_names.go b/pkg/security/sql_names.go index bbb594c..80265d1 100644 --- a/pkg/security/sql_names.go +++ b/pkg/security/sql_names.go @@ -54,6 +54,13 @@ type SQLNames struct { OAuthUpdateRefreshToken string // default: "resolvespec_oauth_updaterefreshtoken" OAuthGetUser string // default: "resolvespec_oauth_getuser" + // OAuth2 server procedures (OAuthServer persistence) + OAuthRegisterClient string // default: "resolvespec_oauth_register_client" + OAuthGetClient string // default: "resolvespec_oauth_get_client" + OAuthSaveCode string // default: "resolvespec_oauth_save_code" + OAuthExchangeCode string // default: "resolvespec_oauth_exchange_code" + OAuthIntrospect string // default: "resolvespec_oauth_introspect" + OAuthRevoke string // default: "resolvespec_oauth_revoke" } // DefaultSQLNames returns an SQLNames with all default resolvespec_* values. @@ -93,6 +100,13 @@ func DefaultSQLNames() *SQLNames { OAuthGetRefreshToken: "resolvespec_oauth_getrefreshtoken", OAuthUpdateRefreshToken: "resolvespec_oauth_updaterefreshtoken", OAuthGetUser: "resolvespec_oauth_getuser", + + OAuthRegisterClient: "resolvespec_oauth_register_client", + OAuthGetClient: "resolvespec_oauth_get_client", + OAuthSaveCode: "resolvespec_oauth_save_code", + OAuthExchangeCode: "resolvespec_oauth_exchange_code", + OAuthIntrospect: "resolvespec_oauth_introspect", + OAuthRevoke: "resolvespec_oauth_revoke", } } @@ -191,6 +205,24 @@ func MergeSQLNames(base, override *SQLNames) *SQLNames { if override.OAuthGetUser != "" { merged.OAuthGetUser = override.OAuthGetUser } + if override.OAuthRegisterClient != "" { + merged.OAuthRegisterClient = override.OAuthRegisterClient + } + if override.OAuthGetClient != "" { + merged.OAuthGetClient = override.OAuthGetClient + } + if override.OAuthSaveCode != "" { + merged.OAuthSaveCode = override.OAuthSaveCode + } + if override.OAuthExchangeCode != "" { + merged.OAuthExchangeCode = override.OAuthExchangeCode + } + if override.OAuthIntrospect != "" { + merged.OAuthIntrospect = override.OAuthIntrospect + } + if override.OAuthRevoke != "" { + merged.OAuthRevoke = override.OAuthRevoke + } return &merged }