Compare commits

...

3 Commits

Author SHA1 Message Date
Hein
568df8c6d6 feat(security): add configurable SQL procedure names
Some checks failed
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Successful in -25m9s
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Successful in -24m29s
Build , Vet Test, and Lint / Build (push) Successful in -30m5s
Build , Vet Test, and Lint / Lint Code (push) Failing after -28m58s
Tests / Integration Tests (push) Failing after -30m26s
Tests / Unit Tests (push) Successful in -28m7s
* Introduce SQLNames struct to define stored procedure names.
* Update DatabaseAuthenticator, JWTAuthenticator, and other providers to use SQLNames for procedure calls.
* Remove hardcoded procedure names for better flexibility and customization.
* Implement validation for SQL names to ensure they are valid identifiers.
* Add tests for SQLNames functionality and merging behavior.
2026-03-31 14:25:59 +02:00
Hein
aa362c77da fix(cursor): trim parentheses from sort column names 2026-03-27 15:07:10 +02:00
Hein
1641eaf278 feat(resolvemcp): enhance handler with configuration support
* Introduce Config struct for BaseURL and BasePath settings
* Update handler creation functions to accept configuration
* Modify SSEServer to use dynamic base URL detection
* Adjust route setup functions to utilize BasePath from config
2026-03-27 13:56:03 +02:00
14 changed files with 619 additions and 194 deletions

View File

@@ -11,7 +11,9 @@ import (
) )
// 1. Create a handler // 1. Create a handler
handler := resolvemcp.NewHandlerWithGORM(db) handler := resolvemcp.NewHandlerWithGORM(db, resolvemcp.Config{
BaseURL: "http://localhost:8080",
})
// 2. Register models // 2. Register models
handler.RegisterModel("public", "users", &User{}) handler.RegisterModel("public", "users", &User{})
@@ -19,19 +21,35 @@ handler.RegisterModel("public", "orders", &Order{})
// 3. Mount routes // 3. Mount routes
r := mux.NewRouter() r := mux.NewRouter()
resolvemcp.SetupMuxRoutes(r, handler, "http://localhost:8080") resolvemcp.SetupMuxRoutes(r, handler)
``` ```
--- ---
## Config
```go
type Config struct {
// BaseURL is the public-facing base URL of the server (e.g. "http://localhost:8080").
// Sent to MCP clients during the SSE handshake so they know where to POST messages.
// If empty, it is detected from each incoming request using the Host header and
// TLS state (X-Forwarded-Proto is honoured for reverse-proxy deployments).
BaseURL string
// BasePath is the URL path prefix where MCP endpoints are mounted (e.g. "/mcp").
// Required.
BasePath string
}
```
## Handler Creation ## Handler Creation
| Function | Description | | Function | Description |
|---|---| |---|---|
| `NewHandlerWithGORM(db *gorm.DB) *Handler` | Backed by GORM | | `NewHandlerWithGORM(db *gorm.DB, cfg Config) *Handler` | Backed by GORM |
| `NewHandlerWithBun(db *bun.DB) *Handler` | Backed by Bun | | `NewHandlerWithBun(db *bun.DB, cfg Config) *Handler` | Backed by Bun |
| `NewHandlerWithDB(db common.Database) *Handler` | Backed by any `common.Database` | | `NewHandlerWithDB(db common.Database, cfg Config) *Handler` | Backed by any `common.Database` |
| `NewHandler(db common.Database, registry common.ModelRegistry) *Handler` | Full control over registry | | `NewHandler(db common.Database, registry common.ModelRegistry, cfg Config) *Handler` | Full control over registry |
--- ---
@@ -53,40 +71,43 @@ Each call immediately creates four MCP **tools** and one MCP **resource** for th
The `*server.SSEServer` returned by any of the helpers below implements `http.Handler`, so it works with every Go HTTP framework. The `*server.SSEServer` returned by any of the helpers below implements `http.Handler`, so it works with every Go HTTP framework.
`Config.BasePath` is required and used for all route registration.
`Config.BaseURL` is optional — when empty it is detected from each request.
### Gorilla Mux ### Gorilla Mux
```go ```go
resolvemcp.SetupMuxRoutes(r, handler, "http://localhost:8080") resolvemcp.SetupMuxRoutes(r, handler)
``` ```
Registers: Registers:
| Route | Method | Description | | Route | Method | Description |
|---|---|---| |---|---|---|
| `/mcp/sse` | GET | SSE connection — clients subscribe here | | `{BasePath}/sse` | GET | SSE connection — clients subscribe here |
| `/mcp/message` | POST | JSON-RPC — clients send requests here | | `{BasePath}/message` | POST | JSON-RPC — clients send requests here |
| `/mcp/*` | any | Full SSE server (convenience prefix) | | `{BasePath}/*` | any | Full SSE server (convenience prefix) |
### bunrouter ### bunrouter
```go ```go
resolvemcp.SetupBunRouterRoutes(router, handler, "http://localhost:8080", "/mcp") resolvemcp.SetupBunRouterRoutes(router, handler)
``` ```
Registers `GET /mcp/sse` and `POST /mcp/message` on the provided `*bunrouter.Router`. Registers `GET {BasePath}/sse` and `POST {BasePath}/message` on the provided `*bunrouter.Router`.
### Gin (or any `http.Handler`-compatible framework) ### Gin (or any `http.Handler`-compatible framework)
Use `handler.SSEServer` to get a pre-bound `*server.SSEServer` and wrap it with the framework's adapter: Use `handler.SSEServer()` to get an `http.Handler` and wrap it with the framework's adapter:
```go ```go
sse := handler.SSEServer("http://localhost:8080", "/mcp") sse := handler.SSEServer()
// Gin // Gin
engine.Any("/mcp/*path", gin.WrapH(sse)) engine.Any("/mcp/*path", gin.WrapH(sse))
// net/http // net/http
http.Handle("/mcp/", http.StripPrefix("/mcp", sse)) http.Handle("/mcp/", sse)
// Echo // Echo
e.Any("/mcp/*", echo.WrapHandler(sse)) e.Any("/mcp/*", echo.WrapHandler(sse))

View File

