mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2026-04-09 09:26:24 +00:00
Compare commits
9 Commits
v1.0.70
...
feature-au
| Author | SHA1 | Date | |
|---|---|---|---|
| 6502b55797 | |||
| aa095d6bfd | |||
| ea5bb38ee4 | |||
| c2e2c9b873 | |||
| 4adf94fe37 | |||
|
|
405a04a192 | ||
|
|
c1b16d363a | ||
|
|
568df8c6d6 | ||
|
|
aa362c77da |
2
go.mod
2
go.mod
@@ -15,6 +15,7 @@ require (
|
||||
github.com/gorilla/websocket v1.5.3
|
||||
github.com/jackc/pgx/v5 v5.8.0
|
||||
github.com/klauspost/compress v1.18.2
|
||||
github.com/mark3labs/mcp-go v0.46.0
|
||||
github.com/mattn/go-sqlite3 v1.14.33
|
||||
github.com/microsoft/go-mssqldb v1.9.5
|
||||
github.com/mochi-mqtt/server/v2 v2.7.9
|
||||
@@ -88,7 +89,6 @@ require (
|
||||
github.com/jinzhu/now v1.1.5 // indirect
|
||||
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
|
||||
github.com/magiconair/properties v1.8.10 // indirect
|
||||
github.com/mark3labs/mcp-go v0.46.0 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/moby/docker-image-spec v1.3.1 // indirect
|
||||
github.com/moby/go-archive v0.1.0 // indirect
|
||||
|
||||
@@ -26,6 +26,7 @@ type Connection interface {
|
||||
Bun() (*bun.DB, error)
|
||||
GORM() (*gorm.DB, error)
|
||||
Native() (*sql.DB, error)
|
||||
DB() (*sql.DB, error)
|
||||
|
||||
// Common Database interface (for SQL databases)
|
||||
Database() (common.Database, error)
|
||||
@@ -224,6 +225,11 @@ func (c *sqlConnection) Native() (*sql.DB, error) {
|
||||
return c.nativeDB, nil
|
||||
}
|
||||
|
||||
// DB returns the underlying *sql.DB connection
|
||||
func (c *sqlConnection) DB() (*sql.DB, error) {
|
||||
return c.Native()
|
||||
}
|
||||
|
||||
// Bun returns a Bun ORM instance wrapping the native connection
|
||||
func (c *sqlConnection) Bun() (*bun.DB, error) {
|
||||
if c == nil {
|
||||
@@ -645,6 +651,11 @@ func (c *mongoConnection) Native() (*sql.DB, error) {
|
||||
return nil, ErrNotSQLDatabase
|
||||
}
|
||||
|
||||
// DB returns an error for MongoDB connections
|
||||
func (c *mongoConnection) DB() (*sql.DB, error) {
|
||||
return nil, ErrNotSQLDatabase
|
||||
}
|
||||
|
||||
// Database returns an error for MongoDB connections
|
||||
func (c *mongoConnection) Database() (common.Database, error) {
|
||||
return nil, ErrNotSQLDatabase
|
||||
|
||||
@@ -3,6 +3,7 @@ package providers_test
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/dbmanager"
|
||||
@@ -29,14 +30,14 @@ func ExamplePostgresListener_basic() {
|
||||
ctx := context.Background()
|
||||
|
||||
if err := provider.Connect(ctx, cfg); err != nil {
|
||||
panic(fmt.Sprintf("Failed to connect: %v", err))
|
||||
log.Fatalf("Failed to connect: %v", err)
|
||||
}
|
||||
defer provider.Close()
|
||||
|
||||
// Get listener
|
||||
listener, err := provider.GetListener(ctx)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Failed to get listener: %v", err))
|
||||
log.Fatalf("Failed to get listener: %v", err)
|
||||
}
|
||||
|
||||
// Subscribe to a channel with a handler
|
||||
@@ -44,13 +45,13 @@ func ExamplePostgresListener_basic() {
|
||||
fmt.Printf("Received notification on %s: %s\n", channel, payload)
|
||||
})
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Failed to listen: %v", err))
|
||||
log.Fatalf("Failed to listen: %v", err)
|
||||
}
|
||||
|
||||
// Send a notification
|
||||
err = listener.Notify(ctx, "user_events", `{"event":"user_created","user_id":123}`)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Failed to notify: %v", err))
|
||||
log.Fatalf("Failed to notify: %v", err)
|
||||
}
|
||||
|
||||
// Wait for notification to be processed
|
||||
@@ -58,7 +59,7 @@ func ExamplePostgresListener_basic() {
|
||||
|
||||
// Unsubscribe from the channel
|
||||
if err := listener.Unlisten("user_events"); err != nil {
|
||||
panic(fmt.Sprintf("Failed to unlisten: %v", err))
|
||||
log.Fatalf("Failed to unlisten: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -80,13 +81,13 @@ func ExamplePostgresListener_multipleChannels() {
|
||||
ctx := context.Background()
|
||||
|
||||
if err := provider.Connect(ctx, cfg); err != nil {
|
||||
panic(fmt.Sprintf("Failed to connect: %v", err))
|
||||
log.Fatalf("Failed to connect: %v", err)
|
||||
}
|
||||
defer provider.Close()
|
||||
|
||||
listener, err := provider.GetListener(ctx)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Failed to get listener: %v", err))
|
||||
log.Fatalf("Failed to get listener: %v", err)
|
||||
}
|
||||
|
||||
// Listen to multiple channels
|
||||
@@ -97,7 +98,7 @@ func ExamplePostgresListener_multipleChannels() {
|
||||
fmt.Printf("[%s] %s\n", ch, payload)
|
||||
})
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Failed to listen on %s: %v", channel, err))
|
||||
log.Fatalf("Failed to listen on %s: %v", channel, err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -140,14 +141,14 @@ func ExamplePostgresListener_withDBManager() {
|
||||
|
||||
provider := providers.NewPostgresProvider()
|
||||
if err := provider.Connect(ctx, cfg); err != nil {
|
||||
panic(err)
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer provider.Close()
|
||||
|
||||
// Get listener
|
||||
listener, err := provider.GetListener(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Subscribe to application events
|
||||
@@ -186,13 +187,13 @@ func ExamplePostgresListener_errorHandling() {
|
||||
ctx := context.Background()
|
||||
|
||||
if err := provider.Connect(ctx, cfg); err != nil {
|
||||
panic(fmt.Sprintf("Failed to connect: %v", err))
|
||||
log.Fatalf("Failed to connect: %v", err)
|
||||
}
|
||||
defer provider.Close()
|
||||
|
||||
listener, err := provider.GetListener(ctx)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Failed to get listener: %v", err))
|
||||
log.Fatalf("Failed to get listener: %v", err)
|
||||
}
|
||||
|
||||
// The listener automatically reconnects if the connection is lost
|
||||
|
||||
@@ -67,55 +67,289 @@ Each call immediately creates four MCP **tools** and one MCP **resource** for th
|
||||
|
||||
---
|
||||
|
||||
## HTTP / SSE Transport
|
||||
|
||||
The `*server.SSEServer` returned by any of the helpers below implements `http.Handler`, so it works with every Go HTTP framework.
|
||||
## HTTP Transports
|
||||
|
||||
`Config.BasePath` is required and used for all route registration.
|
||||
`Config.BaseURL` is optional — when empty it is detected from each request.
|
||||
|
||||
### Gorilla Mux
|
||||
Two transports are supported: **SSE** (legacy, two-endpoint) and **Streamable HTTP** (recommended, single-endpoint).
|
||||
|
||||
---
|
||||
|
||||
### SSE Transport
|
||||
|
||||
Two endpoints: `GET {BasePath}/sse` (subscribe) + `POST {BasePath}/message` (send).
|
||||
|
||||
#### Gorilla Mux
|
||||
|
||||
```go
|
||||
resolvemcp.SetupMuxRoutes(r, handler)
|
||||
```
|
||||
|
||||
Registers:
|
||||
|
||||
| Route | Method | Description |
|
||||
|---|---|---|
|
||||
| `{BasePath}/sse` | GET | SSE connection — clients subscribe here |
|
||||
| `{BasePath}/message` | POST | JSON-RPC — clients send requests here |
|
||||
| `{BasePath}/*` | any | Full SSE server (convenience prefix) |
|
||||
|
||||
### bunrouter
|
||||
#### bunrouter
|
||||
|
||||
```go
|
||||
resolvemcp.SetupBunRouterRoutes(router, handler)
|
||||
```
|
||||
|
||||
Registers `GET {BasePath}/sse` and `POST {BasePath}/message` on the provided `*bunrouter.Router`.
|
||||
|
||||
### Gin (or any `http.Handler`-compatible framework)
|
||||
|
||||
Use `handler.SSEServer()` to get an `http.Handler` and wrap it with the framework's adapter:
|
||||
#### Gin / net/http / Echo
|
||||
|
||||
```go
|
||||
sse := handler.SSEServer()
|
||||
|
||||
// Gin
|
||||
engine.Any("/mcp/*path", gin.WrapH(sse))
|
||||
|
||||
// net/http
|
||||
http.Handle("/mcp/", sse)
|
||||
|
||||
// Echo
|
||||
e.Any("/mcp/*", echo.WrapHandler(sse))
|
||||
engine.Any("/mcp/*path", gin.WrapH(sse)) // Gin
|
||||
http.Handle("/mcp/", sse) // net/http
|
||||
e.Any("/mcp/*", echo.WrapHandler(sse)) // Echo
|
||||
```
|
||||
|
||||
### Authentication
|
||||
---
|
||||
|
||||
Add middleware before the MCP routes. The handler itself has no auth layer.
|
||||
### Streamable HTTP Transport
|
||||
|
||||
Single endpoint at `{BasePath}`. Handles POST (client→server) and GET (server→client streaming). Preferred for new integrations.
|
||||
|
||||
#### Gorilla Mux
|
||||
|
||||
```go
|
||||
resolvemcp.SetupMuxStreamableHTTPRoutes(r, handler)
|
||||
```
|
||||
|
||||
Mounts the handler at `{BasePath}` (all methods).
|
||||
|
||||
#### bunrouter
|
||||
|
||||
```go
|
||||
resolvemcp.SetupBunRouterStreamableHTTPRoutes(router, handler)
|
||||
```
|
||||
|
||||
Registers GET, POST, DELETE on `{BasePath}`.
|
||||
|
||||
#### Gin / net/http / Echo
|
||||
|
||||
```go
|
||||
h := handler.StreamableHTTPServer()
|
||||
// or: h := resolvemcp.NewStreamableHTTPHandler(handler)
|
||||
|
||||
engine.Any("/mcp", gin.WrapH(h)) // Gin
|
||||
http.Handle("/mcp", h) // net/http
|
||||
e.Any("/mcp", echo.WrapHandler(h)) // Echo
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## OAuth2 Authentication
|
||||
|
||||
`resolvemcp` ships a full **MCP-standard OAuth2 authorization server** (`pkg/security.OAuthServer`) that MCP clients (Claude Desktop, Cursor, etc.) can discover and use automatically.
|
||||
|
||||
It can operate as:
|
||||
- **Its own identity provider** — shows a login form, validates via `DatabaseAuthenticator.Login()`
|
||||
- **An OAuth2 federation layer** — delegates to external providers (Google, GitHub, Microsoft, etc.)
|
||||
- **Both simultaneously**
|
||||
|
||||
### Standard endpoints served
|
||||
|
||||
| Path | Spec | Purpose |
|
||||
|---|---|---|
|
||||
| `GET /.well-known/oauth-authorization-server` | RFC 8414 | MCP client auto-discovery |
|
||||
| `POST /oauth/register` | RFC 7591 | Dynamic client registration |
|
||||
| `GET /oauth/authorize` | OAuth 2.1 + PKCE | Start login (form or provider redirect) |
|
||||
| `POST /oauth/authorize` | — | Login form submission |
|
||||
| `POST /oauth/token` | OAuth 2.1 | Auth code → Bearer token exchange |
|
||||
| `POST /oauth/token` (refresh) | OAuth 2.1 | Refresh token rotation |
|
||||
| `GET /oauth/provider/callback` | Internal | External provider redirect target |
|
||||
|
||||
MCP clients send `Authorization: Bearer <token>` on all subsequent requests.
|
||||
|
||||
---
|
||||
|
||||
### Mode 1 — Direct login (server as identity provider)
|
||||
|
||||
```go
|
||||
import "github.com/bitechdev/ResolveSpec/pkg/security"
|
||||
|
||||
db, _ := sql.Open("postgres", dsn)
|
||||
auth := security.NewDatabaseAuthenticator(db)
|
||||
|
||||
handler := resolvemcp.NewHandlerWithGORM(gormDB, resolvemcp.Config{
|
||||
BaseURL: "https://api.example.com",
|
||||
BasePath: "/mcp",
|
||||
})
|
||||
|
||||
// Enable the OAuth2 server — auth enables the login form
|
||||
handler.EnableOAuthServer(security.OAuthServerConfig{
|
||||
Issuer: "https://api.example.com",
|
||||
}, auth)
|
||||
|
||||
provider, _ := security.NewCompositeSecurityProvider(auth, colSec, rowSec)
|
||||
securityList, _ := security.NewSecurityList(provider)
|
||||
security.RegisterSecurityHooks(handler, securityList)
|
||||
|
||||
http.ListenAndServe(":8080", handler.HTTPHandler(securityList))
|
||||
```
|
||||
|
||||
MCP client flow:
|
||||
1. Discovers server at `/.well-known/oauth-authorization-server`
|
||||
2. Registers itself at `/oauth/register`
|
||||
3. Redirects user to `/oauth/authorize` → login form appears
|
||||
4. On submit, exchanges code at `/oauth/token` → receives `Authorization: Bearer` token
|
||||
5. Uses token on all MCP tool calls
|
||||
|
||||
---
|
||||
|
||||
### Mode 2 — External provider (Google, GitHub, etc.)
|
||||
|
||||
The `RedirectURL` in the provider config must point to `/oauth/provider/callback` on this server.
|
||||
|
||||
```go
|
||||
auth := security.NewDatabaseAuthenticator(db).WithOAuth2(security.OAuth2Config{
|
||||
ClientID: os.Getenv("GOOGLE_CLIENT_ID"),
|
||||
ClientSecret: os.Getenv("GOOGLE_CLIENT_SECRET"),
|
||||
RedirectURL: "https://api.example.com/oauth/provider/callback",
|
||||
Scopes: []string{"openid", "profile", "email"},
|
||||
AuthURL: "https://accounts.google.com/o/oauth2/auth",
|
||||
TokenURL: "https://oauth2.googleapis.com/token",
|
||||
UserInfoURL: "https://www.googleapis.com/oauth2/v2/userinfo",
|
||||
ProviderName: "google",
|
||||
})
|
||||
|
||||
// nil = no password login; Google handles auth
|
||||
handler.EnableOAuthServer(security.OAuthServerConfig{
|
||||
Issuer: "https://api.example.com",
|
||||
}, nil)
|
||||
handler.RegisterOAuth2Provider(auth, "google")
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Mode 3 — Both (login form + external providers)
|
||||
|
||||
```go
|
||||
handler.EnableOAuthServer(security.OAuthServerConfig{
|
||||
Issuer: "https://api.example.com",
|
||||
LoginTitle: "My App Login",
|
||||
}, auth) // auth enables the username/password form
|
||||
|
||||
handler.RegisterOAuth2Provider(googleAuth, "google")
|
||||
handler.RegisterOAuth2Provider(githubAuth, "github")
|
||||
```
|
||||
|
||||
When external providers are registered they take priority; the login form is used as fallback when no providers are configured.
|
||||
|
||||
---
|
||||
|
||||
### Using `security.OAuthServer` standalone
|
||||
|
||||
The authorization server lives in `pkg/security` and can be used with any HTTP framework independently of `resolvemcp`:
|
||||
|
||||
```go
|
||||
oauthSrv := security.NewOAuthServer(security.OAuthServerConfig{
|
||||
Issuer: "https://api.example.com",
|
||||
}, auth)
|
||||
oauthSrv.RegisterExternalProvider(googleAuth, "google")
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle("/", oauthSrv.HTTPHandler()) // mounts all OAuth2 routes
|
||||
mux.Handle("/mcp/", myMCPHandler)
|
||||
http.ListenAndServe(":8080", mux)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Cookie-based flow (legacy)
|
||||
|
||||
For simple setups without full MCP OAuth2 compliance, use the legacy helpers that set a session cookie after external provider login:
|
||||
|
||||
```go
|
||||
resolvemcp.SetupMuxOAuth2Routes(r, auth, resolvemcp.OAuth2RouteConfig{
|
||||
ProviderName: "google",
|
||||
LoginPath: "/auth/google/login",
|
||||
CallbackPath: "/auth/google/callback",
|
||||
AfterLoginRedirect: "/",
|
||||
})
|
||||
resolvemcp.SetupMuxRoutesWithAuth(r, handler, securityList)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Security
|
||||
|
||||
`resolvemcp` integrates with the `security` package to provide per-entity access control, row-level security, and column-level security — the same system used by `resolvespec` and `restheadspec`.
|
||||
|
||||
### Wiring security hooks
|
||||
|
||||
```go
|
||||
import "github.com/bitechdev/ResolveSpec/pkg/security"
|
||||
|
||||
securityList := security.NewSecurityList(mySecurityProvider)
|
||||
resolvemcp.RegisterSecurityHooks(handler, securityList)
|
||||
```
|
||||
|
||||
Call `RegisterSecurityHooks` **once**, after creating the handler and before registering models. It installs these controls automatically:
|
||||
|
||||
| Hook | Effect |
|
||||
|---|---|
|
||||
| `BeforeHandle` | Enforces per-entity operation rules (see below) |
|
||||
| `BeforeRead` | Loads RLS/CLS rules, then injects a user-scoped WHERE clause |
|
||||
| `AfterRead` | Masks/hides columns per column-security rules; writes audit log |
|
||||
| `BeforeUpdate` | Blocks update if `CanUpdate` is false |
|
||||
| `BeforeDelete` | Blocks delete if `CanDelete` is false |
|
||||
|
||||
### Per-entity operation rules
|
||||
|
||||
Use `RegisterModelWithRules` instead of `RegisterModel` to set access rules at registration time:
|
||||
|
||||
```go
|
||||
import "github.com/bitechdev/ResolveSpec/pkg/modelregistry"
|
||||
|
||||
// Read-only entity
|
||||
handler.RegisterModelWithRules("public", "audit_logs", &AuditLog{}, modelregistry.ModelRules{
|
||||
CanRead: true,
|
||||
CanCreate: false,
|
||||
CanUpdate: false,
|
||||
CanDelete: false,
|
||||
})
|
||||
|
||||
// Public read, authenticated write
|
||||
handler.RegisterModelWithRules("public", "products", &Product{}, modelregistry.ModelRules{
|
||||
CanPublicRead: true,
|
||||
CanRead: true,
|
||||
CanCreate: true,
|
||||
CanUpdate: true,
|
||||
CanDelete: false,
|
||||
})
|
||||
```
|
||||
|
||||
To update rules for an already-registered model:
|
||||
|
||||
```go
|
||||
handler.SetModelRules("public", "users", modelregistry.ModelRules{
|
||||
CanRead: true,
|
||||
CanCreate: true,
|
||||
CanUpdate: true,
|
||||
CanDelete: false,
|
||||
})
|
||||
```
|
||||
|
||||
`RegisterModel` (no rules) registers with all-allowed defaults (`CanRead/Create/Update/Delete = true`).
|
||||
|
||||
### ModelRules fields
|
||||
|
||||
| Field | Default | Description |
|
||||
|---|---|---|
|
||||
| `CanPublicRead` | `false` | Allow unauthenticated reads |
|
||||
| `CanPublicCreate` | `false` | Allow unauthenticated creates |
|
||||
| `CanPublicUpdate` | `false` | Allow unauthenticated updates |
|
||||
| `CanPublicDelete` | `false` | Allow unauthenticated deletes |
|
||||
| `CanRead` | `true` | Allow authenticated reads |
|
||||
| `CanCreate` | `true` | Allow authenticated creates |
|
||||
| `CanUpdate` | `true` | Allow authenticated updates |
|
||||
| `CanDelete` | `true` | Allow authenticated deletes |
|
||||
| `SecurityDisabled` | `false` | Skip all security checks for this model |
|
||||
|
||||
---
|
||||
|
||||
@@ -204,6 +438,35 @@ Delete a record by primary key. **Irreversible.**
|
||||
{ "success": true, "data": { ...deleted record... } }
|
||||
```
|
||||
|
||||
### Annotation Tool — `resolvespec_annotate`
|
||||
|
||||
Store or retrieve freeform annotation records for any tool, model, or entity. Registered automatically on every handler.
|
||||
|
||||
| Argument | Type | Description |
|
||||
|---|---|---|
|
||||
| `tool_name` | string (required) | Key to annotate — an MCP tool name (e.g. `read_public_users`), a model name (e.g. `public.users`), or any other identifier. |
|
||||
| `annotations` | object | Annotation data to persist. Omit to retrieve existing annotations instead. |
|
||||
|
||||
**Set annotations** (calls `resolvespec_set_annotation(tool_name, annotations)`):
|
||||
```json
|
||||
{ "tool_name": "read_public_users", "annotations": { "description": "Returns active users", "owner": "platform-team" } }
|
||||
```
|
||||
**Response:**
|
||||
```json
|
||||
{ "success": true, "tool_name": "read_public_users", "action": "set" }
|
||||
```
|
||||
|
||||
**Get annotations** (calls `resolvespec_get_annotation(tool_name)`):
|
||||
```json
|
||||
{ "tool_name": "read_public_users" }
|
||||
```
|
||||
**Response:**
|
||||
```json
|
||||
{ "success": true, "tool_name": "read_public_users", "action": "get", "annotations": { ... } }
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Resource — `{schema}.{entity}`
|
||||
|
||||
Each model is also registered as an MCP resource with URI `schema.entity` (or just `entity` when schema is empty). Reading the resource returns up to 100 records as `application/json`.
|
||||
|
||||
107
pkg/resolvemcp/annotation.go
Normal file
107
pkg/resolvemcp/annotation.go
Normal file
@@ -0,0 +1,107 @@
|
||||
package resolvemcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
)
|
||||
|
||||
const annotationToolName = "resolvespec_annotate"
|
||||
|
||||
// registerAnnotationTool adds the resolvespec_annotate tool to the MCP server.
|
||||
// The tool lets models/entities store and retrieve freeform annotation records
|
||||
// using the resolvespec_set_annotation / resolvespec_get_annotation database procedures.
|
||||
func registerAnnotationTool(h *Handler) {
|
||||
tool := mcp.NewTool(annotationToolName,
|
||||
mcp.WithDescription(
|
||||
"Store or retrieve annotations for any MCP tool, model, or entity.\n\n"+
|
||||
"To set annotations: provide both 'tool_name' and 'annotations'. "+
|
||||
"Calls resolvespec_set_annotation(tool_name, annotations) to persist the data.\n\n"+
|
||||
"To get annotations: provide only 'tool_name'. "+
|
||||
"Calls resolvespec_get_annotation(tool_name) and returns the stored annotations.\n\n"+
|
||||
"'tool_name' may be any identifier: an MCP tool name (e.g. 'read_public_users'), "+
|
||||
"a model/entity name (e.g. 'public.users'), or any other key.",
|
||||
),
|
||||
mcp.WithString("tool_name",
|
||||
mcp.Description("Name of the tool, model, or entity to annotate (e.g. 'read_public_users', 'public.users')."),
|
||||
mcp.Required(),
|
||||
),
|
||||
mcp.WithObject("annotations",
|
||||
mcp.Description("Annotation data to store. Omit to retrieve existing annotations instead of setting them."),
|
||||
),
|
||||
)
|
||||
|
||||
h.mcpServer.AddTool(tool, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
args := req.GetArguments()
|
||||
|
||||
toolName, ok := args["tool_name"].(string)
|
||||
if !ok || toolName == "" {
|
||||
return mcp.NewToolResultError("missing required argument: tool_name"), nil
|
||||
}
|
||||
|
||||
annotations, hasAnnotations := args["annotations"]
|
||||
|
||||
if hasAnnotations && annotations != nil {
|
||||
return executeSetAnnotation(ctx, h, toolName, annotations)
|
||||
}
|
||||
return executeGetAnnotation(ctx, h, toolName)
|
||||
})
|
||||
}
|
||||
|
||||
func executeSetAnnotation(ctx context.Context, h *Handler, toolName string, annotations interface{}) (*mcp.CallToolResult, error) {
|
||||
jsonBytes, err := json.Marshal(annotations)
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("failed to marshal annotations: %v", err)), nil
|
||||
}
|
||||
|
||||
_, err = h.db.Exec(ctx, "SELECT resolvespec_set_annotation($1, $2)", toolName, string(jsonBytes))
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("failed to set annotation: %v", err)), nil
|
||||
}
|
||||
|
||||
return marshalResult(map[string]interface{}{
|
||||
"success": true,
|
||||
"tool_name": toolName,
|
||||
"action": "set",
|
||||
})
|
||||
}
|
||||
|
||||
func executeGetAnnotation(ctx context.Context, h *Handler, toolName string) (*mcp.CallToolResult, error) {
|
||||
var rows []map[string]interface{}
|
||||
err := h.db.Query(ctx, &rows, "SELECT resolvespec_get_annotation($1)", toolName)
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("failed to get annotation: %v", err)), nil
|
||||
}
|
||||
|
||||
var annotations interface{}
|
||||
if len(rows) > 0 {
|
||||
// The procedure returns a single value; extract the first column of the first row.
|
||||
for _, v := range rows[0] {
|
||||
annotations = v
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// If the value is a []byte or string containing JSON, decode it so it round-trips cleanly.
|
||||
switch v := annotations.(type) {
|
||||
case []byte:
|
||||
var decoded interface{}
|
||||
if json.Unmarshal(v, &decoded) == nil {
|
||||
annotations = decoded
|
||||
}
|
||||
case string:
|
||||
var decoded interface{}
|
||||
if json.Unmarshal([]byte(v), &decoded) == nil {
|
||||
annotations = decoded
|
||||
}
|
||||
}
|
||||
|
||||
return marshalResult(map[string]interface{}{
|
||||
"success": true,
|
||||
"tool_name": toolName,
|
||||
"action": "get",
|
||||
"annotations": annotations,
|
||||
})
|
||||
}
|
||||
@@ -46,7 +46,7 @@ func getCursorFilter(
|
||||
reverse := direction < 0
|
||||
|
||||
for _, s := range sortItems {
|
||||
col := strings.TrimSpace(s.Column)
|
||||
col := strings.Trim(strings.TrimSpace(s.Column), "()")
|
||||
if col == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -14,23 +14,27 @@ import (
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/security"
|
||||
)
|
||||
|
||||
// Handler exposes registered database models as MCP tools and resources.
|
||||
type Handler struct {
|
||||
db common.Database
|
||||
registry common.ModelRegistry
|
||||
hooks *HookRegistry
|
||||
mcpServer *server.MCPServer
|
||||
config Config
|
||||
name string
|
||||
version string
|
||||
db common.Database
|
||||
registry common.ModelRegistry
|
||||
hooks *HookRegistry
|
||||
mcpServer *server.MCPServer
|
||||
config Config
|
||||
name string
|
||||
version string
|
||||
oauth2Regs []oauth2Registration
|
||||
oauthSrv *security.OAuthServer
|
||||
}
|
||||
|
||||
// NewHandler creates a Handler with the given database, model registry, and config.
|
||||
func NewHandler(db common.Database, registry common.ModelRegistry, cfg Config) *Handler {
|
||||
return &Handler{
|
||||
h := &Handler{
|
||||
db: db,
|
||||
registry: registry,
|
||||
hooks: NewHookRegistry(),
|
||||
@@ -39,6 +43,8 @@ func NewHandler(db common.Database, registry common.ModelRegistry, cfg Config) *
|
||||
name: "resolvemcp",
|
||||
version: "1.0.0",
|
||||
}
|
||||
registerAnnotationTool(h)
|
||||
return h
|
||||
}
|
||||
|
||||
// Hooks returns the hook registry.
|
||||
@@ -66,12 +72,20 @@ func (h *Handler) SSEServer() http.Handler {
|
||||
return &dynamicSSEHandler{h: h}
|
||||
}
|
||||
|
||||
// StreamableHTTPServer returns an http.Handler that serves MCP over the streamable HTTP transport.
|
||||
// Unlike SSE (which requires two endpoints), streamable HTTP uses a single endpoint for all
|
||||
// client-server communication (POST for requests, GET for server-initiated messages).
|
||||
// Mount the returned handler at the desired path; the path itself becomes the MCP endpoint.
|
||||
func (h *Handler) StreamableHTTPServer() http.Handler {
|
||||
return server.NewStreamableHTTPServer(h.mcpServer)
|
||||
}
|
||||
|
||||
// newSSEServer creates a concrete *server.SSEServer for known baseURL and basePath values.
|
||||
func (h *Handler) newSSEServer(baseURL, basePath string) *server.SSEServer {
|
||||
return server.NewSSEServer(
|
||||
h.mcpServer,
|
||||
server.WithBaseURL(baseURL),
|
||||
server.WithBasePath(basePath),
|
||||
server.WithStaticBasePath(basePath),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -123,6 +137,32 @@ func (h *Handler) RegisterModel(schema, entity string, model interface{}) error
|
||||
return nil
|
||||
}
|
||||
|
||||
// RegisterModelWithRules registers a model and sets per-entity operation rules
|
||||
// (CanRead, CanCreate, CanUpdate, CanDelete, CanPublic*, SecurityDisabled).
|
||||
// Requires RegisterSecurityHooks to have been called for the rules to be enforced.
|
||||
func (h *Handler) RegisterModelWithRules(schema, entity string, model interface{}, rules modelregistry.ModelRules) error {
|
||||
reg, ok := h.registry.(*modelregistry.DefaultModelRegistry)
|
||||
if !ok {
|
||||
return fmt.Errorf("resolvemcp: registry does not support model rules (use NewHandlerWithGORM/Bun/DB)")
|
||||
}
|
||||
fullName := buildModelName(schema, entity)
|
||||
if err := reg.RegisterModelWithRules(fullName, model, rules); err != nil {
|
||||
return err
|
||||
}
|
||||
registerModelTools(h, schema, entity, model)
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetModelRules updates the operation rules for an already-registered model.
|
||||
// Requires RegisterSecurityHooks to have been called for the rules to be enforced.
|
||||
func (h *Handler) SetModelRules(schema, entity string, rules modelregistry.ModelRules) error {
|
||||
reg, ok := h.registry.(*modelregistry.DefaultModelRegistry)
|
||||
if !ok {
|
||||
return fmt.Errorf("resolvemcp: registry does not support model rules (use NewHandlerWithGORM/Bun/DB)")
|
||||
}
|
||||
return reg.SetModelRules(buildModelName(schema, entity), rules)
|
||||
}
|
||||
|
||||
// buildModelName builds the registry key for a model (same format as resolvespec).
|
||||
func buildModelName(schema, entity string) string {
|
||||
if schema == "" {
|
||||
@@ -160,8 +200,19 @@ func (h *Handler) getSchemaAndTable(defaultSchema, entity string, model interfac
|
||||
return defaultSchema, entity
|
||||
}
|
||||
|
||||
// recoverPanic catches a panic from the current goroutine and returns it as an error.
|
||||
// Usage: defer recoverPanic(&returnedErr)
|
||||
func recoverPanic(err *error) {
|
||||
if r := recover(); r != nil {
|
||||
msg := fmt.Sprintf("%v", r)
|
||||
logger.Error("[resolvemcp] panic recovered: %s", msg)
|
||||
*err = fmt.Errorf("internal error: %s", msg)
|
||||
}
|
||||
}
|
||||
|
||||
// executeRead reads records from the database and returns raw data + metadata.
|
||||
func (h *Handler) executeRead(ctx context.Context, schema, entity, id string, options common.RequestOptions) (interface{}, *common.Metadata, error) {
|
||||
func (h *Handler) executeRead(ctx context.Context, schema, entity, id string, options common.RequestOptions) (_ interface{}, _ *common.Metadata, retErr error) {
|
||||
defer recoverPanic(&retErr)
|
||||
model, err := h.registry.GetModelByEntity(schema, entity)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("model not found: %w", err)
|
||||
@@ -217,15 +268,6 @@ func (h *Handler) executeRead(ctx context.Context, schema, entity, id string, op
|
||||
query = query.ColumnExpr(fmt.Sprintf("(%s) AS %s", cu.Expression, cu.Name))
|
||||
}
|
||||
|
||||
// Preloads
|
||||
if len(options.Preload) > 0 {
|
||||
var err error
|
||||
query, err = h.applyPreloads(model, query, options.Preload)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to apply preloads: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Filters
|
||||
query = h.applyFilters(query, options.Filters)
|
||||
|
||||
@@ -267,7 +309,7 @@ func (h *Handler) executeRead(ctx context.Context, schema, entity, id string, op
|
||||
}
|
||||
}
|
||||
|
||||
// Count
|
||||
// Count — must happen before preloads are applied; Bun panics when counting with relations.
|
||||
total, err := query.Count(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("error counting records: %w", err)
|
||||
@@ -281,6 +323,15 @@ func (h *Handler) executeRead(ctx context.Context, schema, entity, id string, op
|
||||
query = query.Offset(*options.Offset)
|
||||
}
|
||||
|
||||
// Preloads — applied after count to avoid Bun panic when counting with relations.
|
||||
if len(options.Preload) > 0 {
|
||||
var preloadErr error
|
||||
query, preloadErr = h.applyPreloads(model, query, options.Preload)
|
||||
if preloadErr != nil {
|
||||
return nil, nil, fmt.Errorf("failed to apply preloads: %w", preloadErr)
|
||||
}
|
||||
}
|
||||
|
||||
// BeforeRead hook
|
||||
hookCtx.Query = query
|
||||
if err := h.hooks.Execute(BeforeRead, hookCtx); err != nil {
|
||||
@@ -341,7 +392,8 @@ func (h *Handler) executeRead(ctx context.Context, schema, entity, id string, op
|
||||
}
|
||||
|
||||
// executeCreate inserts one or more records.
|
||||
func (h *Handler) executeCreate(ctx context.Context, schema, entity string, data interface{}) (interface{}, error) {
|
||||
func (h *Handler) executeCreate(ctx context.Context, schema, entity string, data interface{}) (_ interface{}, retErr error) {
|
||||
defer recoverPanic(&retErr)
|
||||
model, err := h.registry.GetModelByEntity(schema, entity)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("model not found: %w", err)
|
||||
@@ -425,7 +477,8 @@ func (h *Handler) executeCreate(ctx context.Context, schema, entity string, data
|
||||
}
|
||||
|
||||
// executeUpdate updates a record by ID.
|
||||
func (h *Handler) executeUpdate(ctx context.Context, schema, entity, id string, data interface{}) (interface{}, error) {
|
||||
func (h *Handler) executeUpdate(ctx context.Context, schema, entity, id string, data interface{}) (_ interface{}, retErr error) {
|
||||
defer recoverPanic(&retErr)
|
||||
model, err := h.registry.GetModelByEntity(schema, entity)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("model not found: %w", err)
|
||||
@@ -535,7 +588,8 @@ func (h *Handler) executeUpdate(ctx context.Context, schema, entity, id string,
|
||||
}
|
||||
|
||||
// executeDelete deletes a record by ID.
|
||||
func (h *Handler) executeDelete(ctx context.Context, schema, entity, id string) (interface{}, error) {
|
||||
func (h *Handler) executeDelete(ctx context.Context, schema, entity, id string) (_ interface{}, retErr error) {
|
||||
defer recoverPanic(&retErr)
|
||||
if id == "" {
|
||||
return nil, fmt.Errorf("delete requires an ID")
|
||||
}
|
||||
|
||||
264
pkg/resolvemcp/oauth2.go
Normal file
264
pkg/resolvemcp/oauth2.go
Normal file
@@ -0,0 +1,264 @@
|
||||
package resolvemcp
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/security"
|
||||
)
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// OAuth2 registration on the Handler
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// oauth2Registration stores a configured auth provider and its route config.
|
||||
type oauth2Registration struct {
|
||||
auth *security.DatabaseAuthenticator
|
||||
cfg OAuth2RouteConfig
|
||||
}
|
||||
|
||||
// RegisterOAuth2 attaches an OAuth2 provider to the Handler.
|
||||
// The login and callback HTTP routes are served by HTTPHandler / StreamableHTTPMux.
|
||||
// Call this once per provider before serving requests.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// auth := security.NewGoogleAuthenticator(clientID, secret, redirectURL, db)
|
||||
// handler.RegisterOAuth2(auth, resolvemcp.OAuth2RouteConfig{
|
||||
// ProviderName: "google",
|
||||
// LoginPath: "/auth/google/login",
|
||||
// CallbackPath: "/auth/google/callback",
|
||||
// AfterLoginRedirect: "/",
|
||||
// })
|
||||
func (h *Handler) RegisterOAuth2(auth *security.DatabaseAuthenticator, cfg OAuth2RouteConfig) {
|
||||
h.oauth2Regs = append(h.oauth2Regs, oauth2Registration{auth: auth, cfg: cfg})
|
||||
}
|
||||
|
||||
// HTTPHandler returns a single http.Handler that serves:
|
||||
// - MCP OAuth2 authorization server endpoints (when EnableOAuthServer has been called)
|
||||
// - OAuth2 login + callback routes for every registered provider (legacy cookie flow)
|
||||
// - The MCP SSE transport wrapped with required authentication middleware
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// auth := security.NewGoogleAuthenticator(...)
|
||||
// handler.RegisterOAuth2(auth, cfg)
|
||||
// handler.EnableOAuthServer(resolvemcp.OAuthServerConfig{Issuer: "https://api.example.com"})
|
||||
// security.RegisterSecurityHooks(handler, securityList)
|
||||
// http.ListenAndServe(":8080", handler.HTTPHandler(securityList))
|
||||
func (h *Handler) HTTPHandler(securityList *security.SecurityList) http.Handler {
|
||||
mux := http.NewServeMux()
|
||||
if h.oauthSrv != nil {
|
||||
h.mountOAuthServerRoutes(mux)
|
||||
}
|
||||
h.mountOAuth2Routes(mux)
|
||||
|
||||
mcpHandler := h.AuthedSSEServer(securityList)
|
||||
basePath := h.config.BasePath
|
||||
if basePath == "" {
|
||||
basePath = "/mcp"
|
||||
}
|
||||
mux.Handle(basePath+"/sse", mcpHandler)
|
||||
mux.Handle(basePath+"/message", mcpHandler)
|
||||
mux.Handle(basePath+"/", http.StripPrefix(basePath, mcpHandler))
|
||||
|
||||
return mux
|
||||
}
|
||||
|
||||
// StreamableHTTPMux returns a single http.Handler that serves:
|
||||
// - MCP OAuth2 authorization server endpoints (when EnableOAuthServer has been called)
|
||||
// - OAuth2 login + callback routes for every registered provider (legacy cookie flow)
|
||||
// - The MCP streamable HTTP transport wrapped with required authentication middleware
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// http.ListenAndServe(":8080", handler.StreamableHTTPMux(securityList))
|
||||
func (h *Handler) StreamableHTTPMux(securityList *security.SecurityList) http.Handler {
|
||||
mux := http.NewServeMux()
|
||||
if h.oauthSrv != nil {
|
||||
h.mountOAuthServerRoutes(mux)
|
||||
}
|
||||
h.mountOAuth2Routes(mux)
|
||||
|
||||
mcpHandler := h.AuthedStreamableHTTPServer(securityList)
|
||||
basePath := h.config.BasePath
|
||||
if basePath == "" {
|
||||
basePath = "/mcp"
|
||||
}
|
||||
mux.Handle(basePath+"/", http.StripPrefix(basePath, mcpHandler))
|
||||
mux.Handle(basePath, mcpHandler)
|
||||
|
||||
return mux
|
||||
}
|
||||
|
||||
// mountOAuth2Routes registers all stored OAuth2 login+callback routes onto mux.
|
||||
func (h *Handler) mountOAuth2Routes(mux *http.ServeMux) {
|
||||
for _, reg := range h.oauth2Regs {
|
||||
var cookieOpts []security.SessionCookieOptions
|
||||
if reg.cfg.CookieOptions != nil {
|
||||
cookieOpts = append(cookieOpts, *reg.cfg.CookieOptions)
|
||||
}
|
||||
mux.Handle(reg.cfg.LoginPath, OAuth2LoginHandler(reg.auth, reg.cfg.ProviderName))
|
||||
mux.Handle(reg.cfg.CallbackPath, OAuth2CallbackHandler(reg.auth, reg.cfg.ProviderName, reg.cfg.AfterLoginRedirect, cookieOpts...))
|
||||
}
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Auth-wrapped transports
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// AuthedSSEServer wraps SSEServer with required authentication middleware from pkg/security.
|
||||
// The middleware reads the session cookie / Authorization header and populates the user
|
||||
// context into the request context, making it available to BeforeHandle security hooks.
|
||||
// Unauthenticated requests receive 401 before reaching any MCP tool.
|
||||
func (h *Handler) AuthedSSEServer(securityList *security.SecurityList) http.Handler {
|
||||
return security.NewAuthMiddleware(securityList)(h.SSEServer())
|
||||
}
|
||||
|
||||
// OptionalAuthSSEServer wraps SSEServer with optional authentication middleware.
|
||||
// Unauthenticated requests continue as guest rather than returning 401.
|
||||
// Use together with RegisterSecurityHooks and per-model CanPublicRead/Write rules
|
||||
// to allow mixed public/private access.
|
||||
func (h *Handler) OptionalAuthSSEServer(securityList *security.SecurityList) http.Handler {
|
||||
return security.NewOptionalAuthMiddleware(securityList)(h.SSEServer())
|
||||
}
|
||||
|
||||
// AuthedStreamableHTTPServer wraps StreamableHTTPServer with required authentication middleware.
|
||||
func (h *Handler) AuthedStreamableHTTPServer(securityList *security.SecurityList) http.Handler {
|
||||
return security.NewAuthMiddleware(securityList)(h.StreamableHTTPServer())
|
||||
}
|
||||
|
||||
// OptionalAuthStreamableHTTPServer wraps StreamableHTTPServer with optional authentication middleware.
|
||||
func (h *Handler) OptionalAuthStreamableHTTPServer(securityList *security.SecurityList) http.Handler {
|
||||
return security.NewOptionalAuthMiddleware(securityList)(h.StreamableHTTPServer())
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// OAuth2 route config and standalone handlers
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// OAuth2RouteConfig configures the OAuth2 HTTP endpoints for a single provider.
|
||||
type OAuth2RouteConfig struct {
|
||||
// ProviderName is the OAuth2 provider name as registered with WithOAuth2()
|
||||
// (e.g. "google", "github", "microsoft").
|
||||
ProviderName string
|
||||
|
||||
// LoginPath is the HTTP path that redirects the browser to the OAuth2 provider
|
||||
// (e.g. "/auth/google/login").
|
||||
LoginPath string
|
||||
|
||||
// CallbackPath is the HTTP path that the OAuth2 provider redirects back to
|
||||
// (e.g. "/auth/google/callback"). Must match the RedirectURL in OAuth2Config.
|
||||
CallbackPath string
|
||||
|
||||
// AfterLoginRedirect is the URL to redirect the browser to after a successful
|
||||
// login. When empty the LoginResponse JSON is written directly to the response.
|
||||
AfterLoginRedirect string
|
||||
|
||||
// CookieOptions customises the session cookie written on successful login.
|
||||
// Defaults to HttpOnly, Secure, SameSite=Lax when nil.
|
||||
CookieOptions *security.SessionCookieOptions
|
||||
}
|
||||
|
||||
// OAuth2LoginHandler returns an http.HandlerFunc that redirects the browser to
|
||||
// the OAuth2 provider's authorization URL.
|
||||
//
|
||||
// Register it on any router:
|
||||
//
|
||||
// mux.Handle("/auth/google/login", resolvemcp.OAuth2LoginHandler(auth, "google"))
|
||||
func OAuth2LoginHandler(auth *security.DatabaseAuthenticator, providerName string) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
state, err := auth.OAuth2GenerateState()
|
||||
if err != nil {
|
||||
http.Error(w, "failed to generate state", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
authURL, err := auth.OAuth2GetAuthURL(providerName, state)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
|
||||
}
|
||||
}
|
||||
|
||||
// OAuth2CallbackHandler returns an http.HandlerFunc that handles the OAuth2 provider
|
||||
// callback: exchanges the authorization code for a session token, writes the session
|
||||
// cookie, then either redirects to afterLoginRedirect or writes the LoginResponse as JSON.
|
||||
//
|
||||
// Register it on any router:
|
||||
//
|
||||
// mux.Handle("/auth/google/callback", resolvemcp.OAuth2CallbackHandler(auth, "google", "/dashboard"))
|
||||
func OAuth2CallbackHandler(auth *security.DatabaseAuthenticator, providerName, afterLoginRedirect string, cookieOpts ...security.SessionCookieOptions) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
code := r.URL.Query().Get("code")
|
||||
state := r.URL.Query().Get("state")
|
||||
if code == "" {
|
||||
http.Error(w, "missing code parameter", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
loginResp, err := auth.OAuth2HandleCallback(r.Context(), providerName, code, state)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
security.SetSessionCookie(w, loginResp, cookieOpts...)
|
||||
|
||||
if afterLoginRedirect != "" {
|
||||
http.Redirect(w, r, afterLoginRedirect, http.StatusTemporaryRedirect)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(loginResp) //nolint:errcheck
|
||||
}
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Gorilla Mux convenience helpers
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// SetupMuxOAuth2Routes registers the OAuth2 login and callback routes on a Gorilla Mux router.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// resolvemcp.SetupMuxOAuth2Routes(r, auth, resolvemcp.OAuth2RouteConfig{
|
||||
// ProviderName: "google", LoginPath: "/auth/google/login",
|
||||
// CallbackPath: "/auth/google/callback", AfterLoginRedirect: "/",
|
||||
// })
|
||||
func SetupMuxOAuth2Routes(muxRouter *mux.Router, auth *security.DatabaseAuthenticator, cfg OAuth2RouteConfig) {
|
||||
var cookieOpts []security.SessionCookieOptions
|
||||
if cfg.CookieOptions != nil {
|
||||
cookieOpts = append(cookieOpts, *cfg.CookieOptions)
|
||||
}
|
||||
|
||||
muxRouter.Handle(cfg.LoginPath,
|
||||
OAuth2LoginHandler(auth, cfg.ProviderName),
|
||||
).Methods(http.MethodGet)
|
||||
|
||||
muxRouter.Handle(cfg.CallbackPath,
|
||||
OAuth2CallbackHandler(auth, cfg.ProviderName, cfg.AfterLoginRedirect, cookieOpts...),
|
||||
).Methods(http.MethodGet)
|
||||
}
|
||||
|
||||
// SetupMuxRoutesWithAuth mounts the MCP SSE endpoints on a Gorilla Mux router
|
||||
// with required authentication middleware applied.
|
||||
func SetupMuxRoutesWithAuth(muxRouter *mux.Router, handler *Handler, securityList *security.SecurityList) {
|
||||
basePath := handler.config.BasePath
|
||||
h := handler.AuthedSSEServer(securityList)
|
||||
|
||||
muxRouter.Handle(basePath+"/sse", h).Methods(http.MethodGet, http.MethodOptions)
|
||||
muxRouter.Handle(basePath+"/message", h).Methods(http.MethodPost, http.MethodOptions)
|
||||
muxRouter.PathPrefix(basePath).Handler(http.StripPrefix(basePath, h))
|
||||
}
|
||||
|
||||
// SetupMuxStreamableHTTPRoutesWithAuth mounts the MCP streamable HTTP endpoint on a
|
||||
// Gorilla Mux router with required authentication middleware applied.
|
||||
func SetupMuxStreamableHTTPRoutesWithAuth(muxRouter *mux.Router, handler *Handler, securityList *security.SecurityList) {
|
||||
basePath := handler.config.BasePath
|
||||
h := handler.AuthedStreamableHTTPServer(securityList)
|
||||
muxRouter.PathPrefix(basePath).Handler(http.StripPrefix(basePath, h))
|
||||
}
|
||||
51
pkg/resolvemcp/oauth2_server.go
Normal file
51
pkg/resolvemcp/oauth2_server.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package resolvemcp
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/security"
|
||||
)
|
||||
|
||||
// EnableOAuthServer activates the MCP-standard OAuth2 authorization server on this Handler.
|
||||
//
|
||||
// Pass a DatabaseAuthenticator to enable direct username/password login — the server acts as
|
||||
// its own identity provider and renders a login form at /oauth/authorize. Pass nil to use
|
||||
// only external providers registered via RegisterOAuth2Provider.
|
||||
//
|
||||
// After calling this, HTTPHandler and StreamableHTTPMux serve the full set of RFC-compliant
|
||||
// endpoints required by MCP clients alongside the MCP transport:
|
||||
//
|
||||
// GET /.well-known/oauth-authorization-server RFC 8414 — auto-discovery
|
||||
// POST /oauth/register RFC 7591 — dynamic client registration
|
||||
// GET /oauth/authorize OAuth 2.1 + PKCE — start login
|
||||
// POST /oauth/authorize Login form submission (password flow)
|
||||
// POST /oauth/token Bearer token exchange + refresh
|
||||
// GET /oauth/provider/callback External provider redirect target
|
||||
func (h *Handler) EnableOAuthServer(cfg security.OAuthServerConfig, auth *security.DatabaseAuthenticator) {
|
||||
h.oauthSrv = security.NewOAuthServer(cfg, auth)
|
||||
// Wire any external providers already registered via RegisterOAuth2
|
||||
for _, reg := range h.oauth2Regs {
|
||||
h.oauthSrv.RegisterExternalProvider(reg.auth, reg.cfg.ProviderName)
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterOAuth2Provider adds an external OAuth2 provider to the MCP OAuth2 authorization server.
|
||||
// EnableOAuthServer must be called before this. The auth must have been configured with
|
||||
// WithOAuth2(providerName, ...) for the given provider name.
|
||||
func (h *Handler) RegisterOAuth2Provider(auth *security.DatabaseAuthenticator, providerName string) {
|
||||
if h.oauthSrv != nil {
|
||||
h.oauthSrv.RegisterExternalProvider(auth, providerName)
|
||||
}
|
||||
}
|
||||
|
||||
// mountOAuthServerRoutes mounts the security.OAuthServer's HTTP handler onto mux.
|
||||
func (h *Handler) mountOAuthServerRoutes(mux *http.ServeMux) {
|
||||
oauthHandler := h.oauthSrv.HTTPHandler()
|
||||
// Delegate all /oauth/ and /.well-known/ paths to the OAuth server
|
||||
mux.Handle("/.well-known/", oauthHandler)
|
||||
mux.Handle("/oauth/", oauthHandler)
|
||||
if h.oauthSrv != nil {
|
||||
// Also mount the external provider callback path if it differs from /oauth/
|
||||
mux.Handle(h.oauthSrv.ProviderCallbackPath(), oauthHandler)
|
||||
}
|
||||
}
|
||||
@@ -98,3 +98,36 @@ func SetupBunRouterRoutes(router *bunrouter.Router, handler *Handler) {
|
||||
func NewSSEServer(handler *Handler) http.Handler {
|
||||
return handler.SSEServer()
|
||||
}
|
||||
|
||||
// SetupMuxStreamableHTTPRoutes mounts the MCP streamable HTTP endpoint on the given Gorilla Mux router.
|
||||
// The streamable HTTP transport uses a single endpoint (Config.BasePath) for all communication:
|
||||
// POST for client→server messages, GET for server→client streaming.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// resolvemcp.SetupMuxStreamableHTTPRoutes(r, handler) // mounts at Config.BasePath
|
||||
func SetupMuxStreamableHTTPRoutes(muxRouter *mux.Router, handler *Handler) {
|
||||
basePath := handler.config.BasePath
|
||||
h := handler.StreamableHTTPServer()
|
||||
muxRouter.PathPrefix(basePath).Handler(http.StripPrefix(basePath, h))
|
||||
}
|
||||
|
||||
// SetupBunRouterStreamableHTTPRoutes mounts the MCP streamable HTTP endpoint on a bunrouter router.
|
||||
// The streamable HTTP transport uses a single endpoint (Config.BasePath).
|
||||
func SetupBunRouterStreamableHTTPRoutes(router *bunrouter.Router, handler *Handler) {
|
||||
basePath := handler.config.BasePath
|
||||
h := handler.StreamableHTTPServer()
|
||||
router.GET(basePath, bunrouter.HTTPHandler(h))
|
||||
router.POST(basePath, bunrouter.HTTPHandler(h))
|
||||
router.DELETE(basePath, bunrouter.HTTPHandler(h))
|
||||
}
|
||||
|
||||
// NewStreamableHTTPHandler returns an http.Handler that serves MCP over the streamable HTTP transport.
|
||||
// Mount it at the desired path; that path becomes the MCP endpoint.
|
||||
//
|
||||
// h := resolvemcp.NewStreamableHTTPHandler(handler)
|
||||
// http.Handle("/mcp", h)
|
||||
// engine.Any("/mcp", gin.WrapH(h))
|
||||
func NewStreamableHTTPHandler(handler *Handler) http.Handler {
|
||||
return handler.StreamableHTTPServer()
|
||||
}
|
||||
|
||||
115
pkg/resolvemcp/security_hooks.go
Normal file
115
pkg/resolvemcp/security_hooks.go
Normal file
@@ -0,0 +1,115 @@
|
||||
package resolvemcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/security"
|
||||
)
|
||||
|
||||
// RegisterSecurityHooks wires the security package's access-control layer into the
|
||||
// resolvemcp handler. Call it once after creating the handler, before registering models.
|
||||
//
|
||||
// The following controls are applied:
|
||||
// - Per-entity operation rules (CanRead, CanCreate, CanUpdate, CanDelete, CanPublic*)
|
||||
// stored via RegisterModelWithRules / SetModelRules.
|
||||
// - Row-level security: WHERE clause injected per user from the SecurityList provider.
|
||||
// - Column-level security: sensitive columns masked/hidden in read results.
|
||||
// - Audit logging after each read.
|
||||
func RegisterSecurityHooks(handler *Handler, securityList *security.SecurityList) {
|
||||
// BeforeHandle: enforce model-level operation rules (auth check).
|
||||
handler.Hooks().Register(BeforeHandle, func(hookCtx *HookContext) error {
|
||||
if err := security.CheckModelAuthAllowed(newSecurityContext(hookCtx), hookCtx.Operation); err != nil {
|
||||
hookCtx.Abort = true
|
||||
hookCtx.AbortMessage = err.Error()
|
||||
hookCtx.AbortCode = http.StatusUnauthorized
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
// BeforeRead (1st): load RLS + CLS rules from the provider into SecurityList.
|
||||
handler.Hooks().Register(BeforeRead, func(hookCtx *HookContext) error {
|
||||
return security.LoadSecurityRules(newSecurityContext(hookCtx), securityList)
|
||||
})
|
||||
|
||||
// BeforeRead (2nd): apply row-level security — injects a WHERE clause into the query.
|
||||
// resolvemcp has no separate BeforeScan hook; the query is available in BeforeRead.
|
||||
handler.Hooks().Register(BeforeRead, func(hookCtx *HookContext) error {
|
||||
return security.ApplyRowSecurity(newSecurityContext(hookCtx), securityList)
|
||||
})
|
||||
|
||||
// AfterRead (1st): apply column-level security — mask/hide columns in the result.
|
||||
handler.Hooks().Register(AfterRead, func(hookCtx *HookContext) error {
|
||||
return security.ApplyColumnSecurity(newSecurityContext(hookCtx), securityList)
|
||||
})
|
||||
|
||||
// AfterRead (2nd): audit log.
|
||||
handler.Hooks().Register(AfterRead, func(hookCtx *HookContext) error {
|
||||
return security.LogDataAccess(newSecurityContext(hookCtx))
|
||||
})
|
||||
|
||||
// BeforeUpdate: enforce CanUpdate rule.
|
||||
handler.Hooks().Register(BeforeUpdate, func(hookCtx *HookContext) error {
|
||||
return security.CheckModelUpdateAllowed(newSecurityContext(hookCtx))
|
||||
})
|
||||
|
||||
// BeforeDelete: enforce CanDelete rule.
|
||||
handler.Hooks().Register(BeforeDelete, func(hookCtx *HookContext) error {
|
||||
return security.CheckModelDeleteAllowed(newSecurityContext(hookCtx))
|
||||
})
|
||||
|
||||
logger.Info("Security hooks registered for resolvemcp handler")
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// securityContext — adapts resolvemcp.HookContext to security.SecurityContext
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
type securityContext struct {
|
||||
ctx *HookContext
|
||||
}
|
||||
|
||||
func newSecurityContext(ctx *HookContext) security.SecurityContext {
|
||||
return &securityContext{ctx: ctx}
|
||||
}
|
||||
|
||||
func (s *securityContext) GetContext() context.Context {
|
||||
return s.ctx.Context
|
||||
}
|
||||
|
||||
func (s *securityContext) GetUserID() (int, bool) {
|
||||
return security.GetUserID(s.ctx.Context)
|
||||
}
|
||||
|
||||
func (s *securityContext) GetSchema() string {
|
||||
return s.ctx.Schema
|
||||
}
|
||||
|
||||
func (s *securityContext) GetEntity() string {
|
||||
return s.ctx.Entity
|
||||
}
|
||||
|
||||
func (s *securityContext) GetModel() interface{} {
|
||||
return s.ctx.Model
|
||||
}
|
||||
|
||||
func (s *securityContext) GetQuery() interface{} {
|
||||
return s.ctx.Query
|
||||
}
|
||||
|
||||
func (s *securityContext) SetQuery(query interface{}) {
|
||||
if q, ok := query.(common.SelectQuery); ok {
|
||||
s.ctx.Query = q
|
||||
}
|
||||
}
|
||||
|
||||
func (s *securityContext) GetResult() interface{} {
|
||||
return s.ctx.Result
|
||||
}
|
||||
|
||||
func (s *securityContext) SetResult(result interface{}) {
|
||||
s.ctx.Result = result
|
||||
}
|
||||
@@ -67,7 +67,7 @@ func GetCursorFilter(
|
||||
// 4. Process each sort column
|
||||
// --------------------------------------------------------------------- //
|
||||
for _, s := range sortItems {
|
||||
col := strings.TrimSpace(s.Column)
|
||||
col := strings.Trim(strings.TrimSpace(s.Column), "()")
|
||||
if col == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -64,7 +64,7 @@ func (opts *ExtendedRequestOptions) GetCursorFilter(
|
||||
// 4. Process each sort column
|
||||
// --------------------------------------------------------------------- //
|
||||
for _, s := range sortItems {
|
||||
col := strings.TrimSpace(s.Column)
|
||||
col := strings.Trim(strings.TrimSpace(s.Column), "()")
|
||||
if col == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -258,11 +258,8 @@ func (a *JWTAuthenticator) Login(ctx context.Context, req security.LoginRequest)
|
||||
}
|
||||
|
||||
func (a *JWTAuthenticator) Logout(ctx context.Context, req security.LogoutRequest) error {
|
||||
// Add to blacklist
|
||||
return a.db.WithContext(ctx).Table("token_blacklist").Create(map[string]any{
|
||||
"token": req.Token,
|
||||
"user_id": req.UserID,
|
||||
}).Error
|
||||
// Invalidate session via stored procedure
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *JWTAuthenticator) Authenticate(r *http.Request) (*security.UserContext, error) {
|
||||
|
||||
@@ -12,6 +12,7 @@ Type-safe, composable security system for ResolveSpec with support for authentic
|
||||
- ✅ **Testable** - Easy to mock and test
|
||||
- ✅ **Extensible** - Implement custom providers for your needs
|
||||
- ✅ **Stored Procedures** - All database operations use PostgreSQL stored procedures for security and maintainability
|
||||
- ✅ **OAuth2 Authorization Server** - Built-in OAuth 2.1 + PKCE server (RFC 8414, 7591, 7009, 7662) with login form and external provider federation
|
||||
|
||||
## Stored Procedure Architecture
|
||||
|
||||
@@ -38,6 +39,12 @@ Type-safe, composable security system for ResolveSpec with support for authentic
|
||||
| `resolvespec_jwt_logout` | JWT token blacklist | JWTAuthenticator |
|
||||
| `resolvespec_column_security` | Load column rules | DatabaseColumnSecurityProvider |
|
||||
| `resolvespec_row_security` | Load row templates | DatabaseRowSecurityProvider |
|
||||
| `resolvespec_oauth_register_client` | Persist OAuth2 client (RFC 7591) | OAuthServer / DatabaseAuthenticator |
|
||||
| `resolvespec_oauth_get_client` | Retrieve OAuth2 client by ID | OAuthServer / DatabaseAuthenticator |
|
||||
| `resolvespec_oauth_save_code` | Persist authorization code | OAuthServer / DatabaseAuthenticator |
|
||||
| `resolvespec_oauth_exchange_code` | Consume authorization code (single-use) | OAuthServer / DatabaseAuthenticator |
|
||||
| `resolvespec_oauth_introspect` | Token introspection (RFC 7662) | OAuthServer / DatabaseAuthenticator |
|
||||
| `resolvespec_oauth_revoke` | Token revocation (RFC 7009) | OAuthServer / DatabaseAuthenticator |
|
||||
|
||||
See `database_schema.sql` for complete stored procedure definitions and examples.
|
||||
|
||||
@@ -897,6 +904,155 @@ securityList := security.NewSecurityList(provider)
|
||||
restheadspec.RegisterSecurityHooks(handler, securityList) // or funcspec/resolvespec
|
||||
```
|
||||
|
||||
## OAuth2 Authorization Server
|
||||
|
||||
`OAuthServer` is a generic OAuth 2.1 + PKCE authorization server. It is not tied to any spec — `pkg/resolvemcp` uses it, but it can be used standalone with any `http.ServeMux`.
|
||||
|
||||
### Endpoints
|
||||
|
||||
| Method | Path | RFC |
|
||||
|--------|------|-----|
|
||||
| `GET` | `/.well-known/oauth-authorization-server` | RFC 8414 — server metadata |
|
||||
| `POST` | `/oauth/register` | RFC 7591 — dynamic client registration |
|
||||
| `GET` | `/oauth/authorize` | OAuth 2.1 — start authorization / provider selection |
|
||||
| `POST` | `/oauth/authorize` | OAuth 2.1 — login form submission |
|
||||
| `POST` | `/oauth/token` | OAuth 2.1 — code exchange + refresh |
|
||||
| `POST` | `/oauth/revoke` | RFC 7009 — token revocation |
|
||||
| `POST` | `/oauth/introspect` | RFC 7662 — token introspection |
|
||||
| `GET` | `{ProviderCallbackPath}` | External provider redirect target |
|
||||
|
||||
### Config
|
||||
|
||||
```go
|
||||
cfg := security.OAuthServerConfig{
|
||||
Issuer: "https://example.com", // Required — token issuer URL
|
||||
ProviderCallbackPath: "/oauth/provider/callback", // External provider redirect target
|
||||
LoginTitle: "My App Login", // HTML login page title
|
||||
PersistClients: true, // Store clients in DB (multi-instance safe)
|
||||
PersistCodes: true, // Store codes in DB (multi-instance safe)
|
||||
DefaultScopes: []string{"openid", "profile"}, // Returned when no scope requested
|
||||
AccessTokenTTL: time.Hour,
|
||||
AuthCodeTTL: 5 * time.Minute,
|
||||
}
|
||||
```
|
||||
|
||||
| Field | Default | Notes |
|
||||
|-------|---------|-------|
|
||||
| `Issuer` | — | Required |
|
||||
| `ProviderCallbackPath` | `/oauth/provider/callback` | |
|
||||
| `LoginTitle` | `"Login"` | |
|
||||
| `PersistClients` | `false` | Set `true` for multi-instance |
|
||||
| `PersistCodes` | `false` | Set `true` for multi-instance |
|
||||
| `DefaultScopes` | `nil` | |
|
||||
| `AccessTokenTTL` | `1h` | |
|
||||
| `AuthCodeTTL` | `5m` | |
|
||||
|
||||
### Operating Modes
|
||||
|
||||
**Mode 1 — Direct login (username/password form)**
|
||||
|
||||
Pass a `*DatabaseAuthenticator` to `NewOAuthServer`. The server renders a login form at `GET /oauth/authorize` and issues tokens via the stored session after login.
|
||||
|
||||
```go
|
||||
auth := security.NewDatabaseAuthenticator(db)
|
||||
srv := security.NewOAuthServer(cfg, auth)
|
||||
```
|
||||
|
||||
**Mode 2 — External provider federation**
|
||||
|
||||
Pass `nil` as auth and register external providers. The authorize page shows a provider selection UI.
|
||||
|
||||
```go
|
||||
srv := security.NewOAuthServer(cfg, nil)
|
||||
srv.RegisterExternalProvider(googleAuth, "google")
|
||||
srv.RegisterExternalProvider(githubAuth, "github")
|
||||
```
|
||||
|
||||
**Mode 3 — Both**
|
||||
|
||||
Pass auth for the login form and also register external providers. The authorize page shows both a login form and provider buttons.
|
||||
|
||||
```go
|
||||
srv := security.NewOAuthServer(cfg, auth)
|
||||
srv.RegisterExternalProvider(googleAuth, "google")
|
||||
```
|
||||
|
||||
### Standalone Usage
|
||||
|
||||
```go
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle("/.well-known/", srv.HTTPHandler())
|
||||
mux.Handle("/oauth/", srv.HTTPHandler())
|
||||
mux.Handle(cfg.ProviderCallbackPath, srv.HTTPHandler())
|
||||
|
||||
http.ListenAndServe(":8080", mux)
|
||||
```
|
||||
|
||||
### DB Persistence
|
||||
|
||||
When `PersistClients: true` or `PersistCodes: true`, the server calls the corresponding `DatabaseAuthenticator` methods. Both flags default to `false` (in-memory maps). Enable both for multi-instance deployments.
|
||||
|
||||
Requires `oauth_clients` and `oauth_codes` tables + 6 stored procedures from `database_schema.sql`.
|
||||
|
||||
#### New DB Types
|
||||
|
||||
```go
|
||||
type OAuthServerClient struct {
|
||||
ClientID string `json:"client_id"`
|
||||
RedirectURIs []string `json:"redirect_uris"`
|
||||
ClientName string `json:"client_name,omitempty"`
|
||||
GrantTypes []string `json:"grant_types"`
|
||||
AllowedScopes []string `json:"allowed_scopes,omitempty"`
|
||||
}
|
||||
|
||||
type OAuthCode struct {
|
||||
Code string `json:"code"`
|
||||
ClientID string `json:"client_id"`
|
||||
RedirectURI string `json:"redirect_uri"`
|
||||
ClientState string `json:"client_state,omitempty"`
|
||||
CodeChallenge string `json:"code_challenge"`
|
||||
CodeChallengeMethod string `json:"code_challenge_method"`
|
||||
SessionToken string `json:"session_token"`
|
||||
Scopes []string `json:"scopes,omitempty"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
}
|
||||
|
||||
type OAuthTokenInfo struct {
|
||||
Active bool `json:"active"`
|
||||
Sub string `json:"sub,omitempty"`
|
||||
Username string `json:"username,omitempty"`
|
||||
Email string `json:"email,omitempty"`
|
||||
Roles []string `json:"roles,omitempty"`
|
||||
Exp int64 `json:"exp,omitempty"`
|
||||
Iat int64 `json:"iat,omitempty"`
|
||||
}
|
||||
```
|
||||
|
||||
#### DatabaseAuthenticator OAuth Methods
|
||||
|
||||
```go
|
||||
auth.OAuthRegisterClient(ctx, client) // RFC 7591 — persist client
|
||||
auth.OAuthGetClient(ctx, clientID) // retrieve client
|
||||
auth.OAuthSaveCode(ctx, code) // persist authorization code
|
||||
auth.OAuthExchangeCode(ctx, code) // consume code (single-use, deletes on read)
|
||||
auth.OAuthIntrospectToken(ctx, token) // RFC 7662 — returns OAuthTokenInfo
|
||||
auth.OAuthRevokeToken(ctx, token) // RFC 7009 — revoke session
|
||||
```
|
||||
|
||||
#### SQLNames Fields
|
||||
|
||||
```go
|
||||
type SQLNames struct {
|
||||
// ... existing fields ...
|
||||
OAuthRegisterClient string // default: "resolvespec_oauth_register_client"
|
||||
OAuthGetClient string // default: "resolvespec_oauth_get_client"
|
||||
OAuthSaveCode string // default: "resolvespec_oauth_save_code"
|
||||
OAuthExchangeCode string // default: "resolvespec_oauth_exchange_code"
|
||||
OAuthIntrospect string // default: "resolvespec_oauth_introspect"
|
||||
OAuthRevoke string // default: "resolvespec_oauth_revoke"
|
||||
}
|
||||
```
|
||||
|
||||
The main changes:
|
||||
1. Security package no longer knows about specific spec types
|
||||
2. Each spec registers its own security hooks
|
||||
|
||||
@@ -1397,3 +1397,173 @@ $$ LANGUAGE plpgsql;
|
||||
|
||||
-- Get credentials by username
|
||||
-- SELECT * FROM resolvespec_passkey_get_credentials_by_username('admin');
|
||||
|
||||
-- ============================================
|
||||
-- OAuth2 Server Tables (OAuthServer persistence)
|
||||
-- ============================================
|
||||
|
||||
-- oauth_clients: persistent RFC 7591 registered clients
|
||||
CREATE TABLE IF NOT EXISTS oauth_clients (
|
||||
id SERIAL PRIMARY KEY,
|
||||
client_id VARCHAR(255) NOT NULL UNIQUE,
|
||||
redirect_uris TEXT[] NOT NULL,
|
||||
client_name VARCHAR(255),
|
||||
grant_types TEXT[] DEFAULT ARRAY['authorization_code'],
|
||||
allowed_scopes TEXT[] DEFAULT ARRAY['openid','profile','email'],
|
||||
is_active BOOLEAN DEFAULT true,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
-- oauth_codes: short-lived authorization codes (for multi-instance deployments)
|
||||
CREATE TABLE IF NOT EXISTS oauth_codes (
|
||||
id SERIAL PRIMARY KEY,
|
||||
code VARCHAR(255) NOT NULL UNIQUE,
|
||||
client_id VARCHAR(255) NOT NULL REFERENCES oauth_clients(client_id) ON DELETE CASCADE,
|
||||
redirect_uri TEXT NOT NULL,
|
||||
client_state TEXT,
|
||||
code_challenge VARCHAR(255) NOT NULL,
|
||||
code_challenge_method VARCHAR(10) DEFAULT 'S256',
|
||||
session_token TEXT NOT NULL,
|
||||
scopes TEXT[],
|
||||
expires_at TIMESTAMP NOT NULL,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_oauth_codes_code ON oauth_codes(code);
|
||||
CREATE INDEX IF NOT EXISTS idx_oauth_codes_expires ON oauth_codes(expires_at);
|
||||
|
||||
-- ============================================
|
||||
-- OAuth2 Server Stored Procedures
|
||||
-- ============================================
|
||||
|
||||
CREATE OR REPLACE FUNCTION resolvespec_oauth_register_client(p_data jsonb)
|
||||
RETURNS TABLE(p_success bool, p_error text, p_data jsonb)
|
||||
LANGUAGE plpgsql AS $$
|
||||
DECLARE
|
||||
v_client_id text;
|
||||
v_row jsonb;
|
||||
BEGIN
|
||||
v_client_id := p_data->>'client_id';
|
||||
|
||||
INSERT INTO oauth_clients (client_id, redirect_uris, client_name, grant_types, allowed_scopes)
|
||||
VALUES (
|
||||
v_client_id,
|
||||
ARRAY(SELECT jsonb_array_elements_text(p_data->'redirect_uris')),
|
||||
p_data->>'client_name',
|
||||
COALESCE(ARRAY(SELECT jsonb_array_elements_text(p_data->'grant_types')), ARRAY['authorization_code']),
|
||||
COALESCE(ARRAY(SELECT jsonb_array_elements_text(p_data->'allowed_scopes')), ARRAY['openid','profile','email'])
|
||||
)
|
||||
RETURNING to_jsonb(oauth_clients.*) INTO v_row;
|
||||
|
||||
RETURN QUERY SELECT true, null::text, v_row;
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
RETURN QUERY SELECT false, SQLERRM, null::jsonb;
|
||||
END;
|
||||
$$;
|
||||
|
||||
CREATE OR REPLACE FUNCTION resolvespec_oauth_get_client(p_client_id text)
|
||||
RETURNS TABLE(p_success bool, p_error text, p_data jsonb)
|
||||
LANGUAGE plpgsql AS $$
|
||||
DECLARE
|
||||
v_row jsonb;
|
||||
BEGIN
|
||||
SELECT to_jsonb(oauth_clients.*)
|
||||
INTO v_row
|
||||
FROM oauth_clients
|
||||
WHERE client_id = p_client_id AND is_active = true;
|
||||
|
||||
IF v_row IS NULL THEN
|
||||
RETURN QUERY SELECT false, 'client not found'::text, null::jsonb;
|
||||
ELSE
|
||||
RETURN QUERY SELECT true, null::text, v_row;
|
||||
END IF;
|
||||
END;
|
||||
$$;
|
||||
|
||||
CREATE OR REPLACE FUNCTION resolvespec_oauth_save_code(p_data jsonb)
|
||||
RETURNS TABLE(p_success bool, p_error text)
|
||||
LANGUAGE plpgsql AS $$
|
||||
BEGIN
|
||||
INSERT INTO oauth_codes (code, client_id, redirect_uri, client_state, code_challenge, code_challenge_method, session_token, scopes, expires_at)
|
||||
VALUES (
|
||||
p_data->>'code',
|
||||
p_data->>'client_id',
|
||||
p_data->>'redirect_uri',
|
||||
p_data->>'client_state',
|
||||
p_data->>'code_challenge',
|
||||
COALESCE(p_data->>'code_challenge_method', 'S256'),
|
||||
p_data->>'session_token',
|
||||
ARRAY(SELECT jsonb_array_elements_text(p_data->'scopes')),
|
||||
(p_data->>'expires_at')::timestamp
|
||||
);
|
||||
|
||||
RETURN QUERY SELECT true, null::text;
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
RETURN QUERY SELECT false, SQLERRM;
|
||||
END;
|
||||
$$;
|
||||
|
||||
CREATE OR REPLACE FUNCTION resolvespec_oauth_exchange_code(p_code text)
|
||||
RETURNS TABLE(p_success bool, p_error text, p_data jsonb)
|
||||
LANGUAGE plpgsql AS $$
|
||||
DECLARE
|
||||
v_row jsonb;
|
||||
BEGIN
|
||||
DELETE FROM oauth_codes
|
||||
WHERE code = p_code AND expires_at > now()
|
||||
RETURNING jsonb_build_object(
|
||||
'client_id', client_id,
|
||||
'redirect_uri', redirect_uri,
|
||||
'client_state', client_state,
|
||||
'code_challenge', code_challenge,
|
||||
'code_challenge_method', code_challenge_method,
|
||||
'session_token', session_token,
|
||||
'scopes', to_jsonb(scopes)
|
||||
) INTO v_row;
|
||||
|
||||
IF v_row IS NULL THEN
|
||||
RETURN QUERY SELECT false, 'invalid or expired code'::text, null::jsonb;
|
||||
ELSE
|
||||
RETURN QUERY SELECT true, null::text, v_row;
|
||||
END IF;
|
||||
END;
|
||||
$$;
|
||||
|
||||
CREATE OR REPLACE FUNCTION resolvespec_oauth_introspect(p_token text)
|
||||
RETURNS TABLE(p_success bool, p_error text, p_data jsonb)
|
||||
LANGUAGE plpgsql AS $$
|
||||
DECLARE
|
||||
v_row jsonb;
|
||||
BEGIN
|
||||
SELECT jsonb_build_object(
|
||||
'active', true,
|
||||
'sub', u.id::text,
|
||||
'username', u.username,
|
||||
'email', u.email,
|
||||
'user_level', u.user_level,
|
||||
'roles', to_jsonb(string_to_array(COALESCE(u.roles, ''), ',')),
|
||||
'exp', EXTRACT(EPOCH FROM s.expires_at)::bigint,
|
||||
'iat', EXTRACT(EPOCH FROM s.created_at)::bigint
|
||||
)
|
||||
INTO v_row
|
||||
FROM user_sessions s
|
||||
JOIN users u ON u.id = s.user_id
|
||||
WHERE s.session_token = p_token
|
||||
AND s.expires_at > now()
|
||||
AND u.is_active = true;
|
||||
|
||||
IF v_row IS NULL THEN
|
||||
RETURN QUERY SELECT true, null::text, '{"active":false}'::jsonb;
|
||||
ELSE
|
||||
RETURN QUERY SELECT true, null::text, v_row;
|
||||
END IF;
|
||||
END;
|
||||
$$;
|
||||
|
||||
CREATE OR REPLACE FUNCTION resolvespec_oauth_revoke(p_token text)
|
||||
RETURNS TABLE(p_success bool, p_error text)
|
||||
LANGUAGE plpgsql AS $$
|
||||
BEGIN
|
||||
DELETE FROM user_sessions WHERE session_token = p_token;
|
||||
RETURN QUERY SELECT true, null::text;
|
||||
END;
|
||||
$$;
|
||||
|
||||
@@ -135,12 +135,6 @@ func (a *JWTAuthenticatorExample) Login(ctx context.Context, req LoginRequest) (
|
||||
}
|
||||
|
||||
func (a *JWTAuthenticatorExample) Logout(ctx context.Context, req LogoutRequest) error {
|
||||
// For JWT, logout could involve token blacklisting
|
||||
// Add token to blacklist table
|
||||
// err := a.db.WithContext(ctx).Table("token_blacklist").Create(map[string]interface{}{
|
||||
// "token": req.Token,
|
||||
// "expires_at": time.Now().Add(24 * time.Hour),
|
||||
// }).Error
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -244,10 +244,10 @@ func (a *DatabaseAuthenticator) oauth2GetOrCreateUser(ctx context.Context, userC
|
||||
var errMsg *string
|
||||
var userID *int
|
||||
|
||||
err = a.db.QueryRowContext(ctx, `
|
||||
SELECT p_success, p_error, p_user_id
|
||||
FROM resolvespec_oauth_getorcreateuser($1::jsonb)
|
||||
`, userJSON).Scan(&success, &errMsg, &userID)
|
||||
err = a.db.QueryRowContext(ctx, fmt.Sprintf(`
|
||||
SELECT p_success, p_error, p_user_id
|
||||
FROM %s($1::jsonb)
|
||||
`, a.sqlNames.OAuthGetOrCreateUser), userJSON).Scan(&success, &errMsg, &userID)
|
||||
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to get or create user: %w", err)
|
||||
@@ -287,10 +287,10 @@ func (a *DatabaseAuthenticator) oauth2CreateSession(ctx context.Context, session
|
||||
var success bool
|
||||
var errMsg *string
|
||||
|
||||
err = a.db.QueryRowContext(ctx, `
|
||||
SELECT p_success, p_error
|
||||
FROM resolvespec_oauth_createsession($1::jsonb)
|
||||
`, sessionJSON).Scan(&success, &errMsg)
|
||||
err = a.db.QueryRowContext(ctx, fmt.Sprintf(`
|
||||
SELECT p_success, p_error
|
||||
FROM %s($1::jsonb)
|
||||
`, a.sqlNames.OAuthCreateSession), sessionJSON).Scan(&success, &errMsg)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create session: %w", err)
|
||||
@@ -385,10 +385,10 @@ func (a *DatabaseAuthenticator) OAuth2RefreshToken(ctx context.Context, refreshT
|
||||
var errMsg *string
|
||||
var sessionData []byte
|
||||
|
||||
err = a.db.QueryRowContext(ctx, `
|
||||
err = a.db.QueryRowContext(ctx, fmt.Sprintf(`
|
||||
SELECT p_success, p_error, p_data::text
|
||||
FROM resolvespec_oauth_getrefreshtoken($1)
|
||||
`, refreshToken).Scan(&success, &errMsg, &sessionData)
|
||||
FROM %s($1)
|
||||
`, a.sqlNames.OAuthGetRefreshToken), refreshToken).Scan(&success, &errMsg, &sessionData)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get session by refresh token: %w", err)
|
||||
@@ -451,10 +451,10 @@ func (a *DatabaseAuthenticator) OAuth2RefreshToken(ctx context.Context, refreshT
|
||||
var updateSuccess bool
|
||||
var updateErrMsg *string
|
||||
|
||||
err = a.db.QueryRowContext(ctx, `
|
||||
err = a.db.QueryRowContext(ctx, fmt.Sprintf(`
|
||||
SELECT p_success, p_error
|
||||
FROM resolvespec_oauth_updaterefreshtoken($1::jsonb)
|
||||
`, updateJSON).Scan(&updateSuccess, &updateErrMsg)
|
||||
FROM %s($1::jsonb)
|
||||
`, a.sqlNames.OAuthUpdateRefreshToken), updateJSON).Scan(&updateSuccess, &updateErrMsg)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to update session: %w", err)
|
||||
@@ -472,10 +472,10 @@ func (a *DatabaseAuthenticator) OAuth2RefreshToken(ctx context.Context, refreshT
|
||||
var userErrMsg *string
|
||||
var userData []byte
|
||||
|
||||
err = a.db.QueryRowContext(ctx, `
|
||||
err = a.db.QueryRowContext(ctx, fmt.Sprintf(`
|
||||
SELECT p_success, p_error, p_data::text
|
||||
FROM resolvespec_oauth_getuser($1)
|
||||
`, session.UserID).Scan(&userSuccess, &userErrMsg, &userData)
|
||||
FROM %s($1)
|
||||
`, a.sqlNames.OAuthGetUser), session.UserID).Scan(&userSuccess, &userErrMsg, &userData)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get user data: %w", err)
|
||||
|
||||
859
pkg/security/oauth_server.go
Normal file
859
pkg/security/oauth_server.go
Normal file
@@ -0,0 +1,859 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// OAuthServerConfig configures the MCP-standard OAuth2 authorization server.
|
||||
type OAuthServerConfig struct {
|
||||
// Issuer is the public base URL of this server (e.g. "https://api.example.com").
|
||||
// Used in /.well-known/oauth-authorization-server and to build endpoint URLs.
|
||||
Issuer string
|
||||
|
||||
// ProviderCallbackPath is the path on this server that external OAuth2 providers
|
||||
// redirect back to. Defaults to "/oauth/provider/callback".
|
||||
ProviderCallbackPath string
|
||||
|
||||
// LoginTitle is shown on the built-in login form when the server acts as its own
|
||||
// identity provider. Defaults to "MCP Login".
|
||||
LoginTitle string
|
||||
|
||||
// PersistClients stores registered clients in the database when a DatabaseAuthenticator is provided.
|
||||
// Clients registered during a session survive server restarts.
|
||||
PersistClients bool
|
||||
|
||||
// PersistCodes stores authorization codes in the database.
|
||||
// Useful for multi-instance deployments. Defaults to in-memory.
|
||||
PersistCodes bool
|
||||
|
||||
// DefaultScopes lists scopes advertised in server metadata. Defaults to ["openid","profile","email"].
|
||||
DefaultScopes []string
|
||||
|
||||
// AccessTokenTTL is the issued token lifetime. Defaults to 24h.
|
||||
AccessTokenTTL time.Duration
|
||||
|
||||
// AuthCodeTTL is the auth code lifetime. Defaults to 2 minutes.
|
||||
AuthCodeTTL time.Duration
|
||||
}
|
||||
|
||||
// oauthClient is a dynamically registered OAuth2 client (RFC 7591).
|
||||
type oauthClient struct {
|
||||
ClientID string `json:"client_id"`
|
||||
RedirectURIs []string `json:"redirect_uris"`
|
||||
ClientName string `json:"client_name,omitempty"`
|
||||
GrantTypes []string `json:"grant_types"`
|
||||
AllowedScopes []string `json:"allowed_scopes,omitempty"`
|
||||
}
|
||||
|
||||
// pendingAuth tracks an in-progress authorization code exchange.
|
||||
type pendingAuth struct {
|
||||
ClientID string
|
||||
RedirectURI string
|
||||
ClientState string
|
||||
CodeChallenge string
|
||||
CodeChallengeMethod string
|
||||
ProviderName string // empty = password login
|
||||
ExpiresAt time.Time
|
||||
SessionToken string // set after authentication completes
|
||||
Scopes []string // requested scopes
|
||||
}
|
||||
|
||||
// externalProvider pairs a DatabaseAuthenticator with its provider name.
|
||||
type externalProvider struct {
|
||||
auth *DatabaseAuthenticator
|
||||
providerName string
|
||||
}
|
||||
|
||||
// OAuthServer implements the MCP-standard OAuth2 authorization server (OAuth 2.1 + PKCE).
|
||||
//
|
||||
// It can act as both:
|
||||
// - A direct identity provider using DatabaseAuthenticator username/password login
|
||||
// - A federation layer that delegates authentication to external OAuth2 providers
|
||||
// (Google, GitHub, Microsoft, etc.) registered via RegisterExternalProvider
|
||||
//
|
||||
// The server exposes these RFC-compliant endpoints:
|
||||
//
|
||||
// GET /.well-known/oauth-authorization-server RFC 8414 — server metadata discovery
|
||||
// POST /oauth/register RFC 7591 — dynamic client registration
|
||||
// GET /oauth/authorize OAuth 2.1 + PKCE — start authorization
|
||||
// POST /oauth/authorize Direct login form submission
|
||||
// POST /oauth/token Token exchange and refresh
|
||||
// POST /oauth/revoke RFC 7009 — token revocation
|
||||
// POST /oauth/introspect RFC 7662 — token introspection
|
||||
// GET {ProviderCallbackPath} Internal — external provider callback
|
||||
type OAuthServer struct {
|
||||
cfg OAuthServerConfig
|
||||
auth *DatabaseAuthenticator // nil = only external providers
|
||||
providers []externalProvider
|
||||
|
||||
mu sync.RWMutex
|
||||
clients map[string]*oauthClient
|
||||
pending map[string]*pendingAuth // provider_state → pending (external flow)
|
||||
codes map[string]*pendingAuth // auth_code → pending (post-auth)
|
||||
}
|
||||
|
||||
// NewOAuthServer creates a new MCP OAuth2 authorization server.
|
||||
//
|
||||
// Pass a DatabaseAuthenticator to enable direct username/password login (the server
|
||||
// acts as its own identity provider). Pass nil to use only external providers.
|
||||
// External providers are added separately via RegisterExternalProvider.
|
||||
func NewOAuthServer(cfg OAuthServerConfig, auth *DatabaseAuthenticator) *OAuthServer {
|
||||
if cfg.ProviderCallbackPath == "" {
|
||||
cfg.ProviderCallbackPath = "/oauth/provider/callback"
|
||||
}
|
||||
if cfg.LoginTitle == "" {
|
||||
cfg.LoginTitle = "Sign in"
|
||||
}
|
||||
if len(cfg.DefaultScopes) == 0 {
|
||||
cfg.DefaultScopes = []string{"openid", "profile", "email"}
|
||||
}
|
||||
if cfg.AccessTokenTTL == 0 {
|
||||
cfg.AccessTokenTTL = 24 * time.Hour
|
||||
}
|
||||
if cfg.AuthCodeTTL == 0 {
|
||||
cfg.AuthCodeTTL = 2 * time.Minute
|
||||
}
|
||||
s := &OAuthServer{
|
||||
cfg: cfg,
|
||||
auth: auth,
|
||||
clients: make(map[string]*oauthClient),
|
||||
pending: make(map[string]*pendingAuth),
|
||||
codes: make(map[string]*pendingAuth),
|
||||
}
|
||||
go s.cleanupExpired()
|
||||
return s
|
||||
}
|
||||
|
||||
// RegisterExternalProvider adds an external OAuth2 provider (Google, GitHub, Microsoft, etc.)
|
||||
// that handles user authentication via redirect. The DatabaseAuthenticator must have been
|
||||
// configured with WithOAuth2(providerName, ...) before calling this.
|
||||
// Multiple providers can be registered; the first is used as the default.
|
||||
func (s *OAuthServer) RegisterExternalProvider(auth *DatabaseAuthenticator, providerName string) {
|
||||
s.providers = append(s.providers, externalProvider{auth: auth, providerName: providerName})
|
||||
}
|
||||
|
||||
// ProviderCallbackPath returns the configured path for external provider callbacks.
|
||||
func (s *OAuthServer) ProviderCallbackPath() string {
|
||||
return s.cfg.ProviderCallbackPath
|
||||
}
|
||||
|
||||
// HTTPHandler returns an http.Handler that serves all RFC-required OAuth2 endpoints.
|
||||
// Mount it at the root of your HTTP server alongside the MCP transport.
|
||||
//
|
||||
// mux := http.NewServeMux()
|
||||
// mux.Handle("/", oauthServer.HTTPHandler())
|
||||
// mux.Handle("/mcp/", mcpTransport)
|
||||
func (s *OAuthServer) HTTPHandler() http.Handler {
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/.well-known/oauth-authorization-server", s.metadataHandler)
|
||||
mux.HandleFunc("/oauth/register", s.registerHandler)
|
||||
mux.HandleFunc("/oauth/authorize", s.authorizeHandler)
|
||||
mux.HandleFunc("/oauth/token", s.tokenHandler)
|
||||
mux.HandleFunc("/oauth/revoke", s.revokeHandler)
|
||||
mux.HandleFunc("/oauth/introspect", s.introspectHandler)
|
||||
mux.HandleFunc(s.cfg.ProviderCallbackPath, s.providerCallbackHandler)
|
||||
return mux
|
||||
}
|
||||
|
||||
// cleanupExpired removes stale pending auths and codes every 5 minutes.
|
||||
func (s *OAuthServer) cleanupExpired() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
for range ticker.C {
|
||||
now := time.Now()
|
||||
s.mu.Lock()
|
||||
for k, p := range s.pending {
|
||||
if now.After(p.ExpiresAt) {
|
||||
delete(s.pending, k)
|
||||
}
|
||||
}
|
||||
for k, p := range s.codes {
|
||||
if now.After(p.ExpiresAt) {
|
||||
delete(s.codes, k)
|
||||
}
|
||||
}
|
||||
s.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// RFC 8414 — Server metadata
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
func (s *OAuthServer) metadataHandler(w http.ResponseWriter, r *http.Request) {
|
||||
issuer := s.cfg.Issuer
|
||||
meta := map[string]interface{}{
|
||||
"issuer": issuer,
|
||||
"authorization_endpoint": issuer + "/oauth/authorize",
|
||||
"token_endpoint": issuer + "/oauth/token",
|
||||
"registration_endpoint": issuer + "/oauth/register",
|
||||
"revocation_endpoint": issuer + "/oauth/revoke",
|
||||
"introspection_endpoint": issuer + "/oauth/introspect",
|
||||
"scopes_supported": s.cfg.DefaultScopes,
|
||||
"response_types_supported": []string{"code"},
|
||||
"grant_types_supported": []string{"authorization_code", "refresh_token"},
|
||||
"code_challenge_methods_supported": []string{"S256"},
|
||||
"token_endpoint_auth_methods_supported": []string{"none"},
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(meta) //nolint:errcheck
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// RFC 7591 — Dynamic client registration
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
func (s *OAuthServer) registerHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
var req struct {
|
||||
RedirectURIs []string `json:"redirect_uris"`
|
||||
ClientName string `json:"client_name"`
|
||||
GrantTypes []string `json:"grant_types"`
|
||||
AllowedScopes []string `json:"allowed_scopes"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeOAuthError(w, "invalid_request", "malformed JSON", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if len(req.RedirectURIs) == 0 {
|
||||
writeOAuthError(w, "invalid_request", "redirect_uris required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
grantTypes := req.GrantTypes
|
||||
if len(grantTypes) == 0 {
|
||||
grantTypes = []string{"authorization_code"}
|
||||
}
|
||||
allowedScopes := req.AllowedScopes
|
||||
if len(allowedScopes) == 0 {
|
||||
allowedScopes = s.cfg.DefaultScopes
|
||||
}
|
||||
clientID, err := randomOAuthToken()
|
||||
if err != nil {
|
||||
http.Error(w, "server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
client := &oauthClient{
|
||||
ClientID: clientID,
|
||||
RedirectURIs: req.RedirectURIs,
|
||||
ClientName: req.ClientName,
|
||||
GrantTypes: grantTypes,
|
||||
AllowedScopes: allowedScopes,
|
||||
}
|
||||
|
||||
if s.cfg.PersistClients && s.auth != nil {
|
||||
dbClient := &OAuthServerClient{
|
||||
ClientID: client.ClientID,
|
||||
RedirectURIs: client.RedirectURIs,
|
||||
ClientName: client.ClientName,
|
||||
GrantTypes: client.GrantTypes,
|
||||
AllowedScopes: client.AllowedScopes,
|
||||
}
|
||||
if _, err := s.auth.OAuthRegisterClient(r.Context(), dbClient); err != nil {
|
||||
http.Error(w, "server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.clients[clientID] = client
|
||||
s.mu.Unlock()
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
json.NewEncoder(w).Encode(client) //nolint:errcheck
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Authorization endpoint — GET + POST /oauth/authorize
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
func (s *OAuthServer) authorizeHandler(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.Method {
|
||||
case http.MethodGet:
|
||||
s.authorizeGet(w, r)
|
||||
case http.MethodPost:
|
||||
s.authorizePost(w, r)
|
||||
default:
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
}
|
||||
}
|
||||
|
||||
// authorizeGet validates the request and either:
|
||||
// - Redirects to an external provider (if providers are registered)
|
||||
// - Renders a login form (if the server is its own identity provider)
|
||||
func (s *OAuthServer) authorizeGet(w http.ResponseWriter, r *http.Request) {
|
||||
q := r.URL.Query()
|
||||
clientID := q.Get("client_id")
|
||||
redirectURI := q.Get("redirect_uri")
|
||||
clientState := q.Get("state")
|
||||
codeChallenge := q.Get("code_challenge")
|
||||
codeChallengeMethod := q.Get("code_challenge_method")
|
||||
providerName := q.Get("provider")
|
||||
scopeStr := q.Get("scope")
|
||||
var scopes []string
|
||||
if scopeStr != "" {
|
||||
scopes = strings.Fields(scopeStr)
|
||||
}
|
||||
|
||||
if q.Get("response_type") != "code" {
|
||||
writeOAuthError(w, "unsupported_response_type", "only 'code' is supported", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if codeChallenge == "" {
|
||||
writeOAuthError(w, "invalid_request", "code_challenge required (PKCE S256)", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if codeChallengeMethod != "" && codeChallengeMethod != "S256" {
|
||||
writeOAuthError(w, "invalid_request", "only S256 code_challenge_method is supported", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
client, ok := s.lookupOrFetchClient(r.Context(), clientID)
|
||||
if !ok {
|
||||
writeOAuthError(w, "invalid_client", "unknown client_id", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if !oauthSliceContains(client.RedirectURIs, redirectURI) {
|
||||
writeOAuthError(w, "invalid_request", "redirect_uri not registered", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// External provider path
|
||||
if len(s.providers) > 0 {
|
||||
s.redirectToExternalProvider(w, r, clientID, redirectURI, clientState, codeChallenge, codeChallengeMethod, providerName, scopes)
|
||||
return
|
||||
}
|
||||
|
||||
// Direct login form path (server is its own identity provider)
|
||||
if s.auth == nil {
|
||||
http.Error(w, "no authentication provider configured", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
s.renderLoginForm(w, r, clientID, redirectURI, clientState, codeChallenge, codeChallengeMethod, scopeStr, "")
|
||||
}
|
||||
|
||||
// authorizePost handles login form submission for the direct login flow.
|
||||
func (s *OAuthServer) authorizePost(w http.ResponseWriter, r *http.Request) {
|
||||
if err := r.ParseForm(); err != nil {
|
||||
http.Error(w, "invalid form", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
clientID := r.FormValue("client_id")
|
||||
redirectURI := r.FormValue("redirect_uri")
|
||||
clientState := r.FormValue("client_state")
|
||||
codeChallenge := r.FormValue("code_challenge")
|
||||
codeChallengeMethod := r.FormValue("code_challenge_method")
|
||||
username := r.FormValue("username")
|
||||
password := r.FormValue("password")
|
||||
scopeStr := r.FormValue("scope")
|
||||
var scopes []string
|
||||
if scopeStr != "" {
|
||||
scopes = strings.Fields(scopeStr)
|
||||
}
|
||||
|
||||
client, ok := s.lookupOrFetchClient(r.Context(), clientID)
|
||||
if !ok || !oauthSliceContains(client.RedirectURIs, redirectURI) {
|
||||
http.Error(w, "invalid client or redirect_uri", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if s.auth == nil {
|
||||
http.Error(w, "no authentication provider configured", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
loginResp, err := s.auth.Login(r.Context(), LoginRequest{
|
||||
Username: username,
|
||||
Password: password,
|
||||
})
|
||||
if err != nil {
|
||||
s.renderLoginForm(w, r, clientID, redirectURI, clientState, codeChallenge, codeChallengeMethod, scopeStr, "Invalid username or password")
|
||||
return
|
||||
}
|
||||
|
||||
s.issueCodeAndRedirect(w, r, loginResp.Token, clientID, redirectURI, clientState, codeChallenge, codeChallengeMethod, "", scopes)
|
||||
}
|
||||
|
||||
// redirectToExternalProvider stores the pending auth and redirects to the configured provider.
|
||||
func (s *OAuthServer) redirectToExternalProvider(w http.ResponseWriter, r *http.Request, clientID, redirectURI, clientState, codeChallenge, codeChallengeMethod, providerName string, scopes []string) {
|
||||
var provider *externalProvider
|
||||
if providerName != "" {
|
||||
for i := range s.providers {
|
||||
if s.providers[i].providerName == providerName {
|
||||
provider = &s.providers[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
if provider == nil {
|
||||
http.Error(w, fmt.Sprintf("provider %q not found", providerName), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
provider = &s.providers[0]
|
||||
}
|
||||
|
||||
providerState, err := randomOAuthToken()
|
||||
if err != nil {
|
||||
http.Error(w, "server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
pending := &pendingAuth{
|
||||
ClientID: clientID,
|
||||
RedirectURI: redirectURI,
|
||||
ClientState: clientState,
|
||||
CodeChallenge: codeChallenge,
|
||||
CodeChallengeMethod: codeChallengeMethod,
|
||||
ProviderName: provider.providerName,
|
||||
ExpiresAt: time.Now().Add(10 * time.Minute),
|
||||
Scopes: scopes,
|
||||
}
|
||||
s.mu.Lock()
|
||||
s.pending[providerState] = pending
|
||||
s.mu.Unlock()
|
||||
|
||||
authURL, err := provider.auth.OAuth2GetAuthURL(provider.providerName, providerState)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
http.Redirect(w, r, authURL, http.StatusFound)
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// External provider callback — GET {ProviderCallbackPath}
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
func (s *OAuthServer) providerCallbackHandler(w http.ResponseWriter, r *http.Request) {
|
||||
code := r.URL.Query().Get("code")
|
||||
providerState := r.URL.Query().Get("state")
|
||||
|
||||
if code == "" {
|
||||
http.Error(w, "missing code", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
pending, ok := s.pending[providerState]
|
||||
if ok {
|
||||
delete(s.pending, providerState)
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
if !ok || time.Now().After(pending.ExpiresAt) {
|
||||
http.Error(w, "invalid or expired state", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
provider := s.providerByName(pending.ProviderName)
|
||||
if provider == nil {
|
||||
http.Error(w, fmt.Sprintf("provider %q not found", pending.ProviderName), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
loginResp, err := provider.auth.OAuth2HandleCallback(r.Context(), pending.ProviderName, code, providerState)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
s.issueCodeAndRedirect(w, r, loginResp.Token,
|
||||
pending.ClientID, pending.RedirectURI, pending.ClientState,
|
||||
pending.CodeChallenge, pending.CodeChallengeMethod, pending.ProviderName, pending.Scopes)
|
||||
}
|
||||
|
||||
// issueCodeAndRedirect generates a short-lived auth code and redirects to the MCP client.
|
||||
func (s *OAuthServer) issueCodeAndRedirect(w http.ResponseWriter, r *http.Request, sessionToken, clientID, redirectURI, clientState, codeChallenge, codeChallengeMethod, providerName string, scopes []string) {
|
||||
authCode, err := randomOAuthToken()
|
||||
if err != nil {
|
||||
http.Error(w, "server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
pending := &pendingAuth{
|
||||
ClientID: clientID,
|
||||
RedirectURI: redirectURI,
|
||||
ClientState: clientState,
|
||||
CodeChallenge: codeChallenge,
|
||||
CodeChallengeMethod: codeChallengeMethod,
|
||||
ProviderName: providerName,
|
||||
SessionToken: sessionToken,
|
||||
ExpiresAt: time.Now().Add(s.cfg.AuthCodeTTL),
|
||||
Scopes: scopes,
|
||||
}
|
||||
|
||||
if s.cfg.PersistCodes && s.auth != nil {
|
||||
oauthCode := &OAuthCode{
|
||||
Code: authCode,
|
||||
ClientID: clientID,
|
||||
RedirectURI: redirectURI,
|
||||
ClientState: clientState,
|
||||
CodeChallenge: codeChallenge,
|
||||
CodeChallengeMethod: codeChallengeMethod,
|
||||
SessionToken: sessionToken,
|
||||
Scopes: scopes,
|
||||
ExpiresAt: pending.ExpiresAt,
|
||||
}
|
||||
if err := s.auth.OAuthSaveCode(r.Context(), oauthCode); err != nil {
|
||||
http.Error(w, "server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
s.mu.Lock()
|
||||
s.codes[authCode] = pending
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
redirectURL, err := url.Parse(redirectURI)
|
||||
if err != nil {
|
||||
http.Error(w, "invalid redirect_uri", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
qp := redirectURL.Query()
|
||||
qp.Set("code", authCode)
|
||||
if clientState != "" {
|
||||
qp.Set("state", clientState)
|
||||
}
|
||||
redirectURL.RawQuery = qp.Encode()
|
||||
http.Redirect(w, r, redirectURL.String(), http.StatusFound)
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Token endpoint — POST /oauth/token
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
func (s *OAuthServer) tokenHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
if err := r.ParseForm(); err != nil {
|
||||
writeOAuthError(w, "invalid_request", "cannot parse form", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
switch r.FormValue("grant_type") {
|
||||
case "authorization_code":
|
||||
s.handleAuthCodeGrant(w, r)
|
||||
case "refresh_token":
|
||||
s.handleRefreshGrant(w, r)
|
||||
default:
|
||||
writeOAuthError(w, "unsupported_grant_type", "", http.StatusBadRequest)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OAuthServer) handleAuthCodeGrant(w http.ResponseWriter, r *http.Request) {
|
||||
code := r.FormValue("code")
|
||||
redirectURI := r.FormValue("redirect_uri")
|
||||
clientID := r.FormValue("client_id")
|
||||
codeVerifier := r.FormValue("code_verifier")
|
||||
|
||||
if code == "" || codeVerifier == "" {
|
||||
writeOAuthError(w, "invalid_request", "code and code_verifier required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
var sessionToken string
|
||||
var scopes []string
|
||||
|
||||
if s.cfg.PersistCodes && s.auth != nil {
|
||||
oauthCode, err := s.auth.OAuthExchangeCode(r.Context(), code)
|
||||
if err != nil {
|
||||
writeOAuthError(w, "invalid_grant", "code expired or invalid", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if oauthCode.ClientID != clientID {
|
||||
writeOAuthError(w, "invalid_client", "", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if oauthCode.RedirectURI != redirectURI {
|
||||
writeOAuthError(w, "invalid_grant", "redirect_uri mismatch", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if !validatePKCESHA256(oauthCode.CodeChallenge, codeVerifier) {
|
||||
writeOAuthError(w, "invalid_grant", "code_verifier invalid", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
sessionToken = oauthCode.SessionToken
|
||||
scopes = oauthCode.Scopes
|
||||
} else {
|
||||
s.mu.Lock()
|
||||
pending, ok := s.codes[code]
|
||||
if ok {
|
||||
delete(s.codes, code)
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
if !ok || time.Now().After(pending.ExpiresAt) {
|
||||
writeOAuthError(w, "invalid_grant", "code expired or invalid", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if pending.ClientID != clientID {
|
||||
writeOAuthError(w, "invalid_client", "", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if pending.RedirectURI != redirectURI {
|
||||
writeOAuthError(w, "invalid_grant", "redirect_uri mismatch", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if !validatePKCESHA256(pending.CodeChallenge, codeVerifier) {
|
||||
writeOAuthError(w, "invalid_grant", "code_verifier invalid", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
sessionToken = pending.SessionToken
|
||||
scopes = pending.Scopes
|
||||
}
|
||||
|
||||
writeOAuthToken(w, sessionToken, "", scopes)
|
||||
}
|
||||
|
||||
func (s *OAuthServer) handleRefreshGrant(w http.ResponseWriter, r *http.Request) {
|
||||
refreshToken := r.FormValue("refresh_token")
|
||||
providerName := r.FormValue("provider")
|
||||
if refreshToken == "" {
|
||||
writeOAuthError(w, "invalid_request", "refresh_token required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Try external providers first, then fall back to DatabaseAuthenticator
|
||||
provider := s.providerByName(providerName)
|
||||
if provider != nil {
|
||||
loginResp, err := provider.auth.OAuth2RefreshToken(r.Context(), refreshToken, providerName)
|
||||
if err != nil {
|
||||
writeOAuthError(w, "invalid_grant", err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
writeOAuthToken(w, loginResp.Token, loginResp.RefreshToken, nil)
|
||||
return
|
||||
}
|
||||
|
||||
if s.auth != nil {
|
||||
loginResp, err := s.auth.RefreshToken(r.Context(), refreshToken)
|
||||
if err != nil {
|
||||
writeOAuthError(w, "invalid_grant", err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
writeOAuthToken(w, loginResp.Token, loginResp.RefreshToken, nil)
|
||||
return
|
||||
}
|
||||
|
||||
writeOAuthError(w, "invalid_grant", "no provider available for refresh", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// RFC 7009 — Token revocation
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
func (s *OAuthServer) revokeHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
if err := r.ParseForm(); err != nil {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
token := r.FormValue("token")
|
||||
if token == "" {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
|
||||
if s.auth != nil {
|
||||
s.auth.OAuthRevokeToken(r.Context(), token) //nolint:errcheck
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// RFC 7662 — Token introspection
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
func (s *OAuthServer) introspectHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
if err := r.ParseForm(); err != nil {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write([]byte(`{"active":false}`)) //nolint:errcheck
|
||||
return
|
||||
}
|
||||
token := r.FormValue("token")
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
if token == "" || s.auth == nil {
|
||||
w.Write([]byte(`{"active":false}`)) //nolint:errcheck
|
||||
return
|
||||
}
|
||||
|
||||
info, err := s.auth.OAuthIntrospectToken(r.Context(), token)
|
||||
if err != nil {
|
||||
w.Write([]byte(`{"active":false}`)) //nolint:errcheck
|
||||
return
|
||||
}
|
||||
json.NewEncoder(w).Encode(info) //nolint:errcheck
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Login form (direct identity provider mode)
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
func (s *OAuthServer) renderLoginForm(w http.ResponseWriter, r *http.Request, clientID, redirectURI, clientState, codeChallenge, codeChallengeMethod, scope, errMsg string) {
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
errHTML := ""
|
||||
if errMsg != "" {
|
||||
errHTML = `<p style="color:red">` + errMsg + `</p>`
|
||||
}
|
||||
fmt.Fprintf(w, loginFormHTML,
|
||||
s.cfg.LoginTitle,
|
||||
s.cfg.LoginTitle,
|
||||
errHTML,
|
||||
clientID,
|
||||
htmlEscape(redirectURI),
|
||||
htmlEscape(clientState),
|
||||
htmlEscape(codeChallenge),
|
||||
htmlEscape(codeChallengeMethod),
|
||||
htmlEscape(scope),
|
||||
)
|
||||
}
|
||||
|
||||
const loginFormHTML = `<!DOCTYPE html>
|
||||
<html><head><meta charset="utf-8"><title>%s</title>
|
||||
<style>body{font-family:sans-serif;display:flex;justify-content:center;align-items:center;min-height:100vh;margin:0;background:#f5f5f5}
|
||||
.card{background:#fff;padding:2rem;border-radius:8px;box-shadow:0 2px 8px rgba(0,0,0,.15);width:320px}
|
||||
h2{margin:0 0 1.5rem;font-size:1.25rem}
|
||||
label{display:block;margin-bottom:.25rem;font-size:.875rem;color:#555}
|
||||
input[type=text],input[type=password]{width:100%%;box-sizing:border-box;padding:.5rem;border:1px solid #ccc;border-radius:4px;margin-bottom:1rem;font-size:1rem}
|
||||
button{width:100%%;padding:.6rem;background:#0070f3;color:#fff;border:none;border-radius:4px;font-size:1rem;cursor:pointer}
|
||||
button:hover{background:#005fd4}.err{color:#d32f2f;margin-bottom:1rem;font-size:.875rem}</style>
|
||||
</head><body><div class="card">
|
||||
<h2>%s</h2>%s
|
||||
<form method="POST" action="/oauth/authorize">
|
||||
<input type="hidden" name="client_id" value="%s">
|
||||
<input type="hidden" name="redirect_uri" value="%s">
|
||||
<input type="hidden" name="client_state" value="%s">
|
||||
<input type="hidden" name="code_challenge" value="%s">
|
||||
<input type="hidden" name="code_challenge_method" value="%s">
|
||||
<input type="hidden" name="scope" value="%s">
|
||||
<label>Username</label><input type="text" name="username" autofocus autocomplete="username">
|
||||
<label>Password</label><input type="password" name="password" autocomplete="current-password">
|
||||
<button type="submit">Sign in</button>
|
||||
</form></div></body></html>`
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Helpers
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// lookupOrFetchClient checks in-memory first, then DB if PersistClients is enabled.
|
||||
func (s *OAuthServer) lookupOrFetchClient(ctx context.Context, clientID string) (*oauthClient, bool) {
|
||||
s.mu.RLock()
|
||||
c, ok := s.clients[clientID]
|
||||
s.mu.RUnlock()
|
||||
if ok {
|
||||
return c, true
|
||||
}
|
||||
|
||||
if !s.cfg.PersistClients || s.auth == nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
dbClient, err := s.auth.OAuthGetClient(ctx, clientID)
|
||||
if err != nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
c = &oauthClient{
|
||||
ClientID: dbClient.ClientID,
|
||||
RedirectURIs: dbClient.RedirectURIs,
|
||||
ClientName: dbClient.ClientName,
|
||||
GrantTypes: dbClient.GrantTypes,
|
||||
AllowedScopes: dbClient.AllowedScopes,
|
||||
}
|
||||
s.mu.Lock()
|
||||
s.clients[clientID] = c
|
||||
s.mu.Unlock()
|
||||
return c, true
|
||||
}
|
||||
|
||||
func (s *OAuthServer) providerByName(name string) *externalProvider {
|
||||
for i := range s.providers {
|
||||
if s.providers[i].providerName == name {
|
||||
return &s.providers[i]
|
||||
}
|
||||
}
|
||||
// If name is empty and only one provider exists, return it
|
||||
if name == "" && len(s.providers) == 1 {
|
||||
return &s.providers[0]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validatePKCESHA256(challenge, verifier string) bool {
|
||||
h := sha256.Sum256([]byte(verifier))
|
||||
return base64.RawURLEncoding.EncodeToString(h[:]) == challenge
|
||||
}
|
||||
|
||||
func randomOAuthToken() (string, error) {
|
||||
b := make([]byte, 32)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
func oauthSliceContains(slice []string, s string) bool {
|
||||
for _, v := range slice {
|
||||
if strings.EqualFold(v, s) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func writeOAuthToken(w http.ResponseWriter, accessToken, refreshToken string, scopes []string) {
|
||||
resp := map[string]interface{}{
|
||||
"access_token": accessToken,
|
||||
"token_type": "Bearer",
|
||||
"expires_in": 86400,
|
||||
}
|
||||
if refreshToken != "" {
|
||||
resp["refresh_token"] = refreshToken
|
||||
}
|
||||
if len(scopes) > 0 {
|
||||
resp["scope"] = strings.Join(scopes, " ")
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Cache-Control", "no-store")
|
||||
w.Header().Set("Pragma", "no-cache")
|
||||
json.NewEncoder(w).Encode(resp) //nolint:errcheck
|
||||
}
|
||||
|
||||
func writeOAuthError(w http.ResponseWriter, errCode, description string, status int) {
|
||||
resp := map[string]string{"error": errCode}
|
||||
if description != "" {
|
||||
resp["error_description"] = description
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
json.NewEncoder(w).Encode(resp) //nolint:errcheck
|
||||
}
|
||||
|
||||
func htmlEscape(s string) string {
|
||||
s = strings.ReplaceAll(s, "&", "&")
|
||||
s = strings.ReplaceAll(s, `"`, """)
|
||||
s = strings.ReplaceAll(s, "<", "<")
|
||||
s = strings.ReplaceAll(s, ">", ">")
|
||||
return s
|
||||
}
|
||||
202
pkg/security/oauth_server_db.go
Normal file
202
pkg/security/oauth_server_db.go
Normal file
@@ -0,0 +1,202 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// OAuthServerClient is a persisted RFC 7591 registered OAuth2 client.
|
||||
type OAuthServerClient struct {
|
||||
ClientID string `json:"client_id"`
|
||||
RedirectURIs []string `json:"redirect_uris"`
|
||||
ClientName string `json:"client_name,omitempty"`
|
||||
GrantTypes []string `json:"grant_types"`
|
||||
AllowedScopes []string `json:"allowed_scopes,omitempty"`
|
||||
}
|
||||
|
||||
// OAuthCode is a short-lived authorization code.
|
||||
type OAuthCode struct {
|
||||
Code string `json:"code"`
|
||||
ClientID string `json:"client_id"`
|
||||
RedirectURI string `json:"redirect_uri"`
|
||||
ClientState string `json:"client_state,omitempty"`
|
||||
CodeChallenge string `json:"code_challenge"`
|
||||
CodeChallengeMethod string `json:"code_challenge_method"`
|
||||
SessionToken string `json:"session_token"`
|
||||
Scopes []string `json:"scopes,omitempty"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
}
|
||||
|
||||
// OAuthTokenInfo is the RFC 7662 token introspection response.
|
||||
type OAuthTokenInfo struct {
|
||||
Active bool `json:"active"`
|
||||
Sub string `json:"sub,omitempty"`
|
||||
Username string `json:"username,omitempty"`
|
||||
Email string `json:"email,omitempty"`
|
||||
Roles []string `json:"roles,omitempty"`
|
||||
Exp int64 `json:"exp,omitempty"`
|
||||
Iat int64 `json:"iat,omitempty"`
|
||||
}
|
||||
|
||||
// OAuthRegisterClient persists an OAuth2 client registration.
|
||||
func (a *DatabaseAuthenticator) OAuthRegisterClient(ctx context.Context, client *OAuthServerClient) (*OAuthServerClient, error) {
|
||||
input, err := json.Marshal(client)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal client: %w", err)
|
||||
}
|
||||
|
||||
var success bool
|
||||
var errMsg *string
|
||||
var data []byte
|
||||
|
||||
err = a.db.QueryRowContext(ctx, fmt.Sprintf(`
|
||||
SELECT p_success, p_error, p_data::text
|
||||
FROM %s($1::jsonb)
|
||||
`, a.sqlNames.OAuthRegisterClient), input).Scan(&success, &errMsg, &data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to register client: %w", err)
|
||||
}
|
||||
if !success {
|
||||
if errMsg != nil {
|
||||
return nil, fmt.Errorf("%s", *errMsg)
|
||||
}
|
||||
return nil, fmt.Errorf("failed to register client")
|
||||
}
|
||||
|
||||
var result OAuthServerClient
|
||||
if err := json.Unmarshal(data, &result); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse registered client: %w", err)
|
||||
}
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// OAuthGetClient retrieves a registered client by ID.
|
||||
func (a *DatabaseAuthenticator) OAuthGetClient(ctx context.Context, clientID string) (*OAuthServerClient, error) {
|
||||
var success bool
|
||||
var errMsg *string
|
||||
var data []byte
|
||||
|
||||
err := a.db.QueryRowContext(ctx, fmt.Sprintf(`
|
||||
SELECT p_success, p_error, p_data::text
|
||||
FROM %s($1)
|
||||
`, a.sqlNames.OAuthGetClient), clientID).Scan(&success, &errMsg, &data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get client: %w", err)
|
||||
}
|
||||
if !success {
|
||||
if errMsg != nil {
|
||||
return nil, fmt.Errorf("%s", *errMsg)
|
||||
}
|
||||
return nil, fmt.Errorf("client not found")
|
||||
}
|
||||
|
||||
var result OAuthServerClient
|
||||
if err := json.Unmarshal(data, &result); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse client: %w", err)
|
||||
}
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// OAuthSaveCode persists an authorization code.
|
||||
func (a *DatabaseAuthenticator) OAuthSaveCode(ctx context.Context, code *OAuthCode) error {
|
||||
input, err := json.Marshal(code)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal code: %w", err)
|
||||
}
|
||||
|
||||
var success bool
|
||||
var errMsg *string
|
||||
|
||||
err = a.db.QueryRowContext(ctx, fmt.Sprintf(`
|
||||
SELECT p_success, p_error
|
||||
FROM %s($1::jsonb)
|
||||
`, a.sqlNames.OAuthSaveCode), input).Scan(&success, &errMsg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to save code: %w", err)
|
||||
}
|
||||
if !success {
|
||||
if errMsg != nil {
|
||||
return fmt.Errorf("%s", *errMsg)
|
||||
}
|
||||
return fmt.Errorf("failed to save code")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// OAuthExchangeCode retrieves and deletes an authorization code (single use).
|
||||
func (a *DatabaseAuthenticator) OAuthExchangeCode(ctx context.Context, code string) (*OAuthCode, error) {
|
||||
var success bool
|
||||
var errMsg *string
|
||||
var data []byte
|
||||
|
||||
err := a.db.QueryRowContext(ctx, fmt.Sprintf(`
|
||||
SELECT p_success, p_error, p_data::text
|
||||
FROM %s($1)
|
||||
`, a.sqlNames.OAuthExchangeCode), code).Scan(&success, &errMsg, &data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to exchange code: %w", err)
|
||||
}
|
||||
if !success {
|
||||
if errMsg != nil {
|
||||
return nil, fmt.Errorf("%s", *errMsg)
|
||||
}
|
||||
return nil, fmt.Errorf("invalid or expired code")
|
||||
}
|
||||
|
||||
var result OAuthCode
|
||||
if err := json.Unmarshal(data, &result); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse code data: %w", err)
|
||||
}
|
||||
result.Code = code
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// OAuthIntrospectToken validates a token and returns its metadata (RFC 7662).
|
||||
func (a *DatabaseAuthenticator) OAuthIntrospectToken(ctx context.Context, token string) (*OAuthTokenInfo, error) {
|
||||
var success bool
|
||||
var errMsg *string
|
||||
var data []byte
|
||||
|
||||
err := a.db.QueryRowContext(ctx, fmt.Sprintf(`
|
||||
SELECT p_success, p_error, p_data::text
|
||||
FROM %s($1)
|
||||
`, a.sqlNames.OAuthIntrospect), token).Scan(&success, &errMsg, &data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to introspect token: %w", err)
|
||||
}
|
||||
if !success {
|
||||
if errMsg != nil {
|
||||
return nil, fmt.Errorf("%s", *errMsg)
|
||||
}
|
||||
return nil, fmt.Errorf("introspection failed")
|
||||
}
|
||||
|
||||
var result OAuthTokenInfo
|
||||
if err := json.Unmarshal(data, &result); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse token info: %w", err)
|
||||
}
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// OAuthRevokeToken revokes a token by deleting the session (RFC 7009).
|
||||
func (a *DatabaseAuthenticator) OAuthRevokeToken(ctx context.Context, token string) error {
|
||||
var success bool
|
||||
var errMsg *string
|
||||
|
||||
err := a.db.QueryRowContext(ctx, fmt.Sprintf(`
|
||||
SELECT p_success, p_error
|
||||
FROM %s($1)
|
||||
`, a.sqlNames.OAuthRevoke), token).Scan(&success, &errMsg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to revoke token: %w", err)
|
||||
}
|
||||
if !success {
|
||||
if errMsg != nil {
|
||||
return fmt.Errorf("%s", *errMsg)
|
||||
}
|
||||
return fmt.Errorf("failed to revoke token")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -11,12 +11,14 @@ import (
|
||||
)
|
||||
|
||||
// DatabasePasskeyProvider implements PasskeyProvider using database storage
|
||||
// Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults)
|
||||
type DatabasePasskeyProvider struct {
|
||||
db *sql.DB
|
||||
rpID string // Relying Party ID (domain)
|
||||
rpName string // Relying Party display name
|
||||
rpOrigin string // Expected origin for WebAuthn
|
||||
timeout int64 // Timeout in milliseconds (default: 60000)
|
||||
sqlNames *SQLNames
|
||||
}
|
||||
|
||||
// DatabasePasskeyProviderOptions configures the passkey provider
|
||||
@@ -29,6 +31,8 @@ type DatabasePasskeyProviderOptions struct {
|
||||
RPOrigin string
|
||||
// Timeout is the timeout for operations in milliseconds (default: 60000)
|
||||
Timeout int64
|
||||
// SQLNames provides custom SQL procedure/function names. If nil, uses DefaultSQLNames().
|
||||
SQLNames *SQLNames
|
||||
}
|
||||
|
||||
// NewDatabasePasskeyProvider creates a new database-backed passkey provider
|
||||
@@ -37,12 +41,15 @@ func NewDatabasePasskeyProvider(db *sql.DB, opts DatabasePasskeyProviderOptions)
|
||||
opts.Timeout = 60000 // 60 seconds default
|
||||
}
|
||||
|
||||
sqlNames := MergeSQLNames(DefaultSQLNames(), opts.SQLNames)
|
||||
|
||||
return &DatabasePasskeyProvider{
|
||||
db: db,
|
||||
rpID: opts.RPID,
|
||||
rpName: opts.RPName,
|
||||
rpOrigin: opts.RPOrigin,
|
||||
timeout: opts.Timeout,
|
||||
sqlNames: sqlNames,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -132,7 +139,7 @@ func (p *DatabasePasskeyProvider) CompleteRegistration(ctx context.Context, user
|
||||
var errorMsg sql.NullString
|
||||
var credentialID sql.NullInt64
|
||||
|
||||
query := `SELECT p_success, p_error, p_credential_id FROM resolvespec_passkey_store_credential($1::jsonb)`
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_credential_id FROM %s($1::jsonb)`, p.sqlNames.PasskeyStoreCredential)
|
||||
err = p.db.QueryRowContext(ctx, query, string(credJSON)).Scan(&success, &errorMsg, &credentialID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to store credential: %w", err)
|
||||
@@ -173,7 +180,7 @@ func (p *DatabasePasskeyProvider) BeginAuthentication(ctx context.Context, usern
|
||||
var userID sql.NullInt64
|
||||
var credentialsJSON sql.NullString
|
||||
|
||||
query := `SELECT p_success, p_error, p_user_id, p_credentials::text FROM resolvespec_passkey_get_credentials_by_username($1)`
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_user_id, p_credentials::text FROM %s($1)`, p.sqlNames.PasskeyGetCredsByUsername)
|
||||
err := p.db.QueryRowContext(ctx, query, username).Scan(&success, &errorMsg, &userID, &credentialsJSON)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get credentials: %w", err)
|
||||
@@ -233,7 +240,7 @@ func (p *DatabasePasskeyProvider) CompleteAuthentication(ctx context.Context, re
|
||||
var errorMsg sql.NullString
|
||||
var credentialJSON sql.NullString
|
||||
|
||||
query := `SELECT p_success, p_error, p_credential::text FROM resolvespec_passkey_get_credential($1)`
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_credential::text FROM %s($1)`, p.sqlNames.PasskeyGetCredential)
|
||||
err := p.db.QueryRowContext(ctx, query, response.RawID).Scan(&success, &errorMsg, &credentialJSON)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to get credential: %w", err)
|
||||
@@ -264,7 +271,7 @@ func (p *DatabasePasskeyProvider) CompleteAuthentication(ctx context.Context, re
|
||||
var updateError sql.NullString
|
||||
var cloneWarning sql.NullBool
|
||||
|
||||
updateQuery := `SELECT p_success, p_error, p_clone_warning FROM resolvespec_passkey_update_counter($1, $2)`
|
||||
updateQuery := fmt.Sprintf(`SELECT p_success, p_error, p_clone_warning FROM %s($1, $2)`, p.sqlNames.PasskeyUpdateCounter)
|
||||
err = p.db.QueryRowContext(ctx, updateQuery, response.RawID, newCounter).Scan(&updateSuccess, &updateError, &cloneWarning)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to update counter: %w", err)
|
||||
@@ -283,7 +290,7 @@ func (p *DatabasePasskeyProvider) GetCredentials(ctx context.Context, userID int
|
||||
var errorMsg sql.NullString
|
||||
var credentialsJSON sql.NullString
|
||||
|
||||
query := `SELECT p_success, p_error, p_credentials::text FROM resolvespec_passkey_get_user_credentials($1)`
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_credentials::text FROM %s($1)`, p.sqlNames.PasskeyGetUserCredentials)
|
||||
err := p.db.QueryRowContext(ctx, query, userID).Scan(&success, &errorMsg, &credentialsJSON)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get credentials: %w", err)
|
||||
@@ -362,7 +369,7 @@ func (p *DatabasePasskeyProvider) DeleteCredential(ctx context.Context, userID i
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
|
||||
query := `SELECT p_success, p_error FROM resolvespec_passkey_delete_credential($1, $2)`
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1, $2)`, p.sqlNames.PasskeyDeleteCredential)
|
||||
err = p.db.QueryRowContext(ctx, query, userID, credID).Scan(&success, &errorMsg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete credential: %w", err)
|
||||
@@ -388,7 +395,7 @@ func (p *DatabasePasskeyProvider) UpdateCredentialName(ctx context.Context, user
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
|
||||
query := `SELECT p_success, p_error FROM resolvespec_passkey_update_name($1, $2, $3)`
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1, $2, $3)`, p.sqlNames.PasskeyUpdateName)
|
||||
err = p.db.QueryRowContext(ctx, query, userID, credID, name).Scan(&success, &errorMsg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update credential name: %w", err)
|
||||
|
||||
@@ -58,8 +58,7 @@ func (a *HeaderAuthenticator) Authenticate(r *http.Request) (*UserContext, error
|
||||
|
||||
// DatabaseAuthenticator provides session-based authentication with database storage
|
||||
// All database operations go through stored procedures for security and consistency
|
||||
// Requires stored procedures: resolvespec_login, resolvespec_logout, resolvespec_session,
|
||||
// resolvespec_session_update, resolvespec_refresh_token
|
||||
// Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults)
|
||||
// See database_schema.sql for procedure definitions
|
||||
// Also supports multiple OAuth2 providers configured with WithOAuth2()
|
||||
// Also supports passkey authentication configured with WithPasskey()
|
||||
@@ -67,6 +66,7 @@ type DatabaseAuthenticator struct {
|
||||
db *sql.DB
|
||||
cache *cache.Cache
|
||||
cacheTTL time.Duration
|
||||
sqlNames *SQLNames
|
||||
|
||||
// OAuth2 providers registry (multiple providers supported)
|
||||
oauth2Providers map[string]*OAuth2Provider
|
||||
@@ -85,6 +85,9 @@ type DatabaseAuthenticatorOptions struct {
|
||||
Cache *cache.Cache
|
||||
// PasskeyProvider is an optional passkey provider for WebAuthn/FIDO2 authentication
|
||||
PasskeyProvider PasskeyProvider
|
||||
// SQLNames provides custom SQL procedure/function names. If nil, uses DefaultSQLNames().
|
||||
// Partial overrides are supported: only set the fields you want to change.
|
||||
SQLNames *SQLNames
|
||||
}
|
||||
|
||||
func NewDatabaseAuthenticator(db *sql.DB) *DatabaseAuthenticator {
|
||||
@@ -103,10 +106,13 @@ func NewDatabaseAuthenticatorWithOptions(db *sql.DB, opts DatabaseAuthenticatorO
|
||||
cacheInstance = cache.GetDefaultCache()
|
||||
}
|
||||
|
||||
sqlNames := MergeSQLNames(DefaultSQLNames(), opts.SQLNames)
|
||||
|
||||
return &DatabaseAuthenticator{
|
||||
db: db,
|
||||
cache: cacheInstance,
|
||||
cacheTTL: opts.CacheTTL,
|
||||
sqlNames: sqlNames,
|
||||
passkeyProvider: opts.PasskeyProvider,
|
||||
}
|
||||
}
|
||||
@@ -118,12 +124,11 @@ func (a *DatabaseAuthenticator) Login(ctx context.Context, req LoginRequest) (*L
|
||||
return nil, fmt.Errorf("failed to marshal login request: %w", err)
|
||||
}
|
||||
|
||||
// Call resolvespec_login stored procedure
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
var dataJSON sql.NullString
|
||||
|
||||
query := `SELECT p_success, p_error, p_data::text FROM resolvespec_login($1::jsonb)`
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Login)
|
||||
err = a.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("login query failed: %w", err)
|
||||
@@ -153,12 +158,11 @@ func (a *DatabaseAuthenticator) Register(ctx context.Context, req RegisterReques
|
||||
return nil, fmt.Errorf("failed to marshal register request: %w", err)
|
||||
}
|
||||
|
||||
// Call resolvespec_register stored procedure
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
var dataJSON sql.NullString
|
||||
|
||||
query := `SELECT p_success, p_error, p_data::text FROM resolvespec_register($1::jsonb)`
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Register)
|
||||
err = a.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("register query failed: %w", err)
|
||||
@@ -187,12 +191,11 @@ func (a *DatabaseAuthenticator) Logout(ctx context.Context, req LogoutRequest) e
|
||||
return fmt.Errorf("failed to marshal logout request: %w", err)
|
||||
}
|
||||
|
||||
// Call resolvespec_logout stored procedure
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
var dataJSON sql.NullString
|
||||
|
||||
query := `SELECT p_success, p_error, p_data::text FROM resolvespec_logout($1::jsonb)`
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Logout)
|
||||
err = a.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
|
||||
if err != nil {
|
||||
return fmt.Errorf("logout query failed: %w", err)
|
||||
@@ -266,7 +269,7 @@ func (a *DatabaseAuthenticator) Authenticate(r *http.Request) (*UserContext, err
|
||||
var errorMsg sql.NullString
|
||||
var userJSON sql.NullString
|
||||
|
||||
query := `SELECT p_success, p_error, p_user::text FROM resolvespec_session($1, $2)`
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2)`, a.sqlNames.Session)
|
||||
err := a.db.QueryRowContext(r.Context(), query, token, reference).Scan(&success, &errorMsg, &userJSON)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("session query failed: %w", err)
|
||||
@@ -338,24 +341,22 @@ func (a *DatabaseAuthenticator) updateSessionActivity(ctx context.Context, sessi
|
||||
return
|
||||
}
|
||||
|
||||
// Call resolvespec_session_update stored procedure
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
var updatedUserJSON sql.NullString
|
||||
|
||||
query := `SELECT p_success, p_error, p_user::text FROM resolvespec_session_update($1, $2::jsonb)`
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2::jsonb)`, a.sqlNames.SessionUpdate)
|
||||
_ = a.db.QueryRowContext(ctx, query, sessionToken, string(userJSON)).Scan(&success, &errorMsg, &updatedUserJSON)
|
||||
}
|
||||
|
||||
// RefreshToken implements Refreshable interface
|
||||
func (a *DatabaseAuthenticator) RefreshToken(ctx context.Context, refreshToken string) (*LoginResponse, error) {
|
||||
// Call api_refresh_token stored procedure
|
||||
// First, we need to get the current user context for the refresh token
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
var userJSON sql.NullString
|
||||
// Get current session to pass to refresh
|
||||
query := `SELECT p_success, p_error, p_user::text FROM resolvespec_session($1, $2)`
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2)`, a.sqlNames.Session)
|
||||
err := a.db.QueryRowContext(ctx, query, refreshToken, "refresh").Scan(&success, &errorMsg, &userJSON)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("refresh token query failed: %w", err)
|
||||
@@ -368,12 +369,11 @@ func (a *DatabaseAuthenticator) RefreshToken(ctx context.Context, refreshToken s
|
||||
return nil, fmt.Errorf("invalid refresh token")
|
||||
}
|
||||
|
||||
// Call resolvespec_refresh_token to generate new token
|
||||
var newSuccess bool
|
||||
var newErrorMsg sql.NullString
|
||||
var newUserJSON sql.NullString
|
||||
|
||||
refreshQuery := `SELECT p_success, p_error, p_user::text FROM resolvespec_refresh_token($1, $2::jsonb)`
|
||||
refreshQuery := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2::jsonb)`, a.sqlNames.RefreshToken)
|
||||
err = a.db.QueryRowContext(ctx, refreshQuery, refreshToken, userJSON).Scan(&newSuccess, &newErrorMsg, &newUserJSON)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("refresh token generation failed: %w", err)
|
||||
@@ -401,27 +401,28 @@ func (a *DatabaseAuthenticator) RefreshToken(ctx context.Context, refreshToken s
|
||||
|
||||
// JWTAuthenticator provides JWT token-based authentication
|
||||
// All database operations go through stored procedures
|
||||
// Requires stored procedures: resolvespec_jwt_login, resolvespec_jwt_logout
|
||||
// Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults)
|
||||
// NOTE: JWT signing/verification requires github.com/golang-jwt/jwt/v5 to be installed and imported
|
||||
type JWTAuthenticator struct {
|
||||
secretKey []byte
|
||||
db *sql.DB
|
||||
sqlNames *SQLNames
|
||||
}
|
||||
|
||||
func NewJWTAuthenticator(secretKey string, db *sql.DB) *JWTAuthenticator {
|
||||
func NewJWTAuthenticator(secretKey string, db *sql.DB, names ...*SQLNames) *JWTAuthenticator {
|
||||
return &JWTAuthenticator{
|
||||
secretKey: []byte(secretKey),
|
||||
db: db,
|
||||
sqlNames: resolveSQLNames(names...),
|
||||
}
|
||||
}
|
||||
|
||||
func (a *JWTAuthenticator) Login(ctx context.Context, req LoginRequest) (*LoginResponse, error) {
|
||||
// Call resolvespec_jwt_login stored procedure
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
var userJSON []byte
|
||||
|
||||
query := `SELECT p_success, p_error, p_user FROM resolvespec_jwt_login($1, $2)`
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_user FROM %s($1, $2)`, a.sqlNames.JWTLogin)
|
||||
err := a.db.QueryRowContext(ctx, query, req.Username, req.Password).Scan(&success, &errorMsg, &userJSON)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("login query failed: %w", err)
|
||||
@@ -471,11 +472,10 @@ func (a *JWTAuthenticator) Login(ctx context.Context, req LoginRequest) (*LoginR
|
||||
}
|
||||
|
||||
func (a *JWTAuthenticator) Logout(ctx context.Context, req LogoutRequest) error {
|
||||
// Call resolvespec_jwt_logout stored procedure
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
|
||||
query := `SELECT p_success, p_error FROM resolvespec_jwt_logout($1, $2)`
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1, $2)`, a.sqlNames.JWTLogout)
|
||||
err := a.db.QueryRowContext(ctx, query, req.Token, req.UserID).Scan(&success, &errorMsg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("logout query failed: %w", err)
|
||||
@@ -511,24 +511,24 @@ func (a *JWTAuthenticator) Authenticate(r *http.Request) (*UserContext, error) {
|
||||
|
||||
// DatabaseColumnSecurityProvider loads column security from database
|
||||
// All database operations go through stored procedures
|
||||
// Requires stored procedure: resolvespec_column_security
|
||||
// Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults)
|
||||
type DatabaseColumnSecurityProvider struct {
|
||||
db *sql.DB
|
||||
db *sql.DB
|
||||
sqlNames *SQLNames
|
||||
}
|
||||
|
||||
func NewDatabaseColumnSecurityProvider(db *sql.DB) *DatabaseColumnSecurityProvider {
|
||||
return &DatabaseColumnSecurityProvider{db: db}
|
||||
func NewDatabaseColumnSecurityProvider(db *sql.DB, names ...*SQLNames) *DatabaseColumnSecurityProvider {
|
||||
return &DatabaseColumnSecurityProvider{db: db, sqlNames: resolveSQLNames(names...)}
|
||||
}
|
||||
|
||||
func (p *DatabaseColumnSecurityProvider) GetColumnSecurity(ctx context.Context, userID int, schema, table string) ([]ColumnSecurity, error) {
|
||||
var rules []ColumnSecurity
|
||||
|
||||
// Call resolvespec_column_security stored procedure
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
var rulesJSON []byte
|
||||
|
||||
query := `SELECT p_success, p_error, p_rules FROM resolvespec_column_security($1, $2, $3)`
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_rules FROM %s($1, $2, $3)`, p.sqlNames.ColumnSecurity)
|
||||
err := p.db.QueryRowContext(ctx, query, userID, schema, table).Scan(&success, &errorMsg, &rulesJSON)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load column security: %w", err)
|
||||
@@ -576,21 +576,21 @@ func (p *DatabaseColumnSecurityProvider) GetColumnSecurity(ctx context.Context,
|
||||
|
||||
// DatabaseRowSecurityProvider loads row security from database
|
||||
// All database operations go through stored procedures
|
||||
// Requires stored procedure: resolvespec_row_security
|
||||
// Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults)
|
||||
type DatabaseRowSecurityProvider struct {
|
||||
db *sql.DB
|
||||
db *sql.DB
|
||||
sqlNames *SQLNames
|
||||
}
|
||||
|
||||
func NewDatabaseRowSecurityProvider(db *sql.DB) *DatabaseRowSecurityProvider {
|
||||
return &DatabaseRowSecurityProvider{db: db}
|
||||
func NewDatabaseRowSecurityProvider(db *sql.DB, names ...*SQLNames) *DatabaseRowSecurityProvider {
|
||||
return &DatabaseRowSecurityProvider{db: db, sqlNames: resolveSQLNames(names...)}
|
||||
}
|
||||
|
||||
func (p *DatabaseRowSecurityProvider) GetRowSecurity(ctx context.Context, userID int, schema, table string) (RowSecurity, error) {
|
||||
var template string
|
||||
var hasBlock bool
|
||||
|
||||
// Call resolvespec_row_security stored procedure
|
||||
query := `SELECT p_template, p_block FROM resolvespec_row_security($1, $2, $3)`
|
||||
query := fmt.Sprintf(`SELECT p_template, p_block FROM %s($1, $2, $3)`, p.sqlNames.RowSecurity)
|
||||
|
||||
err := p.db.QueryRowContext(ctx, query, schema, table, userID).Scan(&template, &hasBlock)
|
||||
if err != nil {
|
||||
@@ -758,56 +758,47 @@ func (a *DatabaseAuthenticator) LoginWithPasskey(ctx context.Context, req Passke
|
||||
return nil, fmt.Errorf("passkey authentication failed: %w", err)
|
||||
}
|
||||
|
||||
// Get user data from database
|
||||
var username, email, roles string
|
||||
var userLevel int
|
||||
query := `SELECT username, email, user_level, COALESCE(roles, '') FROM users WHERE id = $1 AND is_active = true`
|
||||
err = a.db.QueryRowContext(ctx, query, userID).Scan(&username, &email, &userLevel, &roles)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get user data: %w", err)
|
||||
// Build request JSON for passkey login stored procedure
|
||||
reqData := map[string]any{
|
||||
"user_id": userID,
|
||||
}
|
||||
|
||||
// Generate session token
|
||||
sessionToken := "sess_" + generateRandomString(32) + "_" + fmt.Sprintf("%d", time.Now().Unix())
|
||||
expiresAt := time.Now().Add(24 * time.Hour)
|
||||
|
||||
// Extract IP and user agent from claims
|
||||
ipAddress := ""
|
||||
userAgent := ""
|
||||
if req.Claims != nil {
|
||||
if ip, ok := req.Claims["ip_address"].(string); ok {
|
||||
ipAddress = ip
|
||||
reqData["ip_address"] = ip
|
||||
}
|
||||
if ua, ok := req.Claims["user_agent"].(string); ok {
|
||||
userAgent = ua
|
||||
reqData["user_agent"] = ua
|
||||
}
|
||||
}
|
||||
|
||||
// Create session
|
||||
insertQuery := `INSERT INTO user_sessions (session_token, user_id, expires_at, ip_address, user_agent, last_activity_at)
|
||||
VALUES ($1, $2, $3, $4, $5, now())`
|
||||
_, err = a.db.ExecContext(ctx, insertQuery, sessionToken, userID, expiresAt, ipAddress, userAgent)
|
||||
reqJSON, err := json.Marshal(reqData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create session: %w", err)
|
||||
return nil, fmt.Errorf("failed to marshal passkey login request: %w", err)
|
||||
}
|
||||
|
||||
// Update last login
|
||||
updateQuery := `UPDATE users SET last_login_at = now() WHERE id = $1`
|
||||
_, _ = a.db.ExecContext(ctx, updateQuery, userID)
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
var dataJSON sql.NullString
|
||||
|
||||
// Return login response
|
||||
return &LoginResponse{
|
||||
Token: sessionToken,
|
||||
User: &UserContext{
|
||||
UserID: userID,
|
||||
UserName: username,
|
||||
Email: email,
|
||||
UserLevel: userLevel,
|
||||
SessionID: sessionToken,
|
||||
Roles: parseRoles(roles),
|
||||
},
|
||||
ExpiresIn: int64(24 * time.Hour.Seconds()),
|
||||
}, nil
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.PasskeyLogin)
|
||||
err = a.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("passkey login query failed: %w", err)
|
||||
}
|
||||
|
||||
if !success {
|
||||
if errorMsg.Valid {
|
||||
return nil, fmt.Errorf("%s", errorMsg.String)
|
||||
}
|
||||
return nil, fmt.Errorf("passkey login failed")
|
||||
}
|
||||
|
||||
var response LoginResponse
|
||||
if err := json.Unmarshal([]byte(dataJSON.String), &response); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse passkey login response: %w", err)
|
||||
}
|
||||
|
||||
return &response, nil
|
||||
}
|
||||
|
||||
// GetPasskeyCredentials returns all passkey credentials for a user
|
||||
|
||||
254
pkg/security/sql_names.go
Normal file
254
pkg/security/sql_names.go
Normal file
@@ -0,0 +1,254 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"regexp"
|
||||
)
|
||||
|
||||
var validSQLIdentifier = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`)
|
||||
|
||||
// SQLNames defines all configurable SQL stored procedure and table names
|
||||
// used by the security package. Override individual fields to remap
|
||||
// to custom database objects. Use DefaultSQLNames() for baseline defaults,
|
||||
// and MergeSQLNames() to apply partial overrides.
|
||||
type SQLNames struct {
|
||||
// Auth procedures (DatabaseAuthenticator)
|
||||
Login string // default: "resolvespec_login"
|
||||
Register string // default: "resolvespec_register"
|
||||
Logout string // default: "resolvespec_logout"
|
||||
Session string // default: "resolvespec_session"
|
||||
SessionUpdate string // default: "resolvespec_session_update"
|
||||
RefreshToken string // default: "resolvespec_refresh_token"
|
||||
|
||||
// JWT procedures (JWTAuthenticator)
|
||||
JWTLogin string // default: "resolvespec_jwt_login"
|
||||
JWTLogout string // default: "resolvespec_jwt_logout"
|
||||
|
||||
// Security policy procedures
|
||||
ColumnSecurity string // default: "resolvespec_column_security"
|
||||
RowSecurity string // default: "resolvespec_row_security"
|
||||
|
||||
// TOTP procedures (DatabaseTwoFactorProvider)
|
||||
TOTPEnable string // default: "resolvespec_totp_enable"
|
||||
TOTPDisable string // default: "resolvespec_totp_disable"
|
||||
TOTPGetStatus string // default: "resolvespec_totp_get_status"
|
||||
TOTPGetSecret string // default: "resolvespec_totp_get_secret"
|
||||
TOTPRegenerateBackup string // default: "resolvespec_totp_regenerate_backup_codes"
|
||||
TOTPValidateBackupCode string // default: "resolvespec_totp_validate_backup_code"
|
||||
|
||||
// Passkey procedures (DatabasePasskeyProvider)
|
||||
PasskeyStoreCredential string // default: "resolvespec_passkey_store_credential"
|
||||
PasskeyGetCredsByUsername string // default: "resolvespec_passkey_get_credentials_by_username"
|
||||
PasskeyGetCredential string // default: "resolvespec_passkey_get_credential"
|
||||
PasskeyUpdateCounter string // default: "resolvespec_passkey_update_counter"
|
||||
PasskeyGetUserCredentials string // default: "resolvespec_passkey_get_user_credentials"
|
||||
PasskeyDeleteCredential string // default: "resolvespec_passkey_delete_credential"
|
||||
PasskeyUpdateName string // default: "resolvespec_passkey_update_name"
|
||||
PasskeyLogin string // default: "resolvespec_passkey_login"
|
||||
|
||||
// OAuth2 procedures (DatabaseAuthenticator OAuth2 methods)
|
||||
OAuthGetOrCreateUser string // default: "resolvespec_oauth_getorcreateuser"
|
||||
OAuthCreateSession string // default: "resolvespec_oauth_createsession"
|
||||
OAuthGetRefreshToken string // default: "resolvespec_oauth_getrefreshtoken"
|
||||
OAuthUpdateRefreshToken string // default: "resolvespec_oauth_updaterefreshtoken"
|
||||
OAuthGetUser string // default: "resolvespec_oauth_getuser"
|
||||
|
||||
// OAuth2 server procedures (OAuthServer persistence)
|
||||
OAuthRegisterClient string // default: "resolvespec_oauth_register_client"
|
||||
OAuthGetClient string // default: "resolvespec_oauth_get_client"
|
||||
OAuthSaveCode string // default: "resolvespec_oauth_save_code"
|
||||
OAuthExchangeCode string // default: "resolvespec_oauth_exchange_code"
|
||||
OAuthIntrospect string // default: "resolvespec_oauth_introspect"
|
||||
OAuthRevoke string // default: "resolvespec_oauth_revoke"
|
||||
}
|
||||
|
||||
// DefaultSQLNames returns an SQLNames with all default resolvespec_* values.
|
||||
func DefaultSQLNames() *SQLNames {
|
||||
return &SQLNames{
|
||||
Login: "resolvespec_login",
|
||||
Register: "resolvespec_register",
|
||||
Logout: "resolvespec_logout",
|
||||
Session: "resolvespec_session",
|
||||
SessionUpdate: "resolvespec_session_update",
|
||||
RefreshToken: "resolvespec_refresh_token",
|
||||
|
||||
JWTLogin: "resolvespec_jwt_login",
|
||||
JWTLogout: "resolvespec_jwt_logout",
|
||||
|
||||
ColumnSecurity: "resolvespec_column_security",
|
||||
RowSecurity: "resolvespec_row_security",
|
||||
|
||||
TOTPEnable: "resolvespec_totp_enable",
|
||||
TOTPDisable: "resolvespec_totp_disable",
|
||||
TOTPGetStatus: "resolvespec_totp_get_status",
|
||||
TOTPGetSecret: "resolvespec_totp_get_secret",
|
||||
TOTPRegenerateBackup: "resolvespec_totp_regenerate_backup_codes",
|
||||
TOTPValidateBackupCode: "resolvespec_totp_validate_backup_code",
|
||||
|
||||
PasskeyStoreCredential: "resolvespec_passkey_store_credential",
|
||||
PasskeyGetCredsByUsername: "resolvespec_passkey_get_credentials_by_username",
|
||||
PasskeyGetCredential: "resolvespec_passkey_get_credential",
|
||||
PasskeyUpdateCounter: "resolvespec_passkey_update_counter",
|
||||
PasskeyGetUserCredentials: "resolvespec_passkey_get_user_credentials",
|
||||
PasskeyDeleteCredential: "resolvespec_passkey_delete_credential",
|
||||
PasskeyUpdateName: "resolvespec_passkey_update_name",
|
||||
PasskeyLogin: "resolvespec_passkey_login",
|
||||
|
||||
OAuthGetOrCreateUser: "resolvespec_oauth_getorcreateuser",
|
||||
OAuthCreateSession: "resolvespec_oauth_createsession",
|
||||
OAuthGetRefreshToken: "resolvespec_oauth_getrefreshtoken",
|
||||
OAuthUpdateRefreshToken: "resolvespec_oauth_updaterefreshtoken",
|
||||
OAuthGetUser: "resolvespec_oauth_getuser",
|
||||
|
||||
OAuthRegisterClient: "resolvespec_oauth_register_client",
|
||||
OAuthGetClient: "resolvespec_oauth_get_client",
|
||||
OAuthSaveCode: "resolvespec_oauth_save_code",
|
||||
OAuthExchangeCode: "resolvespec_oauth_exchange_code",
|
||||
OAuthIntrospect: "resolvespec_oauth_introspect",
|
||||
OAuthRevoke: "resolvespec_oauth_revoke",
|
||||
}
|
||||
}
|
||||
|
||||
// MergeSQLNames returns a copy of base with any non-empty fields from override applied.
|
||||
// If override is nil, a copy of base is returned.
|
||||
func MergeSQLNames(base, override *SQLNames) *SQLNames {
|
||||
if override == nil {
|
||||
copied := *base
|
||||
return &copied
|
||||
}
|
||||
merged := *base
|
||||
if override.Login != "" {
|
||||
merged.Login = override.Login
|
||||
}
|
||||
if override.Register != "" {
|
||||
merged.Register = override.Register
|
||||
}
|
||||
if override.Logout != "" {
|
||||
merged.Logout = override.Logout
|
||||
}
|
||||
if override.Session != "" {
|
||||
merged.Session = override.Session
|
||||
}
|
||||
if override.SessionUpdate != "" {
|
||||
merged.SessionUpdate = override.SessionUpdate
|
||||
}
|
||||
if override.RefreshToken != "" {
|
||||
merged.RefreshToken = override.RefreshToken
|
||||
}
|
||||
if override.JWTLogin != "" {
|
||||
merged.JWTLogin = override.JWTLogin
|
||||
}
|
||||
if override.JWTLogout != "" {
|
||||
merged.JWTLogout = override.JWTLogout
|
||||
}
|
||||
if override.ColumnSecurity != "" {
|
||||
merged.ColumnSecurity = override.ColumnSecurity
|
||||
}
|
||||
if override.RowSecurity != "" {
|
||||
merged.RowSecurity = override.RowSecurity
|
||||
}
|
||||
if override.TOTPEnable != "" {
|
||||
merged.TOTPEnable = override.TOTPEnable
|
||||
}
|
||||
if override.TOTPDisable != "" {
|
||||
merged.TOTPDisable = override.TOTPDisable
|
||||
}
|
||||
if override.TOTPGetStatus != "" {
|
||||
merged.TOTPGetStatus = override.TOTPGetStatus
|
||||
}
|
||||
if override.TOTPGetSecret != "" {
|
||||
merged.TOTPGetSecret = override.TOTPGetSecret
|
||||
}
|
||||
if override.TOTPRegenerateBackup != "" {
|
||||
merged.TOTPRegenerateBackup = override.TOTPRegenerateBackup
|
||||
}
|
||||
if override.TOTPValidateBackupCode != "" {
|
||||
merged.TOTPValidateBackupCode = override.TOTPValidateBackupCode
|
||||
}
|
||||
if override.PasskeyStoreCredential != "" {
|
||||
merged.PasskeyStoreCredential = override.PasskeyStoreCredential
|
||||
}
|
||||
if override.PasskeyGetCredsByUsername != "" {
|
||||
merged.PasskeyGetCredsByUsername = override.PasskeyGetCredsByUsername
|
||||
}
|
||||
if override.PasskeyGetCredential != "" {
|
||||
merged.PasskeyGetCredential = override.PasskeyGetCredential
|
||||
}
|
||||
if override.PasskeyUpdateCounter != "" {
|
||||
merged.PasskeyUpdateCounter = override.PasskeyUpdateCounter
|
||||
}
|
||||
if override.PasskeyGetUserCredentials != "" {
|
||||
merged.PasskeyGetUserCredentials = override.PasskeyGetUserCredentials
|
||||
}
|
||||
if override.PasskeyDeleteCredential != "" {
|
||||
merged.PasskeyDeleteCredential = override.PasskeyDeleteCredential
|
||||
}
|
||||
if override.PasskeyUpdateName != "" {
|
||||
merged.PasskeyUpdateName = override.PasskeyUpdateName
|
||||
}
|
||||
if override.PasskeyLogin != "" {
|
||||
merged.PasskeyLogin = override.PasskeyLogin
|
||||
}
|
||||
if override.OAuthGetOrCreateUser != "" {
|
||||
merged.OAuthGetOrCreateUser = override.OAuthGetOrCreateUser
|
||||
}
|
||||
if override.OAuthCreateSession != "" {
|
||||
merged.OAuthCreateSession = override.OAuthCreateSession
|
||||
}
|
||||
if override.OAuthGetRefreshToken != "" {
|
||||
merged.OAuthGetRefreshToken = override.OAuthGetRefreshToken
|
||||
}
|
||||
if override.OAuthUpdateRefreshToken != "" {
|
||||
merged.OAuthUpdateRefreshToken = override.OAuthUpdateRefreshToken
|
||||
}
|
||||
if override.OAuthGetUser != "" {
|
||||
merged.OAuthGetUser = override.OAuthGetUser
|
||||
}
|
||||
if override.OAuthRegisterClient != "" {
|
||||
merged.OAuthRegisterClient = override.OAuthRegisterClient
|
||||
}
|
||||
if override.OAuthGetClient != "" {
|
||||
merged.OAuthGetClient = override.OAuthGetClient
|
||||
}
|
||||
if override.OAuthSaveCode != "" {
|
||||
merged.OAuthSaveCode = override.OAuthSaveCode
|
||||
}
|
||||
if override.OAuthExchangeCode != "" {
|
||||
merged.OAuthExchangeCode = override.OAuthExchangeCode
|
||||
}
|
||||
if override.OAuthIntrospect != "" {
|
||||
merged.OAuthIntrospect = override.OAuthIntrospect
|
||||
}
|
||||
if override.OAuthRevoke != "" {
|
||||
merged.OAuthRevoke = override.OAuthRevoke
|
||||
}
|
||||
return &merged
|
||||
}
|
||||
|
||||
// ValidateSQLNames checks that all non-empty fields in names are valid SQL identifiers.
|
||||
// Returns an error if any field contains invalid characters.
|
||||
func ValidateSQLNames(names *SQLNames) error {
|
||||
v := reflect.ValueOf(names).Elem()
|
||||
typ := v.Type()
|
||||
for i := 0; i < v.NumField(); i++ {
|
||||
field := v.Field(i)
|
||||
if field.Kind() != reflect.String {
|
||||
continue
|
||||
}
|
||||
val := field.String()
|
||||
if val != "" && !validSQLIdentifier.MatchString(val) {
|
||||
return fmt.Errorf("SQLNames.%s contains invalid characters: %q", typ.Field(i).Name, val)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// resolveSQLNames merges an optional override with defaults.
|
||||
// Used by constructors that accept variadic *SQLNames.
|
||||
func resolveSQLNames(override ...*SQLNames) *SQLNames {
|
||||
if len(override) > 0 && override[0] != nil {
|
||||
return MergeSQLNames(DefaultSQLNames(), override[0])
|
||||
}
|
||||
return DefaultSQLNames()
|
||||
}
|
||||
145
pkg/security/sql_names_test.go
Normal file
145
pkg/security/sql_names_test.go
Normal file
@@ -0,0 +1,145 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestDefaultSQLNames_AllFieldsNonEmpty(t *testing.T) {
|
||||
names := DefaultSQLNames()
|
||||
v := reflect.ValueOf(names).Elem()
|
||||
typ := v.Type()
|
||||
|
||||
for i := 0; i < v.NumField(); i++ {
|
||||
field := v.Field(i)
|
||||
if field.Kind() != reflect.String {
|
||||
continue
|
||||
}
|
||||
if field.String() == "" {
|
||||
t.Errorf("DefaultSQLNames().%s is empty", typ.Field(i).Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeSQLNames_PartialOverride(t *testing.T) {
|
||||
base := DefaultSQLNames()
|
||||
override := &SQLNames{
|
||||
Login: "custom_login",
|
||||
TOTPEnable: "custom_totp_enable",
|
||||
PasskeyLogin: "custom_passkey_login",
|
||||
}
|
||||
|
||||
merged := MergeSQLNames(base, override)
|
||||
|
||||
if merged.Login != "custom_login" {
|
||||
t.Errorf("MergeSQLNames().Login = %q, want %q", merged.Login, "custom_login")
|
||||
}
|
||||
if merged.TOTPEnable != "custom_totp_enable" {
|
||||
t.Errorf("MergeSQLNames().TOTPEnable = %q, want %q", merged.TOTPEnable, "custom_totp_enable")
|
||||
}
|
||||
if merged.PasskeyLogin != "custom_passkey_login" {
|
||||
t.Errorf("MergeSQLNames().PasskeyLogin = %q, want %q", merged.PasskeyLogin, "custom_passkey_login")
|
||||
}
|
||||
// Non-overridden fields should retain defaults
|
||||
if merged.Logout != "resolvespec_logout" {
|
||||
t.Errorf("MergeSQLNames().Logout = %q, want %q", merged.Logout, "resolvespec_logout")
|
||||
}
|
||||
if merged.Session != "resolvespec_session" {
|
||||
t.Errorf("MergeSQLNames().Session = %q, want %q", merged.Session, "resolvespec_session")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeSQLNames_NilOverride(t *testing.T) {
|
||||
base := DefaultSQLNames()
|
||||
merged := MergeSQLNames(base, nil)
|
||||
|
||||
// Should be a copy, not the same pointer
|
||||
if merged == base {
|
||||
t.Error("MergeSQLNames with nil override should return a copy, not the same pointer")
|
||||
}
|
||||
|
||||
// All values should match
|
||||
v1 := reflect.ValueOf(base).Elem()
|
||||
v2 := reflect.ValueOf(merged).Elem()
|
||||
typ := v1.Type()
|
||||
|
||||
for i := 0; i < v1.NumField(); i++ {
|
||||
f1 := v1.Field(i)
|
||||
f2 := v2.Field(i)
|
||||
if f1.Kind() != reflect.String {
|
||||
continue
|
||||
}
|
||||
if f1.String() != f2.String() {
|
||||
t.Errorf("MergeSQLNames(base, nil).%s = %q, want %q", typ.Field(i).Name, f2.String(), f1.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeSQLNames_DoesNotMutateBase(t *testing.T) {
|
||||
base := DefaultSQLNames()
|
||||
originalLogin := base.Login
|
||||
|
||||
override := &SQLNames{Login: "custom_login"}
|
||||
_ = MergeSQLNames(base, override)
|
||||
|
||||
if base.Login != originalLogin {
|
||||
t.Errorf("MergeSQLNames mutated base: Login = %q, want %q", base.Login, originalLogin)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeSQLNames_AllFieldsMerged(t *testing.T) {
|
||||
base := DefaultSQLNames()
|
||||
override := &SQLNames{}
|
||||
v := reflect.ValueOf(override).Elem()
|
||||
for i := 0; i < v.NumField(); i++ {
|
||||
if v.Field(i).Kind() == reflect.String {
|
||||
v.Field(i).SetString("custom_sentinel")
|
||||
}
|
||||
}
|
||||
|
||||
merged := MergeSQLNames(base, override)
|
||||
mv := reflect.ValueOf(merged).Elem()
|
||||
typ := mv.Type()
|
||||
for i := 0; i < mv.NumField(); i++ {
|
||||
if mv.Field(i).Kind() != reflect.String {
|
||||
continue
|
||||
}
|
||||
if mv.Field(i).String() != "custom_sentinel" {
|
||||
t.Errorf("MergeSQLNames did not merge field %s", typ.Field(i).Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateSQLNames_Valid(t *testing.T) {
|
||||
names := DefaultSQLNames()
|
||||
if err := ValidateSQLNames(names); err != nil {
|
||||
t.Errorf("ValidateSQLNames(defaults) error = %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateSQLNames_Invalid(t *testing.T) {
|
||||
names := DefaultSQLNames()
|
||||
names.Login = "resolvespec_login; DROP TABLE users; --"
|
||||
|
||||
err := ValidateSQLNames(names)
|
||||
if err == nil {
|
||||
t.Error("ValidateSQLNames should reject names with invalid characters")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveSQLNames_NoOverride(t *testing.T) {
|
||||
names := resolveSQLNames()
|
||||
if names.Login != "resolvespec_login" {
|
||||
t.Errorf("resolveSQLNames().Login = %q, want default", names.Login)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveSQLNames_WithOverride(t *testing.T) {
|
||||
names := resolveSQLNames(&SQLNames{Login: "custom_login"})
|
||||
if names.Login != "custom_login" {
|
||||
t.Errorf("resolveSQLNames().Login = %q, want %q", names.Login, "custom_login")
|
||||
}
|
||||
if names.Logout != "resolvespec_logout" {
|
||||
t.Errorf("resolveSQLNames().Logout = %q, want default", names.Logout)
|
||||
}
|
||||
}
|
||||
@@ -9,23 +9,23 @@ import (
|
||||
)
|
||||
|
||||
// DatabaseTwoFactorProvider implements TwoFactorAuthProvider using PostgreSQL stored procedures
|
||||
// Requires stored procedures: resolvespec_totp_enable, resolvespec_totp_disable,
|
||||
// resolvespec_totp_get_status, resolvespec_totp_get_secret,
|
||||
// resolvespec_totp_regenerate_backup_codes, resolvespec_totp_validate_backup_code
|
||||
// Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults)
|
||||
// See totp_database_schema.sql for procedure definitions
|
||||
type DatabaseTwoFactorProvider struct {
|
||||
db *sql.DB
|
||||
totpGen *TOTPGenerator
|
||||
db *sql.DB
|
||||
totpGen *TOTPGenerator
|
||||
sqlNames *SQLNames
|
||||
}
|
||||
|
||||
// NewDatabaseTwoFactorProvider creates a new database-backed 2FA provider
|
||||
func NewDatabaseTwoFactorProvider(db *sql.DB, config *TwoFactorConfig) *DatabaseTwoFactorProvider {
|
||||
func NewDatabaseTwoFactorProvider(db *sql.DB, config *TwoFactorConfig, names ...*SQLNames) *DatabaseTwoFactorProvider {
|
||||
if config == nil {
|
||||
config = DefaultTwoFactorConfig()
|
||||
}
|
||||
return &DatabaseTwoFactorProvider{
|
||||
db: db,
|
||||
totpGen: NewTOTPGenerator(config),
|
||||
db: db,
|
||||
totpGen: NewTOTPGenerator(config),
|
||||
sqlNames: resolveSQLNames(names...),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -76,7 +76,7 @@ func (p *DatabaseTwoFactorProvider) Enable2FA(userID int, secret string, backupC
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
|
||||
query := `SELECT p_success, p_error FROM resolvespec_totp_enable($1, $2, $3::jsonb)`
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1, $2, $3::jsonb)`, p.sqlNames.TOTPEnable)
|
||||
err = p.db.QueryRow(query, userID, secret, string(codesJSON)).Scan(&success, &errorMsg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("enable 2FA query failed: %w", err)
|
||||
@@ -97,7 +97,7 @@ func (p *DatabaseTwoFactorProvider) Disable2FA(userID int) error {
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
|
||||
query := `SELECT p_success, p_error FROM resolvespec_totp_disable($1)`
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1)`, p.sqlNames.TOTPDisable)
|
||||
err := p.db.QueryRow(query, userID).Scan(&success, &errorMsg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("disable 2FA query failed: %w", err)
|
||||
@@ -119,7 +119,7 @@ func (p *DatabaseTwoFactorProvider) Get2FAStatus(userID int) (bool, error) {
|
||||
var errorMsg sql.NullString
|
||||
var enabled bool
|
||||
|
||||
query := `SELECT p_success, p_error, p_enabled FROM resolvespec_totp_get_status($1)`
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_enabled FROM %s($1)`, p.sqlNames.TOTPGetStatus)
|
||||
err := p.db.QueryRow(query, userID).Scan(&success, &errorMsg, &enabled)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("get 2FA status query failed: %w", err)
|
||||
@@ -141,7 +141,7 @@ func (p *DatabaseTwoFactorProvider) Get2FASecret(userID int) (string, error) {
|
||||
var errorMsg sql.NullString
|
||||
var secret sql.NullString
|
||||
|
||||
query := `SELECT p_success, p_error, p_secret FROM resolvespec_totp_get_secret($1)`
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_secret FROM %s($1)`, p.sqlNames.TOTPGetSecret)
|
||||
err := p.db.QueryRow(query, userID).Scan(&success, &errorMsg, &secret)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("get 2FA secret query failed: %w", err)
|
||||
@@ -185,7 +185,7 @@ func (p *DatabaseTwoFactorProvider) GenerateBackupCodes(userID int, count int) (
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
|
||||
query := `SELECT p_success, p_error FROM resolvespec_totp_regenerate_backup_codes($1, $2::jsonb)`
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1, $2::jsonb)`, p.sqlNames.TOTPRegenerateBackup)
|
||||
err = p.db.QueryRow(query, userID, string(codesJSON)).Scan(&success, &errorMsg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("regenerate backup codes query failed: %w", err)
|
||||
@@ -212,7 +212,7 @@ func (p *DatabaseTwoFactorProvider) ValidateBackupCode(userID int, code string)
|
||||
var errorMsg sql.NullString
|
||||
var valid bool
|
||||
|
||||
query := `SELECT p_success, p_error, p_valid FROM resolvespec_totp_validate_backup_code($1, $2)`
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_valid FROM %s($1, $2)`, p.sqlNames.TOTPValidateBackupCode)
|
||||
err := p.db.QueryRow(query, userID, codeHash).Scan(&success, &errorMsg, &valid)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("validate backup code query failed: %w", err)
|
||||
|
||||
@@ -3,6 +3,7 @@ package server_test
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
@@ -29,18 +30,18 @@ func ExampleManager_basic() {
|
||||
GZIP: true, // Enable GZIP compression
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Start all servers
|
||||
if err := mgr.StartAll(); err != nil {
|
||||
panic(err)
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Server is now running...
|
||||
// When done, stop gracefully
|
||||
if err := mgr.StopAll(); err != nil {
|
||||
panic(err)
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -61,7 +62,7 @@ func ExampleManager_https() {
|
||||
SSLKey: "/path/to/key.pem",
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Option 2: Self-signed certificate (for development)
|
||||
@@ -73,27 +74,27 @@ func ExampleManager_https() {
|
||||
SelfSignedSSL: true,
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Option 3: Let's Encrypt / AutoTLS (for production)
|
||||
_, err = mgr.Add(server.Config{
|
||||
Name: "https-server-letsencrypt",
|
||||
Host: "0.0.0.0",
|
||||
Port: 443,
|
||||
Handler: handler,
|
||||
AutoTLS: true,
|
||||
AutoTLSDomains: []string{"example.com", "www.example.com"},
|
||||
AutoTLSEmail: "admin@example.com",
|
||||
Name: "https-server-letsencrypt",
|
||||
Host: "0.0.0.0",
|
||||
Port: 443,
|
||||
Handler: handler,
|
||||
AutoTLS: true,
|
||||
AutoTLSDomains: []string{"example.com", "www.example.com"},
|
||||
AutoTLSEmail: "admin@example.com",
|
||||
AutoTLSCacheDir: "./certs-cache",
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Start all servers
|
||||
if err := mgr.StartAll(); err != nil {
|
||||
panic(err)
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Cleanup
|
||||
@@ -136,7 +137,7 @@ func ExampleManager_gracefulShutdown() {
|
||||
IdleTimeout: 120 * time.Second,
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Start servers and block until shutdown signal (SIGINT/SIGTERM)
|
||||
@@ -164,7 +165,7 @@ func ExampleManager_healthChecks() {
|
||||
Handler: mux,
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Add health and readiness endpoints
|
||||
@@ -173,7 +174,7 @@ func ExampleManager_healthChecks() {
|
||||
|
||||
// Start the server
|
||||
if err := mgr.StartAll(); err != nil {
|
||||
panic(err)
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Health check returns:
|
||||
@@ -204,7 +205,7 @@ func ExampleManager_multipleServers() {
|
||||
GZIP: true,
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Admin API server (different port)
|
||||
@@ -218,7 +219,7 @@ func ExampleManager_multipleServers() {
|
||||
Handler: adminHandler,
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Metrics server (internal only)
|
||||
@@ -232,18 +233,18 @@ func ExampleManager_multipleServers() {
|
||||
Handler: metricsHandler,
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Start all servers at once
|
||||
if err := mgr.StartAll(); err != nil {
|
||||
panic(err)
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Get specific server instance
|
||||
publicInstance, err := mgr.Get("public-api")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
log.Fatal(err)
|
||||
}
|
||||
fmt.Printf("Public API running on: %s\n", publicInstance.Addr())
|
||||
|
||||
@@ -253,7 +254,7 @@ func ExampleManager_multipleServers() {
|
||||
|
||||
// Stop all servers gracefully (in parallel)
|
||||
if err := mgr.StopAll(); err != nil {
|
||||
panic(err)
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -273,11 +274,11 @@ func ExampleManager_monitoring() {
|
||||
Handler: handler,
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
if err := mgr.StartAll(); err != nil {
|
||||
panic(err)
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Check server status
|
||||
|
||||
Reference in New Issue
Block a user