Compare commits

...

5 Commits

Author SHA1 Message Date
Hein
79a3912f93 fix(db): improve database connection handling and reconnection logic
* Added a database factory function to allow reconnection when the database is closed.
* Implemented mutex locks for safe concurrent access to the database connection.
* Updated all database query methods to handle reconnection attempts on closed connections.
* Enhanced error handling for database operations across multiple providers.
2026-04-09 09:19:28 +02:00
Hein
a9bf08f58b feat(security): implement keystore for user authentication keys
* Add ConfigKeyStore for in-memory key management
* Introduce DatabaseKeyStore for PostgreSQL-backed key storage
* Create KeyStoreAuthenticator for API key validation
* Define SQL procedures for key management in PostgreSQL
* Document keystore functionality and usage in KEYSTORE.md
2026-04-07 17:09:17 +02:00
Hein
405a04a192 feat(security): integrate security hooks for access control
Some checks failed
Build , Vet Test, and Lint / Lint Code (push) Failing after -30m6s
Tests / Unit Tests (push) Successful in -30m22s
Tests / Integration Tests (push) Failing after -30m41s
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Successful in -30m3s
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Successful in -29m36s
Build , Vet Test, and Lint / Build (push) Successful in -29m58s
* Add security hooks for per-entity operation rules and row/column-level security.
* Implement annotation tool for storing and retrieving freeform annotations.
* Enhance handler to support model registration with access rules.
2026-04-07 15:53:12 +02:00
Hein
c1b16d363a feat(db): add DB method to sqlConnection and mongoConnection
Some checks failed
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Successful in -30m22s
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Successful in -29m59s
Build , Vet Test, and Lint / Lint Code (push) Failing after -30m11s
Build , Vet Test, and Lint / Build (push) Successful in -30m12s
Tests / Unit Tests (push) Successful in -30m49s
Tests / Integration Tests (push) Failing after -30m59s
2026-04-01 15:34:09 +02:00
Hein
568df8c6d6 feat(security): add configurable SQL procedure names
Some checks failed
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Successful in -25m9s
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Successful in -24m29s
Build , Vet Test, and Lint / Build (push) Successful in -30m5s
Build , Vet Test, and Lint / Lint Code (push) Failing after -28m58s
Tests / Integration Tests (push) Failing after -30m26s
Tests / Unit Tests (push) Successful in -28m7s
* Introduce SQLNames struct to define stored procedure names.
* Update DatabaseAuthenticator, JWTAuthenticator, and other providers to use SQLNames for procedure calls.
* Remove hardcoded procedure names for better flexibility and customization.
* Implement validation for SQL names to ensure they are valid identifiers.
* Add tests for SQLNames functionality and merging behavior.
2026-03-31 14:25:59 +02:00
24 changed files with 2203 additions and 177 deletions

View File