@@ -46,7 +46,7 @@ func getCursorFilter(
reverse := direction < 0 reverse := direction < 0
for _, s := range sortItems { for _, s := range sortItems {
col := strings.TrimSpace(s.Column) col := strings.Trim(strings.TrimSpace(s.Column), "()")
if col == "" { if col == "" {
continue continue
} }

View File

@@ -5,8 +5,10 @@ import (
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http"
"reflect" "reflect"
"strings" "strings"
"sync"
"github.com/mark3labs/mcp-go/server" "github.com/mark3labs/mcp-go/server"
@@ -21,17 +23,19 @@ type Handler struct {
registry common.ModelRegistry registry common.ModelRegistry
hooks *HookRegistry hooks *HookRegistry
mcpServer *server.MCPServer mcpServer *server.MCPServer
config Config
name string name string
version string version string
} }
// NewHandler creates a Handler with the given database and model registry. // NewHandler creates a Handler with the given database, model registry, and config.
func NewHandler(db common.Database, registry common.ModelRegistry) *Handler { func NewHandler(db common.Database, registry common.ModelRegistry, cfg Config) *Handler {
return &Handler{ return &Handler{
db: db, db: db,
registry: registry, registry: registry,
hooks: NewHookRegistry(), hooks: NewHookRegistry(),
mcpServer: server.NewMCPServer("resolvemcp", "1.0.0"), mcpServer: server.NewMCPServer("resolvemcp", "1.0.0"),
config: cfg,
name: "resolvemcp", name: "resolvemcp",
version: "1.0.0", version: "1.0.0",
} }
@@ -52,12 +56,18 @@ func (h *Handler) MCPServer() *server.MCPServer {
return h.mcpServer return h.mcpServer
} }
// SSEServer creates an *server.SSEServer bound to this handler. // SSEServer returns an http.Handler that serves MCP over SSE.
// Use it to mount MCP on any HTTP framework that accepts http.Handler. // Config.BasePath must be set. Config.BaseURL is used when set; if empty it is
// // detected automatically from each incoming request.
// sse := handler.SSEServer("http://localhost:8080", "/mcp") func (h *Handler) SSEServer() http.Handler {
// ginEngine.Any("/mcp/*path", gin.WrapH(sse)) if h.config.BaseURL != "" {
func (h *Handler) SSEServer(baseURL, basePath string) *server.SSEServer { return h.newSSEServer(h.config.BaseURL, h.config.BasePath)
}
return &dynamicSSEHandler{h: h}
}
// newSSEServer creates a concrete *server.SSEServer for known baseURL and basePath values.
func (h *Handler) newSSEServer(baseURL, basePath string) *server.SSEServer {
return server.NewSSEServer( return server.NewSSEServer(
h.mcpServer, h.mcpServer,
server.WithBaseURL(baseURL), server.WithBaseURL(baseURL),
@@ -65,6 +75,44 @@ func (h *Handler) SSEServer(baseURL, basePath string) *server.SSEServer {
) )
} }
// dynamicSSEHandler detects BaseURL from each request and delegates to a cached
// *server.SSEServer per detected baseURL. Used when Config.BaseURL is empty.
type dynamicSSEHandler struct {
h *Handler
mu sync.Mutex
pool map[string]*server.SSEServer
}
func (d *dynamicSSEHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
baseURL := requestBaseURL(r)
d.mu.Lock()
if d.pool == nil {
d.pool = make(map[string]*server.SSEServer)
}
s, ok := d.pool[baseURL]
if !ok {
s = d.h.newSSEServer(baseURL, d.h.config.BasePath)
d.pool[baseURL] = s
}
d.mu.Unlock()
s.ServeHTTP(w, r)
}
// requestBaseURL builds the base URL from an incoming request.
// It honours the X-Forwarded-Proto header for deployments behind a proxy.
func requestBaseURL(r *http.Request) string {
scheme := "http"
if r.TLS != nil {
scheme = "https"
}
if proto := r.Header.Get("X-Forwarded-Proto"); proto != "" {
scheme = proto
}
return scheme + "://" + r.Host
}
// RegisterModel registers a model and immediately exposes it as MCP tools and a resource. // RegisterModel registers a model and immediately exposes it as MCP tools and a resource.
func (h *Handler) RegisterModel(schema, entity string, model interface{}) error { func (h *Handler) RegisterModel(schema, entity string, model interface{}) error {
fullName := buildModelName(schema, entity) fullName := buildModelName(schema, entity)

View File

@@ -8,18 +8,17 @@
// //
// Usage: // Usage:
// //
// handler := resolvemcp.NewHandlerWithGORM(db) // handler := resolvemcp.NewHandlerWithGORM(db, resolvemcp.Config{BaseURL: "http://localhost:8080"})
// handler.RegisterModel("public", "users", &User{}) // handler.RegisterModel("public", "users", &User{})
// //
// r := mux.NewRouter() // r := mux.NewRouter()
// resolvemcp.SetupMuxRoutes(r, handler, "http://localhost:8080") // resolvemcp.SetupMuxRoutes(r, handler)
package resolvemcp package resolvemcp
import ( import (
"net/http" "net/http"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/mark3labs/mcp-go/server"
"github.com/uptrace/bun" "github.com/uptrace/bun"
bunrouter "github.com/uptrace/bunrouter" bunrouter "github.com/uptrace/bunrouter"
"gorm.io/gorm" "gorm.io/gorm"
@@ -29,72 +28,73 @@ import (
"github.com/bitechdev/ResolveSpec/pkg/modelregistry" "github.com/bitechdev/ResolveSpec/pkg/modelregistry"
) )
// Config holds configuration for the resolvemcp handler.
type Config struct {
// BaseURL is the public-facing base URL of the server (e.g. "http://localhost:8080").
// It is sent to MCP clients during the SSE handshake so they know where to POST messages.
BaseURL string
// BasePath is the URL path prefix where the MCP endpoints are mounted (e.g. "/mcp").
// If empty, the path is detected from each incoming request automatically.
BasePath string
}
// NewHandlerWithGORM creates a Handler backed by a GORM database connection. // NewHandlerWithGORM creates a Handler backed by a GORM database connection.
func NewHandlerWithGORM(db *gorm.DB) *Handler { func NewHandlerWithGORM(db *gorm.DB, cfg Config) *Handler {
return NewHandler(database.NewGormAdapter(db), modelregistry.NewModelRegistry()) return NewHandler(database.NewGormAdapter(db), modelregistry.NewModelRegistry(), cfg)
} }
// NewHandlerWithBun creates a Handler backed by a Bun database connection. // NewHandlerWithBun creates a Handler backed by a Bun database connection.
func NewHandlerWithBun(db *bun.DB) *Handler { func NewHandlerWithBun(db *bun.DB, cfg Config) *Handler {
return NewHandler(database.NewBunAdapter(db), modelregistry.NewModelRegistry()) return NewHandler(database.NewBunAdapter(db), modelregistry.NewModelRegistry(), cfg)
} }
// NewHandlerWithDB creates a Handler using an existing common.Database and a new registry. // NewHandlerWithDB creates a Handler using an existing common.Database and a new registry.
func NewHandlerWithDB(db common.Database) *Handler { func NewHandlerWithDB(db common.Database, cfg Config) *Handler {
return NewHandler(db, modelregistry.NewModelRegistry()) return NewHandler(db, modelregistry.NewModelRegistry(), cfg)
} }
// SetupMuxRoutes mounts the MCP HTTP/SSE endpoints on the given Gorilla Mux router. // SetupMuxRoutes mounts the MCP HTTP/SSE endpoints on the given Gorilla Mux router
// // using the base path from Config.BasePath (falls back to "/mcp" if empty).
// baseURL is the public-facing base URL of the server (e.g. "http://localhost:8080").
// It is sent to MCP clients during the SSE handshake so they know where to POST messages.
// //
// Two routes are registered: // Two routes are registered:
// - GET /mcp/sse — SSE connection endpoint (client subscribes here) // - GET {basePath}/sse — SSE connection endpoint (client subscribes here)
// - POST /mcp/message — JSON-RPC message endpoint (client sends requests here) // - POST {basePath}/message — JSON-RPC message endpoint (client sends requests here)
// //
// To protect these routes with authentication, wrap the mux router or apply middleware // To protect these routes with authentication, wrap the mux router or apply middleware
// before calling SetupMuxRoutes. // before calling SetupMuxRoutes.
func SetupMuxRoutes(muxRouter *mux.Router, handler *Handler, baseURL string) { func SetupMuxRoutes(muxRouter *mux.Router, handler *Handler) {
sseServer := server.NewSSEServer( basePath := handler.config.BasePath
handler.mcpServer, h := handler.SSEServer()
server.WithBaseURL(baseURL),
server.WithBasePath("/mcp"),
)
muxRouter.Handle("/mcp/sse", sseServer.SSEHandler()).Methods("GET", "OPTIONS") muxRouter.Handle(basePath+"/sse", h).Methods("GET", "OPTIONS")
muxRouter.Handle("/mcp/message", sseServer.MessageHandler()).Methods("POST", "OPTIONS") muxRouter.Handle(basePath+"/message", h).Methods("POST", "OPTIONS")
// Convenience: also expose the full SSE server at /mcp for clients that // Convenience: also expose the full SSE server at basePath for clients that
// use ServeHTTP directly (e.g. net/http default mux). // use ServeHTTP directly (e.g. net/http default mux).
muxRouter.PathPrefix("/mcp").Handler(http.StripPrefix("/mcp", sseServer)) muxRouter.PathPrefix(basePath).Handler(http.StripPrefix(basePath, h))
} }
// NewSSEServer creates an *server.SSEServer that can be mounted manually, // SetupBunRouterRoutes mounts the MCP HTTP/SSE endpoints on a bunrouter router
// useful when integrating with non-Mux routers or adding extra middleware. // using the base path from Config.BasePath.
// //
// sseServer := resolvemcp.NewSSEServer(handler, "http://localhost:8080", "/mcp") // Two routes are registered:
// http.Handle("/mcp/", http.StripPrefix("/mcp", sseServer))
func NewSSEServer(handler *Handler, baseURL, basePath string) *server.SSEServer {
return server.NewSSEServer(
handler.mcpServer,
server.WithBaseURL(baseURL),
server.WithBasePath(basePath),
)
}
// SetupBunRouterRoutes mounts the MCP HTTP/SSE endpoints on a bunrouter router.
//
// Two routes are registered under the given basePath prefix:
// - GET {basePath}/sse — SSE connection endpoint // - GET {basePath}/sse — SSE connection endpoint
// - POST {basePath}/message — JSON-RPC message endpoint // - POST {basePath}/message — JSON-RPC message endpoint
func SetupBunRouterRoutes(router *bunrouter.Router, handler *Handler, baseURL, basePath string) { func SetupBunRouterRoutes(router *bunrouter.Router, handler *Handler) {
sseServer := server.NewSSEServer( basePath := handler.config.BasePath
handler.mcpServer, h := handler.SSEServer()
server.WithBaseURL(baseURL),
server.WithBasePath(basePath),
)
router.GET(basePath+"/sse", bunrouter.HTTPHandler(sseServer.SSEHandler())) router.GET(basePath+"/sse", bunrouter.HTTPHandler(h))
router.POST(basePath+"/message", bunrouter.HTTPHandler(sseServer.MessageHandler())) router.POST(basePath+"/message", bunrouter.HTTPHandler(h))
}
// NewSSEServer returns an http.Handler that serves MCP over SSE.
// If Config.BasePath is set it is used directly; otherwise the base path is
// detected from each incoming request (by stripping the "/sse" or "/message" suffix).
//
// h := resolvemcp.NewSSEServer(handler)
// http.Handle("/api/mcp/", h)
func NewSSEServer(handler *Handler) http.Handler {
return handler.SSEServer()
} }

View File

@@ -67,7 +67,7 @@ func GetCursorFilter(
// 4. Process each sort column // 4. Process each sort column
// --------------------------------------------------------------------- // // --------------------------------------------------------------------- //
for _, s := range sortItems { for _, s := range sortItems {
col := strings.TrimSpace(s.Column) col := strings.Trim(strings.TrimSpace(s.Column), "()")
if col == "" { if col == "" {
continue continue
} }

View File

@@ -64,7 +64,7 @@ func (opts *ExtendedRequestOptions) GetCursorFilter(
// 4. Process each sort column // 4. Process each sort column
// --------------------------------------------------------------------- // // --------------------------------------------------------------------- //
for _, s := range sortItems { for _, s := range sortItems {
col := strings.TrimSpace(s.Column) col := strings.Trim(strings.TrimSpace(s.Column), "()")
if col == "" { if col == "" {
continue continue
} }

View File

@@ -258,11 +258,8 @@ func (a *JWTAuthenticator) Login(ctx context.Context, req security.LoginRequest)
} }
func (a *JWTAuthenticator) Logout(ctx context.Context, req security.LogoutRequest) error { func (a *JWTAuthenticator) Logout(ctx context.Context, req security.LogoutRequest) error {
// Add to blacklist // Invalidate session via stored procedure
return a.db.WithContext(ctx).Table("token_blacklist").Create(map[string]any{ return nil
"token": req.Token,
"user_id": req.UserID,
}).Error
} }
func (a *JWTAuthenticator) Authenticate(r *http.Request) (*security.UserContext, error) { func (a *JWTAuthenticator) Authenticate(r *http.Request) (*security.UserContext, error) {

View File

@@ -135,12 +135,6 @@ func (a *JWTAuthenticatorExample) Login(ctx context.Context, req LoginRequest) (
} }
func (a *JWTAuthenticatorExample) Logout(ctx context.Context, req LogoutRequest) error { func (a *JWTAuthenticatorExample) Logout(ctx context.Context, req LogoutRequest) error {
// For JWT, logout could involve token blacklisting
// Add token to blacklist table
// err := a.db.WithContext(ctx).Table("token_blacklist").Create(map[string]interface{}{
// "token": req.Token,
// "expires_at": time.Now().Add(24 * time.Hour),
// }).Error
return nil return nil
} }

View File

@@ -244,10 +244,10 @@ func (a *DatabaseAuthenticator) oauth2GetOrCreateUser(ctx context.Context, userC
var errMsg *string var errMsg *string
var userID *int var userID *int
err = a.db.QueryRowContext(ctx, ` err = a.db.QueryRowContext(ctx, fmt.Sprintf(`
SELECT p_success, p_error, p_user_id SELECT p_success, p_error, p_user_id
FROM resolvespec_oauth_getorcreateuser($1::jsonb) FROM %s($1::jsonb)
`, userJSON).Scan(&success, &errMsg, &userID) `, a.sqlNames.OAuthGetOrCreateUser), userJSON).Scan(&success, &errMsg, &userID)
if err != nil { if err != nil {
return 0, fmt.Errorf("failed to get or create user: %w", err) return 0, fmt.Errorf("failed to get or create user: %w", err)
@@ -287,10 +287,10 @@ func (a *DatabaseAuthenticator) oauth2CreateSession(ctx context.Context, session
var success bool var success bool
var errMsg *string var errMsg *string
err = a.db.QueryRowContext(ctx, ` err = a.db.QueryRowContext(ctx, fmt.Sprintf(`
SELECT p_success, p_error SELECT p_success, p_error
FROM resolvespec_oauth_createsession($1::jsonb) FROM %s($1::jsonb)
`, sessionJSON).Scan(&success, &errMsg) `, a.sqlNames.OAuthCreateSession), sessionJSON).Scan(&success, &errMsg)
if err != nil { if err != nil {
return fmt.Errorf("failed to create session: %w", err) return fmt.Errorf("failed to create session: %w", err)
@@ -385,10 +385,10 @@ func (a *DatabaseAuthenticator) OAuth2RefreshToken(ctx context.Context, refreshT
var errMsg *string var errMsg *string
var sessionData []byte var sessionData []byte
err = a.db.QueryRowContext(ctx, ` err = a.db.QueryRowContext(ctx, fmt.Sprintf(`
SELECT p_success, p_error, p_data::text SELECT p_success, p_error, p_data::text
FROM resolvespec_oauth_getrefreshtoken($1) FROM %s($1)
`, refreshToken).Scan(&success, &errMsg, &sessionData) `, a.sqlNames.OAuthGetRefreshToken), refreshToken).Scan(&success, &errMsg, &sessionData)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get session by refresh token: %w", err) return nil, fmt.Errorf("failed to get session by refresh token: %w", err)
@@ -451,10 +451,10 @@ func (a *DatabaseAuthenticator) OAuth2RefreshToken(ctx context.Context, refreshT
var updateSuccess bool var updateSuccess bool
var updateErrMsg *string var updateErrMsg *string
err = a.db.QueryRowContext(ctx, ` err = a.db.QueryRowContext(ctx, fmt.Sprintf(`
SELECT p_success, p_error SELECT p_success, p_error
FROM resolvespec_oauth_updaterefreshtoken($1::jsonb) FROM %s($1::jsonb)
`, updateJSON).Scan(&updateSuccess, &updateErrMsg) `, a.sqlNames.OAuthUpdateRefreshToken), updateJSON).Scan(&updateSuccess, &updateErrMsg)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to update session: %w", err) return nil, fmt.Errorf("failed to update session: %w", err)
@@ -472,10 +472,10 @@ func (a *DatabaseAuthenticator) OAuth2RefreshToken(ctx context.Context, refreshT
var userErrMsg *string var userErrMsg *string
var userData []byte var userData []byte
err = a.db.QueryRowContext(ctx, ` err = a.db.QueryRowContext(ctx, fmt.Sprintf(`
SELECT p_success, p_error, p_data::text SELECT p_success, p_error, p_data::text
FROM resolvespec_oauth_getuser($1) FROM %s($1)
`, session.UserID).Scan(&userSuccess, &userErrMsg, &userData) `, a.sqlNames.OAuthGetUser), session.UserID).Scan(&userSuccess, &userErrMsg, &userData)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get user data: %w", err) return nil, fmt.Errorf("failed to get user data: %w", err)

View File

@@ -11,12 +11,14 @@ import (
) )
// DatabasePasskeyProvider implements PasskeyProvider using database storage // DatabasePasskeyProvider implements PasskeyProvider using database storage
// Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults)
type DatabasePasskeyProvider struct { type DatabasePasskeyProvider struct {
db *sql.DB db *sql.DB
rpID string // Relying Party ID (domain) rpID string // Relying Party ID (domain)
rpName string // Relying Party display name rpName string // Relying Party display name
rpOrigin string // Expected origin for WebAuthn rpOrigin string // Expected origin for WebAuthn
timeout int64 // Timeout in milliseconds (default: 60000) timeout int64 // Timeout in milliseconds (default: 60000)
sqlNames *SQLNames
} }
// DatabasePasskeyProviderOptions configures the passkey provider // DatabasePasskeyProviderOptions configures the passkey provider
@@ -29,6 +31,8 @@ type DatabasePasskeyProviderOptions struct {
RPOrigin string RPOrigin string
// Timeout is the timeout for operations in milliseconds (default: 60000) // Timeout is the timeout for operations in milliseconds (default: 60000)
Timeout int64 Timeout int64
// SQLNames provides custom SQL procedure/function names. If nil, uses DefaultSQLNames().
SQLNames *SQLNames
} }
// NewDatabasePasskeyProvider creates a new database-backed passkey provider // NewDatabasePasskeyProvider creates a new database-backed passkey provider
@@ -37,12 +41,15 @@ func NewDatabasePasskeyProvider(db *sql.DB, opts DatabasePasskeyProviderOptions)
opts.Timeout = 60000 // 60 seconds default opts.Timeout = 60000 // 60 seconds default
} }
sqlNames := MergeSQLNames(DefaultSQLNames(), opts.SQLNames)
return &DatabasePasskeyProvider{ return &DatabasePasskeyProvider{
db: db, db: db,
rpID: opts.RPID, rpID: opts.RPID,
rpName: opts.RPName, rpName: opts.RPName,
rpOrigin: opts.RPOrigin, rpOrigin: opts.RPOrigin,
timeout: opts.Timeout, timeout: opts.Timeout,
sqlNames: sqlNames,
} }
} }
@@ -132,7 +139,7 @@ func (p *DatabasePasskeyProvider) CompleteRegistration(ctx context.Context, user
var errorMsg sql.NullString var errorMsg sql.NullString
var credentialID sql.NullInt64 var credentialID sql.NullInt64
query := `SELECT p_success, p_error, p_credential_id FROM resolvespec_passkey_store_credential($1::jsonb)` query := fmt.Sprintf(`SELECT p_success, p_error, p_credential_id FROM %s($1::jsonb)`, p.sqlNames.PasskeyStoreCredential)
err = p.db.QueryRowContext(ctx, query, string(credJSON)).Scan(&success, &errorMsg, &credentialID) err = p.db.QueryRowContext(ctx, query, string(credJSON)).Scan(&success, &errorMsg, &credentialID)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to store credential: %w", err) return nil, fmt.Errorf("failed to store credential: %w", err)
@@ -173,7 +180,7 @@ func (p *DatabasePasskeyProvider) BeginAuthentication(ctx context.Context, usern
var userID sql.NullInt64 var userID sql.NullInt64
var credentialsJSON sql.NullString var credentialsJSON sql.NullString
query := `SELECT p_success, p_error, p_user_id, p_credentials::text FROM resolvespec_passkey_get_credentials_by_username($1)` query := fmt.Sprintf(`SELECT p_success, p_error, p_user_id, p_credentials::text FROM %s($1)`, p.sqlNames.PasskeyGetCredsByUsername)
err := p.db.QueryRowContext(ctx, query, username).Scan(&success, &errorMsg, &userID, &credentialsJSON) err := p.db.QueryRowContext(ctx, query, username).Scan(&success, &errorMsg, &userID, &credentialsJSON)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get credentials: %w", err) return nil, fmt.Errorf("failed to get credentials: %w", err)
@@ -233,7 +240,7 @@ func (p *DatabasePasskeyProvider) CompleteAuthentication(ctx context.Context, re
var errorMsg sql.NullString var errorMsg sql.NullString
var credentialJSON sql.NullString var credentialJSON sql.NullString
query := `SELECT p_success, p_error, p_credential::text FROM resolvespec_passkey_get_credential($1)` query := fmt.Sprintf(`SELECT p_success, p_error, p_credential::text FROM %s($1)`, p.sqlNames.PasskeyGetCredential)
err := p.db.QueryRowContext(ctx, query, response.RawID).Scan(&success, &errorMsg, &credentialJSON) err := p.db.QueryRowContext(ctx, query, response.RawID).Scan(&success, &errorMsg, &credentialJSON)
if err != nil { if err != nil {
return 0, fmt.Errorf("failed to get credential: %w", err) return 0, fmt.Errorf("failed to get credential: %w", err)
@@ -264,7 +271,7 @@ func (p *DatabasePasskeyProvider) CompleteAuthentication(ctx context.Context, re
var updateError sql.NullString var updateError sql.NullString
var cloneWarning sql.NullBool var cloneWarning sql.NullBool
updateQuery := `SELECT p_success, p_error, p_clone_warning FROM resolvespec_passkey_update_counter($1, $2)` updateQuery := fmt.Sprintf(`SELECT p_success, p_error, p_clone_warning FROM %s($1, $2)`, p.sqlNames.PasskeyUpdateCounter)
err = p.db.QueryRowContext(ctx, updateQuery, response.RawID, newCounter).Scan(&updateSuccess, &updateError, &cloneWarning) err = p.db.QueryRowContext(ctx, updateQuery, response.RawID, newCounter).Scan(&updateSuccess, &updateError, &cloneWarning)
if err != nil { if err != nil {
return 0, fmt.Errorf("failed to update counter: %w", err) return 0, fmt.Errorf("failed to update counter: %w", err)
@@ -283,7 +290,7 @@ func (p *DatabasePasskeyProvider) GetCredentials(ctx context.Context, userID int
var errorMsg sql.NullString var errorMsg sql.NullString
var credentialsJSON sql.NullString var credentialsJSON sql.NullString
query := `SELECT p_success, p_error, p_credentials::text FROM resolvespec_passkey_get_user_credentials($1)` query := fmt.Sprintf(`SELECT p_success, p_error, p_credentials::text FROM %s($1)`, p.sqlNames.PasskeyGetUserCredentials)
err := p.db.QueryRowContext(ctx, query, userID).Scan(&success, &errorMsg, &credentialsJSON) err := p.db.QueryRowContext(ctx, query, userID).Scan(&success, &errorMsg, &credentialsJSON)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get credentials: %w", err) return nil, fmt.Errorf("failed to get credentials: %w", err)
@@ -362,7 +369,7 @@ func (p *DatabasePasskeyProvider) DeleteCredential(ctx context.Context, userID i
var success bool var success bool
var errorMsg sql.NullString var errorMsg sql.NullString
query := `SELECT p_success, p_error FROM resolvespec_passkey_delete_credential($1, $2)` query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1, $2)`, p.sqlNames.PasskeyDeleteCredential)
err = p.db.QueryRowContext(ctx, query, userID, credID).Scan(&success, &errorMsg) err = p.db.QueryRowContext(ctx, query, userID, credID).Scan(&success, &errorMsg)
if err != nil { if err != nil {
return fmt.Errorf("failed to delete credential: %w", err) return fmt.Errorf("failed to delete credential: %w", err)
@@ -388,7 +395,7 @@ func (p *DatabasePasskeyProvider) UpdateCredentialName(ctx context.Context, user
var success bool var success bool
var errorMsg sql.NullString var errorMsg sql.NullString
query := `SELECT p_success, p_error FROM resolvespec_passkey_update_name($1, $2, $3)` query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1, $2, $3)`, p.sqlNames.PasskeyUpdateName)
err = p.db.QueryRowContext(ctx, query, userID, credID, name).Scan(&success, &errorMsg) err = p.db.QueryRowContext(ctx, query, userID, credID, name).Scan(&success, &errorMsg)
if err != nil { if err != nil {
return fmt.Errorf("failed to update credential name: %w", err) return fmt.Errorf("failed to update credential name: %w", err)

View File

@@ -58,8 +58,7 @@ func (a *HeaderAuthenticator) Authenticate(r *http.Request) (*UserContext, error
// DatabaseAuthenticator provides session-based authentication with database storage // DatabaseAuthenticator provides session-based authentication with database storage
// All database operations go through stored procedures for security and consistency // All database operations go through stored procedures for security and consistency
// Requires stored procedures: resolvespec_login, resolvespec_logout, resolvespec_session, // Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults)
// resolvespec_session_update, resolvespec_refresh_token
// See database_schema.sql for procedure definitions // See database_schema.sql for procedure definitions
// Also supports multiple OAuth2 providers configured with WithOAuth2() // Also supports multiple OAuth2 providers configured with WithOAuth2()
// Also supports passkey authentication configured with WithPasskey() // Also supports passkey authentication configured with WithPasskey()
@@ -67,6 +66,7 @@ type DatabaseAuthenticator struct {
db *sql.DB db *sql.DB
cache *cache.Cache cache *cache.Cache
cacheTTL time.Duration cacheTTL time.Duration
sqlNames *SQLNames
// OAuth2 providers registry (multiple providers supported) // OAuth2 providers registry (multiple providers supported)
oauth2Providers map[string]*OAuth2Provider oauth2Providers map[string]*OAuth2Provider
@@ -85,6 +85,9 @@ type DatabaseAuthenticatorOptions struct {
Cache *cache.Cache Cache *cache.Cache
// PasskeyProvider is an optional passkey provider for WebAuthn/FIDO2 authentication // PasskeyProvider is an optional passkey provider for WebAuthn/FIDO2 authentication
PasskeyProvider PasskeyProvider PasskeyProvider PasskeyProvider
// SQLNames provides custom SQL procedure/function names. If nil, uses DefaultSQLNames().
// Partial overrides are supported: only set the fields you want to change.
SQLNames *SQLNames
} }
func NewDatabaseAuthenticator(db *sql.DB) *DatabaseAuthenticator { func NewDatabaseAuthenticator(db *sql.DB) *DatabaseAuthenticator {
@@ -103,10 +106,13 @@ func NewDatabaseAuthenticatorWithOptions(db *sql.DB, opts DatabaseAuthenticatorO
cacheInstance = cache.GetDefaultCache() cacheInstance = cache.GetDefaultCache()
} }
sqlNames := MergeSQLNames(DefaultSQLNames(), opts.SQLNames)
return &DatabaseAuthenticator{ return &DatabaseAuthenticator{
db: db, db: db,
cache: cacheInstance, cache: cacheInstance,
cacheTTL: opts.CacheTTL, cacheTTL: opts.CacheTTL,
sqlNames: sqlNames,
passkeyProvider: opts.PasskeyProvider, passkeyProvider: opts.PasskeyProvider,
} }
} }
@@ -118,12 +124,11 @@ func (a *DatabaseAuthenticator) Login(ctx context.Context, req LoginRequest) (*L
return nil, fmt.Errorf("failed to marshal login request: %w", err) return nil, fmt.Errorf("failed to marshal login request: %w", err)
} }
// Call resolvespec_login stored procedure
var success bool var success bool
var errorMsg sql.NullString var errorMsg sql.NullString
var dataJSON sql.NullString var dataJSON sql.NullString
query := `SELECT p_success, p_error, p_data::text FROM resolvespec_login($1::jsonb)` query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Login)
err = a.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON) err = a.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
if err != nil { if err != nil {
return nil, fmt.Errorf("login query failed: %w", err) return nil, fmt.Errorf("login query failed: %w", err)
@@ -153,12 +158,11 @@ func (a *DatabaseAuthenticator) Register(ctx context.Context, req RegisterReques
return nil, fmt.Errorf("failed to marshal register request: %w", err) return nil, fmt.Errorf("failed to marshal register request: %w", err)
} }
// Call resolvespec_register stored procedure
var success bool var success bool
var errorMsg sql.NullString var errorMsg sql.NullString
var dataJSON sql.NullString var dataJSON sql.NullString
query := `SELECT p_success, p_error, p_data::text FROM resolvespec_register($1::jsonb)` query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Register)
err = a.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON) err = a.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
if err != nil { if err != nil {
return nil, fmt.Errorf("register query failed: %w", err) return nil, fmt.Errorf("register query failed: %w", err)
@@ -187,12 +191,11 @@ func (a *DatabaseAuthenticator) Logout(ctx context.Context, req LogoutRequest) e
return fmt.Errorf("failed to marshal logout request: %w", err) return fmt.Errorf("failed to marshal logout request: %w", err)
} }
// Call resolvespec_logout stored procedure
var success bool var success bool
var errorMsg sql.NullString var errorMsg sql.NullString
var dataJSON sql.NullString var dataJSON sql.NullString
query := `SELECT p_success, p_error, p_data::text FROM resolvespec_logout($1::jsonb)` query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Logout)
err = a.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON) err = a.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
if err != nil { if err != nil {
return fmt.Errorf("logout query failed: %w", err) return fmt.Errorf("logout query failed: %w", err)
@@ -266,7 +269,7 @@ func (a *DatabaseAuthenticator) Authenticate(r *http.Request) (*UserContext, err
var errorMsg sql.NullString var errorMsg sql.NullString
var userJSON sql.NullString var userJSON sql.NullString
query := `SELECT p_success, p_error, p_user::text FROM resolvespec_session($1, $2)` query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2)`, a.sqlNames.Session)
err := a.db.QueryRowContext(r.Context(), query, token, reference).Scan(&success, &errorMsg, &userJSON) err := a.db.QueryRowContext(r.Context(), query, token, reference).Scan(&success, &errorMsg, &userJSON)
if err != nil { if err != nil {
return nil, fmt.Errorf("session query failed: %w", err) return nil, fmt.Errorf("session query failed: %w", err)
@@ -338,24 +341,22 @@ func (a *DatabaseAuthenticator) updateSessionActivity(ctx context.Context, sessi
return return
} }
// Call resolvespec_session_update stored procedure
var success bool var success bool
var errorMsg sql.NullString var errorMsg sql.NullString
var updatedUserJSON sql.NullString var updatedUserJSON sql.NullString
query := `SELECT p_success, p_error, p_user::text FROM resolvespec_session_update($1, $2::jsonb)` query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2::jsonb)`, a.sqlNames.SessionUpdate)
_ = a.db.QueryRowContext(ctx, query, sessionToken, string(userJSON)).Scan(&success, &errorMsg, &updatedUserJSON) _ = a.db.QueryRowContext(ctx, query, sessionToken, string(userJSON)).Scan(&success, &errorMsg, &updatedUserJSON)
} }
// RefreshToken implements Refreshable interface // RefreshToken implements Refreshable interface
func (a *DatabaseAuthenticator) RefreshToken(ctx context.Context, refreshToken string) (*LoginResponse, error) { func (a *DatabaseAuthenticator) RefreshToken(ctx context.Context, refreshToken string) (*LoginResponse, error) {
// Call api_refresh_token stored procedure
// First, we need to get the current user context for the refresh token // First, we need to get the current user context for the refresh token
var success bool var success bool
var errorMsg sql.NullString var errorMsg sql.NullString
var userJSON sql.NullString var userJSON sql.NullString
// Get current session to pass to refresh // Get current session to pass to refresh
query := `SELECT p_success, p_error, p_user::text FROM resolvespec_session($1, $2)` query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2)`, a.sqlNames.Session)
err := a.db.QueryRowContext(ctx, query, refreshToken, "refresh").Scan(&success, &errorMsg, &userJSON) err := a.db.QueryRowContext(ctx, query, refreshToken, "refresh").Scan(&success, &errorMsg, &userJSON)
if err != nil { if err != nil {
return nil, fmt.Errorf("refresh token query failed: %w", err) return nil, fmt.Errorf("refresh token query failed: %w", err)
@@ -368,12 +369,11 @@ func (a *DatabaseAuthenticator) RefreshToken(ctx context.Context, refreshToken s
return nil, fmt.Errorf("invalid refresh token") return nil, fmt.Errorf("invalid refresh token")
} }
// Call resolvespec_refresh_token to generate new token
var newSuccess bool var newSuccess bool
var newErrorMsg sql.NullString var newErrorMsg sql.NullString
var newUserJSON sql.NullString var newUserJSON sql.NullString
refreshQuery := `SELECT p_success, p_error, p_user::text FROM resolvespec_refresh_token($1, $2::jsonb)` refreshQuery := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2::jsonb)`, a.sqlNames.RefreshToken)
err = a.db.QueryRowContext(ctx, refreshQuery, refreshToken, userJSON).Scan(&newSuccess, &newErrorMsg, &newUserJSON) err = a.db.QueryRowContext(ctx, refreshQuery, refreshToken, userJSON).Scan(&newSuccess, &newErrorMsg, &newUserJSON)
if err != nil { if err != nil {
return nil, fmt.Errorf("refresh token generation failed: %w", err) return nil, fmt.Errorf("refresh token generation failed: %w", err)
@@ -401,27 +401,28 @@ func (a *DatabaseAuthenticator) RefreshToken(ctx context.Context, refreshToken s
// JWTAuthenticator provides JWT token-based authentication // JWTAuthenticator provides JWT token-based authentication
// All database operations go through stored procedures // All database operations go through stored procedures
// Requires stored procedures: resolvespec_jwt_login, resolvespec_jwt_logout // Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults)
// NOTE: JWT signing/verification requires github.com/golang-jwt/jwt/v5 to be installed and imported // NOTE: JWT signing/verification requires github.com/golang-jwt/jwt/v5 to be installed and imported
type JWTAuthenticator struct { type JWTAuthenticator struct {
secretKey []byte secretKey []byte
db *sql.DB db *sql.DB
sqlNames *SQLNames
} }
func NewJWTAuthenticator(secretKey string, db *sql.DB) *JWTAuthenticator { func NewJWTAuthenticator(secretKey string, db *sql.DB, names ...*SQLNames) *JWTAuthenticator {
return &JWTAuthenticator{ return &JWTAuthenticator{
secretKey: []byte(secretKey), secretKey: []byte(secretKey),
db: db, db: db,
sqlNames: resolveSQLNames(names...),
} }
} }
func (a *JWTAuthenticator) Login(ctx context.Context, req LoginRequest) (*LoginResponse, error) { func (a *JWTAuthenticator) Login(ctx context.Context, req LoginRequest) (*LoginResponse, error) {
// Call resolvespec_jwt_login stored procedure
var success bool var success bool
var errorMsg sql.NullString var errorMsg sql.NullString
var userJSON []byte var userJSON []byte
query := `SELECT p_success, p_error, p_user FROM resolvespec_jwt_login($1, $2)` query := fmt.Sprintf(`SELECT p_success, p_error, p_user FROM %s($1, $2)`, a.sqlNames.JWTLogin)
err := a.db.QueryRowContext(ctx, query, req.Username, req.Password).Scan(&success, &errorMsg, &userJSON) err := a.db.QueryRowContext(ctx, query, req.Username, req.Password).Scan(&success, &errorMsg, &userJSON)
if err != nil { if err != nil {
return nil, fmt.Errorf("login query failed: %w", err) return nil, fmt.Errorf("login query failed: %w", err)
@@ -471,11 +472,10 @@ func (a *JWTAuthenticator) Login(ctx context.Context, req LoginRequest) (*LoginR
} }
func (a *JWTAuthenticator) Logout(ctx context.Context, req LogoutRequest) error { func (a *JWTAuthenticator) Logout(ctx context.Context, req LogoutRequest) error {
// Call resolvespec_jwt_logout stored procedure
var success bool var success bool
var errorMsg sql.NullString var errorMsg sql.NullString
query := `SELECT p_success, p_error FROM resolvespec_jwt_logout($1, $2)` query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1, $2)`, a.sqlNames.JWTLogout)
err := a.db.QueryRowContext(ctx, query, req.Token, req.UserID).Scan(&success, &errorMsg) err := a.db.QueryRowContext(ctx, query, req.Token, req.UserID).Scan(&success, &errorMsg)
if err != nil { if err != nil {
return fmt.Errorf("logout query failed: %w", err) return fmt.Errorf("logout query failed: %w", err)
@@ -511,24 +511,24 @@ func (a *JWTAuthenticator) Authenticate(r *http.Request) (*UserContext, error) {
// DatabaseColumnSecurityProvider loads column security from database // DatabaseColumnSecurityProvider loads column security from database
// All database operations go through stored procedures // All database operations go through stored procedures
// Requires stored procedure: resolvespec_column_security // Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults)
type DatabaseColumnSecurityProvider struct { type DatabaseColumnSecurityProvider struct {
db *sql.DB db *sql.DB
sqlNames *SQLNames
} }
func NewDatabaseColumnSecurityProvider(db *sql.DB) *DatabaseColumnSecurityProvider { func NewDatabaseColumnSecurityProvider(db *sql.DB, names ...*SQLNames) *DatabaseColumnSecurityProvider {
return &DatabaseColumnSecurityProvider{db: db} return &DatabaseColumnSecurityProvider{db: db, sqlNames: resolveSQLNames(names...)}
} }
func (p *DatabaseColumnSecurityProvider) GetColumnSecurity(ctx context.Context, userID int, schema, table string) ([]ColumnSecurity, error) { func (p *DatabaseColumnSecurityProvider) GetColumnSecurity(ctx context.Context, userID int, schema, table string) ([]ColumnSecurity, error) {
var rules []ColumnSecurity var rules []ColumnSecurity
// Call resolvespec_column_security stored procedure
var success bool var success bool
var errorMsg sql.NullString var errorMsg sql.NullString
var rulesJSON []byte var rulesJSON []byte
query := `SELECT p_success, p_error, p_rules FROM resolvespec_column_security($1, $2, $3)` query := fmt.Sprintf(`SELECT p_success, p_error, p_rules FROM %s($1, $2, $3)`, p.sqlNames.ColumnSecurity)
err := p.db.QueryRowContext(ctx, query, userID, schema, table).Scan(&success, &errorMsg, &rulesJSON) err := p.db.QueryRowContext(ctx, query, userID, schema, table).Scan(&success, &errorMsg, &rulesJSON)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to load column security: %w", err) return nil, fmt.Errorf("failed to load column security: %w", err)
@@ -576,21 +576,21 @@ func (p *DatabaseColumnSecurityProvider) GetColumnSecurity(ctx context.Context,
// DatabaseRowSecurityProvider loads row security from database // DatabaseRowSecurityProvider loads row security from database
// All database operations go through stored procedures // All database operations go through stored procedures
// Requires stored procedure: resolvespec_row_security // Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults)
type DatabaseRowSecurityProvider struct { type DatabaseRowSecurityProvider struct {
db *sql.DB db *sql.DB
sqlNames *SQLNames
} }
func NewDatabaseRowSecurityProvider(db *sql.DB) *DatabaseRowSecurityProvider { func NewDatabaseRowSecurityProvider(db *sql.DB, names ...*SQLNames) *DatabaseRowSecurityProvider {
return &DatabaseRowSecurityProvider{db: db} return &DatabaseRowSecurityProvider{db: db, sqlNames: resolveSQLNames(names...)}
} }
func (p *DatabaseRowSecurityProvider) GetRowSecurity(ctx context.Context, userID int, schema, table string) (RowSecurity, error) { func (p *DatabaseRowSecurityProvider) GetRowSecurity(ctx context.Context, userID int, schema, table string) (RowSecurity, error) {
var template string var template string
var hasBlock bool var hasBlock bool
// Call resolvespec_row_security stored procedure query := fmt.Sprintf(`SELECT p_template, p_block FROM %s($1, $2, $3)`, p.sqlNames.RowSecurity)
query := `SELECT p_template, p_block FROM resolvespec_row_security($1, $2, $3)`
err := p.db.QueryRowContext(ctx, query, schema, table, userID).Scan(&template, &hasBlock) err := p.db.QueryRowContext(ctx, query, schema, table, userID).Scan(&template, &hasBlock)
if err != nil { if err != nil {
@@ -758,56 +758,47 @@ func (a *DatabaseAuthenticator) LoginWithPasskey(ctx context.Context, req Passke
return nil, fmt.Errorf("passkey authentication failed: %w", err) return nil, fmt.Errorf("passkey authentication failed: %w", err)
} }
// Get user data from database // Build request JSON for passkey login stored procedure
var username, email, roles string reqData := map[string]any{
var userLevel int "user_id": userID,
query := `SELECT username, email, user_level, COALESCE(roles, '') FROM users WHERE id = $1 AND is_active = true`
err = a.db.QueryRowContext(ctx, query, userID).Scan(&username, &email, &userLevel, &roles)
if err != nil {
return nil, fmt.Errorf("failed to get user data: %w", err)
} }
// Generate session token
sessionToken := "sess_" + generateRandomString(32) + "_" + fmt.Sprintf("%d", time.Now().Unix())
expiresAt := time.Now().Add(24 * time.Hour)
// Extract IP and user agent from claims
ipAddress := ""
userAgent := ""
if req.Claims != nil { if req.Claims != nil {
if ip, ok := req.Claims["ip_address"].(string); ok { if ip, ok := req.Claims["ip_address"].(string); ok {
ipAddress = ip reqData["ip_address"] = ip
} }
if ua, ok := req.Claims["user_agent"].(string); ok { if ua, ok := req.Claims["user_agent"].(string); ok {
userAgent = ua reqData["user_agent"] = ua
} }
} }
// Create session reqJSON, err := json.Marshal(reqData)
insertQuery := `INSERT INTO user_sessions (session_token, user_id, expires_at, ip_address, user_agent, last_activity_at)
VALUES ($1, $2, $3, $4, $5, now())`
_, err = a.db.ExecContext(ctx, insertQuery, sessionToken, userID, expiresAt, ipAddress, userAgent)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create session: %w", err) return nil, fmt.Errorf("failed to marshal passkey login request: %w", err)
} }
// Update last login var success bool
updateQuery := `UPDATE users SET last_login_at = now() WHERE id = $1` var errorMsg sql.NullString
_, _ = a.db.ExecContext(ctx, updateQuery, userID) var dataJSON sql.NullString
// Return login response query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.PasskeyLogin)
return &LoginResponse{ err = a.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
Token: sessionToken, if err != nil {
User: &UserContext{ return nil, fmt.Errorf("passkey login query failed: %w", err)
UserID: userID, }
UserName: username,
Email: email, if !success {
UserLevel: userLevel, if errorMsg.Valid {
SessionID: sessionToken, return nil, fmt.Errorf("%s", errorMsg.String)
Roles: parseRoles(roles), }
}, return nil, fmt.Errorf("passkey login failed")
ExpiresIn: int64(24 * time.Hour.Seconds()), }
}, nil
var response LoginResponse
if err := json.Unmarshal([]byte(dataJSON.String), &response); err != nil {
return nil, fmt.Errorf("failed to parse passkey login response: %w", err)
}
return &response, nil
} }
// GetPasskeyCredentials returns all passkey credentials for a user // GetPasskeyCredentials returns all passkey credentials for a user

222
pkg/security/sql_names.go Normal file
View File

@@ -0,0 +1,222 @@
package security
import (
"fmt"
"reflect"
"regexp"
)
var validSQLIdentifier = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`)
// SQLNames defines all configurable SQL stored procedure and table names
// used by the security package. Override individual fields to remap
// to custom database objects. Use DefaultSQLNames() for baseline defaults,
// and MergeSQLNames() to apply partial overrides.
type SQLNames struct {
// Auth procedures (DatabaseAuthenticator)
Login string // default: "resolvespec_login"
Register string // default: "resolvespec_register"
Logout string // default: "resolvespec_logout"
Session string // default: "resolvespec_session"
SessionUpdate string // default: "resolvespec_session_update"
RefreshToken string // default: "resolvespec_refresh_token"
// JWT procedures (JWTAuthenticator)
JWTLogin string // default: "resolvespec_jwt_login"
JWTLogout string // default: "resolvespec_jwt_logout"
// Security policy procedures
ColumnSecurity string // default: "resolvespec_column_security"
RowSecurity string // default: "resolvespec_row_security"
// TOTP procedures (DatabaseTwoFactorProvider)
TOTPEnable string // default: "resolvespec_totp_enable"
TOTPDisable string // default: "resolvespec_totp_disable"
TOTPGetStatus string // default: "resolvespec_totp_get_status"
TOTPGetSecret string // default: "resolvespec_totp_get_secret"
TOTPRegenerateBackup string // default: "resolvespec_totp_regenerate_backup_codes"
TOTPValidateBackupCode string // default: "resolvespec_totp_validate_backup_code"
// Passkey procedures (DatabasePasskeyProvider)
PasskeyStoreCredential string // default: "resolvespec_passkey_store_credential"
PasskeyGetCredsByUsername string // default: "resolvespec_passkey_get_credentials_by_username"
PasskeyGetCredential string // default: "resolvespec_passkey_get_credential"
PasskeyUpdateCounter string // default: "resolvespec_passkey_update_counter"
PasskeyGetUserCredentials string // default: "resolvespec_passkey_get_user_credentials"
PasskeyDeleteCredential string // default: "resolvespec_passkey_delete_credential"
PasskeyUpdateName string // default: "resolvespec_passkey_update_name"
PasskeyLogin string // default: "resolvespec_passkey_login"
// OAuth2 procedures (DatabaseAuthenticator OAuth2 methods)
OAuthGetOrCreateUser string // default: "resolvespec_oauth_getorcreateuser"
OAuthCreateSession string // default: "resolvespec_oauth_createsession"
OAuthGetRefreshToken string // default: "resolvespec_oauth_getrefreshtoken"
OAuthUpdateRefreshToken string // default: "resolvespec_oauth_updaterefreshtoken"
OAuthGetUser string // default: "resolvespec_oauth_getuser"
}
// DefaultSQLNames returns an SQLNames with all default resolvespec_* values.
func DefaultSQLNames() *SQLNames {
return &SQLNames{
Login: "resolvespec_login",
Register: "resolvespec_register",
Logout: "resolvespec_logout",
Session: "resolvespec_session",
SessionUpdate: "resolvespec_session_update",
RefreshToken: "resolvespec_refresh_token",
JWTLogin: "resolvespec_jwt_login",
JWTLogout: "resolvespec_jwt_logout",
ColumnSecurity: "resolvespec_column_security",
RowSecurity: "resolvespec_row_security",
TOTPEnable: "resolvespec_totp_enable",
TOTPDisable: "resolvespec_totp_disable",
TOTPGetStatus: "resolvespec_totp_get_status",
TOTPGetSecret: "resolvespec_totp_get_secret",
TOTPRegenerateBackup: "resolvespec_totp_regenerate_backup_codes",
TOTPValidateBackupCode: "resolvespec_totp_validate_backup_code",
PasskeyStoreCredential: "resolvespec_passkey_store_credential",
PasskeyGetCredsByUsername: "resolvespec_passkey_get_credentials_by_username",
PasskeyGetCredential: "resolvespec_passkey_get_credential",
PasskeyUpdateCounter: "resolvespec_passkey_update_counter",
PasskeyGetUserCredentials: "resolvespec_passkey_get_user_credentials",
PasskeyDeleteCredential: "resolvespec_passkey_delete_credential",
PasskeyUpdateName: "resolvespec_passkey_update_name",
PasskeyLogin: "resolvespec_passkey_login",
OAuthGetOrCreateUser: "resolvespec_oauth_getorcreateuser",
OAuthCreateSession: "resolvespec_oauth_createsession",
OAuthGetRefreshToken: "resolvespec_oauth_getrefreshtoken",
OAuthUpdateRefreshToken: "resolvespec_oauth_updaterefreshtoken",
OAuthGetUser: "resolvespec_oauth_getuser",
}
}
// MergeSQLNames returns a copy of base with any non-empty fields from override applied.
// If override is nil, a copy of base is returned.
func MergeSQLNames(base, override *SQLNames) *SQLNames {
if override == nil {
copied := *base
return &copied
}
merged := *base
if override.Login != "" {
merged.Login = override.Login
}
if override.Register != "" {
merged.Register = override.Register
}
if override.Logout != "" {
merged.Logout = override.Logout
}
if override.Session != "" {
merged.Session = override.Session
}
if override.SessionUpdate != "" {
merged.SessionUpdate = override.SessionUpdate
}
if override.RefreshToken != "" {
merged.RefreshToken = override.RefreshToken
}
if override.JWTLogin != "" {
merged.JWTLogin = override.JWTLogin
}
if override.JWTLogout != "" {
merged.JWTLogout = override.JWTLogout
}
if override.ColumnSecurity != "" {
merged.ColumnSecurity = override.ColumnSecurity
}
if override.RowSecurity != "" {
merged.RowSecurity = override.RowSecurity
}
if override.TOTPEnable != "" {
merged.TOTPEnable = override.TOTPEnable
}
if override.TOTPDisable != "" {
merged.TOTPDisable = override.TOTPDisable
}
if override.TOTPGetStatus != "" {
merged.TOTPGetStatus = override.TOTPGetStatus
}
if override.TOTPGetSecret != "" {
merged.TOTPGetSecret = override.TOTPGetSecret
}
if override.TOTPRegenerateBackup != "" {
merged.TOTPRegenerateBackup = override.TOTPRegenerateBackup
}
if override.TOTPValidateBackupCode != "" {
merged.TOTPValidateBackupCode = override.TOTPValidateBackupCode
}
if override.PasskeyStoreCredential != "" {
merged.PasskeyStoreCredential = override.PasskeyStoreCredential
}
if override.PasskeyGetCredsByUsername != "" {
merged.PasskeyGetCredsByUsername = override.PasskeyGetCredsByUsername
}
if override.PasskeyGetCredential != "" {
merged.PasskeyGetCredential = override.PasskeyGetCredential
}
if override.PasskeyUpdateCounter != "" {
merged.PasskeyUpdateCounter = override.PasskeyUpdateCounter
}
if override.PasskeyGetUserCredentials != "" {
merged.PasskeyGetUserCredentials = override.PasskeyGetUserCredentials
}
if override.PasskeyDeleteCredential != "" {
merged.PasskeyDeleteCredential = override.PasskeyDeleteCredential
}
if override.PasskeyUpdateName != "" {
merged.PasskeyUpdateName = override.PasskeyUpdateName
}
if override.PasskeyLogin != "" {
merged.PasskeyLogin = override.PasskeyLogin
}
if override.OAuthGetOrCreateUser != "" {
merged.OAuthGetOrCreateUser = override.OAuthGetOrCreateUser
}
if override.OAuthCreateSession != "" {
merged.OAuthCreateSession = override.OAuthCreateSession
}
if override.OAuthGetRefreshToken != "" {
merged.OAuthGetRefreshToken = override.OAuthGetRefreshToken
}
if override.OAuthUpdateRefreshToken != "" {
merged.OAuthUpdateRefreshToken = override.OAuthUpdateRefreshToken
}
if override.OAuthGetUser != "" {
merged.OAuthGetUser = override.OAuthGetUser
}
return &merged
}
// ValidateSQLNames checks that all non-empty fields in names are valid SQL identifiers.
// Returns an error if any field contains invalid characters.
func ValidateSQLNames(names *SQLNames) error {
v := reflect.ValueOf(names).Elem()
typ := v.Type()
for i := 0; i < v.NumField(); i++ {
field := v.Field(i)
if field.Kind() != reflect.String {
continue
}
val := field.String()
if val != "" && !validSQLIdentifier.MatchString(val) {
return fmt.Errorf("SQLNames.%s contains invalid characters: %q", typ.Field(i).Name, val)
}
}
return nil
}
// resolveSQLNames merges an optional override with defaults.
// Used by constructors that accept variadic *SQLNames.
func resolveSQLNames(override ...*SQLNames) *SQLNames {
if len(override) > 0 && override[0] != nil {
return MergeSQLNames(DefaultSQLNames(), override[0])
}
return DefaultSQLNames()
}

View File

@@ -0,0 +1,145 @@
package security
import (
"reflect"
"testing"
)
func TestDefaultSQLNames_AllFieldsNonEmpty(t *testing.T) {
names := DefaultSQLNames()
v := reflect.ValueOf(names).Elem()
typ := v.Type()
for i := 0; i < v.NumField(); i++ {
field := v.Field(i)
if field.Kind() != reflect.String {
continue
}
if field.String() == "" {
t.Errorf("DefaultSQLNames().%s is empty", typ.Field(i).Name)
}
}
}
func TestMergeSQLNames_PartialOverride(t *testing.T) {
base := DefaultSQLNames()
override := &SQLNames{
Login: "custom_login",
TOTPEnable: "custom_totp_enable",
PasskeyLogin: "custom_passkey_login",
}
merged := MergeSQLNames(base, override)
if merged.Login != "custom_login" {
t.Errorf("MergeSQLNames().Login = %q, want %q", merged.Login, "custom_login")
}
if merged.TOTPEnable != "custom_totp_enable" {
t.Errorf("MergeSQLNames().TOTPEnable = %q, want %q", merged.TOTPEnable, "custom_totp_enable")
}
if merged.PasskeyLogin != "custom_passkey_login" {
t.Errorf("MergeSQLNames().PasskeyLogin = %q, want %q", merged.PasskeyLogin, "custom_passkey_login")
}
// Non-overridden fields should retain defaults
if merged.Logout != "resolvespec_logout" {
t.Errorf("MergeSQLNames().Logout = %q, want %q", merged.Logout, "resolvespec_logout")
}
if merged.Session != "resolvespec_session" {
t.Errorf("MergeSQLNames().Session = %q, want %q", merged.Session, "resolvespec_session")
}
}
func TestMergeSQLNames_NilOverride(t *testing.T) {
base := DefaultSQLNames()
merged := MergeSQLNames(base, nil)
// Should be a copy, not the same pointer
if merged == base {
t.Error("MergeSQLNames with nil override should return a copy, not the same pointer")
}
// All values should match
v1 := reflect.ValueOf(base).Elem()
v2 := reflect.ValueOf(merged).Elem()
typ := v1.Type()
for i := 0; i < v1.NumField(); i++ {
f1 := v1.Field(i)
f2 := v2.Field(i)
if f1.Kind() != reflect.String {
continue
}
if f1.String() != f2.String() {
t.Errorf("MergeSQLNames(base, nil).%s = %q, want %q", typ.Field(i).Name, f2.String(), f1.String())
}
}
}
func TestMergeSQLNames_DoesNotMutateBase(t *testing.T) {
base := DefaultSQLNames()
originalLogin := base.Login
override := &SQLNames{Login: "custom_login"}
_ = MergeSQLNames(base, override)
if base.Login != originalLogin {
t.Errorf("MergeSQLNames mutated base: Login = %q, want %q", base.Login, originalLogin)
}
}
func TestMergeSQLNames_AllFieldsMerged(t *testing.T) {
base := DefaultSQLNames()
override := &SQLNames{}
v := reflect.ValueOf(override).Elem()
for i := 0; i < v.NumField(); i++ {
if v.Field(i).Kind() == reflect.String {
v.Field(i).SetString("custom_sentinel")
}
}
merged := MergeSQLNames(base, override)
mv := reflect.ValueOf(merged).Elem()
typ := mv.Type()
for i := 0; i < mv.NumField(); i++ {
if mv.Field(i).Kind() != reflect.String {
continue
}
if mv.Field(i).String() != "custom_sentinel" {
t.Errorf("MergeSQLNames did not merge field %s", typ.Field(i).Name)
}
}
}
func TestValidateSQLNames_Valid(t *testing.T) {
names := DefaultSQLNames()
if err := ValidateSQLNames(names); err != nil {
t.Errorf("ValidateSQLNames(defaults) error = %v", err)
}
}
func TestValidateSQLNames_Invalid(t *testing.T) {
names := DefaultSQLNames()
names.Login = "resolvespec_login; DROP TABLE users; --"
err := ValidateSQLNames(names)
if err == nil {
t.Error("ValidateSQLNames should reject names with invalid characters")
}
}
func TestResolveSQLNames_NoOverride(t *testing.T) {
names := resolveSQLNames()
if names.Login != "resolvespec_login" {
t.Errorf("resolveSQLNames().Login = %q, want default", names.Login)
}
}
func TestResolveSQLNames_WithOverride(t *testing.T) {
names := resolveSQLNames(&SQLNames{Login: "custom_login"})
if names.Login != "custom_login" {
t.Errorf("resolveSQLNames().Login = %q, want %q", names.Login, "custom_login")
}
if names.Logout != "resolvespec_logout" {
t.Errorf("resolveSQLNames().Logout = %q, want default", names.Logout)
}
}

View File

@@ -9,23 +9,23 @@ import (
) )
// DatabaseTwoFactorProvider implements TwoFactorAuthProvider using PostgreSQL stored procedures // DatabaseTwoFactorProvider implements TwoFactorAuthProvider using PostgreSQL stored procedures
// Requires stored procedures: resolvespec_totp_enable, resolvespec_totp_disable, // Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults)
// resolvespec_totp_get_status, resolvespec_totp_get_secret,
// resolvespec_totp_regenerate_backup_codes, resolvespec_totp_validate_backup_code
// See totp_database_schema.sql for procedure definitions // See totp_database_schema.sql for procedure definitions
type DatabaseTwoFactorProvider struct { type DatabaseTwoFactorProvider struct {
db *sql.DB db *sql.DB
totpGen *TOTPGenerator totpGen *TOTPGenerator
sqlNames *SQLNames
} }
// NewDatabaseTwoFactorProvider creates a new database-backed 2FA provider // NewDatabaseTwoFactorProvider creates a new database-backed 2FA provider
func NewDatabaseTwoFactorProvider(db *sql.DB, config *TwoFactorConfig) *DatabaseTwoFactorProvider { func NewDatabaseTwoFactorProvider(db *sql.DB, config *TwoFactorConfig, names ...*SQLNames) *DatabaseTwoFactorProvider {
if config == nil { if config == nil {
config = DefaultTwoFactorConfig() config = DefaultTwoFactorConfig()
} }
return &DatabaseTwoFactorProvider{ return &DatabaseTwoFactorProvider{
db: db, db: db,
totpGen: NewTOTPGenerator(config), totpGen: NewTOTPGenerator(config),
sqlNames: resolveSQLNames(names...),
} }
} }
@@ -76,7 +76,7 @@ func (p *DatabaseTwoFactorProvider) Enable2FA(userID int, secret string, backupC
var success bool var success bool
var errorMsg sql.NullString var errorMsg sql.NullString
query := `SELECT p_success, p_error FROM resolvespec_totp_enable($1, $2, $3::jsonb)` query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1, $2, $3::jsonb)`, p.sqlNames.TOTPEnable)
err = p.db.QueryRow(query, userID, secret, string(codesJSON)).Scan(&success, &errorMsg) err = p.db.QueryRow(query, userID, secret, string(codesJSON)).Scan(&success, &errorMsg)
if err != nil { if err != nil {
return fmt.Errorf("enable 2FA query failed: %w", err) return fmt.Errorf("enable 2FA query failed: %w", err)
@@ -97,7 +97,7 @@ func (p *DatabaseTwoFactorProvider) Disable2FA(userID int) error {
var success bool var success bool
var errorMsg sql.NullString var errorMsg sql.NullString
query := `SELECT p_success, p_error FROM resolvespec_totp_disable($1)` query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1)`, p.sqlNames.TOTPDisable)
err := p.db.QueryRow(query, userID).Scan(&success, &errorMsg) err := p.db.QueryRow(query, userID).Scan(&success, &errorMsg)
if err != nil { if err != nil {
return fmt.Errorf("disable 2FA query failed: %w", err) return fmt.Errorf("disable 2FA query failed: %w", err)
@@ -119,7 +119,7 @@ func (p *DatabaseTwoFactorProvider) Get2FAStatus(userID int) (bool, error) {
var errorMsg sql.NullString var errorMsg sql.NullString
var enabled bool var enabled bool
query := `SELECT p_success, p_error, p_enabled FROM resolvespec_totp_get_status($1)` query := fmt.Sprintf(`SELECT p_success, p_error, p_enabled FROM %s($1)`, p.sqlNames.TOTPGetStatus)
err := p.db.QueryRow(query, userID).Scan(&success, &errorMsg, &enabled) err := p.db.QueryRow(query, userID).Scan(&success, &errorMsg, &enabled)
if err != nil { if err != nil {
return false, fmt.Errorf("get 2FA status query failed: %w", err) return false, fmt.Errorf("get 2FA status query failed: %w", err)
@@ -141,7 +141,7 @@ func (p *DatabaseTwoFactorProvider) Get2FASecret(userID int) (string, error) {
var errorMsg sql.NullString var errorMsg sql.NullString
var secret sql.NullString var secret sql.NullString
query := `SELECT p_success, p_error, p_secret FROM resolvespec_totp_get_secret($1)` query := fmt.Sprintf(`SELECT p_success, p_error, p_secret FROM %s($1)`, p.sqlNames.TOTPGetSecret)
err := p.db.QueryRow(query, userID).Scan(&success, &errorMsg, &secret) err := p.db.QueryRow(query, userID).Scan(&success, &errorMsg, &secret)
if err != nil { if err != nil {
return "", fmt.Errorf("get 2FA secret query failed: %w", err) return "", fmt.Errorf("get 2FA secret query failed: %w", err)
@@ -185,7 +185,7 @@ func (p *DatabaseTwoFactorProvider) GenerateBackupCodes(userID int, count int) (
var success bool var success bool
var errorMsg sql.NullString var errorMsg sql.NullString
query := `SELECT p_success, p_error FROM resolvespec_totp_regenerate_backup_codes($1, $2::jsonb)` query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1, $2::jsonb)`, p.sqlNames.TOTPRegenerateBackup)
err = p.db.QueryRow(query, userID, string(codesJSON)).Scan(&success, &errorMsg) err = p.db.QueryRow(query, userID, string(codesJSON)).Scan(&success, &errorMsg)
if err != nil { if err != nil {
return nil, fmt.Errorf("regenerate backup codes query failed: %w", err) return nil, fmt.Errorf("regenerate backup codes query failed: %w", err)
@@ -212,7 +212,7 @@ func (p *DatabaseTwoFactorProvider) ValidateBackupCode(userID int, code string)
var errorMsg sql.NullString var errorMsg sql.NullString
var valid bool var valid bool
query := `SELECT p_success, p_error, p_valid FROM resolvespec_totp_validate_backup_code($1, $2)` query := fmt.Sprintf(`SELECT p_success, p_error, p_valid FROM %s($1, $2)`, p.sqlNames.TOTPValidateBackupCode)
err := p.db.QueryRow(query, userID, codeHash).Scan(&success, &errorMsg, &valid) err := p.db.QueryRow(query, userID, codeHash).Scan(&success, &errorMsg, &valid)
if err != nil { if err != nil {
return false, fmt.Errorf("validate backup code query failed: %w", err) return false, fmt.Errorf("validate backup code query failed: %w", err)