mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2026-04-09 09:26:24 +00:00
Compare commits
14 Commits
v1.0.64
...
feature-ke
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
79a3912f93 | ||
|
|
a9bf08f58b | ||
|
|
405a04a192 | ||
|
|
c1b16d363a | ||
|
|
568df8c6d6 | ||
|
|
aa362c77da | ||
|
|
1641eaf278 | ||
|
|
200a03c225 | ||
|
|
7ef9cf39d3 | ||
|
|
7f6410f665 | ||
|
|
835bbb0727 | ||
|
|
047a1cc187 | ||
|
|
7a498edab7 | ||
|
|
f10bb0827e |
71
README.md
71
README.md
@@ -9,6 +9,7 @@ ResolveSpec is a flexible and powerful REST API specification and implementation
|
||||
3. **FuncSpec** - Header-based API to map and call API's to sql functions
|
||||
4. **WebSocketSpec** - Real-time bidirectional communication with full CRUD operations
|
||||
5. **MQTTSpec** - MQTT-based API ideal for IoT and mobile applications
|
||||
6. **ResolveMCP** - Model Context Protocol (MCP) server that exposes models as AI-consumable tools and resources over HTTP/SSE
|
||||
|
||||
All share the same core architecture and provide dynamic data querying, relationship preloading, and complex filtering.
|
||||
|
||||
@@ -21,6 +22,7 @@ All share the same core architecture and provide dynamic data querying, relation
|
||||
* [Quick Start](#quick-start)
|
||||
* [ResolveSpec (Body-Based API)](#resolvespec---body-based-api)
|
||||
* [RestHeadSpec (Header-Based API)](#restheadspec---header-based-api)
|
||||
* [ResolveMCP (MCP Server)](#resolvemcp---mcp-server)
|
||||
* [Architecture](#architecture)
|
||||
* [API Structure](#api-structure)
|
||||
* [RestHeadSpec Overview](#restheadspec-header-based-api)
|
||||
@@ -50,6 +52,15 @@ All share the same core architecture and provide dynamic data querying, relation
|
||||
* **🆕 Backward Compatible**: Existing code works without changes
|
||||
* **🆕 Better Testing**: Mockable interfaces for easy unit testing
|
||||
|
||||
### ResolveMCP (v3.2+)
|
||||
|
||||
* **🆕 MCP Server**: Expose any registered database model as Model Context Protocol tools and resources
|
||||
* **🆕 AI-Ready Descriptions**: Tool descriptions include the full column schema, primary key, nullable flags, and relations — giving AI models everything they need to query correctly without guessing
|
||||
* **🆕 Four Tools Per Model**: `read_`, `create_`, `update_`, `delete_` tools auto-registered per model
|
||||
* **🆕 Full Query Support**: Filters, sort, limit/offset, cursor pagination, column selection, and relation preloading all available as tool parameters
|
||||
* **🆕 HTTP/SSE Transport**: Standards-compliant SSE transport for use with Claude Desktop, Cursor, and any MCP-compatible client
|
||||
* **🆕 Lifecycle Hooks**: Same Before/After hook system as ResolveSpec for auth and side-effects
|
||||
|
||||
### RestHeadSpec (v2.1+)
|
||||
|
||||
* **🆕 Header-Based API**: All query options passed via HTTP headers instead of request body
|
||||
@@ -190,6 +201,40 @@ restheadspec.SetupMuxRoutes(router, handler, nil)
|
||||
|
||||
For complete documentation, see [pkg/restheadspec/README.md](pkg/restheadspec/README.md).
|
||||
|
||||
### ResolveMCP (MCP Server)
|
||||
|
||||
ResolveMCP exposes registered models as Model Context Protocol tools so AI models (Claude, Cursor, etc.) can query and mutate your database directly:
|
||||
|
||||
```go
|
||||
import "github.com/bitechdev/ResolveSpec/pkg/resolvemcp"
|
||||
|
||||
// Create handler
|
||||
handler := resolvemcp.NewHandlerWithGORM(db)
|
||||
|
||||
// Register models — must be done BEFORE Build()
|
||||
handler.RegisterModel("public", "users", &User{})
|
||||
handler.RegisterModel("public", "posts", &Post{})
|
||||
|
||||
// Finalize: registers MCP tools and resources
|
||||
handler.Build()
|
||||
|
||||
// Mount SSE transport on your existing router
|
||||
router := mux.NewRouter()
|
||||
resolvemcp.SetupMuxRoutes(router, handler, "http://localhost:8080")
|
||||
|
||||
// MCP clients connect to:
|
||||
// SSE stream: GET http://localhost:8080/mcp/sse
|
||||
// Messages: POST http://localhost:8080/mcp/message
|
||||
//
|
||||
// Auto-registered tools per model:
|
||||
// read_public_users — filter, sort, paginate, preload
|
||||
// create_public_users — insert a new record
|
||||
// update_public_users — update a record by ID
|
||||
// delete_public_users — delete a record by ID
|
||||
```
|
||||
|
||||
For complete documentation, see [pkg/resolvemcp/README.md](pkg/resolvemcp/README.md) (if present) or the package source.
|
||||
|
||||
## Architecture
|
||||
|
||||
### Two Complementary APIs
|
||||
@@ -344,6 +389,19 @@ Alternative REST API where query options are passed via HTTP headers.
|
||||
|
||||
For complete documentation, see [pkg/restheadspec/README.md](pkg/restheadspec/README.md).
|
||||
|
||||
#### ResolveMCP - MCP Server
|
||||
|
||||
Expose any registered model as Model Context Protocol tools and resources consumable by AI models over HTTP/SSE.
|
||||
|
||||
**Key Features**:
|
||||
- Four tools per model: `read_`, `create_`, `update_`, `delete_`
|
||||
- Rich AI-readable descriptions: column names, types, primary key, nullable flags, and preloadable relations
|
||||
- Full query support: filters, sort, limit/offset, cursor pagination, column selection, preloads
|
||||
- HTTP/SSE transport compatible with Claude Desktop, Cursor, and any MCP client
|
||||
- Same Before/After lifecycle hooks as ResolveSpec
|
||||
|
||||
For complete documentation, see [pkg/resolvemcp/](pkg/resolvemcp/).
|
||||
|
||||
#### FuncSpec - Function-Based SQL API
|
||||
|
||||
Execute SQL functions and queries through a simple HTTP API with header-based parameters.
|
||||
@@ -529,7 +587,18 @@ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file
|
||||
|
||||
## What's New
|
||||
|
||||
### v3.1 (Latest - February 2026)
|
||||
### v3.2 (Latest - March 2026)
|
||||
|
||||
**ResolveMCP - Model Context Protocol Server (🆕)**:
|
||||
|
||||
* **MCP Tools**: Four tools auto-registered per model (`read_`, `create_`, `update_`, `delete_`) over HTTP/SSE transport
|
||||
* **AI-Ready Descriptions**: Full column schema, primary key, nullable flags, and relation names surfaced in tool descriptions so AI models can query without guessing
|
||||
* **Full Query Support**: Filters, sort, limit/offset, cursor pagination, column selection, and relation preloading all available as tool parameters
|
||||
* **HTTP/SSE Transport**: Standards-compliant transport compatible with Claude Desktop, Cursor, and any MCP 2024-11-05 client
|
||||
* **Lifecycle Hooks**: Same Before/After hook system as ResolveSpec for auth, auditing, and side-effects
|
||||
* **MCP Resources**: Each model also exposed as a named resource for direct data access by AI clients
|
||||
|
||||
### v3.1 (February 2026)
|
||||
|
||||
**SQLite Schema Translation (🆕)**:
|
||||
|
||||
|
||||
5
go.mod
5
go.mod
@@ -40,6 +40,7 @@ require (
|
||||
go.opentelemetry.io/otel/trace v1.38.0
|
||||
go.uber.org/zap v1.27.1
|
||||
golang.org/x/crypto v0.46.0
|
||||
golang.org/x/oauth2 v0.34.0
|
||||
golang.org/x/time v0.14.0
|
||||
gorm.io/driver/postgres v1.6.0
|
||||
gorm.io/driver/sqlite v1.6.0
|
||||
@@ -78,6 +79,7 @@ require (
|
||||
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect
|
||||
github.com/golang-sql/sqlexp v0.1.0 // indirect
|
||||
github.com/golang/snappy v1.0.0 // indirect
|
||||
github.com/google/jsonschema-go v0.4.2 // indirect
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 // indirect
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
||||
@@ -86,6 +88,7 @@ require (
|
||||
github.com/jinzhu/now v1.1.5 // indirect
|
||||
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
|
||||
github.com/magiconair/properties v1.8.10 // indirect
|
||||
github.com/mark3labs/mcp-go v0.46.0 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/moby/docker-image-spec v1.3.1 // indirect
|
||||
github.com/moby/go-archive v0.1.0 // indirect
|
||||
@@ -131,6 +134,7 @@ require (
|
||||
github.com/xdg-go/pbkdf2 v1.0.0 // indirect
|
||||
github.com/xdg-go/scram v1.2.0 // indirect
|
||||
github.com/xdg-go/stringprep v1.0.4 // indirect
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
|
||||
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 // indirect
|
||||
github.com/yusufpapurcu/wmi v1.2.4 // indirect
|
||||
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
|
||||
@@ -143,7 +147,6 @@ require (
|
||||
golang.org/x/exp v0.0.0-20251219203646-944ab1f22d93 // indirect
|
||||
golang.org/x/mod v0.31.0 // indirect
|
||||
golang.org/x/net v0.48.0 // indirect
|
||||
golang.org/x/oauth2 v0.34.0 // indirect
|
||||
golang.org/x/sync v0.19.0 // indirect
|
||||
golang.org/x/sys v0.39.0 // indirect
|
||||
golang.org/x/text v0.32.0 // indirect
|
||||
|
||||
6
go.sum
6
go.sum
@@ -120,6 +120,8 @@ github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/jsonschema-go v0.4.2 h1:tmrUohrwoLZZS/P3x7ex0WAVknEkBZM46iALbcqoRA8=
|
||||
github.com/google/jsonschema-go v0.4.2/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE=
|
||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
|
||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
|
||||
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
@@ -173,6 +175,8 @@ github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ
|
||||
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I=
|
||||
github.com/magiconair/properties v1.8.10 h1:s31yESBquKXCV9a/ScB3ESkOjUYYv+X0rg8SYxI99mE=
|
||||
github.com/magiconair/properties v1.8.10/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
|
||||
github.com/mark3labs/mcp-go v0.46.0 h1:8KRibF4wcKejbLsHxCA/QBVUr5fQ9nwz/n8lGqmaALo=
|
||||
github.com/mark3labs/mcp-go v0.46.0/go.mod h1:JKTC7R2LLVagkEWK7Kwu7DbmA6iIvnNAod6yrHiQMag=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-sqlite3 v1.14.33 h1:A5blZ5ulQo2AtayQ9/limgHEkFreKj1Dv226a1K73s0=
|
||||
@@ -326,6 +330,8 @@ github.com/xdg-go/scram v1.2.0 h1:bYKF2AEwG5rqd1BumT4gAnvwU/M9nBp2pTSxeZw7Wvs=
|
||||
github.com/xdg-go/scram v1.2.0/go.mod h1:3dlrS0iBaWKYVt2ZfA4cj48umJZ+cAEbR6/SjLA88I8=
|
||||
github.com/xdg-go/stringprep v1.0.4 h1:XLI/Ng3O1Atzq0oBs3TWm+5ZVgkq2aqdlvP9JtoZ6c8=
|
||||
github.com/xdg-go/stringprep v1.0.4/go.mod h1:mPGuuIYwz7CmR2bT9j4GbQqutWS1zV24gijq1dTyGkM=
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
|
||||
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 h1:ilQV1hzziu+LLM3zUTJ0trRztfwgjqKnBWNtSRkbmwM=
|
||||
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78/go.mod h1:aL8wCCfTfSfmXjznFBSZNN13rSJjlIOI1fUNAtF7rmI=
|
||||
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/uptrace/bun"
|
||||
@@ -95,6 +96,8 @@ func debugScanIntoStruct(rows interface{}, dest interface{}) error {
|
||||
// This demonstrates how the abstraction works with different ORMs
|
||||
type BunAdapter struct {
|
||||
db *bun.DB
|
||||
dbMu sync.RWMutex
|
||||
dbFactory func() (*bun.DB, error)
|
||||
driverName string
|
||||
}
|
||||
|
||||
@@ -106,10 +109,36 @@ func NewBunAdapter(db *bun.DB) *BunAdapter {
|
||||
return adapter
|
||||
}
|
||||
|
||||
// WithDBFactory configures a factory used to reopen the database connection if it is closed.
|
||||
func (b *BunAdapter) WithDBFactory(factory func() (*bun.DB, error)) *BunAdapter {
|
||||
b.dbFactory = factory
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *BunAdapter) getDB() *bun.DB {
|
||||
b.dbMu.RLock()
|
||||
defer b.dbMu.RUnlock()
|
||||
return b.db
|
||||
}
|
||||
|
||||
func (b *BunAdapter) reconnectDB() error {
|
||||
if b.dbFactory == nil {
|
||||
return fmt.Errorf("no db factory configured for reconnect")
|
||||
}
|
||||
newDB, err := b.dbFactory()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
b.dbMu.Lock()
|
||||
b.db = newDB
|
||||
b.dbMu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// EnableQueryDebug enables query debugging which logs all SQL queries including preloads
|
||||
// This is useful for debugging preload queries that may be failing
|
||||
func (b *BunAdapter) EnableQueryDebug() {
|
||||
b.db.AddQueryHook(&QueryDebugHook{})
|
||||
b.getDB().AddQueryHook(&QueryDebugHook{})
|
||||
logger.Info("Bun query debug mode enabled - all SQL queries will be logged")
|
||||
}
|
||||
|
||||
@@ -130,22 +159,22 @@ func (b *BunAdapter) DisableQueryDebug() {
|
||||
|
||||
func (b *BunAdapter) NewSelect() common.SelectQuery {
|
||||
return &BunSelectQuery{
|
||||
query: b.db.NewSelect(),
|
||||
query: b.getDB().NewSelect(),
|
||||
db: b.db,
|
||||
driverName: b.driverName,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *BunAdapter) NewInsert() common.InsertQuery {
|
||||
return &BunInsertQuery{query: b.db.NewInsert()}
|
||||
return &BunInsertQuery{query: b.getDB().NewInsert()}
|
||||
}
|
||||
|
||||
func (b *BunAdapter) NewUpdate() common.UpdateQuery {
|
||||
return &BunUpdateQuery{query: b.db.NewUpdate()}
|
||||
return &BunUpdateQuery{query: b.getDB().NewUpdate()}
|
||||
}
|
||||
|
||||
func (b *BunAdapter) NewDelete() common.DeleteQuery {
|
||||
return &BunDeleteQuery{query: b.db.NewDelete()}
|
||||
return &BunDeleteQuery{query: b.getDB().NewDelete()}
|
||||
}
|
||||
|
||||
func (b *BunAdapter) Exec(ctx context.Context, query string, args ...interface{}) (res common.Result, err error) {
|
||||
@@ -154,7 +183,14 @@ func (b *BunAdapter) Exec(ctx context.Context, query string, args ...interface{}
|
||||
err = logger.HandlePanic("BunAdapter.Exec", r)
|
||||
}
|
||||
}()
|
||||
result, err := b.db.ExecContext(ctx, query, args...)
|
||||
var result sql.Result
|
||||
run := func() error { var e error; result, e = b.getDB().ExecContext(ctx, query, args...); return e }
|
||||
err = run()
|
||||
if isDBClosed(err) {
|
||||
if reconnErr := b.reconnectDB(); reconnErr == nil {
|
||||
err = run()
|
||||
}
|
||||
}
|
||||
return &BunResult{result: result}, err
|
||||
}
|
||||
|
||||
@@ -164,11 +200,17 @@ func (b *BunAdapter) Query(ctx context.Context, dest interface{}, query string,
|
||||
err = logger.HandlePanic("BunAdapter.Query", r)
|
||||
}
|
||||
}()
|
||||
return b.db.NewRaw(query, args...).Scan(ctx, dest)
|
||||
err = b.getDB().NewRaw(query, args...).Scan(ctx, dest)
|
||||
if isDBClosed(err) {
|
||||
if reconnErr := b.reconnectDB(); reconnErr == nil {
|
||||
err = b.getDB().NewRaw(query, args...).Scan(ctx, dest)
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (b *BunAdapter) BeginTx(ctx context.Context) (common.Database, error) {
|
||||
tx, err := b.db.BeginTx(ctx, &sql.TxOptions{})
|
||||
tx, err := b.getDB().BeginTx(ctx, &sql.TxOptions{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -194,7 +236,7 @@ func (b *BunAdapter) RunInTransaction(ctx context.Context, fn func(common.Databa
|
||||
err = logger.HandlePanic("BunAdapter.RunInTransaction", r)
|
||||
}
|
||||
}()
|
||||
return b.db.RunInTx(ctx, &sql.TxOptions{}, func(ctx context.Context, tx bun.Tx) error {
|
||||
return b.getDB().RunInTx(ctx, &sql.TxOptions{}, func(ctx context.Context, tx bun.Tx) error {
|
||||
// Create adapter with transaction
|
||||
adapter := &BunTxAdapter{tx: tx, driverName: b.driverName}
|
||||
return fn(adapter)
|
||||
@@ -202,7 +244,7 @@ func (b *BunAdapter) RunInTransaction(ctx context.Context, fn func(common.Databa
|
||||
}
|
||||
|
||||
func (b *BunAdapter) GetUnderlyingDB() interface{} {
|
||||
return b.db
|
||||
return b.getDB()
|
||||
}
|
||||
|
||||
func (b *BunAdapter) DriverName() string {
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
@@ -17,6 +18,8 @@ import (
|
||||
// This provides a lightweight PostgreSQL adapter without ORM overhead
|
||||
type PgSQLAdapter struct {
|
||||
db *sql.DB
|
||||
dbMu sync.RWMutex
|
||||
dbFactory func() (*sql.DB, error)
|
||||
driverName string
|
||||
}
|
||||
|
||||
@@ -31,6 +34,36 @@ func NewPgSQLAdapter(db *sql.DB, driverName ...string) *PgSQLAdapter {
|
||||
return &PgSQLAdapter{db: db, driverName: name}
|
||||
}
|
||||
|
||||
// WithDBFactory configures a factory used to reopen the database connection if it is closed.
|
||||
func (p *PgSQLAdapter) WithDBFactory(factory func() (*sql.DB, error)) *PgSQLAdapter {
|
||||
p.dbFactory = factory
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *PgSQLAdapter) getDB() *sql.DB {
|
||||
p.dbMu.RLock()
|
||||
defer p.dbMu.RUnlock()
|
||||
return p.db
|
||||
}
|
||||
|
||||
func (p *PgSQLAdapter) reconnectDB() error {
|
||||
if p.dbFactory == nil {
|
||||
return fmt.Errorf("no db factory configured for reconnect")
|
||||
}
|
||||
newDB, err := p.dbFactory()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p.dbMu.Lock()
|
||||
p.db = newDB
|
||||
p.dbMu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func isDBClosed(err error) bool {
|
||||
return err != nil && strings.Contains(err.Error(), "sql: database is closed")
|
||||
}
|
||||
|
||||
// EnableQueryDebug enables query debugging for development
|
||||
func (p *PgSQLAdapter) EnableQueryDebug() {
|
||||
logger.Info("PgSQL query debug mode - logging enabled via logger")
|
||||
@@ -38,7 +71,7 @@ func (p *PgSQLAdapter) EnableQueryDebug() {
|
||||
|
||||
func (p *PgSQLAdapter) NewSelect() common.SelectQuery {
|
||||
return &PgSQLSelectQuery{
|
||||
db: p.db,
|
||||
db: p.getDB(),
|
||||
driverName: p.driverName,
|
||||
columns: []string{"*"},
|
||||
args: make([]interface{}, 0),
|
||||
@@ -47,7 +80,7 @@ func (p *PgSQLAdapter) NewSelect() common.SelectQuery {
|
||||
|
||||
func (p *PgSQLAdapter) NewInsert() common.InsertQuery {
|
||||
return &PgSQLInsertQuery{
|
||||
db: p.db,
|
||||
db: p.getDB(),
|
||||
driverName: p.driverName,
|
||||
values: make(map[string]interface{}),
|
||||
}
|
||||
@@ -55,7 +88,7 @@ func (p *PgSQLAdapter) NewInsert() common.InsertQuery {
|
||||
|
||||
func (p *PgSQLAdapter) NewUpdate() common.UpdateQuery {
|
||||
return &PgSQLUpdateQuery{
|
||||
db: p.db,
|
||||
db: p.getDB(),
|
||||
driverName: p.driverName,
|
||||
sets: make(map[string]interface{}),
|
||||
args: make([]interface{}, 0),
|
||||
@@ -65,7 +98,7 @@ func (p *PgSQLAdapter) NewUpdate() common.UpdateQuery {
|
||||
|
||||
func (p *PgSQLAdapter) NewDelete() common.DeleteQuery {
|
||||
return &PgSQLDeleteQuery{
|
||||
db: p.db,
|
||||
db: p.getDB(),
|
||||
driverName: p.driverName,
|
||||
args: make([]interface{}, 0),
|
||||
whereClauses: make([]string, 0),
|
||||
@@ -79,7 +112,14 @@ func (p *PgSQLAdapter) Exec(ctx context.Context, query string, args ...interface
|
||||
}
|
||||
}()
|
||||
logger.Debug("PgSQL Exec: %s [args: %v]", query, args)
|
||||
result, err := p.db.ExecContext(ctx, query, args...)
|
||||
var result sql.Result
|
||||
run := func() error { var e error; result, e = p.getDB().ExecContext(ctx, query, args...); return e }
|
||||
err = run()
|
||||
if isDBClosed(err) {
|
||||
if reconnErr := p.reconnectDB(); reconnErr == nil {
|
||||
err = run()
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
logger.Error("PgSQL Exec failed: %v", err)
|
||||
return nil, err
|
||||
@@ -94,7 +134,14 @@ func (p *PgSQLAdapter) Query(ctx context.Context, dest interface{}, query string
|
||||
}
|
||||
}()
|
||||
logger.Debug("PgSQL Query: %s [args: %v]", query, args)
|
||||
rows, err := p.db.QueryContext(ctx, query, args...)
|
||||
var rows *sql.Rows
|
||||
run := func() error { var e error; rows, e = p.getDB().QueryContext(ctx, query, args...); return e }
|
||||
err = run()
|
||||
if isDBClosed(err) {
|
||||
if reconnErr := p.reconnectDB(); reconnErr == nil {
|
||||
err = run()
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
logger.Error("PgSQL Query failed: %v", err)
|
||||
return err
|
||||
@@ -105,7 +152,7 @@ func (p *PgSQLAdapter) Query(ctx context.Context, dest interface{}, query string
|
||||
}
|
||||
|
||||
func (p *PgSQLAdapter) BeginTx(ctx context.Context) (common.Database, error) {
|
||||
tx, err := p.db.BeginTx(ctx, nil)
|
||||
tx, err := p.getDB().BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -127,7 +174,7 @@ func (p *PgSQLAdapter) RunInTransaction(ctx context.Context, fn func(common.Data
|
||||
}
|
||||
}()
|
||||
|
||||
tx, err := p.db.BeginTx(ctx, nil)
|
||||
tx, err := p.getDB().BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -168,16 +168,17 @@ func SanitizeWhereClause(where string, tableName string, options ...*RequestOpti
|
||||
}
|
||||
|
||||
// Build a set of allowed table prefixes (main table + preloaded relations)
|
||||
// Keys are stored lowercase for case-insensitive matching
|
||||
allowedPrefixes := make(map[string]bool)
|
||||
if tableName != "" {
|
||||
allowedPrefixes[tableName] = true
|
||||
allowedPrefixes[strings.ToLower(tableName)] = true
|
||||
}
|
||||
|
||||
// Add preload relation names as allowed prefixes
|
||||
if len(options) > 0 && options[0] != nil {
|
||||
for pi := range options[0].Preload {
|
||||
if options[0].Preload[pi].Relation != "" {
|
||||
allowedPrefixes[options[0].Preload[pi].Relation] = true
|
||||
allowedPrefixes[strings.ToLower(options[0].Preload[pi].Relation)] = true
|
||||
logger.Debug("Added preload relation '%s' as allowed table prefix", options[0].Preload[pi].Relation)
|
||||
}
|
||||
}
|
||||
@@ -185,7 +186,7 @@ func SanitizeWhereClause(where string, tableName string, options ...*RequestOpti
|
||||
// Add join aliases as allowed prefixes
|
||||
for _, alias := range options[0].JoinAliases {
|
||||
if alias != "" {
|
||||
allowedPrefixes[alias] = true
|
||||
allowedPrefixes[strings.ToLower(alias)] = true
|
||||
logger.Debug("Added join alias '%s' as allowed table prefix", alias)
|
||||
}
|
||||
}
|
||||
@@ -217,8 +218,8 @@ func SanitizeWhereClause(where string, tableName string, options ...*RequestOpti
|
||||
currentPrefix, columnName := extractTableAndColumn(condToCheck)
|
||||
|
||||
if currentPrefix != "" && columnName != "" {
|
||||
// Check if the prefix is allowed (main table or preload relation)
|
||||
if !allowedPrefixes[currentPrefix] {
|
||||
// Check if the prefix is allowed (main table or preload relation) - case-insensitive
|
||||
if !allowedPrefixes[strings.ToLower(currentPrefix)] {
|
||||
// Prefix is not in the allowed list - only fix if it's a valid column in the main table
|
||||
if validColumns == nil || isValidColumn(columnName, validColumns) {
|
||||
// Replace the incorrect prefix with the correct main table name
|
||||
|
||||
@@ -26,6 +26,7 @@ type Connection interface {
|
||||
Bun() (*bun.DB, error)
|
||||
GORM() (*gorm.DB, error)
|
||||
Native() (*sql.DB, error)
|
||||
DB() (*sql.DB, error)
|
||||
|
||||
// Common Database interface (for SQL databases)
|
||||
Database() (common.Database, error)
|
||||
@@ -224,6 +225,11 @@ func (c *sqlConnection) Native() (*sql.DB, error) {
|
||||
return c.nativeDB, nil
|
||||
}
|
||||
|
||||
// DB returns the underlying *sql.DB connection
|
||||
func (c *sqlConnection) DB() (*sql.DB, error) {
|
||||
return c.Native()
|
||||
}
|
||||
|
||||
// Bun returns a Bun ORM instance wrapping the native connection
|
||||
func (c *sqlConnection) Bun() (*bun.DB, error) {
|
||||
if c == nil {
|
||||
@@ -645,6 +651,11 @@ func (c *mongoConnection) Native() (*sql.DB, error) {
|
||||
return nil, ErrNotSQLDatabase
|
||||
}
|
||||
|
||||
// DB returns an error for MongoDB connections
|
||||
func (c *mongoConnection) DB() (*sql.DB, error) {
|
||||
return nil, ErrNotSQLDatabase
|
||||
}
|
||||
|
||||
// Database returns an error for MongoDB connections
|
||||
func (c *mongoConnection) Database() (common.Database, error) {
|
||||
return nil, ErrNotSQLDatabase
|
||||
|
||||
@@ -4,11 +4,17 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
)
|
||||
|
||||
// isDBClosed reports whether err indicates the *sql.DB has been closed.
|
||||
func isDBClosed(err error) bool {
|
||||
return err != nil && strings.Contains(err.Error(), "sql: database is closed")
|
||||
}
|
||||
|
||||
// Common errors
|
||||
var (
|
||||
// ErrNotSQLDatabase is returned when attempting SQL operations on a non-SQL database
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
_ "github.com/glebarez/sqlite" // Pure Go SQLite driver
|
||||
@@ -14,8 +15,10 @@ import (
|
||||
|
||||
// SQLiteProvider implements Provider for SQLite databases
|
||||
type SQLiteProvider struct {
|
||||
db *sql.DB
|
||||
config ConnectionConfig
|
||||
db *sql.DB
|
||||
dbMu sync.RWMutex
|
||||
dbFactory func() (*sql.DB, error)
|
||||
config ConnectionConfig
|
||||
}
|
||||
|
||||
// NewSQLiteProvider creates a new SQLite provider
|
||||
@@ -129,7 +132,13 @@ func (p *SQLiteProvider) HealthCheck(ctx context.Context) error {
|
||||
|
||||
// Execute a simple query to verify the database is accessible
|
||||
var result int
|
||||
err := p.db.QueryRowContext(healthCtx, "SELECT 1").Scan(&result)
|
||||
run := func() error { return p.getDB().QueryRowContext(healthCtx, "SELECT 1").Scan(&result) }
|
||||
err := run()
|
||||
if isDBClosed(err) {
|
||||
if reconnErr := p.reconnectDB(); reconnErr == nil {
|
||||
err = run()
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("health check failed: %w", err)
|
||||
}
|
||||
@@ -141,6 +150,32 @@ func (p *SQLiteProvider) HealthCheck(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// WithDBFactory configures a factory used to reopen the database connection if it is closed.
|
||||
func (p *SQLiteProvider) WithDBFactory(factory func() (*sql.DB, error)) *SQLiteProvider {
|
||||
p.dbFactory = factory
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *SQLiteProvider) getDB() *sql.DB {
|
||||
p.dbMu.RLock()
|
||||
defer p.dbMu.RUnlock()
|
||||
return p.db
|
||||
}
|
||||
|
||||
func (p *SQLiteProvider) reconnectDB() error {
|
||||
if p.dbFactory == nil {
|
||||
return fmt.Errorf("no db factory configured for reconnect")
|
||||
}
|
||||
newDB, err := p.dbFactory()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p.dbMu.Lock()
|
||||
p.db = newDB
|
||||
p.dbMu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetNative returns the native *sql.DB connection
|
||||
func (p *SQLiteProvider) GetNative() (*sql.DB, error) {
|
||||
if p.db == nil {
|
||||
|
||||
513
pkg/resolvemcp/README.md
Normal file
513
pkg/resolvemcp/README.md
Normal file
@@ -0,0 +1,513 @@
|
||||
# resolvemcp
|
||||
|
||||
Package `resolvemcp` exposes registered database models as **Model Context Protocol (MCP) tools and resources** over HTTP/SSE transport. It mirrors the `resolvespec` package patterns — same model registration API, same filter/sort/pagination/preload options, same lifecycle hook system.
|
||||
|
||||
## Quick Start
|
||||
|
||||
```go
|
||||
import (
|
||||
"github.com/bitechdev/ResolveSpec/pkg/resolvemcp"
|
||||
"github.com/gorilla/mux"
|
||||
)
|
||||
|
||||
// 1. Create a handler
|
||||
handler := resolvemcp.NewHandlerWithGORM(db, resolvemcp.Config{
|
||||
BaseURL: "http://localhost:8080",
|
||||
})
|
||||
|
||||
// 2. Register models
|
||||
handler.RegisterModel("public", "users", &User{})
|
||||
handler.RegisterModel("public", "orders", &Order{})
|
||||
|
||||
// 3. Mount routes
|
||||
r := mux.NewRouter()
|
||||
resolvemcp.SetupMuxRoutes(r, handler)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Config
|
||||
|
||||
```go
|
||||
type Config struct {
|
||||
// BaseURL is the public-facing base URL of the server (e.g. "http://localhost:8080").
|
||||
// Sent to MCP clients during the SSE handshake so they know where to POST messages.
|
||||
// If empty, it is detected from each incoming request using the Host header and
|
||||
// TLS state (X-Forwarded-Proto is honoured for reverse-proxy deployments).
|
||||
BaseURL string
|
||||
|
||||
// BasePath is the URL path prefix where MCP endpoints are mounted (e.g. "/mcp").
|
||||
// Required.
|
||||
BasePath string
|
||||
}
|
||||
```
|
||||
|
||||
## Handler Creation
|
||||
|
||||
| Function | Description |
|
||||
|---|---|
|
||||
| `NewHandlerWithGORM(db *gorm.DB, cfg Config) *Handler` | Backed by GORM |
|
||||
| `NewHandlerWithBun(db *bun.DB, cfg Config) *Handler` | Backed by Bun |
|
||||
| `NewHandlerWithDB(db common.Database, cfg Config) *Handler` | Backed by any `common.Database` |
|
||||
| `NewHandler(db common.Database, registry common.ModelRegistry, cfg Config) *Handler` | Full control over registry |
|
||||
|
||||
---
|
||||
|
||||
## Registering Models
|
||||
|
||||
```go
|
||||
handler.RegisterModel(schema, entity string, model interface{}) error
|
||||
```
|
||||
|
||||
- `schema` — database schema name (e.g. `"public"`), or empty string for no schema prefix.
|
||||
- `entity` — table/entity name (e.g. `"users"`).
|
||||
- `model` — a pointer to a struct (e.g. `&User{}`).
|
||||
|
||||
Each call immediately creates four MCP **tools** and one MCP **resource** for the model.
|
||||
|
||||
---
|
||||
|
||||
## HTTP / SSE Transport
|
||||
|
||||
The `*server.SSEServer` returned by any of the helpers below implements `http.Handler`, so it works with every Go HTTP framework.
|
||||
|
||||
`Config.BasePath` is required and used for all route registration.
|
||||
`Config.BaseURL` is optional — when empty it is detected from each request.
|
||||
|
||||
### Gorilla Mux
|
||||
|
||||
```go
|
||||
resolvemcp.SetupMuxRoutes(r, handler)
|
||||
```
|
||||
|
||||
Registers:
|
||||
|
||||
| Route | Method | Description |
|
||||
|---|---|---|
|
||||
| `{BasePath}/sse` | GET | SSE connection — clients subscribe here |
|
||||
| `{BasePath}/message` | POST | JSON-RPC — clients send requests here |
|
||||
| `{BasePath}/*` | any | Full SSE server (convenience prefix) |
|
||||
|
||||
### bunrouter
|
||||
|
||||
```go
|
||||
resolvemcp.SetupBunRouterRoutes(router, handler)
|
||||
```
|
||||
|
||||
Registers `GET {BasePath}/sse` and `POST {BasePath}/message` on the provided `*bunrouter.Router`.
|
||||
|
||||
### Gin (or any `http.Handler`-compatible framework)
|
||||
|
||||
Use `handler.SSEServer()` to get an `http.Handler` and wrap it with the framework's adapter:
|
||||
|
||||
```go
|
||||
sse := handler.SSEServer()
|
||||
|
||||
// Gin
|
||||
engine.Any("/mcp/*path", gin.WrapH(sse))
|
||||
|
||||
// net/http
|
||||
http.Handle("/mcp/", sse)
|
||||
|
||||
// Echo
|
||||
e.Any("/mcp/*", echo.WrapHandler(sse))
|
||||
```
|
||||
|
||||
### Authentication
|
||||
|
||||
Add middleware before the MCP routes. The handler itself has no auth layer.
|
||||
|
||||
---
|
||||
|
||||
## Security
|
||||
|
||||
`resolvemcp` integrates with the `security` package to provide per-entity access control, row-level security, and column-level security — the same system used by `resolvespec` and `restheadspec`.
|
||||
|
||||
### Wiring security hooks
|
||||
|
||||
```go
|
||||
import "github.com/bitechdev/ResolveSpec/pkg/security"
|
||||
|
||||
securityList := security.NewSecurityList(mySecurityProvider)
|
||||
resolvemcp.RegisterSecurityHooks(handler, securityList)
|
||||
```
|
||||
|
||||
Call `RegisterSecurityHooks` **once**, after creating the handler and before registering models. It installs these controls automatically:
|
||||
|
||||
| Hook | Effect |
|
||||
|---|---|
|
||||
| `BeforeHandle` | Enforces per-entity operation rules (see below) |
|
||||
| `BeforeRead` | Loads RLS/CLS rules, then injects a user-scoped WHERE clause |
|
||||
| `AfterRead` | Masks/hides columns per column-security rules; writes audit log |
|
||||
| `BeforeUpdate` | Blocks update if `CanUpdate` is false |
|
||||
| `BeforeDelete` | Blocks delete if `CanDelete` is false |
|
||||
|
||||
### Per-entity operation rules
|
||||
|
||||
Use `RegisterModelWithRules` instead of `RegisterModel` to set access rules at registration time:
|
||||
|
||||
```go
|
||||
import "github.com/bitechdev/ResolveSpec/pkg/modelregistry"
|
||||
|
||||
// Read-only entity
|
||||
handler.RegisterModelWithRules("public", "audit_logs", &AuditLog{}, modelregistry.ModelRules{
|
||||
CanRead: true,
|
||||
CanCreate: false,
|
||||
CanUpdate: false,
|
||||
CanDelete: false,
|
||||
})
|
||||
|
||||
// Public read, authenticated write
|
||||
handler.RegisterModelWithRules("public", "products", &Product{}, modelregistry.ModelRules{
|
||||
CanPublicRead: true,
|
||||
CanRead: true,
|
||||
CanCreate: true,
|
||||
CanUpdate: true,
|
||||
CanDelete: false,
|
||||
})
|
||||
```
|
||||
|
||||
To update rules for an already-registered model:
|
||||
|
||||
```go
|
||||
handler.SetModelRules("public", "users", modelregistry.ModelRules{
|
||||
CanRead: true,
|
||||
CanCreate: true,
|
||||
CanUpdate: true,
|
||||
CanDelete: false,
|
||||
})
|
||||
```
|
||||
|
||||
`RegisterModel` (no rules) registers with all-allowed defaults (`CanRead/Create/Update/Delete = true`).
|
||||
|
||||
### ModelRules fields
|
||||
|
||||
| Field | Default | Description |
|
||||
|---|---|---|
|
||||
| `CanPublicRead` | `false` | Allow unauthenticated reads |
|
||||
| `CanPublicCreate` | `false` | Allow unauthenticated creates |
|
||||
| `CanPublicUpdate` | `false` | Allow unauthenticated updates |
|
||||
| `CanPublicDelete` | `false` | Allow unauthenticated deletes |
|
||||
| `CanRead` | `true` | Allow authenticated reads |
|
||||
| `CanCreate` | `true` | Allow authenticated creates |
|
||||
| `CanUpdate` | `true` | Allow authenticated updates |
|
||||
| `CanDelete` | `true` | Allow authenticated deletes |
|
||||
| `SecurityDisabled` | `false` | Skip all security checks for this model |
|
||||
|
||||
---
|
||||
|
||||
## MCP Tools
|
||||
|
||||
### Tool Naming
|
||||
|
||||
```
|
||||
{operation}_{schema}_{entity} // e.g. read_public_users
|
||||
{operation}_{entity} // e.g. read_users (when schema is empty)
|
||||
```
|
||||
|
||||
Operations: `read`, `create`, `update`, `delete`.
|
||||
|
||||
### Read Tool — `read_{schema}_{entity}`
|
||||
|
||||
Fetch one or many records.
|
||||
|
||||
| Argument | Type | Description |
|
||||
|---|---|---|
|
||||
| `id` | string | Primary key value. Omit to return multiple records. |
|
||||
| `limit` | number | Max records per page (recommended: 10–100). |
|
||||
| `offset` | number | Records to skip (offset-based pagination). |
|
||||
| `cursor_forward` | string | PK of the **last** record on the current page (next-page cursor). |
|
||||
| `cursor_backward` | string | PK of the **first** record on the current page (prev-page cursor). |
|
||||
| `columns` | array | Column names to include. Omit for all columns. |
|
||||
| `omit_columns` | array | Column names to exclude. |
|
||||
| `filters` | array | Filter objects (see [Filtering](#filtering)). |
|
||||
| `sort` | array | Sort objects (see [Sorting](#sorting)). |
|
||||
| `preloads` | array | Relation preload objects (see [Preloading](#preloading)). |
|
||||
|
||||
**Response:**
|
||||
```json
|
||||
{
|
||||
"success": true,
|
||||
"data": [...],
|
||||
"metadata": {
|
||||
"total": 100,
|
||||
"filtered": 100,
|
||||
"count": 10,
|
||||
"limit": 10,
|
||||
"offset": 0
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Create Tool — `create_{schema}_{entity}`
|
||||
|
||||
Insert one or more records.
|
||||
|
||||
| Argument | Type | Description |
|
||||
|---|---|---|
|
||||
| `data` | object \| array | Single object or array of objects to insert. |
|
||||
|
||||
Array input runs inside a single transaction — all succeed or all fail.
|
||||
|
||||
**Response:**
|
||||
```json
|
||||
{ "success": true, "data": { ... } }
|
||||
```
|
||||
|
||||
### Update Tool — `update_{schema}_{entity}`
|
||||
|
||||
Partially update an existing record. Only non-null, non-empty fields in `data` are applied; existing values are preserved for omitted fields.
|
||||
|
||||
| Argument | Type | Description |
|
||||
|---|---|---|
|
||||
| `id` | string | Primary key of the record. Can also be included inside `data`. |
|
||||
| `data` | object (required) | Fields to update. |
|
||||
|
||||
**Response:**
|
||||
```json
|
||||
{ "success": true, "data": { ...merged record... } }
|
||||
```
|
||||
|
||||
### Delete Tool — `delete_{schema}_{entity}`
|
||||
|
||||
Delete a record by primary key. **Irreversible.**
|
||||
|
||||
| Argument | Type | Description |
|
||||
|---|---|---|
|
||||
| `id` | string (required) | Primary key of the record to delete. |
|
||||
|
||||
**Response:**
|
||||
```json
|
||||
{ "success": true, "data": { ...deleted record... } }
|
||||
```
|
||||
|
||||
### Annotation Tool — `resolvespec_annotate`
|
||||
|
||||
Store or retrieve freeform annotation records for any tool, model, or entity. Registered automatically on every handler.
|
||||
|
||||
| Argument | Type | Description |
|
||||
|---|---|---|
|
||||
| `tool_name` | string (required) | Key to annotate — an MCP tool name (e.g. `read_public_users`), a model name (e.g. `public.users`), or any other identifier. |
|
||||
| `annotations` | object | Annotation data to persist. Omit to retrieve existing annotations instead. |
|
||||
|
||||
**Set annotations** (calls `resolvespec_set_annotation(tool_name, annotations)`):
|
||||
```json
|
||||
{ "tool_name": "read_public_users", "annotations": { "description": "Returns active users", "owner": "platform-team" } }
|
||||
```
|
||||
**Response:**
|
||||
```json
|
||||
{ "success": true, "tool_name": "read_public_users", "action": "set" }
|
||||
```
|
||||
|
||||
**Get annotations** (calls `resolvespec_get_annotation(tool_name)`):
|
||||
```json
|
||||
{ "tool_name": "read_public_users" }
|
||||
```
|
||||
**Response:**
|
||||
```json
|
||||
{ "success": true, "tool_name": "read_public_users", "action": "get", "annotations": { ... } }
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Resource — `{schema}.{entity}`
|
||||
|
||||
Each model is also registered as an MCP resource with URI `schema.entity` (or just `entity` when schema is empty). Reading the resource returns up to 100 records as `application/json`.
|
||||
|
||||
---
|
||||
|
||||
## Filtering
|
||||
|
||||
Pass an array of filter objects to the `filters` argument:
|
||||
|
||||
```json
|
||||
[
|
||||
{ "column": "status", "operator": "=", "value": "active" },
|
||||
{ "column": "age", "operator": ">", "value": 18, "logic_operator": "AND" },
|
||||
{ "column": "role", "operator": "in", "value": ["admin", "editor"], "logic_operator": "OR" }
|
||||
]
|
||||
```
|
||||
|
||||
### Supported Operators
|
||||
|
||||
| Operator | Aliases | Description |
|
||||
|---|---|---|
|
||||
| `=` | `eq` | Equal |
|
||||
| `!=` | `neq`, `<>` | Not equal |
|
||||
| `>` | `gt` | Greater than |
|
||||
| `>=` | `gte` | Greater than or equal |
|
||||
| `<` | `lt` | Less than |
|
||||
| `<=` | `lte` | Less than or equal |
|
||||
| `like` | | SQL LIKE (case-sensitive) |
|
||||
| `ilike` | | SQL ILIKE (case-insensitive) |
|
||||
| `in` | | Value in list |
|
||||
| `is_null` | | Column IS NULL |
|
||||
| `is_not_null` | | Column IS NOT NULL |
|
||||
|
||||
### Logic Operators
|
||||
|
||||
- `"logic_operator": "AND"` (default) — filter is AND-chained with the previous condition.
|
||||
- `"logic_operator": "OR"` — filter is OR-grouped with the previous condition.
|
||||
|
||||
Consecutive OR filters are grouped into a single `(cond1 OR cond2 OR ...)` clause.
|
||||
|
||||
---
|
||||
|
||||
## Sorting
|
||||
|
||||
```json
|
||||
[
|
||||
{ "column": "created_at", "direction": "desc" },
|
||||
{ "column": "name", "direction": "asc" }
|
||||
]
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Pagination
|
||||
|
||||
### Offset-Based
|
||||
|
||||
```json
|
||||
{ "limit": 20, "offset": 40 }
|
||||
```
|
||||
|
||||
### Cursor-Based
|
||||
|
||||
Cursor pagination uses a SQL `EXISTS` subquery for stable, efficient paging. Always pair with a `sort` argument.
|
||||
|
||||
```json
|
||||
// Next page: pass the PK of the last record on the current page
|
||||
{ "cursor_forward": "42", "limit": 20, "sort": [{"column": "id", "direction": "asc"}] }
|
||||
|
||||
// Previous page: pass the PK of the first record on the current page
|
||||
{ "cursor_backward": "23", "limit": 20, "sort": [{"column": "id", "direction": "asc"}] }
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Preloading Relations
|
||||
|
||||
```json
|
||||
[
|
||||
{ "relation": "Profile" },
|
||||
{ "relation": "Orders" }
|
||||
]
|
||||
```
|
||||
|
||||
Available relations are listed in each tool's description. Only relations defined on the model struct are valid.
|
||||
|
||||
---
|
||||
|
||||
## Hook System
|
||||
|
||||
Hooks let you intercept and modify CRUD operations at well-defined lifecycle points.
|
||||
|
||||
### Hook Types
|
||||
|
||||
| Constant | Fires |
|
||||
|---|---|
|
||||
| `BeforeHandle` | After model resolution, before operation dispatch (all CRUD) |
|
||||
| `BeforeRead` / `AfterRead` | Around read queries |
|
||||
| `BeforeCreate` / `AfterCreate` | Around insert |
|
||||
| `BeforeUpdate` / `AfterUpdate` | Around update |
|
||||
| `BeforeDelete` / `AfterDelete` | Around delete |
|
||||
|
||||
### Registering Hooks
|
||||
|
||||
```go
|
||||
handler.Hooks().Register(resolvemcp.BeforeCreate, func(ctx *resolvemcp.HookContext) error {
|
||||
// Inject a timestamp before insert
|
||||
if data, ok := ctx.Data.(map[string]interface{}); ok {
|
||||
data["created_at"] = time.Now()
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
// Register the same hook for multiple events
|
||||
handler.Hooks().RegisterMultiple(
|
||||
[]resolvemcp.HookType{resolvemcp.BeforeCreate, resolvemcp.BeforeUpdate},
|
||||
auditHook,
|
||||
)
|
||||
```
|
||||
|
||||
### HookContext Fields
|
||||
|
||||
| Field | Type | Description |
|
||||
|---|---|---|
|
||||
| `Context` | `context.Context` | Request context |
|
||||
| `Handler` | `*Handler` | The resolvemcp handler |
|
||||
| `Schema` | `string` | Database schema name |
|
||||
| `Entity` | `string` | Entity/table name |
|
||||
| `Model` | `interface{}` | Registered model instance |
|
||||
| `Options` | `common.RequestOptions` | Parsed request options (read operations) |
|
||||
| `Operation` | `string` | `"read"`, `"create"`, `"update"`, or `"delete"` |
|
||||
| `ID` | `string` | Primary key from request (read/update/delete) |
|
||||
| `Data` | `interface{}` | Input data (create/update — modifiable) |
|
||||
| `Result` | `interface{}` | Output data (set by After hooks) |
|
||||
| `Error` | `error` | Operation error, if any |
|
||||
| `Query` | `common.SelectQuery` | Live query object (available in `BeforeRead`) |
|
||||
| `Tx` | `common.Database` | Database/transaction handle |
|
||||
| `Abort` | `bool` | Set to `true` to abort the operation |
|
||||
| `AbortMessage` | `string` | Error message returned when aborting |
|
||||
| `AbortCode` | `int` | Optional status code for the abort |
|
||||
|
||||
### Aborting an Operation
|
||||
|
||||
```go
|
||||
handler.Hooks().Register(resolvemcp.BeforeDelete, func(ctx *resolvemcp.HookContext) error {
|
||||
ctx.Abort = true
|
||||
ctx.AbortMessage = "deletion is disabled"
|
||||
return nil
|
||||
})
|
||||
```
|
||||
|
||||
### Managing Hooks
|
||||
|
||||
```go
|
||||
registry := handler.Hooks()
|
||||
registry.HasHooks(resolvemcp.BeforeCreate) // bool
|
||||
registry.Clear(resolvemcp.BeforeCreate) // remove hooks for one type
|
||||
registry.ClearAll() // remove all hooks
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Context Helpers
|
||||
|
||||
Request metadata is threaded through `context.Context` during handler execution. Hooks and custom tools can read it:
|
||||
|
||||
```go
|
||||
schema := resolvemcp.GetSchema(ctx)
|
||||
entity := resolvemcp.GetEntity(ctx)
|
||||
tableName := resolvemcp.GetTableName(ctx)
|
||||
model := resolvemcp.GetModel(ctx)
|
||||
modelPtr := resolvemcp.GetModelPtr(ctx)
|
||||
```
|
||||
|
||||
You can also set values manually (e.g. in middleware):
|
||||
|
||||
```go
|
||||
ctx = resolvemcp.WithSchema(ctx, "tenant_a")
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Adding Custom MCP Tools
|
||||
|
||||
Access the underlying `*server.MCPServer` to register additional tools:
|
||||
|
||||
```go
|
||||
mcpServer := handler.MCPServer()
|
||||
mcpServer.AddTool(myTool, myHandler)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Table Name Resolution
|
||||
|
||||
The handler resolves table names in priority order:
|
||||
|
||||
1. `TableNameProvider` interface — `TableName() string` (can return `"schema.table"`)
|
||||
2. `SchemaProvider` interface — `SchemaName() string` (combined with entity name)
|
||||
3. Fallback: `schema.entity` (or `schema_entity` for SQLite)
|
||||
107
pkg/resolvemcp/annotation.go
Normal file
107
pkg/resolvemcp/annotation.go
Normal file
@@ -0,0 +1,107 @@
|
||||
package resolvemcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
)
|
||||
|
||||
const annotationToolName = "resolvespec_annotate"
|
||||
|
||||
// registerAnnotationTool adds the resolvespec_annotate tool to the MCP server.
|
||||
// The tool lets models/entities store and retrieve freeform annotation records
|
||||
// using the resolvespec_set_annotation / resolvespec_get_annotation database procedures.
|
||||
func registerAnnotationTool(h *Handler) {
|
||||
tool := mcp.NewTool(annotationToolName,
|
||||
mcp.WithDescription(
|
||||
"Store or retrieve annotations for any MCP tool, model, or entity.\n\n"+
|
||||
"To set annotations: provide both 'tool_name' and 'annotations'. "+
|
||||
"Calls resolvespec_set_annotation(tool_name, annotations) to persist the data.\n\n"+
|
||||
"To get annotations: provide only 'tool_name'. "+
|
||||
"Calls resolvespec_get_annotation(tool_name) and returns the stored annotations.\n\n"+
|
||||
"'tool_name' may be any identifier: an MCP tool name (e.g. 'read_public_users'), "+
|
||||
"a model/entity name (e.g. 'public.users'), or any other key.",
|
||||
),
|
||||
mcp.WithString("tool_name",
|
||||
mcp.Description("Name of the tool, model, or entity to annotate (e.g. 'read_public_users', 'public.users')."),
|
||||
mcp.Required(),
|
||||
),
|
||||
mcp.WithObject("annotations",
|
||||
mcp.Description("Annotation data to store. Omit to retrieve existing annotations instead of setting them."),
|
||||
),
|
||||
)
|
||||
|
||||
h.mcpServer.AddTool(tool, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
args := req.GetArguments()
|
||||
|
||||
toolName, ok := args["tool_name"].(string)
|
||||
if !ok || toolName == "" {
|
||||
return mcp.NewToolResultError("missing required argument: tool_name"), nil
|
||||
}
|
||||
|
||||
annotations, hasAnnotations := args["annotations"]
|
||||
|
||||
if hasAnnotations && annotations != nil {
|
||||
return executeSetAnnotation(ctx, h, toolName, annotations)
|
||||
}
|
||||
return executeGetAnnotation(ctx, h, toolName)
|
||||
})
|
||||
}
|
||||
|
||||
func executeSetAnnotation(ctx context.Context, h *Handler, toolName string, annotations interface{}) (*mcp.CallToolResult, error) {
|
||||
jsonBytes, err := json.Marshal(annotations)
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("failed to marshal annotations: %v", err)), nil
|
||||
}
|
||||
|
||||
_, err = h.db.Exec(ctx, "SELECT resolvespec_set_annotation($1, $2)", toolName, string(jsonBytes))
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("failed to set annotation: %v", err)), nil
|
||||
}
|
||||
|
||||
return marshalResult(map[string]interface{}{
|
||||
"success": true,
|
||||
"tool_name": toolName,
|
||||
"action": "set",
|
||||
})
|
||||
}
|
||||
|
||||
func executeGetAnnotation(ctx context.Context, h *Handler, toolName string) (*mcp.CallToolResult, error) {
|
||||
var rows []map[string]interface{}
|
||||
err := h.db.Query(ctx, &rows, "SELECT resolvespec_get_annotation($1)", toolName)
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("failed to get annotation: %v", err)), nil
|
||||
}
|
||||
|
||||
var annotations interface{}
|
||||
if len(rows) > 0 {
|
||||
// The procedure returns a single value; extract the first column of the first row.
|
||||
for _, v := range rows[0] {
|
||||
annotations = v
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// If the value is a []byte or string containing JSON, decode it so it round-trips cleanly.
|
||||
switch v := annotations.(type) {
|
||||
case []byte:
|
||||
var decoded interface{}
|
||||
if json.Unmarshal(v, &decoded) == nil {
|
||||
annotations = decoded
|
||||
}
|
||||
case string:
|
||||
var decoded interface{}
|
||||
if json.Unmarshal([]byte(v), &decoded) == nil {
|
||||
annotations = decoded
|
||||
}
|
||||
}
|
||||
|
||||
return marshalResult(map[string]interface{}{
|
||||
"success": true,
|
||||
"tool_name": toolName,
|
||||
"action": "get",
|
||||
"annotations": annotations,
|
||||
})
|
||||
}
|
||||
71
pkg/resolvemcp/context.go
Normal file
71
pkg/resolvemcp/context.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package resolvemcp
|
||||
|
||||
import "context"
|
||||
|
||||
type contextKey string
|
||||
|
||||
const (
|
||||
contextKeySchema contextKey = "schema"
|
||||
contextKeyEntity contextKey = "entity"
|
||||
contextKeyTableName contextKey = "tableName"
|
||||
contextKeyModel contextKey = "model"
|
||||
contextKeyModelPtr contextKey = "modelPtr"
|
||||
)
|
||||
|
||||
func WithSchema(ctx context.Context, schema string) context.Context {
|
||||
return context.WithValue(ctx, contextKeySchema, schema)
|
||||
}
|
||||
|
||||
func GetSchema(ctx context.Context) string {
|
||||
if v := ctx.Value(contextKeySchema); v != nil {
|
||||
return v.(string)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func WithEntity(ctx context.Context, entity string) context.Context {
|
||||
return context.WithValue(ctx, contextKeyEntity, entity)
|
||||
}
|
||||
|
||||
func GetEntity(ctx context.Context) string {
|
||||
if v := ctx.Value(contextKeyEntity); v != nil {
|
||||
return v.(string)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func WithTableName(ctx context.Context, tableName string) context.Context {
|
||||
return context.WithValue(ctx, contextKeyTableName, tableName)
|
||||
}
|
||||
|
||||
func GetTableName(ctx context.Context) string {
|
||||
if v := ctx.Value(contextKeyTableName); v != nil {
|
||||
return v.(string)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func WithModel(ctx context.Context, model interface{}) context.Context {
|
||||
return context.WithValue(ctx, contextKeyModel, model)
|
||||
}
|
||||
|
||||
func GetModel(ctx context.Context) interface{} {
|
||||
return ctx.Value(contextKeyModel)
|
||||
}
|
||||
|
||||
func WithModelPtr(ctx context.Context, modelPtr interface{}) context.Context {
|
||||
return context.WithValue(ctx, contextKeyModelPtr, modelPtr)
|
||||
}
|
||||
|
||||
func GetModelPtr(ctx context.Context) interface{} {
|
||||
return ctx.Value(contextKeyModelPtr)
|
||||
}
|
||||
|
||||
func withRequestData(ctx context.Context, schema, entity, tableName string, model, modelPtr interface{}) context.Context {
|
||||
ctx = WithSchema(ctx, schema)
|
||||
ctx = WithEntity(ctx, entity)
|
||||
ctx = WithTableName(ctx, tableName)
|
||||
ctx = WithModel(ctx, model)
|
||||
ctx = WithModelPtr(ctx, modelPtr)
|
||||
return ctx
|
||||
}
|
||||
161
pkg/resolvemcp/cursor.go
Normal file
161
pkg/resolvemcp/cursor.go
Normal file
@@ -0,0 +1,161 @@
|
||||
package resolvemcp
|
||||
|
||||
// Cursor-based pagination adapted from pkg/resolvespec/cursor.go.
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
)
|
||||
|
||||
type cursorDirection int
|
||||
|
||||
const (
|
||||
cursorForward cursorDirection = 1
|
||||
cursorBackward cursorDirection = -1
|
||||
)
|
||||
|
||||
// getCursorFilter generates a SQL EXISTS subquery for cursor-based pagination.
|
||||
// expandJoins is an optional map[alias]string of JOIN clauses for join-column sort support.
|
||||
func getCursorFilter(
|
||||
tableName string,
|
||||
pkName string,
|
||||
modelColumns []string,
|
||||
options common.RequestOptions,
|
||||
expandJoins map[string]string,
|
||||
) (string, error) {
|
||||
fullTableName := tableName
|
||||
if strings.Contains(tableName, ".") {
|
||||
tableName = strings.SplitN(tableName, ".", 2)[1]
|
||||
}
|
||||
|
||||
cursorID, direction := getActiveCursor(options)
|
||||
if cursorID == "" {
|
||||
return "", fmt.Errorf("no cursor provided for table %s", tableName)
|
||||
}
|
||||
|
||||
sortItems := options.Sort
|
||||
if len(sortItems) == 0 {
|
||||
return "", fmt.Errorf("no sort columns defined")
|
||||
}
|
||||
|
||||
var whereClauses []string
|
||||
joinSQL := ""
|
||||
reverse := direction < 0
|
||||
|
||||
for _, s := range sortItems {
|
||||
col := strings.Trim(strings.TrimSpace(s.Column), "()")
|
||||
if col == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
parts := strings.Split(col, ".")
|
||||
field := strings.TrimSpace(parts[len(parts)-1])
|
||||
prefix := strings.Join(parts[:len(parts)-1], ".")
|
||||
|
||||
desc := strings.EqualFold(s.Direction, "desc")
|
||||
if reverse {
|
||||
desc = !desc
|
||||
}
|
||||
|
||||
cursorCol, targetCol, isJoin, err := resolveCursorColumn(field, prefix, tableName, modelColumns)
|
||||
if err != nil {
|
||||
logger.Warn("Skipping invalid sort column %q: %v", col, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if isJoin {
|
||||
if expandJoins != nil {
|
||||
if joinClause, ok := expandJoins[prefix]; ok {
|
||||
jSQL, cRef := rewriteCursorJoin(joinClause, tableName, prefix)
|
||||
joinSQL = jSQL
|
||||
cursorCol = cRef + "." + field
|
||||
targetCol = prefix + "." + field
|
||||
}
|
||||
}
|
||||
if cursorCol == "" {
|
||||
logger.Warn("Skipping cursor sort column %q: join alias %q not in expandJoins", col, prefix)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
op := "<"
|
||||
if desc {
|
||||
op = ">"
|
||||
}
|
||||
whereClauses = append(whereClauses, fmt.Sprintf("%s %s %s", cursorCol, op, targetCol))
|
||||
}
|
||||
|
||||
if len(whereClauses) == 0 {
|
||||
return "", fmt.Errorf("no valid sort columns after filtering")
|
||||
}
|
||||
|
||||
orSQL := buildCursorPriorityChain(whereClauses)
|
||||
|
||||
query := fmt.Sprintf(`EXISTS (
|
||||
SELECT 1
|
||||
FROM %s cursor_select
|
||||
%s
|
||||
WHERE cursor_select.%s = %s
|
||||
AND (%s)
|
||||
)`,
|
||||
fullTableName,
|
||||
joinSQL,
|
||||
pkName,
|
||||
cursorID,
|
||||
orSQL,
|
||||
)
|
||||
|
||||
return query, nil
|
||||
}
|
||||
|
||||
func getActiveCursor(options common.RequestOptions) (id string, direction cursorDirection) {
|
||||
if options.CursorForward != "" {
|
||||
return options.CursorForward, cursorForward
|
||||
}
|
||||
if options.CursorBackward != "" {
|
||||
return options.CursorBackward, cursorBackward
|
||||
}
|
||||
return "", 0
|
||||
}
|
||||
|
||||
func resolveCursorColumn(field, prefix, tableName string, modelColumns []string) (cursorCol, targetCol string, isJoin bool, err error) {
|
||||
if strings.Contains(field, "->") {
|
||||
return "cursor_select." + field, tableName + "." + field, false, nil
|
||||
}
|
||||
|
||||
if modelColumns != nil {
|
||||
for _, col := range modelColumns {
|
||||
if strings.EqualFold(col, field) {
|
||||
return "cursor_select." + field, tableName + "." + field, false, nil
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return "cursor_select." + field, tableName + "." + field, false, nil
|
||||
}
|
||||
|
||||
if prefix != "" && prefix != tableName {
|
||||
return "", "", true, nil
|
||||
}
|
||||
|
||||
return "", "", false, fmt.Errorf("invalid column: %s", field)
|
||||
}
|
||||
|
||||
func rewriteCursorJoin(joinClause, mainTable, alias string) (joinSQL, cursorAlias string) {
|
||||
joinSQL = strings.ReplaceAll(joinClause, mainTable+".", "cursor_select.")
|
||||
cursorAlias = "cursor_select_" + alias
|
||||
joinSQL = strings.ReplaceAll(joinSQL, " "+alias+" ", " "+cursorAlias+" ")
|
||||
joinSQL = strings.ReplaceAll(joinSQL, " "+alias+".", " "+cursorAlias+".")
|
||||
return joinSQL, cursorAlias
|
||||
}
|
||||
|
||||
func buildCursorPriorityChain(clauses []string) string {
|
||||
var or []string
|
||||
for i := 0; i < len(clauses); i++ {
|
||||
and := strings.Join(clauses[:i+1], "\n AND ")
|
||||
or = append(or, "("+and+")")
|
||||
}
|
||||
return strings.Join(or, "\n OR ")
|
||||
}
|
||||
736
pkg/resolvemcp/handler.go
Normal file
736
pkg/resolvemcp/handler.go
Normal file
@@ -0,0 +1,736 @@
|
||||
package resolvemcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/mark3labs/mcp-go/server"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
||||
)
|
||||
|
||||
// Handler exposes registered database models as MCP tools and resources.
|
||||
type Handler struct {
|
||||
db common.Database
|
||||
registry common.ModelRegistry
|
||||
hooks *HookRegistry
|
||||
mcpServer *server.MCPServer
|
||||
config Config
|
||||
name string
|
||||
version string
|
||||
}
|
||||
|
||||
// NewHandler creates a Handler with the given database, model registry, and config.
|
||||
func NewHandler(db common.Database, registry common.ModelRegistry, cfg Config) *Handler {
|
||||
h := &Handler{
|
||||
db: db,
|
||||
registry: registry,
|
||||
hooks: NewHookRegistry(),
|
||||
mcpServer: server.NewMCPServer("resolvemcp", "1.0.0"),
|
||||
config: cfg,
|
||||
name: "resolvemcp",
|
||||
version: "1.0.0",
|
||||
}
|
||||
registerAnnotationTool(h)
|
||||
return h
|
||||
}
|
||||
|
||||
// Hooks returns the hook registry.
|
||||
func (h *Handler) Hooks() *HookRegistry {
|
||||
return h.hooks
|
||||
}
|
||||
|
||||
// GetDatabase returns the underlying database.
|
||||
func (h *Handler) GetDatabase() common.Database {
|
||||
return h.db
|
||||
}
|
||||
|
||||
// MCPServer returns the underlying MCP server, e.g. to add custom tools.
|
||||
func (h *Handler) MCPServer() *server.MCPServer {
|
||||
return h.mcpServer
|
||||
}
|
||||
|
||||
// SSEServer returns an http.Handler that serves MCP over SSE.
|
||||
// Config.BasePath must be set. Config.BaseURL is used when set; if empty it is
|
||||
// detected automatically from each incoming request.
|
||||
func (h *Handler) SSEServer() http.Handler {
|
||||
if h.config.BaseURL != "" {
|
||||
return h.newSSEServer(h.config.BaseURL, h.config.BasePath)
|
||||
}
|
||||
return &dynamicSSEHandler{h: h}
|
||||
}
|
||||
|
||||
// newSSEServer creates a concrete *server.SSEServer for known baseURL and basePath values.
|
||||
func (h *Handler) newSSEServer(baseURL, basePath string) *server.SSEServer {
|
||||
return server.NewSSEServer(
|
||||
h.mcpServer,
|
||||
server.WithBaseURL(baseURL),
|
||||
server.WithStaticBasePath(basePath),
|
||||
)
|
||||
}
|
||||
|
||||
// dynamicSSEHandler detects BaseURL from each request and delegates to a cached
|
||||
// *server.SSEServer per detected baseURL. Used when Config.BaseURL is empty.
|
||||
type dynamicSSEHandler struct {
|
||||
h *Handler
|
||||
mu sync.Mutex
|
||||
pool map[string]*server.SSEServer
|
||||
}
|
||||
|
||||
func (d *dynamicSSEHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
baseURL := requestBaseURL(r)
|
||||
|
||||
d.mu.Lock()
|
||||
if d.pool == nil {
|
||||
d.pool = make(map[string]*server.SSEServer)
|
||||
}
|
||||
s, ok := d.pool[baseURL]
|
||||
if !ok {
|
||||
s = d.h.newSSEServer(baseURL, d.h.config.BasePath)
|
||||
d.pool[baseURL] = s
|
||||
}
|
||||
d.mu.Unlock()
|
||||
|
||||
s.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
// requestBaseURL builds the base URL from an incoming request.
|
||||
// It honours the X-Forwarded-Proto header for deployments behind a proxy.
|
||||
func requestBaseURL(r *http.Request) string {
|
||||
scheme := "http"
|
||||
if r.TLS != nil {
|
||||
scheme = "https"
|
||||
}
|
||||
if proto := r.Header.Get("X-Forwarded-Proto"); proto != "" {
|
||||
scheme = proto
|
||||
}
|
||||
return scheme + "://" + r.Host
|
||||
}
|
||||
|
||||
// RegisterModel registers a model and immediately exposes it as MCP tools and a resource.
|
||||
func (h *Handler) RegisterModel(schema, entity string, model interface{}) error {
|
||||
fullName := buildModelName(schema, entity)
|
||||
if err := h.registry.RegisterModel(fullName, model); err != nil {
|
||||
return err
|
||||
}
|
||||
registerModelTools(h, schema, entity, model)
|
||||
return nil
|
||||
}
|
||||
|
||||
// RegisterModelWithRules registers a model and sets per-entity operation rules
|
||||
// (CanRead, CanCreate, CanUpdate, CanDelete, CanPublic*, SecurityDisabled).
|
||||
// Requires RegisterSecurityHooks to have been called for the rules to be enforced.
|
||||
func (h *Handler) RegisterModelWithRules(schema, entity string, model interface{}, rules modelregistry.ModelRules) error {
|
||||
reg, ok := h.registry.(*modelregistry.DefaultModelRegistry)
|
||||
if !ok {
|
||||
return fmt.Errorf("resolvemcp: registry does not support model rules (use NewHandlerWithGORM/Bun/DB)")
|
||||
}
|
||||
fullName := buildModelName(schema, entity)
|
||||
if err := reg.RegisterModelWithRules(fullName, model, rules); err != nil {
|
||||
return err
|
||||
}
|
||||
registerModelTools(h, schema, entity, model)
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetModelRules updates the operation rules for an already-registered model.
|
||||
// Requires RegisterSecurityHooks to have been called for the rules to be enforced.
|
||||
func (h *Handler) SetModelRules(schema, entity string, rules modelregistry.ModelRules) error {
|
||||
reg, ok := h.registry.(*modelregistry.DefaultModelRegistry)
|
||||
if !ok {
|
||||
return fmt.Errorf("resolvemcp: registry does not support model rules (use NewHandlerWithGORM/Bun/DB)")
|
||||
}
|
||||
return reg.SetModelRules(buildModelName(schema, entity), rules)
|
||||
}
|
||||
|
||||
// buildModelName builds the registry key for a model (same format as resolvespec).
|
||||
func buildModelName(schema, entity string) string {
|
||||
if schema == "" {
|
||||
return entity
|
||||
}
|
||||
return fmt.Sprintf("%s.%s", schema, entity)
|
||||
}
|
||||
|
||||
// getTableName returns the fully qualified table name for a model.
|
||||
func (h *Handler) getTableName(schema, entity string, model interface{}) string {
|
||||
schemaName, tableName := h.getSchemaAndTable(schema, entity, model)
|
||||
if schemaName != "" {
|
||||
if h.db.DriverName() == "sqlite" {
|
||||
return fmt.Sprintf("%s_%s", schemaName, tableName)
|
||||
}
|
||||
return fmt.Sprintf("%s.%s", schemaName, tableName)
|
||||
}
|
||||
return tableName
|
||||
}
|
||||
|
||||
func (h *Handler) getSchemaAndTable(defaultSchema, entity string, model interface{}) (schema, table string) {
|
||||
if tableProvider, ok := model.(common.TableNameProvider); ok {
|
||||
tableName := tableProvider.TableName()
|
||||
if idx := strings.LastIndex(tableName, "."); idx != -1 {
|
||||
return tableName[:idx], tableName[idx+1:]
|
||||
}
|
||||
if schemaProvider, ok := model.(common.SchemaProvider); ok {
|
||||
return schemaProvider.SchemaName(), tableName
|
||||
}
|
||||
return defaultSchema, tableName
|
||||
}
|
||||
if schemaProvider, ok := model.(common.SchemaProvider); ok {
|
||||
return schemaProvider.SchemaName(), entity
|
||||
}
|
||||
return defaultSchema, entity
|
||||
}
|
||||
|
||||
// executeRead reads records from the database and returns raw data + metadata.
|
||||
func (h *Handler) executeRead(ctx context.Context, schema, entity, id string, options common.RequestOptions) (interface{}, *common.Metadata, error) {
|
||||
model, err := h.registry.GetModelByEntity(schema, entity)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("model not found: %w", err)
|
||||
}
|
||||
|
||||
unwrapped, err := common.ValidateAndUnwrapModel(model)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("invalid model: %w", err)
|
||||
}
|
||||
|
||||
model = unwrapped.Model
|
||||
modelType := unwrapped.ModelType
|
||||
tableName := h.getTableName(schema, entity, model)
|
||||
ctx = withRequestData(ctx, schema, entity, tableName, model, unwrapped.ModelPtr)
|
||||
|
||||
validator := common.NewColumnValidator(model)
|
||||
options = validator.FilterRequestOptions(options)
|
||||
|
||||
// BeforeHandle hook
|
||||
hookCtx := &HookContext{
|
||||
Context: ctx,
|
||||
Handler: h,
|
||||
Schema: schema,
|
||||
Entity: entity,
|
||||
Model: model,
|
||||
Operation: "read",
|
||||
Options: options,
|
||||
ID: id,
|
||||
Tx: h.db,
|
||||
}
|
||||
if err := h.hooks.Execute(BeforeHandle, hookCtx); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
sliceType := reflect.SliceOf(reflect.PointerTo(modelType))
|
||||
modelPtr := reflect.New(sliceType).Interface()
|
||||
|
||||
query := h.db.NewSelect().Model(modelPtr)
|
||||
|
||||
tempInstance := reflect.New(modelType).Interface()
|
||||
if provider, ok := tempInstance.(common.TableNameProvider); !ok || provider.TableName() == "" {
|
||||
query = query.Table(tableName)
|
||||
}
|
||||
|
||||
// Column selection
|
||||
if len(options.Columns) == 0 && len(options.ComputedColumns) > 0 {
|
||||
options.Columns = reflection.GetSQLModelColumns(model)
|
||||
}
|
||||
for _, col := range options.Columns {
|
||||
query = query.Column(reflection.ExtractSourceColumn(col))
|
||||
}
|
||||
for _, cu := range options.ComputedColumns {
|
||||
query = query.ColumnExpr(fmt.Sprintf("(%s) AS %s", cu.Expression, cu.Name))
|
||||
}
|
||||
|
||||
// Preloads
|
||||
if len(options.Preload) > 0 {
|
||||
var err error
|
||||
query, err = h.applyPreloads(model, query, options.Preload)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to apply preloads: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Filters
|
||||
query = h.applyFilters(query, options.Filters)
|
||||
|
||||
// Custom operators
|
||||
for _, customOp := range options.CustomOperators {
|
||||
query = query.Where(customOp.SQL)
|
||||
}
|
||||
|
||||
// Sorting
|
||||
for _, sort := range options.Sort {
|
||||
direction := "ASC"
|
||||
if strings.EqualFold(sort.Direction, "desc") {
|
||||
direction = "DESC"
|
||||
}
|
||||
query = query.Order(fmt.Sprintf("%s %s", sort.Column, direction))
|
||||
}
|
||||
|
||||
// Cursor pagination
|
||||
if options.CursorForward != "" || options.CursorBackward != "" {
|
||||
pkName := reflection.GetPrimaryKeyName(model)
|
||||
modelColumns := reflection.GetModelColumns(model)
|
||||
|
||||
if len(options.Sort) == 0 {
|
||||
options.Sort = []common.SortOption{{Column: pkName, Direction: "ASC"}}
|
||||
}
|
||||
|
||||
// expandJoins is empty for resolvemcp — no custom SQL join support yet
|
||||
cursorFilter, err := getCursorFilter(tableName, pkName, modelColumns, options, nil)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("cursor error: %w", err)
|
||||
}
|
||||
|
||||
if cursorFilter != "" {
|
||||
sanitized := common.SanitizeWhereClause(cursorFilter, reflection.ExtractTableNameOnly(tableName), &options)
|
||||
sanitized = common.EnsureOuterParentheses(sanitized)
|
||||
if sanitized != "" {
|
||||
query = query.Where(sanitized)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Count
|
||||
total, err := query.Count(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("error counting records: %w", err)
|
||||
}
|
||||
|
||||
// Pagination
|
||||
if options.Limit != nil && *options.Limit > 0 {
|
||||
query = query.Limit(*options.Limit)
|
||||
}
|
||||
if options.Offset != nil && *options.Offset > 0 {
|
||||
query = query.Offset(*options.Offset)
|
||||
}
|
||||
|
||||
// BeforeRead hook
|
||||
hookCtx.Query = query
|
||||
if err := h.hooks.Execute(BeforeRead, hookCtx); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
var data interface{}
|
||||
if id != "" {
|
||||
singleResult := reflect.New(modelType).Interface()
|
||||
pkName := reflection.GetPrimaryKeyName(singleResult)
|
||||
query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id)
|
||||
if err := query.Scan(ctx, singleResult); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil, fmt.Errorf("record not found")
|
||||
}
|
||||
return nil, nil, fmt.Errorf("query error: %w", err)
|
||||
}
|
||||
data = singleResult
|
||||
} else {
|
||||
if err := query.Scan(ctx, modelPtr); err != nil {
|
||||
return nil, nil, fmt.Errorf("query error: %w", err)
|
||||
}
|
||||
data = reflect.ValueOf(modelPtr).Elem().Interface()
|
||||
}
|
||||
|
||||
limit := 0
|
||||
offset := 0
|
||||
if options.Limit != nil {
|
||||
limit = *options.Limit
|
||||
}
|
||||
if options.Offset != nil {
|
||||
offset = *options.Offset
|
||||
}
|
||||
|
||||
// Count is the number of records in this page, not the total.
|
||||
var pageCount int64
|
||||
if id != "" {
|
||||
pageCount = 1
|
||||
} else {
|
||||
pageCount = int64(reflect.ValueOf(data).Len())
|
||||
}
|
||||
|
||||
metadata := &common.Metadata{
|
||||
Total: int64(total),
|
||||
Filtered: int64(total),
|
||||
Count: pageCount,
|
||||
Limit: limit,
|
||||
Offset: offset,
|
||||
}
|
||||
|
||||
// AfterRead hook
|
||||
hookCtx.Result = data
|
||||
if err := h.hooks.Execute(AfterRead, hookCtx); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return data, metadata, nil
|
||||
}
|
||||
|
||||
// executeCreate inserts one or more records.
|
||||
func (h *Handler) executeCreate(ctx context.Context, schema, entity string, data interface{}) (interface{}, error) {
|
||||
model, err := h.registry.GetModelByEntity(schema, entity)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("model not found: %w", err)
|
||||
}
|
||||
|
||||
result, err := common.ValidateAndUnwrapModel(model)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid model: %w", err)
|
||||
}
|
||||
|
||||
model = result.Model
|
||||
tableName := h.getTableName(schema, entity, model)
|
||||
ctx = withRequestData(ctx, schema, entity, tableName, model, result.ModelPtr)
|
||||
|
||||
hookCtx := &HookContext{
|
||||
Context: ctx,
|
||||
Handler: h,
|
||||
Schema: schema,
|
||||
Entity: entity,
|
||||
Model: model,
|
||||
Operation: "create",
|
||||
Data: data,
|
||||
Tx: h.db,
|
||||
}
|
||||
if err := h.hooks.Execute(BeforeHandle, hookCtx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := h.hooks.Execute(BeforeCreate, hookCtx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Use potentially modified data
|
||||
data = hookCtx.Data
|
||||
|
||||
switch v := data.(type) {
|
||||
case map[string]interface{}:
|
||||
query := h.db.NewInsert().Table(tableName)
|
||||
for key, value := range v {
|
||||
query = query.Value(key, value)
|
||||
}
|
||||
if _, err := query.Exec(ctx); err != nil {
|
||||
return nil, fmt.Errorf("create error: %w", err)
|
||||
}
|
||||
hookCtx.Result = v
|
||||
if err := h.hooks.Execute(AfterCreate, hookCtx); err != nil {
|
||||
return nil, fmt.Errorf("AfterCreate hook failed: %w", err)
|
||||
}
|
||||
return v, nil
|
||||
|
||||
case []interface{}:
|
||||
results := make([]interface{}, 0, len(v))
|
||||
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
||||
for _, item := range v {
|
||||
itemMap, ok := item.(map[string]interface{})
|
||||
if !ok {
|
||||
return fmt.Errorf("each item must be an object")
|
||||
}
|
||||
q := tx.NewInsert().Table(tableName)
|
||||
for key, value := range itemMap {
|
||||
q = q.Value(key, value)
|
||||
}
|
||||
if _, err := q.Exec(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
results = append(results, item)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("batch create error: %w", err)
|
||||
}
|
||||
hookCtx.Result = results
|
||||
if err := h.hooks.Execute(AfterCreate, hookCtx); err != nil {
|
||||
return nil, fmt.Errorf("AfterCreate hook failed: %w", err)
|
||||
}
|
||||
return results, nil
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("data must be an object or array of objects")
|
||||
}
|
||||
}
|
||||
|
||||
// executeUpdate updates a record by ID.
|
||||
func (h *Handler) executeUpdate(ctx context.Context, schema, entity, id string, data interface{}) (interface{}, error) {
|
||||
model, err := h.registry.GetModelByEntity(schema, entity)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("model not found: %w", err)
|
||||
}
|
||||
|
||||
result, err := common.ValidateAndUnwrapModel(model)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid model: %w", err)
|
||||
}
|
||||
|
||||
model = result.Model
|
||||
tableName := h.getTableName(schema, entity, model)
|
||||
ctx = withRequestData(ctx, schema, entity, tableName, model, result.ModelPtr)
|
||||
|
||||
updates, ok := data.(map[string]interface{})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("data must be an object")
|
||||
}
|
||||
|
||||
if id == "" {
|
||||
if idVal, exists := updates["id"]; exists {
|
||||
id = fmt.Sprintf("%v", idVal)
|
||||
}
|
||||
}
|
||||
if id == "" {
|
||||
return nil, fmt.Errorf("update requires an ID")
|
||||
}
|
||||
|
||||
pkName := reflection.GetPrimaryKeyName(model)
|
||||
|
||||
var updateResult interface{}
|
||||
err = h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
||||
// Read existing record
|
||||
modelType := reflect.TypeOf(model)
|
||||
if modelType.Kind() == reflect.Ptr {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
existingRecord := reflect.New(modelType).Interface()
|
||||
selectQuery := tx.NewSelect().Model(existingRecord).Column("*").
|
||||
Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id)
|
||||
|
||||
if err := selectQuery.ScanModel(ctx); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return fmt.Errorf("no records found to update")
|
||||
}
|
||||
return fmt.Errorf("error fetching existing record: %w", err)
|
||||
}
|
||||
|
||||
// Convert to map
|
||||
existingMap := make(map[string]interface{})
|
||||
jsonData, err := json.Marshal(existingRecord)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error marshaling existing record: %w", err)
|
||||
}
|
||||
if err := json.Unmarshal(jsonData, &existingMap); err != nil {
|
||||
return fmt.Errorf("error unmarshaling existing record: %w", err)
|
||||
}
|
||||
|
||||
hookCtx := &HookContext{
|
||||
Context: ctx,
|
||||
Handler: h,
|
||||
Schema: schema,
|
||||
Entity: entity,
|
||||
Model: model,
|
||||
Operation: "update",
|
||||
ID: id,
|
||||
Data: updates,
|
||||
Tx: tx,
|
||||
}
|
||||
if err := h.hooks.Execute(BeforeUpdate, hookCtx); err != nil {
|
||||
return err
|
||||
}
|
||||
if modifiedData, ok := hookCtx.Data.(map[string]interface{}); ok {
|
||||
updates = modifiedData
|
||||
}
|
||||
|
||||
// Merge non-nil, non-empty values
|
||||
for key, newValue := range updates {
|
||||
if newValue == nil {
|
||||
continue
|
||||
}
|
||||
if strVal, ok := newValue.(string); ok && strVal == "" {
|
||||
continue
|
||||
}
|
||||
existingMap[key] = newValue
|
||||
}
|
||||
|
||||
q := tx.NewUpdate().Table(tableName).SetMap(existingMap).
|
||||
Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id)
|
||||
res, err := q.Exec(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error updating record: %w", err)
|
||||
}
|
||||
if res.RowsAffected() == 0 {
|
||||
return fmt.Errorf("no records found to update")
|
||||
}
|
||||
|
||||
updateResult = existingMap
|
||||
hookCtx.Result = updateResult
|
||||
return h.hooks.Execute(AfterUpdate, hookCtx)
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return updateResult, nil
|
||||
}
|
||||
|
||||
// executeDelete deletes a record by ID.
|
||||
func (h *Handler) executeDelete(ctx context.Context, schema, entity, id string) (interface{}, error) {
|
||||
if id == "" {
|
||||
return nil, fmt.Errorf("delete requires an ID")
|
||||
}
|
||||
|
||||
model, err := h.registry.GetModelByEntity(schema, entity)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("model not found: %w", err)
|
||||
}
|
||||
|
||||
result, err := common.ValidateAndUnwrapModel(model)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid model: %w", err)
|
||||
}
|
||||
|
||||
model = result.Model
|
||||
tableName := h.getTableName(schema, entity, model)
|
||||
ctx = withRequestData(ctx, schema, entity, tableName, model, result.ModelPtr)
|
||||
|
||||
pkName := reflection.GetPrimaryKeyName(model)
|
||||
|
||||
hookCtx := &HookContext{
|
||||
Context: ctx,
|
||||
Handler: h,
|
||||
Schema: schema,
|
||||
Entity: entity,
|
||||
Model: model,
|
||||
Operation: "delete",
|
||||
ID: id,
|
||||
Tx: h.db,
|
||||
}
|
||||
if err := h.hooks.Execute(BeforeHandle, hookCtx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := h.hooks.Execute(BeforeDelete, hookCtx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
modelType := reflect.TypeOf(model)
|
||||
if modelType.Kind() == reflect.Ptr {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
|
||||
var recordToDelete interface{}
|
||||
|
||||
err = h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
||||
record := reflect.New(modelType).Interface()
|
||||
selectQuery := tx.NewSelect().Model(record).
|
||||
Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id)
|
||||
if err := selectQuery.ScanModel(ctx); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return fmt.Errorf("record not found")
|
||||
}
|
||||
return fmt.Errorf("error fetching record: %w", err)
|
||||
}
|
||||
|
||||
res, err := tx.NewDelete().Table(tableName).
|
||||
Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id).
|
||||
Exec(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("delete error: %w", err)
|
||||
}
|
||||
if res.RowsAffected() == 0 {
|
||||
return fmt.Errorf("record not found or already deleted")
|
||||
}
|
||||
|
||||
recordToDelete = record
|
||||
hookCtx.Tx = tx
|
||||
hookCtx.Result = record
|
||||
return h.hooks.Execute(AfterDelete, hookCtx)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
logger.Info("[resolvemcp] Deleted record %s from %s.%s", id, schema, entity)
|
||||
return recordToDelete, nil
|
||||
}
|
||||
|
||||
// applyFilters applies all filters with OR grouping logic.
|
||||
func (h *Handler) applyFilters(query common.SelectQuery, filters []common.FilterOption) common.SelectQuery {
|
||||
if len(filters) == 0 {
|
||||
return query
|
||||
}
|
||||
|
||||
i := 0
|
||||
for i < len(filters) {
|
||||
startORGroup := i+1 < len(filters) && strings.EqualFold(filters[i+1].LogicOperator, "OR")
|
||||
|
||||
if startORGroup {
|
||||
orGroup := []common.FilterOption{filters[i]}
|
||||
j := i + 1
|
||||
for j < len(filters) && strings.EqualFold(filters[j].LogicOperator, "OR") {
|
||||
orGroup = append(orGroup, filters[j])
|
||||
j++
|
||||
}
|
||||
query = h.applyFilterGroup(query, orGroup)
|
||||
i = j
|
||||
} else {
|
||||
condition, args := h.buildFilterCondition(filters[i])
|
||||
if condition != "" {
|
||||
query = query.Where(condition, args...)
|
||||
}
|
||||
i++
|
||||
}
|
||||
}
|
||||
|
||||
return query
|
||||
}
|
||||
|
||||
func (h *Handler) applyFilterGroup(query common.SelectQuery, filters []common.FilterOption) common.SelectQuery {
|
||||
var conditions []string
|
||||
var args []interface{}
|
||||
|
||||
for _, filter := range filters {
|
||||
condition, filterArgs := h.buildFilterCondition(filter)
|
||||
if condition != "" {
|
||||
conditions = append(conditions, condition)
|
||||
args = append(args, filterArgs...)
|
||||
}
|
||||
}
|
||||
|
||||
if len(conditions) == 0 {
|
||||
return query
|
||||
}
|
||||
if len(conditions) == 1 {
|
||||
return query.Where(conditions[0], args...)
|
||||
}
|
||||
return query.Where("("+strings.Join(conditions, " OR ")+")", args...)
|
||||
}
|
||||
|
||||
func (h *Handler) buildFilterCondition(filter common.FilterOption) (condition string, args []interface{}) {
|
||||
switch filter.Operator {
|
||||
case "eq", "=":
|
||||
return fmt.Sprintf("%s = ?", filter.Column), []interface{}{filter.Value}
|
||||
case "neq", "!=", "<>":
|
||||
return fmt.Sprintf("%s != ?", filter.Column), []interface{}{filter.Value}
|
||||
case "gt", ">":
|
||||
return fmt.Sprintf("%s > ?", filter.Column), []interface{}{filter.Value}
|
||||
case "gte", ">=":
|
||||
return fmt.Sprintf("%s >= ?", filter.Column), []interface{}{filter.Value}
|
||||
case "lt", "<":
|
||||
return fmt.Sprintf("%s < ?", filter.Column), []interface{}{filter.Value}
|
||||
case "lte", "<=":
|
||||
return fmt.Sprintf("%s <= ?", filter.Column), []interface{}{filter.Value}
|
||||
case "like":
|
||||
return fmt.Sprintf("%s LIKE ?", filter.Column), []interface{}{filter.Value}
|
||||
case "ilike":
|
||||
return fmt.Sprintf("%s ILIKE ?", filter.Column), []interface{}{filter.Value}
|
||||
case "in":
|
||||
condition, args := common.BuildInCondition(filter.Column, filter.Value)
|
||||
return condition, args
|
||||
case "is_null":
|
||||
return fmt.Sprintf("%s IS NULL", filter.Column), nil
|
||||
case "is_not_null":
|
||||
return fmt.Sprintf("%s IS NOT NULL", filter.Column), nil
|
||||
}
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, preloads []common.PreloadOption) (common.SelectQuery, error) {
|
||||
for i := range preloads {
|
||||
preload := &preloads[i]
|
||||
if preload.Relation == "" {
|
||||
continue
|
||||
}
|
||||
query = query.PreloadRelation(preload.Relation)
|
||||
}
|
||||
return query, nil
|
||||
}
|
||||
113
pkg/resolvemcp/hooks.go
Normal file
113
pkg/resolvemcp/hooks.go
Normal file
@@ -0,0 +1,113 @@
|
||||
package resolvemcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
)
|
||||
|
||||
// HookType defines the type of hook to execute
|
||||
type HookType string
|
||||
|
||||
const (
|
||||
// BeforeHandle fires after model resolution, before operation dispatch.
|
||||
BeforeHandle HookType = "before_handle"
|
||||
|
||||
BeforeRead HookType = "before_read"
|
||||
AfterRead HookType = "after_read"
|
||||
|
||||
BeforeCreate HookType = "before_create"
|
||||
AfterCreate HookType = "after_create"
|
||||
|
||||
BeforeUpdate HookType = "before_update"
|
||||
AfterUpdate HookType = "after_update"
|
||||
|
||||
BeforeDelete HookType = "before_delete"
|
||||
AfterDelete HookType = "after_delete"
|
||||
)
|
||||
|
||||
// HookContext contains all the data available to a hook
|
||||
type HookContext struct {
|
||||
Context context.Context
|
||||
Handler *Handler
|
||||
Schema string
|
||||
Entity string
|
||||
Model interface{}
|
||||
Options common.RequestOptions
|
||||
Operation string
|
||||
ID string
|
||||
Data interface{}
|
||||
Result interface{}
|
||||
Error error
|
||||
Query common.SelectQuery
|
||||
Abort bool
|
||||
AbortMessage string
|
||||
AbortCode int
|
||||
Tx common.Database
|
||||
}
|
||||
|
||||
// HookFunc is the signature for hook functions
|
||||
type HookFunc func(*HookContext) error
|
||||
|
||||
// HookRegistry manages all registered hooks
|
||||
type HookRegistry struct {
|
||||
hooks map[HookType][]HookFunc
|
||||
}
|
||||
|
||||
func NewHookRegistry() *HookRegistry {
|
||||
return &HookRegistry{
|
||||
hooks: make(map[HookType][]HookFunc),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *HookRegistry) Register(hookType HookType, hook HookFunc) {
|
||||
if r.hooks == nil {
|
||||
r.hooks = make(map[HookType][]HookFunc)
|
||||
}
|
||||
r.hooks[hookType] = append(r.hooks[hookType], hook)
|
||||
logger.Info("Registered resolvemcp hook for %s (total: %d)", hookType, len(r.hooks[hookType]))
|
||||
}
|
||||
|
||||
func (r *HookRegistry) RegisterMultiple(hookTypes []HookType, hook HookFunc) {
|
||||
for _, hookType := range hookTypes {
|
||||
r.Register(hookType, hook)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *HookRegistry) Execute(hookType HookType, ctx *HookContext) error {
|
||||
hooks, exists := r.hooks[hookType]
|
||||
if !exists || len(hooks) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
logger.Debug("Executing %d resolvemcp hook(s) for %s", len(hooks), hookType)
|
||||
|
||||
for i, hook := range hooks {
|
||||
if err := hook(ctx); err != nil {
|
||||
logger.Error("resolvemcp hook %d for %s failed: %v", i+1, hookType, err)
|
||||
return fmt.Errorf("hook execution failed: %w", err)
|
||||
}
|
||||
|
||||
if ctx.Abort {
|
||||
logger.Warn("resolvemcp hook %d for %s requested abort: %s", i+1, hookType, ctx.AbortMessage)
|
||||
return fmt.Errorf("operation aborted by hook: %s", ctx.AbortMessage)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *HookRegistry) Clear(hookType HookType) {
|
||||
delete(r.hooks, hookType)
|
||||
}
|
||||
|
||||
func (r *HookRegistry) ClearAll() {
|
||||
r.hooks = make(map[HookType][]HookFunc)
|
||||
}
|
||||
|
||||
func (r *HookRegistry) HasHooks(hookType HookType) bool {
|
||||
hooks, exists := r.hooks[hookType]
|
||||
return exists && len(hooks) > 0
|
||||
}
|
||||
100
pkg/resolvemcp/resolvemcp.go
Normal file
100
pkg/resolvemcp/resolvemcp.go
Normal file
@@ -0,0 +1,100 @@
|
||||
// Package resolvemcp exposes registered database models as Model Context Protocol (MCP) tools
|
||||
// and resources over HTTP/SSE transport.
|
||||
//
|
||||
// It mirrors the resolvespec package patterns:
|
||||
// - Same model registration API
|
||||
// - Same filter, sort, cursor pagination, preload options
|
||||
// - Same lifecycle hook system
|
||||
//
|
||||
// Usage:
|
||||
//
|
||||
// handler := resolvemcp.NewHandlerWithGORM(db, resolvemcp.Config{BaseURL: "http://localhost:8080"})
|
||||
// handler.RegisterModel("public", "users", &User{})
|
||||
//
|
||||
// r := mux.NewRouter()
|
||||
// resolvemcp.SetupMuxRoutes(r, handler)
|
||||
package resolvemcp
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/uptrace/bun"
|
||||
bunrouter "github.com/uptrace/bunrouter"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common/adapters/database"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
|
||||
)
|
||||
|
||||
// Config holds configuration for the resolvemcp handler.
|
||||
type Config struct {
|
||||
// BaseURL is the public-facing base URL of the server (e.g. "http://localhost:8080").
|
||||
// It is sent to MCP clients during the SSE handshake so they know where to POST messages.
|
||||
BaseURL string
|
||||
|
||||
// BasePath is the URL path prefix where the MCP endpoints are mounted (e.g. "/mcp").
|
||||
// If empty, the path is detected from each incoming request automatically.
|
||||
BasePath string
|
||||
}
|
||||
|
||||
// NewHandlerWithGORM creates a Handler backed by a GORM database connection.
|
||||
func NewHandlerWithGORM(db *gorm.DB, cfg Config) *Handler {
|
||||
return NewHandler(database.NewGormAdapter(db), modelregistry.NewModelRegistry(), cfg)
|
||||
}
|
||||
|
||||
// NewHandlerWithBun creates a Handler backed by a Bun database connection.
|
||||
func NewHandlerWithBun(db *bun.DB, cfg Config) *Handler {
|
||||
return NewHandler(database.NewBunAdapter(db), modelregistry.NewModelRegistry(), cfg)
|
||||
}
|
||||
|
||||
// NewHandlerWithDB creates a Handler using an existing common.Database and a new registry.
|
||||
func NewHandlerWithDB(db common.Database, cfg Config) *Handler {
|
||||
return NewHandler(db, modelregistry.NewModelRegistry(), cfg)
|
||||
}
|
||||
|
||||
// SetupMuxRoutes mounts the MCP HTTP/SSE endpoints on the given Gorilla Mux router
|
||||
// using the base path from Config.BasePath (falls back to "/mcp" if empty).
|
||||
//
|
||||
// Two routes are registered:
|
||||
// - GET {basePath}/sse — SSE connection endpoint (client subscribes here)
|
||||
// - POST {basePath}/message — JSON-RPC message endpoint (client sends requests here)
|
||||
//
|
||||
// To protect these routes with authentication, wrap the mux router or apply middleware
|
||||
// before calling SetupMuxRoutes.
|
||||
func SetupMuxRoutes(muxRouter *mux.Router, handler *Handler) {
|
||||
basePath := handler.config.BasePath
|
||||
h := handler.SSEServer()
|
||||
|
||||
muxRouter.Handle(basePath+"/sse", h).Methods("GET", "OPTIONS")
|
||||
muxRouter.Handle(basePath+"/message", h).Methods("POST", "OPTIONS")
|
||||
|
||||
// Convenience: also expose the full SSE server at basePath for clients that
|
||||
// use ServeHTTP directly (e.g. net/http default mux).
|
||||
muxRouter.PathPrefix(basePath).Handler(http.StripPrefix(basePath, h))
|
||||
}
|
||||
|
||||
// SetupBunRouterRoutes mounts the MCP HTTP/SSE endpoints on a bunrouter router
|
||||
// using the base path from Config.BasePath.
|
||||
//
|
||||
// Two routes are registered:
|
||||
// - GET {basePath}/sse — SSE connection endpoint
|
||||
// - POST {basePath}/message — JSON-RPC message endpoint
|
||||
func SetupBunRouterRoutes(router *bunrouter.Router, handler *Handler) {
|
||||
basePath := handler.config.BasePath
|
||||
h := handler.SSEServer()
|
||||
|
||||
router.GET(basePath+"/sse", bunrouter.HTTPHandler(h))
|
||||
router.POST(basePath+"/message", bunrouter.HTTPHandler(h))
|
||||
}
|
||||
|
||||
// NewSSEServer returns an http.Handler that serves MCP over SSE.
|
||||
// If Config.BasePath is set it is used directly; otherwise the base path is
|
||||
// detected from each incoming request (by stripping the "/sse" or "/message" suffix).
|
||||
//
|
||||
// h := resolvemcp.NewSSEServer(handler)
|
||||
// http.Handle("/api/mcp/", h)
|
||||
func NewSSEServer(handler *Handler) http.Handler {
|
||||
return handler.SSEServer()
|
||||
}
|
||||
115
pkg/resolvemcp/security_hooks.go
Normal file
115
pkg/resolvemcp/security_hooks.go
Normal file
@@ -0,0 +1,115 @@
|
||||
package resolvemcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/security"
|
||||
)
|
||||
|
||||
// RegisterSecurityHooks wires the security package's access-control layer into the
|
||||
// resolvemcp handler. Call it once after creating the handler, before registering models.
|
||||
//
|
||||
// The following controls are applied:
|
||||
// - Per-entity operation rules (CanRead, CanCreate, CanUpdate, CanDelete, CanPublic*)
|
||||
// stored via RegisterModelWithRules / SetModelRules.
|
||||
// - Row-level security: WHERE clause injected per user from the SecurityList provider.
|
||||
// - Column-level security: sensitive columns masked/hidden in read results.
|
||||
// - Audit logging after each read.
|
||||
func RegisterSecurityHooks(handler *Handler, securityList *security.SecurityList) {
|
||||
// BeforeHandle: enforce model-level operation rules (auth check).
|
||||
handler.Hooks().Register(BeforeHandle, func(hookCtx *HookContext) error {
|
||||
if err := security.CheckModelAuthAllowed(newSecurityContext(hookCtx), hookCtx.Operation); err != nil {
|
||||
hookCtx.Abort = true
|
||||
hookCtx.AbortMessage = err.Error()
|
||||
hookCtx.AbortCode = http.StatusUnauthorized
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
// BeforeRead (1st): load RLS + CLS rules from the provider into SecurityList.
|
||||
handler.Hooks().Register(BeforeRead, func(hookCtx *HookContext) error {
|
||||
return security.LoadSecurityRules(newSecurityContext(hookCtx), securityList)
|
||||
})
|
||||
|
||||
// BeforeRead (2nd): apply row-level security — injects a WHERE clause into the query.
|
||||
// resolvemcp has no separate BeforeScan hook; the query is available in BeforeRead.
|
||||
handler.Hooks().Register(BeforeRead, func(hookCtx *HookContext) error {
|
||||
return security.ApplyRowSecurity(newSecurityContext(hookCtx), securityList)
|
||||
})
|
||||
|
||||
// AfterRead (1st): apply column-level security — mask/hide columns in the result.
|
||||
handler.Hooks().Register(AfterRead, func(hookCtx *HookContext) error {
|
||||
return security.ApplyColumnSecurity(newSecurityContext(hookCtx), securityList)
|
||||
})
|
||||
|
||||
// AfterRead (2nd): audit log.
|
||||
handler.Hooks().Register(AfterRead, func(hookCtx *HookContext) error {
|
||||
return security.LogDataAccess(newSecurityContext(hookCtx))
|
||||
})
|
||||
|
||||
// BeforeUpdate: enforce CanUpdate rule.
|
||||
handler.Hooks().Register(BeforeUpdate, func(hookCtx *HookContext) error {
|
||||
return security.CheckModelUpdateAllowed(newSecurityContext(hookCtx))
|
||||
})
|
||||
|
||||
// BeforeDelete: enforce CanDelete rule.
|
||||
handler.Hooks().Register(BeforeDelete, func(hookCtx *HookContext) error {
|
||||
return security.CheckModelDeleteAllowed(newSecurityContext(hookCtx))
|
||||
})
|
||||
|
||||
logger.Info("Security hooks registered for resolvemcp handler")
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// securityContext — adapts resolvemcp.HookContext to security.SecurityContext
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
type securityContext struct {
|
||||
ctx *HookContext
|
||||
}
|
||||
|
||||
func newSecurityContext(ctx *HookContext) security.SecurityContext {
|
||||
return &securityContext{ctx: ctx}
|
||||
}
|
||||
|
||||
func (s *securityContext) GetContext() context.Context {
|
||||
return s.ctx.Context
|
||||
}
|
||||
|
||||
func (s *securityContext) GetUserID() (int, bool) {
|
||||
return security.GetUserID(s.ctx.Context)
|
||||
}
|
||||
|
||||
func (s *securityContext) GetSchema() string {
|
||||
return s.ctx.Schema
|
||||
}
|
||||
|
||||
func (s *securityContext) GetEntity() string {
|
||||
return s.ctx.Entity
|
||||
}
|
||||
|
||||
func (s *securityContext) GetModel() interface{} {
|
||||
return s.ctx.Model
|
||||
}
|
||||
|
||||
func (s *securityContext) GetQuery() interface{} {
|
||||
return s.ctx.Query
|
||||
}
|
||||
|
||||
func (s *securityContext) SetQuery(query interface{}) {
|
||||
if q, ok := query.(common.SelectQuery); ok {
|
||||
s.ctx.Query = q
|
||||
}
|
||||
}
|
||||
|
||||
func (s *securityContext) GetResult() interface{} {
|
||||
return s.ctx.Result
|
||||
}
|
||||
|
||||
func (s *securityContext) SetResult(result interface{}) {
|
||||
s.ctx.Result = result
|
||||
}
|
||||
692
pkg/resolvemcp/tools.go
Normal file
692
pkg/resolvemcp/tools.go
Normal file
@@ -0,0 +1,692 @@
|
||||
package resolvemcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
||||
)
|
||||
|
||||
// toolName builds the MCP tool name for a given operation and model.
|
||||
func toolName(operation, schema, entity string) string {
|
||||
if schema == "" {
|
||||
return fmt.Sprintf("%s_%s", operation, entity)
|
||||
}
|
||||
return fmt.Sprintf("%s_%s_%s", operation, schema, entity)
|
||||
}
|
||||
|
||||
// registerModelTools registers the four CRUD tools and resource for a model.
|
||||
func registerModelTools(h *Handler, schema, entity string, model interface{}) {
|
||||
info := buildModelInfo(schema, entity, model)
|
||||
registerReadTool(h, schema, entity, info)
|
||||
registerCreateTool(h, schema, entity, info)
|
||||
registerUpdateTool(h, schema, entity, info)
|
||||
registerDeleteTool(h, schema, entity, info)
|
||||
registerModelResource(h, schema, entity, info)
|
||||
|
||||
logger.Info("[resolvemcp] Registered MCP tools for %s", info.fullName)
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Model introspection
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// modelInfo holds pre-computed metadata for a model used in tool descriptions.
|
||||
type modelInfo struct {
|
||||
fullName string // e.g. "public.users"
|
||||
pkName string // e.g. "id"
|
||||
columns []columnInfo
|
||||
relationNames []string
|
||||
schemaDoc string // formatted multi-line schema listing
|
||||
}
|
||||
|
||||
type columnInfo struct {
|
||||
jsonName string
|
||||
sqlName string
|
||||
goType string
|
||||
sqlType string
|
||||
isPrimary bool
|
||||
isUnique bool
|
||||
isFK bool
|
||||
nullable bool
|
||||
}
|
||||
|
||||
// buildModelInfo extracts column metadata and pre-builds the schema documentation string.
|
||||
func buildModelInfo(schema, entity string, model interface{}) modelInfo {
|
||||
info := modelInfo{
|
||||
fullName: buildModelName(schema, entity),
|
||||
pkName: reflection.GetPrimaryKeyName(model),
|
||||
}
|
||||
|
||||
// Unwrap to base struct type
|
||||
modelType := reflect.TypeOf(model)
|
||||
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice) {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||
return info
|
||||
}
|
||||
|
||||
details := reflection.GetModelColumnDetail(reflect.New(modelType).Elem())
|
||||
|
||||
for _, d := range details {
|
||||
// Derive the JSON name from the struct field
|
||||
jsonName := fieldJSONName(modelType, d.Name)
|
||||
if jsonName == "" || jsonName == "-" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip relation fields (slice or user-defined struct that isn't time.Time).
|
||||
fieldType, found := modelType.FieldByName(d.Name)
|
||||
if found {
|
||||
ft := fieldType.Type
|
||||
if ft.Kind() == reflect.Ptr {
|
||||
ft = ft.Elem()
|
||||
}
|
||||
isUserStruct := ft.Kind() == reflect.Struct && ft.Name() != "Time" && ft.PkgPath() != ""
|
||||
if ft.Kind() == reflect.Slice || isUserStruct {
|
||||
info.relationNames = append(info.relationNames, jsonName)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
sqlName := d.SQLName
|
||||
if sqlName == "" {
|
||||
sqlName = jsonName
|
||||
}
|
||||
|
||||
// Derive Go type name, unwrapping pointer if needed.
|
||||
goType := d.DataType
|
||||
if goType == "" && found {
|
||||
ft := fieldType.Type
|
||||
for ft.Kind() == reflect.Ptr {
|
||||
ft = ft.Elem()
|
||||
}
|
||||
goType = ft.Name()
|
||||
}
|
||||
|
||||
// isPrimary: use both the GORM-tag detection and a name comparison against
|
||||
// the known primary key (handles camelCase "primaryKey" tags correctly).
|
||||
isPrimary := d.SQLKey == "primary_key" ||
|
||||
(info.pkName != "" && (sqlName == info.pkName || jsonName == info.pkName))
|
||||
|
||||
ci := columnInfo{
|
||||
jsonName: jsonName,
|
||||
sqlName: sqlName,
|
||||
goType: goType,
|
||||
sqlType: d.SQLDataType,
|
||||
isPrimary: isPrimary,
|
||||
isUnique: d.SQLKey == "unique" || d.SQLKey == "uniqueindex",
|
||||
isFK: d.SQLKey == "foreign_key",
|
||||
nullable: d.Nullable,
|
||||
}
|
||||
info.columns = append(info.columns, ci)
|
||||
}
|
||||
|
||||
info.schemaDoc = buildSchemaDoc(info)
|
||||
return info
|
||||
}
|
||||
|
||||
// fieldJSONName returns the JSON tag name for a struct field, falling back to the field name.
|
||||
func fieldJSONName(modelType reflect.Type, fieldName string) string {
|
||||
field, ok := modelType.FieldByName(fieldName)
|
||||
if !ok {
|
||||
return fieldName
|
||||
}
|
||||
tag := field.Tag.Get("json")
|
||||
if tag == "" {
|
||||
return fieldName
|
||||
}
|
||||
parts := strings.SplitN(tag, ",", 2)
|
||||
if parts[0] == "" {
|
||||
return fieldName
|
||||
}
|
||||
return parts[0]
|
||||
}
|
||||
|
||||
// buildSchemaDoc builds a human-readable column listing for inclusion in tool descriptions.
|
||||
func buildSchemaDoc(info modelInfo) string {
|
||||
if len(info.columns) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
sb.WriteString("Columns:\n")
|
||||
for _, c := range info.columns {
|
||||
line := fmt.Sprintf(" • %s", c.jsonName)
|
||||
|
||||
typeDesc := c.goType
|
||||
if c.sqlType != "" {
|
||||
typeDesc = c.sqlType
|
||||
}
|
||||
if typeDesc != "" {
|
||||
line += fmt.Sprintf(" (%s)", typeDesc)
|
||||
}
|
||||
|
||||
var flags []string
|
||||
if c.isPrimary {
|
||||
flags = append(flags, "primary key")
|
||||
}
|
||||
if c.isUnique {
|
||||
flags = append(flags, "unique")
|
||||
}
|
||||
if c.isFK {
|
||||
flags = append(flags, "foreign key")
|
||||
}
|
||||
if !c.nullable && !c.isPrimary {
|
||||
flags = append(flags, "not null")
|
||||
} else if c.nullable {
|
||||
flags = append(flags, "nullable")
|
||||
}
|
||||
if len(flags) > 0 {
|
||||
line += " — " + strings.Join(flags, ", ")
|
||||
}
|
||||
|
||||
sb.WriteString(line + "\n")
|
||||
}
|
||||
|
||||
if len(info.relationNames) > 0 {
|
||||
sb.WriteString("Relations (preloadable): " + strings.Join(info.relationNames, ", ") + "\n")
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// columnNameList returns a comma-separated list of JSON column names (for descriptions).
|
||||
func columnNameList(cols []columnInfo) string {
|
||||
names := make([]string, len(cols))
|
||||
for i, c := range cols {
|
||||
names[i] = c.jsonName
|
||||
}
|
||||
return strings.Join(names, ", ")
|
||||
}
|
||||
|
||||
// writableColumnNames returns JSON names for all non-primary-key columns.
|
||||
func writableColumnNames(cols []columnInfo) []string {
|
||||
var names []string
|
||||
for _, c := range cols {
|
||||
if !c.isPrimary {
|
||||
names = append(names, c.jsonName)
|
||||
}
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Read tool
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
func registerReadTool(h *Handler, schema, entity string, info modelInfo) {
|
||||
name := toolName("read", schema, entity)
|
||||
|
||||
var descParts []string
|
||||
descParts = append(descParts, fmt.Sprintf("Read records from the '%s' database table.", info.fullName))
|
||||
if info.pkName != "" {
|
||||
descParts = append(descParts, fmt.Sprintf("Primary key: '%s'. Pass it via 'id' to fetch a single record.", info.pkName))
|
||||
}
|
||||
if info.schemaDoc != "" {
|
||||
descParts = append(descParts, info.schemaDoc)
|
||||
}
|
||||
descParts = append(descParts,
|
||||
"Pagination: use 'limit'/'offset' for offset-based paging, or 'cursor_forward'/'cursor_backward' (pass the primary key value of the last/first record on the current page) for cursor-based paging.",
|
||||
"Filtering: each filter object requires 'column' (JSON field name) and 'operator'. Supported operators: = != > < >= <= like ilike in is_null is_not_null. Combine with 'logic_operator': AND (default) or OR.",
|
||||
"Sorting: each sort object requires 'column' and 'direction' (asc or desc).",
|
||||
)
|
||||
if len(info.relationNames) > 0 {
|
||||
descParts = append(descParts, fmt.Sprintf("Preloadable relations: %s. Pass relation name in 'preloads'.", strings.Join(info.relationNames, ", ")))
|
||||
}
|
||||
|
||||
description := strings.Join(descParts, "\n\n")
|
||||
|
||||
filterDesc := `Array of filter objects. Example: [{"column":"status","operator":"=","value":"active"},{"column":"age","operator":">","value":18,"logic_operator":"AND"}]`
|
||||
if len(info.columns) > 0 {
|
||||
filterDesc += fmt.Sprintf(" Available columns: %s.", columnNameList(info.columns))
|
||||
}
|
||||
|
||||
sortDesc := `Array of sort objects. Example: [{"column":"created_at","direction":"desc"}]`
|
||||
if len(info.columns) > 0 {
|
||||
sortDesc += fmt.Sprintf(" Available columns: %s.", columnNameList(info.columns))
|
||||
}
|
||||
|
||||
tool := mcp.NewTool(name,
|
||||
mcp.WithDescription(description),
|
||||
mcp.WithString("id",
|
||||
mcp.Description(fmt.Sprintf("Primary key (%s) of a single record to fetch. Omit to return multiple records.", info.pkName)),
|
||||
),
|
||||
mcp.WithNumber("limit",
|
||||
mcp.Description("Maximum number of records to return per page. Recommended: 10–100."),
|
||||
),
|
||||
mcp.WithNumber("offset",
|
||||
mcp.Description("Number of records to skip (for offset-based pagination). Use with 'limit'."),
|
||||
),
|
||||
mcp.WithString("cursor_forward",
|
||||
mcp.Description(fmt.Sprintf("Cursor for the next page: pass the '%s' value of the last record on the current page. Requires 'sort' to be set.", info.pkName)),
|
||||
),
|
||||
mcp.WithString("cursor_backward",
|
||||
mcp.Description(fmt.Sprintf("Cursor for the previous page: pass the '%s' value of the first record on the current page. Requires 'sort' to be set.", info.pkName)),
|
||||
),
|
||||
mcp.WithArray("columns",
|
||||
mcp.Description(fmt.Sprintf("Columns to include in the result. Omit to return all columns. Available: %s.", columnNameList(info.columns))),
|
||||
),
|
||||
mcp.WithArray("omit_columns",
|
||||
mcp.Description(fmt.Sprintf("Columns to exclude from the result. Available: %s.", columnNameList(info.columns))),
|
||||
),
|
||||
mcp.WithArray("filters",
|
||||
mcp.Description(filterDesc),
|
||||
),
|
||||
mcp.WithArray("sort",
|
||||
mcp.Description(sortDesc),
|
||||
),
|
||||
mcp.WithArray("preloads",
|
||||
mcp.Description(buildPreloadDesc(info)),
|
||||
),
|
||||
)
|
||||
|
||||
h.mcpServer.AddTool(tool, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
args := req.GetArguments()
|
||||
id, _ := args["id"].(string)
|
||||
options := parseRequestOptions(args)
|
||||
|
||||
data, metadata, err := h.executeRead(ctx, schema, entity, id, options)
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(err.Error()), nil
|
||||
}
|
||||
|
||||
return marshalResult(map[string]interface{}{
|
||||
"success": true,
|
||||
"data": data,
|
||||
"metadata": metadata,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func buildPreloadDesc(info modelInfo) string {
|
||||
if len(info.relationNames) == 0 {
|
||||
return `Array of relation preload objects. Each object: {"relation":"RelationName"}. No relations defined on this model.`
|
||||
}
|
||||
return fmt.Sprintf(
|
||||
`Array of relation preload objects. Each object: {"relation":"RelationName","columns":["col1","col2"]}. Available relations: %s.`,
|
||||
strings.Join(info.relationNames, ", "),
|
||||
)
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Create tool
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
func registerCreateTool(h *Handler, schema, entity string, info modelInfo) {
|
||||
name := toolName("create", schema, entity)
|
||||
|
||||
writable := writableColumnNames(info.columns)
|
||||
|
||||
var descParts []string
|
||||
descParts = append(descParts, fmt.Sprintf("Create one or more new records in the '%s' table.", info.fullName))
|
||||
if len(writable) > 0 {
|
||||
descParts = append(descParts, fmt.Sprintf("Writable fields: %s.", strings.Join(writable, ", ")))
|
||||
}
|
||||
if info.pkName != "" {
|
||||
descParts = append(descParts, fmt.Sprintf("The primary key ('%s') is typically auto-generated — omit it unless you need to supply it explicitly.", info.pkName))
|
||||
}
|
||||
descParts = append(descParts,
|
||||
"Pass a single JSON object to 'data' to create one record. Pass an array of objects to create multiple records in a single transaction (all succeed or all fail).",
|
||||
)
|
||||
if info.schemaDoc != "" {
|
||||
descParts = append(descParts, info.schemaDoc)
|
||||
}
|
||||
|
||||
description := strings.Join(descParts, "\n\n")
|
||||
|
||||
dataDesc := "Record fields to create."
|
||||
if len(writable) > 0 {
|
||||
dataDesc += fmt.Sprintf(" Writable fields: %s.", strings.Join(writable, ", "))
|
||||
}
|
||||
dataDesc += " Pass a single object or an array of objects."
|
||||
|
||||
tool := mcp.NewTool(name,
|
||||
mcp.WithDescription(description),
|
||||
mcp.WithObject("data",
|
||||
mcp.Description(dataDesc),
|
||||
mcp.Required(),
|
||||
),
|
||||
)
|
||||
|
||||
h.mcpServer.AddTool(tool, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
args := req.GetArguments()
|
||||
data, ok := args["data"]
|
||||
if !ok {
|
||||
return mcp.NewToolResultError("missing required argument: data"), nil
|
||||
}
|
||||
|
||||
result, err := h.executeCreate(ctx, schema, entity, data)
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(err.Error()), nil
|
||||
}
|
||||
|
||||
return marshalResult(map[string]interface{}{
|
||||
"success": true,
|
||||
"data": result,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Update tool
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
func registerUpdateTool(h *Handler, schema, entity string, info modelInfo) {
|
||||
name := toolName("update", schema, entity)
|
||||
|
||||
writable := writableColumnNames(info.columns)
|
||||
|
||||
var descParts []string
|
||||
descParts = append(descParts, fmt.Sprintf("Update an existing record in the '%s' table.", info.fullName))
|
||||
if info.pkName != "" {
|
||||
descParts = append(descParts, fmt.Sprintf("Identify the record by its primary key ('%s') via the 'id' argument or by including '%s' inside 'data'.", info.pkName, info.pkName))
|
||||
}
|
||||
if len(writable) > 0 {
|
||||
descParts = append(descParts, fmt.Sprintf("Updatable fields: %s.", strings.Join(writable, ", ")))
|
||||
}
|
||||
descParts = append(descParts,
|
||||
"Only non-null, non-empty fields in 'data' are applied — existing values are preserved for fields you omit. Returns the merged record as stored.",
|
||||
)
|
||||
if info.schemaDoc != "" {
|
||||
descParts = append(descParts, info.schemaDoc)
|
||||
}
|
||||
|
||||
description := strings.Join(descParts, "\n\n")
|
||||
|
||||
idDesc := fmt.Sprintf("Primary key ('%s') of the record to update. Can also be included inside 'data'.", info.pkName)
|
||||
|
||||
dataDesc := "Fields to update (non-null, non-empty values are merged into the existing record)."
|
||||
if len(writable) > 0 {
|
||||
dataDesc += fmt.Sprintf(" Updatable fields: %s.", strings.Join(writable, ", "))
|
||||
}
|
||||
|
||||
tool := mcp.NewTool(name,
|
||||
mcp.WithDescription(description),
|
||||
mcp.WithString("id",
|
||||
mcp.Description(idDesc),
|
||||
),
|
||||
mcp.WithObject("data",
|
||||
mcp.Description(dataDesc),
|
||||
mcp.Required(),
|
||||
),
|
||||
)
|
||||
|
||||
h.mcpServer.AddTool(tool, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
args := req.GetArguments()
|
||||
id, _ := args["id"].(string)
|
||||
|
||||
data, ok := args["data"]
|
||||
if !ok {
|
||||
return mcp.NewToolResultError("missing required argument: data"), nil
|
||||
}
|
||||
dataMap, ok := data.(map[string]interface{})
|
||||
if !ok {
|
||||
return mcp.NewToolResultError("data must be an object"), nil
|
||||
}
|
||||
|
||||
result, err := h.executeUpdate(ctx, schema, entity, id, dataMap)
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(err.Error()), nil
|
||||
}
|
||||
|
||||
return marshalResult(map[string]interface{}{
|
||||
"success": true,
|
||||
"data": result,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Delete tool
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
func registerDeleteTool(h *Handler, schema, entity string, info modelInfo) {
|
||||
name := toolName("delete", schema, entity)
|
||||
|
||||
descParts := []string{
|
||||
fmt.Sprintf("Delete a record from the '%s' table by its primary key.", info.fullName),
|
||||
}
|
||||
if info.pkName != "" {
|
||||
descParts = append(descParts, fmt.Sprintf("Pass the '%s' value of the record to delete via the 'id' argument.", info.pkName))
|
||||
}
|
||||
descParts = append(descParts, "Returns the deleted record. This operation is irreversible.")
|
||||
|
||||
description := strings.Join(descParts, " ")
|
||||
|
||||
tool := mcp.NewTool(name,
|
||||
mcp.WithDescription(description),
|
||||
mcp.WithString("id",
|
||||
mcp.Description(fmt.Sprintf("Primary key ('%s') of the record to delete.", info.pkName)),
|
||||
mcp.Required(),
|
||||
),
|
||||
)
|
||||
|
||||
h.mcpServer.AddTool(tool, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
args := req.GetArguments()
|
||||
id, _ := args["id"].(string)
|
||||
|
||||
result, err := h.executeDelete(ctx, schema, entity, id)
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(err.Error()), nil
|
||||
}
|
||||
|
||||
return marshalResult(map[string]interface{}{
|
||||
"success": true,
|
||||
"data": result,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Resource registration
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
func registerModelResource(h *Handler, schema, entity string, info modelInfo) {
|
||||
resourceURI := info.fullName
|
||||
|
||||
var resourceDesc strings.Builder
|
||||
fmt.Fprintf(&resourceDesc, "Database table: %s", info.fullName)
|
||||
if info.pkName != "" {
|
||||
fmt.Fprintf(&resourceDesc, " (primary key: %s)", info.pkName)
|
||||
}
|
||||
if info.schemaDoc != "" {
|
||||
resourceDesc.WriteString("\n\n")
|
||||
resourceDesc.WriteString(info.schemaDoc)
|
||||
}
|
||||
|
||||
resource := mcp.NewResource(
|
||||
resourceURI,
|
||||
entity,
|
||||
mcp.WithResourceDescription(resourceDesc.String()),
|
||||
mcp.WithMIMEType("application/json"),
|
||||
)
|
||||
|
||||
h.mcpServer.AddResource(resource, func(ctx context.Context, req mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) {
|
||||
limit := 100
|
||||
options := common.RequestOptions{Limit: &limit}
|
||||
|
||||
data, metadata, err := h.executeRead(ctx, schema, entity, "", options)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
payload := map[string]interface{}{
|
||||
"data": data,
|
||||
"metadata": metadata,
|
||||
}
|
||||
jsonBytes, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error marshaling resource: %w", err)
|
||||
}
|
||||
|
||||
return []mcp.ResourceContents{
|
||||
mcp.TextResourceContents{
|
||||
URI: req.Params.URI,
|
||||
MIMEType: "application/json",
|
||||
Text: string(jsonBytes),
|
||||
},
|
||||
}, nil
|
||||
})
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Argument parsing helpers
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// parseRequestOptions converts raw MCP tool arguments into common.RequestOptions.
|
||||
func parseRequestOptions(args map[string]interface{}) common.RequestOptions {
|
||||
options := common.RequestOptions{}
|
||||
|
||||
if v, ok := args["limit"]; ok {
|
||||
switch n := v.(type) {
|
||||
case float64:
|
||||
limit := int(n)
|
||||
options.Limit = &limit
|
||||
case int:
|
||||
options.Limit = &n
|
||||
}
|
||||
}
|
||||
|
||||
if v, ok := args["offset"]; ok {
|
||||
switch n := v.(type) {
|
||||
case float64:
|
||||
offset := int(n)
|
||||
options.Offset = &offset
|
||||
case int:
|
||||
options.Offset = &n
|
||||
}
|
||||
}
|
||||
|
||||
if v, ok := args["cursor_forward"].(string); ok {
|
||||
options.CursorForward = v
|
||||
}
|
||||
if v, ok := args["cursor_backward"].(string); ok {
|
||||
options.CursorBackward = v
|
||||
}
|
||||
|
||||
options.Columns = parseStringArray(args["columns"])
|
||||
options.OmitColumns = parseStringArray(args["omit_columns"])
|
||||
options.Filters = parseFilters(args["filters"])
|
||||
options.Sort = parseSortOptions(args["sort"])
|
||||
options.Preload = parsePreloadOptions(args["preloads"])
|
||||
|
||||
return options
|
||||
}
|
||||
|
||||
func parseStringArray(raw interface{}) []string {
|
||||
if raw == nil {
|
||||
return nil
|
||||
}
|
||||
items, ok := raw.([]interface{})
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
result := make([]string, 0, len(items))
|
||||
for _, item := range items {
|
||||
if s, ok := item.(string); ok {
|
||||
result = append(result, s)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func parseFilters(raw interface{}) []common.FilterOption {
|
||||
if raw == nil {
|
||||
return nil
|
||||
}
|
||||
items, ok := raw.([]interface{})
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
result := make([]common.FilterOption, 0, len(items))
|
||||
for _, item := range items {
|
||||
b, err := json.Marshal(item)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
var f common.FilterOption
|
||||
if err := json.Unmarshal(b, &f); err != nil {
|
||||
continue
|
||||
}
|
||||
if f.Column == "" || f.Operator == "" {
|
||||
continue
|
||||
}
|
||||
if strings.EqualFold(f.LogicOperator, "or") {
|
||||
f.LogicOperator = "OR"
|
||||
} else {
|
||||
f.LogicOperator = "AND"
|
||||
}
|
||||
result = append(result, f)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func parseSortOptions(raw interface{}) []common.SortOption {
|
||||
if raw == nil {
|
||||
return nil
|
||||
}
|
||||
items, ok := raw.([]interface{})
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
result := make([]common.SortOption, 0, len(items))
|
||||
for _, item := range items {
|
||||
b, err := json.Marshal(item)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
var s common.SortOption
|
||||
if err := json.Unmarshal(b, &s); err != nil {
|
||||
continue
|
||||
}
|
||||
if s.Column == "" {
|
||||
continue
|
||||
}
|
||||
result = append(result, s)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func parsePreloadOptions(raw interface{}) []common.PreloadOption {
|
||||
if raw == nil {
|
||||
return nil
|
||||
}
|
||||
items, ok := raw.([]interface{})
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
result := make([]common.PreloadOption, 0, len(items))
|
||||
for _, item := range items {
|
||||
b, err := json.Marshal(item)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
var p common.PreloadOption
|
||||
if err := json.Unmarshal(b, &p); err != nil {
|
||||
continue
|
||||
}
|
||||
if p.Relation == "" {
|
||||
continue
|
||||
}
|
||||
result = append(result, p)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// marshalResult marshals a value to JSON and returns it as an MCP text result.
|
||||
func marshalResult(v interface{}) (*mcp.CallToolResult, error) {
|
||||
b, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("error marshaling result: %v", err)), nil
|
||||
}
|
||||
return mcp.NewToolResultText(string(b)), nil
|
||||
}
|
||||
@@ -24,6 +24,7 @@ const (
|
||||
// - pkName: primary key column (e.g. "id")
|
||||
// - modelColumns: optional list of valid main-table columns (for validation). Pass nil to skip.
|
||||
// - options: the request options containing sort and cursor information
|
||||
// - expandJoins: optional map[alias]string of JOIN clauses for join-column sort support
|
||||
//
|
||||
// Returns SQL snippet to embed in WHERE clause.
|
||||
func GetCursorFilter(
|
||||
@@ -31,6 +32,7 @@ func GetCursorFilter(
|
||||
pkName string,
|
||||
modelColumns []string,
|
||||
options common.RequestOptions,
|
||||
expandJoins map[string]string,
|
||||
) (string, error) {
|
||||
// Separate schema prefix from bare table name
|
||||
fullTableName := tableName
|
||||
@@ -58,18 +60,19 @@ func GetCursorFilter(
|
||||
// 3. Prepare
|
||||
// --------------------------------------------------------------------- //
|
||||
var whereClauses []string
|
||||
joinSQL := ""
|
||||
reverse := direction < 0
|
||||
|
||||
// --------------------------------------------------------------------- //
|
||||
// 4. Process each sort column
|
||||
// --------------------------------------------------------------------- //
|
||||
for _, s := range sortItems {
|
||||
col := strings.TrimSpace(s.Column)
|
||||
col := strings.Trim(strings.TrimSpace(s.Column), "()")
|
||||
if col == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Parse: "created_at", "user.name", etc.
|
||||
// Parse: "created_at", "user.name", "fn.sortorder", etc.
|
||||
parts := strings.Split(col, ".")
|
||||
field := strings.TrimSpace(parts[len(parts)-1])
|
||||
prefix := strings.Join(parts[:len(parts)-1], ".")
|
||||
@@ -82,7 +85,7 @@ func GetCursorFilter(
|
||||
}
|
||||
|
||||
// Resolve column
|
||||
cursorCol, targetCol, err := resolveColumn(
|
||||
cursorCol, targetCol, isJoin, err := resolveColumn(
|
||||
field, prefix, tableName, modelColumns,
|
||||
)
|
||||
if err != nil {
|
||||
@@ -90,6 +93,22 @@ func GetCursorFilter(
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle joins
|
||||
if isJoin {
|
||||
if expandJoins != nil {
|
||||
if joinClause, ok := expandJoins[prefix]; ok {
|
||||
jSQL, cRef := rewriteJoin(joinClause, tableName, prefix)
|
||||
joinSQL = jSQL
|
||||
cursorCol = cRef + "." + field
|
||||
targetCol = prefix + "." + field
|
||||
}
|
||||
}
|
||||
if cursorCol == "" {
|
||||
logger.Warn("Skipping cursor sort column %q: join alias %q not in expandJoins", col, prefix)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Build inequality
|
||||
op := "<"
|
||||
if desc {
|
||||
@@ -113,10 +132,12 @@ func GetCursorFilter(
|
||||
query := fmt.Sprintf(`EXISTS (
|
||||
SELECT 1
|
||||
FROM %s cursor_select
|
||||
%s
|
||||
WHERE cursor_select.%s = %s
|
||||
AND (%s)
|
||||
)`,
|
||||
fullTableName,
|
||||
joinSQL,
|
||||
pkName,
|
||||
cursorID,
|
||||
orSQL,
|
||||
@@ -137,35 +158,44 @@ func getActiveCursor(options common.RequestOptions) (id string, direction Cursor
|
||||
return "", 0
|
||||
}
|
||||
|
||||
// Helper: resolve column (main table only for now)
|
||||
// Helper: resolve column (main table or join)
|
||||
func resolveColumn(
|
||||
field, prefix, tableName string,
|
||||
modelColumns []string,
|
||||
) (cursorCol, targetCol string, err error) {
|
||||
) (cursorCol, targetCol string, isJoin bool, err error) {
|
||||
|
||||
// JSON field
|
||||
if strings.Contains(field, "->") {
|
||||
return "cursor_select." + field, tableName + "." + field, nil
|
||||
return "cursor_select." + field, tableName + "." + field, false, nil
|
||||
}
|
||||
|
||||
// Main table column
|
||||
if modelColumns != nil {
|
||||
for _, col := range modelColumns {
|
||||
if strings.EqualFold(col, field) {
|
||||
return "cursor_select." + field, tableName + "." + field, nil
|
||||
return "cursor_select." + field, tableName + "." + field, false, nil
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// No validation → allow all main-table fields
|
||||
return "cursor_select." + field, tableName + "." + field, nil
|
||||
return "cursor_select." + field, tableName + "." + field, false, nil
|
||||
}
|
||||
|
||||
// Joined column (not supported in resolvespec yet)
|
||||
// Joined column
|
||||
if prefix != "" && prefix != tableName {
|
||||
return "", "", fmt.Errorf("joined columns not supported in cursor pagination: %s", field)
|
||||
return "", "", true, nil
|
||||
}
|
||||
|
||||
return "", "", fmt.Errorf("invalid column: %s", field)
|
||||
return "", "", false, fmt.Errorf("invalid column: %s", field)
|
||||
}
|
||||
|
||||
// Helper: rewrite JOIN clause for cursor subquery
|
||||
func rewriteJoin(joinClause, mainTable, alias string) (joinSQL, cursorAlias string) {
|
||||
joinSQL = strings.ReplaceAll(joinClause, mainTable+".", "cursor_select.")
|
||||
cursorAlias = "cursor_select_" + alias
|
||||
joinSQL = strings.ReplaceAll(joinSQL, " "+alias+" ", " "+cursorAlias+" ")
|
||||
joinSQL = strings.ReplaceAll(joinSQL, " "+alias+".", " "+cursorAlias+".")
|
||||
return joinSQL, cursorAlias
|
||||
}
|
||||
|
||||
// ------------------------------------------------------------------------- //
|
||||
|
||||
@@ -20,7 +20,7 @@ func TestGetCursorFilter_Forward(t *testing.T) {
|
||||
pkName := "id"
|
||||
modelColumns := []string{"id", "title", "created_at", "user_id"}
|
||||
|
||||
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options)
|
||||
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("GetCursorFilter failed: %v", err)
|
||||
}
|
||||
@@ -65,7 +65,7 @@ func TestGetCursorFilter_Backward(t *testing.T) {
|
||||
pkName := "id"
|
||||
modelColumns := []string{"id", "title", "created_at", "user_id"}
|
||||
|
||||
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options)
|
||||
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("GetCursorFilter failed: %v", err)
|
||||
}
|
||||
@@ -96,7 +96,7 @@ func TestGetCursorFilter_NoCursor(t *testing.T) {
|
||||
pkName := "id"
|
||||
modelColumns := []string{"id", "title", "created_at"}
|
||||
|
||||
_, err := GetCursorFilter(tableName, pkName, modelColumns, options)
|
||||
_, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil)
|
||||
if err == nil {
|
||||
t.Error("Expected error when no cursor is provided")
|
||||
}
|
||||
@@ -116,7 +116,7 @@ func TestGetCursorFilter_NoSort(t *testing.T) {
|
||||
pkName := "id"
|
||||
modelColumns := []string{"id", "title"}
|
||||
|
||||
_, err := GetCursorFilter(tableName, pkName, modelColumns, options)
|
||||
_, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil)
|
||||
if err == nil {
|
||||
t.Error("Expected error when no sort columns are defined")
|
||||
}
|
||||
@@ -140,7 +140,7 @@ func TestGetCursorFilter_MultiColumnSort(t *testing.T) {
|
||||
pkName := "id"
|
||||
modelColumns := []string{"id", "title", "priority", "created_at"}
|
||||
|
||||
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options)
|
||||
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("GetCursorFilter failed: %v", err)
|
||||
}
|
||||
@@ -170,7 +170,7 @@ func TestGetCursorFilter_WithSchemaPrefix(t *testing.T) {
|
||||
pkName := "id"
|
||||
modelColumns := []string{"id", "name", "email"}
|
||||
|
||||
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options)
|
||||
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("GetCursorFilter failed: %v", err)
|
||||
}
|
||||
@@ -183,6 +183,37 @@ func TestGetCursorFilter_WithSchemaPrefix(t *testing.T) {
|
||||
t.Logf("Generated cursor filter with schema: %s", filter)
|
||||
}
|
||||
|
||||
func TestGetCursorFilter_LateralJoin(t *testing.T) {
|
||||
lateralJoin := "inner join lateral (\nselect string_agg(a.name, '.') as sortorder\nfrom tree(account.rid_account) r\ninner join account a on a.id = r.id\n) fn on true"
|
||||
|
||||
options := common.RequestOptions{
|
||||
Sort: []common.SortOption{{Column: "fn.sortorder", Direction: "ASC"}},
|
||||
CursorForward: "8975",
|
||||
}
|
||||
|
||||
tableName := "core.account"
|
||||
pkName := "rid_account"
|
||||
modelColumns := []string{"rid_account", "description", "pastelno"}
|
||||
expandJoins := map[string]string{"fn": lateralJoin}
|
||||
|
||||
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options, expandJoins)
|
||||
if err != nil {
|
||||
t.Fatalf("GetCursorFilter failed: %v", err)
|
||||
}
|
||||
|
||||
t.Logf("Generated lateral cursor filter: %s", filter)
|
||||
|
||||
if !strings.Contains(filter, "cursor_select_fn") {
|
||||
t.Errorf("Filter should reference cursor_select_fn alias, got: %s", filter)
|
||||
}
|
||||
if !strings.Contains(filter, "sortorder") {
|
||||
t.Errorf("Filter should reference sortorder column, got: %s", filter)
|
||||
}
|
||||
if strings.Contains(filter, " < ") || strings.Contains(filter, " > ") {
|
||||
t.Errorf("Filter should not contain empty comparison operators, got: %s", filter)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetActiveCursor(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -288,18 +319,19 @@ func TestResolveColumn(t *testing.T) {
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Joined column (not supported)",
|
||||
name: "Joined column (isJoin=true, no error)",
|
||||
field: "name",
|
||||
prefix: "user",
|
||||
tableName: "posts",
|
||||
modelColumns: []string{"id", "title"},
|
||||
wantErr: true,
|
||||
wantErr: false,
|
||||
// cursorCol and targetCol are empty when isJoin=true; handled by caller
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cursor, target, err := resolveColumn(tt.field, tt.prefix, tt.tableName, tt.modelColumns)
|
||||
cursor, target, isJoin, err := resolveColumn(tt.field, tt.prefix, tt.tableName, tt.modelColumns)
|
||||
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
@@ -312,6 +344,14 @@ func TestResolveColumn(t *testing.T) {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// For join columns, cursor/target are empty and isJoin=true
|
||||
if isJoin {
|
||||
if cursor != "" || target != "" {
|
||||
t.Errorf("Expected empty cursor/target for join column, got %q / %q", cursor, target)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if cursor != tt.wantCursor {
|
||||
t.Errorf("Expected cursor %q, got %q", tt.wantCursor, cursor)
|
||||
}
|
||||
@@ -362,7 +402,7 @@ func TestCursorFilter_SQL_Safety(t *testing.T) {
|
||||
pkName := "id"
|
||||
modelColumns := []string{"id", "created_at"}
|
||||
|
||||
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options)
|
||||
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("GetCursorFilter failed: %v", err)
|
||||
}
|
||||
|
||||
@@ -334,8 +334,8 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
||||
options.Sort = []common.SortOption{{Column: pkName, Direction: "ASC"}}
|
||||
}
|
||||
|
||||
// Get cursor filter SQL
|
||||
cursorFilter, err := GetCursorFilter(tableName, pkName, modelColumns, options)
|
||||
// Get cursor filter SQL (expandJoins is empty for resolvespec — no custom SQL join support yet)
|
||||
cursorFilter, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil)
|
||||
if err != nil {
|
||||
logger.Error("Error building cursor filter: %v", err)
|
||||
h.sendError(w, http.StatusBadRequest, "cursor_error", "Invalid cursor pagination", err)
|
||||
|
||||
@@ -64,7 +64,7 @@ func (opts *ExtendedRequestOptions) GetCursorFilter(
|
||||
// 4. Process each sort column
|
||||
// --------------------------------------------------------------------- //
|
||||
for _, s := range sortItems {
|
||||
col := strings.TrimSpace(s.Column)
|
||||
col := strings.Trim(strings.TrimSpace(s.Column), "()")
|
||||
if col == "" {
|
||||
continue
|
||||
}
|
||||
@@ -93,12 +93,18 @@ func (opts *ExtendedRequestOptions) GetCursorFilter(
|
||||
}
|
||||
|
||||
// Handle joins
|
||||
if isJoin && expandJoins != nil {
|
||||
if joinClause, ok := expandJoins[prefix]; ok {
|
||||
jSQL, cRef := rewriteJoin(joinClause, tableName, prefix)
|
||||
joinSQL = jSQL
|
||||
cursorCol = cRef + "." + field
|
||||
targetCol = prefix + "." + field
|
||||
if isJoin {
|
||||
if expandJoins != nil {
|
||||
if joinClause, ok := expandJoins[prefix]; ok {
|
||||
jSQL, cRef := rewriteJoin(joinClause, tableName, prefix)
|
||||
joinSQL = jSQL
|
||||
cursorCol = cRef + "." + field
|
||||
targetCol = prefix + "." + field
|
||||
}
|
||||
}
|
||||
if cursorCol == "" {
|
||||
logger.Warn("Skipping cursor sort column %q: join alias %q not in expandJoins", col, prefix)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -278,6 +278,47 @@ func TestCleanSortField(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetCursorFilter_LateralJoin(t *testing.T) {
|
||||
lateralJoin := "inner join lateral (\nselect string_agg(a.name, '.') as sortorder\nfrom tree(account.rid_account) r\ninner join account a on a.id = r.id\n) fn on true"
|
||||
|
||||
opts := &ExtendedRequestOptions{
|
||||
RequestOptions: common.RequestOptions{
|
||||
Sort: []common.SortOption{
|
||||
{Column: "fn.sortorder", Direction: "ASC"},
|
||||
},
|
||||
},
|
||||
}
|
||||
opts.CursorForward = "8975"
|
||||
|
||||
tableName := "core.account"
|
||||
pkName := "rid_account"
|
||||
// modelColumns does not contain "sortorder" - it's a lateral join computed column
|
||||
modelColumns := []string{"rid_account", "description", "pastelno"}
|
||||
expandJoins := map[string]string{"fn": lateralJoin}
|
||||
|
||||
filter, err := opts.GetCursorFilter(tableName, pkName, modelColumns, expandJoins)
|
||||
if err != nil {
|
||||
t.Fatalf("GetCursorFilter failed: %v", err)
|
||||
}
|
||||
|
||||
t.Logf("Generated lateral cursor filter: %s", filter)
|
||||
|
||||
// Should contain the rewritten lateral join inside the EXISTS subquery
|
||||
if !strings.Contains(filter, "cursor_select_fn") {
|
||||
t.Errorf("Filter should reference cursor_select_fn alias, got: %s", filter)
|
||||
}
|
||||
|
||||
// Should compare fn.sortorder values
|
||||
if !strings.Contains(filter, "sortorder") {
|
||||
t.Errorf("Filter should reference sortorder column, got: %s", filter)
|
||||
}
|
||||
|
||||
// Should NOT contain empty comparison like "< "
|
||||
if strings.Contains(filter, " < ") || strings.Contains(filter, " > ") {
|
||||
t.Errorf("Filter should not contain empty comparison operators, got: %s", filter)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildPriorityChain(t *testing.T) {
|
||||
clauses := []string{
|
||||
"cursor_select.priority > posts.priority",
|
||||
|
||||
@@ -723,13 +723,15 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
||||
// Extract model columns for validation using the generic database function
|
||||
modelColumns := reflection.GetModelColumns(model)
|
||||
|
||||
// Build expand joins map (if needed in future)
|
||||
var expandJoins map[string]string
|
||||
if len(options.Expand) > 0 {
|
||||
expandJoins = make(map[string]string)
|
||||
// TODO: Build actual JOIN SQL for each expand relation
|
||||
// For now, pass empty map as joins are handled via Preload
|
||||
// Build expand joins map: custom SQL joins are available in cursor subquery
|
||||
expandJoins := make(map[string]string)
|
||||
for _, joinClause := range options.CustomSQLJoin {
|
||||
alias := extractJoinAlias(joinClause)
|
||||
if alias != "" {
|
||||
expandJoins[alias] = joinClause
|
||||
}
|
||||
}
|
||||
// TODO: also add Expand relation JOINs when those are built as SQL rather than Preload
|
||||
|
||||
// Default sort to primary key when none provided
|
||||
if len(options.Sort) == 0 {
|
||||
|
||||
@@ -274,9 +274,11 @@ func (h *Handler) parseOptionsFromHeaders(r common.Request, model interface{}) E
|
||||
}
|
||||
}
|
||||
|
||||
// Resolve relation names (convert table names to field names) if model is provided
|
||||
// Skip resolution if X-Files header was provided, as XFiles uses Prefix which already contains the correct field names
|
||||
if model != nil && !options.XFilesPresent {
|
||||
// Resolve relation names (convert table names/prefixes to actual model field names) if model is provided.
|
||||
// This runs for both regular headers and X-Files, because XFile prefixes don't always match model
|
||||
// field names (e.g., prefix "HUB" vs field "HUB_RID_HUB"). RelatedKey/ForeignKey are used to
|
||||
// disambiguate when multiple fields point to the same related type.
|
||||
if model != nil {
|
||||
h.resolveRelationNamesInOptions(&options, model)
|
||||
}
|
||||
|
||||
@@ -550,10 +552,8 @@ func (h *Handler) parseCustomSQLJoin(options *ExtendedRequestOptions, value stri
|
||||
// - "LEFT JOIN departments d ON ..." -> "d"
|
||||
// - "INNER JOIN users AS u ON ..." -> "u"
|
||||
// - "JOIN roles r ON ..." -> "r"
|
||||
// - "INNER JOIN LATERAL (...) fn ON true" -> "fn"
|
||||
func extractJoinAlias(joinClause string) string {
|
||||
// Pattern: JOIN table_name [AS] alias ON ...
|
||||
// We need to extract the alias (word before ON)
|
||||
|
||||
upperJoin := strings.ToUpper(joinClause)
|
||||
|
||||
// Find the "JOIN" keyword position
|
||||
@@ -562,7 +562,20 @@ func extractJoinAlias(joinClause string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Find the "ON" keyword position
|
||||
// Lateral joins: alias is the word after the closing ) and before ON
|
||||
if strings.Contains(upperJoin, "LATERAL") {
|
||||
lastClose := strings.LastIndex(joinClause, ")")
|
||||
if lastClose != -1 {
|
||||
words := strings.Fields(joinClause[lastClose+1:])
|
||||
// words should be like ["fn", "on", "true"] or ["on", "true"]
|
||||
if len(words) >= 1 && !strings.EqualFold(words[0], "on") {
|
||||
return words[0]
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// Regular joins: find the "ON" keyword position (first occurrence)
|
||||
onIdx := strings.Index(upperJoin, " ON ")
|
||||
if onIdx == -1 {
|
||||
return ""
|
||||
@@ -863,8 +876,21 @@ func (h *Handler) resolveRelationNamesInOptions(options *ExtendedRequestOptions,
|
||||
|
||||
// Resolve each part of the path
|
||||
currentModel := model
|
||||
for _, part := range parts {
|
||||
resolvedPart := h.resolveRelationName(currentModel, part)
|
||||
for partIdx, part := range parts {
|
||||
isLast := partIdx == len(parts)-1
|
||||
var resolvedPart string
|
||||
if isLast {
|
||||
// For the final part, use join-key-aware resolution to disambiguate when
|
||||
// multiple fields point to the same type (e.g., HUB_RID_HUB vs HUB_RID_ASSIGNEDTO).
|
||||
// RelatedKey = parent's local column linking to child; ForeignKey = local column linking to parent.
|
||||
localKey := preload.RelatedKey
|
||||
if localKey == "" {
|
||||
localKey = preload.ForeignKey
|
||||
}
|
||||
resolvedPart = h.resolveRelationNameWithJoinKey(currentModel, part, localKey)
|
||||
} else {
|
||||
resolvedPart = h.resolveRelationName(currentModel, part)
|
||||
}
|
||||
resolvedParts = append(resolvedParts, resolvedPart)
|
||||
|
||||
// Try to get the model type for the next level
|
||||
@@ -980,6 +1006,101 @@ func (h *Handler) resolveRelationName(model interface{}, nameOrTable string) str
|
||||
return nameOrTable
|
||||
}
|
||||
|
||||
// resolveRelationNameWithJoinKey resolves a relation name like resolveRelationName, but when
|
||||
// multiple fields point to the same related type, uses localKey to pick the one whose bun join
|
||||
// tag starts with "join:localKey=". Falls back to resolveRelationName if no key match is found.
|
||||
func (h *Handler) resolveRelationNameWithJoinKey(model interface{}, nameOrTable string, localKey string) string {
|
||||
if localKey == "" {
|
||||
return h.resolveRelationName(model, nameOrTable)
|
||||
}
|
||||
|
||||
modelType := reflect.TypeOf(model)
|
||||
if modelType == nil {
|
||||
return nameOrTable
|
||||
}
|
||||
if modelType.Kind() == reflect.Ptr {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||
return nameOrTable
|
||||
}
|
||||
|
||||
// If it's already a direct field name, return as-is (no ambiguity).
|
||||
for i := 0; i < modelType.NumField(); i++ {
|
||||
if modelType.Field(i).Name == nameOrTable {
|
||||
return nameOrTable
|
||||
}
|
||||
}
|
||||
|
||||
normalizedInput := strings.ToLower(strings.ReplaceAll(nameOrTable, "_", ""))
|
||||
localKeyLower := strings.ToLower(localKey)
|
||||
|
||||
// Find all fields whose related type matches nameOrTable, then pick the one
|
||||
// whose bun join tag local key matches localKey.
|
||||
var fallbackField string
|
||||
for i := 0; i < modelType.NumField(); i++ {
|
||||
field := modelType.Field(i)
|
||||
fieldType := field.Type
|
||||
|
||||
var targetType reflect.Type
|
||||
if fieldType.Kind() == reflect.Slice {
|
||||
targetType = fieldType.Elem()
|
||||
} else if fieldType.Kind() == reflect.Ptr {
|
||||
targetType = fieldType.Elem()
|
||||
}
|
||||
if targetType != nil && targetType.Kind() == reflect.Ptr {
|
||||
targetType = targetType.Elem()
|
||||
}
|
||||
if targetType == nil || targetType.Kind() != reflect.Struct {
|
||||
continue
|
||||
}
|
||||
|
||||
normalizedTypeName := strings.ToLower(targetType.Name())
|
||||
normalizedTypeName = strings.TrimPrefix(normalizedTypeName, "modelcore")
|
||||
normalizedTypeName = strings.TrimPrefix(normalizedTypeName, "model")
|
||||
if normalizedTypeName != normalizedInput {
|
||||
continue
|
||||
}
|
||||
|
||||
// Type name matches; record as fallback.
|
||||
if fallbackField == "" {
|
||||
fallbackField = field.Name
|
||||
}
|
||||
|
||||
// Check bun join tag: "join:localKey=foreignKey"
|
||||
bunTag := field.Tag.Get("bun")
|
||||
for _, tagPart := range strings.Split(bunTag, ",") {
|
||||
tagPart = strings.TrimSpace(tagPart)
|
||||
if !strings.HasPrefix(tagPart, "join:") {
|
||||
continue
|
||||
}
|
||||
joinSpec := strings.TrimPrefix(tagPart, "join:")
|
||||
// joinSpec can be "col1=col2" or "col1=col2 col3=col4" (multi-col joins)
|
||||
joinCols := strings.Fields(joinSpec)
|
||||
if len(joinCols) == 0 {
|
||||
joinCols = []string{joinSpec}
|
||||
}
|
||||
for _, joinCol := range joinCols {
|
||||
eqIdx := strings.Index(joinCol, "=")
|
||||
if eqIdx < 0 {
|
||||
continue
|
||||
}
|
||||
joinLocalKey := strings.ToLower(joinCol[:eqIdx])
|
||||
if joinLocalKey == localKeyLower {
|
||||
logger.Debug("Resolved '%s' (localKey: %s) -> field '%s'", nameOrTable, localKey, field.Name)
|
||||
return field.Name
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if fallbackField != "" {
|
||||
logger.Debug("No join key match for '%s' (localKey: %s), using first type match: '%s'", nameOrTable, localKey, fallbackField)
|
||||
return fallbackField
|
||||
}
|
||||
return h.resolveRelationName(model, nameOrTable)
|
||||
}
|
||||
|
||||
// addXFilesPreload converts an XFiles relation into a PreloadOption
|
||||
// and recursively processes its children
|
||||
func (h *Handler) addXFilesPreload(xfile *XFiles, options *ExtendedRequestOptions, basePath string) {
|
||||
|
||||
@@ -142,6 +142,16 @@ func TestExtractJoinAlias(t *testing.T) {
|
||||
joinClause: "LEFT JOIN departments",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "LATERAL join with alias",
|
||||
joinClause: "inner join lateral (select sortorder from compute_fn(t.id)) fn on true",
|
||||
expected: "fn",
|
||||
},
|
||||
{
|
||||
name: "LATERAL join with multiline subquery containing inner ON",
|
||||
joinClause: "inner join lateral (\nselect string_agg(a.name, '.') as sortorder\nfrom tree(t.id) r\ninner join account a on a.id = r.id\n) fn on true",
|
||||
expected: "fn",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
||||
153
pkg/security/KEYSTORE.md
Normal file
153
pkg/security/KEYSTORE.md
Normal file
@@ -0,0 +1,153 @@
|
||||
# Keystore
|
||||
|
||||
Per-user named auth keys with pluggable storage. Each user can hold multiple keys of different types — JWT secrets, header API keys, OAuth2 client credentials, or generic API keys. Keys are identified by a human-readable name ("CI deploy", "mobile app") and can carry scopes and arbitrary metadata.
|
||||
|
||||
## Key types
|
||||
|
||||
| Constant | Value | Use case |
|
||||
|---|---|---|
|
||||
| `KeyTypeJWTSecret` | `jwt_secret` | Per-user JWT signing secret |
|
||||
| `KeyTypeHeaderAPI` | `header_api` | Static API key sent in a request header |
|
||||
| `KeyTypeOAuth2` | `oauth2` | OAuth2 client credentials |
|
||||
| `KeyTypeGenericAPI` | `api` | General-purpose application key |
|
||||
|
||||
## Storage backends
|
||||
|
||||
### ConfigKeyStore
|
||||
|
||||
In-memory store seeded from a static list. Suitable for a small, fixed set of service-account keys loaded from a config file. Keys created at runtime via `CreateKey` are held in memory and lost on restart.
|
||||
|
||||
```go
|
||||
// Pre-load keys from config (KeyHash = SHA-256 hex of the raw key)
|
||||
store := security.NewConfigKeyStore([]security.UserKey{
|
||||
{
|
||||
UserID: 1,
|
||||
KeyType: security.KeyTypeGenericAPI,
|
||||
KeyHash: "e3b0c44298fc1c149afb...", // sha256(rawKey)
|
||||
Name: "CI deploy",
|
||||
Scopes: []string{"deploy"},
|
||||
IsActive: true,
|
||||
},
|
||||
})
|
||||
```
|
||||
|
||||
### DatabaseKeyStore
|
||||
|
||||
Backed by PostgreSQL stored procedures. Supports optional caching (default 2-minute TTL). Apply `keystore_schema.sql` before use.
|
||||
|
||||
```go
|
||||
db, _ := sql.Open("postgres", dsn)
|
||||
|
||||
store := security.NewDatabaseKeyStore(db)
|
||||
|
||||
// With options
|
||||
store = security.NewDatabaseKeyStore(db, security.DatabaseKeyStoreOptions{
|
||||
CacheTTL: 5 * time.Minute,
|
||||
SQLNames: &security.KeyStoreSQLNames{
|
||||
ValidateKey: "myapp_keystore_validate", // override one procedure name
|
||||
},
|
||||
})
|
||||
```
|
||||
|
||||
## Managing keys
|
||||
|
||||
```go
|
||||
ctx := context.Background()
|
||||
|
||||
// Create — raw key returned once; store it securely
|
||||
resp, err := store.CreateKey(ctx, security.CreateKeyRequest{
|
||||
UserID: 42,
|
||||
KeyType: security.KeyTypeGenericAPI,
|
||||
Name: "mobile app",
|
||||
Scopes: []string{"read", "write"},
|
||||
})
|
||||
fmt.Println(resp.RawKey) // only shown here; hashed internally
|
||||
|
||||
// List
|
||||
keys, err := store.GetUserKeys(ctx, 42, "") // "" = all types
|
||||
keys, err = store.GetUserKeys(ctx, 42, security.KeyTypeGenericAPI)
|
||||
|
||||
// Revoke
|
||||
err = store.DeleteKey(ctx, 42, resp.Key.ID)
|
||||
|
||||
// Validate (used by authenticators internally)
|
||||
key, err := store.ValidateKey(ctx, rawKey, "")
|
||||
```
|
||||
|
||||
## HTTP authentication
|
||||
|
||||
`KeyStoreAuthenticator` wraps any `KeyStore` and implements the `Authenticator` interface. It is drop-in compatible with `DatabaseAuthenticator` and works in `CompositeSecurityProvider`.
|
||||
|
||||
Keys are extracted from the request in this order:
|
||||
|
||||
1. `Authorization: Bearer <key>`
|
||||
2. `Authorization: ApiKey <key>`
|
||||
3. `X-API-Key: <key>`
|
||||
|
||||
```go
|
||||
auth := security.NewKeyStoreAuthenticator(store, "") // "" = accept any key type
|
||||
// Restrict to a specific type:
|
||||
auth = security.NewKeyStoreAuthenticator(store, security.KeyTypeGenericAPI)
|
||||
```
|
||||
|
||||
Plug it into a handler:
|
||||
|
||||
```go
|
||||
handler := resolvespec.NewHandler(db, registry,
|
||||
resolvespec.WithAuthenticator(auth),
|
||||
)
|
||||
```
|
||||
|
||||
`Login` and `Logout` return an error — key lifecycle is managed through `KeyStore` directly.
|
||||
|
||||
On successful validation the request context receives a `UserContext` where:
|
||||
|
||||
- `UserID` — from the key
|
||||
- `Roles` — the key's `Scopes`
|
||||
- `Claims["key_type"]` — key type string
|
||||
- `Claims["key_name"]` — key name
|
||||
|
||||
## Database setup
|
||||
|
||||
Apply `keystore_schema.sql` to your PostgreSQL database. It requires the `users` table from the main `database_schema.sql`.
|
||||
|
||||
```sql
|
||||
\i pkg/security/keystore_schema.sql
|
||||
```
|
||||
|
||||
This creates:
|
||||
|
||||
- `user_keys` table with indexes on `user_id`, `key_hash`, and `key_type`
|
||||
- `resolvespec_keystore_get_user_keys(p_user_id, p_key_type)`
|
||||
- `resolvespec_keystore_create_key(p_request jsonb)`
|
||||
- `resolvespec_keystore_delete_key(p_user_id, p_key_id)`
|
||||
- `resolvespec_keystore_validate_key(p_key_hash, p_key_type)`
|
||||
|
||||
### Custom procedure names
|
||||
|
||||
```go
|
||||
store := security.NewDatabaseKeyStore(db, security.DatabaseKeyStoreOptions{
|
||||
SQLNames: &security.KeyStoreSQLNames{
|
||||
GetUserKeys: "myschema_get_keys",
|
||||
CreateKey: "myschema_create_key",
|
||||
DeleteKey: "myschema_delete_key",
|
||||
ValidateKey: "myschema_validate_key",
|
||||
},
|
||||
})
|
||||
|
||||
// Validate names at startup
|
||||
names := &security.KeyStoreSQLNames{
|
||||
GetUserKeys: "myschema_get_keys",
|
||||
// ...
|
||||
}
|
||||
if err := security.ValidateKeyStoreSQLNames(names); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
```
|
||||
|
||||
## Security notes
|
||||
|
||||
- Raw keys are never stored. Only the SHA-256 hex digest is persisted.
|
||||
- The raw key is generated with `crypto/rand` (32 bytes, base64url-encoded) and returned exactly once in `CreateKeyResponse.RawKey`.
|
||||
- Hash comparisons in `ConfigKeyStore` use `crypto/subtle.ConstantTimeCompare` to prevent timing side-channels.
|
||||
- `DeleteKey` performs a soft delete (`is_active = false`). The `DatabaseKeyStore` invalidates the cache entry immediately, but due to the cache TTL a revoked key may authenticate for up to `CacheTTL` (default 2 minutes) in a distributed environment. Set `CacheTTL: 0` to disable caching if immediate revocation is required.
|
||||
@@ -258,11 +258,8 @@ func (a *JWTAuthenticator) Login(ctx context.Context, req security.LoginRequest)
|
||||
}
|
||||
|
||||
func (a *JWTAuthenticator) Logout(ctx context.Context, req security.LogoutRequest) error {
|
||||
// Add to blacklist
|
||||
return a.db.WithContext(ctx).Table("token_blacklist").Create(map[string]any{
|
||||
"token": req.Token,
|
||||
"user_id": req.UserID,
|
||||
}).Error
|
||||
// Invalidate session via stored procedure
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *JWTAuthenticator) Authenticate(r *http.Request) (*security.UserContext, error) {
|
||||
|
||||
@@ -135,12 +135,6 @@ func (a *JWTAuthenticatorExample) Login(ctx context.Context, req LoginRequest) (
|
||||
}
|
||||
|
||||
func (a *JWTAuthenticatorExample) Logout(ctx context.Context, req LogoutRequest) error {
|
||||
// For JWT, logout could involve token blacklisting
|
||||
// Add token to blacklist table
|
||||
// err := a.db.WithContext(ctx).Table("token_blacklist").Create(map[string]interface{}{
|
||||
// "token": req.Token,
|
||||
// "expires_at": time.Now().Add(24 * time.Hour),
|
||||
// }).Error
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
81
pkg/security/keystore.go
Normal file
81
pkg/security/keystore.go
Normal file
@@ -0,0 +1,81 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"time"
|
||||
)
|
||||
|
||||
// hashSHA256Hex returns the lowercase hex SHA-256 digest of the given string.
|
||||
// Used by all keystore implementations to hash raw keys before storage or lookup.
|
||||
func hashSHA256Hex(raw string) string {
|
||||
sum := sha256.Sum256([]byte(raw))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
// KeyType identifies the category of an auth key.
|
||||
type KeyType string
|
||||
|
||||
const (
|
||||
// KeyTypeJWTSecret is a per-user JWT signing secret for token generation.
|
||||
KeyTypeJWTSecret KeyType = "jwt_secret"
|
||||
// KeyTypeHeaderAPI is a static API key sent via a request header.
|
||||
KeyTypeHeaderAPI KeyType = "header_api"
|
||||
// KeyTypeOAuth2 holds OAuth2 client credentials (client_id / client_secret).
|
||||
KeyTypeOAuth2 KeyType = "oauth2"
|
||||
// KeyTypeGenericAPI is a generic application API key.
|
||||
KeyTypeGenericAPI KeyType = "api"
|
||||
)
|
||||
|
||||
// UserKey represents a single named auth key belonging to a user.
|
||||
// KeyHash stores the SHA-256 hex digest of the raw key; the raw key is never persisted.
|
||||
type UserKey struct {
|
||||
ID int64 `json:"id"`
|
||||
UserID int `json:"user_id"`
|
||||
KeyType KeyType `json:"key_type"`
|
||||
KeyHash string `json:"key_hash"` // SHA-256 hex; never the raw key
|
||||
Name string `json:"name"`
|
||||
Scopes []string `json:"scopes,omitempty"`
|
||||
Meta map[string]any `json:"meta,omitempty"`
|
||||
ExpiresAt *time.Time `json:"expires_at,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
LastUsedAt *time.Time `json:"last_used_at,omitempty"`
|
||||
IsActive bool `json:"is_active"`
|
||||
}
|
||||
|
||||
// CreateKeyRequest specifies the parameters for a new key.
|
||||
type CreateKeyRequest struct {
|
||||
UserID int
|
||||
KeyType KeyType
|
||||
Name string
|
||||
Scopes []string
|
||||
Meta map[string]any
|
||||
ExpiresAt *time.Time
|
||||
}
|
||||
|
||||
// CreateKeyResponse is returned exactly once when a key is created.
|
||||
// The caller is responsible for persisting RawKey; it is not stored anywhere.
|
||||
type CreateKeyResponse struct {
|
||||
Key UserKey
|
||||
RawKey string // crypto/rand 32 bytes, base64url-encoded
|
||||
}
|
||||
|
||||
// KeyStore manages per-user auth keys with pluggable storage backends.
|
||||
// Implementations: ConfigKeyStore (static list) and DatabaseKeyStore (stored procedures).
|
||||
type KeyStore interface {
|
||||
// CreateKey generates a new key, stores its hash, and returns the raw key once.
|
||||
CreateKey(ctx context.Context, req CreateKeyRequest) (*CreateKeyResponse, error)
|
||||
|
||||
// GetUserKeys returns all active, non-expired keys for a user.
|
||||
// Pass an empty KeyType to return all types.
|
||||
GetUserKeys(ctx context.Context, userID int, keyType KeyType) ([]UserKey, error)
|
||||
|
||||
// DeleteKey soft-deletes a key by ID after verifying ownership.
|
||||
DeleteKey(ctx context.Context, userID int, keyID int64) error
|
||||
|
||||
// ValidateKey checks a raw key, returns the matching UserKey on success.
|
||||
// The implementation hashes the raw key before any lookup.
|
||||
// Pass an empty KeyType to accept any type.
|
||||
ValidateKey(ctx context.Context, rawKey string, keyType KeyType) (*UserKey, error)
|
||||
}
|
||||
97
pkg/security/keystore_authenticator.go
Normal file
97
pkg/security/keystore_authenticator.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// KeyStoreAuthenticator implements the Authenticator interface using a KeyStore.
|
||||
// It is suitable for long-lived application credentials (API keys, JWT secrets, etc.)
|
||||
// rather than interactive sessions. Login and Logout are not supported — key lifecycle
|
||||
// is managed directly through the KeyStore.
|
||||
//
|
||||
// Key extraction order:
|
||||
// 1. Authorization: Bearer <key>
|
||||
// 2. Authorization: ApiKey <key>
|
||||
// 3. X-API-Key header
|
||||
type KeyStoreAuthenticator struct {
|
||||
keyStore KeyStore
|
||||
keyType KeyType // empty = accept any type
|
||||
}
|
||||
|
||||
// NewKeyStoreAuthenticator creates a KeyStoreAuthenticator.
|
||||
// Pass an empty keyType to accept keys of any type.
|
||||
func NewKeyStoreAuthenticator(ks KeyStore, keyType KeyType) *KeyStoreAuthenticator {
|
||||
return &KeyStoreAuthenticator{keyStore: ks, keyType: keyType}
|
||||
}
|
||||
|
||||
// Login is not supported for keystore authentication.
|
||||
func (a *KeyStoreAuthenticator) Login(_ context.Context, _ LoginRequest) (*LoginResponse, error) {
|
||||
return nil, fmt.Errorf("keystore authenticator does not support login")
|
||||
}
|
||||
|
||||
// Logout is not supported for keystore authentication.
|
||||
func (a *KeyStoreAuthenticator) Logout(_ context.Context, _ LogoutRequest) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Authenticate extracts an API key from the request and validates it against the KeyStore.
|
||||
// Returns a UserContext built from the matching UserKey on success.
|
||||
func (a *KeyStoreAuthenticator) Authenticate(r *http.Request) (*UserContext, error) {
|
||||
rawKey := extractAPIKey(r)
|
||||
if rawKey == "" {
|
||||
return nil, fmt.Errorf("API key required (Authorization: Bearer/ApiKey <key> or X-API-Key header)")
|
||||
}
|
||||
|
||||
userKey, err := a.keyStore.ValidateKey(r.Context(), rawKey, a.keyType)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid API key: %w", err)
|
||||
}
|
||||
|
||||
return userKeyToUserContext(userKey), nil
|
||||
}
|
||||
|
||||
// extractAPIKey extracts a raw key from the request using the following precedence:
|
||||
// 1. Authorization: Bearer <key>
|
||||
// 2. Authorization: ApiKey <key>
|
||||
// 3. X-API-Key header
|
||||
func extractAPIKey(r *http.Request) string {
|
||||
if auth := r.Header.Get("Authorization"); auth != "" {
|
||||
if after, ok := strings.CutPrefix(auth, "Bearer "); ok {
|
||||
return strings.TrimSpace(after)
|
||||
}
|
||||
if after, ok := strings.CutPrefix(auth, "ApiKey "); ok {
|
||||
return strings.TrimSpace(after)
|
||||
}
|
||||
}
|
||||
return strings.TrimSpace(r.Header.Get("X-API-Key"))
|
||||
}
|
||||
|
||||
// userKeyToUserContext converts a UserKey into a UserContext.
|
||||
// Scopes are mapped to Roles. Key type and name are stored in Claims.
|
||||
func userKeyToUserContext(k *UserKey) *UserContext {
|
||||
claims := map[string]any{
|
||||
"key_type": string(k.KeyType),
|
||||
"key_name": k.Name,
|
||||
}
|
||||
|
||||
meta := k.Meta
|
||||
if meta == nil {
|
||||
meta = map[string]any{}
|
||||
}
|
||||
|
||||
roles := k.Scopes
|
||||
if roles == nil {
|
||||
roles = []string{}
|
||||
}
|
||||
|
||||
return &UserContext{
|
||||
UserID: k.UserID,
|
||||
SessionID: fmt.Sprintf("key:%d", k.ID),
|
||||
Roles: roles,
|
||||
Claims: claims,
|
||||
Meta: meta,
|
||||
}
|
||||
}
|
||||
149
pkg/security/keystore_config.go
Normal file
149
pkg/security/keystore_config.go
Normal file
@@ -0,0 +1,149 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/subtle"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ConfigKeyStore is an in-memory keystore backed by a static slice of UserKey values.
|
||||
// It is designed for config-file driven setups (e.g. service accounts defined in YAML)
|
||||
// with a small, bounded number of keys. For large or dynamic key sets use DatabaseKeyStore.
|
||||
//
|
||||
// Pre-existing entries must have KeyHash set to the SHA-256 hex of the intended raw key.
|
||||
// Keys created at runtime via CreateKey are held in memory only and lost on restart.
|
||||
type ConfigKeyStore struct {
|
||||
mu sync.RWMutex
|
||||
keys []UserKey
|
||||
next int64 // monotonic ID counter for runtime-created keys (atomic)
|
||||
}
|
||||
|
||||
// NewConfigKeyStore creates a ConfigKeyStore seeded with the provided keys.
|
||||
// Pass nil or an empty slice to start with no pre-loaded keys.
|
||||
// Zero-value entries (CreatedAt is zero) are treated as active and assigned the current time.
|
||||
func NewConfigKeyStore(keys []UserKey) *ConfigKeyStore {
|
||||
var maxID int64
|
||||
copied := make([]UserKey, len(keys))
|
||||
copy(copied, keys)
|
||||
for i := range copied {
|
||||
if copied[i].CreatedAt.IsZero() {
|
||||
copied[i].IsActive = true
|
||||
copied[i].CreatedAt = time.Now()
|
||||
}
|
||||
if copied[i].ID > maxID {
|
||||
maxID = copied[i].ID
|
||||
}
|
||||
}
|
||||
return &ConfigKeyStore{keys: copied, next: maxID}
|
||||
}
|
||||
|
||||
// CreateKey generates a new raw key, stores its SHA-256 hash, and returns the raw key once.
|
||||
func (s *ConfigKeyStore) CreateKey(_ context.Context, req CreateKeyRequest) (*CreateKeyResponse, error) {
|
||||
rawBytes := make([]byte, 32)
|
||||
if _, err := rand.Read(rawBytes); err != nil {
|
||||
return nil, fmt.Errorf("failed to generate key material: %w", err)
|
||||
}
|
||||
rawKey := base64.RawURLEncoding.EncodeToString(rawBytes)
|
||||
hash := hashSHA256Hex(rawKey)
|
||||
|
||||
id := atomic.AddInt64(&s.next, 1)
|
||||
key := UserKey{
|
||||
ID: id,
|
||||
UserID: req.UserID,
|
||||
KeyType: req.KeyType,
|
||||
KeyHash: hash,
|
||||
Name: req.Name,
|
||||
Scopes: req.Scopes,
|
||||
Meta: req.Meta,
|
||||
ExpiresAt: req.ExpiresAt,
|
||||
CreatedAt: time.Now(),
|
||||
IsActive: true,
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.keys = append(s.keys, key)
|
||||
s.mu.Unlock()
|
||||
|
||||
return &CreateKeyResponse{Key: key, RawKey: rawKey}, nil
|
||||
}
|
||||
|
||||
// GetUserKeys returns all active, non-expired keys for the given user.
|
||||
// Pass an empty KeyType to return all types.
|
||||
func (s *ConfigKeyStore) GetUserKeys(_ context.Context, userID int, keyType KeyType) ([]UserKey, error) {
|
||||
now := time.Now()
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
var result []UserKey
|
||||
for i := range s.keys {
|
||||
k := &s.keys[i]
|
||||
if k.UserID != userID || !k.IsActive {
|
||||
continue
|
||||
}
|
||||
if k.ExpiresAt != nil && k.ExpiresAt.Before(now) {
|
||||
continue
|
||||
}
|
||||
if keyType != "" && k.KeyType != keyType {
|
||||
continue
|
||||
}
|
||||
result = append(result, *k)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// DeleteKey soft-deletes a key by setting IsActive to false after ownership verification.
|
||||
func (s *ConfigKeyStore) DeleteKey(_ context.Context, userID int, keyID int64) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
for i := range s.keys {
|
||||
if s.keys[i].ID == keyID {
|
||||
if s.keys[i].UserID != userID {
|
||||
return fmt.Errorf("key not found or permission denied")
|
||||
}
|
||||
s.keys[i].IsActive = false
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("key not found")
|
||||
}
|
||||
|
||||
// ValidateKey hashes the raw key and finds a matching, active, non-expired entry.
|
||||
// Uses constant-time comparison to prevent timing side-channels.
|
||||
// Pass an empty KeyType to accept any type.
|
||||
func (s *ConfigKeyStore) ValidateKey(_ context.Context, rawKey string, keyType KeyType) (*UserKey, error) {
|
||||
hash := hashSHA256Hex(rawKey)
|
||||
hashBytes, _ := hex.DecodeString(hash)
|
||||
now := time.Now()
|
||||
|
||||
// Write lock: ValidateKey updates LastUsedAt on the matched entry.
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
for i := range s.keys {
|
||||
k := &s.keys[i]
|
||||
if !k.IsActive {
|
||||
continue
|
||||
}
|
||||
if k.ExpiresAt != nil && k.ExpiresAt.Before(now) {
|
||||
continue
|
||||
}
|
||||
if keyType != "" && k.KeyType != keyType {
|
||||
continue
|
||||
}
|
||||
stored, _ := hex.DecodeString(k.KeyHash)
|
||||
if subtle.ConstantTimeCompare(hashBytes, stored) != 1 {
|
||||
continue
|
||||
}
|
||||
k.LastUsedAt = &now
|
||||
result := *k
|
||||
return &result, nil
|
||||
}
|
||||
return nil, fmt.Errorf("invalid or expired key")
|
||||
}
|
||||
256
pkg/security/keystore_database.go
Normal file
256
pkg/security/keystore_database.go
Normal file
@@ -0,0 +1,256 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"database/sql"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/cache"
|
||||
)
|
||||
|
||||
// DatabaseKeyStoreOptions configures DatabaseKeyStore.
|
||||
type DatabaseKeyStoreOptions struct {
|
||||
// Cache is an optional cache instance. If nil, uses the default cache.
|
||||
Cache *cache.Cache
|
||||
// CacheTTL is the duration to cache ValidateKey results.
|
||||
// Default: 2 minutes.
|
||||
CacheTTL time.Duration
|
||||
// SQLNames provides custom procedure names. If nil, uses DefaultKeyStoreSQLNames().
|
||||
SQLNames *KeyStoreSQLNames
|
||||
// DBFactory is called to obtain a fresh *sql.DB when the existing connection is closed.
|
||||
// If nil, reconnection is disabled.
|
||||
DBFactory func() (*sql.DB, error)
|
||||
}
|
||||
|
||||
// DatabaseKeyStore is a KeyStore backed by PostgreSQL stored procedures.
|
||||
// All DB operations go through configurable procedure names; the raw key is
|
||||
// never passed to the database.
|
||||
//
|
||||
// See keystore_schema.sql for the required table and procedure definitions.
|
||||
//
|
||||
// Note: DeleteKey invalidates the cache entry for the deleted key. Due to the
|
||||
// cache TTL, a deleted key may continue to authenticate for up to CacheTTL
|
||||
// (default 2 minutes) if the cache entry cannot be invalidated.
|
||||
type DatabaseKeyStore struct {
|
||||
db *sql.DB
|
||||
dbMu sync.RWMutex
|
||||
dbFactory func() (*sql.DB, error)
|
||||
sqlNames *KeyStoreSQLNames
|
||||
cache *cache.Cache
|
||||
cacheTTL time.Duration
|
||||
}
|
||||
|
||||
// NewDatabaseKeyStore creates a DatabaseKeyStore with optional configuration.
|
||||
func NewDatabaseKeyStore(db *sql.DB, opts ...DatabaseKeyStoreOptions) *DatabaseKeyStore {
|
||||
o := DatabaseKeyStoreOptions{}
|
||||
if len(opts) > 0 {
|
||||
o = opts[0]
|
||||
}
|
||||
if o.CacheTTL == 0 {
|
||||
o.CacheTTL = 2 * time.Minute
|
||||
}
|
||||
c := o.Cache
|
||||
if c == nil {
|
||||
c = cache.GetDefaultCache()
|
||||
}
|
||||
names := MergeKeyStoreSQLNames(DefaultKeyStoreSQLNames(), o.SQLNames)
|
||||
return &DatabaseKeyStore{
|
||||
db: db,
|
||||
dbFactory: o.DBFactory,
|
||||
sqlNames: names,
|
||||
cache: c,
|
||||
cacheTTL: o.CacheTTL,
|
||||
}
|
||||
}
|
||||
|
||||
func (ks *DatabaseKeyStore) getDB() *sql.DB {
|
||||
ks.dbMu.RLock()
|
||||
defer ks.dbMu.RUnlock()
|
||||
return ks.db
|
||||
}
|
||||
|
||||
func (ks *DatabaseKeyStore) reconnectDB() error {
|
||||
if ks.dbFactory == nil {
|
||||
return fmt.Errorf("no db factory configured for reconnect")
|
||||
}
|
||||
newDB, err := ks.dbFactory()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ks.dbMu.Lock()
|
||||
ks.db = newDB
|
||||
ks.dbMu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateKey generates a raw key, stores its SHA-256 hash via the create procedure,
|
||||
// and returns the raw key once.
|
||||
func (ks *DatabaseKeyStore) CreateKey(ctx context.Context, req CreateKeyRequest) (*CreateKeyResponse, error) {
|
||||
rawBytes := make([]byte, 32)
|
||||
if _, err := rand.Read(rawBytes); err != nil {
|
||||
return nil, fmt.Errorf("failed to generate key material: %w", err)
|
||||
}
|
||||
rawKey := base64.RawURLEncoding.EncodeToString(rawBytes)
|
||||
hash := hashSHA256Hex(rawKey)
|
||||
|
||||
type createRequest struct {
|
||||
UserID int `json:"user_id"`
|
||||
KeyType KeyType `json:"key_type"`
|
||||
KeyHash string `json:"key_hash"`
|
||||
Name string `json:"name"`
|
||||
Scopes []string `json:"scopes,omitempty"`
|
||||
Meta map[string]any `json:"meta,omitempty"`
|
||||
ExpiresAt *time.Time `json:"expires_at,omitempty"`
|
||||
}
|
||||
|
||||
reqJSON, err := json.Marshal(createRequest{
|
||||
UserID: req.UserID,
|
||||
KeyType: req.KeyType,
|
||||
KeyHash: hash,
|
||||
Name: req.Name,
|
||||
Scopes: req.Scopes,
|
||||
Meta: req.Meta,
|
||||
ExpiresAt: req.ExpiresAt,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal create key request: %w", err)
|
||||
}
|
||||
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
var keyJSON sql.NullString
|
||||
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_key::text FROM %s($1::jsonb)`, ks.sqlNames.CreateKey)
|
||||
if err = ks.getDB().QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &keyJSON); err != nil {
|
||||
return nil, fmt.Errorf("create key procedure failed: %w", err)
|
||||
}
|
||||
if !success {
|
||||
return nil, errors.New(nullStringOr(errorMsg, "create key failed"))
|
||||
}
|
||||
|
||||
var key UserKey
|
||||
if err = json.Unmarshal([]byte(keyJSON.String), &key); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse created key: %w", err)
|
||||
}
|
||||
|
||||
return &CreateKeyResponse{Key: key, RawKey: rawKey}, nil
|
||||
}
|
||||
|
||||
// GetUserKeys returns all active, non-expired keys for the given user.
|
||||
// Pass an empty KeyType to return all types.
|
||||
func (ks *DatabaseKeyStore) GetUserKeys(ctx context.Context, userID int, keyType KeyType) ([]UserKey, error) {
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
var keysJSON sql.NullString
|
||||
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_keys::text FROM %s($1, $2)`, ks.sqlNames.GetUserKeys)
|
||||
if err := ks.getDB().QueryRowContext(ctx, query, userID, string(keyType)).Scan(&success, &errorMsg, &keysJSON); err != nil {
|
||||
return nil, fmt.Errorf("get user keys procedure failed: %w", err)
|
||||
}
|
||||
if !success {
|
||||
return nil, errors.New(nullStringOr(errorMsg, "get user keys failed"))
|
||||
}
|
||||
|
||||
var keys []UserKey
|
||||
if keysJSON.Valid && keysJSON.String != "" && keysJSON.String != "[]" {
|
||||
if err := json.Unmarshal([]byte(keysJSON.String), &keys); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse user keys: %w", err)
|
||||
}
|
||||
}
|
||||
if keys == nil {
|
||||
keys = []UserKey{}
|
||||
}
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
// DeleteKey soft-deletes a key after verifying ownership and invalidates its cache entry.
|
||||
// The delete procedure returns the key_hash so no separate lookup is needed.
|
||||
// Note: cache invalidation is best-effort; a cached entry may persist for up to CacheTTL.
|
||||
func (ks *DatabaseKeyStore) DeleteKey(ctx context.Context, userID int, keyID int64) error {
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
var keyHash sql.NullString
|
||||
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_key_hash FROM %s($1, $2)`, ks.sqlNames.DeleteKey)
|
||||
if err := ks.getDB().QueryRowContext(ctx, query, userID, keyID).Scan(&success, &errorMsg, &keyHash); err != nil {
|
||||
return fmt.Errorf("delete key procedure failed: %w", err)
|
||||
}
|
||||
if !success {
|
||||
return errors.New(nullStringOr(errorMsg, "delete key failed"))
|
||||
}
|
||||
|
||||
if keyHash.Valid && keyHash.String != "" && ks.cache != nil {
|
||||
_ = ks.cache.Delete(ctx, keystoreCacheKey(keyHash.String))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateKey hashes the raw key and calls the validate procedure.
|
||||
// Results are cached for CacheTTL to reduce DB load on hot paths.
|
||||
func (ks *DatabaseKeyStore) ValidateKey(ctx context.Context, rawKey string, keyType KeyType) (*UserKey, error) {
|
||||
hash := hashSHA256Hex(rawKey)
|
||||
cacheKey := keystoreCacheKey(hash)
|
||||
|
||||
if ks.cache != nil {
|
||||
var cached UserKey
|
||||
if err := ks.cache.Get(ctx, cacheKey, &cached); err == nil {
|
||||
if cached.IsActive {
|
||||
return &cached, nil
|
||||
}
|
||||
return nil, errors.New("invalid or expired key")
|
||||
}
|
||||
}
|
||||
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
var keyJSON sql.NullString
|
||||
|
||||
runQuery := func() error {
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_key::text FROM %s($1, $2)`, ks.sqlNames.ValidateKey)
|
||||
return ks.getDB().QueryRowContext(ctx, query, hash, string(keyType)).Scan(&success, &errorMsg, &keyJSON)
|
||||
}
|
||||
if err := runQuery(); err != nil {
|
||||
if isDBClosed(err) {
|
||||
if reconnErr := ks.reconnectDB(); reconnErr == nil {
|
||||
err = runQuery()
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("validate key procedure failed: %w", err)
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf("validate key procedure failed: %w", err)
|
||||
}
|
||||
}
|
||||
if !success {
|
||||
return nil, errors.New(nullStringOr(errorMsg, "invalid or expired key"))
|
||||
}
|
||||
|
||||
var key UserKey
|
||||
if err := json.Unmarshal([]byte(keyJSON.String), &key); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse validated key: %w", err)
|
||||
}
|
||||
|
||||
if ks.cache != nil {
|
||||
_ = ks.cache.Set(ctx, cacheKey, key, ks.cacheTTL)
|
||||
}
|
||||
|
||||
return &key, nil
|
||||
}
|
||||
|
||||
func keystoreCacheKey(hash string) string {
|
||||
return "keystore:validate:" + hash
|
||||
}
|
||||
|
||||
// nullStringOr returns s.String if valid, otherwise the fallback.
|
||||
func nullStringOr(s sql.NullString, fallback string) string {
|
||||
if s.Valid && s.String != "" {
|
||||
return s.String
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
187
pkg/security/keystore_schema.sql
Normal file
187
pkg/security/keystore_schema.sql
Normal file
@@ -0,0 +1,187 @@
|
||||
-- Keystore schema for per-user auth keys
|
||||
-- Apply alongside database_schema.sql (requires the users table)
|
||||
|
||||
CREATE TABLE IF NOT EXISTS user_keys (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
key_type VARCHAR(50) NOT NULL,
|
||||
key_hash VARCHAR(64) NOT NULL UNIQUE, -- SHA-256 hex digest (64 chars)
|
||||
name VARCHAR(255) NOT NULL DEFAULT '',
|
||||
scopes TEXT, -- JSON array, e.g. '["read","write"]'
|
||||
meta JSONB,
|
||||
expires_at TIMESTAMP,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
last_used_at TIMESTAMP,
|
||||
is_active BOOLEAN DEFAULT true
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_user_keys_user_id ON user_keys(user_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_user_keys_key_hash ON user_keys(key_hash);
|
||||
CREATE INDEX IF NOT EXISTS idx_user_keys_key_type ON user_keys(key_type);
|
||||
|
||||
-- resolvespec_keystore_get_user_keys
|
||||
-- Returns all active, non-expired keys for a user.
|
||||
-- Pass empty p_key_type to return all key types.
|
||||
CREATE OR REPLACE FUNCTION resolvespec_keystore_get_user_keys(
|
||||
p_user_id INTEGER,
|
||||
p_key_type TEXT DEFAULT ''
|
||||
)
|
||||
RETURNS TABLE(p_success BOOLEAN, p_error TEXT, p_keys JSONB)
|
||||
LANGUAGE plpgsql AS $$
|
||||
DECLARE
|
||||
v_keys JSONB;
|
||||
BEGIN
|
||||
SELECT COALESCE(
|
||||
jsonb_agg(
|
||||
jsonb_build_object(
|
||||
'id', k.id,
|
||||
'user_id', k.user_id,
|
||||
'key_type', k.key_type,
|
||||
'name', k.name,
|
||||
'scopes', CASE WHEN k.scopes IS NOT NULL THEN k.scopes::jsonb ELSE '[]'::jsonb END,
|
||||
'meta', COALESCE(k.meta, '{}'::jsonb),
|
||||
'expires_at', k.expires_at,
|
||||
'created_at', k.created_at,
|
||||
'last_used_at', k.last_used_at,
|
||||
'is_active', k.is_active
|
||||
)
|
||||
),
|
||||
'[]'::jsonb
|
||||
)
|
||||
INTO v_keys
|
||||
FROM user_keys k
|
||||
WHERE k.user_id = p_user_id
|
||||
AND k.is_active = true
|
||||
AND (k.expires_at IS NULL OR k.expires_at > NOW())
|
||||
AND (p_key_type = '' OR k.key_type = p_key_type);
|
||||
|
||||
RETURN QUERY SELECT true, NULL::TEXT, v_keys;
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
RETURN QUERY SELECT false, SQLERRM, NULL::JSONB;
|
||||
END;
|
||||
$$;
|
||||
|
||||
-- resolvespec_keystore_create_key
|
||||
-- Inserts a new key row. key_hash is provided by the caller (Go hashes the raw key).
|
||||
-- Returns the created key record (without key_hash).
|
||||
CREATE OR REPLACE FUNCTION resolvespec_keystore_create_key(
|
||||
p_request JSONB
|
||||
)
|
||||
RETURNS TABLE(p_success BOOLEAN, p_error TEXT, p_key JSONB)
|
||||
LANGUAGE plpgsql AS $$
|
||||
DECLARE
|
||||
v_id BIGINT;
|
||||
v_created_at TIMESTAMP;
|
||||
v_key JSONB;
|
||||
BEGIN
|
||||
INSERT INTO user_keys (user_id, key_type, key_hash, name, scopes, meta, expires_at)
|
||||
VALUES (
|
||||
(p_request->>'user_id')::INTEGER,
|
||||
p_request->>'key_type',
|
||||
p_request->>'key_hash',
|
||||
COALESCE(p_request->>'name', ''),
|
||||
p_request->>'scopes',
|
||||
p_request->'meta',
|
||||
CASE WHEN p_request->>'expires_at' IS NOT NULL
|
||||
THEN (p_request->>'expires_at')::TIMESTAMP
|
||||
ELSE NULL
|
||||
END
|
||||
)
|
||||
RETURNING id, created_at INTO v_id, v_created_at;
|
||||
|
||||
v_key := jsonb_build_object(
|
||||
'id', v_id,
|
||||
'user_id', (p_request->>'user_id')::INTEGER,
|
||||
'key_type', p_request->>'key_type',
|
||||
'name', COALESCE(p_request->>'name', ''),
|
||||
'scopes', CASE WHEN p_request->>'scopes' IS NOT NULL
|
||||
THEN (p_request->>'scopes')::jsonb
|
||||
ELSE '[]'::jsonb END,
|
||||
'meta', COALESCE(p_request->'meta', '{}'::jsonb),
|
||||
'expires_at', p_request->>'expires_at',
|
||||
'created_at', v_created_at,
|
||||
'is_active', true
|
||||
);
|
||||
|
||||
RETURN QUERY SELECT true, NULL::TEXT, v_key;
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
RETURN QUERY SELECT false, SQLERRM, NULL::JSONB;
|
||||
END;
|
||||
$$;
|
||||
|
||||
-- resolvespec_keystore_delete_key
|
||||
-- Soft-deletes a key (is_active = false) after verifying ownership.
|
||||
-- Returns p_key_hash so the caller can invalidate cache entries without a separate query.
|
||||
CREATE OR REPLACE FUNCTION resolvespec_keystore_delete_key(
|
||||
p_user_id INTEGER,
|
||||
p_key_id BIGINT
|
||||
)
|
||||
RETURNS TABLE(p_success BOOLEAN, p_error TEXT, p_key_hash TEXT)
|
||||
LANGUAGE plpgsql AS $$
|
||||
DECLARE
|
||||
v_hash TEXT;
|
||||
BEGIN
|
||||
UPDATE user_keys
|
||||
SET is_active = false
|
||||
WHERE id = p_key_id AND user_id = p_user_id AND is_active = true
|
||||
RETURNING key_hash INTO v_hash;
|
||||
|
||||
IF NOT FOUND THEN
|
||||
RETURN QUERY SELECT false, 'key not found or already deleted'::TEXT, NULL::TEXT;
|
||||
RETURN;
|
||||
END IF;
|
||||
|
||||
RETURN QUERY SELECT true, NULL::TEXT, v_hash;
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
RETURN QUERY SELECT false, SQLERRM, NULL::TEXT;
|
||||
END;
|
||||
$$;
|
||||
|
||||
-- resolvespec_keystore_validate_key
|
||||
-- Looks up a key by its SHA-256 hash, checks active status and expiry,
|
||||
-- updates last_used_at, and returns the key record.
|
||||
-- p_key_type can be empty to accept any key type.
|
||||
CREATE OR REPLACE FUNCTION resolvespec_keystore_validate_key(
|
||||
p_key_hash TEXT,
|
||||
p_key_type TEXT DEFAULT ''
|
||||
)
|
||||
RETURNS TABLE(p_success BOOLEAN, p_error TEXT, p_key JSONB)
|
||||
LANGUAGE plpgsql AS $$
|
||||
DECLARE
|
||||
v_key_rec user_keys%ROWTYPE;
|
||||
v_key JSONB;
|
||||
BEGIN
|
||||
SELECT * INTO v_key_rec
|
||||
FROM user_keys
|
||||
WHERE key_hash = p_key_hash
|
||||
AND is_active = true
|
||||
AND (expires_at IS NULL OR expires_at > NOW())
|
||||
AND (p_key_type = '' OR key_type = p_key_type);
|
||||
|
||||
IF NOT FOUND THEN
|
||||
RETURN QUERY SELECT false, 'invalid or expired key'::TEXT, NULL::JSONB;
|
||||
RETURN;
|
||||
END IF;
|
||||
|
||||
UPDATE user_keys SET last_used_at = NOW() WHERE id = v_key_rec.id;
|
||||
|
||||
v_key := jsonb_build_object(
|
||||
'id', v_key_rec.id,
|
||||
'user_id', v_key_rec.user_id,
|
||||
'key_type', v_key_rec.key_type,
|
||||
'name', v_key_rec.name,
|
||||
'scopes', CASE WHEN v_key_rec.scopes IS NOT NULL
|
||||
THEN v_key_rec.scopes::jsonb
|
||||
ELSE '[]'::jsonb END,
|
||||
'meta', COALESCE(v_key_rec.meta, '{}'::jsonb),
|
||||
'expires_at', v_key_rec.expires_at,
|
||||
'created_at', v_key_rec.created_at,
|
||||
'last_used_at', NOW(),
|
||||
'is_active', v_key_rec.is_active
|
||||
);
|
||||
|
||||
RETURN QUERY SELECT true, NULL::TEXT, v_key;
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
RETURN QUERY SELECT false, SQLERRM, NULL::JSONB;
|
||||
END;
|
||||
$$;
|
||||
61
pkg/security/keystore_sql_names.go
Normal file
61
pkg/security/keystore_sql_names.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package security
|
||||
|
||||
import "fmt"
|
||||
|
||||
// KeyStoreSQLNames holds the configurable stored procedure names used by DatabaseKeyStore.
|
||||
// Use DefaultKeyStoreSQLNames() for defaults and MergeKeyStoreSQLNames() for partial overrides.
|
||||
type KeyStoreSQLNames struct {
|
||||
GetUserKeys string // default: "resolvespec_keystore_get_user_keys"
|
||||
CreateKey string // default: "resolvespec_keystore_create_key"
|
||||
DeleteKey string // default: "resolvespec_keystore_delete_key"
|
||||
ValidateKey string // default: "resolvespec_keystore_validate_key"
|
||||
}
|
||||
|
||||
// DefaultKeyStoreSQLNames returns a KeyStoreSQLNames with all default resolvespec_keystore_* values.
|
||||
func DefaultKeyStoreSQLNames() *KeyStoreSQLNames {
|
||||
return &KeyStoreSQLNames{
|
||||
GetUserKeys: "resolvespec_keystore_get_user_keys",
|
||||
CreateKey: "resolvespec_keystore_create_key",
|
||||
DeleteKey: "resolvespec_keystore_delete_key",
|
||||
ValidateKey: "resolvespec_keystore_validate_key",
|
||||
}
|
||||
}
|
||||
|
||||
// MergeKeyStoreSQLNames returns a copy of base with any non-empty fields from override applied.
|
||||
// If override is nil, a copy of base is returned.
|
||||
func MergeKeyStoreSQLNames(base, override *KeyStoreSQLNames) *KeyStoreSQLNames {
|
||||
if override == nil {
|
||||
copied := *base
|
||||
return &copied
|
||||
}
|
||||
merged := *base
|
||||
if override.GetUserKeys != "" {
|
||||
merged.GetUserKeys = override.GetUserKeys
|
||||
}
|
||||
if override.CreateKey != "" {
|
||||
merged.CreateKey = override.CreateKey
|
||||
}
|
||||
if override.DeleteKey != "" {
|
||||
merged.DeleteKey = override.DeleteKey
|
||||
}
|
||||
if override.ValidateKey != "" {
|
||||
merged.ValidateKey = override.ValidateKey
|
||||
}
|
||||
return &merged
|
||||
}
|
||||
|
||||
// ValidateKeyStoreSQLNames checks that all non-empty procedure names are valid SQL identifiers.
|
||||
func ValidateKeyStoreSQLNames(names *KeyStoreSQLNames) error {
|
||||
fields := map[string]string{
|
||||
"GetUserKeys": names.GetUserKeys,
|
||||
"CreateKey": names.CreateKey,
|
||||
"DeleteKey": names.DeleteKey,
|
||||
"ValidateKey": names.ValidateKey,
|
||||
}
|
||||
for field, val := range fields {
|
||||
if val != "" && !validSQLIdentifier.MatchString(val) {
|
||||
return fmt.Errorf("KeyStoreSQLNames.%s contains invalid characters: %q", field, val)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -244,10 +244,10 @@ func (a *DatabaseAuthenticator) oauth2GetOrCreateUser(ctx context.Context, userC
|
||||
var errMsg *string
|
||||
var userID *int
|
||||
|
||||
err = a.db.QueryRowContext(ctx, `
|
||||
SELECT p_success, p_error, p_user_id
|
||||
FROM resolvespec_oauth_getorcreateuser($1::jsonb)
|
||||
`, userJSON).Scan(&success, &errMsg, &userID)
|
||||
err = a.getDB().QueryRowContext(ctx, fmt.Sprintf(`
|
||||
SELECT p_success, p_error, p_user_id
|
||||
FROM %s($1::jsonb)
|
||||
`, a.sqlNames.OAuthGetOrCreateUser), userJSON).Scan(&success, &errMsg, &userID)
|
||||
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to get or create user: %w", err)
|
||||
@@ -287,10 +287,10 @@ func (a *DatabaseAuthenticator) oauth2CreateSession(ctx context.Context, session
|
||||
var success bool
|
||||
var errMsg *string
|
||||
|
||||
err = a.db.QueryRowContext(ctx, `
|
||||
SELECT p_success, p_error
|
||||
FROM resolvespec_oauth_createsession($1::jsonb)
|
||||
`, sessionJSON).Scan(&success, &errMsg)
|
||||
err = a.getDB().QueryRowContext(ctx, fmt.Sprintf(`
|
||||
SELECT p_success, p_error
|
||||
FROM %s($1::jsonb)
|
||||
`, a.sqlNames.OAuthCreateSession), sessionJSON).Scan(&success, &errMsg)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create session: %w", err)
|
||||
@@ -385,10 +385,10 @@ func (a *DatabaseAuthenticator) OAuth2RefreshToken(ctx context.Context, refreshT
|
||||
var errMsg *string
|
||||
var sessionData []byte
|
||||
|
||||
err = a.db.QueryRowContext(ctx, `
|
||||
err = a.getDB().QueryRowContext(ctx, fmt.Sprintf(`
|
||||
SELECT p_success, p_error, p_data::text
|
||||
FROM resolvespec_oauth_getrefreshtoken($1)
|
||||
`, refreshToken).Scan(&success, &errMsg, &sessionData)
|
||||
FROM %s($1)
|
||||
`, a.sqlNames.OAuthGetRefreshToken), refreshToken).Scan(&success, &errMsg, &sessionData)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get session by refresh token: %w", err)
|
||||
@@ -451,10 +451,10 @@ func (a *DatabaseAuthenticator) OAuth2RefreshToken(ctx context.Context, refreshT
|
||||
var updateSuccess bool
|
||||
var updateErrMsg *string
|
||||
|
||||
err = a.db.QueryRowContext(ctx, `
|
||||
err = a.getDB().QueryRowContext(ctx, fmt.Sprintf(`
|
||||
SELECT p_success, p_error
|
||||
FROM resolvespec_oauth_updaterefreshtoken($1::jsonb)
|
||||
`, updateJSON).Scan(&updateSuccess, &updateErrMsg)
|
||||
FROM %s($1::jsonb)
|
||||
`, a.sqlNames.OAuthUpdateRefreshToken), updateJSON).Scan(&updateSuccess, &updateErrMsg)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to update session: %w", err)
|
||||
@@ -472,10 +472,10 @@ func (a *DatabaseAuthenticator) OAuth2RefreshToken(ctx context.Context, refreshT
|
||||
var userErrMsg *string
|
||||
var userData []byte
|
||||
|
||||
err = a.db.QueryRowContext(ctx, `
|
||||
err = a.getDB().QueryRowContext(ctx, fmt.Sprintf(`
|
||||
SELECT p_success, p_error, p_data::text
|
||||
FROM resolvespec_oauth_getuser($1)
|
||||
`, session.UserID).Scan(&userSuccess, &userErrMsg, &userData)
|
||||
FROM %s($1)
|
||||
`, a.sqlNames.OAuthGetUser), session.UserID).Scan(&userSuccess, &userErrMsg, &userData)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get user data: %w", err)
|
||||
|
||||
@@ -7,16 +7,21 @@ import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// DatabasePasskeyProvider implements PasskeyProvider using database storage
|
||||
// Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults)
|
||||
type DatabasePasskeyProvider struct {
|
||||
db *sql.DB
|
||||
rpID string // Relying Party ID (domain)
|
||||
rpName string // Relying Party display name
|
||||
rpOrigin string // Expected origin for WebAuthn
|
||||
timeout int64 // Timeout in milliseconds (default: 60000)
|
||||
db *sql.DB
|
||||
dbMu sync.RWMutex
|
||||
dbFactory func() (*sql.DB, error)
|
||||
rpID string // Relying Party ID (domain)
|
||||
rpName string // Relying Party display name
|
||||
rpOrigin string // Expected origin for WebAuthn
|
||||
timeout int64 // Timeout in milliseconds (default: 60000)
|
||||
sqlNames *SQLNames
|
||||
}
|
||||
|
||||
// DatabasePasskeyProviderOptions configures the passkey provider
|
||||
@@ -29,6 +34,11 @@ type DatabasePasskeyProviderOptions struct {
|
||||
RPOrigin string
|
||||
// Timeout is the timeout for operations in milliseconds (default: 60000)
|
||||
Timeout int64
|
||||
// SQLNames provides custom SQL procedure/function names. If nil, uses DefaultSQLNames().
|
||||
SQLNames *SQLNames
|
||||
// DBFactory is called to obtain a fresh *sql.DB when the existing connection is closed.
|
||||
// If nil, reconnection is disabled.
|
||||
DBFactory func() (*sql.DB, error)
|
||||
}
|
||||
|
||||
// NewDatabasePasskeyProvider creates a new database-backed passkey provider
|
||||
@@ -37,15 +47,39 @@ func NewDatabasePasskeyProvider(db *sql.DB, opts DatabasePasskeyProviderOptions)
|
||||
opts.Timeout = 60000 // 60 seconds default
|
||||
}
|
||||
|
||||
sqlNames := MergeSQLNames(DefaultSQLNames(), opts.SQLNames)
|
||||
|
||||
return &DatabasePasskeyProvider{
|
||||
db: db,
|
||||
rpID: opts.RPID,
|
||||
rpName: opts.RPName,
|
||||
rpOrigin: opts.RPOrigin,
|
||||
timeout: opts.Timeout,
|
||||
db: db,
|
||||
dbFactory: opts.DBFactory,
|
||||
rpID: opts.RPID,
|
||||
rpName: opts.RPName,
|
||||
rpOrigin: opts.RPOrigin,
|
||||
timeout: opts.Timeout,
|
||||
sqlNames: sqlNames,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *DatabasePasskeyProvider) getDB() *sql.DB {
|
||||
p.dbMu.RLock()
|
||||
defer p.dbMu.RUnlock()
|
||||
return p.db
|
||||
}
|
||||
|
||||
func (p *DatabasePasskeyProvider) reconnectDB() error {
|
||||
if p.dbFactory == nil {
|
||||
return fmt.Errorf("no db factory configured for reconnect")
|
||||
}
|
||||
newDB, err := p.dbFactory()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p.dbMu.Lock()
|
||||
p.db = newDB
|
||||
p.dbMu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// BeginRegistration creates registration options for a new passkey
|
||||
func (p *DatabasePasskeyProvider) BeginRegistration(ctx context.Context, userID int, username, displayName string) (*PasskeyRegistrationOptions, error) {
|
||||
// Generate challenge
|
||||
@@ -132,8 +166,8 @@ func (p *DatabasePasskeyProvider) CompleteRegistration(ctx context.Context, user
|
||||
var errorMsg sql.NullString
|
||||
var credentialID sql.NullInt64
|
||||
|
||||
query := `SELECT p_success, p_error, p_credential_id FROM resolvespec_passkey_store_credential($1::jsonb)`
|
||||
err = p.db.QueryRowContext(ctx, query, string(credJSON)).Scan(&success, &errorMsg, &credentialID)
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_credential_id FROM %s($1::jsonb)`, p.sqlNames.PasskeyStoreCredential)
|
||||
err = p.getDB().QueryRowContext(ctx, query, string(credJSON)).Scan(&success, &errorMsg, &credentialID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to store credential: %w", err)
|
||||
}
|
||||
@@ -173,8 +207,8 @@ func (p *DatabasePasskeyProvider) BeginAuthentication(ctx context.Context, usern
|
||||
var userID sql.NullInt64
|
||||
var credentialsJSON sql.NullString
|
||||
|
||||
query := `SELECT p_success, p_error, p_user_id, p_credentials::text FROM resolvespec_passkey_get_credentials_by_username($1)`
|
||||
err := p.db.QueryRowContext(ctx, query, username).Scan(&success, &errorMsg, &userID, &credentialsJSON)
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_user_id, p_credentials::text FROM %s($1)`, p.sqlNames.PasskeyGetCredsByUsername)
|
||||
err := p.getDB().QueryRowContext(ctx, query, username).Scan(&success, &errorMsg, &userID, &credentialsJSON)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get credentials: %w", err)
|
||||
}
|
||||
@@ -233,8 +267,16 @@ func (p *DatabasePasskeyProvider) CompleteAuthentication(ctx context.Context, re
|
||||
var errorMsg sql.NullString
|
||||
var credentialJSON sql.NullString
|
||||
|
||||
query := `SELECT p_success, p_error, p_credential::text FROM resolvespec_passkey_get_credential($1)`
|
||||
err := p.db.QueryRowContext(ctx, query, response.RawID).Scan(&success, &errorMsg, &credentialJSON)
|
||||
runQuery := func() error {
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_credential::text FROM %s($1)`, p.sqlNames.PasskeyGetCredential)
|
||||
return p.getDB().QueryRowContext(ctx, query, response.RawID).Scan(&success, &errorMsg, &credentialJSON)
|
||||
}
|
||||
err := runQuery()
|
||||
if isDBClosed(err) {
|
||||
if reconnErr := p.reconnectDB(); reconnErr == nil {
|
||||
err = runQuery()
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to get credential: %w", err)
|
||||
}
|
||||
@@ -264,8 +306,8 @@ func (p *DatabasePasskeyProvider) CompleteAuthentication(ctx context.Context, re
|
||||
var updateError sql.NullString
|
||||
var cloneWarning sql.NullBool
|
||||
|
||||
updateQuery := `SELECT p_success, p_error, p_clone_warning FROM resolvespec_passkey_update_counter($1, $2)`
|
||||
err = p.db.QueryRowContext(ctx, updateQuery, response.RawID, newCounter).Scan(&updateSuccess, &updateError, &cloneWarning)
|
||||
updateQuery := fmt.Sprintf(`SELECT p_success, p_error, p_clone_warning FROM %s($1, $2)`, p.sqlNames.PasskeyUpdateCounter)
|
||||
err = p.getDB().QueryRowContext(ctx, updateQuery, response.RawID, newCounter).Scan(&updateSuccess, &updateError, &cloneWarning)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to update counter: %w", err)
|
||||
}
|
||||
@@ -283,8 +325,8 @@ func (p *DatabasePasskeyProvider) GetCredentials(ctx context.Context, userID int
|
||||
var errorMsg sql.NullString
|
||||
var credentialsJSON sql.NullString
|
||||
|
||||
query := `SELECT p_success, p_error, p_credentials::text FROM resolvespec_passkey_get_user_credentials($1)`
|
||||
err := p.db.QueryRowContext(ctx, query, userID).Scan(&success, &errorMsg, &credentialsJSON)
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_credentials::text FROM %s($1)`, p.sqlNames.PasskeyGetUserCredentials)
|
||||
err := p.getDB().QueryRowContext(ctx, query, userID).Scan(&success, &errorMsg, &credentialsJSON)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get credentials: %w", err)
|
||||
}
|
||||
@@ -362,8 +404,8 @@ func (p *DatabasePasskeyProvider) DeleteCredential(ctx context.Context, userID i
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
|
||||
query := `SELECT p_success, p_error FROM resolvespec_passkey_delete_credential($1, $2)`
|
||||
err = p.db.QueryRowContext(ctx, query, userID, credID).Scan(&success, &errorMsg)
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1, $2)`, p.sqlNames.PasskeyDeleteCredential)
|
||||
err = p.getDB().QueryRowContext(ctx, query, userID, credID).Scan(&success, &errorMsg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete credential: %w", err)
|
||||
}
|
||||
@@ -388,8 +430,8 @@ func (p *DatabasePasskeyProvider) UpdateCredentialName(ctx context.Context, user
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
|
||||
query := `SELECT p_success, p_error FROM resolvespec_passkey_update_name($1, $2, $3)`
|
||||
err = p.db.QueryRowContext(ctx, query, userID, credID, name).Scan(&success, &errorMsg)
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1, $2, $3)`, p.sqlNames.PasskeyUpdateName)
|
||||
err = p.getDB().QueryRowContext(ctx, query, userID, credID, name).Scan(&success, &errorMsg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update credential name: %w", err)
|
||||
}
|
||||
|
||||
@@ -58,15 +58,17 @@ func (a *HeaderAuthenticator) Authenticate(r *http.Request) (*UserContext, error
|
||||
|
||||
// DatabaseAuthenticator provides session-based authentication with database storage
|
||||
// All database operations go through stored procedures for security and consistency
|
||||
// Requires stored procedures: resolvespec_login, resolvespec_logout, resolvespec_session,
|
||||
// resolvespec_session_update, resolvespec_refresh_token
|
||||
// Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults)
|
||||
// See database_schema.sql for procedure definitions
|
||||
// Also supports multiple OAuth2 providers configured with WithOAuth2()
|
||||
// Also supports passkey authentication configured with WithPasskey()
|
||||
type DatabaseAuthenticator struct {
|
||||
db *sql.DB
|
||||
cache *cache.Cache
|
||||
cacheTTL time.Duration
|
||||
db *sql.DB
|
||||
dbMu sync.RWMutex
|
||||
dbFactory func() (*sql.DB, error)
|
||||
cache *cache.Cache
|
||||
cacheTTL time.Duration
|
||||
sqlNames *SQLNames
|
||||
|
||||
// OAuth2 providers registry (multiple providers supported)
|
||||
oauth2Providers map[string]*OAuth2Provider
|
||||
@@ -85,6 +87,12 @@ type DatabaseAuthenticatorOptions struct {
|
||||
Cache *cache.Cache
|
||||
// PasskeyProvider is an optional passkey provider for WebAuthn/FIDO2 authentication
|
||||
PasskeyProvider PasskeyProvider
|
||||
// SQLNames provides custom SQL procedure/function names. If nil, uses DefaultSQLNames().
|
||||
// Partial overrides are supported: only set the fields you want to change.
|
||||
SQLNames *SQLNames
|
||||
// DBFactory is called to obtain a fresh *sql.DB when the existing connection is closed.
|
||||
// If nil, reconnection is disabled.
|
||||
DBFactory func() (*sql.DB, error)
|
||||
}
|
||||
|
||||
func NewDatabaseAuthenticator(db *sql.DB) *DatabaseAuthenticator {
|
||||
@@ -103,14 +111,38 @@ func NewDatabaseAuthenticatorWithOptions(db *sql.DB, opts DatabaseAuthenticatorO
|
||||
cacheInstance = cache.GetDefaultCache()
|
||||
}
|
||||
|
||||
sqlNames := MergeSQLNames(DefaultSQLNames(), opts.SQLNames)
|
||||
|
||||
return &DatabaseAuthenticator{
|
||||
db: db,
|
||||
dbFactory: opts.DBFactory,
|
||||
cache: cacheInstance,
|
||||
cacheTTL: opts.CacheTTL,
|
||||
sqlNames: sqlNames,
|
||||
passkeyProvider: opts.PasskeyProvider,
|
||||
}
|
||||
}
|
||||
|
||||
func (a *DatabaseAuthenticator) getDB() *sql.DB {
|
||||
a.dbMu.RLock()
|
||||
defer a.dbMu.RUnlock()
|
||||
return a.db
|
||||
}
|
||||
|
||||
func (a *DatabaseAuthenticator) reconnectDB() error {
|
||||
if a.dbFactory == nil {
|
||||
return fmt.Errorf("no db factory configured for reconnect")
|
||||
}
|
||||
newDB, err := a.dbFactory()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
a.dbMu.Lock()
|
||||
a.db = newDB
|
||||
a.dbMu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *DatabaseAuthenticator) Login(ctx context.Context, req LoginRequest) (*LoginResponse, error) {
|
||||
// Convert LoginRequest to JSON
|
||||
reqJSON, err := json.Marshal(req)
|
||||
@@ -118,13 +150,20 @@ func (a *DatabaseAuthenticator) Login(ctx context.Context, req LoginRequest) (*L
|
||||
return nil, fmt.Errorf("failed to marshal login request: %w", err)
|
||||
}
|
||||
|
||||
// Call resolvespec_login stored procedure
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
var dataJSON sql.NullString
|
||||
|
||||
query := `SELECT p_success, p_error, p_data::text FROM resolvespec_login($1::jsonb)`
|
||||
err = a.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
|
||||
runLoginQuery := func() error {
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Login)
|
||||
return a.getDB().QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
|
||||
}
|
||||
err = runLoginQuery()
|
||||
if isDBClosed(err) {
|
||||
if reconnErr := a.reconnectDB(); reconnErr == nil {
|
||||
err = runLoginQuery()
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("login query failed: %w", err)
|
||||
}
|
||||
@@ -153,13 +192,12 @@ func (a *DatabaseAuthenticator) Register(ctx context.Context, req RegisterReques
|
||||
return nil, fmt.Errorf("failed to marshal register request: %w", err)
|
||||
}
|
||||
|
||||
// Call resolvespec_register stored procedure
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
var dataJSON sql.NullString
|
||||
|
||||
query := `SELECT p_success, p_error, p_data::text FROM resolvespec_register($1::jsonb)`
|
||||
err = a.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Register)
|
||||
err = a.getDB().QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("register query failed: %w", err)
|
||||
}
|
||||
@@ -187,13 +225,12 @@ func (a *DatabaseAuthenticator) Logout(ctx context.Context, req LogoutRequest) e
|
||||
return fmt.Errorf("failed to marshal logout request: %w", err)
|
||||
}
|
||||
|
||||
// Call resolvespec_logout stored procedure
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
var dataJSON sql.NullString
|
||||
|
||||
query := `SELECT p_success, p_error, p_data::text FROM resolvespec_logout($1::jsonb)`
|
||||
err = a.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Logout)
|
||||
err = a.getDB().QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
|
||||
if err != nil {
|
||||
return fmt.Errorf("logout query failed: %w", err)
|
||||
}
|
||||
@@ -266,8 +303,8 @@ func (a *DatabaseAuthenticator) Authenticate(r *http.Request) (*UserContext, err
|
||||
var errorMsg sql.NullString
|
||||
var userJSON sql.NullString
|
||||
|
||||
query := `SELECT p_success, p_error, p_user::text FROM resolvespec_session($1, $2)`
|
||||
err := a.db.QueryRowContext(r.Context(), query, token, reference).Scan(&success, &errorMsg, &userJSON)
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2)`, a.sqlNames.Session)
|
||||
err := a.getDB().QueryRowContext(r.Context(), query, token, reference).Scan(&success, &errorMsg, &userJSON)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("session query failed: %w", err)
|
||||
}
|
||||
@@ -338,25 +375,23 @@ func (a *DatabaseAuthenticator) updateSessionActivity(ctx context.Context, sessi
|
||||
return
|
||||
}
|
||||
|
||||
// Call resolvespec_session_update stored procedure
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
var updatedUserJSON sql.NullString
|
||||
|
||||
query := `SELECT p_success, p_error, p_user::text FROM resolvespec_session_update($1, $2::jsonb)`
|
||||
_ = a.db.QueryRowContext(ctx, query, sessionToken, string(userJSON)).Scan(&success, &errorMsg, &updatedUserJSON)
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2::jsonb)`, a.sqlNames.SessionUpdate)
|
||||
_ = a.getDB().QueryRowContext(ctx, query, sessionToken, string(userJSON)).Scan(&success, &errorMsg, &updatedUserJSON)
|
||||
}
|
||||
|
||||
// RefreshToken implements Refreshable interface
|
||||
func (a *DatabaseAuthenticator) RefreshToken(ctx context.Context, refreshToken string) (*LoginResponse, error) {
|
||||
// Call api_refresh_token stored procedure
|
||||
// First, we need to get the current user context for the refresh token
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
var userJSON sql.NullString
|
||||
// Get current session to pass to refresh
|
||||
query := `SELECT p_success, p_error, p_user::text FROM resolvespec_session($1, $2)`
|
||||
err := a.db.QueryRowContext(ctx, query, refreshToken, "refresh").Scan(&success, &errorMsg, &userJSON)
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2)`, a.sqlNames.Session)
|
||||
err := a.getDB().QueryRowContext(ctx, query, refreshToken, "refresh").Scan(&success, &errorMsg, &userJSON)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("refresh token query failed: %w", err)
|
||||
}
|
||||
@@ -368,13 +403,12 @@ func (a *DatabaseAuthenticator) RefreshToken(ctx context.Context, refreshToken s
|
||||
return nil, fmt.Errorf("invalid refresh token")
|
||||
}
|
||||
|
||||
// Call resolvespec_refresh_token to generate new token
|
||||
var newSuccess bool
|
||||
var newErrorMsg sql.NullString
|
||||
var newUserJSON sql.NullString
|
||||
|
||||
refreshQuery := `SELECT p_success, p_error, p_user::text FROM resolvespec_refresh_token($1, $2::jsonb)`
|
||||
err = a.db.QueryRowContext(ctx, refreshQuery, refreshToken, userJSON).Scan(&newSuccess, &newErrorMsg, &newUserJSON)
|
||||
refreshQuery := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2::jsonb)`, a.sqlNames.RefreshToken)
|
||||
err = a.getDB().QueryRowContext(ctx, refreshQuery, refreshToken, userJSON).Scan(&newSuccess, &newErrorMsg, &newUserJSON)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("refresh token generation failed: %w", err)
|
||||
}
|
||||
@@ -401,28 +435,65 @@ func (a *DatabaseAuthenticator) RefreshToken(ctx context.Context, refreshToken s
|
||||
|
||||
// JWTAuthenticator provides JWT token-based authentication
|
||||
// All database operations go through stored procedures
|
||||
// Requires stored procedures: resolvespec_jwt_login, resolvespec_jwt_logout
|
||||
// Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults)
|
||||
// NOTE: JWT signing/verification requires github.com/golang-jwt/jwt/v5 to be installed and imported
|
||||
type JWTAuthenticator struct {
|
||||
secretKey []byte
|
||||
db *sql.DB
|
||||
dbMu sync.RWMutex
|
||||
dbFactory func() (*sql.DB, error)
|
||||
sqlNames *SQLNames
|
||||
}
|
||||
|
||||
func NewJWTAuthenticator(secretKey string, db *sql.DB) *JWTAuthenticator {
|
||||
func NewJWTAuthenticator(secretKey string, db *sql.DB, names ...*SQLNames) *JWTAuthenticator {
|
||||
return &JWTAuthenticator{
|
||||
secretKey: []byte(secretKey),
|
||||
db: db,
|
||||
sqlNames: resolveSQLNames(names...),
|
||||
}
|
||||
}
|
||||
|
||||
// WithDBFactory configures a factory used to reopen the database connection if it is closed.
|
||||
func (a *JWTAuthenticator) WithDBFactory(factory func() (*sql.DB, error)) *JWTAuthenticator {
|
||||
a.dbFactory = factory
|
||||
return a
|
||||
}
|
||||
|
||||
func (a *JWTAuthenticator) getDB() *sql.DB {
|
||||
a.dbMu.RLock()
|
||||
defer a.dbMu.RUnlock()
|
||||
return a.db
|
||||
}
|
||||
|
||||
func (a *JWTAuthenticator) reconnectDB() error {
|
||||
if a.dbFactory == nil {
|
||||
return fmt.Errorf("no db factory configured for reconnect")
|
||||
}
|
||||
newDB, err := a.dbFactory()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
a.dbMu.Lock()
|
||||
a.db = newDB
|
||||
a.dbMu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *JWTAuthenticator) Login(ctx context.Context, req LoginRequest) (*LoginResponse, error) {
|
||||
// Call resolvespec_jwt_login stored procedure
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
var userJSON []byte
|
||||
|
||||
query := `SELECT p_success, p_error, p_user FROM resolvespec_jwt_login($1, $2)`
|
||||
err := a.db.QueryRowContext(ctx, query, req.Username, req.Password).Scan(&success, &errorMsg, &userJSON)
|
||||
runLoginQuery := func() error {
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_user FROM %s($1, $2)`, a.sqlNames.JWTLogin)
|
||||
return a.getDB().QueryRowContext(ctx, query, req.Username, req.Password).Scan(&success, &errorMsg, &userJSON)
|
||||
}
|
||||
err := runLoginQuery()
|
||||
if isDBClosed(err) {
|
||||
if reconnErr := a.reconnectDB(); reconnErr == nil {
|
||||
err = runLoginQuery()
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("login query failed: %w", err)
|
||||
}
|
||||
@@ -471,12 +542,11 @@ func (a *JWTAuthenticator) Login(ctx context.Context, req LoginRequest) (*LoginR
|
||||
}
|
||||
|
||||
func (a *JWTAuthenticator) Logout(ctx context.Context, req LogoutRequest) error {
|
||||
// Call resolvespec_jwt_logout stored procedure
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
|
||||
query := `SELECT p_success, p_error FROM resolvespec_jwt_logout($1, $2)`
|
||||
err := a.db.QueryRowContext(ctx, query, req.Token, req.UserID).Scan(&success, &errorMsg)
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1, $2)`, a.sqlNames.JWTLogout)
|
||||
err := a.getDB().QueryRowContext(ctx, query, req.Token, req.UserID).Scan(&success, &errorMsg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("logout query failed: %w", err)
|
||||
}
|
||||
@@ -511,25 +581,60 @@ func (a *JWTAuthenticator) Authenticate(r *http.Request) (*UserContext, error) {
|
||||
|
||||
// DatabaseColumnSecurityProvider loads column security from database
|
||||
// All database operations go through stored procedures
|
||||
// Requires stored procedure: resolvespec_column_security
|
||||
// Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults)
|
||||
type DatabaseColumnSecurityProvider struct {
|
||||
db *sql.DB
|
||||
db *sql.DB
|
||||
dbMu sync.RWMutex
|
||||
dbFactory func() (*sql.DB, error)
|
||||
sqlNames *SQLNames
|
||||
}
|
||||
|
||||
func NewDatabaseColumnSecurityProvider(db *sql.DB) *DatabaseColumnSecurityProvider {
|
||||
return &DatabaseColumnSecurityProvider{db: db}
|
||||
func NewDatabaseColumnSecurityProvider(db *sql.DB, names ...*SQLNames) *DatabaseColumnSecurityProvider {
|
||||
return &DatabaseColumnSecurityProvider{db: db, sqlNames: resolveSQLNames(names...)}
|
||||
}
|
||||
|
||||
func (p *DatabaseColumnSecurityProvider) WithDBFactory(factory func() (*sql.DB, error)) *DatabaseColumnSecurityProvider {
|
||||
p.dbFactory = factory
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *DatabaseColumnSecurityProvider) getDB() *sql.DB {
|
||||
p.dbMu.RLock()
|
||||
defer p.dbMu.RUnlock()
|
||||
return p.db
|
||||
}
|
||||
|
||||
func (p *DatabaseColumnSecurityProvider) reconnectDB() error {
|
||||
if p.dbFactory == nil {
|
||||
return fmt.Errorf("no db factory configured for reconnect")
|
||||
}
|
||||
newDB, err := p.dbFactory()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p.dbMu.Lock()
|
||||
p.db = newDB
|
||||
p.dbMu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *DatabaseColumnSecurityProvider) GetColumnSecurity(ctx context.Context, userID int, schema, table string) ([]ColumnSecurity, error) {
|
||||
var rules []ColumnSecurity
|
||||
|
||||
// Call resolvespec_column_security stored procedure
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
var rulesJSON []byte
|
||||
|
||||
query := `SELECT p_success, p_error, p_rules FROM resolvespec_column_security($1, $2, $3)`
|
||||
err := p.db.QueryRowContext(ctx, query, userID, schema, table).Scan(&success, &errorMsg, &rulesJSON)
|
||||
runQuery := func() error {
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_rules FROM %s($1, $2, $3)`, p.sqlNames.ColumnSecurity)
|
||||
return p.getDB().QueryRowContext(ctx, query, userID, schema, table).Scan(&success, &errorMsg, &rulesJSON)
|
||||
}
|
||||
err := runQuery()
|
||||
if isDBClosed(err) {
|
||||
if reconnErr := p.reconnectDB(); reconnErr == nil {
|
||||
err = runQuery()
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load column security: %w", err)
|
||||
}
|
||||
@@ -576,23 +681,57 @@ func (p *DatabaseColumnSecurityProvider) GetColumnSecurity(ctx context.Context,
|
||||
|
||||
// DatabaseRowSecurityProvider loads row security from database
|
||||
// All database operations go through stored procedures
|
||||
// Requires stored procedure: resolvespec_row_security
|
||||
// Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults)
|
||||
type DatabaseRowSecurityProvider struct {
|
||||
db *sql.DB
|
||||
db *sql.DB
|
||||
dbMu sync.RWMutex
|
||||
dbFactory func() (*sql.DB, error)
|
||||
sqlNames *SQLNames
|
||||
}
|
||||
|
||||
func NewDatabaseRowSecurityProvider(db *sql.DB) *DatabaseRowSecurityProvider {
|
||||
return &DatabaseRowSecurityProvider{db: db}
|
||||
func NewDatabaseRowSecurityProvider(db *sql.DB, names ...*SQLNames) *DatabaseRowSecurityProvider {
|
||||
return &DatabaseRowSecurityProvider{db: db, sqlNames: resolveSQLNames(names...)}
|
||||
}
|
||||
|
||||
func (p *DatabaseRowSecurityProvider) WithDBFactory(factory func() (*sql.DB, error)) *DatabaseRowSecurityProvider {
|
||||
p.dbFactory = factory
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *DatabaseRowSecurityProvider) getDB() *sql.DB {
|
||||
p.dbMu.RLock()
|
||||
defer p.dbMu.RUnlock()
|
||||
return p.db
|
||||
}
|
||||
|
||||
func (p *DatabaseRowSecurityProvider) reconnectDB() error {
|
||||
if p.dbFactory == nil {
|
||||
return fmt.Errorf("no db factory configured for reconnect")
|
||||
}
|
||||
newDB, err := p.dbFactory()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p.dbMu.Lock()
|
||||
p.db = newDB
|
||||
p.dbMu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *DatabaseRowSecurityProvider) GetRowSecurity(ctx context.Context, userID int, schema, table string) (RowSecurity, error) {
|
||||
var template string
|
||||
var hasBlock bool
|
||||
|
||||
// Call resolvespec_row_security stored procedure
|
||||
query := `SELECT p_template, p_block FROM resolvespec_row_security($1, $2, $3)`
|
||||
|
||||
err := p.db.QueryRowContext(ctx, query, schema, table, userID).Scan(&template, &hasBlock)
|
||||
runQuery := func() error {
|
||||
query := fmt.Sprintf(`SELECT p_template, p_block FROM %s($1, $2, $3)`, p.sqlNames.RowSecurity)
|
||||
return p.getDB().QueryRowContext(ctx, query, schema, table, userID).Scan(&template, &hasBlock)
|
||||
}
|
||||
err := runQuery()
|
||||
if isDBClosed(err) {
|
||||
if reconnErr := p.reconnectDB(); reconnErr == nil {
|
||||
err = runQuery()
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return RowSecurity{}, fmt.Errorf("failed to load row security: %w", err)
|
||||
}
|
||||
@@ -662,6 +801,11 @@ func (p *ConfigRowSecurityProvider) GetRowSecurity(ctx context.Context, userID i
|
||||
// Helper functions
|
||||
// ================
|
||||
|
||||
// isDBClosed reports whether err indicates the *sql.DB has been closed.
|
||||
func isDBClosed(err error) bool {
|
||||
return err != nil && strings.Contains(err.Error(), "sql: database is closed")
|
||||
}
|
||||
|
||||
func parseRoles(rolesStr string) []string {
|
||||
if rolesStr == "" {
|
||||
return []string{}
|
||||
@@ -758,56 +902,55 @@ func (a *DatabaseAuthenticator) LoginWithPasskey(ctx context.Context, req Passke
|
||||
return nil, fmt.Errorf("passkey authentication failed: %w", err)
|
||||
}
|
||||
|
||||
// Get user data from database
|
||||
var username, email, roles string
|
||||
var userLevel int
|
||||
query := `SELECT username, email, user_level, COALESCE(roles, '') FROM users WHERE id = $1 AND is_active = true`
|
||||
err = a.db.QueryRowContext(ctx, query, userID).Scan(&username, &email, &userLevel, &roles)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get user data: %w", err)
|
||||
// Build request JSON for passkey login stored procedure
|
||||
reqData := map[string]any{
|
||||
"user_id": userID,
|
||||
}
|
||||
|
||||
// Generate session token
|
||||
sessionToken := "sess_" + generateRandomString(32) + "_" + fmt.Sprintf("%d", time.Now().Unix())
|
||||
expiresAt := time.Now().Add(24 * time.Hour)
|
||||
|
||||
// Extract IP and user agent from claims
|
||||
ipAddress := ""
|
||||
userAgent := ""
|
||||
if req.Claims != nil {
|
||||
if ip, ok := req.Claims["ip_address"].(string); ok {
|
||||
ipAddress = ip
|
||||
reqData["ip_address"] = ip
|
||||
}
|
||||
if ua, ok := req.Claims["user_agent"].(string); ok {
|
||||
userAgent = ua
|
||||
reqData["user_agent"] = ua
|
||||
}
|
||||
}
|
||||
|
||||
// Create session
|
||||
insertQuery := `INSERT INTO user_sessions (session_token, user_id, expires_at, ip_address, user_agent, last_activity_at)
|
||||
VALUES ($1, $2, $3, $4, $5, now())`
|
||||
_, err = a.db.ExecContext(ctx, insertQuery, sessionToken, userID, expiresAt, ipAddress, userAgent)
|
||||
reqJSON, err := json.Marshal(reqData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create session: %w", err)
|
||||
return nil, fmt.Errorf("failed to marshal passkey login request: %w", err)
|
||||
}
|
||||
|
||||
// Update last login
|
||||
updateQuery := `UPDATE users SET last_login_at = now() WHERE id = $1`
|
||||
_, _ = a.db.ExecContext(ctx, updateQuery, userID)
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
var dataJSON sql.NullString
|
||||
|
||||
// Return login response
|
||||
return &LoginResponse{
|
||||
Token: sessionToken,
|
||||
User: &UserContext{
|
||||
UserID: userID,
|
||||
UserName: username,
|
||||
Email: email,
|
||||
UserLevel: userLevel,
|
||||
SessionID: sessionToken,
|
||||
Roles: parseRoles(roles),
|
||||
},
|
||||
ExpiresIn: int64(24 * time.Hour.Seconds()),
|
||||
}, nil
|
||||
runPasskeyQuery := func() error {
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.PasskeyLogin)
|
||||
return a.getDB().QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
|
||||
}
|
||||
err = runPasskeyQuery()
|
||||
if isDBClosed(err) {
|
||||
if reconnErr := a.reconnectDB(); reconnErr == nil {
|
||||
err = runPasskeyQuery()
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("passkey login query failed: %w", err)
|
||||
}
|
||||
|
||||
if !success {
|
||||
if errorMsg.Valid {
|
||||
return nil, fmt.Errorf("%s", errorMsg.String)
|
||||
}
|
||||
return nil, fmt.Errorf("passkey login failed")
|
||||
}
|
||||
|
||||
var response LoginResponse
|
||||
if err := json.Unmarshal([]byte(dataJSON.String), &response); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse passkey login response: %w", err)
|
||||
}
|
||||
|
||||
return &response, nil
|
||||
}
|
||||
|
||||
// GetPasskeyCredentials returns all passkey credentials for a user
|
||||
|
||||
222
pkg/security/sql_names.go
Normal file
222
pkg/security/sql_names.go
Normal file
@@ -0,0 +1,222 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"regexp"
|
||||
)
|
||||
|
||||
var validSQLIdentifier = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`)
|
||||
|
||||
// SQLNames defines all configurable SQL stored procedure and table names
|
||||
// used by the security package. Override individual fields to remap
|
||||
// to custom database objects. Use DefaultSQLNames() for baseline defaults,
|
||||
// and MergeSQLNames() to apply partial overrides.
|
||||
type SQLNames struct {
|
||||
// Auth procedures (DatabaseAuthenticator)
|
||||
Login string // default: "resolvespec_login"
|
||||
Register string // default: "resolvespec_register"
|
||||
Logout string // default: "resolvespec_logout"
|
||||
Session string // default: "resolvespec_session"
|
||||
SessionUpdate string // default: "resolvespec_session_update"
|
||||
RefreshToken string // default: "resolvespec_refresh_token"
|
||||
|
||||
// JWT procedures (JWTAuthenticator)
|
||||
JWTLogin string // default: "resolvespec_jwt_login"
|
||||
JWTLogout string // default: "resolvespec_jwt_logout"
|
||||
|
||||
// Security policy procedures
|
||||
ColumnSecurity string // default: "resolvespec_column_security"
|
||||
RowSecurity string // default: "resolvespec_row_security"
|
||||
|
||||
// TOTP procedures (DatabaseTwoFactorProvider)
|
||||
TOTPEnable string // default: "resolvespec_totp_enable"
|
||||
TOTPDisable string // default: "resolvespec_totp_disable"
|
||||
TOTPGetStatus string // default: "resolvespec_totp_get_status"
|
||||
TOTPGetSecret string // default: "resolvespec_totp_get_secret"
|
||||
TOTPRegenerateBackup string // default: "resolvespec_totp_regenerate_backup_codes"
|
||||
TOTPValidateBackupCode string // default: "resolvespec_totp_validate_backup_code"
|
||||
|
||||
// Passkey procedures (DatabasePasskeyProvider)
|
||||
PasskeyStoreCredential string // default: "resolvespec_passkey_store_credential"
|
||||
PasskeyGetCredsByUsername string // default: "resolvespec_passkey_get_credentials_by_username"
|
||||
PasskeyGetCredential string // default: "resolvespec_passkey_get_credential"
|
||||
PasskeyUpdateCounter string // default: "resolvespec_passkey_update_counter"
|
||||
PasskeyGetUserCredentials string // default: "resolvespec_passkey_get_user_credentials"
|
||||
PasskeyDeleteCredential string // default: "resolvespec_passkey_delete_credential"
|
||||
PasskeyUpdateName string // default: "resolvespec_passkey_update_name"
|
||||
PasskeyLogin string // default: "resolvespec_passkey_login"
|
||||
|
||||
// OAuth2 procedures (DatabaseAuthenticator OAuth2 methods)
|
||||
OAuthGetOrCreateUser string // default: "resolvespec_oauth_getorcreateuser"
|
||||
OAuthCreateSession string // default: "resolvespec_oauth_createsession"
|
||||
OAuthGetRefreshToken string // default: "resolvespec_oauth_getrefreshtoken"
|
||||
OAuthUpdateRefreshToken string // default: "resolvespec_oauth_updaterefreshtoken"
|
||||
OAuthGetUser string // default: "resolvespec_oauth_getuser"
|
||||
|
||||
}
|
||||
|
||||
// DefaultSQLNames returns an SQLNames with all default resolvespec_* values.
|
||||
func DefaultSQLNames() *SQLNames {
|
||||
return &SQLNames{
|
||||
Login: "resolvespec_login",
|
||||
Register: "resolvespec_register",
|
||||
Logout: "resolvespec_logout",
|
||||
Session: "resolvespec_session",
|
||||
SessionUpdate: "resolvespec_session_update",
|
||||
RefreshToken: "resolvespec_refresh_token",
|
||||
|
||||
JWTLogin: "resolvespec_jwt_login",
|
||||
JWTLogout: "resolvespec_jwt_logout",
|
||||
|
||||
ColumnSecurity: "resolvespec_column_security",
|
||||
RowSecurity: "resolvespec_row_security",
|
||||
|
||||
TOTPEnable: "resolvespec_totp_enable",
|
||||
TOTPDisable: "resolvespec_totp_disable",
|
||||
TOTPGetStatus: "resolvespec_totp_get_status",
|
||||
TOTPGetSecret: "resolvespec_totp_get_secret",
|
||||
TOTPRegenerateBackup: "resolvespec_totp_regenerate_backup_codes",
|
||||
TOTPValidateBackupCode: "resolvespec_totp_validate_backup_code",
|
||||
|
||||
PasskeyStoreCredential: "resolvespec_passkey_store_credential",
|
||||
PasskeyGetCredsByUsername: "resolvespec_passkey_get_credentials_by_username",
|
||||
PasskeyGetCredential: "resolvespec_passkey_get_credential",
|
||||
PasskeyUpdateCounter: "resolvespec_passkey_update_counter",
|
||||
PasskeyGetUserCredentials: "resolvespec_passkey_get_user_credentials",
|
||||
PasskeyDeleteCredential: "resolvespec_passkey_delete_credential",
|
||||
PasskeyUpdateName: "resolvespec_passkey_update_name",
|
||||
PasskeyLogin: "resolvespec_passkey_login",
|
||||
|
||||
OAuthGetOrCreateUser: "resolvespec_oauth_getorcreateuser",
|
||||
OAuthCreateSession: "resolvespec_oauth_createsession",
|
||||
OAuthGetRefreshToken: "resolvespec_oauth_getrefreshtoken",
|
||||
OAuthUpdateRefreshToken: "resolvespec_oauth_updaterefreshtoken",
|
||||
OAuthGetUser: "resolvespec_oauth_getuser",
|
||||
}
|
||||
}
|
||||
|
||||
// MergeSQLNames returns a copy of base with any non-empty fields from override applied.
|
||||
// If override is nil, a copy of base is returned.
|
||||
func MergeSQLNames(base, override *SQLNames) *SQLNames {
|
||||
if override == nil {
|
||||
copied := *base
|
||||
return &copied
|
||||
}
|
||||
merged := *base
|
||||
if override.Login != "" {
|
||||
merged.Login = override.Login
|
||||
}
|
||||
if override.Register != "" {
|
||||
merged.Register = override.Register
|
||||
}
|
||||
if override.Logout != "" {
|
||||
merged.Logout = override.Logout
|
||||
}
|
||||
if override.Session != "" {
|
||||
merged.Session = override.Session
|
||||
}
|
||||
if override.SessionUpdate != "" {
|
||||
merged.SessionUpdate = override.SessionUpdate
|
||||
}
|
||||
if override.RefreshToken != "" {
|
||||
merged.RefreshToken = override.RefreshToken
|
||||
}
|
||||
if override.JWTLogin != "" {
|
||||
merged.JWTLogin = override.JWTLogin
|
||||
}
|
||||
if override.JWTLogout != "" {
|
||||
merged.JWTLogout = override.JWTLogout
|
||||
}
|
||||
if override.ColumnSecurity != "" {
|
||||
merged.ColumnSecurity = override.ColumnSecurity
|
||||
}
|
||||
if override.RowSecurity != "" {
|
||||
merged.RowSecurity = override.RowSecurity
|
||||
}
|
||||
if override.TOTPEnable != "" {
|
||||
merged.TOTPEnable = override.TOTPEnable
|
||||
}
|
||||
if override.TOTPDisable != "" {
|
||||
merged.TOTPDisable = override.TOTPDisable
|
||||
}
|
||||
if override.TOTPGetStatus != "" {
|
||||
merged.TOTPGetStatus = override.TOTPGetStatus
|
||||
}
|
||||
if override.TOTPGetSecret != "" {
|
||||
merged.TOTPGetSecret = override.TOTPGetSecret
|
||||
}
|
||||
if override.TOTPRegenerateBackup != "" {
|
||||
merged.TOTPRegenerateBackup = override.TOTPRegenerateBackup
|
||||
}
|
||||
if override.TOTPValidateBackupCode != "" {
|
||||
merged.TOTPValidateBackupCode = override.TOTPValidateBackupCode
|
||||
}
|
||||
if override.PasskeyStoreCredential != "" {
|
||||
merged.PasskeyStoreCredential = override.PasskeyStoreCredential
|
||||
}
|
||||
if override.PasskeyGetCredsByUsername != "" {
|
||||
merged.PasskeyGetCredsByUsername = override.PasskeyGetCredsByUsername
|
||||
}
|
||||
if override.PasskeyGetCredential != "" {
|
||||
merged.PasskeyGetCredential = override.PasskeyGetCredential
|
||||
}
|
||||
if override.PasskeyUpdateCounter != "" {
|
||||
merged.PasskeyUpdateCounter = override.PasskeyUpdateCounter
|
||||
}
|
||||
if override.PasskeyGetUserCredentials != "" {
|
||||
merged.PasskeyGetUserCredentials = override.PasskeyGetUserCredentials
|
||||
}
|
||||
if override.PasskeyDeleteCredential != "" {
|
||||
merged.PasskeyDeleteCredential = override.PasskeyDeleteCredential
|
||||
}
|
||||
if override.PasskeyUpdateName != "" {
|
||||
merged.PasskeyUpdateName = override.PasskeyUpdateName
|
||||
}
|
||||
if override.PasskeyLogin != "" {
|
||||
merged.PasskeyLogin = override.PasskeyLogin
|
||||
}
|
||||
if override.OAuthGetOrCreateUser != "" {
|
||||
merged.OAuthGetOrCreateUser = override.OAuthGetOrCreateUser
|
||||
}
|
||||
if override.OAuthCreateSession != "" {
|
||||
merged.OAuthCreateSession = override.OAuthCreateSession
|
||||
}
|
||||
if override.OAuthGetRefreshToken != "" {
|
||||
merged.OAuthGetRefreshToken = override.OAuthGetRefreshToken
|
||||
}
|
||||
if override.OAuthUpdateRefreshToken != "" {
|
||||
merged.OAuthUpdateRefreshToken = override.OAuthUpdateRefreshToken
|
||||
}
|
||||
if override.OAuthGetUser != "" {
|
||||
merged.OAuthGetUser = override.OAuthGetUser
|
||||
}
|
||||
return &merged
|
||||
}
|
||||
|
||||
// ValidateSQLNames checks that all non-empty fields in names are valid SQL identifiers.
|
||||
// Returns an error if any field contains invalid characters.
|
||||
func ValidateSQLNames(names *SQLNames) error {
|
||||
v := reflect.ValueOf(names).Elem()
|
||||
typ := v.Type()
|
||||
for i := 0; i < v.NumField(); i++ {
|
||||
field := v.Field(i)
|
||||
if field.Kind() != reflect.String {
|
||||
continue
|
||||
}
|
||||
val := field.String()
|
||||
if val != "" && !validSQLIdentifier.MatchString(val) {
|
||||
return fmt.Errorf("SQLNames.%s contains invalid characters: %q", typ.Field(i).Name, val)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// resolveSQLNames merges an optional override with defaults.
|
||||
// Used by constructors that accept variadic *SQLNames.
|
||||
func resolveSQLNames(override ...*SQLNames) *SQLNames {
|
||||
if len(override) > 0 && override[0] != nil {
|
||||
return MergeSQLNames(DefaultSQLNames(), override[0])
|
||||
}
|
||||
return DefaultSQLNames()
|
||||
}
|
||||
145
pkg/security/sql_names_test.go
Normal file
145
pkg/security/sql_names_test.go
Normal file
@@ -0,0 +1,145 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestDefaultSQLNames_AllFieldsNonEmpty(t *testing.T) {
|
||||
names := DefaultSQLNames()
|
||||
v := reflect.ValueOf(names).Elem()
|
||||
typ := v.Type()
|
||||
|
||||
for i := 0; i < v.NumField(); i++ {
|
||||
field := v.Field(i)
|
||||
if field.Kind() != reflect.String {
|
||||
continue
|
||||
}
|
||||
if field.String() == "" {
|
||||
t.Errorf("DefaultSQLNames().%s is empty", typ.Field(i).Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeSQLNames_PartialOverride(t *testing.T) {
|
||||
base := DefaultSQLNames()
|
||||
override := &SQLNames{
|
||||
Login: "custom_login",
|
||||
TOTPEnable: "custom_totp_enable",
|
||||
PasskeyLogin: "custom_passkey_login",
|
||||
}
|
||||
|
||||
merged := MergeSQLNames(base, override)
|
||||
|
||||
if merged.Login != "custom_login" {
|
||||
t.Errorf("MergeSQLNames().Login = %q, want %q", merged.Login, "custom_login")
|
||||
}
|
||||
if merged.TOTPEnable != "custom_totp_enable" {
|
||||
t.Errorf("MergeSQLNames().TOTPEnable = %q, want %q", merged.TOTPEnable, "custom_totp_enable")
|
||||
}
|
||||
if merged.PasskeyLogin != "custom_passkey_login" {
|
||||
t.Errorf("MergeSQLNames().PasskeyLogin = %q, want %q", merged.PasskeyLogin, "custom_passkey_login")
|
||||
}
|
||||
// Non-overridden fields should retain defaults
|
||||
if merged.Logout != "resolvespec_logout" {
|
||||
t.Errorf("MergeSQLNames().Logout = %q, want %q", merged.Logout, "resolvespec_logout")
|
||||
}
|
||||
if merged.Session != "resolvespec_session" {
|
||||
t.Errorf("MergeSQLNames().Session = %q, want %q", merged.Session, "resolvespec_session")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeSQLNames_NilOverride(t *testing.T) {
|
||||
base := DefaultSQLNames()
|
||||
merged := MergeSQLNames(base, nil)
|
||||
|
||||
// Should be a copy, not the same pointer
|
||||
if merged == base {
|
||||
t.Error("MergeSQLNames with nil override should return a copy, not the same pointer")
|
||||
}
|
||||
|
||||
// All values should match
|
||||
v1 := reflect.ValueOf(base).Elem()
|
||||
v2 := reflect.ValueOf(merged).Elem()
|
||||
typ := v1.Type()
|
||||
|
||||
for i := 0; i < v1.NumField(); i++ {
|
||||
f1 := v1.Field(i)
|
||||
f2 := v2.Field(i)
|
||||
if f1.Kind() != reflect.String {
|
||||
continue
|
||||
}
|
||||
if f1.String() != f2.String() {
|
||||
t.Errorf("MergeSQLNames(base, nil).%s = %q, want %q", typ.Field(i).Name, f2.String(), f1.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeSQLNames_DoesNotMutateBase(t *testing.T) {
|
||||
base := DefaultSQLNames()
|
||||
originalLogin := base.Login
|
||||
|
||||
override := &SQLNames{Login: "custom_login"}
|
||||
_ = MergeSQLNames(base, override)
|
||||
|
||||
if base.Login != originalLogin {
|
||||
t.Errorf("MergeSQLNames mutated base: Login = %q, want %q", base.Login, originalLogin)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeSQLNames_AllFieldsMerged(t *testing.T) {
|
||||
base := DefaultSQLNames()
|
||||
override := &SQLNames{}
|
||||
v := reflect.ValueOf(override).Elem()
|
||||
for i := 0; i < v.NumField(); i++ {
|
||||
if v.Field(i).Kind() == reflect.String {
|
||||
v.Field(i).SetString("custom_sentinel")
|
||||
}
|
||||
}
|
||||
|
||||
merged := MergeSQLNames(base, override)
|
||||
mv := reflect.ValueOf(merged).Elem()
|
||||
typ := mv.Type()
|
||||
for i := 0; i < mv.NumField(); i++ {
|
||||
if mv.Field(i).Kind() != reflect.String {
|
||||
continue
|
||||
}
|
||||
if mv.Field(i).String() != "custom_sentinel" {
|
||||
t.Errorf("MergeSQLNames did not merge field %s", typ.Field(i).Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateSQLNames_Valid(t *testing.T) {
|
||||
names := DefaultSQLNames()
|
||||
if err := ValidateSQLNames(names); err != nil {
|
||||
t.Errorf("ValidateSQLNames(defaults) error = %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateSQLNames_Invalid(t *testing.T) {
|
||||
names := DefaultSQLNames()
|
||||
names.Login = "resolvespec_login; DROP TABLE users; --"
|
||||
|
||||
err := ValidateSQLNames(names)
|
||||
if err == nil {
|
||||
t.Error("ValidateSQLNames should reject names with invalid characters")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveSQLNames_NoOverride(t *testing.T) {
|
||||
names := resolveSQLNames()
|
||||
if names.Login != "resolvespec_login" {
|
||||
t.Errorf("resolveSQLNames().Login = %q, want default", names.Login)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveSQLNames_WithOverride(t *testing.T) {
|
||||
names := resolveSQLNames(&SQLNames{Login: "custom_login"})
|
||||
if names.Login != "custom_login" {
|
||||
t.Errorf("resolveSQLNames().Login = %q, want %q", names.Login, "custom_login")
|
||||
}
|
||||
if names.Logout != "resolvespec_logout" {
|
||||
t.Errorf("resolveSQLNames().Logout = %q, want default", names.Logout)
|
||||
}
|
||||
}
|
||||
@@ -9,23 +9,23 @@ import (
|
||||
)
|
||||
|
||||
// DatabaseTwoFactorProvider implements TwoFactorAuthProvider using PostgreSQL stored procedures
|
||||
// Requires stored procedures: resolvespec_totp_enable, resolvespec_totp_disable,
|
||||
// resolvespec_totp_get_status, resolvespec_totp_get_secret,
|
||||
// resolvespec_totp_regenerate_backup_codes, resolvespec_totp_validate_backup_code
|
||||
// Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults)
|
||||
// See totp_database_schema.sql for procedure definitions
|
||||
type DatabaseTwoFactorProvider struct {
|
||||
db *sql.DB
|
||||
totpGen *TOTPGenerator
|
||||
db *sql.DB
|
||||
totpGen *TOTPGenerator
|
||||
sqlNames *SQLNames
|
||||
}
|
||||
|
||||
// NewDatabaseTwoFactorProvider creates a new database-backed 2FA provider
|
||||
func NewDatabaseTwoFactorProvider(db *sql.DB, config *TwoFactorConfig) *DatabaseTwoFactorProvider {
|
||||
func NewDatabaseTwoFactorProvider(db *sql.DB, config *TwoFactorConfig, names ...*SQLNames) *DatabaseTwoFactorProvider {
|
||||
if config == nil {
|
||||
config = DefaultTwoFactorConfig()
|
||||
}
|
||||
return &DatabaseTwoFactorProvider{
|
||||
db: db,
|
||||
totpGen: NewTOTPGenerator(config),
|
||||
db: db,
|
||||
totpGen: NewTOTPGenerator(config),
|
||||
sqlNames: resolveSQLNames(names...),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -76,7 +76,7 @@ func (p *DatabaseTwoFactorProvider) Enable2FA(userID int, secret string, backupC
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
|
||||
query := `SELECT p_success, p_error FROM resolvespec_totp_enable($1, $2, $3::jsonb)`
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1, $2, $3::jsonb)`, p.sqlNames.TOTPEnable)
|
||||
err = p.db.QueryRow(query, userID, secret, string(codesJSON)).Scan(&success, &errorMsg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("enable 2FA query failed: %w", err)
|
||||
@@ -97,7 +97,7 @@ func (p *DatabaseTwoFactorProvider) Disable2FA(userID int) error {
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
|
||||
query := `SELECT p_success, p_error FROM resolvespec_totp_disable($1)`
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1)`, p.sqlNames.TOTPDisable)
|
||||
err := p.db.QueryRow(query, userID).Scan(&success, &errorMsg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("disable 2FA query failed: %w", err)
|
||||
@@ -119,7 +119,7 @@ func (p *DatabaseTwoFactorProvider) Get2FAStatus(userID int) (bool, error) {
|
||||
var errorMsg sql.NullString
|
||||
var enabled bool
|
||||
|
||||
query := `SELECT p_success, p_error, p_enabled FROM resolvespec_totp_get_status($1)`
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_enabled FROM %s($1)`, p.sqlNames.TOTPGetStatus)
|
||||
err := p.db.QueryRow(query, userID).Scan(&success, &errorMsg, &enabled)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("get 2FA status query failed: %w", err)
|
||||
@@ -141,7 +141,7 @@ func (p *DatabaseTwoFactorProvider) Get2FASecret(userID int) (string, error) {
|
||||
var errorMsg sql.NullString
|
||||
var secret sql.NullString
|
||||
|
||||
query := `SELECT p_success, p_error, p_secret FROM resolvespec_totp_get_secret($1)`
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_secret FROM %s($1)`, p.sqlNames.TOTPGetSecret)
|
||||
err := p.db.QueryRow(query, userID).Scan(&success, &errorMsg, &secret)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("get 2FA secret query failed: %w", err)
|
||||
@@ -185,7 +185,7 @@ func (p *DatabaseTwoFactorProvider) GenerateBackupCodes(userID int, count int) (
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
|
||||
query := `SELECT p_success, p_error FROM resolvespec_totp_regenerate_backup_codes($1, $2::jsonb)`
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1, $2::jsonb)`, p.sqlNames.TOTPRegenerateBackup)
|
||||
err = p.db.QueryRow(query, userID, string(codesJSON)).Scan(&success, &errorMsg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("regenerate backup codes query failed: %w", err)
|
||||
@@ -212,7 +212,7 @@ func (p *DatabaseTwoFactorProvider) ValidateBackupCode(userID int, code string)
|
||||
var errorMsg sql.NullString
|
||||
var valid bool
|
||||
|
||||
query := `SELECT p_success, p_error, p_valid FROM resolvespec_totp_validate_backup_code($1, $2)`
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_valid FROM %s($1, $2)`, p.sqlNames.TOTPValidateBackupCode)
|
||||
err := p.db.QueryRow(query, userID, codeHash).Scan(&success, &errorMsg, &valid)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("validate backup code query failed: %w", err)
|
||||
|
||||
Reference in New Issue
Block a user