@@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"reflect" "reflect"
"strings" "strings"
"sync"
"time" "time"
"github.com/uptrace/bun" "github.com/uptrace/bun"
@@ -95,6 +96,8 @@ func debugScanIntoStruct(rows interface{}, dest interface{}) error {
// This demonstrates how the abstraction works with different ORMs // This demonstrates how the abstraction works with different ORMs
type BunAdapter struct { type BunAdapter struct {
db *bun.DB db *bun.DB
dbMu sync.RWMutex
dbFactory func() (*bun.DB, error)
driverName string driverName string
} }
@@ -106,10 +109,36 @@ func NewBunAdapter(db *bun.DB) *BunAdapter {
return adapter return adapter
} }
// WithDBFactory configures a factory used to reopen the database connection if it is closed.
func (b *BunAdapter) WithDBFactory(factory func() (*bun.DB, error)) *BunAdapter {
b.dbFactory = factory
return b
}
func (b *BunAdapter) getDB() *bun.DB {
b.dbMu.RLock()
defer b.dbMu.RUnlock()
return b.db
}
func (b *BunAdapter) reconnectDB() error {
if b.dbFactory == nil {
return fmt.Errorf("no db factory configured for reconnect")
}
newDB, err := b.dbFactory()
if err != nil {
return err
}
b.dbMu.Lock()
b.db = newDB
b.dbMu.Unlock()
return nil
}
// EnableQueryDebug enables query debugging which logs all SQL queries including preloads // EnableQueryDebug enables query debugging which logs all SQL queries including preloads
// This is useful for debugging preload queries that may be failing // This is useful for debugging preload queries that may be failing
func (b *BunAdapter) EnableQueryDebug() { func (b *BunAdapter) EnableQueryDebug() {
b.db.AddQueryHook(&QueryDebugHook{}) b.getDB().AddQueryHook(&QueryDebugHook{})
logger.Info("Bun query debug mode enabled - all SQL queries will be logged") logger.Info("Bun query debug mode enabled - all SQL queries will be logged")
} }
@@ -130,22 +159,22 @@ func (b *BunAdapter) DisableQueryDebug() {
func (b *BunAdapter) NewSelect() common.SelectQuery { func (b *BunAdapter) NewSelect() common.SelectQuery {
return &BunSelectQuery{ return &BunSelectQuery{
query: b.db.NewSelect(), query: b.getDB().NewSelect(),
db: b.db, db: b.db,
driverName: b.driverName, driverName: b.driverName,
} }
} }
func (b *BunAdapter) NewInsert() common.InsertQuery { func (b *BunAdapter) NewInsert() common.InsertQuery {
return &BunInsertQuery{query: b.db.NewInsert()} return &BunInsertQuery{query: b.getDB().NewInsert()}
} }
func (b *BunAdapter) NewUpdate() common.UpdateQuery { func (b *BunAdapter) NewUpdate() common.UpdateQuery {
return &BunUpdateQuery{query: b.db.NewUpdate()} return &BunUpdateQuery{query: b.getDB().NewUpdate()}
} }
func (b *BunAdapter) NewDelete() common.DeleteQuery { func (b *BunAdapter) NewDelete() common.DeleteQuery {
return &BunDeleteQuery{query: b.db.NewDelete()} return &BunDeleteQuery{query: b.getDB().NewDelete()}
} }
func (b *BunAdapter) Exec(ctx context.Context, query string, args ...interface{}) (res common.Result, err error) { func (b *BunAdapter) Exec(ctx context.Context, query string, args ...interface{}) (res common.Result, err error) {
@@ -154,7 +183,14 @@ func (b *BunAdapter) Exec(ctx context.Context, query string, args ...interface{}
err = logger.HandlePanic("BunAdapter.Exec", r) err = logger.HandlePanic("BunAdapter.Exec", r)
} }
}() }()
result, err := b.db.ExecContext(ctx, query, args...) var result sql.Result
run := func() error { var e error; result, e = b.getDB().ExecContext(ctx, query, args...); return e }
err = run()
if isDBClosed(err) {
if reconnErr := b.reconnectDB(); reconnErr == nil {
err = run()
}
}
return &BunResult{result: result}, err return &BunResult{result: result}, err
} }
@@ -164,11 +200,17 @@ func (b *BunAdapter) Query(ctx context.Context, dest interface{}, query string,
err = logger.HandlePanic("BunAdapter.Query", r) err = logger.HandlePanic("BunAdapter.Query", r)
} }
}() }()
return b.db.NewRaw(query, args...).Scan(ctx, dest) err = b.getDB().NewRaw(query, args...).Scan(ctx, dest)
if isDBClosed(err) {
if reconnErr := b.reconnectDB(); reconnErr == nil {
err = b.getDB().NewRaw(query, args...).Scan(ctx, dest)
}
}
return err
} }
func (b *BunAdapter) BeginTx(ctx context.Context) (common.Database, error) { func (b *BunAdapter) BeginTx(ctx context.Context) (common.Database, error) {
tx, err := b.db.BeginTx(ctx, &sql.TxOptions{}) tx, err := b.getDB().BeginTx(ctx, &sql.TxOptions{})
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -194,7 +236,7 @@ func (b *BunAdapter) RunInTransaction(ctx context.Context, fn func(common.Databa
err = logger.HandlePanic("BunAdapter.RunInTransaction", r) err = logger.HandlePanic("BunAdapter.RunInTransaction", r)
} }
}() }()
return b.db.RunInTx(ctx, &sql.TxOptions{}, func(ctx context.Context, tx bun.Tx) error { return b.getDB().RunInTx(ctx, &sql.TxOptions{}, func(ctx context.Context, tx bun.Tx) error {
// Create adapter with transaction // Create adapter with transaction
adapter := &BunTxAdapter{tx: tx, driverName: b.driverName} adapter := &BunTxAdapter{tx: tx, driverName: b.driverName}
return fn(adapter) return fn(adapter)
@@ -202,7 +244,7 @@ func (b *BunAdapter) RunInTransaction(ctx context.Context, fn func(common.Databa
} }
func (b *BunAdapter) GetUnderlyingDB() interface{} { func (b *BunAdapter) GetUnderlyingDB() interface{} {
return b.db return b.getDB()
} }
func (b *BunAdapter) DriverName() string { func (b *BunAdapter) DriverName() string {

View File

@@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"reflect" "reflect"
"strings" "strings"
"sync"
"github.com/bitechdev/ResolveSpec/pkg/common" "github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/logger" "github.com/bitechdev/ResolveSpec/pkg/logger"
@@ -17,6 +18,8 @@ import (
// This provides a lightweight PostgreSQL adapter without ORM overhead // This provides a lightweight PostgreSQL adapter without ORM overhead
type PgSQLAdapter struct { type PgSQLAdapter struct {
db *sql.DB db *sql.DB
dbMu sync.RWMutex
dbFactory func() (*sql.DB, error)
driverName string driverName string
} }
@@ -31,6 +34,36 @@ func NewPgSQLAdapter(db *sql.DB, driverName ...string) *PgSQLAdapter {
return &PgSQLAdapter{db: db, driverName: name} return &PgSQLAdapter{db: db, driverName: name}
} }
// WithDBFactory configures a factory used to reopen the database connection if it is closed.
func (p *PgSQLAdapter) WithDBFactory(factory func() (*sql.DB, error)) *PgSQLAdapter {
p.dbFactory = factory
return p
}
func (p *PgSQLAdapter) getDB() *sql.DB {
p.dbMu.RLock()
defer p.dbMu.RUnlock()
return p.db
}
func (p *PgSQLAdapter) reconnectDB() error {
if p.dbFactory == nil {
return fmt.Errorf("no db factory configured for reconnect")
}
newDB, err := p.dbFactory()
if err != nil {
return err
}
p.dbMu.Lock()
p.db = newDB
p.dbMu.Unlock()
return nil
}
func isDBClosed(err error) bool {
return err != nil && strings.Contains(err.Error(), "sql: database is closed")
}
// EnableQueryDebug enables query debugging for development // EnableQueryDebug enables query debugging for development
func (p *PgSQLAdapter) EnableQueryDebug() { func (p *PgSQLAdapter) EnableQueryDebug() {
logger.Info("PgSQL query debug mode - logging enabled via logger") logger.Info("PgSQL query debug mode - logging enabled via logger")
@@ -38,7 +71,7 @@ func (p *PgSQLAdapter) EnableQueryDebug() {
func (p *PgSQLAdapter) NewSelect() common.SelectQuery { func (p *PgSQLAdapter) NewSelect() common.SelectQuery {
return &PgSQLSelectQuery{ return &PgSQLSelectQuery{
db: p.db, db: p.getDB(),
driverName: p.driverName, driverName: p.driverName,
columns: []string{"*"}, columns: []string{"*"},
args: make([]interface{}, 0), args: make([]interface{}, 0),
@@ -47,7 +80,7 @@ func (p *PgSQLAdapter) NewSelect() common.SelectQuery {
func (p *PgSQLAdapter) NewInsert() common.InsertQuery { func (p *PgSQLAdapter) NewInsert() common.InsertQuery {
return &PgSQLInsertQuery{ return &PgSQLInsertQuery{
db: p.db, db: p.getDB(),
driverName: p.driverName, driverName: p.driverName,
values: make(map[string]interface{}), values: make(map[string]interface{}),
} }
@@ -55,7 +88,7 @@ func (p *PgSQLAdapter) NewInsert() common.InsertQuery {
func (p *PgSQLAdapter) NewUpdate() common.UpdateQuery { func (p *PgSQLAdapter) NewUpdate() common.UpdateQuery {
return &PgSQLUpdateQuery{ return &PgSQLUpdateQuery{
db: p.db, db: p.getDB(),
driverName: p.driverName, driverName: p.driverName,
sets: make(map[string]interface{}), sets: make(map[string]interface{}),
args: make([]interface{}, 0), args: make([]interface{}, 0),
@@ -65,7 +98,7 @@ func (p *PgSQLAdapter) NewUpdate() common.UpdateQuery {
func (p *PgSQLAdapter) NewDelete() common.DeleteQuery { func (p *PgSQLAdapter) NewDelete() common.DeleteQuery {
return &PgSQLDeleteQuery{ return &PgSQLDeleteQuery{
db: p.db, db: p.getDB(),
driverName: p.driverName, driverName: p.driverName,
args: make([]interface{}, 0), args: make([]interface{}, 0),
whereClauses: make([]string, 0), whereClauses: make([]string, 0),
@@ -79,7 +112,14 @@ func (p *PgSQLAdapter) Exec(ctx context.Context, query string, args ...interface
} }
}() }()
logger.Debug("PgSQL Exec: %s [args: %v]", query, args) logger.Debug("PgSQL Exec: %s [args: %v]", query, args)
result, err := p.db.ExecContext(ctx, query, args...) var result sql.Result
run := func() error { var e error; result, e = p.getDB().ExecContext(ctx, query, args...); return e }
err = run()
if isDBClosed(err) {
if reconnErr := p.reconnectDB(); reconnErr == nil {
err = run()
}
}
if err != nil { if err != nil {
logger.Error("PgSQL Exec failed: %v", err) logger.Error("PgSQL Exec failed: %v", err)
return nil, err return nil, err
@@ -94,7 +134,14 @@ func (p *PgSQLAdapter) Query(ctx context.Context, dest interface{}, query string
} }
}() }()
logger.Debug("PgSQL Query: %s [args: %v]", query, args) logger.Debug("PgSQL Query: %s [args: %v]", query, args)
rows, err := p.db.QueryContext(ctx, query, args...) var rows *sql.Rows
run := func() error { var e error; rows, e = p.getDB().QueryContext(ctx, query, args...); return e }
err = run()
if isDBClosed(err) {
if reconnErr := p.reconnectDB(); reconnErr == nil {
err = run()
}
}
if err != nil { if err != nil {
logger.Error("PgSQL Query failed: %v", err) logger.Error("PgSQL Query failed: %v", err)
return err return err
@@ -105,7 +152,7 @@ func (p *PgSQLAdapter) Query(ctx context.Context, dest interface{}, query string
} }
func (p *PgSQLAdapter) BeginTx(ctx context.Context) (common.Database, error) { func (p *PgSQLAdapter) BeginTx(ctx context.Context) (common.Database, error) {
tx, err := p.db.BeginTx(ctx, nil) tx, err := p.getDB().BeginTx(ctx, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -127,7 +174,7 @@ func (p *PgSQLAdapter) RunInTransaction(ctx context.Context, fn func(common.Data
} }
}() }()
tx, err := p.db.BeginTx(ctx, nil) tx, err := p.getDB().BeginTx(ctx, nil)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -26,6 +26,7 @@ type Connection interface {
Bun() (*bun.DB, error) Bun() (*bun.DB, error)
GORM() (*gorm.DB, error) GORM() (*gorm.DB, error)
Native() (*sql.DB, error) Native() (*sql.DB, error)
DB() (*sql.DB, error)
// Common Database interface (for SQL databases) // Common Database interface (for SQL databases)
Database() (common.Database, error) Database() (common.Database, error)
@@ -224,6 +225,11 @@ func (c *sqlConnection) Native() (*sql.DB, error) {
return c.nativeDB, nil return c.nativeDB, nil
} }
// DB returns the underlying *sql.DB connection
func (c *sqlConnection) DB() (*sql.DB, error) {
return c.Native()
}
// Bun returns a Bun ORM instance wrapping the native connection // Bun returns a Bun ORM instance wrapping the native connection
func (c *sqlConnection) Bun() (*bun.DB, error) { func (c *sqlConnection) Bun() (*bun.DB, error) {
if c == nil { if c == nil {
@@ -645,6 +651,11 @@ func (c *mongoConnection) Native() (*sql.DB, error) {
return nil, ErrNotSQLDatabase return nil, ErrNotSQLDatabase
} }
// DB returns an error for MongoDB connections
func (c *mongoConnection) DB() (*sql.DB, error) {
return nil, ErrNotSQLDatabase
}
// Database returns an error for MongoDB connections // Database returns an error for MongoDB connections
func (c *mongoConnection) Database() (common.Database, error) { func (c *mongoConnection) Database() (common.Database, error) {
return nil, ErrNotSQLDatabase return nil, ErrNotSQLDatabase

View File

@@ -4,11 +4,17 @@ import (
"context" "context"
"database/sql" "database/sql"
"errors" "errors"
"strings"
"time" "time"
"go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo"
) )
// isDBClosed reports whether err indicates the *sql.DB has been closed.
func isDBClosed(err error) bool {
return err != nil && strings.Contains(err.Error(), "sql: database is closed")
}
// Common errors // Common errors
var ( var (
// ErrNotSQLDatabase is returned when attempting SQL operations on a non-SQL database // ErrNotSQLDatabase is returned when attempting SQL operations on a non-SQL database

View File

@@ -4,6 +4,7 @@ import (
"context" "context"
"database/sql" "database/sql"
"fmt" "fmt"
"sync"
"time" "time"
_ "github.com/glebarez/sqlite" // Pure Go SQLite driver _ "github.com/glebarez/sqlite" // Pure Go SQLite driver
@@ -14,8 +15,10 @@ import (
// SQLiteProvider implements Provider for SQLite databases // SQLiteProvider implements Provider for SQLite databases
type SQLiteProvider struct { type SQLiteProvider struct {
db *sql.DB db *sql.DB
config ConnectionConfig dbMu sync.RWMutex
dbFactory func() (*sql.DB, error)
config ConnectionConfig
} }
// NewSQLiteProvider creates a new SQLite provider // NewSQLiteProvider creates a new SQLite provider
@@ -129,7 +132,13 @@ func (p *SQLiteProvider) HealthCheck(ctx context.Context) error {
// Execute a simple query to verify the database is accessible // Execute a simple query to verify the database is accessible
var result int var result int
err := p.db.QueryRowContext(healthCtx, "SELECT 1").Scan(&result) run := func() error { return p.getDB().QueryRowContext(healthCtx, "SELECT 1").Scan(&result) }
err := run()
if isDBClosed(err) {
if reconnErr := p.reconnectDB(); reconnErr == nil {
err = run()
}
}
if err != nil { if err != nil {
return fmt.Errorf("health check failed: %w", err) return fmt.Errorf("health check failed: %w", err)
} }
@@ -141,6 +150,32 @@ func (p *SQLiteProvider) HealthCheck(ctx context.Context) error {
return nil return nil
} }
// WithDBFactory configures a factory used to reopen the database connection if it is closed.
func (p *SQLiteProvider) WithDBFactory(factory func() (*sql.DB, error)) *SQLiteProvider {
p.dbFactory = factory
return p
}
func (p *SQLiteProvider) getDB() *sql.DB {
p.dbMu.RLock()
defer p.dbMu.RUnlock()
return p.db
}
func (p *SQLiteProvider) reconnectDB() error {
if p.dbFactory == nil {
return fmt.Errorf("no db factory configured for reconnect")
}
newDB, err := p.dbFactory()
if err != nil {
return err
}
p.dbMu.Lock()
p.db = newDB
p.dbMu.Unlock()
return nil
}
// GetNative returns the native *sql.DB connection // GetNative returns the native *sql.DB connection
func (p *SQLiteProvider) GetNative() (*sql.DB, error) { func (p *SQLiteProvider) GetNative() (*sql.DB, error) {
if p.db == nil { if p.db == nil {

View File

@@ -119,6 +119,83 @@ Add middleware before the MCP routes. The handler itself has no auth layer.
--- ---
## Security
`resolvemcp` integrates with the `security` package to provide per-entity access control, row-level security, and column-level security — the same system used by `resolvespec` and `restheadspec`.
### Wiring security hooks
```go
import "github.com/bitechdev/ResolveSpec/pkg/security"
securityList := security.NewSecurityList(mySecurityProvider)
resolvemcp.RegisterSecurityHooks(handler, securityList)
```
Call `RegisterSecurityHooks` **once**, after creating the handler and before registering models. It installs these controls automatically:
| Hook | Effect |
|---|---|
| `BeforeHandle` | Enforces per-entity operation rules (see below) |
| `BeforeRead` | Loads RLS/CLS rules, then injects a user-scoped WHERE clause |
| `AfterRead` | Masks/hides columns per column-security rules; writes audit log |
| `BeforeUpdate` | Blocks update if `CanUpdate` is false |
| `BeforeDelete` | Blocks delete if `CanDelete` is false |
### Per-entity operation rules
Use `RegisterModelWithRules` instead of `RegisterModel` to set access rules at registration time:
```go
import "github.com/bitechdev/ResolveSpec/pkg/modelregistry"
// Read-only entity
handler.RegisterModelWithRules("public", "audit_logs", &AuditLog{}, modelregistry.ModelRules{
CanRead: true,
CanCreate: false,
CanUpdate: false,
CanDelete: false,
})
// Public read, authenticated write
handler.RegisterModelWithRules("public", "products", &Product{}, modelregistry.ModelRules{
CanPublicRead: true,
CanRead: true,
CanCreate: true,
CanUpdate: true,
CanDelete: false,
})
```
To update rules for an already-registered model:
```go
handler.SetModelRules("public", "users", modelregistry.ModelRules{
CanRead: true,
CanCreate: true,
CanUpdate: true,
CanDelete: false,
})
```
`RegisterModel` (no rules) registers with all-allowed defaults (`CanRead/Create/Update/Delete = true`).
### ModelRules fields
| Field | Default | Description |
|---|---|---|
| `CanPublicRead` | `false` | Allow unauthenticated reads |
| `CanPublicCreate` | `false` | Allow unauthenticated creates |
| `CanPublicUpdate` | `false` | Allow unauthenticated updates |
| `CanPublicDelete` | `false` | Allow unauthenticated deletes |
| `CanRead` | `true` | Allow authenticated reads |
| `CanCreate` | `true` | Allow authenticated creates |
| `CanUpdate` | `true` | Allow authenticated updates |
| `CanDelete` | `true` | Allow authenticated deletes |
| `SecurityDisabled` | `false` | Skip all security checks for this model |
---
## MCP Tools ## MCP Tools
### Tool Naming ### Tool Naming
@@ -204,6 +281,35 @@ Delete a record by primary key. **Irreversible.**
{ "success": true, "data": { ...deleted record... } } { "success": true, "data": { ...deleted record... } }
``` ```
### Annotation Tool — `resolvespec_annotate`
Store or retrieve freeform annotation records for any tool, model, or entity. Registered automatically on every handler.
| Argument | Type | Description |
|---|---|---|
| `tool_name` | string (required) | Key to annotate — an MCP tool name (e.g. `read_public_users`), a model name (e.g. `public.users`), or any other identifier. |
| `annotations` | object | Annotation data to persist. Omit to retrieve existing annotations instead. |
**Set annotations** (calls `resolvespec_set_annotation(tool_name, annotations)`):
```json
{ "tool_name": "read_public_users", "annotations": { "description": "Returns active users", "owner": "platform-team" } }
```
**Response:**
```json
{ "success": true, "tool_name": "read_public_users", "action": "set" }
```
**Get annotations** (calls `resolvespec_get_annotation(tool_name)`):
```json
{ "tool_name": "read_public_users" }
```
**Response:**
```json
{ "success": true, "tool_name": "read_public_users", "action": "get", "annotations": { ... } }
```
---
### Resource — `{schema}.{entity}` ### Resource — `{schema}.{entity}`
Each model is also registered as an MCP resource with URI `schema.entity` (or just `entity` when schema is empty). Reading the resource returns up to 100 records as `application/json`. Each model is also registered as an MCP resource with URI `schema.entity` (or just `entity` when schema is empty). Reading the resource returns up to 100 records as `application/json`.

View File

@@ -0,0 +1,107 @@
package resolvemcp
import (
"context"
"encoding/json"
"fmt"
"github.com/mark3labs/mcp-go/mcp"
)
const annotationToolName = "resolvespec_annotate"
// registerAnnotationTool adds the resolvespec_annotate tool to the MCP server.
// The tool lets models/entities store and retrieve freeform annotation records
// using the resolvespec_set_annotation / resolvespec_get_annotation database procedures.
func registerAnnotationTool(h *Handler) {
tool := mcp.NewTool(annotationToolName,
mcp.WithDescription(
"Store or retrieve annotations for any MCP tool, model, or entity.\n\n"+
"To set annotations: provide both 'tool_name' and 'annotations'. "+
"Calls resolvespec_set_annotation(tool_name, annotations) to persist the data.\n\n"+
"To get annotations: provide only 'tool_name'. "+
"Calls resolvespec_get_annotation(tool_name) and returns the stored annotations.\n\n"+
"'tool_name' may be any identifier: an MCP tool name (e.g. 'read_public_users'), "+
"a model/entity name (e.g. 'public.users'), or any other key.",
),
mcp.WithString("tool_name",
mcp.Description("Name of the tool, model, or entity to annotate (e.g. 'read_public_users', 'public.users')."),
mcp.Required(),
),
mcp.WithObject("annotations",
mcp.Description("Annotation data to store. Omit to retrieve existing annotations instead of setting them."),
),
)
h.mcpServer.AddTool(tool, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
args := req.GetArguments()
toolName, ok := args["tool_name"].(string)
if !ok || toolName == "" {
return mcp.NewToolResultError("missing required argument: tool_name"), nil
}
annotations, hasAnnotations := args["annotations"]
if hasAnnotations && annotations != nil {
return executeSetAnnotation(ctx, h, toolName, annotations)
}
return executeGetAnnotation(ctx, h, toolName)
})
}
func executeSetAnnotation(ctx context.Context, h *Handler, toolName string, annotations interface{}) (*mcp.CallToolResult, error) {
jsonBytes, err := json.Marshal(annotations)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to marshal annotations: %v", err)), nil
}
_, err = h.db.Exec(ctx, "SELECT resolvespec_set_annotation($1, $2)", toolName, string(jsonBytes))
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to set annotation: %v", err)), nil
}
return marshalResult(map[string]interface{}{
"success": true,
"tool_name": toolName,
"action": "set",
})
}
func executeGetAnnotation(ctx context.Context, h *Handler, toolName string) (*mcp.CallToolResult, error) {
var rows []map[string]interface{}
err := h.db.Query(ctx, &rows, "SELECT resolvespec_get_annotation($1)", toolName)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to get annotation: %v", err)), nil
}
var annotations interface{}
if len(rows) > 0 {
// The procedure returns a single value; extract the first column of the first row.
for _, v := range rows[0] {
annotations = v
break
}
}
// If the value is a []byte or string containing JSON, decode it so it round-trips cleanly.
switch v := annotations.(type) {
case []byte:
var decoded interface{}
if json.Unmarshal(v, &decoded) == nil {
annotations = decoded
}
case string:
var decoded interface{}
if json.Unmarshal([]byte(v), &decoded) == nil {
annotations = decoded
}
}
return marshalResult(map[string]interface{}{
"success": true,
"tool_name": toolName,
"action": "get",
"annotations": annotations,
})
}

View File

@@ -14,6 +14,7 @@ import (
"github.com/bitechdev/ResolveSpec/pkg/common" "github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/logger" "github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
"github.com/bitechdev/ResolveSpec/pkg/reflection" "github.com/bitechdev/ResolveSpec/pkg/reflection"
) )
@@ -30,7 +31,7 @@ type Handler struct {
// NewHandler creates a Handler with the given database, model registry, and config. // NewHandler creates a Handler with the given database, model registry, and config.
func NewHandler(db common.Database, registry common.ModelRegistry, cfg Config) *Handler { func NewHandler(db common.Database, registry common.ModelRegistry, cfg Config) *Handler {
return &Handler{ h := &Handler{
db: db, db: db,
registry: registry, registry: registry,
hooks: NewHookRegistry(), hooks: NewHookRegistry(),
@@ -39,6 +40,8 @@ func NewHandler(db common.Database, registry common.ModelRegistry, cfg Config) *
name: "resolvemcp", name: "resolvemcp",
version: "1.0.0", version: "1.0.0",
} }
registerAnnotationTool(h)
return h
} }
// Hooks returns the hook registry. // Hooks returns the hook registry.
@@ -71,7 +74,7 @@ func (h *Handler) newSSEServer(baseURL, basePath string) *server.SSEServer {
return server.NewSSEServer( return server.NewSSEServer(
h.mcpServer, h.mcpServer,
server.WithBaseURL(baseURL), server.WithBaseURL(baseURL),
server.WithBasePath(basePath), server.WithStaticBasePath(basePath),
) )
} }
@@ -123,6 +126,32 @@ func (h *Handler) RegisterModel(schema, entity string, model interface{}) error
return nil return nil
} }
// RegisterModelWithRules registers a model and sets per-entity operation rules
// (CanRead, CanCreate, CanUpdate, CanDelete, CanPublic*, SecurityDisabled).
// Requires RegisterSecurityHooks to have been called for the rules to be enforced.
func (h *Handler) RegisterModelWithRules(schema, entity string, model interface{}, rules modelregistry.ModelRules) error {
reg, ok := h.registry.(*modelregistry.DefaultModelRegistry)
if !ok {
return fmt.Errorf("resolvemcp: registry does not support model rules (use NewHandlerWithGORM/Bun/DB)")
}
fullName := buildModelName(schema, entity)
if err := reg.RegisterModelWithRules(fullName, model, rules); err != nil {
return err
}
registerModelTools(h, schema, entity, model)
return nil
}
// SetModelRules updates the operation rules for an already-registered model.
// Requires RegisterSecurityHooks to have been called for the rules to be enforced.
func (h *Handler) SetModelRules(schema, entity string, rules modelregistry.ModelRules) error {
reg, ok := h.registry.(*modelregistry.DefaultModelRegistry)
if !ok {
return fmt.Errorf("resolvemcp: registry does not support model rules (use NewHandlerWithGORM/Bun/DB)")
}
return reg.SetModelRules(buildModelName(schema, entity), rules)
}
// buildModelName builds the registry key for a model (same format as resolvespec). // buildModelName builds the registry key for a model (same format as resolvespec).
func buildModelName(schema, entity string) string { func buildModelName(schema, entity string) string {
if schema == "" { if schema == "" {
@@ -666,7 +695,7 @@ func (h *Handler) applyFilterGroup(query common.SelectQuery, filters []common.Fi
return query.Where("("+strings.Join(conditions, " OR ")+")", args...) return query.Where("("+strings.Join(conditions, " OR ")+")", args...)
} }
func (h *Handler) buildFilterCondition(filter common.FilterOption) (string, []interface{}) { func (h *Handler) buildFilterCondition(filter common.FilterOption) (condition string, args []interface{}) {
switch filter.Operator { switch filter.Operator {
case "eq", "=": case "eq", "=":
return fmt.Sprintf("%s = ?", filter.Column), []interface{}{filter.Value} return fmt.Sprintf("%s = ?", filter.Column), []interface{}{filter.Value}
@@ -696,7 +725,8 @@ func (h *Handler) buildFilterCondition(filter common.FilterOption) (string, []in
} }
func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, preloads []common.PreloadOption) (common.SelectQuery, error) { func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, preloads []common.PreloadOption) (common.SelectQuery, error) {
for _, preload := range preloads { for i := range preloads {
preload := &preloads[i]
if preload.Relation == "" { if preload.Relation == "" {
continue continue
} }

View File

@@ -0,0 +1,115 @@
package resolvemcp
import (
"context"
"net/http"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/security"
)
// RegisterSecurityHooks wires the security package's access-control layer into the
// resolvemcp handler. Call it once after creating the handler, before registering models.
//
// The following controls are applied:
// - Per-entity operation rules (CanRead, CanCreate, CanUpdate, CanDelete, CanPublic*)
// stored via RegisterModelWithRules / SetModelRules.
// - Row-level security: WHERE clause injected per user from the SecurityList provider.
// - Column-level security: sensitive columns masked/hidden in read results.
// - Audit logging after each read.
func RegisterSecurityHooks(handler *Handler, securityList *security.SecurityList) {
// BeforeHandle: enforce model-level operation rules (auth check).
handler.Hooks().Register(BeforeHandle, func(hookCtx *HookContext) error {
if err := security.CheckModelAuthAllowed(newSecurityContext(hookCtx), hookCtx.Operation); err != nil {
hookCtx.Abort = true
hookCtx.AbortMessage = err.Error()
hookCtx.AbortCode = http.StatusUnauthorized
return err
}
return nil
})
// BeforeRead (1st): load RLS + CLS rules from the provider into SecurityList.
handler.Hooks().Register(BeforeRead, func(hookCtx *HookContext) error {
return security.LoadSecurityRules(newSecurityContext(hookCtx), securityList)
})
// BeforeRead (2nd): apply row-level security — injects a WHERE clause into the query.
// resolvemcp has no separate BeforeScan hook; the query is available in BeforeRead.
handler.Hooks().Register(BeforeRead, func(hookCtx *HookContext) error {
return security.ApplyRowSecurity(newSecurityContext(hookCtx), securityList)
})
// AfterRead (1st): apply column-level security — mask/hide columns in the result.
handler.Hooks().Register(AfterRead, func(hookCtx *HookContext) error {
return security.ApplyColumnSecurity(newSecurityContext(hookCtx), securityList)
})
// AfterRead (2nd): audit log.
handler.Hooks().Register(AfterRead, func(hookCtx *HookContext) error {
return security.LogDataAccess(newSecurityContext(hookCtx))
})
// BeforeUpdate: enforce CanUpdate rule.
handler.Hooks().Register(BeforeUpdate, func(hookCtx *HookContext) error {
return security.CheckModelUpdateAllowed(newSecurityContext(hookCtx))
})
// BeforeDelete: enforce CanDelete rule.
handler.Hooks().Register(BeforeDelete, func(hookCtx *HookContext) error {
return security.CheckModelDeleteAllowed(newSecurityContext(hookCtx))
})
logger.Info("Security hooks registered for resolvemcp handler")
}
// --------------------------------------------------------------------------
// securityContext — adapts resolvemcp.HookContext to security.SecurityContext
// --------------------------------------------------------------------------
type securityContext struct {
ctx *HookContext
}
func newSecurityContext(ctx *HookContext) security.SecurityContext {
return &securityContext{ctx: ctx}
}
func (s *securityContext) GetContext() context.Context {
return s.ctx.Context
}
func (s *securityContext) GetUserID() (int, bool) {
return security.GetUserID(s.ctx.Context)
}
func (s *securityContext) GetSchema() string {
return s.ctx.Schema
}
func (s *securityContext) GetEntity() string {
return s.ctx.Entity
}
func (s *securityContext) GetModel() interface{} {
return s.ctx.Model
}
func (s *securityContext) GetQuery() interface{} {
return s.ctx.Query
}
func (s *securityContext) SetQuery(query interface{}) {
if q, ok := query.(common.SelectQuery); ok {
s.ctx.Query = q
}
}
func (s *securityContext) GetResult() interface{} {
return s.ctx.Result
}
func (s *securityContext) SetResult(result interface{}) {
s.ctx.Result = result
}

153
pkg/security/KEYSTORE.md Normal file
View File

@@ -0,0 +1,153 @@
# Keystore
Per-user named auth keys with pluggable storage. Each user can hold multiple keys of different types — JWT secrets, header API keys, OAuth2 client credentials, or generic API keys. Keys are identified by a human-readable name ("CI deploy", "mobile app") and can carry scopes and arbitrary metadata.
## Key types
| Constant | Value | Use case |
|---|---|---|
| `KeyTypeJWTSecret` | `jwt_secret` | Per-user JWT signing secret |
| `KeyTypeHeaderAPI` | `header_api` | Static API key sent in a request header |
| `KeyTypeOAuth2` | `oauth2` | OAuth2 client credentials |
| `KeyTypeGenericAPI` | `api` | General-purpose application key |
## Storage backends
### ConfigKeyStore
In-memory store seeded from a static list. Suitable for a small, fixed set of service-account keys loaded from a config file. Keys created at runtime via `CreateKey` are held in memory and lost on restart.
```go
// Pre-load keys from config (KeyHash = SHA-256 hex of the raw key)
store := security.NewConfigKeyStore([]security.UserKey{
{
UserID: 1,
KeyType: security.KeyTypeGenericAPI,
KeyHash: "e3b0c44298fc1c149afb...", // sha256(rawKey)
Name: "CI deploy",
Scopes: []string{"deploy"},
IsActive: true,
},
})
```
### DatabaseKeyStore
Backed by PostgreSQL stored procedures. Supports optional caching (default 2-minute TTL). Apply `keystore_schema.sql` before use.
```go
db, _ := sql.Open("postgres", dsn)
store := security.NewDatabaseKeyStore(db)
// With options
store = security.NewDatabaseKeyStore(db, security.DatabaseKeyStoreOptions{
CacheTTL: 5 * time.Minute,
SQLNames: &security.KeyStoreSQLNames{
ValidateKey: "myapp_keystore_validate", // override one procedure name
},
})
```
## Managing keys
```go
ctx := context.Background()
// Create — raw key returned once; store it securely
resp, err := store.CreateKey(ctx, security.CreateKeyRequest{
UserID: 42,
KeyType: security.KeyTypeGenericAPI,
Name: "mobile app",
Scopes: []string{"read", "write"},
})
fmt.Println(resp.RawKey) // only shown here; hashed internally
// List
keys, err := store.GetUserKeys(ctx, 42, "") // "" = all types
keys, err = store.GetUserKeys(ctx, 42, security.KeyTypeGenericAPI)
// Revoke
err = store.DeleteKey(ctx, 42, resp.Key.ID)
// Validate (used by authenticators internally)
key, err := store.ValidateKey(ctx, rawKey, "")
```
## HTTP authentication
`KeyStoreAuthenticator` wraps any `KeyStore` and implements the `Authenticator` interface. It is drop-in compatible with `DatabaseAuthenticator` and works in `CompositeSecurityProvider`.
Keys are extracted from the request in this order:
1. `Authorization: Bearer <key>`
2. `Authorization: ApiKey <key>`
3. `X-API-Key: <key>`
```go
auth := security.NewKeyStoreAuthenticator(store, "") // "" = accept any key type
// Restrict to a specific type:
auth = security.NewKeyStoreAuthenticator(store, security.KeyTypeGenericAPI)
```
Plug it into a handler:
```go
handler := resolvespec.NewHandler(db, registry,
resolvespec.WithAuthenticator(auth),
)
```
`Login` and `Logout` return an error — key lifecycle is managed through `KeyStore` directly.
On successful validation the request context receives a `UserContext` where:
- `UserID` — from the key
- `Roles` — the key's `Scopes`
- `Claims["key_type"]` — key type string
- `Claims["key_name"]` — key name
## Database setup
Apply `keystore_schema.sql` to your PostgreSQL database. It requires the `users` table from the main `database_schema.sql`.
```sql
\i pkg/security/keystore_schema.sql
```
This creates:
- `user_keys` table with indexes on `user_id`, `key_hash`, and `key_type`
- `resolvespec_keystore_get_user_keys(p_user_id, p_key_type)`
- `resolvespec_keystore_create_key(p_request jsonb)`
- `resolvespec_keystore_delete_key(p_user_id, p_key_id)`
- `resolvespec_keystore_validate_key(p_key_hash, p_key_type)`
### Custom procedure names
```go
store := security.NewDatabaseKeyStore(db, security.DatabaseKeyStoreOptions{
SQLNames: &security.KeyStoreSQLNames{
GetUserKeys: "myschema_get_keys",
CreateKey: "myschema_create_key",
DeleteKey: "myschema_delete_key",
ValidateKey: "myschema_validate_key",
},
})
// Validate names at startup
names := &security.KeyStoreSQLNames{
GetUserKeys: "myschema_get_keys",
// ...
}
if err := security.ValidateKeyStoreSQLNames(names); err != nil {
log.Fatal(err)
}
```
## Security notes
- Raw keys are never stored. Only the SHA-256 hex digest is persisted.
- The raw key is generated with `crypto/rand` (32 bytes, base64url-encoded) and returned exactly once in `CreateKeyResponse.RawKey`.
- Hash comparisons in `ConfigKeyStore` use `crypto/subtle.ConstantTimeCompare` to prevent timing side-channels.
- `DeleteKey` performs a soft delete (`is_active = false`). The `DatabaseKeyStore` invalidates the cache entry immediately, but due to the cache TTL a revoked key may authenticate for up to `CacheTTL` (default 2 minutes) in a distributed environment. Set `CacheTTL: 0` to disable caching if immediate revocation is required.

View File

@@ -258,11 +258,8 @@ func (a *JWTAuthenticator) Login(ctx context.Context, req security.LoginRequest)
} }
func (a *JWTAuthenticator) Logout(ctx context.Context, req security.LogoutRequest) error { func (a *JWTAuthenticator) Logout(ctx context.Context, req security.LogoutRequest) error {
// Add to blacklist // Invalidate session via stored procedure
return a.db.WithContext(ctx).Table("token_blacklist").Create(map[string]any{ return nil
"token": req.Token,
"user_id": req.UserID,
}).Error
} }
func (a *JWTAuthenticator) Authenticate(r *http.Request) (*security.UserContext, error) { func (a *JWTAuthenticator) Authenticate(r *http.Request) (*security.UserContext, error) {

View File

@@ -135,12 +135,6 @@ func (a *JWTAuthenticatorExample) Login(ctx context.Context, req LoginRequest) (
} }
func (a *JWTAuthenticatorExample) Logout(ctx context.Context, req LogoutRequest) error { func (a *JWTAuthenticatorExample) Logout(ctx context.Context, req LogoutRequest) error {
// For JWT, logout could involve token blacklisting
// Add token to blacklist table
// err := a.db.WithContext(ctx).Table("token_blacklist").Create(map[string]interface{}{
// "token": req.Token,
// "expires_at": time.Now().Add(24 * time.Hour),
// }).Error
return nil return nil
} }

81
pkg/security/keystore.go Normal file
View File

@@ -0,0 +1,81 @@
package security
import (
"context"
"crypto/sha256"
"encoding/hex"
"time"
)
// hashSHA256Hex returns the lowercase hex SHA-256 digest of the given string.
// Used by all keystore implementations to hash raw keys before storage or lookup.
func hashSHA256Hex(raw string) string {
sum := sha256.Sum256([]byte(raw))
return hex.EncodeToString(sum[:])
}
// KeyType identifies the category of an auth key.
type KeyType string
const (
// KeyTypeJWTSecret is a per-user JWT signing secret for token generation.
KeyTypeJWTSecret KeyType = "jwt_secret"
// KeyTypeHeaderAPI is a static API key sent via a request header.
KeyTypeHeaderAPI KeyType = "header_api"
// KeyTypeOAuth2 holds OAuth2 client credentials (client_id / client_secret).
KeyTypeOAuth2 KeyType = "oauth2"
// KeyTypeGenericAPI is a generic application API key.
KeyTypeGenericAPI KeyType = "api"
)
// UserKey represents a single named auth key belonging to a user.
// KeyHash stores the SHA-256 hex digest of the raw key; the raw key is never persisted.
type UserKey struct {
ID int64 `json:"id"`
UserID int `json:"user_id"`
KeyType KeyType `json:"key_type"`
KeyHash string `json:"key_hash"` // SHA-256 hex; never the raw key
Name string `json:"name"`
Scopes []string `json:"scopes,omitempty"`
Meta map[string]any `json:"meta,omitempty"`
ExpiresAt *time.Time `json:"expires_at,omitempty"`
CreatedAt time.Time `json:"created_at"`
LastUsedAt *time.Time `json:"last_used_at,omitempty"`
IsActive bool `json:"is_active"`
}
// CreateKeyRequest specifies the parameters for a new key.
type CreateKeyRequest struct {
UserID int
KeyType KeyType
Name string
Scopes []string
Meta map[string]any
ExpiresAt *time.Time
}
// CreateKeyResponse is returned exactly once when a key is created.
// The caller is responsible for persisting RawKey; it is not stored anywhere.
type CreateKeyResponse struct {
Key UserKey
RawKey string // crypto/rand 32 bytes, base64url-encoded
}
// KeyStore manages per-user auth keys with pluggable storage backends.
// Implementations: ConfigKeyStore (static list) and DatabaseKeyStore (stored procedures).
type KeyStore interface {
// CreateKey generates a new key, stores its hash, and returns the raw key once.
CreateKey(ctx context.Context, req CreateKeyRequest) (*CreateKeyResponse, error)
// GetUserKeys returns all active, non-expired keys for a user.
// Pass an empty KeyType to return all types.
GetUserKeys(ctx context.Context, userID int, keyType KeyType) ([]UserKey, error)
// DeleteKey soft-deletes a key by ID after verifying ownership.
DeleteKey(ctx context.Context, userID int, keyID int64) error
// ValidateKey checks a raw key, returns the matching UserKey on success.
// The implementation hashes the raw key before any lookup.
// Pass an empty KeyType to accept any type.
ValidateKey(ctx context.Context, rawKey string, keyType KeyType) (*UserKey, error)
}

View File

@@ -0,0 +1,97 @@
package security
import (
"context"
"fmt"
"net/http"
"strings"
)
// KeyStoreAuthenticator implements the Authenticator interface using a KeyStore.
// It is suitable for long-lived application credentials (API keys, JWT secrets, etc.)
// rather than interactive sessions. Login and Logout are not supported — key lifecycle
// is managed directly through the KeyStore.
//
// Key extraction order:
// 1. Authorization: Bearer <key>
// 2. Authorization: ApiKey <key>
// 3. X-API-Key header
type KeyStoreAuthenticator struct {
keyStore KeyStore
keyType KeyType // empty = accept any type
}
// NewKeyStoreAuthenticator creates a KeyStoreAuthenticator.
// Pass an empty keyType to accept keys of any type.
func NewKeyStoreAuthenticator(ks KeyStore, keyType KeyType) *KeyStoreAuthenticator {
return &KeyStoreAuthenticator{keyStore: ks, keyType: keyType}
}
// Login is not supported for keystore authentication.
func (a *KeyStoreAuthenticator) Login(_ context.Context, _ LoginRequest) (*LoginResponse, error) {
return nil, fmt.Errorf("keystore authenticator does not support login")
}
// Logout is not supported for keystore authentication.
func (a *KeyStoreAuthenticator) Logout(_ context.Context, _ LogoutRequest) error {
return nil
}
// Authenticate extracts an API key from the request and validates it against the KeyStore.
// Returns a UserContext built from the matching UserKey on success.
func (a *KeyStoreAuthenticator) Authenticate(r *http.Request) (*UserContext, error) {
rawKey := extractAPIKey(r)
if rawKey == "" {
return nil, fmt.Errorf("API key required (Authorization: Bearer/ApiKey <key> or X-API-Key header)")
}
userKey, err := a.keyStore.ValidateKey(r.Context(), rawKey, a.keyType)
if err != nil {
return nil, fmt.Errorf("invalid API key: %w", err)
}
return userKeyToUserContext(userKey), nil
}
// extractAPIKey extracts a raw key from the request using the following precedence:
// 1. Authorization: Bearer <key>
// 2. Authorization: ApiKey <key>
// 3. X-API-Key header
func extractAPIKey(r *http.Request) string {
if auth := r.Header.Get("Authorization"); auth != "" {
if after, ok := strings.CutPrefix(auth, "Bearer "); ok {
return strings.TrimSpace(after)
}
if after, ok := strings.CutPrefix(auth, "ApiKey "); ok {
return strings.TrimSpace(after)
}
}
return strings.TrimSpace(r.Header.Get("X-API-Key"))
}
// userKeyToUserContext converts a UserKey into a UserContext.
// Scopes are mapped to Roles. Key type and name are stored in Claims.
func userKeyToUserContext(k *UserKey) *UserContext {
claims := map[string]any{
"key_type": string(k.KeyType),
"key_name": k.Name,
}
meta := k.Meta
if meta == nil {
meta = map[string]any{}
}
roles := k.Scopes
if roles == nil {
roles = []string{}
}
return &UserContext{
UserID: k.UserID,
SessionID: fmt.Sprintf("key:%d", k.ID),
Roles: roles,
Claims: claims,
Meta: meta,
}
}

View File

@@ -0,0 +1,149 @@
package security
import (
"context"
"crypto/rand"
"crypto/subtle"
"encoding/base64"
"encoding/hex"
"fmt"
"sync"
"sync/atomic"
"time"
)
// ConfigKeyStore is an in-memory keystore backed by a static slice of UserKey values.
// It is designed for config-file driven setups (e.g. service accounts defined in YAML)
// with a small, bounded number of keys. For large or dynamic key sets use DatabaseKeyStore.
//
// Pre-existing entries must have KeyHash set to the SHA-256 hex of the intended raw key.
// Keys created at runtime via CreateKey are held in memory only and lost on restart.
type ConfigKeyStore struct {
mu sync.RWMutex
keys []UserKey
next int64 // monotonic ID counter for runtime-created keys (atomic)
}
// NewConfigKeyStore creates a ConfigKeyStore seeded with the provided keys.
// Pass nil or an empty slice to start with no pre-loaded keys.
// Zero-value entries (CreatedAt is zero) are treated as active and assigned the current time.
func NewConfigKeyStore(keys []UserKey) *ConfigKeyStore {
var maxID int64
copied := make([]UserKey, len(keys))
copy(copied, keys)
for i := range copied {
if copied[i].CreatedAt.IsZero() {
copied[i].IsActive = true
copied[i].CreatedAt = time.Now()
}
if copied[i].ID > maxID {
maxID = copied[i].ID
}
}
return &ConfigKeyStore{keys: copied, next: maxID}
}
// CreateKey generates a new raw key, stores its SHA-256 hash, and returns the raw key once.
func (s *ConfigKeyStore) CreateKey(_ context.Context, req CreateKeyRequest) (*CreateKeyResponse, error) {
rawBytes := make([]byte, 32)
if _, err := rand.Read(rawBytes); err != nil {
return nil, fmt.Errorf("failed to generate key material: %w", err)
}
rawKey := base64.RawURLEncoding.EncodeToString(rawBytes)
hash := hashSHA256Hex(rawKey)
id := atomic.AddInt64(&s.next, 1)
key := UserKey{
ID: id,
UserID: req.UserID,
KeyType: req.KeyType,
KeyHash: hash,
Name: req.Name,
Scopes: req.Scopes,
Meta: req.Meta,
ExpiresAt: req.ExpiresAt,
CreatedAt: time.Now(),
IsActive: true,
}
s.mu.Lock()
s.keys = append(s.keys, key)
s.mu.Unlock()
return &CreateKeyResponse{Key: key, RawKey: rawKey}, nil
}
// GetUserKeys returns all active, non-expired keys for the given user.
// Pass an empty KeyType to return all types.
func (s *ConfigKeyStore) GetUserKeys(_ context.Context, userID int, keyType KeyType) ([]UserKey, error) {
now := time.Now()
s.mu.RLock()
defer s.mu.RUnlock()
var result []UserKey
for i := range s.keys {
k := &s.keys[i]
if k.UserID != userID || !k.IsActive {
continue
}
if k.ExpiresAt != nil && k.ExpiresAt.Before(now) {
continue
}
if keyType != "" && k.KeyType != keyType {
continue
}
result = append(result, *k)
}
return result, nil
}
// DeleteKey soft-deletes a key by setting IsActive to false after ownership verification.
func (s *ConfigKeyStore) DeleteKey(_ context.Context, userID int, keyID int64) error {
s.mu.Lock()
defer s.mu.Unlock()
for i := range s.keys {
if s.keys[i].ID == keyID {
if s.keys[i].UserID != userID {
return fmt.Errorf("key not found or permission denied")
}
s.keys[i].IsActive = false
return nil
}
}
return fmt.Errorf("key not found")
}
// ValidateKey hashes the raw key and finds a matching, active, non-expired entry.
// Uses constant-time comparison to prevent timing side-channels.
// Pass an empty KeyType to accept any type.
func (s *ConfigKeyStore) ValidateKey(_ context.Context, rawKey string, keyType KeyType) (*UserKey, error) {
hash := hashSHA256Hex(rawKey)
hashBytes, _ := hex.DecodeString(hash)
now := time.Now()
// Write lock: ValidateKey updates LastUsedAt on the matched entry.
s.mu.Lock()
defer s.mu.Unlock()
for i := range s.keys {
k := &s.keys[i]
if !k.IsActive {
continue
}
if k.ExpiresAt != nil && k.ExpiresAt.Before(now) {
continue
}
if keyType != "" && k.KeyType != keyType {
continue
}
stored, _ := hex.DecodeString(k.KeyHash)
if subtle.ConstantTimeCompare(hashBytes, stored) != 1 {
continue
}
k.LastUsedAt = &now
result := *k
return &result, nil
}
return nil, fmt.Errorf("invalid or expired key")
}

View File

@@ -0,0 +1,256 @@
package security
import (
"context"
"crypto/rand"
"database/sql"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"sync"
"time"
"github.com/bitechdev/ResolveSpec/pkg/cache"
)
// DatabaseKeyStoreOptions configures DatabaseKeyStore.
type DatabaseKeyStoreOptions struct {
// Cache is an optional cache instance. If nil, uses the default cache.
Cache *cache.Cache
// CacheTTL is the duration to cache ValidateKey results.
// Default: 2 minutes.
CacheTTL time.Duration
// SQLNames provides custom procedure names. If nil, uses DefaultKeyStoreSQLNames().
SQLNames *KeyStoreSQLNames
// DBFactory is called to obtain a fresh *sql.DB when the existing connection is closed.
// If nil, reconnection is disabled.
DBFactory func() (*sql.DB, error)
}
// DatabaseKeyStore is a KeyStore backed by PostgreSQL stored procedures.
// All DB operations go through configurable procedure names; the raw key is
// never passed to the database.
//
// See keystore_schema.sql for the required table and procedure definitions.
//
// Note: DeleteKey invalidates the cache entry for the deleted key. Due to the
// cache TTL, a deleted key may continue to authenticate for up to CacheTTL
// (default 2 minutes) if the cache entry cannot be invalidated.
type DatabaseKeyStore struct {
db *sql.DB
dbMu sync.RWMutex
dbFactory func() (*sql.DB, error)
sqlNames *KeyStoreSQLNames
cache *cache.Cache
cacheTTL time.Duration
}
// NewDatabaseKeyStore creates a DatabaseKeyStore with optional configuration.
func NewDatabaseKeyStore(db *sql.DB, opts ...DatabaseKeyStoreOptions) *DatabaseKeyStore {
o := DatabaseKeyStoreOptions{}
if len(opts) > 0 {
o = opts[0]
}
if o.CacheTTL == 0 {
o.CacheTTL = 2 * time.Minute
}
c := o.Cache
if c == nil {
c = cache.GetDefaultCache()
}
names := MergeKeyStoreSQLNames(DefaultKeyStoreSQLNames(), o.SQLNames)
return &DatabaseKeyStore{
db: db,
dbFactory: o.DBFactory,
sqlNames: names,
cache: c,
cacheTTL: o.CacheTTL,
}
}
func (ks *DatabaseKeyStore) getDB() *sql.DB {
ks.dbMu.RLock()
defer ks.dbMu.RUnlock()
return ks.db
}
func (ks *DatabaseKeyStore) reconnectDB() error {
if ks.dbFactory == nil {
return fmt.Errorf("no db factory configured for reconnect")
}
newDB, err := ks.dbFactory()
if err != nil {
return err
}
ks.dbMu.Lock()
ks.db = newDB
ks.dbMu.Unlock()
return nil
}
// CreateKey generates a raw key, stores its SHA-256 hash via the create procedure,
// and returns the raw key once.
func (ks *DatabaseKeyStore) CreateKey(ctx context.Context, req CreateKeyRequest) (*CreateKeyResponse, error) {
rawBytes := make([]byte, 32)
if _, err := rand.Read(rawBytes); err != nil {
return nil, fmt.Errorf("failed to generate key material: %w", err)
}
rawKey := base64.RawURLEncoding.EncodeToString(rawBytes)
hash := hashSHA256Hex(rawKey)
type createRequest struct {
UserID int `json:"user_id"`
KeyType KeyType `json:"key_type"`
KeyHash string `json:"key_hash"`
Name string `json:"name"`
Scopes []string `json:"scopes,omitempty"`
Meta map[string]any `json:"meta,omitempty"`
ExpiresAt *time.Time `json:"expires_at,omitempty"`
}
reqJSON, err := json.Marshal(createRequest{
UserID: req.UserID,
KeyType: req.KeyType,
KeyHash: hash,
Name: req.Name,
Scopes: req.Scopes,
Meta: req.Meta,
ExpiresAt: req.ExpiresAt,
})
if err != nil {
return nil, fmt.Errorf("failed to marshal create key request: %w", err)
}
var success bool
var errorMsg sql.NullString
var keyJSON sql.NullString
query := fmt.Sprintf(`SELECT p_success, p_error, p_key::text FROM %s($1::jsonb)`, ks.sqlNames.CreateKey)
if err = ks.getDB().QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &keyJSON); err != nil {
return nil, fmt.Errorf("create key procedure failed: %w", err)
}
if !success {
return nil, errors.New(nullStringOr(errorMsg, "create key failed"))
}
var key UserKey
if err = json.Unmarshal([]byte(keyJSON.String), &key); err != nil {
return nil, fmt.Errorf("failed to parse created key: %w", err)
}
return &CreateKeyResponse{Key: key, RawKey: rawKey}, nil
}
// GetUserKeys returns all active, non-expired keys for the given user.
// Pass an empty KeyType to return all types.
func (ks *DatabaseKeyStore) GetUserKeys(ctx context.Context, userID int, keyType KeyType) ([]UserKey, error) {
var success bool
var errorMsg sql.NullString
var keysJSON sql.NullString
query := fmt.Sprintf(`SELECT p_success, p_error, p_keys::text FROM %s($1, $2)`, ks.sqlNames.GetUserKeys)
if err := ks.getDB().QueryRowContext(ctx, query, userID, string(keyType)).Scan(&success, &errorMsg, &keysJSON); err != nil {
return nil, fmt.Errorf("get user keys procedure failed: %w", err)
}
if !success {
return nil, errors.New(nullStringOr(errorMsg, "get user keys failed"))
}
var keys []UserKey
if keysJSON.Valid && keysJSON.String != "" && keysJSON.String != "[]" {
if err := json.Unmarshal([]byte(keysJSON.String), &keys); err != nil {
return nil, fmt.Errorf("failed to parse user keys: %w", err)
}
}
if keys == nil {
keys = []UserKey{}
}
return keys, nil
}
// DeleteKey soft-deletes a key after verifying ownership and invalidates its cache entry.
// The delete procedure returns the key_hash so no separate lookup is needed.
// Note: cache invalidation is best-effort; a cached entry may persist for up to CacheTTL.
func (ks *DatabaseKeyStore) DeleteKey(ctx context.Context, userID int, keyID int64) error {
var success bool
var errorMsg sql.NullString
var keyHash sql.NullString
query := fmt.Sprintf(`SELECT p_success, p_error, p_key_hash FROM %s($1, $2)`, ks.sqlNames.DeleteKey)
if err := ks.getDB().QueryRowContext(ctx, query, userID, keyID).Scan(&success, &errorMsg, &keyHash); err != nil {
return fmt.Errorf("delete key procedure failed: %w", err)
}
if !success {
return errors.New(nullStringOr(errorMsg, "delete key failed"))
}
if keyHash.Valid && keyHash.String != "" && ks.cache != nil {
_ = ks.cache.Delete(ctx, keystoreCacheKey(keyHash.String))
}
return nil
}
// ValidateKey hashes the raw key and calls the validate procedure.
// Results are cached for CacheTTL to reduce DB load on hot paths.
func (ks *DatabaseKeyStore) ValidateKey(ctx context.Context, rawKey string, keyType KeyType) (*UserKey, error) {
hash := hashSHA256Hex(rawKey)
cacheKey := keystoreCacheKey(hash)
if ks.cache != nil {
var cached UserKey
if err := ks.cache.Get(ctx, cacheKey, &cached); err == nil {
if cached.IsActive {
return &cached, nil
}
return nil, errors.New("invalid or expired key")
}
}
var success bool
var errorMsg sql.NullString
var keyJSON sql.NullString
runQuery := func() error {
query := fmt.Sprintf(`SELECT p_success, p_error, p_key::text FROM %s($1, $2)`, ks.sqlNames.ValidateKey)
return ks.getDB().QueryRowContext(ctx, query, hash, string(keyType)).Scan(&success, &errorMsg, &keyJSON)
}
if err := runQuery(); err != nil {
if isDBClosed(err) {
if reconnErr := ks.reconnectDB(); reconnErr == nil {
err = runQuery()
}
if err != nil {
return nil, fmt.Errorf("validate key procedure failed: %w", err)
}
} else {
return nil, fmt.Errorf("validate key procedure failed: %w", err)
}
}
if !success {
return nil, errors.New(nullStringOr(errorMsg, "invalid or expired key"))
}
var key UserKey
if err := json.Unmarshal([]byte(keyJSON.String), &key); err != nil {
return nil, fmt.Errorf("failed to parse validated key: %w", err)
}
if ks.cache != nil {
_ = ks.cache.Set(ctx, cacheKey, key, ks.cacheTTL)
}
return &key, nil
}
func keystoreCacheKey(hash string) string {
return "keystore:validate:" + hash
}
// nullStringOr returns s.String if valid, otherwise the fallback.
func nullStringOr(s sql.NullString, fallback string) string {
if s.Valid && s.String != "" {
return s.String
}
return fallback
}

View File

@@ -0,0 +1,187 @@
-- Keystore schema for per-user auth keys
-- Apply alongside database_schema.sql (requires the users table)
CREATE TABLE IF NOT EXISTS user_keys (
id BIGSERIAL PRIMARY KEY,
user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
key_type VARCHAR(50) NOT NULL,
key_hash VARCHAR(64) NOT NULL UNIQUE, -- SHA-256 hex digest (64 chars)
name VARCHAR(255) NOT NULL DEFAULT '',
scopes TEXT, -- JSON array, e.g. '["read","write"]'
meta JSONB,
expires_at TIMESTAMP,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
last_used_at TIMESTAMP,
is_active BOOLEAN DEFAULT true
);
CREATE INDEX IF NOT EXISTS idx_user_keys_user_id ON user_keys(user_id);
CREATE INDEX IF NOT EXISTS idx_user_keys_key_hash ON user_keys(key_hash);
CREATE INDEX IF NOT EXISTS idx_user_keys_key_type ON user_keys(key_type);
-- resolvespec_keystore_get_user_keys
-- Returns all active, non-expired keys for a user.
-- Pass empty p_key_type to return all key types.
CREATE OR REPLACE FUNCTION resolvespec_keystore_get_user_keys(
p_user_id INTEGER,
p_key_type TEXT DEFAULT ''
)
RETURNS TABLE(p_success BOOLEAN, p_error TEXT, p_keys JSONB)
LANGUAGE plpgsql AS $$
DECLARE
v_keys JSONB;
BEGIN
SELECT COALESCE(
jsonb_agg(
jsonb_build_object(
'id', k.id,
'user_id', k.user_id,
'key_type', k.key_type,
'name', k.name,
'scopes', CASE WHEN k.scopes IS NOT NULL THEN k.scopes::jsonb ELSE '[]'::jsonb END,
'meta', COALESCE(k.meta, '{}'::jsonb),
'expires_at', k.expires_at,
'created_at', k.created_at,
'last_used_at', k.last_used_at,
'is_active', k.is_active
)
),
'[]'::jsonb
)
INTO v_keys
FROM user_keys k
WHERE k.user_id = p_user_id
AND k.is_active = true
AND (k.expires_at IS NULL OR k.expires_at > NOW())
AND (p_key_type = '' OR k.key_type = p_key_type);
RETURN QUERY SELECT true, NULL::TEXT, v_keys;
EXCEPTION WHEN OTHERS THEN
RETURN QUERY SELECT false, SQLERRM, NULL::JSONB;
END;
$$;
-- resolvespec_keystore_create_key
-- Inserts a new key row. key_hash is provided by the caller (Go hashes the raw key).
-- Returns the created key record (without key_hash).
CREATE OR REPLACE FUNCTION resolvespec_keystore_create_key(
p_request JSONB
)
RETURNS TABLE(p_success BOOLEAN, p_error TEXT, p_key JSONB)
LANGUAGE plpgsql AS $$
DECLARE
v_id BIGINT;
v_created_at TIMESTAMP;
v_key JSONB;
BEGIN
INSERT INTO user_keys (user_id, key_type, key_hash, name, scopes, meta, expires_at)
VALUES (
(p_request->>'user_id')::INTEGER,
p_request->>'key_type',
p_request->>'key_hash',
COALESCE(p_request->>'name', ''),
p_request->>'scopes',
p_request->'meta',
CASE WHEN p_request->>'expires_at' IS NOT NULL
THEN (p_request->>'expires_at')::TIMESTAMP
ELSE NULL
END
)
RETURNING id, created_at INTO v_id, v_created_at;
v_key := jsonb_build_object(
'id', v_id,
'user_id', (p_request->>'user_id')::INTEGER,
'key_type', p_request->>'key_type',
'name', COALESCE(p_request->>'name', ''),
'scopes', CASE WHEN p_request->>'scopes' IS NOT NULL
THEN (p_request->>'scopes')::jsonb
ELSE '[]'::jsonb END,
'meta', COALESCE(p_request->'meta', '{}'::jsonb),
'expires_at', p_request->>'expires_at',
'created_at', v_created_at,
'is_active', true
);
RETURN QUERY SELECT true, NULL::TEXT, v_key;
EXCEPTION WHEN OTHERS THEN
RETURN QUERY SELECT false, SQLERRM, NULL::JSONB;
END;
$$;
-- resolvespec_keystore_delete_key
-- Soft-deletes a key (is_active = false) after verifying ownership.
-- Returns p_key_hash so the caller can invalidate cache entries without a separate query.
CREATE OR REPLACE FUNCTION resolvespec_keystore_delete_key(
p_user_id INTEGER,
p_key_id BIGINT
)
RETURNS TABLE(p_success BOOLEAN, p_error TEXT, p_key_hash TEXT)
LANGUAGE plpgsql AS $$
DECLARE
v_hash TEXT;
BEGIN
UPDATE user_keys
SET is_active = false
WHERE id = p_key_id AND user_id = p_user_id AND is_active = true
RETURNING key_hash INTO v_hash;
IF NOT FOUND THEN
RETURN QUERY SELECT false, 'key not found or already deleted'::TEXT, NULL::TEXT;
RETURN;
END IF;
RETURN QUERY SELECT true, NULL::TEXT, v_hash;
EXCEPTION WHEN OTHERS THEN
RETURN QUERY SELECT false, SQLERRM, NULL::TEXT;
END;
$$;
-- resolvespec_keystore_validate_key
-- Looks up a key by its SHA-256 hash, checks active status and expiry,
-- updates last_used_at, and returns the key record.
-- p_key_type can be empty to accept any key type.
CREATE OR REPLACE FUNCTION resolvespec_keystore_validate_key(
p_key_hash TEXT,
p_key_type TEXT DEFAULT ''
)
RETURNS TABLE(p_success BOOLEAN, p_error TEXT, p_key JSONB)
LANGUAGE plpgsql AS $$
DECLARE
v_key_rec user_keys%ROWTYPE;
v_key JSONB;
BEGIN
SELECT * INTO v_key_rec
FROM user_keys
WHERE key_hash = p_key_hash
AND is_active = true
AND (expires_at IS NULL OR expires_at > NOW())
AND (p_key_type = '' OR key_type = p_key_type);
IF NOT FOUND THEN
RETURN QUERY SELECT false, 'invalid or expired key'::TEXT, NULL::JSONB;
RETURN;
END IF;
UPDATE user_keys SET last_used_at = NOW() WHERE id = v_key_rec.id;
v_key := jsonb_build_object(
'id', v_key_rec.id,
'user_id', v_key_rec.user_id,
'key_type', v_key_rec.key_type,
'name', v_key_rec.name,
'scopes', CASE WHEN v_key_rec.scopes IS NOT NULL
THEN v_key_rec.scopes::jsonb
ELSE '[]'::jsonb END,
'meta', COALESCE(v_key_rec.meta, '{}'::jsonb),
'expires_at', v_key_rec.expires_at,
'created_at', v_key_rec.created_at,
'last_used_at', NOW(),
'is_active', v_key_rec.is_active
);
RETURN QUERY SELECT true, NULL::TEXT, v_key;
EXCEPTION WHEN OTHERS THEN
RETURN QUERY SELECT false, SQLERRM, NULL::JSONB;
END;
$$;

View File

@@ -0,0 +1,61 @@
package security
import "fmt"
// KeyStoreSQLNames holds the configurable stored procedure names used by DatabaseKeyStore.
// Use DefaultKeyStoreSQLNames() for defaults and MergeKeyStoreSQLNames() for partial overrides.
type KeyStoreSQLNames struct {
GetUserKeys string // default: "resolvespec_keystore_get_user_keys"
CreateKey string // default: "resolvespec_keystore_create_key"
DeleteKey string // default: "resolvespec_keystore_delete_key"
ValidateKey string // default: "resolvespec_keystore_validate_key"
}
// DefaultKeyStoreSQLNames returns a KeyStoreSQLNames with all default resolvespec_keystore_* values.
func DefaultKeyStoreSQLNames() *KeyStoreSQLNames {
return &KeyStoreSQLNames{
GetUserKeys: "resolvespec_keystore_get_user_keys",
CreateKey: "resolvespec_keystore_create_key",
DeleteKey: "resolvespec_keystore_delete_key",
ValidateKey: "resolvespec_keystore_validate_key",
}
}
// MergeKeyStoreSQLNames returns a copy of base with any non-empty fields from override applied.
// If override is nil, a copy of base is returned.
func MergeKeyStoreSQLNames(base, override *KeyStoreSQLNames) *KeyStoreSQLNames {
if override == nil {
copied := *base
return &copied
}
merged := *base
if override.GetUserKeys != "" {
merged.GetUserKeys = override.GetUserKeys
}
if override.CreateKey != "" {
merged.CreateKey = override.CreateKey
}
if override.DeleteKey != "" {
merged.DeleteKey = override.DeleteKey
}
if override.ValidateKey != "" {
merged.ValidateKey = override.ValidateKey
}
return &merged
}
// ValidateKeyStoreSQLNames checks that all non-empty procedure names are valid SQL identifiers.
func ValidateKeyStoreSQLNames(names *KeyStoreSQLNames) error {
fields := map[string]string{
"GetUserKeys": names.GetUserKeys,
"CreateKey": names.CreateKey,
"DeleteKey": names.DeleteKey,
"ValidateKey": names.ValidateKey,
}
for field, val := range fields {
if val != "" && !validSQLIdentifier.MatchString(val) {
return fmt.Errorf("KeyStoreSQLNames.%s contains invalid characters: %q", field, val)
}
}
return nil
}

View File

@@ -244,10 +244,10 @@ func (a *DatabaseAuthenticator) oauth2GetOrCreateUser(ctx context.Context, userC
var errMsg *string var errMsg *string
var userID *int var userID *int
err = a.db.QueryRowContext(ctx, ` err = a.getDB().QueryRowContext(ctx, fmt.Sprintf(`
SELECT p_success, p_error, p_user_id SELECT p_success, p_error, p_user_id
FROM resolvespec_oauth_getorcreateuser($1::jsonb) FROM %s($1::jsonb)
`, userJSON).Scan(&success, &errMsg, &userID) `, a.sqlNames.OAuthGetOrCreateUser), userJSON).Scan(&success, &errMsg, &userID)
if err != nil { if err != nil {
return 0, fmt.Errorf("failed to get or create user: %w", err) return 0, fmt.Errorf("failed to get or create user: %w", err)
@@ -287,10 +287,10 @@ func (a *DatabaseAuthenticator) oauth2CreateSession(ctx context.Context, session
var success bool var success bool
var errMsg *string var errMsg *string
err = a.db.QueryRowContext(ctx, ` err = a.getDB().QueryRowContext(ctx, fmt.Sprintf(`
SELECT p_success, p_error SELECT p_success, p_error
FROM resolvespec_oauth_createsession($1::jsonb) FROM %s($1::jsonb)
`, sessionJSON).Scan(&success, &errMsg) `, a.sqlNames.OAuthCreateSession), sessionJSON).Scan(&success, &errMsg)
if err != nil { if err != nil {
return fmt.Errorf("failed to create session: %w", err) return fmt.Errorf("failed to create session: %w", err)
@@ -385,10 +385,10 @@ func (a *DatabaseAuthenticator) OAuth2RefreshToken(ctx context.Context, refreshT
var errMsg *string var errMsg *string
var sessionData []byte var sessionData []byte
err = a.db.QueryRowContext(ctx, ` err = a.getDB().QueryRowContext(ctx, fmt.Sprintf(`
SELECT p_success, p_error, p_data::text SELECT p_success, p_error, p_data::text
FROM resolvespec_oauth_getrefreshtoken($1) FROM %s($1)
`, refreshToken).Scan(&success, &errMsg, &sessionData) `, a.sqlNames.OAuthGetRefreshToken), refreshToken).Scan(&success, &errMsg, &sessionData)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get session by refresh token: %w", err) return nil, fmt.Errorf("failed to get session by refresh token: %w", err)
@@ -451,10 +451,10 @@ func (a *DatabaseAuthenticator) OAuth2RefreshToken(ctx context.Context, refreshT
var updateSuccess bool var updateSuccess bool
var updateErrMsg *string var updateErrMsg *string
err = a.db.QueryRowContext(ctx, ` err = a.getDB().QueryRowContext(ctx, fmt.Sprintf(`
SELECT p_success, p_error SELECT p_success, p_error
FROM resolvespec_oauth_updaterefreshtoken($1::jsonb) FROM %s($1::jsonb)
`, updateJSON).Scan(&updateSuccess, &updateErrMsg) `, a.sqlNames.OAuthUpdateRefreshToken), updateJSON).Scan(&updateSuccess, &updateErrMsg)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to update session: %w", err) return nil, fmt.Errorf("failed to update session: %w", err)
@@ -472,10 +472,10 @@ func (a *DatabaseAuthenticator) OAuth2RefreshToken(ctx context.Context, refreshT
var userErrMsg *string var userErrMsg *string
var userData []byte var userData []byte
err = a.db.QueryRowContext(ctx, ` err = a.getDB().QueryRowContext(ctx, fmt.Sprintf(`
SELECT p_success, p_error, p_data::text SELECT p_success, p_error, p_data::text
FROM resolvespec_oauth_getuser($1) FROM %s($1)
`, session.UserID).Scan(&userSuccess, &userErrMsg, &userData) `, a.sqlNames.OAuthGetUser), session.UserID).Scan(&userSuccess, &userErrMsg, &userData)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get user data: %w", err) return nil, fmt.Errorf("failed to get user data: %w", err)

View File

@@ -7,16 +7,21 @@ import (
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"fmt" "fmt"
"sync"
"time" "time"
) )
// DatabasePasskeyProvider implements PasskeyProvider using database storage // DatabasePasskeyProvider implements PasskeyProvider using database storage
// Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults)
type DatabasePasskeyProvider struct { type DatabasePasskeyProvider struct {
db *sql.DB db *sql.DB
rpID string // Relying Party ID (domain) dbMu sync.RWMutex
rpName string // Relying Party display name dbFactory func() (*sql.DB, error)
rpOrigin string // Expected origin for WebAuthn rpID string // Relying Party ID (domain)
timeout int64 // Timeout in milliseconds (default: 60000) rpName string // Relying Party display name
rpOrigin string // Expected origin for WebAuthn
timeout int64 // Timeout in milliseconds (default: 60000)
sqlNames *SQLNames
} }
// DatabasePasskeyProviderOptions configures the passkey provider // DatabasePasskeyProviderOptions configures the passkey provider
@@ -29,6 +34,11 @@ type DatabasePasskeyProviderOptions struct {
RPOrigin string RPOrigin string
// Timeout is the timeout for operations in milliseconds (default: 60000) // Timeout is the timeout for operations in milliseconds (default: 60000)
Timeout int64 Timeout int64
// SQLNames provides custom SQL procedure/function names. If nil, uses DefaultSQLNames().
SQLNames *SQLNames
// DBFactory is called to obtain a fresh *sql.DB when the existing connection is closed.
// If nil, reconnection is disabled.
DBFactory func() (*sql.DB, error)
} }
// NewDatabasePasskeyProvider creates a new database-backed passkey provider // NewDatabasePasskeyProvider creates a new database-backed passkey provider
@@ -37,15 +47,39 @@ func NewDatabasePasskeyProvider(db *sql.DB, opts DatabasePasskeyProviderOptions)
opts.Timeout = 60000 // 60 seconds default opts.Timeout = 60000 // 60 seconds default
} }
sqlNames := MergeSQLNames(DefaultSQLNames(), opts.SQLNames)
return &DatabasePasskeyProvider{ return &DatabasePasskeyProvider{
db: db, db: db,
rpID: opts.RPID, dbFactory: opts.DBFactory,
rpName: opts.RPName, rpID: opts.RPID,
rpOrigin: opts.RPOrigin, rpName: opts.RPName,
timeout: opts.Timeout, rpOrigin: opts.RPOrigin,
timeout: opts.Timeout,
sqlNames: sqlNames,
} }
} }
func (p *DatabasePasskeyProvider) getDB() *sql.DB {
p.dbMu.RLock()
defer p.dbMu.RUnlock()
return p.db
}
func (p *DatabasePasskeyProvider) reconnectDB() error {
if p.dbFactory == nil {
return fmt.Errorf("no db factory configured for reconnect")
}
newDB, err := p.dbFactory()
if err != nil {
return err
}
p.dbMu.Lock()
p.db = newDB
p.dbMu.Unlock()
return nil
}
// BeginRegistration creates registration options for a new passkey // BeginRegistration creates registration options for a new passkey
func (p *DatabasePasskeyProvider) BeginRegistration(ctx context.Context, userID int, username, displayName string) (*PasskeyRegistrationOptions, error) { func (p *DatabasePasskeyProvider) BeginRegistration(ctx context.Context, userID int, username, displayName string) (*PasskeyRegistrationOptions, error) {
// Generate challenge // Generate challenge
@@ -132,8 +166,8 @@ func (p *DatabasePasskeyProvider) CompleteRegistration(ctx context.Context, user
var errorMsg sql.NullString var errorMsg sql.NullString
var credentialID sql.NullInt64 var credentialID sql.NullInt64
query := `SELECT p_success, p_error, p_credential_id FROM resolvespec_passkey_store_credential($1::jsonb)` query := fmt.Sprintf(`SELECT p_success, p_error, p_credential_id FROM %s($1::jsonb)`, p.sqlNames.PasskeyStoreCredential)
err = p.db.QueryRowContext(ctx, query, string(credJSON)).Scan(&success, &errorMsg, &credentialID) err = p.getDB().QueryRowContext(ctx, query, string(credJSON)).Scan(&success, &errorMsg, &credentialID)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to store credential: %w", err) return nil, fmt.Errorf("failed to store credential: %w", err)
} }
@@ -173,8 +207,8 @@ func (p *DatabasePasskeyProvider) BeginAuthentication(ctx context.Context, usern
var userID sql.NullInt64 var userID sql.NullInt64
var credentialsJSON sql.NullString var credentialsJSON sql.NullString
query := `SELECT p_success, p_error, p_user_id, p_credentials::text FROM resolvespec_passkey_get_credentials_by_username($1)` query := fmt.Sprintf(`SELECT p_success, p_error, p_user_id, p_credentials::text FROM %s($1)`, p.sqlNames.PasskeyGetCredsByUsername)
err := p.db.QueryRowContext(ctx, query, username).Scan(&success, &errorMsg, &userID, &credentialsJSON) err := p.getDB().QueryRowContext(ctx, query, username).Scan(&success, &errorMsg, &userID, &credentialsJSON)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get credentials: %w", err) return nil, fmt.Errorf("failed to get credentials: %w", err)
} }
@@ -233,8 +267,16 @@ func (p *DatabasePasskeyProvider) CompleteAuthentication(ctx context.Context, re
var errorMsg sql.NullString var errorMsg sql.NullString
var credentialJSON sql.NullString var credentialJSON sql.NullString
query := `SELECT p_success, p_error, p_credential::text FROM resolvespec_passkey_get_credential($1)` runQuery := func() error {
err := p.db.QueryRowContext(ctx, query, response.RawID).Scan(&success, &errorMsg, &credentialJSON) query := fmt.Sprintf(`SELECT p_success, p_error, p_credential::text FROM %s($1)`, p.sqlNames.PasskeyGetCredential)
return p.getDB().QueryRowContext(ctx, query, response.RawID).Scan(&success, &errorMsg, &credentialJSON)
}
err := runQuery()
if isDBClosed(err) {
if reconnErr := p.reconnectDB(); reconnErr == nil {
err = runQuery()
}
}
if err != nil { if err != nil {
return 0, fmt.Errorf("failed to get credential: %w", err) return 0, fmt.Errorf("failed to get credential: %w", err)
} }
@@ -264,8 +306,8 @@ func (p *DatabasePasskeyProvider) CompleteAuthentication(ctx context.Context, re
var updateError sql.NullString var updateError sql.NullString
var cloneWarning sql.NullBool var cloneWarning sql.NullBool
updateQuery := `SELECT p_success, p_error, p_clone_warning FROM resolvespec_passkey_update_counter($1, $2)` updateQuery := fmt.Sprintf(`SELECT p_success, p_error, p_clone_warning FROM %s($1, $2)`, p.sqlNames.PasskeyUpdateCounter)
err = p.db.QueryRowContext(ctx, updateQuery, response.RawID, newCounter).Scan(&updateSuccess, &updateError, &cloneWarning) err = p.getDB().QueryRowContext(ctx, updateQuery, response.RawID, newCounter).Scan(&updateSuccess, &updateError, &cloneWarning)
if err != nil { if err != nil {
return 0, fmt.Errorf("failed to update counter: %w", err) return 0, fmt.Errorf("failed to update counter: %w", err)
} }
@@ -283,8 +325,8 @@ func (p *DatabasePasskeyProvider) GetCredentials(ctx context.Context, userID int
var errorMsg sql.NullString var errorMsg sql.NullString
var credentialsJSON sql.NullString var credentialsJSON sql.NullString
query := `SELECT p_success, p_error, p_credentials::text FROM resolvespec_passkey_get_user_credentials($1)` query := fmt.Sprintf(`SELECT p_success, p_error, p_credentials::text FROM %s($1)`, p.sqlNames.PasskeyGetUserCredentials)
err := p.db.QueryRowContext(ctx, query, userID).Scan(&success, &errorMsg, &credentialsJSON) err := p.getDB().QueryRowContext(ctx, query, userID).Scan(&success, &errorMsg, &credentialsJSON)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get credentials: %w", err) return nil, fmt.Errorf("failed to get credentials: %w", err)
} }
@@ -362,8 +404,8 @@ func (p *DatabasePasskeyProvider) DeleteCredential(ctx context.Context, userID i
var success bool var success bool
var errorMsg sql.NullString var errorMsg sql.NullString
query := `SELECT p_success, p_error FROM resolvespec_passkey_delete_credential($1, $2)` query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1, $2)`, p.sqlNames.PasskeyDeleteCredential)
err = p.db.QueryRowContext(ctx, query, userID, credID).Scan(&success, &errorMsg) err = p.getDB().QueryRowContext(ctx, query, userID, credID).Scan(&success, &errorMsg)
if err != nil { if err != nil {
return fmt.Errorf("failed to delete credential: %w", err) return fmt.Errorf("failed to delete credential: %w", err)
} }
@@ -388,8 +430,8 @@ func (p *DatabasePasskeyProvider) UpdateCredentialName(ctx context.Context, user
var success bool var success bool
var errorMsg sql.NullString var errorMsg sql.NullString
query := `SELECT p_success, p_error FROM resolvespec_passkey_update_name($1, $2, $3)` query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1, $2, $3)`, p.sqlNames.PasskeyUpdateName)
err = p.db.QueryRowContext(ctx, query, userID, credID, name).Scan(&success, &errorMsg) err = p.getDB().QueryRowContext(ctx, query, userID, credID, name).Scan(&success, &errorMsg)
if err != nil { if err != nil {
return fmt.Errorf("failed to update credential name: %w", err) return fmt.Errorf("failed to update credential name: %w", err)
} }

View File

@@ -58,15 +58,17 @@ func (a *HeaderAuthenticator) Authenticate(r *http.Request) (*UserContext, error
// DatabaseAuthenticator provides session-based authentication with database storage // DatabaseAuthenticator provides session-based authentication with database storage
// All database operations go through stored procedures for security and consistency // All database operations go through stored procedures for security and consistency
// Requires stored procedures: resolvespec_login, resolvespec_logout, resolvespec_session, // Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults)
// resolvespec_session_update, resolvespec_refresh_token
// See database_schema.sql for procedure definitions // See database_schema.sql for procedure definitions
// Also supports multiple OAuth2 providers configured with WithOAuth2() // Also supports multiple OAuth2 providers configured with WithOAuth2()
// Also supports passkey authentication configured with WithPasskey() // Also supports passkey authentication configured with WithPasskey()
type DatabaseAuthenticator struct { type DatabaseAuthenticator struct {
db *sql.DB db *sql.DB
cache *cache.Cache dbMu sync.RWMutex
cacheTTL time.Duration dbFactory func() (*sql.DB, error)
cache *cache.Cache
cacheTTL time.Duration
sqlNames *SQLNames
// OAuth2 providers registry (multiple providers supported) // OAuth2 providers registry (multiple providers supported)
oauth2Providers map[string]*OAuth2Provider oauth2Providers map[string]*OAuth2Provider
@@ -85,6 +87,12 @@ type DatabaseAuthenticatorOptions struct {
Cache *cache.Cache Cache *cache.Cache
// PasskeyProvider is an optional passkey provider for WebAuthn/FIDO2 authentication // PasskeyProvider is an optional passkey provider for WebAuthn/FIDO2 authentication
PasskeyProvider PasskeyProvider PasskeyProvider PasskeyProvider
// SQLNames provides custom SQL procedure/function names. If nil, uses DefaultSQLNames().
// Partial overrides are supported: only set the fields you want to change.
SQLNames *SQLNames
// DBFactory is called to obtain a fresh *sql.DB when the existing connection is closed.
// If nil, reconnection is disabled.
DBFactory func() (*sql.DB, error)
} }
func NewDatabaseAuthenticator(db *sql.DB) *DatabaseAuthenticator { func NewDatabaseAuthenticator(db *sql.DB) *DatabaseAuthenticator {
@@ -103,14 +111,38 @@ func NewDatabaseAuthenticatorWithOptions(db *sql.DB, opts DatabaseAuthenticatorO
cacheInstance = cache.GetDefaultCache() cacheInstance = cache.GetDefaultCache()
} }
sqlNames := MergeSQLNames(DefaultSQLNames(), opts.SQLNames)
return &DatabaseAuthenticator{ return &DatabaseAuthenticator{
db: db, db: db,
dbFactory: opts.DBFactory,
cache: cacheInstance, cache: cacheInstance,
cacheTTL: opts.CacheTTL, cacheTTL: opts.CacheTTL,
sqlNames: sqlNames,
passkeyProvider: opts.PasskeyProvider, passkeyProvider: opts.PasskeyProvider,
} }
} }
func (a *DatabaseAuthenticator) getDB() *sql.DB {
a.dbMu.RLock()
defer a.dbMu.RUnlock()
return a.db
}
func (a *DatabaseAuthenticator) reconnectDB() error {
if a.dbFactory == nil {
return fmt.Errorf("no db factory configured for reconnect")
}
newDB, err := a.dbFactory()
if err != nil {
return err
}
a.dbMu.Lock()
a.db = newDB
a.dbMu.Unlock()
return nil
}
func (a *DatabaseAuthenticator) Login(ctx context.Context, req LoginRequest) (*LoginResponse, error) { func (a *DatabaseAuthenticator) Login(ctx context.Context, req LoginRequest) (*LoginResponse, error) {
// Convert LoginRequest to JSON // Convert LoginRequest to JSON
reqJSON, err := json.Marshal(req) reqJSON, err := json.Marshal(req)
@@ -118,13 +150,20 @@ func (a *DatabaseAuthenticator) Login(ctx context.Context, req LoginRequest) (*L
return nil, fmt.Errorf("failed to marshal login request: %w", err) return nil, fmt.Errorf("failed to marshal login request: %w", err)
} }
// Call resolvespec_login stored procedure
var success bool var success bool
var errorMsg sql.NullString var errorMsg sql.NullString
var dataJSON sql.NullString var dataJSON sql.NullString
query := `SELECT p_success, p_error, p_data::text FROM resolvespec_login($1::jsonb)` runLoginQuery := func() error {
err = a.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON) query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Login)
return a.getDB().QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
}
err = runLoginQuery()
if isDBClosed(err) {
if reconnErr := a.reconnectDB(); reconnErr == nil {
err = runLoginQuery()
}
}
if err != nil { if err != nil {
return nil, fmt.Errorf("login query failed: %w", err) return nil, fmt.Errorf("login query failed: %w", err)
} }
@@ -153,13 +192,12 @@ func (a *DatabaseAuthenticator) Register(ctx context.Context, req RegisterReques
return nil, fmt.Errorf("failed to marshal register request: %w", err) return nil, fmt.Errorf("failed to marshal register request: %w", err)
} }
// Call resolvespec_register stored procedure
var success bool var success bool
var errorMsg sql.NullString var errorMsg sql.NullString
var dataJSON sql.NullString var dataJSON sql.NullString
query := `SELECT p_success, p_error, p_data::text FROM resolvespec_register($1::jsonb)` query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Register)
err = a.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON) err = a.getDB().QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
if err != nil { if err != nil {
return nil, fmt.Errorf("register query failed: %w", err) return nil, fmt.Errorf("register query failed: %w", err)
} }
@@ -187,13 +225,12 @@ func (a *DatabaseAuthenticator) Logout(ctx context.Context, req LogoutRequest) e
return fmt.Errorf("failed to marshal logout request: %w", err) return fmt.Errorf("failed to marshal logout request: %w", err)
} }
// Call resolvespec_logout stored procedure
var success bool var success bool
var errorMsg sql.NullString var errorMsg sql.NullString
var dataJSON sql.NullString var dataJSON sql.NullString
query := `SELECT p_success, p_error, p_data::text FROM resolvespec_logout($1::jsonb)` query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Logout)
err = a.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON) err = a.getDB().QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
if err != nil { if err != nil {
return fmt.Errorf("logout query failed: %w", err) return fmt.Errorf("logout query failed: %w", err)
} }
@@ -266,8 +303,8 @@ func (a *DatabaseAuthenticator) Authenticate(r *http.Request) (*UserContext, err
var errorMsg sql.NullString var errorMsg sql.NullString
var userJSON sql.NullString var userJSON sql.NullString
query := `SELECT p_success, p_error, p_user::text FROM resolvespec_session($1, $2)` query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2)`, a.sqlNames.Session)
err := a.db.QueryRowContext(r.Context(), query, token, reference).Scan(&success, &errorMsg, &userJSON) err := a.getDB().QueryRowContext(r.Context(), query, token, reference).Scan(&success, &errorMsg, &userJSON)
if err != nil { if err != nil {
return nil, fmt.Errorf("session query failed: %w", err) return nil, fmt.Errorf("session query failed: %w", err)
} }
@@ -338,25 +375,23 @@ func (a *DatabaseAuthenticator) updateSessionActivity(ctx context.Context, sessi
return return
} }
// Call resolvespec_session_update stored procedure
var success bool var success bool
var errorMsg sql.NullString var errorMsg sql.NullString
var updatedUserJSON sql.NullString var updatedUserJSON sql.NullString
query := `SELECT p_success, p_error, p_user::text FROM resolvespec_session_update($1, $2::jsonb)` query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2::jsonb)`, a.sqlNames.SessionUpdate)
_ = a.db.QueryRowContext(ctx, query, sessionToken, string(userJSON)).Scan(&success, &errorMsg, &updatedUserJSON) _ = a.getDB().QueryRowContext(ctx, query, sessionToken, string(userJSON)).Scan(&success, &errorMsg, &updatedUserJSON)
} }
// RefreshToken implements Refreshable interface // RefreshToken implements Refreshable interface
func (a *DatabaseAuthenticator) RefreshToken(ctx context.Context, refreshToken string) (*LoginResponse, error) { func (a *DatabaseAuthenticator) RefreshToken(ctx context.Context, refreshToken string) (*LoginResponse, error) {
// Call api_refresh_token stored procedure
// First, we need to get the current user context for the refresh token // First, we need to get the current user context for the refresh token
var success bool var success bool
var errorMsg sql.NullString var errorMsg sql.NullString
var userJSON sql.NullString var userJSON sql.NullString
// Get current session to pass to refresh // Get current session to pass to refresh
query := `SELECT p_success, p_error, p_user::text FROM resolvespec_session($1, $2)` query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2)`, a.sqlNames.Session)
err := a.db.QueryRowContext(ctx, query, refreshToken, "refresh").Scan(&success, &errorMsg, &userJSON) err := a.getDB().QueryRowContext(ctx, query, refreshToken, "refresh").Scan(&success, &errorMsg, &userJSON)
if err != nil { if err != nil {
return nil, fmt.Errorf("refresh token query failed: %w", err) return nil, fmt.Errorf("refresh token query failed: %w", err)
} }
@@ -368,13 +403,12 @@ func (a *DatabaseAuthenticator) RefreshToken(ctx context.Context, refreshToken s
return nil, fmt.Errorf("invalid refresh token") return nil, fmt.Errorf("invalid refresh token")
} }
// Call resolvespec_refresh_token to generate new token
var newSuccess bool var newSuccess bool
var newErrorMsg sql.NullString var newErrorMsg sql.NullString
var newUserJSON sql.NullString var newUserJSON sql.NullString
refreshQuery := `SELECT p_success, p_error, p_user::text FROM resolvespec_refresh_token($1, $2::jsonb)` refreshQuery := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2::jsonb)`, a.sqlNames.RefreshToken)
err = a.db.QueryRowContext(ctx, refreshQuery, refreshToken, userJSON).Scan(&newSuccess, &newErrorMsg, &newUserJSON) err = a.getDB().QueryRowContext(ctx, refreshQuery, refreshToken, userJSON).Scan(&newSuccess, &newErrorMsg, &newUserJSON)
if err != nil { if err != nil {
return nil, fmt.Errorf("refresh token generation failed: %w", err) return nil, fmt.Errorf("refresh token generation failed: %w", err)
} }
@@ -401,28 +435,65 @@ func (a *DatabaseAuthenticator) RefreshToken(ctx context.Context, refreshToken s
// JWTAuthenticator provides JWT token-based authentication // JWTAuthenticator provides JWT token-based authentication
// All database operations go through stored procedures // All database operations go through stored procedures
// Requires stored procedures: resolvespec_jwt_login, resolvespec_jwt_logout // Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults)
// NOTE: JWT signing/verification requires github.com/golang-jwt/jwt/v5 to be installed and imported // NOTE: JWT signing/verification requires github.com/golang-jwt/jwt/v5 to be installed and imported
type JWTAuthenticator struct { type JWTAuthenticator struct {
secretKey []byte secretKey []byte
db *sql.DB db *sql.DB
dbMu sync.RWMutex
dbFactory func() (*sql.DB, error)
sqlNames *SQLNames
} }
func NewJWTAuthenticator(secretKey string, db *sql.DB) *JWTAuthenticator { func NewJWTAuthenticator(secretKey string, db *sql.DB, names ...*SQLNames) *JWTAuthenticator {
return &JWTAuthenticator{ return &JWTAuthenticator{
secretKey: []byte(secretKey), secretKey: []byte(secretKey),
db: db, db: db,
sqlNames: resolveSQLNames(names...),
} }
} }
// WithDBFactory configures a factory used to reopen the database connection if it is closed.
func (a *JWTAuthenticator) WithDBFactory(factory func() (*sql.DB, error)) *JWTAuthenticator {
a.dbFactory = factory
return a
}
func (a *JWTAuthenticator) getDB() *sql.DB {
a.dbMu.RLock()
defer a.dbMu.RUnlock()
return a.db
}
func (a *JWTAuthenticator) reconnectDB() error {
if a.dbFactory == nil {
return fmt.Errorf("no db factory configured for reconnect")
}
newDB, err := a.dbFactory()
if err != nil {
return err
}
a.dbMu.Lock()
a.db = newDB
a.dbMu.Unlock()
return nil
}
func (a *JWTAuthenticator) Login(ctx context.Context, req LoginRequest) (*LoginResponse, error) { func (a *JWTAuthenticator) Login(ctx context.Context, req LoginRequest) (*LoginResponse, error) {
// Call resolvespec_jwt_login stored procedure
var success bool var success bool
var errorMsg sql.NullString var errorMsg sql.NullString
var userJSON []byte var userJSON []byte
query := `SELECT p_success, p_error, p_user FROM resolvespec_jwt_login($1, $2)` runLoginQuery := func() error {
err := a.db.QueryRowContext(ctx, query, req.Username, req.Password).Scan(&success, &errorMsg, &userJSON) query := fmt.Sprintf(`SELECT p_success, p_error, p_user FROM %s($1, $2)`, a.sqlNames.JWTLogin)
return a.getDB().QueryRowContext(ctx, query, req.Username, req.Password).Scan(&success, &errorMsg, &userJSON)
}
err := runLoginQuery()
if isDBClosed(err) {
if reconnErr := a.reconnectDB(); reconnErr == nil {
err = runLoginQuery()
}
}
if err != nil { if err != nil {
return nil, fmt.Errorf("login query failed: %w", err) return nil, fmt.Errorf("login query failed: %w", err)
} }
@@ -471,12 +542,11 @@ func (a *JWTAuthenticator) Login(ctx context.Context, req LoginRequest) (*LoginR
} }
func (a *JWTAuthenticator) Logout(ctx context.Context, req LogoutRequest) error { func (a *JWTAuthenticator) Logout(ctx context.Context, req LogoutRequest) error {
// Call resolvespec_jwt_logout stored procedure
var success bool var success bool
var errorMsg sql.NullString var errorMsg sql.NullString
query := `SELECT p_success, p_error FROM resolvespec_jwt_logout($1, $2)` query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1, $2)`, a.sqlNames.JWTLogout)
err := a.db.QueryRowContext(ctx, query, req.Token, req.UserID).Scan(&success, &errorMsg) err := a.getDB().QueryRowContext(ctx, query, req.Token, req.UserID).Scan(&success, &errorMsg)
if err != nil { if err != nil {
return fmt.Errorf("logout query failed: %w", err) return fmt.Errorf("logout query failed: %w", err)
} }
@@ -511,25 +581,60 @@ func (a *JWTAuthenticator) Authenticate(r *http.Request) (*UserContext, error) {
// DatabaseColumnSecurityProvider loads column security from database // DatabaseColumnSecurityProvider loads column security from database
// All database operations go through stored procedures // All database operations go through stored procedures
// Requires stored procedure: resolvespec_column_security // Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults)
type DatabaseColumnSecurityProvider struct { type DatabaseColumnSecurityProvider struct {
db *sql.DB db *sql.DB
dbMu sync.RWMutex
dbFactory func() (*sql.DB, error)
sqlNames *SQLNames
} }
func NewDatabaseColumnSecurityProvider(db *sql.DB) *DatabaseColumnSecurityProvider { func NewDatabaseColumnSecurityProvider(db *sql.DB, names ...*SQLNames) *DatabaseColumnSecurityProvider {
return &DatabaseColumnSecurityProvider{db: db} return &DatabaseColumnSecurityProvider{db: db, sqlNames: resolveSQLNames(names...)}
}
func (p *DatabaseColumnSecurityProvider) WithDBFactory(factory func() (*sql.DB, error)) *DatabaseColumnSecurityProvider {
p.dbFactory = factory
return p
}
func (p *DatabaseColumnSecurityProvider) getDB() *sql.DB {
p.dbMu.RLock()
defer p.dbMu.RUnlock()
return p.db
}
func (p *DatabaseColumnSecurityProvider) reconnectDB() error {
if p.dbFactory == nil {
return fmt.Errorf("no db factory configured for reconnect")
}
newDB, err := p.dbFactory()
if err != nil {
return err
}
p.dbMu.Lock()
p.db = newDB
p.dbMu.Unlock()
return nil
} }
func (p *DatabaseColumnSecurityProvider) GetColumnSecurity(ctx context.Context, userID int, schema, table string) ([]ColumnSecurity, error) { func (p *DatabaseColumnSecurityProvider) GetColumnSecurity(ctx context.Context, userID int, schema, table string) ([]ColumnSecurity, error) {
var rules []ColumnSecurity var rules []ColumnSecurity
// Call resolvespec_column_security stored procedure
var success bool var success bool
var errorMsg sql.NullString var errorMsg sql.NullString
var rulesJSON []byte var rulesJSON []byte
query := `SELECT p_success, p_error, p_rules FROM resolvespec_column_security($1, $2, $3)` runQuery := func() error {
err := p.db.QueryRowContext(ctx, query, userID, schema, table).Scan(&success, &errorMsg, &rulesJSON) query := fmt.Sprintf(`SELECT p_success, p_error, p_rules FROM %s($1, $2, $3)`, p.sqlNames.ColumnSecurity)
return p.getDB().QueryRowContext(ctx, query, userID, schema, table).Scan(&success, &errorMsg, &rulesJSON)
}
err := runQuery()
if isDBClosed(err) {
if reconnErr := p.reconnectDB(); reconnErr == nil {
err = runQuery()
}
}
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to load column security: %w", err) return nil, fmt.Errorf("failed to load column security: %w", err)
} }
@@ -576,23 +681,57 @@ func (p *DatabaseColumnSecurityProvider) GetColumnSecurity(ctx context.Context,
// DatabaseRowSecurityProvider loads row security from database // DatabaseRowSecurityProvider loads row security from database
// All database operations go through stored procedures // All database operations go through stored procedures
// Requires stored procedure: resolvespec_row_security // Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults)
type DatabaseRowSecurityProvider struct { type DatabaseRowSecurityProvider struct {
db *sql.DB db *sql.DB
dbMu sync.RWMutex
dbFactory func() (*sql.DB, error)
sqlNames *SQLNames
} }
func NewDatabaseRowSecurityProvider(db *sql.DB) *DatabaseRowSecurityProvider { func NewDatabaseRowSecurityProvider(db *sql.DB, names ...*SQLNames) *DatabaseRowSecurityProvider {
return &DatabaseRowSecurityProvider{db: db} return &DatabaseRowSecurityProvider{db: db, sqlNames: resolveSQLNames(names...)}
}
func (p *DatabaseRowSecurityProvider) WithDBFactory(factory func() (*sql.DB, error)) *DatabaseRowSecurityProvider {
p.dbFactory = factory
return p
}
func (p *DatabaseRowSecurityProvider) getDB() *sql.DB {
p.dbMu.RLock()
defer p.dbMu.RUnlock()
return p.db
}
func (p *DatabaseRowSecurityProvider) reconnectDB() error {
if p.dbFactory == nil {
return fmt.Errorf("no db factory configured for reconnect")
}
newDB, err := p.dbFactory()
if err != nil {
return err
}
p.dbMu.Lock()
p.db = newDB
p.dbMu.Unlock()
return nil
} }
func (p *DatabaseRowSecurityProvider) GetRowSecurity(ctx context.Context, userID int, schema, table string) (RowSecurity, error) { func (p *DatabaseRowSecurityProvider) GetRowSecurity(ctx context.Context, userID int, schema, table string) (RowSecurity, error) {
var template string var template string
var hasBlock bool var hasBlock bool
// Call resolvespec_row_security stored procedure runQuery := func() error {
query := `SELECT p_template, p_block FROM resolvespec_row_security($1, $2, $3)` query := fmt.Sprintf(`SELECT p_template, p_block FROM %s($1, $2, $3)`, p.sqlNames.RowSecurity)
return p.getDB().QueryRowContext(ctx, query, schema, table, userID).Scan(&template, &hasBlock)
err := p.db.QueryRowContext(ctx, query, schema, table, userID).Scan(&template, &hasBlock) }
err := runQuery()
if isDBClosed(err) {
if reconnErr := p.reconnectDB(); reconnErr == nil {
err = runQuery()
}
}
if err != nil { if err != nil {
return RowSecurity{}, fmt.Errorf("failed to load row security: %w", err) return RowSecurity{}, fmt.Errorf("failed to load row security: %w", err)
} }
@@ -662,6 +801,11 @@ func (p *ConfigRowSecurityProvider) GetRowSecurity(ctx context.Context, userID i
// Helper functions // Helper functions
// ================ // ================
// isDBClosed reports whether err indicates the *sql.DB has been closed.
func isDBClosed(err error) bool {
return err != nil && strings.Contains(err.Error(), "sql: database is closed")
}
func parseRoles(rolesStr string) []string { func parseRoles(rolesStr string) []string {
if rolesStr == "" { if rolesStr == "" {
return []string{} return []string{}
@@ -758,56 +902,55 @@ func (a *DatabaseAuthenticator) LoginWithPasskey(ctx context.Context, req Passke
return nil, fmt.Errorf("passkey authentication failed: %w", err) return nil, fmt.Errorf("passkey authentication failed: %w", err)
} }
// Get user data from database // Build request JSON for passkey login stored procedure
var username, email, roles string reqData := map[string]any{
var userLevel int "user_id": userID,
query := `SELECT username, email, user_level, COALESCE(roles, '') FROM users WHERE id = $1 AND is_active = true`
err = a.db.QueryRowContext(ctx, query, userID).Scan(&username, &email, &userLevel, &roles)
if err != nil {
return nil, fmt.Errorf("failed to get user data: %w", err)
} }
// Generate session token
sessionToken := "sess_" + generateRandomString(32) + "_" + fmt.Sprintf("%d", time.Now().Unix())
expiresAt := time.Now().Add(24 * time.Hour)
// Extract IP and user agent from claims
ipAddress := ""
userAgent := ""
if req.Claims != nil { if req.Claims != nil {
if ip, ok := req.Claims["ip_address"].(string); ok { if ip, ok := req.Claims["ip_address"].(string); ok {
ipAddress = ip reqData["ip_address"] = ip
} }
if ua, ok := req.Claims["user_agent"].(string); ok { if ua, ok := req.Claims["user_agent"].(string); ok {
userAgent = ua reqData["user_agent"] = ua
} }
} }
// Create session reqJSON, err := json.Marshal(reqData)
insertQuery := `INSERT INTO user_sessions (session_token, user_id, expires_at, ip_address, user_agent, last_activity_at)
VALUES ($1, $2, $3, $4, $5, now())`
_, err = a.db.ExecContext(ctx, insertQuery, sessionToken, userID, expiresAt, ipAddress, userAgent)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create session: %w", err) return nil, fmt.Errorf("failed to marshal passkey login request: %w", err)
} }
// Update last login var success bool
updateQuery := `UPDATE users SET last_login_at = now() WHERE id = $1` var errorMsg sql.NullString
_, _ = a.db.ExecContext(ctx, updateQuery, userID) var dataJSON sql.NullString
// Return login response runPasskeyQuery := func() error {
return &LoginResponse{ query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.PasskeyLogin)
Token: sessionToken, return a.getDB().QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
User: &UserContext{ }
UserID: userID, err = runPasskeyQuery()
UserName: username, if isDBClosed(err) {
Email: email, if reconnErr := a.reconnectDB(); reconnErr == nil {
UserLevel: userLevel, err = runPasskeyQuery()
SessionID: sessionToken, }
Roles: parseRoles(roles), }
}, if err != nil {
ExpiresIn: int64(24 * time.Hour.Seconds()), return nil, fmt.Errorf("passkey login query failed: %w", err)
}, nil }
if !success {
if errorMsg.Valid {
return nil, fmt.Errorf("%s", errorMsg.String)
}
return nil, fmt.Errorf("passkey login failed")
}
var response LoginResponse
if err := json.Unmarshal([]byte(dataJSON.String), &response); err != nil {
return nil, fmt.Errorf("failed to parse passkey login response: %w", err)
}
return &response, nil
} }
// GetPasskeyCredentials returns all passkey credentials for a user // GetPasskeyCredentials returns all passkey credentials for a user

222
pkg/security/sql_names.go Normal file
View File

@@ -0,0 +1,222 @@
package security
import (
"fmt"
"reflect"
"regexp"
)
var validSQLIdentifier = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`)
// SQLNames defines all configurable SQL stored procedure and table names
// used by the security package. Override individual fields to remap
// to custom database objects. Use DefaultSQLNames() for baseline defaults,
// and MergeSQLNames() to apply partial overrides.
type SQLNames struct {
// Auth procedures (DatabaseAuthenticator)
Login string // default: "resolvespec_login"
Register string // default: "resolvespec_register"
Logout string // default: "resolvespec_logout"
Session string // default: "resolvespec_session"
SessionUpdate string // default: "resolvespec_session_update"
RefreshToken string // default: "resolvespec_refresh_token"
// JWT procedures (JWTAuthenticator)
JWTLogin string // default: "resolvespec_jwt_login"
JWTLogout string // default: "resolvespec_jwt_logout"
// Security policy procedures
ColumnSecurity string // default: "resolvespec_column_security"
RowSecurity string // default: "resolvespec_row_security"
// TOTP procedures (DatabaseTwoFactorProvider)
TOTPEnable string // default: "resolvespec_totp_enable"
TOTPDisable string // default: "resolvespec_totp_disable"
TOTPGetStatus string // default: "resolvespec_totp_get_status"
TOTPGetSecret string // default: "resolvespec_totp_get_secret"
TOTPRegenerateBackup string // default: "resolvespec_totp_regenerate_backup_codes"
TOTPValidateBackupCode string // default: "resolvespec_totp_validate_backup_code"
// Passkey procedures (DatabasePasskeyProvider)
PasskeyStoreCredential string // default: "resolvespec_passkey_store_credential"
PasskeyGetCredsByUsername string // default: "resolvespec_passkey_get_credentials_by_username"
PasskeyGetCredential string // default: "resolvespec_passkey_get_credential"
PasskeyUpdateCounter string // default: "resolvespec_passkey_update_counter"
PasskeyGetUserCredentials string // default: "resolvespec_passkey_get_user_credentials"
PasskeyDeleteCredential string // default: "resolvespec_passkey_delete_credential"
PasskeyUpdateName string // default: "resolvespec_passkey_update_name"
PasskeyLogin string // default: "resolvespec_passkey_login"
// OAuth2 procedures (DatabaseAuthenticator OAuth2 methods)
OAuthGetOrCreateUser string // default: "resolvespec_oauth_getorcreateuser"
OAuthCreateSession string // default: "resolvespec_oauth_createsession"
OAuthGetRefreshToken string // default: "resolvespec_oauth_getrefreshtoken"
OAuthUpdateRefreshToken string // default: "resolvespec_oauth_updaterefreshtoken"
OAuthGetUser string // default: "resolvespec_oauth_getuser"
}
// DefaultSQLNames returns an SQLNames with all default resolvespec_* values.
func DefaultSQLNames() *SQLNames {
return &SQLNames{
Login: "resolvespec_login",
Register: "resolvespec_register",
Logout: "resolvespec_logout",
Session: "resolvespec_session",
SessionUpdate: "resolvespec_session_update",
RefreshToken: "resolvespec_refresh_token",
JWTLogin: "resolvespec_jwt_login",
JWTLogout: "resolvespec_jwt_logout",
ColumnSecurity: "resolvespec_column_security",
RowSecurity: "resolvespec_row_security",
TOTPEnable: "resolvespec_totp_enable",
TOTPDisable: "resolvespec_totp_disable",
TOTPGetStatus: "resolvespec_totp_get_status",
TOTPGetSecret: "resolvespec_totp_get_secret",
TOTPRegenerateBackup: "resolvespec_totp_regenerate_backup_codes",
TOTPValidateBackupCode: "resolvespec_totp_validate_backup_code",
PasskeyStoreCredential: "resolvespec_passkey_store_credential",
PasskeyGetCredsByUsername: "resolvespec_passkey_get_credentials_by_username",
PasskeyGetCredential: "resolvespec_passkey_get_credential",
PasskeyUpdateCounter: "resolvespec_passkey_update_counter",
PasskeyGetUserCredentials: "resolvespec_passkey_get_user_credentials",
PasskeyDeleteCredential: "resolvespec_passkey_delete_credential",
PasskeyUpdateName: "resolvespec_passkey_update_name",
PasskeyLogin: "resolvespec_passkey_login",
OAuthGetOrCreateUser: "resolvespec_oauth_getorcreateuser",
OAuthCreateSession: "resolvespec_oauth_createsession",
OAuthGetRefreshToken: "resolvespec_oauth_getrefreshtoken",
OAuthUpdateRefreshToken: "resolvespec_oauth_updaterefreshtoken",
OAuthGetUser: "resolvespec_oauth_getuser",
}
}
// MergeSQLNames returns a copy of base with any non-empty fields from override applied.
// If override is nil, a copy of base is returned.
func MergeSQLNames(base, override *SQLNames) *SQLNames {
if override == nil {
copied := *base
return &copied
}
merged := *base
if override.Login != "" {
merged.Login = override.Login
}
if override.Register != "" {
merged.Register = override.Register
}
if override.Logout != "" {
merged.Logout = override.Logout
}
if override.Session != "" {
merged.Session = override.Session
}
if override.SessionUpdate != "" {
merged.SessionUpdate = override.SessionUpdate
}
if override.RefreshToken != "" {
merged.RefreshToken = override.RefreshToken
}
if override.JWTLogin != "" {
merged.JWTLogin = override.JWTLogin
}
if override.JWTLogout != "" {
merged.JWTLogout = override.JWTLogout
}
if override.ColumnSecurity != "" {
merged.ColumnSecurity = override.ColumnSecurity
}
if override.RowSecurity != "" {
merged.RowSecurity = override.RowSecurity
}
if override.TOTPEnable != "" {
merged.TOTPEnable = override.TOTPEnable
}
if override.TOTPDisable != "" {
merged.TOTPDisable = override.TOTPDisable
}
if override.TOTPGetStatus != "" {
merged.TOTPGetStatus = override.TOTPGetStatus
}
if override.TOTPGetSecret != "" {
merged.TOTPGetSecret = override.TOTPGetSecret
}
if override.TOTPRegenerateBackup != "" {
merged.TOTPRegenerateBackup = override.TOTPRegenerateBackup
}
if override.TOTPValidateBackupCode != "" {
merged.TOTPValidateBackupCode = override.TOTPValidateBackupCode
}
if override.PasskeyStoreCredential != "" {
merged.PasskeyStoreCredential = override.PasskeyStoreCredential
}
if override.PasskeyGetCredsByUsername != "" {
merged.PasskeyGetCredsByUsername = override.PasskeyGetCredsByUsername
}
if override.PasskeyGetCredential != "" {
merged.PasskeyGetCredential = override.PasskeyGetCredential
}
if override.PasskeyUpdateCounter != "" {
merged.PasskeyUpdateCounter = override.PasskeyUpdateCounter
}
if override.PasskeyGetUserCredentials != "" {
merged.PasskeyGetUserCredentials = override.PasskeyGetUserCredentials
}
if override.PasskeyDeleteCredential != "" {
merged.PasskeyDeleteCredential = override.PasskeyDeleteCredential
}
if override.PasskeyUpdateName != "" {
merged.PasskeyUpdateName = override.PasskeyUpdateName
}
if override.PasskeyLogin != "" {
merged.PasskeyLogin = override.PasskeyLogin
}
if override.OAuthGetOrCreateUser != "" {
merged.OAuthGetOrCreateUser = override.OAuthGetOrCreateUser
}
if override.OAuthCreateSession != "" {
merged.OAuthCreateSession = override.OAuthCreateSession
}
if override.OAuthGetRefreshToken != "" {
merged.OAuthGetRefreshToken = override.OAuthGetRefreshToken
}
if override.OAuthUpdateRefreshToken != "" {
merged.OAuthUpdateRefreshToken = override.OAuthUpdateRefreshToken
}
if override.OAuthGetUser != "" {
merged.OAuthGetUser = override.OAuthGetUser
}
return &merged
}
// ValidateSQLNames checks that all non-empty fields in names are valid SQL identifiers.
// Returns an error if any field contains invalid characters.
func ValidateSQLNames(names *SQLNames) error {
v := reflect.ValueOf(names).Elem()
typ := v.Type()
for i := 0; i < v.NumField(); i++ {
field := v.Field(i)
if field.Kind() != reflect.String {
continue
}
val := field.String()
if val != "" && !validSQLIdentifier.MatchString(val) {
return fmt.Errorf("SQLNames.%s contains invalid characters: %q", typ.Field(i).Name, val)
}
}
return nil
}
// resolveSQLNames merges an optional override with defaults.
// Used by constructors that accept variadic *SQLNames.
func resolveSQLNames(override ...*SQLNames) *SQLNames {
if len(override) > 0 && override[0] != nil {
return MergeSQLNames(DefaultSQLNames(), override[0])
}
return DefaultSQLNames()
}

View File

@@ -0,0 +1,145 @@
package security
import (
"reflect"
"testing"
)
func TestDefaultSQLNames_AllFieldsNonEmpty(t *testing.T) {
names := DefaultSQLNames()
v := reflect.ValueOf(names).Elem()
typ := v.Type()
for i := 0; i < v.NumField(); i++ {
field := v.Field(i)
if field.Kind() != reflect.String {
continue
}
if field.String() == "" {
t.Errorf("DefaultSQLNames().%s is empty", typ.Field(i).Name)
}
}
}
func TestMergeSQLNames_PartialOverride(t *testing.T) {
base := DefaultSQLNames()
override := &SQLNames{
Login: "custom_login",
TOTPEnable: "custom_totp_enable",
PasskeyLogin: "custom_passkey_login",
}
merged := MergeSQLNames(base, override)
if merged.Login != "custom_login" {
t.Errorf("MergeSQLNames().Login = %q, want %q", merged.Login, "custom_login")
}
if merged.TOTPEnable != "custom_totp_enable" {
t.Errorf("MergeSQLNames().TOTPEnable = %q, want %q", merged.TOTPEnable, "custom_totp_enable")
}
if merged.PasskeyLogin != "custom_passkey_login" {
t.Errorf("MergeSQLNames().PasskeyLogin = %q, want %q", merged.PasskeyLogin, "custom_passkey_login")
}
// Non-overridden fields should retain defaults
if merged.Logout != "resolvespec_logout" {
t.Errorf("MergeSQLNames().Logout = %q, want %q", merged.Logout, "resolvespec_logout")
}
if merged.Session != "resolvespec_session" {
t.Errorf("MergeSQLNames().Session = %q, want %q", merged.Session, "resolvespec_session")
}
}
func TestMergeSQLNames_NilOverride(t *testing.T) {
base := DefaultSQLNames()
merged := MergeSQLNames(base, nil)
// Should be a copy, not the same pointer
if merged == base {
t.Error("MergeSQLNames with nil override should return a copy, not the same pointer")
}
// All values should match
v1 := reflect.ValueOf(base).Elem()
v2 := reflect.ValueOf(merged).Elem()
typ := v1.Type()
for i := 0; i < v1.NumField(); i++ {
f1 := v1.Field(i)
f2 := v2.Field(i)
if f1.Kind() != reflect.String {
continue
}
if f1.String() != f2.String() {
t.Errorf("MergeSQLNames(base, nil).%s = %q, want %q", typ.Field(i).Name, f2.String(), f1.String())
}
}
}
func TestMergeSQLNames_DoesNotMutateBase(t *testing.T) {
base := DefaultSQLNames()
originalLogin := base.Login
override := &SQLNames{Login: "custom_login"}
_ = MergeSQLNames(base, override)
if base.Login != originalLogin {
t.Errorf("MergeSQLNames mutated base: Login = %q, want %q", base.Login, originalLogin)
}
}
func TestMergeSQLNames_AllFieldsMerged(t *testing.T) {
base := DefaultSQLNames()
override := &SQLNames{}
v := reflect.ValueOf(override).Elem()
for i := 0; i < v.NumField(); i++ {
if v.Field(i).Kind() == reflect.String {
v.Field(i).SetString("custom_sentinel")
}
}
merged := MergeSQLNames(base, override)
mv := reflect.ValueOf(merged).Elem()
typ := mv.Type()
for i := 0; i < mv.NumField(); i++ {
if mv.Field(i).Kind() != reflect.String {
continue
}
if mv.Field(i).String() != "custom_sentinel" {
t.Errorf("MergeSQLNames did not merge field %s", typ.Field(i).Name)
}
}
}
func TestValidateSQLNames_Valid(t *testing.T) {
names := DefaultSQLNames()
if err := ValidateSQLNames(names); err != nil {
t.Errorf("ValidateSQLNames(defaults) error = %v", err)
}
}
func TestValidateSQLNames_Invalid(t *testing.T) {
names := DefaultSQLNames()
names.Login = "resolvespec_login; DROP TABLE users; --"
err := ValidateSQLNames(names)
if err == nil {
t.Error("ValidateSQLNames should reject names with invalid characters")
}
}
func TestResolveSQLNames_NoOverride(t *testing.T) {
names := resolveSQLNames()
if names.Login != "resolvespec_login" {
t.Errorf("resolveSQLNames().Login = %q, want default", names.Login)
}
}
func TestResolveSQLNames_WithOverride(t *testing.T) {
names := resolveSQLNames(&SQLNames{Login: "custom_login"})
if names.Login != "custom_login" {
t.Errorf("resolveSQLNames().Login = %q, want %q", names.Login, "custom_login")
}
if names.Logout != "resolvespec_logout" {
t.Errorf("resolveSQLNames().Logout = %q, want default", names.Logout)
}
}

View File

@@ -9,23 +9,23 @@ import (
) )
// DatabaseTwoFactorProvider implements TwoFactorAuthProvider using PostgreSQL stored procedures // DatabaseTwoFactorProvider implements TwoFactorAuthProvider using PostgreSQL stored procedures
// Requires stored procedures: resolvespec_totp_enable, resolvespec_totp_disable, // Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults)
// resolvespec_totp_get_status, resolvespec_totp_get_secret,
// resolvespec_totp_regenerate_backup_codes, resolvespec_totp_validate_backup_code
// See totp_database_schema.sql for procedure definitions // See totp_database_schema.sql for procedure definitions
type DatabaseTwoFactorProvider struct { type DatabaseTwoFactorProvider struct {
db *sql.DB db *sql.DB
totpGen *TOTPGenerator totpGen *TOTPGenerator
sqlNames *SQLNames
} }
// NewDatabaseTwoFactorProvider creates a new database-backed 2FA provider // NewDatabaseTwoFactorProvider creates a new database-backed 2FA provider
func NewDatabaseTwoFactorProvider(db *sql.DB, config *TwoFactorConfig) *DatabaseTwoFactorProvider { func NewDatabaseTwoFactorProvider(db *sql.DB, config *TwoFactorConfig, names ...*SQLNames) *DatabaseTwoFactorProvider {
if config == nil { if config == nil {
config = DefaultTwoFactorConfig() config = DefaultTwoFactorConfig()
} }
return &DatabaseTwoFactorProvider{ return &DatabaseTwoFactorProvider{
db: db, db: db,
totpGen: NewTOTPGenerator(config), totpGen: NewTOTPGenerator(config),
sqlNames: resolveSQLNames(names...),
} }
} }
@@ -76,7 +76,7 @@ func (p *DatabaseTwoFactorProvider) Enable2FA(userID int, secret string, backupC
var success bool var success bool
var errorMsg sql.NullString var errorMsg sql.NullString
query := `SELECT p_success, p_error FROM resolvespec_totp_enable($1, $2, $3::jsonb)` query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1, $2, $3::jsonb)`, p.sqlNames.TOTPEnable)
err = p.db.QueryRow(query, userID, secret, string(codesJSON)).Scan(&success, &errorMsg) err = p.db.QueryRow(query, userID, secret, string(codesJSON)).Scan(&success, &errorMsg)
if err != nil { if err != nil {
return fmt.Errorf("enable 2FA query failed: %w", err) return fmt.Errorf("enable 2FA query failed: %w", err)
@@ -97,7 +97,7 @@ func (p *DatabaseTwoFactorProvider) Disable2FA(userID int) error {
var success bool var success bool
var errorMsg sql.NullString var errorMsg sql.NullString
query := `SELECT p_success, p_error FROM resolvespec_totp_disable($1)` query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1)`, p.sqlNames.TOTPDisable)
err := p.db.QueryRow(query, userID).Scan(&success, &errorMsg) err := p.db.QueryRow(query, userID).Scan(&success, &errorMsg)
if err != nil { if err != nil {
return fmt.Errorf("disable 2FA query failed: %w", err) return fmt.Errorf("disable 2FA query failed: %w", err)
@@ -119,7 +119,7 @@ func (p *DatabaseTwoFactorProvider) Get2FAStatus(userID int) (bool, error) {
var errorMsg sql.NullString var errorMsg sql.NullString
var enabled bool var enabled bool
query := `SELECT p_success, p_error, p_enabled FROM resolvespec_totp_get_status($1)` query := fmt.Sprintf(`SELECT p_success, p_error, p_enabled FROM %s($1)`, p.sqlNames.TOTPGetStatus)
err := p.db.QueryRow(query, userID).Scan(&success, &errorMsg, &enabled) err := p.db.QueryRow(query, userID).Scan(&success, &errorMsg, &enabled)
if err != nil { if err != nil {
return false, fmt.Errorf("get 2FA status query failed: %w", err) return false, fmt.Errorf("get 2FA status query failed: %w", err)
@@ -141,7 +141,7 @@ func (p *DatabaseTwoFactorProvider) Get2FASecret(userID int) (string, error) {
var errorMsg sql.NullString var errorMsg sql.NullString
var secret sql.NullString var secret sql.NullString
query := `SELECT p_success, p_error, p_secret FROM resolvespec_totp_get_secret($1)` query := fmt.Sprintf(`SELECT p_success, p_error, p_secret FROM %s($1)`, p.sqlNames.TOTPGetSecret)
err := p.db.QueryRow(query, userID).Scan(&success, &errorMsg, &secret) err := p.db.QueryRow(query, userID).Scan(&success, &errorMsg, &secret)
if err != nil { if err != nil {
return "", fmt.Errorf("get 2FA secret query failed: %w", err) return "", fmt.Errorf("get 2FA secret query failed: %w", err)
@@ -185,7 +185,7 @@ func (p *DatabaseTwoFactorProvider) GenerateBackupCodes(userID int, count int) (
var success bool var success bool
var errorMsg sql.NullString var errorMsg sql.NullString
query := `SELECT p_success, p_error FROM resolvespec_totp_regenerate_backup_codes($1, $2::jsonb)` query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1, $2::jsonb)`, p.sqlNames.TOTPRegenerateBackup)
err = p.db.QueryRow(query, userID, string(codesJSON)).Scan(&success, &errorMsg) err = p.db.QueryRow(query, userID, string(codesJSON)).Scan(&success, &errorMsg)
if err != nil { if err != nil {
return nil, fmt.Errorf("regenerate backup codes query failed: %w", err) return nil, fmt.Errorf("regenerate backup codes query failed: %w", err)
@@ -212,7 +212,7 @@ func (p *DatabaseTwoFactorProvider) ValidateBackupCode(userID int, code string)
var errorMsg sql.NullString var errorMsg sql.NullString
var valid bool var valid bool
query := `SELECT p_success, p_error, p_valid FROM resolvespec_totp_validate_backup_code($1, $2)` query := fmt.Sprintf(`SELECT p_success, p_error, p_valid FROM %s($1, $2)`, p.sqlNames.TOTPValidateBackupCode)
err := p.db.QueryRow(query, userID, codeHash).Scan(&success, &errorMsg, &valid) err := p.db.QueryRow(query, userID, codeHash).Scan(&success, &errorMsg, &valid)
if err != nil { if err != nil {
return false, fmt.Errorf("validate backup code query failed: %w", err) return false, fmt.Errorf("validate backup code query failed: %w", err)