Compare commits

...

2 Commits

Author SHA1 Message Date
6502b55797 feat(security): implement OAuth2 authorization server with database support
- Add OAuthServer for handling OAuth2 flows including authorization, token exchange, and client registration.
- Introduce DatabaseAuthenticator for persisting clients and authorization codes.
- Implement SQL procedures for client registration, code saving, and token introspection.
- Support for external OAuth2 providers and PKCE (Proof Key for Code Exchange).
2026-04-07 22:56:05 +02:00
aa095d6bfd fix(tests): replace panic with log.Fatal for better error handling
Some checks failed
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Successful in -29m52s
Build , Vet Test, and Lint / Build (push) Successful in -29m52s
Tests / Integration Tests (push) Failing after -30m46s
Tests / Unit Tests (push) Successful in -28m51s
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Successful in -30m17s
Build , Vet Test, and Lint / Lint Code (push) Failing after -29m23s
2026-04-07 20:38:22 +02:00
11 changed files with 1941 additions and 60 deletions

View File

@@ -3,6 +3,7 @@ package providers_test
import ( import (
"context" "context"
"fmt" "fmt"
"log"
"time" "time"
"github.com/bitechdev/ResolveSpec/pkg/dbmanager" "github.com/bitechdev/ResolveSpec/pkg/dbmanager"
@@ -29,14 +30,14 @@ func ExamplePostgresListener_basic() {
ctx := context.Background() ctx := context.Background()
if err := provider.Connect(ctx, cfg); err != nil { if err := provider.Connect(ctx, cfg); err != nil {
panic(fmt.Sprintf("Failed to connect: %v", err)) log.Fatalf("Failed to connect: %v", err)
} }
defer provider.Close() defer provider.Close()
// Get listener // Get listener
listener, err := provider.GetListener(ctx) listener, err := provider.GetListener(ctx)
if err != nil { if err != nil {
panic(fmt.Sprintf("Failed to get listener: %v", err)) log.Fatalf("Failed to get listener: %v", err)
} }
// Subscribe to a channel with a handler // Subscribe to a channel with a handler
@@ -44,13 +45,13 @@ func ExamplePostgresListener_basic() {
fmt.Printf("Received notification on %s: %s\n", channel, payload) fmt.Printf("Received notification on %s: %s\n", channel, payload)
}) })
if err != nil { if err != nil {
panic(fmt.Sprintf("Failed to listen: %v", err)) log.Fatalf("Failed to listen: %v", err)
} }
// Send a notification // Send a notification
err = listener.Notify(ctx, "user_events", `{"event":"user_created","user_id":123}`) err = listener.Notify(ctx, "user_events", `{"event":"user_created","user_id":123}`)
if err != nil { if err != nil {
panic(fmt.Sprintf("Failed to notify: %v", err)) log.Fatalf("Failed to notify: %v", err)
} }
// Wait for notification to be processed // Wait for notification to be processed
@@ -58,7 +59,7 @@ func ExamplePostgresListener_basic() {
// Unsubscribe from the channel // Unsubscribe from the channel
if err := listener.Unlisten("user_events"); err != nil { if err := listener.Unlisten("user_events"); err != nil {
panic(fmt.Sprintf("Failed to unlisten: %v", err)) log.Fatalf("Failed to unlisten: %v", err)
} }
} }
@@ -80,13 +81,13 @@ func ExamplePostgresListener_multipleChannels() {
ctx := context.Background() ctx := context.Background()
if err := provider.Connect(ctx, cfg); err != nil { if err := provider.Connect(ctx, cfg); err != nil {
panic(fmt.Sprintf("Failed to connect: %v", err)) log.Fatalf("Failed to connect: %v", err)
} }
defer provider.Close() defer provider.Close()
listener, err := provider.GetListener(ctx) listener, err := provider.GetListener(ctx)
if err != nil { if err != nil {
panic(fmt.Sprintf("Failed to get listener: %v", err)) log.Fatalf("Failed to get listener: %v", err)
} }
// Listen to multiple channels // Listen to multiple channels
@@ -97,7 +98,7 @@ func ExamplePostgresListener_multipleChannels() {
fmt.Printf("[%s] %s\n", ch, payload) fmt.Printf("[%s] %s\n", ch, payload)
}) })
if err != nil { if err != nil {
panic(fmt.Sprintf("Failed to listen on %s: %v", channel, err)) log.Fatalf("Failed to listen on %s: %v", channel, err)
} }
} }
@@ -140,14 +141,14 @@ func ExamplePostgresListener_withDBManager() {
provider := providers.NewPostgresProvider() provider := providers.NewPostgresProvider()
if err := provider.Connect(ctx, cfg); err != nil { if err := provider.Connect(ctx, cfg); err != nil {
panic(err) log.Fatal(err)
} }
defer provider.Close() defer provider.Close()
// Get listener // Get listener
listener, err := provider.GetListener(ctx) listener, err := provider.GetListener(ctx)
if err != nil { if err != nil {
panic(err) log.Fatal(err)
} }
// Subscribe to application events // Subscribe to application events
@@ -186,13 +187,13 @@ func ExamplePostgresListener_errorHandling() {
ctx := context.Background() ctx := context.Background()
if err := provider.Connect(ctx, cfg); err != nil { if err := provider.Connect(ctx, cfg); err != nil {
panic(fmt.Sprintf("Failed to connect: %v", err)) log.Fatalf("Failed to connect: %v", err)
} }
defer provider.Close() defer provider.Close()
listener, err := provider.GetListener(ctx) listener, err := provider.GetListener(ctx)
if err != nil { if err != nil {
panic(fmt.Sprintf("Failed to get listener: %v", err)) log.Fatalf("Failed to get listener: %v", err)
} }
// The listener automatically reconnects if the connection is lost // The listener automatically reconnects if the connection is lost

View File

@@ -142,9 +142,137 @@ e.Any("/mcp", echo.WrapHandler(h)) // Echo
--- ---
### Authentication ## OAuth2 Authentication
Add middleware before the MCP routes. The handler itself has no auth layer. `resolvemcp` ships a full **MCP-standard OAuth2 authorization server** (`pkg/security.OAuthServer`) that MCP clients (Claude Desktop, Cursor, etc.) can discover and use automatically.
It can operate as:
- **Its own identity provider** — shows a login form, validates via `DatabaseAuthenticator.Login()`
- **An OAuth2 federation layer** — delegates to external providers (Google, GitHub, Microsoft, etc.)
- **Both simultaneously**
### Standard endpoints served
| Path | Spec | Purpose |
|---|---|---|
| `GET /.well-known/oauth-authorization-server` | RFC 8414 | MCP client auto-discovery |
| `POST /oauth/register` | RFC 7591 | Dynamic client registration |
| `GET /oauth/authorize` | OAuth 2.1 + PKCE | Start login (form or provider redirect) |
| `POST /oauth/authorize` | — | Login form submission |
| `POST /oauth/token` | OAuth 2.1 | Auth code → Bearer token exchange |
| `POST /oauth/token` (refresh) | OAuth 2.1 | Refresh token rotation |
| `GET /oauth/provider/callback` | Internal | External provider redirect target |
MCP clients send `Authorization: Bearer <token>` on all subsequent requests.
---
### Mode 1 — Direct login (server as identity provider)
```go
import "github.com/bitechdev/ResolveSpec/pkg/security"
db, _ := sql.Open("postgres", dsn)
auth := security.NewDatabaseAuthenticator(db)
handler := resolvemcp.NewHandlerWithGORM(gormDB, resolvemcp.Config{
BaseURL: "https://api.example.com",
BasePath: "/mcp",
})
// Enable the OAuth2 server — auth enables the login form
handler.EnableOAuthServer(security.OAuthServerConfig{
Issuer: "https://api.example.com",
}, auth)
provider, _ := security.NewCompositeSecurityProvider(auth, colSec, rowSec)
securityList, _ := security.NewSecurityList(provider)
security.RegisterSecurityHooks(handler, securityList)
http.ListenAndServe(":8080", handler.HTTPHandler(securityList))
```
MCP client flow:
1. Discovers server at `/.well-known/oauth-authorization-server`
2. Registers itself at `/oauth/register`
3. Redirects user to `/oauth/authorize` → login form appears
4. On submit, exchanges code at `/oauth/token` → receives `Authorization: Bearer` token
5. Uses token on all MCP tool calls
---
### Mode 2 — External provider (Google, GitHub, etc.)
The `RedirectURL` in the provider config must point to `/oauth/provider/callback` on this server.
```go
auth := security.NewDatabaseAuthenticator(db).WithOAuth2(security.OAuth2Config{
ClientID: os.Getenv("GOOGLE_CLIENT_ID"),
ClientSecret: os.Getenv("GOOGLE_CLIENT_SECRET"),
RedirectURL: "https://api.example.com/oauth/provider/callback",
Scopes: []string{"openid", "profile", "email"},
AuthURL: "https://accounts.google.com/o/oauth2/auth",
TokenURL: "https://oauth2.googleapis.com/token",
UserInfoURL: "https://www.googleapis.com/oauth2/v2/userinfo",
ProviderName: "google",
})
// nil = no password login; Google handles auth
handler.EnableOAuthServer(security.OAuthServerConfig{
Issuer: "https://api.example.com",
}, nil)
handler.RegisterOAuth2Provider(auth, "google")
```
---
### Mode 3 — Both (login form + external providers)
```go
handler.EnableOAuthServer(security.OAuthServerConfig{
Issuer: "https://api.example.com",
LoginTitle: "My App Login",
}, auth) // auth enables the username/password form
handler.RegisterOAuth2Provider(googleAuth, "google")
handler.RegisterOAuth2Provider(githubAuth, "github")
```
When external providers are registered they take priority; the login form is used as fallback when no providers are configured.
---
### Using `security.OAuthServer` standalone
The authorization server lives in `pkg/security` and can be used with any HTTP framework independently of `resolvemcp`:
```go
oauthSrv := security.NewOAuthServer(security.OAuthServerConfig{
Issuer: "https://api.example.com",
}, auth)
oauthSrv.RegisterExternalProvider(googleAuth, "google")
mux := http.NewServeMux()
mux.Handle("/", oauthSrv.HTTPHandler()) // mounts all OAuth2 routes
mux.Handle("/mcp/", myMCPHandler)
http.ListenAndServe(":8080", mux)
```
---
### Cookie-based flow (legacy)
For simple setups without full MCP OAuth2 compliance, use the legacy helpers that set a session cookie after external provider login:
```go
resolvemcp.SetupMuxOAuth2Routes(r, auth, resolvemcp.OAuth2RouteConfig{
ProviderName: "google",
LoginPath: "/auth/google/login",
CallbackPath: "/auth/google/callback",
AfterLoginRedirect: "/",
})
resolvemcp.SetupMuxRoutesWithAuth(r, handler, securityList)
```
--- ---

View File

@@ -16,17 +16,20 @@ import (
"github.com/bitechdev/ResolveSpec/pkg/logger" "github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/modelregistry" "github.com/bitechdev/ResolveSpec/pkg/modelregistry"
"github.com/bitechdev/ResolveSpec/pkg/reflection" "github.com/bitechdev/ResolveSpec/pkg/reflection"
"github.com/bitechdev/ResolveSpec/pkg/security"
) )
// Handler exposes registered database models as MCP tools and resources. // Handler exposes registered database models as MCP tools and resources.
type Handler struct { type Handler struct {
db common.Database db common.Database
registry common.ModelRegistry registry common.ModelRegistry
hooks *HookRegistry hooks *HookRegistry
mcpServer *server.MCPServer mcpServer *server.MCPServer
config Config config Config
name string name string
version string version string
oauth2Regs []oauth2Registration
oauthSrv *security.OAuthServer
} }
// NewHandler creates a Handler with the given database, model registry, and config. // NewHandler creates a Handler with the given database, model registry, and config.
@@ -197,8 +200,19 @@ func (h *Handler) getSchemaAndTable(defaultSchema, entity string, model interfac
return defaultSchema, entity return defaultSchema, entity
} }
// recoverPanic catches a panic from the current goroutine and returns it as an error.
// Usage: defer recoverPanic(&returnedErr)
func recoverPanic(err *error) {
if r := recover(); r != nil {
msg := fmt.Sprintf("%v", r)
logger.Error("[resolvemcp] panic recovered: %s", msg)
*err = fmt.Errorf("internal error: %s", msg)
}
}
// executeRead reads records from the database and returns raw data + metadata. // executeRead reads records from the database and returns raw data + metadata.
func (h *Handler) executeRead(ctx context.Context, schema, entity, id string, options common.RequestOptions) (interface{}, *common.Metadata, error) { func (h *Handler) executeRead(ctx context.Context, schema, entity, id string, options common.RequestOptions) (_ interface{}, _ *common.Metadata, retErr error) {
defer recoverPanic(&retErr)
model, err := h.registry.GetModelByEntity(schema, entity) model, err := h.registry.GetModelByEntity(schema, entity)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("model not found: %w", err) return nil, nil, fmt.Errorf("model not found: %w", err)
@@ -254,15 +268,6 @@ func (h *Handler) executeRead(ctx context.Context, schema, entity, id string, op
query = query.ColumnExpr(fmt.Sprintf("(%s) AS %s", cu.Expression, cu.Name)) query = query.ColumnExpr(fmt.Sprintf("(%s) AS %s", cu.Expression, cu.Name))
} }
// Preloads
if len(options.Preload) > 0 {
var err error
query, err = h.applyPreloads(model, query, options.Preload)
if err != nil {
return nil, nil, fmt.Errorf("failed to apply preloads: %w", err)
}
}
// Filters // Filters
query = h.applyFilters(query, options.Filters) query = h.applyFilters(query, options.Filters)
@@ -304,7 +309,7 @@ func (h *Handler) executeRead(ctx context.Context, schema, entity, id string, op
} }
} }
// Count // Count — must happen before preloads are applied; Bun panics when counting with relations.
total, err := query.Count(ctx) total, err := query.Count(ctx)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("error counting records: %w", err) return nil, nil, fmt.Errorf("error counting records: %w", err)
@@ -318,6 +323,15 @@ func (h *Handler) executeRead(ctx context.Context, schema, entity, id string, op
query = query.Offset(*options.Offset) query = query.Offset(*options.Offset)
} }
// Preloads — applied after count to avoid Bun panic when counting with relations.
if len(options.Preload) > 0 {
var preloadErr error
query, preloadErr = h.applyPreloads(model, query, options.Preload)
if preloadErr != nil {
return nil, nil, fmt.Errorf("failed to apply preloads: %w", preloadErr)
}
}
// BeforeRead hook // BeforeRead hook
hookCtx.Query = query hookCtx.Query = query
if err := h.hooks.Execute(BeforeRead, hookCtx); err != nil { if err := h.hooks.Execute(BeforeRead, hookCtx); err != nil {
@@ -378,7 +392,8 @@ func (h *Handler) executeRead(ctx context.Context, schema, entity, id string, op
} }
// executeCreate inserts one or more records. // executeCreate inserts one or more records.
func (h *Handler) executeCreate(ctx context.Context, schema, entity string, data interface{}) (interface{}, error) { func (h *Handler) executeCreate(ctx context.Context, schema, entity string, data interface{}) (_ interface{}, retErr error) {
defer recoverPanic(&retErr)
model, err := h.registry.GetModelByEntity(schema, entity) model, err := h.registry.GetModelByEntity(schema, entity)
if err != nil { if err != nil {
return nil, fmt.Errorf("model not found: %w", err) return nil, fmt.Errorf("model not found: %w", err)
@@ -462,7 +477,8 @@ func (h *Handler) executeCreate(ctx context.Context, schema, entity string, data
} }
// executeUpdate updates a record by ID. // executeUpdate updates a record by ID.
func (h *Handler) executeUpdate(ctx context.Context, schema, entity, id string, data interface{}) (interface{}, error) { func (h *Handler) executeUpdate(ctx context.Context, schema, entity, id string, data interface{}) (_ interface{}, retErr error) {
defer recoverPanic(&retErr)
model, err := h.registry.GetModelByEntity(schema, entity) model, err := h.registry.GetModelByEntity(schema, entity)
if err != nil { if err != nil {
return nil, fmt.Errorf("model not found: %w", err) return nil, fmt.Errorf("model not found: %w", err)
@@ -572,7 +588,8 @@ func (h *Handler) executeUpdate(ctx context.Context, schema, entity, id string,
} }
// executeDelete deletes a record by ID. // executeDelete deletes a record by ID.
func (h *Handler) executeDelete(ctx context.Context, schema, entity, id string) (interface{}, error) { func (h *Handler) executeDelete(ctx context.Context, schema, entity, id string) (_ interface{}, retErr error) {
defer recoverPanic(&retErr)
if id == "" { if id == "" {
return nil, fmt.Errorf("delete requires an ID") return nil, fmt.Errorf("delete requires an ID")
} }

264
pkg/resolvemcp/oauth2.go Normal file
View File

@@ -0,0 +1,264 @@
package resolvemcp
import (
"encoding/json"
"net/http"
"github.com/gorilla/mux"
"github.com/bitechdev/ResolveSpec/pkg/security"
)
// --------------------------------------------------------------------------
// OAuth2 registration on the Handler
// --------------------------------------------------------------------------
// oauth2Registration stores a configured auth provider and its route config.
type oauth2Registration struct {
auth *security.DatabaseAuthenticator
cfg OAuth2RouteConfig
}
// RegisterOAuth2 attaches an OAuth2 provider to the Handler.
// The login and callback HTTP routes are served by HTTPHandler / StreamableHTTPMux.
// Call this once per provider before serving requests.
//
// Example:
//
// auth := security.NewGoogleAuthenticator(clientID, secret, redirectURL, db)
// handler.RegisterOAuth2(auth, resolvemcp.OAuth2RouteConfig{
// ProviderName: "google",
// LoginPath: "/auth/google/login",
// CallbackPath: "/auth/google/callback",
// AfterLoginRedirect: "/",
// })
func (h *Handler) RegisterOAuth2(auth *security.DatabaseAuthenticator, cfg OAuth2RouteConfig) {
h.oauth2Regs = append(h.oauth2Regs, oauth2Registration{auth: auth, cfg: cfg})
}
// HTTPHandler returns a single http.Handler that serves:
// - MCP OAuth2 authorization server endpoints (when EnableOAuthServer has been called)
// - OAuth2 login + callback routes for every registered provider (legacy cookie flow)
// - The MCP SSE transport wrapped with required authentication middleware
//
// Example:
//
// auth := security.NewGoogleAuthenticator(...)
// handler.RegisterOAuth2(auth, cfg)
// handler.EnableOAuthServer(resolvemcp.OAuthServerConfig{Issuer: "https://api.example.com"})
// security.RegisterSecurityHooks(handler, securityList)
// http.ListenAndServe(":8080", handler.HTTPHandler(securityList))
func (h *Handler) HTTPHandler(securityList *security.SecurityList) http.Handler {
mux := http.NewServeMux()
if h.oauthSrv != nil {
h.mountOAuthServerRoutes(mux)
}
h.mountOAuth2Routes(mux)
mcpHandler := h.AuthedSSEServer(securityList)
basePath := h.config.BasePath
if basePath == "" {
basePath = "/mcp"
}
mux.Handle(basePath+"/sse", mcpHandler)
mux.Handle(basePath+"/message", mcpHandler)
mux.Handle(basePath+"/", http.StripPrefix(basePath, mcpHandler))
return mux
}
// StreamableHTTPMux returns a single http.Handler that serves:
// - MCP OAuth2 authorization server endpoints (when EnableOAuthServer has been called)
// - OAuth2 login + callback routes for every registered provider (legacy cookie flow)
// - The MCP streamable HTTP transport wrapped with required authentication middleware
//
// Example:
//
// http.ListenAndServe(":8080", handler.StreamableHTTPMux(securityList))
func (h *Handler) StreamableHTTPMux(securityList *security.SecurityList) http.Handler {
mux := http.NewServeMux()
if h.oauthSrv != nil {
h.mountOAuthServerRoutes(mux)
}
h.mountOAuth2Routes(mux)
mcpHandler := h.AuthedStreamableHTTPServer(securityList)
basePath := h.config.BasePath
if basePath == "" {
basePath = "/mcp"
}
mux.Handle(basePath+"/", http.StripPrefix(basePath, mcpHandler))
mux.Handle(basePath, mcpHandler)
return mux
}
// mountOAuth2Routes registers all stored OAuth2 login+callback routes onto mux.
func (h *Handler) mountOAuth2Routes(mux *http.ServeMux) {
for _, reg := range h.oauth2Regs {
var cookieOpts []security.SessionCookieOptions
if reg.cfg.CookieOptions != nil {
cookieOpts = append(cookieOpts, *reg.cfg.CookieOptions)
}
mux.Handle(reg.cfg.LoginPath, OAuth2LoginHandler(reg.auth, reg.cfg.ProviderName))
mux.Handle(reg.cfg.CallbackPath, OAuth2CallbackHandler(reg.auth, reg.cfg.ProviderName, reg.cfg.AfterLoginRedirect, cookieOpts...))
}
}
// --------------------------------------------------------------------------
// Auth-wrapped transports
// --------------------------------------------------------------------------
// AuthedSSEServer wraps SSEServer with required authentication middleware from pkg/security.
// The middleware reads the session cookie / Authorization header and populates the user
// context into the request context, making it available to BeforeHandle security hooks.
// Unauthenticated requests receive 401 before reaching any MCP tool.
func (h *Handler) AuthedSSEServer(securityList *security.SecurityList) http.Handler {
return security.NewAuthMiddleware(securityList)(h.SSEServer())
}
// OptionalAuthSSEServer wraps SSEServer with optional authentication middleware.
// Unauthenticated requests continue as guest rather than returning 401.
// Use together with RegisterSecurityHooks and per-model CanPublicRead/Write rules
// to allow mixed public/private access.
func (h *Handler) OptionalAuthSSEServer(securityList *security.SecurityList) http.Handler {
return security.NewOptionalAuthMiddleware(securityList)(h.SSEServer())
}
// AuthedStreamableHTTPServer wraps StreamableHTTPServer with required authentication middleware.
func (h *Handler) AuthedStreamableHTTPServer(securityList *security.SecurityList) http.Handler {
return security.NewAuthMiddleware(securityList)(h.StreamableHTTPServer())
}
// OptionalAuthStreamableHTTPServer wraps StreamableHTTPServer with optional authentication middleware.
func (h *Handler) OptionalAuthStreamableHTTPServer(securityList *security.SecurityList) http.Handler {
return security.NewOptionalAuthMiddleware(securityList)(h.StreamableHTTPServer())
}
// --------------------------------------------------------------------------
// OAuth2 route config and standalone handlers
// --------------------------------------------------------------------------
// OAuth2RouteConfig configures the OAuth2 HTTP endpoints for a single provider.
type OAuth2RouteConfig struct {
// ProviderName is the OAuth2 provider name as registered with WithOAuth2()
// (e.g. "google", "github", "microsoft").
ProviderName string
// LoginPath is the HTTP path that redirects the browser to the OAuth2 provider
// (e.g. "/auth/google/login").
LoginPath string
// CallbackPath is the HTTP path that the OAuth2 provider redirects back to
// (e.g. "/auth/google/callback"). Must match the RedirectURL in OAuth2Config.
CallbackPath string
// AfterLoginRedirect is the URL to redirect the browser to after a successful
// login. When empty the LoginResponse JSON is written directly to the response.
AfterLoginRedirect string
// CookieOptions customises the session cookie written on successful login.
// Defaults to HttpOnly, Secure, SameSite=Lax when nil.
CookieOptions *security.SessionCookieOptions
}
// OAuth2LoginHandler returns an http.HandlerFunc that redirects the browser to
// the OAuth2 provider's authorization URL.
//
// Register it on any router:
//
// mux.Handle("/auth/google/login", resolvemcp.OAuth2LoginHandler(auth, "google"))
func OAuth2LoginHandler(auth *security.DatabaseAuthenticator, providerName string) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
state, err := auth.OAuth2GenerateState()
if err != nil {
http.Error(w, "failed to generate state", http.StatusInternalServerError)
return
}
authURL, err := auth.OAuth2GetAuthURL(providerName, state)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
}
}
// OAuth2CallbackHandler returns an http.HandlerFunc that handles the OAuth2 provider
// callback: exchanges the authorization code for a session token, writes the session
// cookie, then either redirects to afterLoginRedirect or writes the LoginResponse as JSON.
//
// Register it on any router:
//
// mux.Handle("/auth/google/callback", resolvemcp.OAuth2CallbackHandler(auth, "google", "/dashboard"))
func OAuth2CallbackHandler(auth *security.DatabaseAuthenticator, providerName, afterLoginRedirect string, cookieOpts ...security.SessionCookieOptions) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
code := r.URL.Query().Get("code")
state := r.URL.Query().Get("state")
if code == "" {
http.Error(w, "missing code parameter", http.StatusBadRequest)
return
}
loginResp, err := auth.OAuth2HandleCallback(r.Context(), providerName, code, state)
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
security.SetSessionCookie(w, loginResp, cookieOpts...)
if afterLoginRedirect != "" {
http.Redirect(w, r, afterLoginRedirect, http.StatusTemporaryRedirect)
return
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(loginResp) //nolint:errcheck
}
}
// --------------------------------------------------------------------------
// Gorilla Mux convenience helpers
// --------------------------------------------------------------------------
// SetupMuxOAuth2Routes registers the OAuth2 login and callback routes on a Gorilla Mux router.
//
// Example:
//
// resolvemcp.SetupMuxOAuth2Routes(r, auth, resolvemcp.OAuth2RouteConfig{
// ProviderName: "google", LoginPath: "/auth/google/login",
// CallbackPath: "/auth/google/callback", AfterLoginRedirect: "/",
// })
func SetupMuxOAuth2Routes(muxRouter *mux.Router, auth *security.DatabaseAuthenticator, cfg OAuth2RouteConfig) {
var cookieOpts []security.SessionCookieOptions
if cfg.CookieOptions != nil {
cookieOpts = append(cookieOpts, *cfg.CookieOptions)
}
muxRouter.Handle(cfg.LoginPath,
OAuth2LoginHandler(auth, cfg.ProviderName),
).Methods(http.MethodGet)
muxRouter.Handle(cfg.CallbackPath,
OAuth2CallbackHandler(auth, cfg.ProviderName, cfg.AfterLoginRedirect, cookieOpts...),
).Methods(http.MethodGet)
}
// SetupMuxRoutesWithAuth mounts the MCP SSE endpoints on a Gorilla Mux router
// with required authentication middleware applied.
func SetupMuxRoutesWithAuth(muxRouter *mux.Router, handler *Handler, securityList *security.SecurityList) {
basePath := handler.config.BasePath
h := handler.AuthedSSEServer(securityList)
muxRouter.Handle(basePath+"/sse", h).Methods(http.MethodGet, http.MethodOptions)
muxRouter.Handle(basePath+"/message", h).Methods(http.MethodPost, http.MethodOptions)
muxRouter.PathPrefix(basePath).Handler(http.StripPrefix(basePath, h))
}
// SetupMuxStreamableHTTPRoutesWithAuth mounts the MCP streamable HTTP endpoint on a
// Gorilla Mux router with required authentication middleware applied.
func SetupMuxStreamableHTTPRoutesWithAuth(muxRouter *mux.Router, handler *Handler, securityList *security.SecurityList) {
basePath := handler.config.BasePath
h := handler.AuthedStreamableHTTPServer(securityList)
muxRouter.PathPrefix(basePath).Handler(http.StripPrefix(basePath, h))
}

View File

@@ -0,0 +1,51 @@
package resolvemcp
import (
"net/http"
"github.com/bitechdev/ResolveSpec/pkg/security"
)
// EnableOAuthServer activates the MCP-standard OAuth2 authorization server on this Handler.
//
// Pass a DatabaseAuthenticator to enable direct username/password login — the server acts as
// its own identity provider and renders a login form at /oauth/authorize. Pass nil to use
// only external providers registered via RegisterOAuth2Provider.
//
// After calling this, HTTPHandler and StreamableHTTPMux serve the full set of RFC-compliant
// endpoints required by MCP clients alongside the MCP transport:
//
// GET /.well-known/oauth-authorization-server RFC 8414 — auto-discovery
// POST /oauth/register RFC 7591 — dynamic client registration
// GET /oauth/authorize OAuth 2.1 + PKCE — start login
// POST /oauth/authorize Login form submission (password flow)
// POST /oauth/token Bearer token exchange + refresh
// GET /oauth/provider/callback External provider redirect target
func (h *Handler) EnableOAuthServer(cfg security.OAuthServerConfig, auth *security.DatabaseAuthenticator) {
h.oauthSrv = security.NewOAuthServer(cfg, auth)
// Wire any external providers already registered via RegisterOAuth2
for _, reg := range h.oauth2Regs {
h.oauthSrv.RegisterExternalProvider(reg.auth, reg.cfg.ProviderName)
}
}
// RegisterOAuth2Provider adds an external OAuth2 provider to the MCP OAuth2 authorization server.
// EnableOAuthServer must be called before this. The auth must have been configured with
// WithOAuth2(providerName, ...) for the given provider name.
func (h *Handler) RegisterOAuth2Provider(auth *security.DatabaseAuthenticator, providerName string) {
if h.oauthSrv != nil {
h.oauthSrv.RegisterExternalProvider(auth, providerName)
}
}
// mountOAuthServerRoutes mounts the security.OAuthServer's HTTP handler onto mux.
func (h *Handler) mountOAuthServerRoutes(mux *http.ServeMux) {
oauthHandler := h.oauthSrv.HTTPHandler()
// Delegate all /oauth/ and /.well-known/ paths to the OAuth server
mux.Handle("/.well-known/", oauthHandler)
mux.Handle("/oauth/", oauthHandler)
if h.oauthSrv != nil {
// Also mount the external provider callback path if it differs from /oauth/
mux.Handle(h.oauthSrv.ProviderCallbackPath(), oauthHandler)
}
}

View File

@@ -12,6 +12,7 @@ Type-safe, composable security system for ResolveSpec with support for authentic
-**Testable** - Easy to mock and test -**Testable** - Easy to mock and test
-**Extensible** - Implement custom providers for your needs -**Extensible** - Implement custom providers for your needs
-**Stored Procedures** - All database operations use PostgreSQL stored procedures for security and maintainability -**Stored Procedures** - All database operations use PostgreSQL stored procedures for security and maintainability
-**OAuth2 Authorization Server** - Built-in OAuth 2.1 + PKCE server (RFC 8414, 7591, 7009, 7662) with login form and external provider federation
## Stored Procedure Architecture ## Stored Procedure Architecture
@@ -38,6 +39,12 @@ Type-safe, composable security system for ResolveSpec with support for authentic
| `resolvespec_jwt_logout` | JWT token blacklist | JWTAuthenticator | | `resolvespec_jwt_logout` | JWT token blacklist | JWTAuthenticator |
| `resolvespec_column_security` | Load column rules | DatabaseColumnSecurityProvider | | `resolvespec_column_security` | Load column rules | DatabaseColumnSecurityProvider |
| `resolvespec_row_security` | Load row templates | DatabaseRowSecurityProvider | | `resolvespec_row_security` | Load row templates | DatabaseRowSecurityProvider |
| `resolvespec_oauth_register_client` | Persist OAuth2 client (RFC 7591) | OAuthServer / DatabaseAuthenticator |
| `resolvespec_oauth_get_client` | Retrieve OAuth2 client by ID | OAuthServer / DatabaseAuthenticator |
| `resolvespec_oauth_save_code` | Persist authorization code | OAuthServer / DatabaseAuthenticator |
| `resolvespec_oauth_exchange_code` | Consume authorization code (single-use) | OAuthServer / DatabaseAuthenticator |
| `resolvespec_oauth_introspect` | Token introspection (RFC 7662) | OAuthServer / DatabaseAuthenticator |
| `resolvespec_oauth_revoke` | Token revocation (RFC 7009) | OAuthServer / DatabaseAuthenticator |
See `database_schema.sql` for complete stored procedure definitions and examples. See `database_schema.sql` for complete stored procedure definitions and examples.
@@ -897,6 +904,155 @@ securityList := security.NewSecurityList(provider)
restheadspec.RegisterSecurityHooks(handler, securityList) // or funcspec/resolvespec restheadspec.RegisterSecurityHooks(handler, securityList) // or funcspec/resolvespec
``` ```
## OAuth2 Authorization Server
`OAuthServer` is a generic OAuth 2.1 + PKCE authorization server. It is not tied to any spec — `pkg/resolvemcp` uses it, but it can be used standalone with any `http.ServeMux`.
### Endpoints
| Method | Path | RFC |
|--------|------|-----|
| `GET` | `/.well-known/oauth-authorization-server` | RFC 8414 — server metadata |
| `POST` | `/oauth/register` | RFC 7591 — dynamic client registration |
| `GET` | `/oauth/authorize` | OAuth 2.1 — start authorization / provider selection |
| `POST` | `/oauth/authorize` | OAuth 2.1 — login form submission |
| `POST` | `/oauth/token` | OAuth 2.1 — code exchange + refresh |
| `POST` | `/oauth/revoke` | RFC 7009 — token revocation |
| `POST` | `/oauth/introspect` | RFC 7662 — token introspection |
| `GET` | `{ProviderCallbackPath}` | External provider redirect target |
### Config
```go
cfg := security.OAuthServerConfig{
Issuer: "https://example.com", // Required — token issuer URL
ProviderCallbackPath: "/oauth/provider/callback", // External provider redirect target
LoginTitle: "My App Login", // HTML login page title
PersistClients: true, // Store clients in DB (multi-instance safe)
PersistCodes: true, // Store codes in DB (multi-instance safe)
DefaultScopes: []string{"openid", "profile"}, // Returned when no scope requested
AccessTokenTTL: time.Hour,
AuthCodeTTL: 5 * time.Minute,
}
```
| Field | Default | Notes |
|-------|---------|-------|
| `Issuer` | — | Required |
| `ProviderCallbackPath` | `/oauth/provider/callback` | |
| `LoginTitle` | `"Login"` | |
| `PersistClients` | `false` | Set `true` for multi-instance |
| `PersistCodes` | `false` | Set `true` for multi-instance |
| `DefaultScopes` | `nil` | |
| `AccessTokenTTL` | `1h` | |
| `AuthCodeTTL` | `5m` | |
### Operating Modes
**Mode 1 — Direct login (username/password form)**
Pass a `*DatabaseAuthenticator` to `NewOAuthServer`. The server renders a login form at `GET /oauth/authorize` and issues tokens via the stored session after login.
```go
auth := security.NewDatabaseAuthenticator(db)
srv := security.NewOAuthServer(cfg, auth)
```
**Mode 2 — External provider federation**
Pass `nil` as auth and register external providers. The authorize page shows a provider selection UI.
```go
srv := security.NewOAuthServer(cfg, nil)
srv.RegisterExternalProvider(googleAuth, "google")
srv.RegisterExternalProvider(githubAuth, "github")
```
**Mode 3 — Both**
Pass auth for the login form and also register external providers. The authorize page shows both a login form and provider buttons.
```go
srv := security.NewOAuthServer(cfg, auth)
srv.RegisterExternalProvider(googleAuth, "google")
```
### Standalone Usage
```go
mux := http.NewServeMux()
mux.Handle("/.well-known/", srv.HTTPHandler())
mux.Handle("/oauth/", srv.HTTPHandler())
mux.Handle(cfg.ProviderCallbackPath, srv.HTTPHandler())
http.ListenAndServe(":8080", mux)
```
### DB Persistence
When `PersistClients: true` or `PersistCodes: true`, the server calls the corresponding `DatabaseAuthenticator` methods. Both flags default to `false` (in-memory maps). Enable both for multi-instance deployments.
Requires `oauth_clients` and `oauth_codes` tables + 6 stored procedures from `database_schema.sql`.
#### New DB Types
```go
type OAuthServerClient struct {
ClientID string `json:"client_id"`
RedirectURIs []string `json:"redirect_uris"`
ClientName string `json:"client_name,omitempty"`
GrantTypes []string `json:"grant_types"`
AllowedScopes []string `json:"allowed_scopes,omitempty"`
}
type OAuthCode struct {
Code string `json:"code"`
ClientID string `json:"client_id"`
RedirectURI string `json:"redirect_uri"`
ClientState string `json:"client_state,omitempty"`
CodeChallenge string `json:"code_challenge"`
CodeChallengeMethod string `json:"code_challenge_method"`
SessionToken string `json:"session_token"`
Scopes []string `json:"scopes,omitempty"`
ExpiresAt time.Time `json:"expires_at"`
}
type OAuthTokenInfo struct {
Active bool `json:"active"`
Sub string `json:"sub,omitempty"`
Username string `json:"username,omitempty"`
Email string `json:"email,omitempty"`
Roles []string `json:"roles,omitempty"`
Exp int64 `json:"exp,omitempty"`
Iat int64 `json:"iat,omitempty"`
}
```
#### DatabaseAuthenticator OAuth Methods
```go
auth.OAuthRegisterClient(ctx, client) // RFC 7591 — persist client
auth.OAuthGetClient(ctx, clientID) // retrieve client
auth.OAuthSaveCode(ctx, code) // persist authorization code
auth.OAuthExchangeCode(ctx, code) // consume code (single-use, deletes on read)
auth.OAuthIntrospectToken(ctx, token) // RFC 7662 — returns OAuthTokenInfo
auth.OAuthRevokeToken(ctx, token) // RFC 7009 — revoke session
```
#### SQLNames Fields
```go
type SQLNames struct {
// ... existing fields ...
OAuthRegisterClient string // default: "resolvespec_oauth_register_client"
OAuthGetClient string // default: "resolvespec_oauth_get_client"
OAuthSaveCode string // default: "resolvespec_oauth_save_code"
OAuthExchangeCode string // default: "resolvespec_oauth_exchange_code"
OAuthIntrospect string // default: "resolvespec_oauth_introspect"
OAuthRevoke string // default: "resolvespec_oauth_revoke"
}
```
The main changes: The main changes:
1. Security package no longer knows about specific spec types 1. Security package no longer knows about specific spec types
2. Each spec registers its own security hooks 2. Each spec registers its own security hooks

View File

@@ -1397,3 +1397,173 @@ $$ LANGUAGE plpgsql;
-- Get credentials by username -- Get credentials by username
-- SELECT * FROM resolvespec_passkey_get_credentials_by_username('admin'); -- SELECT * FROM resolvespec_passkey_get_credentials_by_username('admin');
-- ============================================
-- OAuth2 Server Tables (OAuthServer persistence)
-- ============================================
-- oauth_clients: persistent RFC 7591 registered clients
CREATE TABLE IF NOT EXISTS oauth_clients (
id SERIAL PRIMARY KEY,
client_id VARCHAR(255) NOT NULL UNIQUE,
redirect_uris TEXT[] NOT NULL,
client_name VARCHAR(255),
grant_types TEXT[] DEFAULT ARRAY['authorization_code'],
allowed_scopes TEXT[] DEFAULT ARRAY['openid','profile','email'],
is_active BOOLEAN DEFAULT true,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
-- oauth_codes: short-lived authorization codes (for multi-instance deployments)
CREATE TABLE IF NOT EXISTS oauth_codes (
id SERIAL PRIMARY KEY,
code VARCHAR(255) NOT NULL UNIQUE,
client_id VARCHAR(255) NOT NULL REFERENCES oauth_clients(client_id) ON DELETE CASCADE,
redirect_uri TEXT NOT NULL,
client_state TEXT,
code_challenge VARCHAR(255) NOT NULL,
code_challenge_method VARCHAR(10) DEFAULT 'S256',
session_token TEXT NOT NULL,
scopes TEXT[],
expires_at TIMESTAMP NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
CREATE INDEX IF NOT EXISTS idx_oauth_codes_code ON oauth_codes(code);
CREATE INDEX IF NOT EXISTS idx_oauth_codes_expires ON oauth_codes(expires_at);
-- ============================================
-- OAuth2 Server Stored Procedures
-- ============================================
CREATE OR REPLACE FUNCTION resolvespec_oauth_register_client(p_data jsonb)
RETURNS TABLE(p_success bool, p_error text, p_data jsonb)
LANGUAGE plpgsql AS $$
DECLARE
v_client_id text;
v_row jsonb;
BEGIN
v_client_id := p_data->>'client_id';
INSERT INTO oauth_clients (client_id, redirect_uris, client_name, grant_types, allowed_scopes)
VALUES (
v_client_id,
ARRAY(SELECT jsonb_array_elements_text(p_data->'redirect_uris')),
p_data->>'client_name',
COALESCE(ARRAY(SELECT jsonb_array_elements_text(p_data->'grant_types')), ARRAY['authorization_code']),
COALESCE(ARRAY(SELECT jsonb_array_elements_text(p_data->'allowed_scopes')), ARRAY['openid','profile','email'])
)
RETURNING to_jsonb(oauth_clients.*) INTO v_row;
RETURN QUERY SELECT true, null::text, v_row;
EXCEPTION WHEN OTHERS THEN
RETURN QUERY SELECT false, SQLERRM, null::jsonb;
END;
$$;
CREATE OR REPLACE FUNCTION resolvespec_oauth_get_client(p_client_id text)
RETURNS TABLE(p_success bool, p_error text, p_data jsonb)
LANGUAGE plpgsql AS $$
DECLARE
v_row jsonb;
BEGIN
SELECT to_jsonb(oauth_clients.*)
INTO v_row
FROM oauth_clients
WHERE client_id = p_client_id AND is_active = true;
IF v_row IS NULL THEN
RETURN QUERY SELECT false, 'client not found'::text, null::jsonb;
ELSE
RETURN QUERY SELECT true, null::text, v_row;
END IF;
END;
$$;
CREATE OR REPLACE FUNCTION resolvespec_oauth_save_code(p_data jsonb)
RETURNS TABLE(p_success bool, p_error text)
LANGUAGE plpgsql AS $$
BEGIN
INSERT INTO oauth_codes (code, client_id, redirect_uri, client_state, code_challenge, code_challenge_method, session_token, scopes, expires_at)
VALUES (
p_data->>'code',
p_data->>'client_id',
p_data->>'redirect_uri',
p_data->>'client_state',
p_data->>'code_challenge',
COALESCE(p_data->>'code_challenge_method', 'S256'),
p_data->>'session_token',
ARRAY(SELECT jsonb_array_elements_text(p_data->'scopes')),
(p_data->>'expires_at')::timestamp
);
RETURN QUERY SELECT true, null::text;
EXCEPTION WHEN OTHERS THEN
RETURN QUERY SELECT false, SQLERRM;
END;
$$;
CREATE OR REPLACE FUNCTION resolvespec_oauth_exchange_code(p_code text)
RETURNS TABLE(p_success bool, p_error text, p_data jsonb)
LANGUAGE plpgsql AS $$
DECLARE
v_row jsonb;
BEGIN
DELETE FROM oauth_codes
WHERE code = p_code AND expires_at > now()
RETURNING jsonb_build_object(
'client_id', client_id,
'redirect_uri', redirect_uri,
'client_state', client_state,
'code_challenge', code_challenge,
'code_challenge_method', code_challenge_method,
'session_token', session_token,
'scopes', to_jsonb(scopes)
) INTO v_row;
IF v_row IS NULL THEN
RETURN QUERY SELECT false, 'invalid or expired code'::text, null::jsonb;
ELSE
RETURN QUERY SELECT true, null::text, v_row;
END IF;
END;
$$;
CREATE OR REPLACE FUNCTION resolvespec_oauth_introspect(p_token text)
RETURNS TABLE(p_success bool, p_error text, p_data jsonb)
LANGUAGE plpgsql AS $$
DECLARE
v_row jsonb;
BEGIN
SELECT jsonb_build_object(
'active', true,
'sub', u.id::text,
'username', u.username,
'email', u.email,
'user_level', u.user_level,
'roles', to_jsonb(string_to_array(COALESCE(u.roles, ''), ',')),
'exp', EXTRACT(EPOCH FROM s.expires_at)::bigint,
'iat', EXTRACT(EPOCH FROM s.created_at)::bigint
)
INTO v_row
FROM user_sessions s
JOIN users u ON u.id = s.user_id
WHERE s.session_token = p_token
AND s.expires_at > now()
AND u.is_active = true;
IF v_row IS NULL THEN
RETURN QUERY SELECT true, null::text, '{"active":false}'::jsonb;
ELSE
RETURN QUERY SELECT true, null::text, v_row;
END IF;
END;
$$;
CREATE OR REPLACE FUNCTION resolvespec_oauth_revoke(p_token text)
RETURNS TABLE(p_success bool, p_error text)
LANGUAGE plpgsql AS $$
BEGIN
DELETE FROM user_sessions WHERE session_token = p_token;
RETURN QUERY SELECT true, null::text;
END;
$$;

View File

@@ -0,0 +1,859 @@
package security
import (
"context"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"net/http"
"net/url"
"strings"
"sync"
"time"
)
// OAuthServerConfig configures the MCP-standard OAuth2 authorization server.
type OAuthServerConfig struct {
// Issuer is the public base URL of this server (e.g. "https://api.example.com").
// Used in /.well-known/oauth-authorization-server and to build endpoint URLs.
Issuer string
// ProviderCallbackPath is the path on this server that external OAuth2 providers
// redirect back to. Defaults to "/oauth/provider/callback".
ProviderCallbackPath string
// LoginTitle is shown on the built-in login form when the server acts as its own
// identity provider. Defaults to "MCP Login".
LoginTitle string
// PersistClients stores registered clients in the database when a DatabaseAuthenticator is provided.
// Clients registered during a session survive server restarts.
PersistClients bool
// PersistCodes stores authorization codes in the database.
// Useful for multi-instance deployments. Defaults to in-memory.
PersistCodes bool
// DefaultScopes lists scopes advertised in server metadata. Defaults to ["openid","profile","email"].
DefaultScopes []string
// AccessTokenTTL is the issued token lifetime. Defaults to 24h.
AccessTokenTTL time.Duration
// AuthCodeTTL is the auth code lifetime. Defaults to 2 minutes.
AuthCodeTTL time.Duration
}
// oauthClient is a dynamically registered OAuth2 client (RFC 7591).
type oauthClient struct {
ClientID string `json:"client_id"`
RedirectURIs []string `json:"redirect_uris"`
ClientName string `json:"client_name,omitempty"`
GrantTypes []string `json:"grant_types"`
AllowedScopes []string `json:"allowed_scopes,omitempty"`
}
// pendingAuth tracks an in-progress authorization code exchange.
type pendingAuth struct {
ClientID string
RedirectURI string
ClientState string
CodeChallenge string
CodeChallengeMethod string
ProviderName string // empty = password login
ExpiresAt time.Time
SessionToken string // set after authentication completes
Scopes []string // requested scopes
}
// externalProvider pairs a DatabaseAuthenticator with its provider name.
type externalProvider struct {
auth *DatabaseAuthenticator
providerName string
}
// OAuthServer implements the MCP-standard OAuth2 authorization server (OAuth 2.1 + PKCE).
//
// It can act as both:
// - A direct identity provider using DatabaseAuthenticator username/password login
// - A federation layer that delegates authentication to external OAuth2 providers
// (Google, GitHub, Microsoft, etc.) registered via RegisterExternalProvider
//
// The server exposes these RFC-compliant endpoints:
//
// GET /.well-known/oauth-authorization-server RFC 8414 — server metadata discovery
// POST /oauth/register RFC 7591 — dynamic client registration
// GET /oauth/authorize OAuth 2.1 + PKCE — start authorization
// POST /oauth/authorize Direct login form submission
// POST /oauth/token Token exchange and refresh
// POST /oauth/revoke RFC 7009 — token revocation
// POST /oauth/introspect RFC 7662 — token introspection
// GET {ProviderCallbackPath} Internal — external provider callback
type OAuthServer struct {
cfg OAuthServerConfig
auth *DatabaseAuthenticator // nil = only external providers
providers []externalProvider
mu sync.RWMutex
clients map[string]*oauthClient
pending map[string]*pendingAuth // provider_state → pending (external flow)
codes map[string]*pendingAuth // auth_code → pending (post-auth)
}
// NewOAuthServer creates a new MCP OAuth2 authorization server.
//
// Pass a DatabaseAuthenticator to enable direct username/password login (the server
// acts as its own identity provider). Pass nil to use only external providers.
// External providers are added separately via RegisterExternalProvider.
func NewOAuthServer(cfg OAuthServerConfig, auth *DatabaseAuthenticator) *OAuthServer {
if cfg.ProviderCallbackPath == "" {
cfg.ProviderCallbackPath = "/oauth/provider/callback"
}
if cfg.LoginTitle == "" {
cfg.LoginTitle = "Sign in"
}
if len(cfg.DefaultScopes) == 0 {
cfg.DefaultScopes = []string{"openid", "profile", "email"}
}
if cfg.AccessTokenTTL == 0 {
cfg.AccessTokenTTL = 24 * time.Hour
}
if cfg.AuthCodeTTL == 0 {
cfg.AuthCodeTTL = 2 * time.Minute
}
s := &OAuthServer{
cfg: cfg,
auth: auth,
clients: make(map[string]*oauthClient),
pending: make(map[string]*pendingAuth),
codes: make(map[string]*pendingAuth),
}
go s.cleanupExpired()
return s
}
// RegisterExternalProvider adds an external OAuth2 provider (Google, GitHub, Microsoft, etc.)
// that handles user authentication via redirect. The DatabaseAuthenticator must have been
// configured with WithOAuth2(providerName, ...) before calling this.
// Multiple providers can be registered; the first is used as the default.
func (s *OAuthServer) RegisterExternalProvider(auth *DatabaseAuthenticator, providerName string) {
s.providers = append(s.providers, externalProvider{auth: auth, providerName: providerName})
}
// ProviderCallbackPath returns the configured path for external provider callbacks.
func (s *OAuthServer) ProviderCallbackPath() string {
return s.cfg.ProviderCallbackPath
}
// HTTPHandler returns an http.Handler that serves all RFC-required OAuth2 endpoints.
// Mount it at the root of your HTTP server alongside the MCP transport.
//
// mux := http.NewServeMux()
// mux.Handle("/", oauthServer.HTTPHandler())
// mux.Handle("/mcp/", mcpTransport)
func (s *OAuthServer) HTTPHandler() http.Handler {
mux := http.NewServeMux()
mux.HandleFunc("/.well-known/oauth-authorization-server", s.metadataHandler)
mux.HandleFunc("/oauth/register", s.registerHandler)
mux.HandleFunc("/oauth/authorize", s.authorizeHandler)
mux.HandleFunc("/oauth/token", s.tokenHandler)
mux.HandleFunc("/oauth/revoke", s.revokeHandler)
mux.HandleFunc("/oauth/introspect", s.introspectHandler)
mux.HandleFunc(s.cfg.ProviderCallbackPath, s.providerCallbackHandler)
return mux
}
// cleanupExpired removes stale pending auths and codes every 5 minutes.
func (s *OAuthServer) cleanupExpired() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for range ticker.C {
now := time.Now()
s.mu.Lock()
for k, p := range s.pending {
if now.After(p.ExpiresAt) {
delete(s.pending, k)
}
}
for k, p := range s.codes {
if now.After(p.ExpiresAt) {
delete(s.codes, k)
}
}
s.mu.Unlock()
}
}
// --------------------------------------------------------------------------
// RFC 8414 — Server metadata
// --------------------------------------------------------------------------
func (s *OAuthServer) metadataHandler(w http.ResponseWriter, r *http.Request) {
issuer := s.cfg.Issuer
meta := map[string]interface{}{
"issuer": issuer,
"authorization_endpoint": issuer + "/oauth/authorize",
"token_endpoint": issuer + "/oauth/token",
"registration_endpoint": issuer + "/oauth/register",
"revocation_endpoint": issuer + "/oauth/revoke",
"introspection_endpoint": issuer + "/oauth/introspect",
"scopes_supported": s.cfg.DefaultScopes,
"response_types_supported": []string{"code"},
"grant_types_supported": []string{"authorization_code", "refresh_token"},
"code_challenge_methods_supported": []string{"S256"},
"token_endpoint_auth_methods_supported": []string{"none"},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(meta) //nolint:errcheck
}
// --------------------------------------------------------------------------
// RFC 7591 — Dynamic client registration
// --------------------------------------------------------------------------
func (s *OAuthServer) registerHandler(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
var req struct {
RedirectURIs []string `json:"redirect_uris"`
ClientName string `json:"client_name"`
GrantTypes []string `json:"grant_types"`
AllowedScopes []string `json:"allowed_scopes"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeOAuthError(w, "invalid_request", "malformed JSON", http.StatusBadRequest)
return
}
if len(req.RedirectURIs) == 0 {
writeOAuthError(w, "invalid_request", "redirect_uris required", http.StatusBadRequest)
return
}
grantTypes := req.GrantTypes
if len(grantTypes) == 0 {
grantTypes = []string{"authorization_code"}
}
allowedScopes := req.AllowedScopes
if len(allowedScopes) == 0 {
allowedScopes = s.cfg.DefaultScopes
}
clientID, err := randomOAuthToken()
if err != nil {
http.Error(w, "server error", http.StatusInternalServerError)
return
}
client := &oauthClient{
ClientID: clientID,
RedirectURIs: req.RedirectURIs,
ClientName: req.ClientName,
GrantTypes: grantTypes,
AllowedScopes: allowedScopes,
}
if s.cfg.PersistClients && s.auth != nil {
dbClient := &OAuthServerClient{
ClientID: client.ClientID,
RedirectURIs: client.RedirectURIs,
ClientName: client.ClientName,
GrantTypes: client.GrantTypes,
AllowedScopes: client.AllowedScopes,
}
if _, err := s.auth.OAuthRegisterClient(r.Context(), dbClient); err != nil {
http.Error(w, "server error", http.StatusInternalServerError)
return
}
}
s.mu.Lock()
s.clients[clientID] = client
s.mu.Unlock()
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusCreated)
json.NewEncoder(w).Encode(client) //nolint:errcheck
}
// --------------------------------------------------------------------------
// Authorization endpoint — GET + POST /oauth/authorize
// --------------------------------------------------------------------------
func (s *OAuthServer) authorizeHandler(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodGet:
s.authorizeGet(w, r)
case http.MethodPost:
s.authorizePost(w, r)
default:
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
}
}
// authorizeGet validates the request and either:
// - Redirects to an external provider (if providers are registered)
// - Renders a login form (if the server is its own identity provider)
func (s *OAuthServer) authorizeGet(w http.ResponseWriter, r *http.Request) {
q := r.URL.Query()
clientID := q.Get("client_id")
redirectURI := q.Get("redirect_uri")
clientState := q.Get("state")
codeChallenge := q.Get("code_challenge")
codeChallengeMethod := q.Get("code_challenge_method")
providerName := q.Get("provider")
scopeStr := q.Get("scope")
var scopes []string
if scopeStr != "" {
scopes = strings.Fields(scopeStr)
}
if q.Get("response_type") != "code" {
writeOAuthError(w, "unsupported_response_type", "only 'code' is supported", http.StatusBadRequest)
return
}
if codeChallenge == "" {
writeOAuthError(w, "invalid_request", "code_challenge required (PKCE S256)", http.StatusBadRequest)
return
}
if codeChallengeMethod != "" && codeChallengeMethod != "S256" {
writeOAuthError(w, "invalid_request", "only S256 code_challenge_method is supported", http.StatusBadRequest)
return
}
client, ok := s.lookupOrFetchClient(r.Context(), clientID)
if !ok {
writeOAuthError(w, "invalid_client", "unknown client_id", http.StatusBadRequest)
return
}
if !oauthSliceContains(client.RedirectURIs, redirectURI) {
writeOAuthError(w, "invalid_request", "redirect_uri not registered", http.StatusBadRequest)
return
}
// External provider path
if len(s.providers) > 0 {
s.redirectToExternalProvider(w, r, clientID, redirectURI, clientState, codeChallenge, codeChallengeMethod, providerName, scopes)
return
}
// Direct login form path (server is its own identity provider)
if s.auth == nil {
http.Error(w, "no authentication provider configured", http.StatusInternalServerError)
return
}
s.renderLoginForm(w, r, clientID, redirectURI, clientState, codeChallenge, codeChallengeMethod, scopeStr, "")
}
// authorizePost handles login form submission for the direct login flow.
func (s *OAuthServer) authorizePost(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
http.Error(w, "invalid form", http.StatusBadRequest)
return
}
clientID := r.FormValue("client_id")
redirectURI := r.FormValue("redirect_uri")
clientState := r.FormValue("client_state")
codeChallenge := r.FormValue("code_challenge")
codeChallengeMethod := r.FormValue("code_challenge_method")
username := r.FormValue("username")
password := r.FormValue("password")
scopeStr := r.FormValue("scope")
var scopes []string
if scopeStr != "" {
scopes = strings.Fields(scopeStr)
}
client, ok := s.lookupOrFetchClient(r.Context(), clientID)
if !ok || !oauthSliceContains(client.RedirectURIs, redirectURI) {
http.Error(w, "invalid client or redirect_uri", http.StatusBadRequest)
return
}
if s.auth == nil {
http.Error(w, "no authentication provider configured", http.StatusInternalServerError)
return
}
loginResp, err := s.auth.Login(r.Context(), LoginRequest{
Username: username,
Password: password,
})
if err != nil {
s.renderLoginForm(w, r, clientID, redirectURI, clientState, codeChallenge, codeChallengeMethod, scopeStr, "Invalid username or password")
return
}
s.issueCodeAndRedirect(w, r, loginResp.Token, clientID, redirectURI, clientState, codeChallenge, codeChallengeMethod, "", scopes)
}
// redirectToExternalProvider stores the pending auth and redirects to the configured provider.
func (s *OAuthServer) redirectToExternalProvider(w http.ResponseWriter, r *http.Request, clientID, redirectURI, clientState, codeChallenge, codeChallengeMethod, providerName string, scopes []string) {
var provider *externalProvider
if providerName != "" {
for i := range s.providers {
if s.providers[i].providerName == providerName {
provider = &s.providers[i]
break
}
}
if provider == nil {
http.Error(w, fmt.Sprintf("provider %q not found", providerName), http.StatusBadRequest)
return
}
} else {
provider = &s.providers[0]
}
providerState, err := randomOAuthToken()
if err != nil {
http.Error(w, "server error", http.StatusInternalServerError)
return
}
pending := &pendingAuth{
ClientID: clientID,
RedirectURI: redirectURI,
ClientState: clientState,
CodeChallenge: codeChallenge,
CodeChallengeMethod: codeChallengeMethod,
ProviderName: provider.providerName,
ExpiresAt: time.Now().Add(10 * time.Minute),
Scopes: scopes,
}
s.mu.Lock()
s.pending[providerState] = pending
s.mu.Unlock()
authURL, err := provider.auth.OAuth2GetAuthURL(provider.providerName, providerState)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
http.Redirect(w, r, authURL, http.StatusFound)
}
// --------------------------------------------------------------------------
// External provider callback — GET {ProviderCallbackPath}
// --------------------------------------------------------------------------
func (s *OAuthServer) providerCallbackHandler(w http.ResponseWriter, r *http.Request) {
code := r.URL.Query().Get("code")
providerState := r.URL.Query().Get("state")
if code == "" {
http.Error(w, "missing code", http.StatusBadRequest)
return
}
s.mu.Lock()
pending, ok := s.pending[providerState]
if ok {
delete(s.pending, providerState)
}
s.mu.Unlock()
if !ok || time.Now().After(pending.ExpiresAt) {
http.Error(w, "invalid or expired state", http.StatusBadRequest)
return
}
provider := s.providerByName(pending.ProviderName)
if provider == nil {
http.Error(w, fmt.Sprintf("provider %q not found", pending.ProviderName), http.StatusInternalServerError)
return
}
loginResp, err := provider.auth.OAuth2HandleCallback(r.Context(), pending.ProviderName, code, providerState)
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
s.issueCodeAndRedirect(w, r, loginResp.Token,
pending.ClientID, pending.RedirectURI, pending.ClientState,
pending.CodeChallenge, pending.CodeChallengeMethod, pending.ProviderName, pending.Scopes)
}
// issueCodeAndRedirect generates a short-lived auth code and redirects to the MCP client.
func (s *OAuthServer) issueCodeAndRedirect(w http.ResponseWriter, r *http.Request, sessionToken, clientID, redirectURI, clientState, codeChallenge, codeChallengeMethod, providerName string, scopes []string) {
authCode, err := randomOAuthToken()
if err != nil {
http.Error(w, "server error", http.StatusInternalServerError)
return
}
pending := &pendingAuth{
ClientID: clientID,
RedirectURI: redirectURI,
ClientState: clientState,
CodeChallenge: codeChallenge,
CodeChallengeMethod: codeChallengeMethod,
ProviderName: providerName,
SessionToken: sessionToken,
ExpiresAt: time.Now().Add(s.cfg.AuthCodeTTL),
Scopes: scopes,
}
if s.cfg.PersistCodes && s.auth != nil {
oauthCode := &OAuthCode{
Code: authCode,
ClientID: clientID,
RedirectURI: redirectURI,
ClientState: clientState,
CodeChallenge: codeChallenge,
CodeChallengeMethod: codeChallengeMethod,
SessionToken: sessionToken,
Scopes: scopes,
ExpiresAt: pending.ExpiresAt,
}
if err := s.auth.OAuthSaveCode(r.Context(), oauthCode); err != nil {
http.Error(w, "server error", http.StatusInternalServerError)
return
}
} else {
s.mu.Lock()
s.codes[authCode] = pending
s.mu.Unlock()
}
redirectURL, err := url.Parse(redirectURI)
if err != nil {
http.Error(w, "invalid redirect_uri", http.StatusInternalServerError)
return
}
qp := redirectURL.Query()
qp.Set("code", authCode)
if clientState != "" {
qp.Set("state", clientState)
}
redirectURL.RawQuery = qp.Encode()
http.Redirect(w, r, redirectURL.String(), http.StatusFound)
}
// --------------------------------------------------------------------------
// Token endpoint — POST /oauth/token
// --------------------------------------------------------------------------
func (s *OAuthServer) tokenHandler(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
if err := r.ParseForm(); err != nil {
writeOAuthError(w, "invalid_request", "cannot parse form", http.StatusBadRequest)
return
}
switch r.FormValue("grant_type") {
case "authorization_code":
s.handleAuthCodeGrant(w, r)
case "refresh_token":
s.handleRefreshGrant(w, r)
default:
writeOAuthError(w, "unsupported_grant_type", "", http.StatusBadRequest)
}
}
func (s *OAuthServer) handleAuthCodeGrant(w http.ResponseWriter, r *http.Request) {
code := r.FormValue("code")
redirectURI := r.FormValue("redirect_uri")
clientID := r.FormValue("client_id")
codeVerifier := r.FormValue("code_verifier")
if code == "" || codeVerifier == "" {
writeOAuthError(w, "invalid_request", "code and code_verifier required", http.StatusBadRequest)
return
}
var sessionToken string
var scopes []string
if s.cfg.PersistCodes && s.auth != nil {
oauthCode, err := s.auth.OAuthExchangeCode(r.Context(), code)
if err != nil {
writeOAuthError(w, "invalid_grant", "code expired or invalid", http.StatusBadRequest)
return
}
if oauthCode.ClientID != clientID {
writeOAuthError(w, "invalid_client", "", http.StatusBadRequest)
return
}
if oauthCode.RedirectURI != redirectURI {
writeOAuthError(w, "invalid_grant", "redirect_uri mismatch", http.StatusBadRequest)
return
}
if !validatePKCESHA256(oauthCode.CodeChallenge, codeVerifier) {
writeOAuthError(w, "invalid_grant", "code_verifier invalid", http.StatusBadRequest)
return
}
sessionToken = oauthCode.SessionToken
scopes = oauthCode.Scopes
} else {
s.mu.Lock()
pending, ok := s.codes[code]
if ok {
delete(s.codes, code)
}
s.mu.Unlock()
if !ok || time.Now().After(pending.ExpiresAt) {
writeOAuthError(w, "invalid_grant", "code expired or invalid", http.StatusBadRequest)
return
}
if pending.ClientID != clientID {
writeOAuthError(w, "invalid_client", "", http.StatusBadRequest)
return
}
if pending.RedirectURI != redirectURI {
writeOAuthError(w, "invalid_grant", "redirect_uri mismatch", http.StatusBadRequest)
return
}
if !validatePKCESHA256(pending.CodeChallenge, codeVerifier) {
writeOAuthError(w, "invalid_grant", "code_verifier invalid", http.StatusBadRequest)
return
}
sessionToken = pending.SessionToken
scopes = pending.Scopes
}
writeOAuthToken(w, sessionToken, "", scopes)
}
func (s *OAuthServer) handleRefreshGrant(w http.ResponseWriter, r *http.Request) {
refreshToken := r.FormValue("refresh_token")
providerName := r.FormValue("provider")
if refreshToken == "" {
writeOAuthError(w, "invalid_request", "refresh_token required", http.StatusBadRequest)
return
}
// Try external providers first, then fall back to DatabaseAuthenticator
provider := s.providerByName(providerName)
if provider != nil {
loginResp, err := provider.auth.OAuth2RefreshToken(r.Context(), refreshToken, providerName)
if err != nil {
writeOAuthError(w, "invalid_grant", err.Error(), http.StatusBadRequest)
return
}
writeOAuthToken(w, loginResp.Token, loginResp.RefreshToken, nil)
return
}
if s.auth != nil {
loginResp, err := s.auth.RefreshToken(r.Context(), refreshToken)
if err != nil {
writeOAuthError(w, "invalid_grant", err.Error(), http.StatusBadRequest)
return
}
writeOAuthToken(w, loginResp.Token, loginResp.RefreshToken, nil)
return
}
writeOAuthError(w, "invalid_grant", "no provider available for refresh", http.StatusBadRequest)
}
// --------------------------------------------------------------------------
// RFC 7009 — Token revocation
// --------------------------------------------------------------------------
func (s *OAuthServer) revokeHandler(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
if err := r.ParseForm(); err != nil {
w.WriteHeader(http.StatusOK)
return
}
token := r.FormValue("token")
if token == "" {
w.WriteHeader(http.StatusOK)
return
}
if s.auth != nil {
s.auth.OAuthRevokeToken(r.Context(), token) //nolint:errcheck
}
w.WriteHeader(http.StatusOK)
}
// --------------------------------------------------------------------------
// RFC 7662 — Token introspection
// --------------------------------------------------------------------------
func (s *OAuthServer) introspectHandler(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
if err := r.ParseForm(); err != nil {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{"active":false}`)) //nolint:errcheck
return
}
token := r.FormValue("token")
w.Header().Set("Content-Type", "application/json")
if token == "" || s.auth == nil {
w.Write([]byte(`{"active":false}`)) //nolint:errcheck
return
}
info, err := s.auth.OAuthIntrospectToken(r.Context(), token)
if err != nil {
w.Write([]byte(`{"active":false}`)) //nolint:errcheck
return
}
json.NewEncoder(w).Encode(info) //nolint:errcheck
}
// --------------------------------------------------------------------------
// Login form (direct identity provider mode)
// --------------------------------------------------------------------------
func (s *OAuthServer) renderLoginForm(w http.ResponseWriter, r *http.Request, clientID, redirectURI, clientState, codeChallenge, codeChallengeMethod, scope, errMsg string) {
w.Header().Set("Content-Type", "text/html; charset=utf-8")
errHTML := ""
if errMsg != "" {
errHTML = `<p style="color:red">` + errMsg + `</p>`
}
fmt.Fprintf(w, loginFormHTML,
s.cfg.LoginTitle,
s.cfg.LoginTitle,
errHTML,
clientID,
htmlEscape(redirectURI),
htmlEscape(clientState),
htmlEscape(codeChallenge),
htmlEscape(codeChallengeMethod),
htmlEscape(scope),
)
}
const loginFormHTML = `<!DOCTYPE html>
<html><head><meta charset="utf-8"><title>%s</title>
<style>body{font-family:sans-serif;display:flex;justify-content:center;align-items:center;min-height:100vh;margin:0;background:#f5f5f5}
.card{background:#fff;padding:2rem;border-radius:8px;box-shadow:0 2px 8px rgba(0,0,0,.15);width:320px}
h2{margin:0 0 1.5rem;font-size:1.25rem}
label{display:block;margin-bottom:.25rem;font-size:.875rem;color:#555}
input[type=text],input[type=password]{width:100%%;box-sizing:border-box;padding:.5rem;border:1px solid #ccc;border-radius:4px;margin-bottom:1rem;font-size:1rem}
button{width:100%%;padding:.6rem;background:#0070f3;color:#fff;border:none;border-radius:4px;font-size:1rem;cursor:pointer}
button:hover{background:#005fd4}.err{color:#d32f2f;margin-bottom:1rem;font-size:.875rem}</style>
</head><body><div class="card">
<h2>%s</h2>%s
<form method="POST" action="/oauth/authorize">
<input type="hidden" name="client_id" value="%s">
<input type="hidden" name="redirect_uri" value="%s">
<input type="hidden" name="client_state" value="%s">
<input type="hidden" name="code_challenge" value="%s">
<input type="hidden" name="code_challenge_method" value="%s">
<input type="hidden" name="scope" value="%s">
<label>Username</label><input type="text" name="username" autofocus autocomplete="username">
<label>Password</label><input type="password" name="password" autocomplete="current-password">
<button type="submit">Sign in</button>
</form></div></body></html>`
// --------------------------------------------------------------------------
// Helpers
// --------------------------------------------------------------------------
// lookupOrFetchClient checks in-memory first, then DB if PersistClients is enabled.
func (s *OAuthServer) lookupOrFetchClient(ctx context.Context, clientID string) (*oauthClient, bool) {
s.mu.RLock()
c, ok := s.clients[clientID]
s.mu.RUnlock()
if ok {
return c, true
}
if !s.cfg.PersistClients || s.auth == nil {
return nil, false
}
dbClient, err := s.auth.OAuthGetClient(ctx, clientID)
if err != nil {
return nil, false
}
c = &oauthClient{
ClientID: dbClient.ClientID,
RedirectURIs: dbClient.RedirectURIs,
ClientName: dbClient.ClientName,
GrantTypes: dbClient.GrantTypes,
AllowedScopes: dbClient.AllowedScopes,
}
s.mu.Lock()
s.clients[clientID] = c
s.mu.Unlock()
return c, true
}
func (s *OAuthServer) providerByName(name string) *externalProvider {
for i := range s.providers {
if s.providers[i].providerName == name {
return &s.providers[i]
}
}
// If name is empty and only one provider exists, return it
if name == "" && len(s.providers) == 1 {
return &s.providers[0]
}
return nil
}
func validatePKCESHA256(challenge, verifier string) bool {
h := sha256.Sum256([]byte(verifier))
return base64.RawURLEncoding.EncodeToString(h[:]) == challenge
}
func randomOAuthToken() (string, error) {
b := make([]byte, 32)
if _, err := rand.Read(b); err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(b), nil
}
func oauthSliceContains(slice []string, s string) bool {
for _, v := range slice {
if strings.EqualFold(v, s) {
return true
}
}
return false
}
func writeOAuthToken(w http.ResponseWriter, accessToken, refreshToken string, scopes []string) {
resp := map[string]interface{}{
"access_token": accessToken,
"token_type": "Bearer",
"expires_in": 86400,
}
if refreshToken != "" {
resp["refresh_token"] = refreshToken
}
if len(scopes) > 0 {
resp["scope"] = strings.Join(scopes, " ")
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Cache-Control", "no-store")
w.Header().Set("Pragma", "no-cache")
json.NewEncoder(w).Encode(resp) //nolint:errcheck
}
func writeOAuthError(w http.ResponseWriter, errCode, description string, status int) {
resp := map[string]string{"error": errCode}
if description != "" {
resp["error_description"] = description
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
json.NewEncoder(w).Encode(resp) //nolint:errcheck
}
func htmlEscape(s string) string {
s = strings.ReplaceAll(s, "&", "&amp;")
s = strings.ReplaceAll(s, `"`, "&#34;")
s = strings.ReplaceAll(s, "<", "&lt;")
s = strings.ReplaceAll(s, ">", "&gt;")
return s
}

View File

@@ -0,0 +1,202 @@
package security
import (
"context"
"encoding/json"
"fmt"
"time"
)
// OAuthServerClient is a persisted RFC 7591 registered OAuth2 client.
type OAuthServerClient struct {
ClientID string `json:"client_id"`
RedirectURIs []string `json:"redirect_uris"`
ClientName string `json:"client_name,omitempty"`
GrantTypes []string `json:"grant_types"`
AllowedScopes []string `json:"allowed_scopes,omitempty"`
}
// OAuthCode is a short-lived authorization code.
type OAuthCode struct {
Code string `json:"code"`
ClientID string `json:"client_id"`
RedirectURI string `json:"redirect_uri"`
ClientState string `json:"client_state,omitempty"`
CodeChallenge string `json:"code_challenge"`
CodeChallengeMethod string `json:"code_challenge_method"`
SessionToken string `json:"session_token"`
Scopes []string `json:"scopes,omitempty"`
ExpiresAt time.Time `json:"expires_at"`
}
// OAuthTokenInfo is the RFC 7662 token introspection response.
type OAuthTokenInfo struct {
Active bool `json:"active"`
Sub string `json:"sub,omitempty"`
Username string `json:"username,omitempty"`
Email string `json:"email,omitempty"`
Roles []string `json:"roles,omitempty"`
Exp int64 `json:"exp,omitempty"`
Iat int64 `json:"iat,omitempty"`
}
// OAuthRegisterClient persists an OAuth2 client registration.
func (a *DatabaseAuthenticator) OAuthRegisterClient(ctx context.Context, client *OAuthServerClient) (*OAuthServerClient, error) {
input, err := json.Marshal(client)
if err != nil {
return nil, fmt.Errorf("failed to marshal client: %w", err)
}
var success bool
var errMsg *string
var data []byte
err = a.db.QueryRowContext(ctx, fmt.Sprintf(`
SELECT p_success, p_error, p_data::text
FROM %s($1::jsonb)
`, a.sqlNames.OAuthRegisterClient), input).Scan(&success, &errMsg, &data)
if err != nil {
return nil, fmt.Errorf("failed to register client: %w", err)
}
if !success {
if errMsg != nil {
return nil, fmt.Errorf("%s", *errMsg)
}
return nil, fmt.Errorf("failed to register client")
}
var result OAuthServerClient
if err := json.Unmarshal(data, &result); err != nil {
return nil, fmt.Errorf("failed to parse registered client: %w", err)
}
return &result, nil
}
// OAuthGetClient retrieves a registered client by ID.
func (a *DatabaseAuthenticator) OAuthGetClient(ctx context.Context, clientID string) (*OAuthServerClient, error) {
var success bool
var errMsg *string
var data []byte
err := a.db.QueryRowContext(ctx, fmt.Sprintf(`
SELECT p_success, p_error, p_data::text
FROM %s($1)
`, a.sqlNames.OAuthGetClient), clientID).Scan(&success, &errMsg, &data)
if err != nil {
return nil, fmt.Errorf("failed to get client: %w", err)
}
if !success {
if errMsg != nil {
return nil, fmt.Errorf("%s", *errMsg)
}
return nil, fmt.Errorf("client not found")
}
var result OAuthServerClient
if err := json.Unmarshal(data, &result); err != nil {
return nil, fmt.Errorf("failed to parse client: %w", err)
}
return &result, nil
}
// OAuthSaveCode persists an authorization code.
func (a *DatabaseAuthenticator) OAuthSaveCode(ctx context.Context, code *OAuthCode) error {
input, err := json.Marshal(code)
if err != nil {
return fmt.Errorf("failed to marshal code: %w", err)
}
var success bool
var errMsg *string
err = a.db.QueryRowContext(ctx, fmt.Sprintf(`
SELECT p_success, p_error
FROM %s($1::jsonb)
`, a.sqlNames.OAuthSaveCode), input).Scan(&success, &errMsg)
if err != nil {
return fmt.Errorf("failed to save code: %w", err)
}
if !success {
if errMsg != nil {
return fmt.Errorf("%s", *errMsg)
}
return fmt.Errorf("failed to save code")
}
return nil
}
// OAuthExchangeCode retrieves and deletes an authorization code (single use).
func (a *DatabaseAuthenticator) OAuthExchangeCode(ctx context.Context, code string) (*OAuthCode, error) {
var success bool
var errMsg *string
var data []byte
err := a.db.QueryRowContext(ctx, fmt.Sprintf(`
SELECT p_success, p_error, p_data::text
FROM %s($1)
`, a.sqlNames.OAuthExchangeCode), code).Scan(&success, &errMsg, &data)
if err != nil {
return nil, fmt.Errorf("failed to exchange code: %w", err)
}
if !success {
if errMsg != nil {
return nil, fmt.Errorf("%s", *errMsg)
}
return nil, fmt.Errorf("invalid or expired code")
}
var result OAuthCode
if err := json.Unmarshal(data, &result); err != nil {
return nil, fmt.Errorf("failed to parse code data: %w", err)
}
result.Code = code
return &result, nil
}
// OAuthIntrospectToken validates a token and returns its metadata (RFC 7662).
func (a *DatabaseAuthenticator) OAuthIntrospectToken(ctx context.Context, token string) (*OAuthTokenInfo, error) {
var success bool
var errMsg *string
var data []byte
err := a.db.QueryRowContext(ctx, fmt.Sprintf(`
SELECT p_success, p_error, p_data::text
FROM %s($1)
`, a.sqlNames.OAuthIntrospect), token).Scan(&success, &errMsg, &data)
if err != nil {
return nil, fmt.Errorf("failed to introspect token: %w", err)
}
if !success {
if errMsg != nil {
return nil, fmt.Errorf("%s", *errMsg)
}
return nil, fmt.Errorf("introspection failed")
}
var result OAuthTokenInfo
if err := json.Unmarshal(data, &result); err != nil {
return nil, fmt.Errorf("failed to parse token info: %w", err)
}
return &result, nil
}
// OAuthRevokeToken revokes a token by deleting the session (RFC 7009).
func (a *DatabaseAuthenticator) OAuthRevokeToken(ctx context.Context, token string) error {
var success bool
var errMsg *string
err := a.db.QueryRowContext(ctx, fmt.Sprintf(`
SELECT p_success, p_error
FROM %s($1)
`, a.sqlNames.OAuthRevoke), token).Scan(&success, &errMsg)
if err != nil {
return fmt.Errorf("failed to revoke token: %w", err)
}
if !success {
if errMsg != nil {
return fmt.Errorf("%s", *errMsg)
}
return fmt.Errorf("failed to revoke token")
}
return nil
}

View File

@@ -54,6 +54,13 @@ type SQLNames struct {
OAuthUpdateRefreshToken string // default: "resolvespec_oauth_updaterefreshtoken" OAuthUpdateRefreshToken string // default: "resolvespec_oauth_updaterefreshtoken"
OAuthGetUser string // default: "resolvespec_oauth_getuser" OAuthGetUser string // default: "resolvespec_oauth_getuser"
// OAuth2 server procedures (OAuthServer persistence)
OAuthRegisterClient string // default: "resolvespec_oauth_register_client"
OAuthGetClient string // default: "resolvespec_oauth_get_client"
OAuthSaveCode string // default: "resolvespec_oauth_save_code"
OAuthExchangeCode string // default: "resolvespec_oauth_exchange_code"
OAuthIntrospect string // default: "resolvespec_oauth_introspect"
OAuthRevoke string // default: "resolvespec_oauth_revoke"
} }
// DefaultSQLNames returns an SQLNames with all default resolvespec_* values. // DefaultSQLNames returns an SQLNames with all default resolvespec_* values.
@@ -93,6 +100,13 @@ func DefaultSQLNames() *SQLNames {
OAuthGetRefreshToken: "resolvespec_oauth_getrefreshtoken", OAuthGetRefreshToken: "resolvespec_oauth_getrefreshtoken",
OAuthUpdateRefreshToken: "resolvespec_oauth_updaterefreshtoken", OAuthUpdateRefreshToken: "resolvespec_oauth_updaterefreshtoken",
OAuthGetUser: "resolvespec_oauth_getuser", OAuthGetUser: "resolvespec_oauth_getuser",
OAuthRegisterClient: "resolvespec_oauth_register_client",
OAuthGetClient: "resolvespec_oauth_get_client",
OAuthSaveCode: "resolvespec_oauth_save_code",
OAuthExchangeCode: "resolvespec_oauth_exchange_code",
OAuthIntrospect: "resolvespec_oauth_introspect",
OAuthRevoke: "resolvespec_oauth_revoke",
} }
} }
@@ -191,6 +205,24 @@ func MergeSQLNames(base, override *SQLNames) *SQLNames {
if override.OAuthGetUser != "" { if override.OAuthGetUser != "" {
merged.OAuthGetUser = override.OAuthGetUser merged.OAuthGetUser = override.OAuthGetUser
} }
if override.OAuthRegisterClient != "" {
merged.OAuthRegisterClient = override.OAuthRegisterClient
}
if override.OAuthGetClient != "" {
merged.OAuthGetClient = override.OAuthGetClient
}
if override.OAuthSaveCode != "" {
merged.OAuthSaveCode = override.OAuthSaveCode
}
if override.OAuthExchangeCode != "" {
merged.OAuthExchangeCode = override.OAuthExchangeCode
}
if override.OAuthIntrospect != "" {
merged.OAuthIntrospect = override.OAuthIntrospect
}
if override.OAuthRevoke != "" {
merged.OAuthRevoke = override.OAuthRevoke
}
return &merged return &merged
} }

View File

@@ -3,6 +3,7 @@ package server_test
import ( import (
"context" "context"
"fmt" "fmt"
"log"
"net/http" "net/http"
"time" "time"
@@ -29,18 +30,18 @@ func ExampleManager_basic() {
GZIP: true, // Enable GZIP compression GZIP: true, // Enable GZIP compression
}) })
if err != nil { if err != nil {
panic(err) log.Fatal(err)
} }
// Start all servers // Start all servers
if err := mgr.StartAll(); err != nil { if err := mgr.StartAll(); err != nil {
panic(err) log.Fatal(err)
} }
// Server is now running... // Server is now running...
// When done, stop gracefully // When done, stop gracefully
if err := mgr.StopAll(); err != nil { if err := mgr.StopAll(); err != nil {
panic(err) log.Fatal(err)
} }
} }
@@ -61,7 +62,7 @@ func ExampleManager_https() {
SSLKey: "/path/to/key.pem", SSLKey: "/path/to/key.pem",
}) })
if err != nil { if err != nil {
panic(err) log.Fatal(err)
} }
// Option 2: Self-signed certificate (for development) // Option 2: Self-signed certificate (for development)
@@ -73,27 +74,27 @@ func ExampleManager_https() {
SelfSignedSSL: true, SelfSignedSSL: true,
}) })
if err != nil { if err != nil {
panic(err) log.Fatal(err)
} }
// Option 3: Let's Encrypt / AutoTLS (for production) // Option 3: Let's Encrypt / AutoTLS (for production)
_, err = mgr.Add(server.Config{ _, err = mgr.Add(server.Config{
Name: "https-server-letsencrypt", Name: "https-server-letsencrypt",
Host: "0.0.0.0", Host: "0.0.0.0",
Port: 443, Port: 443,
Handler: handler, Handler: handler,
AutoTLS: true, AutoTLS: true,
AutoTLSDomains: []string{"example.com", "www.example.com"}, AutoTLSDomains: []string{"example.com", "www.example.com"},
AutoTLSEmail: "admin@example.com", AutoTLSEmail: "admin@example.com",
AutoTLSCacheDir: "./certs-cache", AutoTLSCacheDir: "./certs-cache",
}) })
if err != nil { if err != nil {
panic(err) log.Fatal(err)
} }
// Start all servers // Start all servers
if err := mgr.StartAll(); err != nil { if err := mgr.StartAll(); err != nil {
panic(err) log.Fatal(err)
} }
// Cleanup // Cleanup
@@ -136,7 +137,7 @@ func ExampleManager_gracefulShutdown() {
IdleTimeout: 120 * time.Second, IdleTimeout: 120 * time.Second,
}) })
if err != nil { if err != nil {
panic(err) log.Fatal(err)
} }
// Start servers and block until shutdown signal (SIGINT/SIGTERM) // Start servers and block until shutdown signal (SIGINT/SIGTERM)
@@ -164,7 +165,7 @@ func ExampleManager_healthChecks() {
Handler: mux, Handler: mux,
}) })
if err != nil { if err != nil {
panic(err) log.Fatal(err)
} }
// Add health and readiness endpoints // Add health and readiness endpoints
@@ -173,7 +174,7 @@ func ExampleManager_healthChecks() {
// Start the server // Start the server
if err := mgr.StartAll(); err != nil { if err := mgr.StartAll(); err != nil {
panic(err) log.Fatal(err)
} }
// Health check returns: // Health check returns:
@@ -204,7 +205,7 @@ func ExampleManager_multipleServers() {
GZIP: true, GZIP: true,
}) })
if err != nil { if err != nil {
panic(err) log.Fatal(err)
} }
// Admin API server (different port) // Admin API server (different port)
@@ -218,7 +219,7 @@ func ExampleManager_multipleServers() {
Handler: adminHandler, Handler: adminHandler,
}) })
if err != nil { if err != nil {
panic(err) log.Fatal(err)
} }
// Metrics server (internal only) // Metrics server (internal only)
@@ -232,18 +233,18 @@ func ExampleManager_multipleServers() {
Handler: metricsHandler, Handler: metricsHandler,
}) })
if err != nil { if err != nil {
panic(err) log.Fatal(err)
} }
// Start all servers at once // Start all servers at once
if err := mgr.StartAll(); err != nil { if err := mgr.StartAll(); err != nil {
panic(err) log.Fatal(err)
} }
// Get specific server instance // Get specific server instance
publicInstance, err := mgr.Get("public-api") publicInstance, err := mgr.Get("public-api")
if err != nil { if err != nil {
panic(err) log.Fatal(err)
} }
fmt.Printf("Public API running on: %s\n", publicInstance.Addr()) fmt.Printf("Public API running on: %s\n", publicInstance.Addr())
@@ -253,7 +254,7 @@ func ExampleManager_multipleServers() {
// Stop all servers gracefully (in parallel) // Stop all servers gracefully (in parallel)
if err := mgr.StopAll(); err != nil { if err := mgr.StopAll(); err != nil {
panic(err) log.Fatal(err)
} }
} }
@@ -273,11 +274,11 @@ func ExampleManager_monitoring() {
Handler: handler, Handler: handler,
}) })
if err != nil { if err != nil {
panic(err) log.Fatal(err)
} }
if err := mgr.StartAll(); err != nil { if err := mgr.StartAll(); err != nil {
panic(err) log.Fatal(err)
} }
// Check server status // Check server status