mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2026-04-09 01:16:22 +00:00
Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
568df8c6d6 | ||
|
|
aa362c77da | ||
|
|
1641eaf278 | ||
|
|
200a03c225 |
407
pkg/resolvemcp/README.md
Normal file
407
pkg/resolvemcp/README.md
Normal file
@@ -0,0 +1,407 @@
|
||||
# resolvemcp
|
||||
|
||||
Package `resolvemcp` exposes registered database models as **Model Context Protocol (MCP) tools and resources** over HTTP/SSE transport. It mirrors the `resolvespec` package patterns — same model registration API, same filter/sort/pagination/preload options, same lifecycle hook system.
|
||||
|
||||
## Quick Start
|
||||
|
||||
```go
|
||||
import (
|
||||
"github.com/bitechdev/ResolveSpec/pkg/resolvemcp"
|
||||
"github.com/gorilla/mux"
|
||||
)
|
||||
|
||||
// 1. Create a handler
|
||||
handler := resolvemcp.NewHandlerWithGORM(db, resolvemcp.Config{
|
||||
BaseURL: "http://localhost:8080",
|
||||
})
|
||||
|
||||
// 2. Register models
|
||||
handler.RegisterModel("public", "users", &User{})
|
||||
handler.RegisterModel("public", "orders", &Order{})
|
||||
|
||||
// 3. Mount routes
|
||||
r := mux.NewRouter()
|
||||
resolvemcp.SetupMuxRoutes(r, handler)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Config
|
||||
|
||||
```go
|
||||
type Config struct {
|
||||
// BaseURL is the public-facing base URL of the server (e.g. "http://localhost:8080").
|
||||
// Sent to MCP clients during the SSE handshake so they know where to POST messages.
|
||||
// If empty, it is detected from each incoming request using the Host header and
|
||||
// TLS state (X-Forwarded-Proto is honoured for reverse-proxy deployments).
|
||||
BaseURL string
|
||||
|
||||
// BasePath is the URL path prefix where MCP endpoints are mounted (e.g. "/mcp").
|
||||
// Required.
|
||||
BasePath string
|
||||
}
|
||||
```
|
||||
|
||||
## Handler Creation
|
||||
|
||||
| Function | Description |
|
||||
|---|---|
|
||||
| `NewHandlerWithGORM(db *gorm.DB, cfg Config) *Handler` | Backed by GORM |
|
||||
| `NewHandlerWithBun(db *bun.DB, cfg Config) *Handler` | Backed by Bun |
|
||||
| `NewHandlerWithDB(db common.Database, cfg Config) *Handler` | Backed by any `common.Database` |
|
||||
| `NewHandler(db common.Database, registry common.ModelRegistry, cfg Config) *Handler` | Full control over registry |
|
||||
|
||||
---
|
||||
|
||||
## Registering Models
|
||||
|
||||
```go
|
||||
handler.RegisterModel(schema, entity string, model interface{}) error
|
||||
```
|
||||
|
||||
- `schema` — database schema name (e.g. `"public"`), or empty string for no schema prefix.
|
||||
- `entity` — table/entity name (e.g. `"users"`).
|
||||
- `model` — a pointer to a struct (e.g. `&User{}`).
|
||||
|
||||
Each call immediately creates four MCP **tools** and one MCP **resource** for the model.
|
||||
|
||||
---
|
||||
|
||||
## HTTP / SSE Transport
|
||||
|
||||
The `*server.SSEServer` returned by any of the helpers below implements `http.Handler`, so it works with every Go HTTP framework.
|
||||
|
||||
`Config.BasePath` is required and used for all route registration.
|
||||
`Config.BaseURL` is optional — when empty it is detected from each request.
|
||||
|
||||
### Gorilla Mux
|
||||
|
||||
```go
|
||||
resolvemcp.SetupMuxRoutes(r, handler)
|
||||
```
|
||||
|
||||
Registers:
|
||||
|
||||
| Route | Method | Description |
|
||||
|---|---|---|
|
||||
| `{BasePath}/sse` | GET | SSE connection — clients subscribe here |
|
||||
| `{BasePath}/message` | POST | JSON-RPC — clients send requests here |
|
||||
| `{BasePath}/*` | any | Full SSE server (convenience prefix) |
|
||||
|
||||
### bunrouter
|
||||
|
||||
```go
|
||||
resolvemcp.SetupBunRouterRoutes(router, handler)
|
||||
```
|
||||
|
||||
Registers `GET {BasePath}/sse` and `POST {BasePath}/message` on the provided `*bunrouter.Router`.
|
||||
|
||||
### Gin (or any `http.Handler`-compatible framework)
|
||||
|
||||
Use `handler.SSEServer()` to get an `http.Handler` and wrap it with the framework's adapter:
|
||||
|
||||
```go
|
||||
sse := handler.SSEServer()
|
||||
|
||||
// Gin
|
||||
engine.Any("/mcp/*path", gin.WrapH(sse))
|
||||
|
||||
// net/http
|
||||
http.Handle("/mcp/", sse)
|
||||
|
||||
// Echo
|
||||
e.Any("/mcp/*", echo.WrapHandler(sse))
|
||||
```
|
||||
|
||||
### Authentication
|
||||
|
||||
Add middleware before the MCP routes. The handler itself has no auth layer.
|
||||
|
||||
---
|
||||
|
||||
## MCP Tools
|
||||
|
||||
### Tool Naming
|
||||
|
||||
```
|
||||
{operation}_{schema}_{entity} // e.g. read_public_users
|
||||
{operation}_{entity} // e.g. read_users (when schema is empty)
|
||||
```
|
||||
|
||||
Operations: `read`, `create`, `update`, `delete`.
|
||||
|
||||
### Read Tool — `read_{schema}_{entity}`
|
||||
|
||||
Fetch one or many records.
|
||||
|
||||
| Argument | Type | Description |
|
||||
|---|---|---|
|
||||
| `id` | string | Primary key value. Omit to return multiple records. |
|
||||
| `limit` | number | Max records per page (recommended: 10–100). |
|
||||
| `offset` | number | Records to skip (offset-based pagination). |
|
||||
| `cursor_forward` | string | PK of the **last** record on the current page (next-page cursor). |
|
||||
| `cursor_backward` | string | PK of the **first** record on the current page (prev-page cursor). |
|
||||
| `columns` | array | Column names to include. Omit for all columns. |
|
||||
| `omit_columns` | array | Column names to exclude. |
|
||||
| `filters` | array | Filter objects (see [Filtering](#filtering)). |
|
||||
| `sort` | array | Sort objects (see [Sorting](#sorting)). |
|
||||
| `preloads` | array | Relation preload objects (see [Preloading](#preloading)). |
|
||||
|
||||
**Response:**
|
||||
```json
|
||||
{
|
||||
"success": true,
|
||||
"data": [...],
|
||||
"metadata": {
|
||||
"total": 100,
|
||||
"filtered": 100,
|
||||
"count": 10,
|
||||
"limit": 10,
|
||||
"offset": 0
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Create Tool — `create_{schema}_{entity}`
|
||||
|
||||
Insert one or more records.
|
||||
|
||||
| Argument | Type | Description |
|
||||
|---|---|---|
|
||||
| `data` | object \| array | Single object or array of objects to insert. |
|
||||
|
||||
Array input runs inside a single transaction — all succeed or all fail.
|
||||
|
||||
**Response:**
|
||||
```json
|
||||
{ "success": true, "data": { ... } }
|
||||
```
|
||||
|
||||
### Update Tool — `update_{schema}_{entity}`
|
||||
|
||||
Partially update an existing record. Only non-null, non-empty fields in `data` are applied; existing values are preserved for omitted fields.
|
||||
|
||||
| Argument | Type | Description |
|
||||
|---|---|---|
|
||||
| `id` | string | Primary key of the record. Can also be included inside `data`. |
|
||||
| `data` | object (required) | Fields to update. |
|
||||
|
||||
**Response:**
|
||||
```json
|
||||
{ "success": true, "data": { ...merged record... } }
|
||||
```
|
||||
|
||||
### Delete Tool — `delete_{schema}_{entity}`
|
||||
|
||||
Delete a record by primary key. **Irreversible.**
|
||||
|
||||
| Argument | Type | Description |
|
||||
|---|---|---|
|
||||
| `id` | string (required) | Primary key of the record to delete. |
|
||||
|
||||
**Response:**
|
||||
```json
|
||||
{ "success": true, "data": { ...deleted record... } }
|
||||
```
|
||||
|
||||
### Resource — `{schema}.{entity}`
|
||||
|
||||
Each model is also registered as an MCP resource with URI `schema.entity` (or just `entity` when schema is empty). Reading the resource returns up to 100 records as `application/json`.
|
||||
|
||||
---
|
||||
|
||||
## Filtering
|
||||
|
||||
Pass an array of filter objects to the `filters` argument:
|
||||
|
||||
```json
|
||||
[
|
||||
{ "column": "status", "operator": "=", "value": "active" },
|
||||
{ "column": "age", "operator": ">", "value": 18, "logic_operator": "AND" },
|
||||
{ "column": "role", "operator": "in", "value": ["admin", "editor"], "logic_operator": "OR" }
|
||||
]
|
||||
```
|
||||
|
||||
### Supported Operators
|
||||
|
||||
| Operator | Aliases | Description |
|
||||
|---|---|---|
|
||||
| `=` | `eq` | Equal |
|
||||
| `!=` | `neq`, `<>` | Not equal |
|
||||
| `>` | `gt` | Greater than |
|
||||
| `>=` | `gte` | Greater than or equal |
|
||||
| `<` | `lt` | Less than |
|
||||
| `<=` | `lte` | Less than or equal |
|
||||
| `like` | | SQL LIKE (case-sensitive) |
|
||||
| `ilike` | | SQL ILIKE (case-insensitive) |
|
||||
| `in` | | Value in list |
|
||||
| `is_null` | | Column IS NULL |
|
||||
| `is_not_null` | | Column IS NOT NULL |
|
||||
|
||||
### Logic Operators
|
||||
|
||||
- `"logic_operator": "AND"` (default) — filter is AND-chained with the previous condition.
|
||||
- `"logic_operator": "OR"` — filter is OR-grouped with the previous condition.
|
||||
|
||||
Consecutive OR filters are grouped into a single `(cond1 OR cond2 OR ...)` clause.
|
||||
|
||||
---
|
||||
|
||||
## Sorting
|
||||
|
||||
```json
|
||||
[
|
||||
{ "column": "created_at", "direction": "desc" },
|
||||
{ "column": "name", "direction": "asc" }
|
||||
]
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Pagination
|
||||
|
||||
### Offset-Based
|
||||
|
||||
```json
|
||||
{ "limit": 20, "offset": 40 }
|
||||
```
|
||||
|
||||
### Cursor-Based
|
||||
|
||||
Cursor pagination uses a SQL `EXISTS` subquery for stable, efficient paging. Always pair with a `sort` argument.
|
||||
|
||||
```json
|
||||
// Next page: pass the PK of the last record on the current page
|
||||
{ "cursor_forward": "42", "limit": 20, "sort": [{"column": "id", "direction": "asc"}] }
|
||||
|
||||
// Previous page: pass the PK of the first record on the current page
|
||||
{ "cursor_backward": "23", "limit": 20, "sort": [{"column": "id", "direction": "asc"}] }
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Preloading Relations
|
||||
|
||||
```json
|
||||
[
|
||||
{ "relation": "Profile" },
|
||||
{ "relation": "Orders" }
|
||||
]
|
||||
```
|
||||
|
||||
Available relations are listed in each tool's description. Only relations defined on the model struct are valid.
|
||||
|
||||
---
|
||||
|
||||
## Hook System
|
||||
|
||||
Hooks let you intercept and modify CRUD operations at well-defined lifecycle points.
|
||||
|
||||
### Hook Types
|
||||
|
||||
| Constant | Fires |
|
||||
|---|---|
|
||||
| `BeforeHandle` | After model resolution, before operation dispatch (all CRUD) |
|
||||
| `BeforeRead` / `AfterRead` | Around read queries |
|
||||
| `BeforeCreate` / `AfterCreate` | Around insert |
|
||||
| `BeforeUpdate` / `AfterUpdate` | Around update |
|
||||
| `BeforeDelete` / `AfterDelete` | Around delete |
|
||||
|
||||
### Registering Hooks
|
||||
|
||||
```go
|
||||
handler.Hooks().Register(resolvemcp.BeforeCreate, func(ctx *resolvemcp.HookContext) error {
|
||||
// Inject a timestamp before insert
|
||||
if data, ok := ctx.Data.(map[string]interface{}); ok {
|
||||
data["created_at"] = time.Now()
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
// Register the same hook for multiple events
|
||||
handler.Hooks().RegisterMultiple(
|
||||
[]resolvemcp.HookType{resolvemcp.BeforeCreate, resolvemcp.BeforeUpdate},
|
||||
auditHook,
|
||||
)
|
||||
```
|
||||
|
||||
### HookContext Fields
|
||||
|
||||
| Field | Type | Description |
|
||||
|---|---|---|
|
||||
| `Context` | `context.Context` | Request context |
|
||||
| `Handler` | `*Handler` | The resolvemcp handler |
|
||||
| `Schema` | `string` | Database schema name |
|
||||
| `Entity` | `string` | Entity/table name |
|
||||
| `Model` | `interface{}` | Registered model instance |
|
||||
| `Options` | `common.RequestOptions` | Parsed request options (read operations) |
|
||||
| `Operation` | `string` | `"read"`, `"create"`, `"update"`, or `"delete"` |
|
||||
| `ID` | `string` | Primary key from request (read/update/delete) |
|
||||
| `Data` | `interface{}` | Input data (create/update — modifiable) |
|
||||
| `Result` | `interface{}` | Output data (set by After hooks) |
|
||||
| `Error` | `error` | Operation error, if any |
|
||||
| `Query` | `common.SelectQuery` | Live query object (available in `BeforeRead`) |
|
||||
| `Tx` | `common.Database` | Database/transaction handle |
|
||||
| `Abort` | `bool` | Set to `true` to abort the operation |
|
||||
| `AbortMessage` | `string` | Error message returned when aborting |
|
||||
| `AbortCode` | `int` | Optional status code for the abort |
|
||||
|
||||
### Aborting an Operation
|
||||
|
||||
```go
|
||||
handler.Hooks().Register(resolvemcp.BeforeDelete, func(ctx *resolvemcp.HookContext) error {
|
||||
ctx.Abort = true
|
||||
ctx.AbortMessage = "deletion is disabled"
|
||||
return nil
|
||||
})
|
||||
```
|
||||
|
||||
### Managing Hooks
|
||||
|
||||
```go
|
||||
registry := handler.Hooks()
|
||||
registry.HasHooks(resolvemcp.BeforeCreate) // bool
|
||||
registry.Clear(resolvemcp.BeforeCreate) // remove hooks for one type
|
||||
registry.ClearAll() // remove all hooks
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Context Helpers
|
||||
|
||||
Request metadata is threaded through `context.Context` during handler execution. Hooks and custom tools can read it:
|
||||
|
||||
```go
|
||||
schema := resolvemcp.GetSchema(ctx)
|
||||
entity := resolvemcp.GetEntity(ctx)
|
||||
tableName := resolvemcp.GetTableName(ctx)
|
||||
model := resolvemcp.GetModel(ctx)
|
||||
modelPtr := resolvemcp.GetModelPtr(ctx)
|
||||
```
|
||||
|
||||
You can also set values manually (e.g. in middleware):
|
||||
|
||||
```go
|
||||
ctx = resolvemcp.WithSchema(ctx, "tenant_a")
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Adding Custom MCP Tools
|
||||
|
||||
Access the underlying `*server.MCPServer` to register additional tools:
|
||||
|
||||
```go
|
||||
mcpServer := handler.MCPServer()
|
||||
mcpServer.AddTool(myTool, myHandler)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Table Name Resolution
|
||||
|
||||
The handler resolves table names in priority order:
|
||||
|
||||
1. `TableNameProvider` interface — `TableName() string` (can return `"schema.table"`)
|
||||
2. `SchemaProvider` interface — `SchemaName() string` (combined with entity name)
|
||||
3. Fallback: `schema.entity` (or `schema_entity` for SQLite)
|
||||
@@ -46,7 +46,7 @@ func getCursorFilter(
|
||||
reverse := direction < 0
|
||||
|
||||
for _, s := range sortItems {
|
||||
col := strings.TrimSpace(s.Column)
|
||||
col := strings.Trim(strings.TrimSpace(s.Column), "()")
|
||||
if col == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -5,8 +5,10 @@ import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/mark3labs/mcp-go/server"
|
||||
|
||||
@@ -21,17 +23,19 @@ type Handler struct {
|
||||
registry common.ModelRegistry
|
||||
hooks *HookRegistry
|
||||
mcpServer *server.MCPServer
|
||||
config Config
|
||||
name string
|
||||
version string
|
||||
}
|
||||
|
||||
// NewHandler creates a Handler with the given database and model registry.
|
||||
func NewHandler(db common.Database, registry common.ModelRegistry) *Handler {
|
||||
// NewHandler creates a Handler with the given database, model registry, and config.
|
||||
func NewHandler(db common.Database, registry common.ModelRegistry, cfg Config) *Handler {
|
||||
return &Handler{
|
||||
db: db,
|
||||
registry: registry,
|
||||
hooks: NewHookRegistry(),
|
||||
mcpServer: server.NewMCPServer("resolvemcp", "1.0.0"),
|
||||
config: cfg,
|
||||
name: "resolvemcp",
|
||||
version: "1.0.0",
|
||||
}
|
||||
@@ -52,6 +56,63 @@ func (h *Handler) MCPServer() *server.MCPServer {
|
||||
return h.mcpServer
|
||||
}
|
||||
|
||||
// SSEServer returns an http.Handler that serves MCP over SSE.
|
||||
// Config.BasePath must be set. Config.BaseURL is used when set; if empty it is
|
||||
// detected automatically from each incoming request.
|
||||
func (h *Handler) SSEServer() http.Handler {
|
||||
if h.config.BaseURL != "" {
|
||||
return h.newSSEServer(h.config.BaseURL, h.config.BasePath)
|
||||
}
|
||||
return &dynamicSSEHandler{h: h}
|
||||
}
|
||||
|
||||
// newSSEServer creates a concrete *server.SSEServer for known baseURL and basePath values.
|
||||
func (h *Handler) newSSEServer(baseURL, basePath string) *server.SSEServer {
|
||||
return server.NewSSEServer(
|
||||
h.mcpServer,
|
||||
server.WithBaseURL(baseURL),
|
||||
server.WithBasePath(basePath),
|
||||
)
|
||||
}
|
||||
|
||||
// dynamicSSEHandler detects BaseURL from each request and delegates to a cached
|
||||
// *server.SSEServer per detected baseURL. Used when Config.BaseURL is empty.
|
||||
type dynamicSSEHandler struct {
|
||||
h *Handler
|
||||
mu sync.Mutex
|
||||
pool map[string]*server.SSEServer
|
||||
}
|
||||
|
||||
func (d *dynamicSSEHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
baseURL := requestBaseURL(r)
|
||||
|
||||
d.mu.Lock()
|
||||
if d.pool == nil {
|
||||
d.pool = make(map[string]*server.SSEServer)
|
||||
}
|
||||
s, ok := d.pool[baseURL]
|
||||
if !ok {
|
||||
s = d.h.newSSEServer(baseURL, d.h.config.BasePath)
|
||||
d.pool[baseURL] = s
|
||||
}
|
||||
d.mu.Unlock()
|
||||
|
||||
s.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
// requestBaseURL builds the base URL from an incoming request.
|
||||
// It honours the X-Forwarded-Proto header for deployments behind a proxy.
|
||||
func requestBaseURL(r *http.Request) string {
|
||||
scheme := "http"
|
||||
if r.TLS != nil {
|
||||
scheme = "https"
|
||||
}
|
||||
if proto := r.Header.Get("X-Forwarded-Proto"); proto != "" {
|
||||
scheme = proto
|
||||
}
|
||||
return scheme + "://" + r.Host
|
||||
}
|
||||
|
||||
// RegisterModel registers a model and immediately exposes it as MCP tools and a resource.
|
||||
func (h *Handler) RegisterModel(schema, entity string, model interface{}) error {
|
||||
fullName := buildModelName(schema, entity)
|
||||
|
||||
@@ -8,19 +8,19 @@
|
||||
//
|
||||
// Usage:
|
||||
//
|
||||
// handler := resolvemcp.NewHandlerWithGORM(db)
|
||||
// handler := resolvemcp.NewHandlerWithGORM(db, resolvemcp.Config{BaseURL: "http://localhost:8080"})
|
||||
// handler.RegisterModel("public", "users", &User{})
|
||||
//
|
||||
// r := mux.NewRouter()
|
||||
// resolvemcp.SetupMuxRoutes(r, handler, "http://localhost:8080")
|
||||
// resolvemcp.SetupMuxRoutes(r, handler)
|
||||
package resolvemcp
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/mark3labs/mcp-go/server"
|
||||
"github.com/uptrace/bun"
|
||||
bunrouter "github.com/uptrace/bunrouter"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
@@ -28,56 +28,73 @@ import (
|
||||
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
|
||||
)
|
||||
|
||||
// Config holds configuration for the resolvemcp handler.
|
||||
type Config struct {
|
||||
// BaseURL is the public-facing base URL of the server (e.g. "http://localhost:8080").
|
||||
// It is sent to MCP clients during the SSE handshake so they know where to POST messages.
|
||||
BaseURL string
|
||||
|
||||
// BasePath is the URL path prefix where the MCP endpoints are mounted (e.g. "/mcp").
|
||||
// If empty, the path is detected from each incoming request automatically.
|
||||
BasePath string
|
||||
}
|
||||
|
||||
// NewHandlerWithGORM creates a Handler backed by a GORM database connection.
|
||||
func NewHandlerWithGORM(db *gorm.DB) *Handler {
|
||||
return NewHandler(database.NewGormAdapter(db), modelregistry.NewModelRegistry())
|
||||
func NewHandlerWithGORM(db *gorm.DB, cfg Config) *Handler {
|
||||
return NewHandler(database.NewGormAdapter(db), modelregistry.NewModelRegistry(), cfg)
|
||||
}
|
||||
|
||||
// NewHandlerWithBun creates a Handler backed by a Bun database connection.
|
||||
func NewHandlerWithBun(db *bun.DB) *Handler {
|
||||
return NewHandler(database.NewBunAdapter(db), modelregistry.NewModelRegistry())
|
||||
func NewHandlerWithBun(db *bun.DB, cfg Config) *Handler {
|
||||
return NewHandler(database.NewBunAdapter(db), modelregistry.NewModelRegistry(), cfg)
|
||||
}
|
||||
|
||||
// NewHandlerWithDB creates a Handler using an existing common.Database and a new registry.
|
||||
func NewHandlerWithDB(db common.Database) *Handler {
|
||||
return NewHandler(db, modelregistry.NewModelRegistry())
|
||||
func NewHandlerWithDB(db common.Database, cfg Config) *Handler {
|
||||
return NewHandler(db, modelregistry.NewModelRegistry(), cfg)
|
||||
}
|
||||
|
||||
// SetupMuxRoutes mounts the MCP HTTP/SSE endpoints on the given Gorilla Mux router.
|
||||
//
|
||||
// baseURL is the public-facing base URL of the server (e.g. "http://localhost:8080").
|
||||
// It is sent to MCP clients during the SSE handshake so they know where to POST messages.
|
||||
// SetupMuxRoutes mounts the MCP HTTP/SSE endpoints on the given Gorilla Mux router
|
||||
// using the base path from Config.BasePath (falls back to "/mcp" if empty).
|
||||
//
|
||||
// Two routes are registered:
|
||||
// - GET /mcp/sse — SSE connection endpoint (client subscribes here)
|
||||
// - POST /mcp/message — JSON-RPC message endpoint (client sends requests here)
|
||||
// - GET {basePath}/sse — SSE connection endpoint (client subscribes here)
|
||||
// - POST {basePath}/message — JSON-RPC message endpoint (client sends requests here)
|
||||
//
|
||||
// To protect these routes with authentication, wrap the mux router or apply middleware
|
||||
// before calling SetupMuxRoutes.
|
||||
func SetupMuxRoutes(muxRouter *mux.Router, handler *Handler, baseURL string) {
|
||||
sseServer := server.NewSSEServer(
|
||||
handler.mcpServer,
|
||||
server.WithBaseURL(baseURL),
|
||||
server.WithBasePath("/mcp"),
|
||||
)
|
||||
func SetupMuxRoutes(muxRouter *mux.Router, handler *Handler) {
|
||||
basePath := handler.config.BasePath
|
||||
h := handler.SSEServer()
|
||||
|
||||
muxRouter.Handle("/mcp/sse", sseServer.SSEHandler()).Methods("GET", "OPTIONS")
|
||||
muxRouter.Handle("/mcp/message", sseServer.MessageHandler()).Methods("POST", "OPTIONS")
|
||||
muxRouter.Handle(basePath+"/sse", h).Methods("GET", "OPTIONS")
|
||||
muxRouter.Handle(basePath+"/message", h).Methods("POST", "OPTIONS")
|
||||
|
||||
// Convenience: also expose the full SSE server at /mcp for clients that
|
||||
// Convenience: also expose the full SSE server at basePath for clients that
|
||||
// use ServeHTTP directly (e.g. net/http default mux).
|
||||
muxRouter.PathPrefix("/mcp").Handler(http.StripPrefix("/mcp", sseServer))
|
||||
muxRouter.PathPrefix(basePath).Handler(http.StripPrefix(basePath, h))
|
||||
}
|
||||
|
||||
// NewSSEServer creates an *server.SSEServer that can be mounted manually,
|
||||
// useful when integrating with non-Mux routers or adding extra middleware.
|
||||
// SetupBunRouterRoutes mounts the MCP HTTP/SSE endpoints on a bunrouter router
|
||||
// using the base path from Config.BasePath.
|
||||
//
|
||||
// sseServer := resolvemcp.NewSSEServer(handler, "http://localhost:8080", "/mcp")
|
||||
// http.Handle("/mcp/", http.StripPrefix("/mcp", sseServer))
|
||||
func NewSSEServer(handler *Handler, baseURL, basePath string) *server.SSEServer {
|
||||
return server.NewSSEServer(
|
||||
handler.mcpServer,
|
||||
server.WithBaseURL(baseURL),
|
||||
server.WithBasePath(basePath),
|
||||
)
|
||||
// Two routes are registered:
|
||||
// - GET {basePath}/sse — SSE connection endpoint
|
||||
// - POST {basePath}/message — JSON-RPC message endpoint
|
||||
func SetupBunRouterRoutes(router *bunrouter.Router, handler *Handler) {
|
||||
basePath := handler.config.BasePath
|
||||
h := handler.SSEServer()
|
||||
|
||||
router.GET(basePath+"/sse", bunrouter.HTTPHandler(h))
|
||||
router.POST(basePath+"/message", bunrouter.HTTPHandler(h))
|
||||
}
|
||||
|
||||
// NewSSEServer returns an http.Handler that serves MCP over SSE.
|
||||
// If Config.BasePath is set it is used directly; otherwise the base path is
|
||||
// detected from each incoming request (by stripping the "/sse" or "/message" suffix).
|
||||
//
|
||||
// h := resolvemcp.NewSSEServer(handler)
|
||||
// http.Handle("/api/mcp/", h)
|
||||
func NewSSEServer(handler *Handler) http.Handler {
|
||||
return handler.SSEServer()
|
||||
}
|
||||
|
||||
@@ -67,7 +67,7 @@ func GetCursorFilter(
|
||||
// 4. Process each sort column
|
||||
// --------------------------------------------------------------------- //
|
||||
for _, s := range sortItems {
|
||||
col := strings.TrimSpace(s.Column)
|
||||
col := strings.Trim(strings.TrimSpace(s.Column), "()")
|
||||
if col == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -64,7 +64,7 @@ func (opts *ExtendedRequestOptions) GetCursorFilter(
|
||||
// 4. Process each sort column
|
||||
// --------------------------------------------------------------------- //
|
||||
for _, s := range sortItems {
|
||||
col := strings.TrimSpace(s.Column)
|
||||
col := strings.Trim(strings.TrimSpace(s.Column), "()")
|
||||
if col == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -258,11 +258,8 @@ func (a *JWTAuthenticator) Login(ctx context.Context, req security.LoginRequest)
|
||||
}
|
||||
|
||||
func (a *JWTAuthenticator) Logout(ctx context.Context, req security.LogoutRequest) error {
|
||||
// Add to blacklist
|
||||
return a.db.WithContext(ctx).Table("token_blacklist").Create(map[string]any{
|
||||
"token": req.Token,
|
||||
"user_id": req.UserID,
|
||||
}).Error
|
||||
// Invalidate session via stored procedure
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *JWTAuthenticator) Authenticate(r *http.Request) (*security.UserContext, error) {
|
||||
|
||||
@@ -135,12 +135,6 @@ func (a *JWTAuthenticatorExample) Login(ctx context.Context, req LoginRequest) (
|
||||
}
|
||||
|
||||
func (a *JWTAuthenticatorExample) Logout(ctx context.Context, req LogoutRequest) error {
|
||||
// For JWT, logout could involve token blacklisting
|
||||
// Add token to blacklist table
|
||||
// err := a.db.WithContext(ctx).Table("token_blacklist").Create(map[string]interface{}{
|
||||
// "token": req.Token,
|
||||
// "expires_at": time.Now().Add(24 * time.Hour),
|
||||
// }).Error
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -244,10 +244,10 @@ func (a *DatabaseAuthenticator) oauth2GetOrCreateUser(ctx context.Context, userC
|
||||
var errMsg *string
|
||||
var userID *int
|
||||
|
||||
err = a.db.QueryRowContext(ctx, `
|
||||
SELECT p_success, p_error, p_user_id
|
||||
FROM resolvespec_oauth_getorcreateuser($1::jsonb)
|
||||
`, userJSON).Scan(&success, &errMsg, &userID)
|
||||
err = a.db.QueryRowContext(ctx, fmt.Sprintf(`
|
||||
SELECT p_success, p_error, p_user_id
|
||||
FROM %s($1::jsonb)
|
||||
`, a.sqlNames.OAuthGetOrCreateUser), userJSON).Scan(&success, &errMsg, &userID)
|
||||
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to get or create user: %w", err)
|
||||
@@ -287,10 +287,10 @@ func (a *DatabaseAuthenticator) oauth2CreateSession(ctx context.Context, session
|
||||
var success bool
|
||||
var errMsg *string
|
||||
|
||||
err = a.db.QueryRowContext(ctx, `
|
||||
SELECT p_success, p_error
|
||||
FROM resolvespec_oauth_createsession($1::jsonb)
|
||||
`, sessionJSON).Scan(&success, &errMsg)
|
||||
err = a.db.QueryRowContext(ctx, fmt.Sprintf(`
|
||||
SELECT p_success, p_error
|
||||
FROM %s($1::jsonb)
|
||||
`, a.sqlNames.OAuthCreateSession), sessionJSON).Scan(&success, &errMsg)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create session: %w", err)
|
||||
@@ -385,10 +385,10 @@ func (a *DatabaseAuthenticator) OAuth2RefreshToken(ctx context.Context, refreshT
|
||||
var errMsg *string
|
||||
var sessionData []byte
|
||||
|
||||
err = a.db.QueryRowContext(ctx, `
|
||||
err = a.db.QueryRowContext(ctx, fmt.Sprintf(`
|
||||
SELECT p_success, p_error, p_data::text
|
||||
FROM resolvespec_oauth_getrefreshtoken($1)
|
||||
`, refreshToken).Scan(&success, &errMsg, &sessionData)
|
||||
FROM %s($1)
|
||||
`, a.sqlNames.OAuthGetRefreshToken), refreshToken).Scan(&success, &errMsg, &sessionData)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get session by refresh token: %w", err)
|
||||
@@ -451,10 +451,10 @@ func (a *DatabaseAuthenticator) OAuth2RefreshToken(ctx context.Context, refreshT
|
||||
var updateSuccess bool
|
||||
var updateErrMsg *string
|
||||
|
||||
err = a.db.QueryRowContext(ctx, `
|
||||
err = a.db.QueryRowContext(ctx, fmt.Sprintf(`
|
||||
SELECT p_success, p_error
|
||||
FROM resolvespec_oauth_updaterefreshtoken($1::jsonb)
|
||||
`, updateJSON).Scan(&updateSuccess, &updateErrMsg)
|
||||
FROM %s($1::jsonb)
|
||||
`, a.sqlNames.OAuthUpdateRefreshToken), updateJSON).Scan(&updateSuccess, &updateErrMsg)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to update session: %w", err)
|
||||
@@ -472,10 +472,10 @@ func (a *DatabaseAuthenticator) OAuth2RefreshToken(ctx context.Context, refreshT
|
||||
var userErrMsg *string
|
||||
var userData []byte
|
||||
|
||||
err = a.db.QueryRowContext(ctx, `
|
||||
err = a.db.QueryRowContext(ctx, fmt.Sprintf(`
|
||||
SELECT p_success, p_error, p_data::text
|
||||
FROM resolvespec_oauth_getuser($1)
|
||||
`, session.UserID).Scan(&userSuccess, &userErrMsg, &userData)
|
||||
FROM %s($1)
|
||||
`, a.sqlNames.OAuthGetUser), session.UserID).Scan(&userSuccess, &userErrMsg, &userData)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get user data: %w", err)
|
||||
|
||||
@@ -11,12 +11,14 @@ import (
|
||||
)
|
||||
|
||||
// DatabasePasskeyProvider implements PasskeyProvider using database storage
|
||||
// Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults)
|
||||
type DatabasePasskeyProvider struct {
|
||||
db *sql.DB
|
||||
rpID string // Relying Party ID (domain)
|
||||
rpName string // Relying Party display name
|
||||
rpOrigin string // Expected origin for WebAuthn
|
||||
timeout int64 // Timeout in milliseconds (default: 60000)
|
||||
sqlNames *SQLNames
|
||||
}
|
||||
|
||||
// DatabasePasskeyProviderOptions configures the passkey provider
|
||||
@@ -29,6 +31,8 @@ type DatabasePasskeyProviderOptions struct {
|
||||
RPOrigin string
|
||||
// Timeout is the timeout for operations in milliseconds (default: 60000)
|
||||
Timeout int64
|
||||
// SQLNames provides custom SQL procedure/function names. If nil, uses DefaultSQLNames().
|
||||
SQLNames *SQLNames
|
||||
}
|
||||
|
||||
// NewDatabasePasskeyProvider creates a new database-backed passkey provider
|
||||
@@ -37,12 +41,15 @@ func NewDatabasePasskeyProvider(db *sql.DB, opts DatabasePasskeyProviderOptions)
|
||||
opts.Timeout = 60000 // 60 seconds default
|
||||
}
|
||||
|
||||
sqlNames := MergeSQLNames(DefaultSQLNames(), opts.SQLNames)
|
||||
|
||||
return &DatabasePasskeyProvider{
|
||||
db: db,
|
||||
rpID: opts.RPID,
|
||||
rpName: opts.RPName,
|
||||
rpOrigin: opts.RPOrigin,
|
||||
timeout: opts.Timeout,
|
||||
sqlNames: sqlNames,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -132,7 +139,7 @@ func (p *DatabasePasskeyProvider) CompleteRegistration(ctx context.Context, user
|
||||
var errorMsg sql.NullString
|
||||
var credentialID sql.NullInt64
|
||||
|
||||
query := `SELECT p_success, p_error, p_credential_id FROM resolvespec_passkey_store_credential($1::jsonb)`
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_credential_id FROM %s($1::jsonb)`, p.sqlNames.PasskeyStoreCredential)
|
||||
err = p.db.QueryRowContext(ctx, query, string(credJSON)).Scan(&success, &errorMsg, &credentialID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to store credential: %w", err)
|
||||
@@ -173,7 +180,7 @@ func (p *DatabasePasskeyProvider) BeginAuthentication(ctx context.Context, usern
|
||||
var userID sql.NullInt64
|
||||
var credentialsJSON sql.NullString
|
||||
|
||||
query := `SELECT p_success, p_error, p_user_id, p_credentials::text FROM resolvespec_passkey_get_credentials_by_username($1)`
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_user_id, p_credentials::text FROM %s($1)`, p.sqlNames.PasskeyGetCredsByUsername)
|
||||
err := p.db.QueryRowContext(ctx, query, username).Scan(&success, &errorMsg, &userID, &credentialsJSON)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get credentials: %w", err)
|
||||
@@ -233,7 +240,7 @@ func (p *DatabasePasskeyProvider) CompleteAuthentication(ctx context.Context, re
|
||||
var errorMsg sql.NullString
|
||||
var credentialJSON sql.NullString
|
||||
|
||||
query := `SELECT p_success, p_error, p_credential::text FROM resolvespec_passkey_get_credential($1)`
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_credential::text FROM %s($1)`, p.sqlNames.PasskeyGetCredential)
|
||||
err := p.db.QueryRowContext(ctx, query, response.RawID).Scan(&success, &errorMsg, &credentialJSON)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to get credential: %w", err)
|
||||
@@ -264,7 +271,7 @@ func (p *DatabasePasskeyProvider) CompleteAuthentication(ctx context.Context, re
|
||||
var updateError sql.NullString
|
||||
var cloneWarning sql.NullBool
|
||||
|
||||
updateQuery := `SELECT p_success, p_error, p_clone_warning FROM resolvespec_passkey_update_counter($1, $2)`
|
||||
updateQuery := fmt.Sprintf(`SELECT p_success, p_error, p_clone_warning FROM %s($1, $2)`, p.sqlNames.PasskeyUpdateCounter)
|
||||
err = p.db.QueryRowContext(ctx, updateQuery, response.RawID, newCounter).Scan(&updateSuccess, &updateError, &cloneWarning)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to update counter: %w", err)
|
||||
@@ -283,7 +290,7 @@ func (p *DatabasePasskeyProvider) GetCredentials(ctx context.Context, userID int
|
||||
var errorMsg sql.NullString
|
||||
var credentialsJSON sql.NullString
|
||||
|
||||
query := `SELECT p_success, p_error, p_credentials::text FROM resolvespec_passkey_get_user_credentials($1)`
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_credentials::text FROM %s($1)`, p.sqlNames.PasskeyGetUserCredentials)
|
||||
err := p.db.QueryRowContext(ctx, query, userID).Scan(&success, &errorMsg, &credentialsJSON)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get credentials: %w", err)
|
||||
@@ -362,7 +369,7 @@ func (p *DatabasePasskeyProvider) DeleteCredential(ctx context.Context, userID i
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
|
||||
query := `SELECT p_success, p_error FROM resolvespec_passkey_delete_credential($1, $2)`
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1, $2)`, p.sqlNames.PasskeyDeleteCredential)
|
||||
err = p.db.QueryRowContext(ctx, query, userID, credID).Scan(&success, &errorMsg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete credential: %w", err)
|
||||
@@ -388,7 +395,7 @@ func (p *DatabasePasskeyProvider) UpdateCredentialName(ctx context.Context, user
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
|
||||
query := `SELECT p_success, p_error FROM resolvespec_passkey_update_name($1, $2, $3)`
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1, $2, $3)`, p.sqlNames.PasskeyUpdateName)
|
||||
err = p.db.QueryRowContext(ctx, query, userID, credID, name).Scan(&success, &errorMsg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update credential name: %w", err)
|
||||
|
||||
@@ -58,8 +58,7 @@ func (a *HeaderAuthenticator) Authenticate(r *http.Request) (*UserContext, error
|
||||
|
||||
// DatabaseAuthenticator provides session-based authentication with database storage
|
||||
// All database operations go through stored procedures for security and consistency
|
||||
// Requires stored procedures: resolvespec_login, resolvespec_logout, resolvespec_session,
|
||||
// resolvespec_session_update, resolvespec_refresh_token
|
||||
// Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults)
|
||||
// See database_schema.sql for procedure definitions
|
||||
// Also supports multiple OAuth2 providers configured with WithOAuth2()
|
||||
// Also supports passkey authentication configured with WithPasskey()
|
||||
@@ -67,6 +66,7 @@ type DatabaseAuthenticator struct {
|
||||
db *sql.DB
|
||||
cache *cache.Cache
|
||||
cacheTTL time.Duration
|
||||
sqlNames *SQLNames
|
||||
|
||||
// OAuth2 providers registry (multiple providers supported)
|
||||
oauth2Providers map[string]*OAuth2Provider
|
||||
@@ -85,6 +85,9 @@ type DatabaseAuthenticatorOptions struct {
|
||||
Cache *cache.Cache
|
||||
// PasskeyProvider is an optional passkey provider for WebAuthn/FIDO2 authentication
|
||||
PasskeyProvider PasskeyProvider
|
||||
// SQLNames provides custom SQL procedure/function names. If nil, uses DefaultSQLNames().
|
||||
// Partial overrides are supported: only set the fields you want to change.
|
||||
SQLNames *SQLNames
|
||||
}
|
||||
|
||||
func NewDatabaseAuthenticator(db *sql.DB) *DatabaseAuthenticator {
|
||||
@@ -103,10 +106,13 @@ func NewDatabaseAuthenticatorWithOptions(db *sql.DB, opts DatabaseAuthenticatorO
|
||||
cacheInstance = cache.GetDefaultCache()
|
||||
}
|
||||
|
||||
sqlNames := MergeSQLNames(DefaultSQLNames(), opts.SQLNames)
|
||||
|
||||
return &DatabaseAuthenticator{
|
||||
db: db,
|
||||
cache: cacheInstance,
|
||||
cacheTTL: opts.CacheTTL,
|
||||
sqlNames: sqlNames,
|
||||
passkeyProvider: opts.PasskeyProvider,
|
||||
}
|
||||
}
|
||||
@@ -118,12 +124,11 @@ func (a *DatabaseAuthenticator) Login(ctx context.Context, req LoginRequest) (*L
|
||||
return nil, fmt.Errorf("failed to marshal login request: %w", err)
|
||||
}
|
||||
|
||||
// Call resolvespec_login stored procedure
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
var dataJSON sql.NullString
|
||||
|
||||
query := `SELECT p_success, p_error, p_data::text FROM resolvespec_login($1::jsonb)`
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Login)
|
||||
err = a.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("login query failed: %w", err)
|
||||
@@ -153,12 +158,11 @@ func (a *DatabaseAuthenticator) Register(ctx context.Context, req RegisterReques
|
||||
return nil, fmt.Errorf("failed to marshal register request: %w", err)
|
||||
}
|
||||
|
||||
// Call resolvespec_register stored procedure
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
var dataJSON sql.NullString
|
||||
|
||||
query := `SELECT p_success, p_error, p_data::text FROM resolvespec_register($1::jsonb)`
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Register)
|
||||
err = a.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("register query failed: %w", err)
|
||||
@@ -187,12 +191,11 @@ func (a *DatabaseAuthenticator) Logout(ctx context.Context, req LogoutRequest) e
|
||||
return fmt.Errorf("failed to marshal logout request: %w", err)
|
||||
}
|
||||
|
||||
// Call resolvespec_logout stored procedure
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
var dataJSON sql.NullString
|
||||
|
||||
query := `SELECT p_success, p_error, p_data::text FROM resolvespec_logout($1::jsonb)`
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Logout)
|
||||
err = a.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
|
||||
if err != nil {
|
||||
return fmt.Errorf("logout query failed: %w", err)
|
||||
@@ -266,7 +269,7 @@ func (a *DatabaseAuthenticator) Authenticate(r *http.Request) (*UserContext, err
|
||||
var errorMsg sql.NullString
|
||||
var userJSON sql.NullString
|
||||
|
||||
query := `SELECT p_success, p_error, p_user::text FROM resolvespec_session($1, $2)`
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2)`, a.sqlNames.Session)
|
||||
err := a.db.QueryRowContext(r.Context(), query, token, reference).Scan(&success, &errorMsg, &userJSON)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("session query failed: %w", err)
|
||||
@@ -338,24 +341,22 @@ func (a *DatabaseAuthenticator) updateSessionActivity(ctx context.Context, sessi
|
||||
return
|
||||
}
|
||||
|
||||
// Call resolvespec_session_update stored procedure
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
var updatedUserJSON sql.NullString
|
||||
|
||||
query := `SELECT p_success, p_error, p_user::text FROM resolvespec_session_update($1, $2::jsonb)`
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2::jsonb)`, a.sqlNames.SessionUpdate)
|
||||
_ = a.db.QueryRowContext(ctx, query, sessionToken, string(userJSON)).Scan(&success, &errorMsg, &updatedUserJSON)
|
||||
}
|
||||
|
||||
// RefreshToken implements Refreshable interface
|
||||
func (a *DatabaseAuthenticator) RefreshToken(ctx context.Context, refreshToken string) (*LoginResponse, error) {
|
||||
// Call api_refresh_token stored procedure
|
||||
// First, we need to get the current user context for the refresh token
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
var userJSON sql.NullString
|
||||
// Get current session to pass to refresh
|
||||
query := `SELECT p_success, p_error, p_user::text FROM resolvespec_session($1, $2)`
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2)`, a.sqlNames.Session)
|
||||
err := a.db.QueryRowContext(ctx, query, refreshToken, "refresh").Scan(&success, &errorMsg, &userJSON)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("refresh token query failed: %w", err)
|
||||
@@ -368,12 +369,11 @@ func (a *DatabaseAuthenticator) RefreshToken(ctx context.Context, refreshToken s
|
||||
return nil, fmt.Errorf("invalid refresh token")
|
||||
}
|
||||
|
||||
// Call resolvespec_refresh_token to generate new token
|
||||
var newSuccess bool
|
||||
var newErrorMsg sql.NullString
|
||||
var newUserJSON sql.NullString
|
||||
|
||||
refreshQuery := `SELECT p_success, p_error, p_user::text FROM resolvespec_refresh_token($1, $2::jsonb)`
|
||||
refreshQuery := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2::jsonb)`, a.sqlNames.RefreshToken)
|
||||
err = a.db.QueryRowContext(ctx, refreshQuery, refreshToken, userJSON).Scan(&newSuccess, &newErrorMsg, &newUserJSON)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("refresh token generation failed: %w", err)
|
||||
@@ -401,27 +401,28 @@ func (a *DatabaseAuthenticator) RefreshToken(ctx context.Context, refreshToken s
|
||||
|
||||
// JWTAuthenticator provides JWT token-based authentication
|
||||
// All database operations go through stored procedures
|
||||
// Requires stored procedures: resolvespec_jwt_login, resolvespec_jwt_logout
|
||||
// Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults)
|
||||
// NOTE: JWT signing/verification requires github.com/golang-jwt/jwt/v5 to be installed and imported
|
||||
type JWTAuthenticator struct {
|
||||
secretKey []byte
|
||||
db *sql.DB
|
||||
sqlNames *SQLNames
|
||||
}
|
||||
|
||||
func NewJWTAuthenticator(secretKey string, db *sql.DB) *JWTAuthenticator {
|
||||
func NewJWTAuthenticator(secretKey string, db *sql.DB, names ...*SQLNames) *JWTAuthenticator {
|
||||
return &JWTAuthenticator{
|
||||
secretKey: []byte(secretKey),
|
||||
db: db,
|
||||
sqlNames: resolveSQLNames(names...),
|
||||
}
|
||||
}
|
||||
|
||||
func (a *JWTAuthenticator) Login(ctx context.Context, req LoginRequest) (*LoginResponse, error) {
|
||||
// Call resolvespec_jwt_login stored procedure
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
var userJSON []byte
|
||||
|
||||
query := `SELECT p_success, p_error, p_user FROM resolvespec_jwt_login($1, $2)`
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_user FROM %s($1, $2)`, a.sqlNames.JWTLogin)
|
||||
err := a.db.QueryRowContext(ctx, query, req.Username, req.Password).Scan(&success, &errorMsg, &userJSON)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("login query failed: %w", err)
|
||||
@@ -471,11 +472,10 @@ func (a *JWTAuthenticator) Login(ctx context.Context, req LoginRequest) (*LoginR
|
||||
}
|
||||
|
||||
func (a *JWTAuthenticator) Logout(ctx context.Context, req LogoutRequest) error {
|
||||
// Call resolvespec_jwt_logout stored procedure
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
|
||||
query := `SELECT p_success, p_error FROM resolvespec_jwt_logout($1, $2)`
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1, $2)`, a.sqlNames.JWTLogout)
|
||||
err := a.db.QueryRowContext(ctx, query, req.Token, req.UserID).Scan(&success, &errorMsg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("logout query failed: %w", err)
|
||||
@@ -511,24 +511,24 @@ func (a *JWTAuthenticator) Authenticate(r *http.Request) (*UserContext, error) {
|
||||
|
||||
// DatabaseColumnSecurityProvider loads column security from database
|
||||
// All database operations go through stored procedures
|
||||
// Requires stored procedure: resolvespec_column_security
|
||||
// Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults)
|
||||
type DatabaseColumnSecurityProvider struct {
|
||||
db *sql.DB
|
||||
db *sql.DB
|
||||
sqlNames *SQLNames
|
||||
}
|
||||
|
||||
func NewDatabaseColumnSecurityProvider(db *sql.DB) *DatabaseColumnSecurityProvider {
|
||||
return &DatabaseColumnSecurityProvider{db: db}
|
||||
func NewDatabaseColumnSecurityProvider(db *sql.DB, names ...*SQLNames) *DatabaseColumnSecurityProvider {
|
||||
return &DatabaseColumnSecurityProvider{db: db, sqlNames: resolveSQLNames(names...)}
|
||||
}
|
||||
|
||||
func (p *DatabaseColumnSecurityProvider) GetColumnSecurity(ctx context.Context, userID int, schema, table string) ([]ColumnSecurity, error) {
|
||||
var rules []ColumnSecurity
|
||||
|
||||
// Call resolvespec_column_security stored procedure
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
var rulesJSON []byte
|
||||
|
||||
query := `SELECT p_success, p_error, p_rules FROM resolvespec_column_security($1, $2, $3)`
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_rules FROM %s($1, $2, $3)`, p.sqlNames.ColumnSecurity)
|
||||
err := p.db.QueryRowContext(ctx, query, userID, schema, table).Scan(&success, &errorMsg, &rulesJSON)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load column security: %w", err)
|
||||
@@ -576,21 +576,21 @@ func (p *DatabaseColumnSecurityProvider) GetColumnSecurity(ctx context.Context,
|
||||
|
||||
// DatabaseRowSecurityProvider loads row security from database
|
||||
// All database operations go through stored procedures
|
||||
// Requires stored procedure: resolvespec_row_security
|
||||
// Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults)
|
||||
type DatabaseRowSecurityProvider struct {
|
||||
db *sql.DB
|
||||
db *sql.DB
|
||||
sqlNames *SQLNames
|
||||
}
|
||||
|
||||
func NewDatabaseRowSecurityProvider(db *sql.DB) *DatabaseRowSecurityProvider {
|
||||
return &DatabaseRowSecurityProvider{db: db}
|
||||
func NewDatabaseRowSecurityProvider(db *sql.DB, names ...*SQLNames) *DatabaseRowSecurityProvider {
|
||||
return &DatabaseRowSecurityProvider{db: db, sqlNames: resolveSQLNames(names...)}
|
||||
}
|
||||
|
||||
func (p *DatabaseRowSecurityProvider) GetRowSecurity(ctx context.Context, userID int, schema, table string) (RowSecurity, error) {
|
||||
var template string
|
||||
var hasBlock bool
|
||||
|
||||
// Call resolvespec_row_security stored procedure
|
||||
query := `SELECT p_template, p_block FROM resolvespec_row_security($1, $2, $3)`
|
||||
query := fmt.Sprintf(`SELECT p_template, p_block FROM %s($1, $2, $3)`, p.sqlNames.RowSecurity)
|
||||
|
||||
err := p.db.QueryRowContext(ctx, query, schema, table, userID).Scan(&template, &hasBlock)
|
||||
if err != nil {
|
||||
@@ -758,56 +758,47 @@ func (a *DatabaseAuthenticator) LoginWithPasskey(ctx context.Context, req Passke
|
||||
return nil, fmt.Errorf("passkey authentication failed: %w", err)
|
||||
}
|
||||
|
||||
// Get user data from database
|
||||
var username, email, roles string
|
||||
var userLevel int
|
||||
query := `SELECT username, email, user_level, COALESCE(roles, '') FROM users WHERE id = $1 AND is_active = true`
|
||||
err = a.db.QueryRowContext(ctx, query, userID).Scan(&username, &email, &userLevel, &roles)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get user data: %w", err)
|
||||
// Build request JSON for passkey login stored procedure
|
||||
reqData := map[string]any{
|
||||
"user_id": userID,
|
||||
}
|
||||
|
||||
// Generate session token
|
||||
sessionToken := "sess_" + generateRandomString(32) + "_" + fmt.Sprintf("%d", time.Now().Unix())
|
||||
expiresAt := time.Now().Add(24 * time.Hour)
|
||||
|
||||
// Extract IP and user agent from claims
|
||||
ipAddress := ""
|
||||
userAgent := ""
|
||||
if req.Claims != nil {
|
||||
if ip, ok := req.Claims["ip_address"].(string); ok {
|
||||
ipAddress = ip
|
||||
reqData["ip_address"] = ip
|
||||
}
|
||||
if ua, ok := req.Claims["user_agent"].(string); ok {
|
||||
userAgent = ua
|
||||
reqData["user_agent"] = ua
|
||||
}
|
||||
}
|
||||
|
||||
// Create session
|
||||
insertQuery := `INSERT INTO user_sessions (session_token, user_id, expires_at, ip_address, user_agent, last_activity_at)
|
||||
VALUES ($1, $2, $3, $4, $5, now())`
|
||||
_, err = a.db.ExecContext(ctx, insertQuery, sessionToken, userID, expiresAt, ipAddress, userAgent)
|
||||
reqJSON, err := json.Marshal(reqData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create session: %w", err)
|
||||
return nil, fmt.Errorf("failed to marshal passkey login request: %w", err)
|
||||
}
|
||||
|
||||
// Update last login
|
||||
updateQuery := `UPDATE users SET last_login_at = now() WHERE id = $1`
|
||||
_, _ = a.db.ExecContext(ctx, updateQuery, userID)
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
var dataJSON sql.NullString
|
||||
|
||||
// Return login response
|
||||
return &LoginResponse{
|
||||
Token: sessionToken,
|
||||
User: &UserContext{
|
||||
UserID: userID,
|
||||
UserName: username,
|
||||
Email: email,
|
||||
UserLevel: userLevel,
|
||||
SessionID: sessionToken,
|
||||
Roles: parseRoles(roles),
|
||||
},
|
||||
ExpiresIn: int64(24 * time.Hour.Seconds()),
|
||||
}, nil
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.PasskeyLogin)
|
||||
err = a.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("passkey login query failed: %w", err)
|
||||
}
|
||||
|
||||
if !success {
|
||||
if errorMsg.Valid {
|
||||
return nil, fmt.Errorf("%s", errorMsg.String)
|
||||
}
|
||||
return nil, fmt.Errorf("passkey login failed")
|
||||
}
|
||||
|
||||
var response LoginResponse
|
||||
if err := json.Unmarshal([]byte(dataJSON.String), &response); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse passkey login response: %w", err)
|
||||
}
|
||||
|
||||
return &response, nil
|
||||
}
|
||||
|
||||
// GetPasskeyCredentials returns all passkey credentials for a user
|
||||
|
||||
222
pkg/security/sql_names.go
Normal file
222
pkg/security/sql_names.go
Normal file
@@ -0,0 +1,222 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"regexp"
|
||||
)
|
||||
|
||||
var validSQLIdentifier = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`)
|
||||
|
||||
// SQLNames defines all configurable SQL stored procedure and table names
|
||||
// used by the security package. Override individual fields to remap
|
||||
// to custom database objects. Use DefaultSQLNames() for baseline defaults,
|
||||
// and MergeSQLNames() to apply partial overrides.
|
||||
type SQLNames struct {
|
||||
// Auth procedures (DatabaseAuthenticator)
|
||||
Login string // default: "resolvespec_login"
|
||||
Register string // default: "resolvespec_register"
|
||||
Logout string // default: "resolvespec_logout"
|
||||
Session string // default: "resolvespec_session"
|
||||
SessionUpdate string // default: "resolvespec_session_update"
|
||||
RefreshToken string // default: "resolvespec_refresh_token"
|
||||
|
||||
// JWT procedures (JWTAuthenticator)
|
||||
JWTLogin string // default: "resolvespec_jwt_login"
|
||||
JWTLogout string // default: "resolvespec_jwt_logout"
|
||||
|
||||
// Security policy procedures
|
||||
ColumnSecurity string // default: "resolvespec_column_security"
|
||||
RowSecurity string // default: "resolvespec_row_security"
|
||||
|
||||
// TOTP procedures (DatabaseTwoFactorProvider)
|
||||
TOTPEnable string // default: "resolvespec_totp_enable"
|
||||
TOTPDisable string // default: "resolvespec_totp_disable"
|
||||
TOTPGetStatus string // default: "resolvespec_totp_get_status"
|
||||
TOTPGetSecret string // default: "resolvespec_totp_get_secret"
|
||||
TOTPRegenerateBackup string // default: "resolvespec_totp_regenerate_backup_codes"
|
||||
TOTPValidateBackupCode string // default: "resolvespec_totp_validate_backup_code"
|
||||
|
||||
// Passkey procedures (DatabasePasskeyProvider)
|
||||
PasskeyStoreCredential string // default: "resolvespec_passkey_store_credential"
|
||||
PasskeyGetCredsByUsername string // default: "resolvespec_passkey_get_credentials_by_username"
|
||||
PasskeyGetCredential string // default: "resolvespec_passkey_get_credential"
|
||||
PasskeyUpdateCounter string // default: "resolvespec_passkey_update_counter"
|
||||
PasskeyGetUserCredentials string // default: "resolvespec_passkey_get_user_credentials"
|
||||
PasskeyDeleteCredential string // default: "resolvespec_passkey_delete_credential"
|
||||
PasskeyUpdateName string // default: "resolvespec_passkey_update_name"
|
||||
PasskeyLogin string // default: "resolvespec_passkey_login"
|
||||
|
||||
// OAuth2 procedures (DatabaseAuthenticator OAuth2 methods)
|
||||
OAuthGetOrCreateUser string // default: "resolvespec_oauth_getorcreateuser"
|
||||
OAuthCreateSession string // default: "resolvespec_oauth_createsession"
|
||||
OAuthGetRefreshToken string // default: "resolvespec_oauth_getrefreshtoken"
|
||||
OAuthUpdateRefreshToken string // default: "resolvespec_oauth_updaterefreshtoken"
|
||||
OAuthGetUser string // default: "resolvespec_oauth_getuser"
|
||||
|
||||
}
|
||||
|
||||
// DefaultSQLNames returns an SQLNames with all default resolvespec_* values.
|
||||
func DefaultSQLNames() *SQLNames {
|
||||
return &SQLNames{
|
||||
Login: "resolvespec_login",
|
||||
Register: "resolvespec_register",
|
||||
Logout: "resolvespec_logout",
|
||||
Session: "resolvespec_session",
|
||||
SessionUpdate: "resolvespec_session_update",
|
||||
RefreshToken: "resolvespec_refresh_token",
|
||||
|
||||
JWTLogin: "resolvespec_jwt_login",
|
||||
JWTLogout: "resolvespec_jwt_logout",
|
||||
|
||||
ColumnSecurity: "resolvespec_column_security",
|
||||
RowSecurity: "resolvespec_row_security",
|
||||
|
||||
TOTPEnable: "resolvespec_totp_enable",
|
||||
TOTPDisable: "resolvespec_totp_disable",
|
||||
TOTPGetStatus: "resolvespec_totp_get_status",
|
||||
TOTPGetSecret: "resolvespec_totp_get_secret",
|
||||
TOTPRegenerateBackup: "resolvespec_totp_regenerate_backup_codes",
|
||||
TOTPValidateBackupCode: "resolvespec_totp_validate_backup_code",
|
||||
|
||||
PasskeyStoreCredential: "resolvespec_passkey_store_credential",
|
||||
PasskeyGetCredsByUsername: "resolvespec_passkey_get_credentials_by_username",
|
||||
PasskeyGetCredential: "resolvespec_passkey_get_credential",
|
||||
PasskeyUpdateCounter: "resolvespec_passkey_update_counter",
|
||||
PasskeyGetUserCredentials: "resolvespec_passkey_get_user_credentials",
|
||||
PasskeyDeleteCredential: "resolvespec_passkey_delete_credential",
|
||||
PasskeyUpdateName: "resolvespec_passkey_update_name",
|
||||
PasskeyLogin: "resolvespec_passkey_login",
|
||||
|
||||
OAuthGetOrCreateUser: "resolvespec_oauth_getorcreateuser",
|
||||
OAuthCreateSession: "resolvespec_oauth_createsession",
|
||||
OAuthGetRefreshToken: "resolvespec_oauth_getrefreshtoken",
|
||||
OAuthUpdateRefreshToken: "resolvespec_oauth_updaterefreshtoken",
|
||||
OAuthGetUser: "resolvespec_oauth_getuser",
|
||||
}
|
||||
}
|
||||
|
||||
// MergeSQLNames returns a copy of base with any non-empty fields from override applied.
|
||||
// If override is nil, a copy of base is returned.
|
||||
func MergeSQLNames(base, override *SQLNames) *SQLNames {
|
||||
if override == nil {
|
||||
copied := *base
|
||||
return &copied
|
||||
}
|
||||
merged := *base
|
||||
if override.Login != "" {
|
||||
merged.Login = override.Login
|
||||
}
|
||||
if override.Register != "" {
|
||||
merged.Register = override.Register
|
||||
}
|
||||
if override.Logout != "" {
|
||||
merged.Logout = override.Logout
|
||||
}
|
||||
if override.Session != "" {
|
||||
merged.Session = override.Session
|
||||
}
|
||||
if override.SessionUpdate != "" {
|
||||
merged.SessionUpdate = override.SessionUpdate
|
||||
}
|
||||
if override.RefreshToken != "" {
|
||||
merged.RefreshToken = override.RefreshToken
|
||||
}
|
||||
if override.JWTLogin != "" {
|
||||
merged.JWTLogin = override.JWTLogin
|
||||
}
|
||||
if override.JWTLogout != "" {
|
||||
merged.JWTLogout = override.JWTLogout
|
||||
}
|
||||
if override.ColumnSecurity != "" {
|
||||
merged.ColumnSecurity = override.ColumnSecurity
|
||||
}
|
||||
if override.RowSecurity != "" {
|
||||
merged.RowSecurity = override.RowSecurity
|
||||
}
|
||||
if override.TOTPEnable != "" {
|
||||
merged.TOTPEnable = override.TOTPEnable
|
||||
}
|
||||
if override.TOTPDisable != "" {
|
||||
merged.TOTPDisable = override.TOTPDisable
|
||||
}
|
||||
if override.TOTPGetStatus != "" {
|
||||
merged.TOTPGetStatus = override.TOTPGetStatus
|
||||
}
|
||||
if override.TOTPGetSecret != "" {
|
||||
merged.TOTPGetSecret = override.TOTPGetSecret
|
||||
}
|
||||
if override.TOTPRegenerateBackup != "" {
|
||||
merged.TOTPRegenerateBackup = override.TOTPRegenerateBackup
|
||||
}
|
||||
if override.TOTPValidateBackupCode != "" {
|
||||
merged.TOTPValidateBackupCode = override.TOTPValidateBackupCode
|
||||
}
|
||||
if override.PasskeyStoreCredential != "" {
|
||||
merged.PasskeyStoreCredential = override.PasskeyStoreCredential
|
||||
}
|
||||
if override.PasskeyGetCredsByUsername != "" {
|
||||
merged.PasskeyGetCredsByUsername = override.PasskeyGetCredsByUsername
|
||||
}
|
||||
if override.PasskeyGetCredential != "" {
|
||||
merged.PasskeyGetCredential = override.PasskeyGetCredential
|
||||
}
|
||||
if override.PasskeyUpdateCounter != "" {
|
||||
merged.PasskeyUpdateCounter = override.PasskeyUpdateCounter
|
||||
}
|
||||
if override.PasskeyGetUserCredentials != "" {
|
||||
merged.PasskeyGetUserCredentials = override.PasskeyGetUserCredentials
|
||||
}
|
||||
if override.PasskeyDeleteCredential != "" {
|
||||
merged.PasskeyDeleteCredential = override.PasskeyDeleteCredential
|
||||
}
|
||||
if override.PasskeyUpdateName != "" {
|
||||
merged.PasskeyUpdateName = override.PasskeyUpdateName
|
||||
}
|
||||
if override.PasskeyLogin != "" {
|
||||
merged.PasskeyLogin = override.PasskeyLogin
|
||||
}
|
||||
if override.OAuthGetOrCreateUser != "" {
|
||||
merged.OAuthGetOrCreateUser = override.OAuthGetOrCreateUser
|
||||
}
|
||||
if override.OAuthCreateSession != "" {
|
||||
merged.OAuthCreateSession = override.OAuthCreateSession
|
||||
}
|
||||
if override.OAuthGetRefreshToken != "" {
|
||||
merged.OAuthGetRefreshToken = override.OAuthGetRefreshToken
|
||||
}
|
||||
if override.OAuthUpdateRefreshToken != "" {
|
||||
merged.OAuthUpdateRefreshToken = override.OAuthUpdateRefreshToken
|
||||
}
|
||||
if override.OAuthGetUser != "" {
|
||||
merged.OAuthGetUser = override.OAuthGetUser
|
||||
}
|
||||
return &merged
|
||||
}
|
||||
|
||||
// ValidateSQLNames checks that all non-empty fields in names are valid SQL identifiers.
|
||||
// Returns an error if any field contains invalid characters.
|
||||
func ValidateSQLNames(names *SQLNames) error {
|
||||
v := reflect.ValueOf(names).Elem()
|
||||
typ := v.Type()
|
||||
for i := 0; i < v.NumField(); i++ {
|
||||
field := v.Field(i)
|
||||
if field.Kind() != reflect.String {
|
||||
continue
|
||||
}
|
||||
val := field.String()
|
||||
if val != "" && !validSQLIdentifier.MatchString(val) {
|
||||
return fmt.Errorf("SQLNames.%s contains invalid characters: %q", typ.Field(i).Name, val)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// resolveSQLNames merges an optional override with defaults.
|
||||
// Used by constructors that accept variadic *SQLNames.
|
||||
func resolveSQLNames(override ...*SQLNames) *SQLNames {
|
||||
if len(override) > 0 && override[0] != nil {
|
||||
return MergeSQLNames(DefaultSQLNames(), override[0])
|
||||
}
|
||||
return DefaultSQLNames()
|
||||
}
|
||||
145
pkg/security/sql_names_test.go
Normal file
145
pkg/security/sql_names_test.go
Normal file
@@ -0,0 +1,145 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestDefaultSQLNames_AllFieldsNonEmpty(t *testing.T) {
|
||||
names := DefaultSQLNames()
|
||||
v := reflect.ValueOf(names).Elem()
|
||||
typ := v.Type()
|
||||
|
||||
for i := 0; i < v.NumField(); i++ {
|
||||
field := v.Field(i)
|
||||
if field.Kind() != reflect.String {
|
||||
continue
|
||||
}
|
||||
if field.String() == "" {
|
||||
t.Errorf("DefaultSQLNames().%s is empty", typ.Field(i).Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeSQLNames_PartialOverride(t *testing.T) {
|
||||
base := DefaultSQLNames()
|
||||
override := &SQLNames{
|
||||
Login: "custom_login",
|
||||
TOTPEnable: "custom_totp_enable",
|
||||
PasskeyLogin: "custom_passkey_login",
|
||||
}
|
||||
|
||||
merged := MergeSQLNames(base, override)
|
||||
|
||||
if merged.Login != "custom_login" {
|
||||
t.Errorf("MergeSQLNames().Login = %q, want %q", merged.Login, "custom_login")
|
||||
}
|
||||
if merged.TOTPEnable != "custom_totp_enable" {
|
||||
t.Errorf("MergeSQLNames().TOTPEnable = %q, want %q", merged.TOTPEnable, "custom_totp_enable")
|
||||
}
|
||||
if merged.PasskeyLogin != "custom_passkey_login" {
|
||||
t.Errorf("MergeSQLNames().PasskeyLogin = %q, want %q", merged.PasskeyLogin, "custom_passkey_login")
|
||||
}
|
||||
// Non-overridden fields should retain defaults
|
||||
if merged.Logout != "resolvespec_logout" {
|
||||
t.Errorf("MergeSQLNames().Logout = %q, want %q", merged.Logout, "resolvespec_logout")
|
||||
}
|
||||
if merged.Session != "resolvespec_session" {
|
||||
t.Errorf("MergeSQLNames().Session = %q, want %q", merged.Session, "resolvespec_session")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeSQLNames_NilOverride(t *testing.T) {
|
||||
base := DefaultSQLNames()
|
||||
merged := MergeSQLNames(base, nil)
|
||||
|
||||
// Should be a copy, not the same pointer
|
||||
if merged == base {
|
||||
t.Error("MergeSQLNames with nil override should return a copy, not the same pointer")
|
||||
}
|
||||
|
||||
// All values should match
|
||||
v1 := reflect.ValueOf(base).Elem()
|
||||
v2 := reflect.ValueOf(merged).Elem()
|
||||
typ := v1.Type()
|
||||
|
||||
for i := 0; i < v1.NumField(); i++ {
|
||||
f1 := v1.Field(i)
|
||||
f2 := v2.Field(i)
|
||||
if f1.Kind() != reflect.String {
|
||||
continue
|
||||
}
|
||||
if f1.String() != f2.String() {
|
||||
t.Errorf("MergeSQLNames(base, nil).%s = %q, want %q", typ.Field(i).Name, f2.String(), f1.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeSQLNames_DoesNotMutateBase(t *testing.T) {
|
||||
base := DefaultSQLNames()
|
||||
originalLogin := base.Login
|
||||
|
||||
override := &SQLNames{Login: "custom_login"}
|
||||
_ = MergeSQLNames(base, override)
|
||||
|
||||
if base.Login != originalLogin {
|
||||
t.Errorf("MergeSQLNames mutated base: Login = %q, want %q", base.Login, originalLogin)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeSQLNames_AllFieldsMerged(t *testing.T) {
|
||||
base := DefaultSQLNames()
|
||||
override := &SQLNames{}
|
||||
v := reflect.ValueOf(override).Elem()
|
||||
for i := 0; i < v.NumField(); i++ {
|
||||
if v.Field(i).Kind() == reflect.String {
|
||||
v.Field(i).SetString("custom_sentinel")
|
||||
}
|
||||
}
|
||||
|
||||
merged := MergeSQLNames(base, override)
|
||||
mv := reflect.ValueOf(merged).Elem()
|
||||
typ := mv.Type()
|
||||
for i := 0; i < mv.NumField(); i++ {
|
||||
if mv.Field(i).Kind() != reflect.String {
|
||||
continue
|
||||
}
|
||||
if mv.Field(i).String() != "custom_sentinel" {
|
||||
t.Errorf("MergeSQLNames did not merge field %s", typ.Field(i).Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateSQLNames_Valid(t *testing.T) {
|
||||
names := DefaultSQLNames()
|
||||
if err := ValidateSQLNames(names); err != nil {
|
||||
t.Errorf("ValidateSQLNames(defaults) error = %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateSQLNames_Invalid(t *testing.T) {
|
||||
names := DefaultSQLNames()
|
||||
names.Login = "resolvespec_login; DROP TABLE users; --"
|
||||
|
||||
err := ValidateSQLNames(names)
|
||||
if err == nil {
|
||||
t.Error("ValidateSQLNames should reject names with invalid characters")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveSQLNames_NoOverride(t *testing.T) {
|
||||
names := resolveSQLNames()
|
||||
if names.Login != "resolvespec_login" {
|
||||
t.Errorf("resolveSQLNames().Login = %q, want default", names.Login)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveSQLNames_WithOverride(t *testing.T) {
|
||||
names := resolveSQLNames(&SQLNames{Login: "custom_login"})
|
||||
if names.Login != "custom_login" {
|
||||
t.Errorf("resolveSQLNames().Login = %q, want %q", names.Login, "custom_login")
|
||||
}
|
||||
if names.Logout != "resolvespec_logout" {
|
||||
t.Errorf("resolveSQLNames().Logout = %q, want default", names.Logout)
|
||||
}
|
||||
}
|
||||
@@ -9,23 +9,23 @@ import (
|
||||
)
|
||||
|
||||
// DatabaseTwoFactorProvider implements TwoFactorAuthProvider using PostgreSQL stored procedures
|
||||
// Requires stored procedures: resolvespec_totp_enable, resolvespec_totp_disable,
|
||||
// resolvespec_totp_get_status, resolvespec_totp_get_secret,
|
||||
// resolvespec_totp_regenerate_backup_codes, resolvespec_totp_validate_backup_code
|
||||
// Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults)
|
||||
// See totp_database_schema.sql for procedure definitions
|
||||
type DatabaseTwoFactorProvider struct {
|
||||
db *sql.DB
|
||||
totpGen *TOTPGenerator
|
||||
db *sql.DB
|
||||
totpGen *TOTPGenerator
|
||||
sqlNames *SQLNames
|
||||
}
|
||||
|
||||
// NewDatabaseTwoFactorProvider creates a new database-backed 2FA provider
|
||||
func NewDatabaseTwoFactorProvider(db *sql.DB, config *TwoFactorConfig) *DatabaseTwoFactorProvider {
|
||||
func NewDatabaseTwoFactorProvider(db *sql.DB, config *TwoFactorConfig, names ...*SQLNames) *DatabaseTwoFactorProvider {
|
||||
if config == nil {
|
||||
config = DefaultTwoFactorConfig()
|
||||
}
|
||||
return &DatabaseTwoFactorProvider{
|
||||
db: db,
|
||||
totpGen: NewTOTPGenerator(config),
|
||||
db: db,
|
||||
totpGen: NewTOTPGenerator(config),
|
||||
sqlNames: resolveSQLNames(names...),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -76,7 +76,7 @@ func (p *DatabaseTwoFactorProvider) Enable2FA(userID int, secret string, backupC
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
|
||||
query := `SELECT p_success, p_error FROM resolvespec_totp_enable($1, $2, $3::jsonb)`
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1, $2, $3::jsonb)`, p.sqlNames.TOTPEnable)
|
||||
err = p.db.QueryRow(query, userID, secret, string(codesJSON)).Scan(&success, &errorMsg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("enable 2FA query failed: %w", err)
|
||||
@@ -97,7 +97,7 @@ func (p *DatabaseTwoFactorProvider) Disable2FA(userID int) error {
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
|
||||
query := `SELECT p_success, p_error FROM resolvespec_totp_disable($1)`
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1)`, p.sqlNames.TOTPDisable)
|
||||
err := p.db.QueryRow(query, userID).Scan(&success, &errorMsg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("disable 2FA query failed: %w", err)
|
||||
@@ -119,7 +119,7 @@ func (p *DatabaseTwoFactorProvider) Get2FAStatus(userID int) (bool, error) {
|
||||
var errorMsg sql.NullString
|
||||
var enabled bool
|
||||
|
||||
query := `SELECT p_success, p_error, p_enabled FROM resolvespec_totp_get_status($1)`
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_enabled FROM %s($1)`, p.sqlNames.TOTPGetStatus)
|
||||
err := p.db.QueryRow(query, userID).Scan(&success, &errorMsg, &enabled)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("get 2FA status query failed: %w", err)
|
||||
@@ -141,7 +141,7 @@ func (p *DatabaseTwoFactorProvider) Get2FASecret(userID int) (string, error) {
|
||||
var errorMsg sql.NullString
|
||||
var secret sql.NullString
|
||||
|
||||
query := `SELECT p_success, p_error, p_secret FROM resolvespec_totp_get_secret($1)`
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_secret FROM %s($1)`, p.sqlNames.TOTPGetSecret)
|
||||
err := p.db.QueryRow(query, userID).Scan(&success, &errorMsg, &secret)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("get 2FA secret query failed: %w", err)
|
||||
@@ -185,7 +185,7 @@ func (p *DatabaseTwoFactorProvider) GenerateBackupCodes(userID int, count int) (
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
|
||||
query := `SELECT p_success, p_error FROM resolvespec_totp_regenerate_backup_codes($1, $2::jsonb)`
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1, $2::jsonb)`, p.sqlNames.TOTPRegenerateBackup)
|
||||
err = p.db.QueryRow(query, userID, string(codesJSON)).Scan(&success, &errorMsg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("regenerate backup codes query failed: %w", err)
|
||||
@@ -212,7 +212,7 @@ func (p *DatabaseTwoFactorProvider) ValidateBackupCode(userID int, code string)
|
||||
var errorMsg sql.NullString
|
||||
var valid bool
|
||||
|
||||
query := `SELECT p_success, p_error, p_valid FROM resolvespec_totp_validate_backup_code($1, $2)`
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_valid FROM %s($1, $2)`, p.sqlNames.TOTPValidateBackupCode)
|
||||
err := p.db.QueryRow(query, userID, codeHash).Scan(&success, &errorMsg, &valid)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("validate backup code query failed: %w", err)
|
||||
|
||||
Reference in New Issue
Block a user