Compare commits

..

33 Commits

Author SHA1 Message Date
Hein
354ed2a8dc feat(db): add fallback metric entity handling for unknown targets
Some checks failed
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Successful in -30m5s
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Successful in -29m35s
Build , Vet Test, and Lint / Lint Code (push) Successful in -29m19s
Build , Vet Test, and Lint / Build (push) Successful in -29m43s
Tests / Integration Tests (push) Failing after -30m34s
Tests / Unit Tests (push) Successful in -30m8s
* implement fallbackMetricEntityFromQuery for query sanitization
* add tests for fallback metric entity and sanitization logic
2026-04-10 16:00:22 +02:00
Hein
dfb63c3328 refactor(db): remove metrics enabling methods from adapters 2026-04-10 14:13:15 +02:00
Hein
e8d0ab28c3 feat(db): add query metrics tracking for database operations
* Introduced metrics tracking for SELECT, INSERT, UPDATE, and DELETE operations.
* Added methods to enable or disable metrics on the PgSQLAdapter.
* Created a new query_metrics.go file to handle metrics recording logic.
* Updated interfaces and implementations to support schema and entity tracking.
* Added tests to verify metrics recording functionality.
2026-04-10 13:51:46 +02:00
Hein
4fc25c60ae fix(db): correct connection pool assignment in GORM adapter
Some checks failed
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Successful in -29m43s
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Successful in -29m6s
Build , Vet Test, and Lint / Lint Code (push) Successful in -29m11s
Build , Vet Test, and Lint / Build (push) Successful in -29m33s
Tests / Unit Tests (push) Successful in -30m4s
Tests / Integration Tests (push) Failing after -30m13s
2026-04-10 11:20:44 +02:00
Hein
16a960d973 feat(db): add reconnect logic for database adapters
* Implement reconnect functionality in GormAdapter and other database adapters.
* Introduce a DBFactory to handle reconnections.
* Update health check logic to skip reconnects for transient failures.
* Add tests for reconnect behavior in DatabaseAuthenticator.
2026-04-10 11:18:39 +02:00
Hein
2afee9d238 fix(db): handle database reconnection in transactions 2026-04-10 08:42:41 +02:00
Hein Puth (Warkanum)
1e89124c97 Merge pull request #18 from bitechdev/feature-auth-mcp
Some checks failed
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Successful in -30m1s
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Successful in -29m29s
Build , Vet Test, and Lint / Build (push) Successful in -29m37s
Build , Vet Test, and Lint / Lint Code (push) Successful in -28m58s
Tests / Unit Tests (push) Successful in -30m23s
Tests / Integration Tests (push) Failing after -30m31s
feat(security): implement OAuth2 authorization server with database s…
2026-04-09 16:18:18 +02:00
copilot-swe-agent[bot]
ca0545e144 fix(security): address validation review comments - mutex safety and issuer normalization
Agent-Logs-Url: https://github.com/bitechdev/ResolveSpec/sessions/e886b781-c910-425f-aa6f-06d13c46dcc7

Co-authored-by: warkanum <208308+warkanum@users.noreply.github.com>
2026-04-09 14:07:45 +00:00
copilot-swe-agent[bot]
850ad2b2ab fix(security): address all OAuth2 PR review issues
Agent-Logs-Url: https://github.com/bitechdev/ResolveSpec/sessions/e886b781-c910-425f-aa6f-06d13c46dcc7

Co-authored-by: warkanum <208308+warkanum@users.noreply.github.com>
2026-04-09 14:04:53 +00:00
Hein
2a2e33da0c Merge branch 'main' of https://github.com/bitechdev/ResolveSpec into feature-auth-mcp 2026-04-09 15:52:26 +02:00
Hein Puth (Warkanum)
17808a8121 Merge pull request #19 from bitechdev/feature-keystore
Feature keystore
2026-04-09 15:50:36 +02:00
Hein
134ff85c59 Merge branch 'main' of https://github.com/bitechdev/ResolveSpec into feature-keystore 2026-04-09 15:47:54 +02:00
Hein
bacddc58a6 style(recursive_crud): remove unnecessary blank line 2026-04-09 15:37:13 +02:00
Hein
f1ad83d966 feat(reflection): add JSON to DB column name mapping functions
* Implement BuildJSONToDBColumnMap for translating JSON keys to DB column names
* Enhance GetColumnName to extract column names with priority
* Update filterValidFields to utilize new mapping for improved data handling
* Fix TestToSnakeCase expected values for consistency
2026-04-09 15:36:52 +02:00
Hein
79a3912f93 fix(db): improve database connection handling and reconnection logic
* Added a database factory function to allow reconnection when the database is closed.
* Implemented mutex locks for safe concurrent access to the database connection.
* Updated all database query methods to handle reconnection attempts on closed connections.
* Enhanced error handling for database operations across multiple providers.
2026-04-09 09:19:28 +02:00
6502b55797 feat(security): implement OAuth2 authorization server with database support
- Add OAuthServer for handling OAuth2 flows including authorization, token exchange, and client registration.
- Introduce DatabaseAuthenticator for persisting clients and authorization codes.
- Implement SQL procedures for client registration, code saving, and token introspection.
- Support for external OAuth2 providers and PKCE (Proof Key for Code Exchange).
2026-04-07 22:56:05 +02:00
aa095d6bfd fix(tests): replace panic with log.Fatal for better error handling
Some checks failed
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Successful in -30m17s
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Successful in -29m52s
Build , Vet Test, and Lint / Build (push) Successful in -29m52s
Build , Vet Test, and Lint / Lint Code (push) Failing after -29m23s
Tests / Integration Tests (push) Failing after -30m46s
Tests / Unit Tests (push) Successful in -28m51s
2026-04-07 20:38:22 +02:00
ea5bb38ee4 feat(handler): update to use static base path for SSE server 2026-04-07 20:03:43 +02:00
c2e2c9b873 feat(transport): add streamable HTTP transport for MCP 2026-04-07 19:52:38 +02:00
4adf94fe37 feat(go.mod): add mcp-go dependency for enhanced functionality 2026-04-07 19:09:51 +02:00
Hein
a9bf08f58b feat(security): implement keystore for user authentication keys
* Add ConfigKeyStore for in-memory key management
* Introduce DatabaseKeyStore for PostgreSQL-backed key storage
* Create KeyStoreAuthenticator for API key validation
* Define SQL procedures for key management in PostgreSQL
* Document keystore functionality and usage in KEYSTORE.md
2026-04-07 17:09:17 +02:00
Hein
405a04a192 feat(security): integrate security hooks for access control
Some checks failed
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Successful in -30m3s
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Successful in -29m36s
Build , Vet Test, and Lint / Lint Code (push) Failing after -30m6s
Build , Vet Test, and Lint / Build (push) Successful in -29m58s
Tests / Unit Tests (push) Successful in -30m22s
Tests / Integration Tests (push) Failing after -30m41s
* Add security hooks for per-entity operation rules and row/column-level security.
* Implement annotation tool for storing and retrieving freeform annotations.
* Enhance handler to support model registration with access rules.
2026-04-07 15:53:12 +02:00
Hein
c1b16d363a feat(db): add DB method to sqlConnection and mongoConnection
Some checks failed
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Successful in -30m22s
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Successful in -29m59s
Build , Vet Test, and Lint / Lint Code (push) Failing after -30m11s
Build , Vet Test, and Lint / Build (push) Successful in -30m12s
Tests / Unit Tests (push) Successful in -30m49s
Tests / Integration Tests (push) Failing after -30m59s
2026-04-01 15:34:09 +02:00
Hein
568df8c6d6 feat(security): add configurable SQL procedure names
Some checks failed
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Successful in -25m9s
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Successful in -24m29s
Build , Vet Test, and Lint / Build (push) Successful in -30m5s
Build , Vet Test, and Lint / Lint Code (push) Failing after -28m58s
Tests / Integration Tests (push) Failing after -30m26s
Tests / Unit Tests (push) Successful in -28m7s
* Introduce SQLNames struct to define stored procedure names.
* Update DatabaseAuthenticator, JWTAuthenticator, and other providers to use SQLNames for procedure calls.
* Remove hardcoded procedure names for better flexibility and customization.
* Implement validation for SQL names to ensure they are valid identifiers.
* Add tests for SQLNames functionality and merging behavior.
2026-03-31 14:25:59 +02:00
Hein
aa362c77da fix(cursor): trim parentheses from sort column names 2026-03-27 15:07:10 +02:00
Hein
1641eaf278 feat(resolvemcp): enhance handler with configuration support
* Introduce Config struct for BaseURL and BasePath settings
* Update handler creation functions to accept configuration
* Modify SSEServer to use dynamic base URL detection
* Adjust route setup functions to utilize BasePath from config
2026-03-27 13:56:03 +02:00
Hein
200a03c225 feat(resolvemcp): add SSE server and bunrouter setup functions
* Introduce SSEServer method for creating an SSE server bound to the handler.
* Add SetupBunRouterRoutes function to mount MCP HTTP/SSE endpoints on bunrouter.
* Update README with usage examples for new features.
2026-03-27 13:28:03 +02:00
Hein
7ef9cf39d3 style(tools): simplify string formatting in descriptions 2026-03-27 13:10:50 +02:00
Hein
7f6410f665 feat(resolvemcp): add support for join-column sorting in cursor pagination
* Enhance getCursorFilter to accept join clauses for sorting
* Update resolveColumn to handle joined columns
* Modify tests to validate new join functionality
2026-03-27 13:10:42 +02:00
Hein
835bbb0727 style(hooks): reorder fields in HookContext for consistency 2026-03-27 12:57:30 +02:00
Hein
047a1cc187 feat(resolvemcp): add hook system for model operations
* Implement hooks for CRUD operations: before/after handle, read, create, update, delete.
* Introduce HookContext and HookRegistry for managing hooks.
* Allow registration and execution of multiple hooks per operation.

feat(resolvemcp): implement MCP tools for CRUD operations
* Register tools for reading, creating, updating, and deleting records.
* Define tool arguments and handle requests with appropriate responses.
* Support for resource registration with metadata.

fix(restheadspec): enhance cursor handling for joins
* Improve cursor filter generation to support lateral joins.
* Update join alias extraction to handle lateral joins correctly.
* Ensure cursor filters do not contain empty comparisons.

test(restheadspec): add tests for cursor filters and join alias extraction
* Create tests for lateral join scenarios in cursor filter generation.
* Validate join alias extraction for various join types, including lateral joins.
2026-03-27 12:57:08 +02:00
Hein
7a498edab7 fix(headers): enhance relation name resolution logic
* Allow resolution for both regular headers and X-Files.
* Introduce join-key-aware resolution for disambiguation.
* Add new function to handle multiple fields pointing to the same type.
2026-03-25 12:09:03 +02:00
Hein
f10bb0827e fix(sql_helpers): ensure case-insensitive matching for allowed prefixes 2026-03-25 10:57:42 +02:00
62 changed files with 8860 additions and 605 deletions

1
.gitignore vendored
View File

@@ -29,3 +29,4 @@ test.db
tests/data/
node_modules/
resolvespec-js/dist/
.codex

View File

@@ -9,6 +9,7 @@ ResolveSpec is a flexible and powerful REST API specification and implementation
3. **FuncSpec** - Header-based API to map and call API's to sql functions
4. **WebSocketSpec** - Real-time bidirectional communication with full CRUD operations
5. **MQTTSpec** - MQTT-based API ideal for IoT and mobile applications
6. **ResolveMCP** - Model Context Protocol (MCP) server that exposes models as AI-consumable tools and resources over HTTP/SSE
All share the same core architecture and provide dynamic data querying, relationship preloading, and complex filtering.
@@ -21,6 +22,7 @@ All share the same core architecture and provide dynamic data querying, relation
* [Quick Start](#quick-start)
* [ResolveSpec (Body-Based API)](#resolvespec---body-based-api)
* [RestHeadSpec (Header-Based API)](#restheadspec---header-based-api)
* [ResolveMCP (MCP Server)](#resolvemcp---mcp-server)
* [Architecture](#architecture)
* [API Structure](#api-structure)
* [RestHeadSpec Overview](#restheadspec-header-based-api)
@@ -50,6 +52,15 @@ All share the same core architecture and provide dynamic data querying, relation
* **🆕 Backward Compatible**: Existing code works without changes
* **🆕 Better Testing**: Mockable interfaces for easy unit testing
### ResolveMCP (v3.2+)
* **🆕 MCP Server**: Expose any registered database model as Model Context Protocol tools and resources
* **🆕 AI-Ready Descriptions**: Tool descriptions include the full column schema, primary key, nullable flags, and relations — giving AI models everything they need to query correctly without guessing
* **🆕 Four Tools Per Model**: `read_`, `create_`, `update_`, `delete_` tools auto-registered per model
* **🆕 Full Query Support**: Filters, sort, limit/offset, cursor pagination, column selection, and relation preloading all available as tool parameters
* **🆕 HTTP/SSE Transport**: Standards-compliant SSE transport for use with Claude Desktop, Cursor, and any MCP-compatible client
* **🆕 Lifecycle Hooks**: Same Before/After hook system as ResolveSpec for auth and side-effects
### RestHeadSpec (v2.1+)
* **🆕 Header-Based API**: All query options passed via HTTP headers instead of request body
@@ -190,6 +201,40 @@ restheadspec.SetupMuxRoutes(router, handler, nil)
For complete documentation, see [pkg/restheadspec/README.md](pkg/restheadspec/README.md).
### ResolveMCP (MCP Server)
ResolveMCP exposes registered models as Model Context Protocol tools so AI models (Claude, Cursor, etc.) can query and mutate your database directly:
```go
import "github.com/bitechdev/ResolveSpec/pkg/resolvemcp"
// Create handler
handler := resolvemcp.NewHandlerWithGORM(db)
// Register models — must be done BEFORE Build()
handler.RegisterModel("public", "users", &User{})
handler.RegisterModel("public", "posts", &Post{})
// Finalize: registers MCP tools and resources
handler.Build()
// Mount SSE transport on your existing router
router := mux.NewRouter()
resolvemcp.SetupMuxRoutes(router, handler, "http://localhost:8080")
// MCP clients connect to:
// SSE stream: GET http://localhost:8080/mcp/sse
// Messages: POST http://localhost:8080/mcp/message
//
// Auto-registered tools per model:
// read_public_users — filter, sort, paginate, preload
// create_public_users — insert a new record
// update_public_users — update a record by ID
// delete_public_users — delete a record by ID
```
For complete documentation, see [pkg/resolvemcp/README.md](pkg/resolvemcp/README.md) (if present) or the package source.
## Architecture
### Two Complementary APIs
@@ -344,6 +389,19 @@ Alternative REST API where query options are passed via HTTP headers.
For complete documentation, see [pkg/restheadspec/README.md](pkg/restheadspec/README.md).
#### ResolveMCP - MCP Server
Expose any registered model as Model Context Protocol tools and resources consumable by AI models over HTTP/SSE.
**Key Features**:
- Four tools per model: `read_`, `create_`, `update_`, `delete_`
- Rich AI-readable descriptions: column names, types, primary key, nullable flags, and preloadable relations
- Full query support: filters, sort, limit/offset, cursor pagination, column selection, preloads
- HTTP/SSE transport compatible with Claude Desktop, Cursor, and any MCP client
- Same Before/After lifecycle hooks as ResolveSpec
For complete documentation, see [pkg/resolvemcp/](pkg/resolvemcp/).
#### FuncSpec - Function-Based SQL API
Execute SQL functions and queries through a simple HTTP API with header-based parameters.
@@ -529,7 +587,18 @@ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file
## What's New
### v3.1 (Latest - February 2026)
### v3.2 (Latest - March 2026)
**ResolveMCP - Model Context Protocol Server (🆕)**:
* **MCP Tools**: Four tools auto-registered per model (`read_`, `create_`, `update_`, `delete_`) over HTTP/SSE transport
* **AI-Ready Descriptions**: Full column schema, primary key, nullable flags, and relation names surfaced in tool descriptions so AI models can query without guessing
* **Full Query Support**: Filters, sort, limit/offset, cursor pagination, column selection, and relation preloading all available as tool parameters
* **HTTP/SSE Transport**: Standards-compliant transport compatible with Claude Desktop, Cursor, and any MCP 2024-11-05 client
* **Lifecycle Hooks**: Same Before/After hook system as ResolveSpec for auth, auditing, and side-effects
* **MCP Resources**: Each model also exposed as a named resource for direct data access by AI clients
### v3.1 (February 2026)
**SQLite Schema Translation (🆕)**:

5
go.mod
View File

@@ -15,6 +15,7 @@ require (
github.com/gorilla/websocket v1.5.3
github.com/jackc/pgx/v5 v5.8.0
github.com/klauspost/compress v1.18.2
github.com/mark3labs/mcp-go v0.46.0
github.com/mattn/go-sqlite3 v1.14.33
github.com/microsoft/go-mssqldb v1.9.5
github.com/mochi-mqtt/server/v2 v2.7.9
@@ -40,6 +41,7 @@ require (
go.opentelemetry.io/otel/trace v1.38.0
go.uber.org/zap v1.27.1
golang.org/x/crypto v0.46.0
golang.org/x/oauth2 v0.34.0
golang.org/x/time v0.14.0
gorm.io/driver/postgres v1.6.0
gorm.io/driver/sqlite v1.6.0
@@ -78,6 +80,7 @@ require (
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect
github.com/golang-sql/sqlexp v0.1.0 // indirect
github.com/golang/snappy v1.0.0 // indirect
github.com/google/jsonschema-go v0.4.2 // indirect
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
@@ -131,6 +134,7 @@ require (
github.com/xdg-go/pbkdf2 v1.0.0 // indirect
github.com/xdg-go/scram v1.2.0 // indirect
github.com/xdg-go/stringprep v1.0.4 // indirect
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 // indirect
github.com/yusufpapurcu/wmi v1.2.4 // indirect
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
@@ -143,7 +147,6 @@ require (
golang.org/x/exp v0.0.0-20251219203646-944ab1f22d93 // indirect
golang.org/x/mod v0.31.0 // indirect
golang.org/x/net v0.48.0 // indirect
golang.org/x/oauth2 v0.34.0 // indirect
golang.org/x/sync v0.19.0 // indirect
golang.org/x/sys v0.39.0 // indirect
golang.org/x/text v0.32.0 // indirect

6
go.sum
View File

@@ -120,6 +120,8 @@ github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/jsonschema-go v0.4.2 h1:tmrUohrwoLZZS/P3x7ex0WAVknEkBZM46iALbcqoRA8=
github.com/google/jsonschema-go v0.4.2/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
@@ -173,6 +175,8 @@ github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I=
github.com/magiconair/properties v1.8.10 h1:s31yESBquKXCV9a/ScB3ESkOjUYYv+X0rg8SYxI99mE=
github.com/magiconair/properties v1.8.10/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
github.com/mark3labs/mcp-go v0.46.0 h1:8KRibF4wcKejbLsHxCA/QBVUr5fQ9nwz/n8lGqmaALo=
github.com/mark3labs/mcp-go v0.46.0/go.mod h1:JKTC7R2LLVagkEWK7Kwu7DbmA6iIvnNAod6yrHiQMag=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-sqlite3 v1.14.33 h1:A5blZ5ulQo2AtayQ9/limgHEkFreKj1Dv226a1K73s0=
@@ -326,6 +330,8 @@ github.com/xdg-go/scram v1.2.0 h1:bYKF2AEwG5rqd1BumT4gAnvwU/M9nBp2pTSxeZw7Wvs=
github.com/xdg-go/scram v1.2.0/go.mod h1:3dlrS0iBaWKYVt2ZfA4cj48umJZ+cAEbR6/SjLA88I8=
github.com/xdg-go/stringprep v1.0.4 h1:XLI/Ng3O1Atzq0oBs3TWm+5ZVgkq2aqdlvP9JtoZ6c8=
github.com/xdg-go/stringprep v1.0.4/go.mod h1:mPGuuIYwz7CmR2bT9j4GbQqutWS1zV24gijq1dTyGkM=
github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 h1:ilQV1hzziu+LLM3zUTJ0trRztfwgjqKnBWNtSRkbmwM=
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78/go.mod h1:aL8wCCfTfSfmXjznFBSZNN13rSJjlIOI1fUNAtF7rmI=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=

View File

@@ -6,6 +6,7 @@ import (
"fmt"
"reflect"
"strings"
"sync"
"time"
"github.com/uptrace/bun"
@@ -94,22 +95,57 @@ func debugScanIntoStruct(rows interface{}, dest interface{}) error {
// BunAdapter adapts Bun to work with our Database interface
// This demonstrates how the abstraction works with different ORMs
type BunAdapter struct {
db *bun.DB
driverName string
db *bun.DB
dbMu sync.RWMutex
dbFactory func() (*bun.DB, error)
driverName string
metricsEnabled bool
}
// NewBunAdapter creates a new Bun adapter
func NewBunAdapter(db *bun.DB) *BunAdapter {
adapter := &BunAdapter{db: db}
adapter := &BunAdapter{db: db, metricsEnabled: true}
// Initialize driver name
adapter.driverName = adapter.DriverName()
return adapter
}
// WithDBFactory configures a factory used to reopen the database connection if it is closed.
func (b *BunAdapter) WithDBFactory(factory func() (*bun.DB, error)) *BunAdapter {
b.dbFactory = factory
return b
}
// SetMetricsEnabled enables or disables query metrics for this adapter.
func (b *BunAdapter) SetMetricsEnabled(enabled bool) *BunAdapter {
b.metricsEnabled = enabled
return b
}
func (b *BunAdapter) getDB() *bun.DB {
b.dbMu.RLock()
defer b.dbMu.RUnlock()
return b.db
}
func (b *BunAdapter) reconnectDB() error {
if b.dbFactory == nil {
return fmt.Errorf("no db factory configured for reconnect")
}
newDB, err := b.dbFactory()
if err != nil {
return err
}
b.dbMu.Lock()
b.db = newDB
b.dbMu.Unlock()
return nil
}
// EnableQueryDebug enables query debugging which logs all SQL queries including preloads
// This is useful for debugging preload queries that may be failing
func (b *BunAdapter) EnableQueryDebug() {
b.db.AddQueryHook(&QueryDebugHook{})
b.getDB().AddQueryHook(&QueryDebugHook{})
logger.Info("Bun query debug mode enabled - all SQL queries will be logged")
}
@@ -130,22 +166,23 @@ func (b *BunAdapter) DisableQueryDebug() {
func (b *BunAdapter) NewSelect() common.SelectQuery {
return &BunSelectQuery{
query: b.db.NewSelect(),
db: b.db,
driverName: b.driverName,
query: b.getDB().NewSelect(),
db: b.db,
driverName: b.driverName,
metricsEnabled: b.metricsEnabled,
}
}
func (b *BunAdapter) NewInsert() common.InsertQuery {
return &BunInsertQuery{query: b.db.NewInsert()}
return &BunInsertQuery{query: b.getDB().NewInsert(), driverName: b.driverName, metricsEnabled: b.metricsEnabled}
}
func (b *BunAdapter) NewUpdate() common.UpdateQuery {
return &BunUpdateQuery{query: b.db.NewUpdate()}
return &BunUpdateQuery{query: b.getDB().NewUpdate(), driverName: b.driverName, metricsEnabled: b.metricsEnabled}
}
func (b *BunAdapter) NewDelete() common.DeleteQuery {
return &BunDeleteQuery{query: b.db.NewDelete()}
return &BunDeleteQuery{query: b.getDB().NewDelete(), driverName: b.driverName, metricsEnabled: b.metricsEnabled}
}
func (b *BunAdapter) Exec(ctx context.Context, query string, args ...interface{}) (res common.Result, err error) {
@@ -154,7 +191,17 @@ func (b *BunAdapter) Exec(ctx context.Context, query string, args ...interface{}
err = logger.HandlePanic("BunAdapter.Exec", r)
}
}()
result, err := b.db.ExecContext(ctx, query, args...)
startedAt := time.Now()
operation, schema, entity, table := metricTargetFromRawQuery(query, b.driverName)
var result sql.Result
run := func() error { var e error; result, e = b.getDB().ExecContext(ctx, query, args...); return e }
err = run()
if isDBClosed(err) {
if reconnErr := b.reconnectDB(); reconnErr == nil {
err = run()
}
}
recordQueryMetrics(b.metricsEnabled, operation, schema, entity, table, startedAt, err)
return &BunResult{result: result}, err
}
@@ -164,16 +211,29 @@ func (b *BunAdapter) Query(ctx context.Context, dest interface{}, query string,
err = logger.HandlePanic("BunAdapter.Query", r)
}
}()
return b.db.NewRaw(query, args...).Scan(ctx, dest)
startedAt := time.Now()
operation, schema, entity, table := metricTargetFromRawQuery(query, b.driverName)
err = b.getDB().NewRaw(query, args...).Scan(ctx, dest)
if isDBClosed(err) {
if reconnErr := b.reconnectDB(); reconnErr == nil {
err = b.getDB().NewRaw(query, args...).Scan(ctx, dest)
}
}
recordQueryMetrics(b.metricsEnabled, operation, schema, entity, table, startedAt, err)
return err
}
func (b *BunAdapter) BeginTx(ctx context.Context) (common.Database, error) {
tx, err := b.db.BeginTx(ctx, &sql.TxOptions{})
tx, err := b.getDB().BeginTx(ctx, &sql.TxOptions{})
if isDBClosed(err) {
if reconnErr := b.reconnectDB(); reconnErr == nil {
tx, err = b.getDB().BeginTx(ctx, &sql.TxOptions{})
}
}
if err != nil {
return nil, err
}
// For Bun, we'll return a special wrapper that holds the transaction
return &BunTxAdapter{tx: tx, driverName: b.driverName}, nil
return &BunTxAdapter{tx: tx, driverName: b.driverName, metricsEnabled: b.metricsEnabled}, nil
}
func (b *BunAdapter) CommitTx(ctx context.Context) error {
@@ -194,15 +254,23 @@ func (b *BunAdapter) RunInTransaction(ctx context.Context, fn func(common.Databa
err = logger.HandlePanic("BunAdapter.RunInTransaction", r)
}
}()
return b.db.RunInTx(ctx, &sql.TxOptions{}, func(ctx context.Context, tx bun.Tx) error {
// Create adapter with transaction
adapter := &BunTxAdapter{tx: tx, driverName: b.driverName}
return fn(adapter)
})
run := func() error {
return b.getDB().RunInTx(ctx, &sql.TxOptions{}, func(ctx context.Context, tx bun.Tx) error {
adapter := &BunTxAdapter{tx: tx, driverName: b.driverName, metricsEnabled: b.metricsEnabled}
return fn(adapter)
})
}
err = run()
if isDBClosed(err) {
if reconnErr := b.reconnectDB(); reconnErr == nil {
err = run()
}
}
return err
}
func (b *BunAdapter) GetUnderlyingDB() interface{} {
return b.db
return b.getDB()
}
func (b *BunAdapter) DriverName() string {
@@ -226,25 +294,24 @@ type BunSelectQuery struct {
hasModel bool // Track if Model() was called
schema string // Separated schema name
tableName string // Just the table name, without schema
entity string
tableAlias string
driverName string // Database driver name (postgres, sqlite, mssql)
inJoinContext bool // Track if we're in a JOIN relation context
joinTableAlias string // Alias to use for JOIN conditions
skipAutoDetect bool // Skip auto-detection to prevent circular calls
customPreloads map[string][]func(common.SelectQuery) common.SelectQuery // Relations to load with custom implementation
metricsEnabled bool
}
func (b *BunSelectQuery) Model(model interface{}) common.SelectQuery {
b.query = b.query.Model(model)
b.hasModel = true // Mark that we have a model
// Try to get table name from model if it implements TableNameProvider
if provider, ok := model.(common.TableNameProvider); ok {
fullTableName := provider.TableName()
// Check if the table name contains schema (e.g., "schema.table")
// For SQLite, this will convert "schema.table" to "schema_table"
b.schema, b.tableName = parseTableName(fullTableName, b.driverName)
b.schema, b.tableName = schemaAndTableFromModel(model, b.driverName)
if b.tableName == "" {
b.schema, b.tableName = parseTableName(b.query.GetTableName(), b.driverName)
}
b.entity = entityNameFromModel(model, b.tableName)
if provider, ok := model.(common.TableAliasProvider); ok {
b.tableAlias = provider.TableAlias()
@@ -258,6 +325,9 @@ func (b *BunSelectQuery) Table(table string) common.SelectQuery {
// Check if the table name contains schema (e.g., "schema.table")
// For SQLite, this will convert "schema.table" to "schema_table"
b.schema, b.tableName = parseTableName(table, b.driverName)
if b.entity == "" {
b.entity = cleanMetricIdentifier(b.tableName)
}
return b
}
@@ -563,9 +633,10 @@ func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.S
// Wrap the incoming *bun.SelectQuery in our adapter
wrapper := &BunSelectQuery{
query: sq,
db: b.db,
driverName: b.driverName,
query: sq,
db: b.db,
driverName: b.driverName,
metricsEnabled: b.metricsEnabled,
}
// Try to extract table name and alias from the preload model
@@ -816,7 +887,7 @@ func (b *BunSelectQuery) loadRelationLevel(ctx context.Context, parentRecords re
// Apply user's functions (if any)
if isLast && len(applyFuncs) > 0 {
wrapper := &BunSelectQuery{query: query, db: b.db, driverName: b.driverName}
wrapper := &BunSelectQuery{query: query, db: b.db, driverName: b.driverName, metricsEnabled: b.metricsEnabled}
for _, fn := range applyFuncs {
if fn != nil {
wrapper = fn(wrapper).(*BunSelectQuery)
@@ -1168,27 +1239,28 @@ func (b *BunSelectQuery) Having(having string, args ...interface{}) common.Selec
}
func (b *BunSelectQuery) Scan(ctx context.Context, dest interface{}) (err error) {
startedAt := time.Now()
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("BunSelectQuery.Scan", r)
}
recordQueryMetrics(b.metricsEnabled, "SELECT", b.schema, b.entity, b.tableName, startedAt, err)
}()
if dest == nil {
return fmt.Errorf("destination cannot be nil")
err = fmt.Errorf("destination cannot be nil")
return err
}
err = b.query.Scan(ctx, dest)
if err != nil {
// Log SQL string for debugging
sqlStr := b.query.String()
logger.Error("BunSelectQuery.Scan failed. SQL: %s. Error: %v", sqlStr, err)
return err
}
return nil
return err
}
func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) {
startedAt := time.Now()
defer func() {
if r := recover(); r != nil {
// Enhanced panic recovery with model information
@@ -1198,7 +1270,6 @@ func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) {
modelValue := model.Value()
modelInfo = fmt.Sprintf("Model type: %T", modelValue)
// Try to get the model's underlying struct type
v := reflect.ValueOf(modelValue)
if v.Kind() == reflect.Ptr {
v = v.Elem()
@@ -1218,9 +1289,11 @@ func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) {
logger.Error("Panic in BunSelectQuery.ScanModel: %v. %s. SQL: %s", r, modelInfo, sqlStr)
err = logger.HandlePanic("BunSelectQuery.ScanModel", r)
}
recordQueryMetrics(b.metricsEnabled, "SELECT", b.schema, b.entity, b.tableName, startedAt, err)
}()
if b.query.GetModel() == nil {
return fmt.Errorf("model is nil")
err = fmt.Errorf("model is nil")
return err
}
// Optional: Enable detailed field-level debugging (set to true to debug)
@@ -1236,7 +1309,6 @@ func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) {
err = b.query.Scan(ctx)
if err != nil {
// Log SQL string for debugging
sqlStr := b.query.String()
logger.Error("BunSelectQuery.ScanModel failed. SQL: %s. Error: %v", sqlStr, err)
return err
@@ -1245,7 +1317,7 @@ func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) {
// After main query, load custom preloads using separate queries
if len(b.customPreloads) > 0 {
logger.Info("Loading %d custom preload(s) with separate queries", len(b.customPreloads))
if err := b.loadCustomPreloads(ctx); err != nil {
if err = b.loadCustomPreloads(ctx); err != nil {
logger.Error("Failed to load custom preloads: %v", err)
return err
}
@@ -1255,21 +1327,22 @@ func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) {
}
func (b *BunSelectQuery) Count(ctx context.Context) (count int, err error) {
startedAt := time.Now()
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("BunSelectQuery.Count", r)
count = 0
}
recordQueryMetrics(b.metricsEnabled, "COUNT", b.schema, b.entity, b.tableName, startedAt, err)
}()
// If Model() was set, use bun's native Count() which works properly
if b.hasModel {
count, err := b.query.Count(ctx)
count, err = b.query.Count(ctx) // assign to named returns, not shadow vars
if err != nil {
// Log SQL string for debugging
sqlStr := b.query.String()
logger.Error("BunSelectQuery.Count failed. SQL: %s. Error: %v", sqlStr, err)
}
return count, err
return
}
// Otherwise, wrap as subquery to avoid "Model(nil)" error
@@ -1279,39 +1352,49 @@ func (b *BunSelectQuery) Count(ctx context.Context) (count int, err error) {
ColumnExpr("COUNT(*)")
err = countQuery.Scan(ctx, &count)
if err != nil {
// Log SQL string for debugging
sqlStr := countQuery.String()
logger.Error("BunSelectQuery.Count (subquery) failed. SQL: %s. Error: %v", sqlStr, err)
}
return count, err
return
}
func (b *BunSelectQuery) Exists(ctx context.Context) (exists bool, err error) {
startedAt := time.Now()
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("BunSelectQuery.Exists", r)
exists = false
}
recordQueryMetrics(b.metricsEnabled, "EXISTS", b.schema, b.entity, b.tableName, startedAt, err)
}()
exists, err = b.query.Exists(ctx)
if err != nil {
// Log SQL string for debugging
sqlStr := b.query.String()
logger.Error("BunSelectQuery.Exists failed. SQL: %s. Error: %v", sqlStr, err)
}
return exists, err
return
}
// BunInsertQuery implements InsertQuery for Bun
type BunInsertQuery struct {
query *bun.InsertQuery
values map[string]interface{}
hasModel bool
query *bun.InsertQuery
values map[string]interface{}
hasModel bool
driverName string
schema string
tableName string
entity string
metricsEnabled bool
}
func (b *BunInsertQuery) Model(model interface{}) common.InsertQuery {
b.query = b.query.Model(model)
b.hasModel = true
b.schema, b.tableName = schemaAndTableFromModel(model, b.driverName)
if b.tableName == "" {
b.schema, b.tableName = parseTableName(b.query.GetTableName(), b.driverName)
}
b.entity = entityNameFromModel(model, b.tableName)
return b
}
@@ -1320,6 +1403,10 @@ func (b *BunInsertQuery) Table(table string) common.InsertQuery {
return b
}
b.query = b.query.Table(table)
b.schema, b.tableName = parseTableName(table, b.driverName)
if b.entity == "" {
b.entity = cleanMetricIdentifier(b.tableName)
}
return b
}
@@ -1349,6 +1436,7 @@ func (b *BunInsertQuery) Exec(ctx context.Context) (res common.Result, err error
err = logger.HandlePanic("BunInsertQuery.Exec", r)
}
}()
startedAt := time.Now()
if len(b.values) > 0 {
if !b.hasModel {
// If no model was set, use the values map as the model
@@ -1362,29 +1450,45 @@ func (b *BunInsertQuery) Exec(ctx context.Context) (res common.Result, err error
}
}
result, err := b.query.Exec(ctx)
recordQueryMetrics(b.metricsEnabled, "INSERT", b.schema, b.entity, b.tableName, startedAt, err)
return &BunResult{result: result}, err
}
// BunUpdateQuery implements UpdateQuery for Bun
type BunUpdateQuery struct {
query *bun.UpdateQuery
model interface{}
query *bun.UpdateQuery
model interface{}
driverName string
schema string
tableName string
entity string
metricsEnabled bool
}
func (b *BunUpdateQuery) Model(model interface{}) common.UpdateQuery {
b.query = b.query.Model(model)
b.model = model
b.schema, b.tableName = schemaAndTableFromModel(model, b.driverName)
if b.tableName == "" {
b.schema, b.tableName = parseTableName(b.query.GetTableName(), b.driverName)
}
b.entity = entityNameFromModel(model, b.tableName)
return b
}
func (b *BunUpdateQuery) Table(table string) common.UpdateQuery {
b.query = b.query.Table(table)
b.schema, b.tableName = parseTableName(table, b.driverName)
if b.entity == "" {
b.entity = cleanMetricIdentifier(b.tableName)
}
if b.model == nil {
// Try to get table name from table string if model is not set
model, err := modelregistry.GetModelByName(table)
if err == nil {
b.model = model
b.entity = entityNameFromModel(model, b.tableName)
}
}
return b
@@ -1435,27 +1539,43 @@ func (b *BunUpdateQuery) Exec(ctx context.Context) (res common.Result, err error
err = logger.HandlePanic("BunUpdateQuery.Exec", r)
}
}()
startedAt := time.Now()
result, err := b.query.Exec(ctx)
if err != nil {
// Log SQL string for debugging
sqlStr := b.query.String()
logger.Error("BunUpdateQuery.Exec failed. SQL: %s. Error: %v", sqlStr, err)
}
recordQueryMetrics(b.metricsEnabled, "UPDATE", b.schema, b.entity, b.tableName, startedAt, err)
return &BunResult{result: result}, err
}
// BunDeleteQuery implements DeleteQuery for Bun
type BunDeleteQuery struct {
query *bun.DeleteQuery
query *bun.DeleteQuery
driverName string
schema string
tableName string
entity string
metricsEnabled bool
}
func (b *BunDeleteQuery) Model(model interface{}) common.DeleteQuery {
b.query = b.query.Model(model)
b.schema, b.tableName = schemaAndTableFromModel(model, b.driverName)
if b.tableName == "" {
b.schema, b.tableName = parseTableName(b.query.GetTableName(), b.driverName)
}
b.entity = entityNameFromModel(model, b.tableName)
return b
}
func (b *BunDeleteQuery) Table(table string) common.DeleteQuery {
b.query = b.query.Table(table)
b.schema, b.tableName = parseTableName(table, b.driverName)
if b.entity == "" {
b.entity = cleanMetricIdentifier(b.tableName)
}
return b
}
@@ -1470,12 +1590,14 @@ func (b *BunDeleteQuery) Exec(ctx context.Context) (res common.Result, err error
err = logger.HandlePanic("BunDeleteQuery.Exec", r)
}
}()
startedAt := time.Now()
result, err := b.query.Exec(ctx)
if err != nil {
// Log SQL string for debugging
sqlStr := b.query.String()
logger.Error("BunDeleteQuery.Exec failed. SQL: %s. Error: %v", sqlStr, err)
}
recordQueryMetrics(b.metricsEnabled, "DELETE", b.schema, b.entity, b.tableName, startedAt, err)
return &BunResult{result: result}, err
}
@@ -1501,37 +1623,46 @@ func (b *BunResult) LastInsertId() (int64, error) {
// BunTxAdapter wraps a Bun transaction to implement the Database interface
type BunTxAdapter struct {
tx bun.Tx
driverName string
tx bun.Tx
driverName string
metricsEnabled bool
}
func (b *BunTxAdapter) NewSelect() common.SelectQuery {
return &BunSelectQuery{
query: b.tx.NewSelect(),
db: b.tx,
driverName: b.driverName,
query: b.tx.NewSelect(),
db: b.tx,
driverName: b.driverName,
metricsEnabled: b.metricsEnabled,
}
}
func (b *BunTxAdapter) NewInsert() common.InsertQuery {
return &BunInsertQuery{query: b.tx.NewInsert()}
return &BunInsertQuery{query: b.tx.NewInsert(), driverName: b.driverName, metricsEnabled: b.metricsEnabled}
}
func (b *BunTxAdapter) NewUpdate() common.UpdateQuery {
return &BunUpdateQuery{query: b.tx.NewUpdate()}
return &BunUpdateQuery{query: b.tx.NewUpdate(), driverName: b.driverName, metricsEnabled: b.metricsEnabled}
}
func (b *BunTxAdapter) NewDelete() common.DeleteQuery {
return &BunDeleteQuery{query: b.tx.NewDelete()}
return &BunDeleteQuery{query: b.tx.NewDelete(), driverName: b.driverName, metricsEnabled: b.metricsEnabled}
}
func (b *BunTxAdapter) Exec(ctx context.Context, query string, args ...interface{}) (common.Result, error) {
startedAt := time.Now()
operation, schema, entity, table := metricTargetFromRawQuery(query, b.driverName)
result, err := b.tx.ExecContext(ctx, query, args...)
recordQueryMetrics(b.metricsEnabled, operation, schema, entity, table, startedAt, err)
return &BunResult{result: result}, err
}
func (b *BunTxAdapter) Query(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
return b.tx.NewRaw(query, args...).Scan(ctx, dest)
startedAt := time.Now()
operation, schema, entity, table := metricTargetFromRawQuery(query, b.driverName)
err := b.tx.NewRaw(query, args...).Scan(ctx, dest)
recordQueryMetrics(b.metricsEnabled, operation, schema, entity, table, startedAt, err)
return err
}
func (b *BunTxAdapter) BeginTx(ctx context.Context) (common.Database, error) {

View File

@@ -4,6 +4,8 @@ import (
"context"
"fmt"
"strings"
"sync"
"time"
"gorm.io/gorm"
@@ -15,22 +17,93 @@ import (
// GormAdapter adapts GORM to work with our Database interface
type GormAdapter struct {
db *gorm.DB
driverName string
dbMu sync.RWMutex
db *gorm.DB
dbFactory func() (*gorm.DB, error)
driverName string
metricsEnabled bool
}
// NewGormAdapter creates a new GORM adapter
func NewGormAdapter(db *gorm.DB) *GormAdapter {
adapter := &GormAdapter{db: db}
adapter := &GormAdapter{db: db, metricsEnabled: true}
// Initialize driver name
adapter.driverName = adapter.DriverName()
return adapter
}
// WithDBFactory configures a factory used to reopen the database connection if it is closed.
func (g *GormAdapter) WithDBFactory(factory func() (*gorm.DB, error)) *GormAdapter {
g.dbFactory = factory
return g
}
// SetMetricsEnabled enables or disables query metrics for this adapter.
func (g *GormAdapter) SetMetricsEnabled(enabled bool) *GormAdapter {
g.metricsEnabled = enabled
return g
}
func (g *GormAdapter) getDB() *gorm.DB {
g.dbMu.RLock()
defer g.dbMu.RUnlock()
return g.db
}
func (g *GormAdapter) reconnectDB(targets ...*gorm.DB) error {
if g.dbFactory == nil {
return fmt.Errorf("no db factory configured for reconnect")
}
freshDB, err := g.dbFactory()
if err != nil {
return err
}
g.dbMu.Lock()
previous := g.db
g.db = freshDB
g.driverName = normalizeGormDriverName(freshDB)
g.dbMu.Unlock()
if previous != nil {
syncGormConnPool(previous, freshDB)
}
for _, target := range targets {
if target != nil && target != previous {
syncGormConnPool(target, freshDB)
}
}
return nil
}
func syncGormConnPool(target, fresh *gorm.DB) {
if target == nil || fresh == nil {
return
}
if target.Config != nil && fresh.Config != nil {
target.ConnPool = fresh.ConnPool
}
if target.Statement != nil {
if fresh.Statement != nil && fresh.Statement.ConnPool != nil {
target.Statement.ConnPool = fresh.Statement.ConnPool
} else if fresh.Config != nil {
target.Statement.ConnPool = fresh.ConnPool
}
target.Statement.DB = target
}
}
// EnableQueryDebug enables query debugging which logs all SQL queries including preloads
// This is useful for debugging preload queries that may be failing
func (g *GormAdapter) EnableQueryDebug() *GormAdapter {
g.dbMu.Lock()
g.db = g.db.Debug()
g.dbMu.Unlock()
logger.Info("GORM query debug mode enabled - all SQL queries will be logged")
return g
}
@@ -44,19 +117,19 @@ func (g *GormAdapter) DisableQueryDebug() *GormAdapter {
}
func (g *GormAdapter) NewSelect() common.SelectQuery {
return &GormSelectQuery{db: g.db, driverName: g.driverName}
return &GormSelectQuery{db: g.getDB(), driverName: g.driverName, reconnect: g.reconnectDB, metricsEnabled: g.metricsEnabled}
}
func (g *GormAdapter) NewInsert() common.InsertQuery {
return &GormInsertQuery{db: g.db}
return &GormInsertQuery{db: g.getDB(), reconnect: g.reconnectDB, driverName: g.driverName, metricsEnabled: g.metricsEnabled}
}
func (g *GormAdapter) NewUpdate() common.UpdateQuery {
return &GormUpdateQuery{db: g.db}
return &GormUpdateQuery{db: g.getDB(), reconnect: g.reconnectDB, driverName: g.driverName, metricsEnabled: g.metricsEnabled}
}
func (g *GormAdapter) NewDelete() common.DeleteQuery {
return &GormDeleteQuery{db: g.db}
return &GormDeleteQuery{db: g.getDB(), reconnect: g.reconnectDB, driverName: g.driverName, metricsEnabled: g.metricsEnabled}
}
func (g *GormAdapter) Exec(ctx context.Context, query string, args ...interface{}) (res common.Result, err error) {
@@ -65,7 +138,18 @@ func (g *GormAdapter) Exec(ctx context.Context, query string, args ...interface{
err = logger.HandlePanic("GormAdapter.Exec", r)
}
}()
result := g.db.WithContext(ctx).Exec(query, args...)
startedAt := time.Now()
operation, schema, entity, table := metricTargetFromRawQuery(query, g.driverName)
run := func() *gorm.DB {
return g.getDB().WithContext(ctx).Exec(query, args...)
}
result := run()
if isDBClosed(result.Error) {
if reconnErr := g.reconnectDB(); reconnErr == nil {
result = run()
}
}
recordQueryMetrics(g.metricsEnabled, operation, schema, entity, table, startedAt, result.Error)
return &GormResult{result: result}, result.Error
}
@@ -75,15 +159,35 @@ func (g *GormAdapter) Query(ctx context.Context, dest interface{}, query string,
err = logger.HandlePanic("GormAdapter.Query", r)
}
}()
return g.db.WithContext(ctx).Raw(query, args...).Find(dest).Error
startedAt := time.Now()
operation, schema, entity, table := metricTargetFromRawQuery(query, g.driverName)
run := func() error {
return g.getDB().WithContext(ctx).Raw(query, args...).Find(dest).Error
}
err = run()
if isDBClosed(err) {
if reconnErr := g.reconnectDB(); reconnErr == nil {
err = run()
}
}
recordQueryMetrics(g.metricsEnabled, operation, schema, entity, table, startedAt, err)
return err
}
func (g *GormAdapter) BeginTx(ctx context.Context) (common.Database, error) {
tx := g.db.WithContext(ctx).Begin()
run := func() *gorm.DB {
return g.getDB().WithContext(ctx).Begin()
}
tx := run()
if isDBClosed(tx.Error) {
if reconnErr := g.reconnectDB(); reconnErr == nil {
tx = run()
}
}
if tx.Error != nil {
return nil, tx.Error
}
return &GormAdapter{db: tx, driverName: g.driverName}, nil
return &GormAdapter{db: tx, dbFactory: g.dbFactory, driverName: g.driverName, metricsEnabled: g.metricsEnabled}, nil
}
func (g *GormAdapter) CommitTx(ctx context.Context) error {
@@ -100,24 +204,37 @@ func (g *GormAdapter) RunInTransaction(ctx context.Context, fn func(common.Datab
err = logger.HandlePanic("GormAdapter.RunInTransaction", r)
}
}()
return g.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
adapter := &GormAdapter{db: tx, driverName: g.driverName}
return fn(adapter)
})
run := func() error {
return g.getDB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
adapter := &GormAdapter{db: tx, dbFactory: g.dbFactory, driverName: g.driverName, metricsEnabled: g.metricsEnabled}
return fn(adapter)
})
}
err = run()
if isDBClosed(err) {
if reconnErr := g.reconnectDB(); reconnErr == nil {
err = run()
}
}
return err
}
func (g *GormAdapter) GetUnderlyingDB() interface{} {
return g.db
return g.getDB()
}
func (g *GormAdapter) DriverName() string {
if g.db.Dialector == nil {
return normalizeGormDriverName(g.getDB())
}
func normalizeGormDriverName(db *gorm.DB) string {
if db == nil || db.Dialector == nil {
return ""
}
// Normalize GORM's dialector name to match the project's canonical vocabulary.
// GORM returns "sqlserver" for MSSQL; the rest of the project uses "mssql".
// GORM returns "sqlite" or "sqlite3" for SQLite; we normalize to "sqlite".
switch name := g.db.Name(); name {
switch name := db.Name(); name {
case "sqlserver":
return "mssql"
case "sqlite3":
@@ -130,24 +247,21 @@ func (g *GormAdapter) DriverName() string {
// GormSelectQuery implements SelectQuery for GORM
type GormSelectQuery struct {
db *gorm.DB
reconnect func(...*gorm.DB) error
schema string // Separated schema name
tableName string // Just the table name, without schema
entity string
tableAlias string
driverName string // Database driver name (postgres, sqlite, mssql)
inJoinContext bool // Track if we're in a JOIN relation context
joinTableAlias string // Alias to use for JOIN conditions
metricsEnabled bool
}
func (g *GormSelectQuery) Model(model interface{}) common.SelectQuery {
g.db = g.db.Model(model)
// Try to get table name from model if it implements TableNameProvider
if provider, ok := model.(common.TableNameProvider); ok {
fullTableName := provider.TableName()
// Check if the table name contains schema (e.g., "schema.table")
// For SQLite, this will convert "schema.table" to "schema_table"
g.schema, g.tableName = parseTableName(fullTableName, g.driverName)
}
g.schema, g.tableName = schemaAndTableFromModel(model, g.driverName)
g.entity = entityNameFromModel(model, g.tableName)
if provider, ok := model.(common.TableAliasProvider); ok {
g.tableAlias = provider.TableAlias()
@@ -161,6 +275,9 @@ func (g *GormSelectQuery) Table(table string) common.SelectQuery {
// Check if the table name contains schema (e.g., "schema.table")
// For SQLite, this will convert "schema.table" to "schema_table"
g.schema, g.tableName = parseTableName(table, g.driverName)
if g.entity == "" {
g.entity = cleanMetricIdentifier(g.tableName)
}
return g
}
@@ -346,8 +463,10 @@ func (g *GormSelectQuery) PreloadRelation(relation string, apply ...func(common.
}
wrapper := &GormSelectQuery{
db: db,
driverName: g.driverName,
db: db,
reconnect: g.reconnect,
driverName: g.driverName,
metricsEnabled: g.metricsEnabled,
}
current := common.SelectQuery(wrapper)
@@ -385,9 +504,11 @@ func (g *GormSelectQuery) JoinRelation(relation string, apply ...func(common.Sel
wrapper := &GormSelectQuery{
db: db,
reconnect: g.reconnect,
driverName: g.driverName,
inJoinContext: true, // Mark as JOIN context
joinTableAlias: strings.ToLower(relation), // Use relation name as alias
metricsEnabled: g.metricsEnabled,
}
current := common.SelectQuery(wrapper)
@@ -444,7 +565,16 @@ func (g *GormSelectQuery) Scan(ctx context.Context, dest interface{}) (err error
err = logger.HandlePanic("GormSelectQuery.Scan", r)
}
}()
err = g.db.WithContext(ctx).Find(dest).Error
startedAt := time.Now()
run := func() error {
return g.db.WithContext(ctx).Find(dest).Error
}
err = run()
if isDBClosed(err) && g.reconnect != nil {
if reconnErr := g.reconnect(g.db); reconnErr == nil {
err = run()
}
}
if err != nil {
// Log SQL string for debugging
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
@@ -452,6 +582,7 @@ func (g *GormSelectQuery) Scan(ctx context.Context, dest interface{}) (err error
})
logger.Error("GormSelectQuery.Scan failed. SQL: %s. Error: %v", sqlStr, err)
}
recordQueryMetrics(g.metricsEnabled, "SELECT", g.schema, g.entity, g.tableName, startedAt, err)
return err
}
@@ -464,7 +595,16 @@ func (g *GormSelectQuery) ScanModel(ctx context.Context) (err error) {
if g.db.Statement.Model == nil {
return fmt.Errorf("ScanModel requires Model() to be set before scanning")
}
err = g.db.WithContext(ctx).Find(g.db.Statement.Model).Error
startedAt := time.Now()
run := func() error {
return g.db.WithContext(ctx).Find(g.db.Statement.Model).Error
}
err = run()
if isDBClosed(err) && g.reconnect != nil {
if reconnErr := g.reconnect(g.db); reconnErr == nil {
err = run()
}
}
if err != nil {
// Log SQL string for debugging
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
@@ -472,6 +612,7 @@ func (g *GormSelectQuery) ScanModel(ctx context.Context) (err error) {
})
logger.Error("GormSelectQuery.ScanModel failed. SQL: %s. Error: %v", sqlStr, err)
}
recordQueryMetrics(g.metricsEnabled, "SELECT", g.schema, g.entity, g.tableName, startedAt, err)
return err
}
@@ -482,8 +623,17 @@ func (g *GormSelectQuery) Count(ctx context.Context) (count int, err error) {
count = 0
}
}()
startedAt := time.Now()
var count64 int64
err = g.db.WithContext(ctx).Count(&count64).Error
run := func() error {
return g.db.WithContext(ctx).Count(&count64).Error
}
err = run()
if isDBClosed(err) && g.reconnect != nil {
if reconnErr := g.reconnect(g.db); reconnErr == nil {
err = run()
}
}
if err != nil {
// Log SQL string for debugging
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
@@ -491,6 +641,7 @@ func (g *GormSelectQuery) Count(ctx context.Context) (count int, err error) {
})
logger.Error("GormSelectQuery.Count failed. SQL: %s. Error: %v", sqlStr, err)
}
recordQueryMetrics(g.metricsEnabled, "COUNT", g.schema, g.entity, g.tableName, startedAt, err)
return int(count64), err
}
@@ -501,8 +652,17 @@ func (g *GormSelectQuery) Exists(ctx context.Context) (exists bool, err error) {
exists = false
}
}()
startedAt := time.Now()
var count int64
err = g.db.WithContext(ctx).Limit(1).Count(&count).Error
run := func() error {
return g.db.WithContext(ctx).Limit(1).Count(&count).Error
}
err = run()
if isDBClosed(err) && g.reconnect != nil {
if reconnErr := g.reconnect(g.db); reconnErr == nil {
err = run()
}
}
if err != nil {
// Log SQL string for debugging
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
@@ -510,24 +670,37 @@ func (g *GormSelectQuery) Exists(ctx context.Context) (exists bool, err error) {
})
logger.Error("GormSelectQuery.Exists failed. SQL: %s. Error: %v", sqlStr, err)
}
recordQueryMetrics(g.metricsEnabled, "EXISTS", g.schema, g.entity, g.tableName, startedAt, err)
return count > 0, err
}
// GormInsertQuery implements InsertQuery for GORM
type GormInsertQuery struct {
db *gorm.DB
model interface{}
values map[string]interface{}
db *gorm.DB
reconnect func(...*gorm.DB) error
model interface{}
values map[string]interface{}
schema string
tableName string
entity string
driverName string
metricsEnabled bool
}
func (g *GormInsertQuery) Model(model interface{}) common.InsertQuery {
g.model = model
g.db = g.db.Model(model)
g.schema, g.tableName = schemaAndTableFromModel(model, g.driverName)
g.entity = entityNameFromModel(model, g.tableName)
return g
}
func (g *GormInsertQuery) Table(table string) common.InsertQuery {
g.db = g.db.Table(table)
g.schema, g.tableName = parseTableName(table, g.driverName)
if g.entity == "" {
g.entity = cleanMetricIdentifier(g.tableName)
}
return g
}
@@ -555,38 +728,60 @@ func (g *GormInsertQuery) Exec(ctx context.Context) (res common.Result, err erro
err = logger.HandlePanic("GormInsertQuery.Exec", r)
}
}()
var result *gorm.DB
switch {
case g.model != nil:
result = g.db.WithContext(ctx).Create(g.model)
case g.values != nil:
result = g.db.WithContext(ctx).Create(g.values)
default:
result = g.db.WithContext(ctx).Create(map[string]interface{}{})
startedAt := time.Now()
run := func() *gorm.DB {
switch {
case g.model != nil:
return g.db.WithContext(ctx).Create(g.model)
case g.values != nil:
return g.db.WithContext(ctx).Create(g.values)
default:
return g.db.WithContext(ctx).Create(map[string]interface{}{})
}
}
result := run()
if isDBClosed(result.Error) && g.reconnect != nil {
if reconnErr := g.reconnect(g.db); reconnErr == nil {
result = run()
}
}
recordQueryMetrics(g.metricsEnabled, "INSERT", g.schema, g.entity, g.tableName, startedAt, result.Error)
return &GormResult{result: result}, result.Error
}
// GormUpdateQuery implements UpdateQuery for GORM
type GormUpdateQuery struct {
db *gorm.DB
model interface{}
updates interface{}
db *gorm.DB
reconnect func(...*gorm.DB) error
model interface{}
updates interface{}
schema string
tableName string
entity string
driverName string
metricsEnabled bool
}
func (g *GormUpdateQuery) Model(model interface{}) common.UpdateQuery {
g.model = model
g.db = g.db.Model(model)
g.schema, g.tableName = schemaAndTableFromModel(model, g.driverName)
g.entity = entityNameFromModel(model, g.tableName)
return g
}
func (g *GormUpdateQuery) Table(table string) common.UpdateQuery {
g.db = g.db.Table(table)
g.schema, g.tableName = parseTableName(table, g.driverName)
if g.entity == "" {
g.entity = cleanMetricIdentifier(g.tableName)
}
if g.model == nil {
// Try to get table name from table string if model is not set
model, err := modelregistry.GetModelByName(table)
if err == nil {
g.model = model
g.entity = entityNameFromModel(model, g.tableName)
}
}
return g
@@ -647,7 +842,16 @@ func (g *GormUpdateQuery) Exec(ctx context.Context) (res common.Result, err erro
err = logger.HandlePanic("GormUpdateQuery.Exec", r)
}
}()
result := g.db.WithContext(ctx).Updates(g.updates)
startedAt := time.Now()
run := func() *gorm.DB {
return g.db.WithContext(ctx).Updates(g.updates)
}
result := run()
if isDBClosed(result.Error) && g.reconnect != nil {
if reconnErr := g.reconnect(g.db); reconnErr == nil {
result = run()
}
}
if result.Error != nil {
// Log SQL string for debugging
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
@@ -655,23 +859,36 @@ func (g *GormUpdateQuery) Exec(ctx context.Context) (res common.Result, err erro
})
logger.Error("GormUpdateQuery.Exec failed. SQL: %s. Error: %v", sqlStr, result.Error)
}
recordQueryMetrics(g.metricsEnabled, "UPDATE", g.schema, g.entity, g.tableName, startedAt, result.Error)
return &GormResult{result: result}, result.Error
}
// GormDeleteQuery implements DeleteQuery for GORM
type GormDeleteQuery struct {
db *gorm.DB
model interface{}
db *gorm.DB
reconnect func(...*gorm.DB) error
model interface{}
schema string
tableName string
entity string
driverName string
metricsEnabled bool
}
func (g *GormDeleteQuery) Model(model interface{}) common.DeleteQuery {
g.model = model
g.db = g.db.Model(model)
g.schema, g.tableName = schemaAndTableFromModel(model, g.driverName)
g.entity = entityNameFromModel(model, g.tableName)
return g
}
func (g *GormDeleteQuery) Table(table string) common.DeleteQuery {
g.db = g.db.Table(table)
g.schema, g.tableName = parseTableName(table, g.driverName)
if g.entity == "" {
g.entity = cleanMetricIdentifier(g.tableName)
}
return g
}
@@ -686,7 +903,16 @@ func (g *GormDeleteQuery) Exec(ctx context.Context) (res common.Result, err erro
err = logger.HandlePanic("GormDeleteQuery.Exec", r)
}
}()
result := g.db.WithContext(ctx).Delete(g.model)
startedAt := time.Now()
run := func() *gorm.DB {
return g.db.WithContext(ctx).Delete(g.model)
}
result := run()
if isDBClosed(result.Error) && g.reconnect != nil {
if reconnErr := g.reconnect(g.db); reconnErr == nil {
result = run()
}
}
if result.Error != nil {
// Log SQL string for debugging
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
@@ -694,6 +920,7 @@ func (g *GormDeleteQuery) Exec(ctx context.Context) (res common.Result, err erro
})
logger.Error("GormDeleteQuery.Exec failed. SQL: %s. Error: %v", sqlStr, result.Error)
}
recordQueryMetrics(g.metricsEnabled, "DELETE", g.schema, g.entity, g.tableName, startedAt, result.Error)
return &GormResult{result: result}, result.Error
}

View File

@@ -5,7 +5,10 @@ import (
"database/sql"
"fmt"
"reflect"
"sort"
"strings"
"sync"
"time"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/logger"
@@ -16,8 +19,11 @@ import (
// PgSQLAdapter adapts standard database/sql to work with our Database interface
// This provides a lightweight PostgreSQL adapter without ORM overhead
type PgSQLAdapter struct {
db *sql.DB
driverName string
db *sql.DB
dbMu sync.RWMutex
dbFactory func() (*sql.DB, error)
driverName string
metricsEnabled bool
}
// NewPgSQLAdapter creates a new adapter wrapping a standard sql.DB.
@@ -28,7 +34,43 @@ func NewPgSQLAdapter(db *sql.DB, driverName ...string) *PgSQLAdapter {
if len(driverName) > 0 && driverName[0] != "" {
name = driverName[0]
}
return &PgSQLAdapter{db: db, driverName: name}
return &PgSQLAdapter{db: db, driverName: name, metricsEnabled: true}
}
// WithDBFactory configures a factory used to reopen the database connection if it is closed.
func (p *PgSQLAdapter) WithDBFactory(factory func() (*sql.DB, error)) *PgSQLAdapter {
p.dbFactory = factory
return p
}
// SetMetricsEnabled enables or disables query metrics for this adapter.
func (p *PgSQLAdapter) SetMetricsEnabled(enabled bool) *PgSQLAdapter {
p.metricsEnabled = enabled
return p
}
func (p *PgSQLAdapter) getDB() *sql.DB {
p.dbMu.RLock()
defer p.dbMu.RUnlock()
return p.db
}
func (p *PgSQLAdapter) reconnectDB() error {
if p.dbFactory == nil {
return fmt.Errorf("no db factory configured for reconnect")
}
newDB, err := p.dbFactory()
if err != nil {
return err
}
p.dbMu.Lock()
p.db = newDB
p.dbMu.Unlock()
return nil
}
func isDBClosed(err error) bool {
return err != nil && strings.Contains(err.Error(), "sql: database is closed")
}
// EnableQueryDebug enables query debugging for development
@@ -38,37 +80,41 @@ func (p *PgSQLAdapter) EnableQueryDebug() {
func (p *PgSQLAdapter) NewSelect() common.SelectQuery {
return &PgSQLSelectQuery{
db: p.db,
driverName: p.driverName,
columns: []string{"*"},
args: make([]interface{}, 0),
db: p.getDB(),
driverName: p.driverName,
columns: []string{"*"},
args: make([]interface{}, 0),
metricsEnabled: p.metricsEnabled,
}
}
func (p *PgSQLAdapter) NewInsert() common.InsertQuery {
return &PgSQLInsertQuery{
db: p.db,
driverName: p.driverName,
values: make(map[string]interface{}),
db: p.getDB(),
driverName: p.driverName,
values: make(map[string]interface{}),
metricsEnabled: p.metricsEnabled,
}
}
func (p *PgSQLAdapter) NewUpdate() common.UpdateQuery {
return &PgSQLUpdateQuery{
db: p.db,
driverName: p.driverName,
sets: make(map[string]interface{}),
args: make([]interface{}, 0),
whereClauses: make([]string, 0),
db: p.getDB(),
driverName: p.driverName,
sets: make(map[string]interface{}),
args: make([]interface{}, 0),
whereClauses: make([]string, 0),
metricsEnabled: p.metricsEnabled,
}
}
func (p *PgSQLAdapter) NewDelete() common.DeleteQuery {
return &PgSQLDeleteQuery{
db: p.db,
driverName: p.driverName,
args: make([]interface{}, 0),
whereClauses: make([]string, 0),
db: p.getDB(),
driverName: p.driverName,
args: make([]interface{}, 0),
whereClauses: make([]string, 0),
metricsEnabled: p.metricsEnabled,
}
}
@@ -78,12 +124,23 @@ func (p *PgSQLAdapter) Exec(ctx context.Context, query string, args ...interface
err = logger.HandlePanic("PgSQLAdapter.Exec", r)
}
}()
startedAt := time.Now()
operation, schema, entity, table := metricTargetFromRawQuery(query, p.driverName)
logger.Debug("PgSQL Exec: %s [args: %v]", query, args)
result, err := p.db.ExecContext(ctx, query, args...)
var result sql.Result
run := func() error { var e error; result, e = p.getDB().ExecContext(ctx, query, args...); return e }
err = run()
if isDBClosed(err) {
if reconnErr := p.reconnectDB(); reconnErr == nil {
err = run()
}
}
if err != nil {
logger.Error("PgSQL Exec failed: %v", err)
recordQueryMetrics(p.metricsEnabled, operation, schema, entity, table, startedAt, err)
return nil, err
}
recordQueryMetrics(p.metricsEnabled, operation, schema, entity, table, startedAt, nil)
return &PgSQLResult{result: result}, nil
}
@@ -93,23 +150,35 @@ func (p *PgSQLAdapter) Query(ctx context.Context, dest interface{}, query string
err = logger.HandlePanic("PgSQLAdapter.Query", r)
}
}()
startedAt := time.Now()
operation, schema, entity, table := metricTargetFromRawQuery(query, p.driverName)
logger.Debug("PgSQL Query: %s [args: %v]", query, args)
rows, err := p.db.QueryContext(ctx, query, args...)
var rows *sql.Rows
run := func() error { var e error; rows, e = p.getDB().QueryContext(ctx, query, args...); return e }
err = run()
if isDBClosed(err) {
if reconnErr := p.reconnectDB(); reconnErr == nil {
err = run()
}
}
if err != nil {
logger.Error("PgSQL Query failed: %v", err)
recordQueryMetrics(p.metricsEnabled, operation, schema, entity, table, startedAt, err)
return err
}
defer rows.Close()
return scanRows(rows, dest)
err = scanRows(rows, dest)
recordQueryMetrics(p.metricsEnabled, operation, schema, entity, table, startedAt, err)
return err
}
func (p *PgSQLAdapter) BeginTx(ctx context.Context) (common.Database, error) {
tx, err := p.db.BeginTx(ctx, nil)
tx, err := p.getDB().BeginTx(ctx, nil)
if err != nil {
return nil, err
}
return &PgSQLTxAdapter{tx: tx, driverName: p.driverName}, nil
return &PgSQLTxAdapter{tx: tx, driverName: p.driverName, metricsEnabled: p.metricsEnabled}, nil
}
func (p *PgSQLAdapter) CommitTx(ctx context.Context) error {
@@ -127,12 +196,12 @@ func (p *PgSQLAdapter) RunInTransaction(ctx context.Context, fn func(common.Data
}
}()
tx, err := p.db.BeginTx(ctx, nil)
tx, err := p.getDB().BeginTx(ctx, nil)
if err != nil {
return err
}
adapter := &PgSQLTxAdapter{tx: tx, driverName: p.driverName}
adapter := &PgSQLTxAdapter{tx: tx, driverName: p.driverName, metricsEnabled: p.metricsEnabled}
defer func() {
if p := recover(); p != nil {
@@ -175,34 +244,34 @@ type relationMetadata struct {
// PgSQLSelectQuery implements SelectQuery for PostgreSQL
type PgSQLSelectQuery struct {
db *sql.DB
tx *sql.Tx
model interface{}
tableName string
tableAlias string
driverName string // Database driver name (postgres, sqlite, mssql)
columns []string
columnExprs []string
whereClauses []string
orClauses []string
joins []string
orderBy []string
groupBy []string
havingClauses []string
limit int
offset int
args []interface{}
paramCounter int
preloads []preloadConfig
db *sql.DB
tx *sql.Tx
model interface{}
entity string
tableName string
schema string
tableAlias string
driverName string // Database driver name (postgres, sqlite, mssql)
columns []string
columnExprs []string
whereClauses []string
orClauses []string
joins []string
orderBy []string
groupBy []string
havingClauses []string
limit int
offset int
args []interface{}
paramCounter int
preloads []preloadConfig
metricsEnabled bool
}
func (p *PgSQLSelectQuery) Model(model interface{}) common.SelectQuery {
p.model = model
if provider, ok := model.(common.TableNameProvider); ok {
fullTableName := provider.TableName()
// For SQLite, convert "schema.table" to "schema_table"
_, p.tableName = parseTableName(fullTableName, p.driverName)
}
p.schema, p.tableName = schemaAndTableFromModel(model, p.driverName)
p.entity = entityNameFromModel(model, p.tableName)
if provider, ok := model.(common.TableAliasProvider); ok {
p.tableAlias = provider.TableAlias()
}
@@ -211,7 +280,10 @@ func (p *PgSQLSelectQuery) Model(model interface{}) common.SelectQuery {
func (p *PgSQLSelectQuery) Table(table string) common.SelectQuery {
// For SQLite, convert "schema.table" to "schema_table"
_, p.tableName = parseTableName(table, p.driverName)
p.schema, p.tableName = parseTableName(table, p.driverName)
if p.entity == "" {
p.entity = cleanMetricIdentifier(p.tableName)
}
return p
}
@@ -421,6 +493,7 @@ func (p *PgSQLSelectQuery) Scan(ctx context.Context, dest interface{}) (err erro
err = logger.HandlePanic("PgSQLSelectQuery.Scan", r)
}
}()
startedAt := time.Now()
// Apply preloads that use JOINs
p.applyJoinPreloads()
@@ -437,17 +510,21 @@ func (p *PgSQLSelectQuery) Scan(ctx context.Context, dest interface{}) (err erro
if err != nil {
logger.Error("PgSQL SELECT failed: %v", err)
recordQueryMetrics(p.metricsEnabled, "SELECT", p.schema, p.entity, p.tableName, startedAt, err)
return err
}
defer rows.Close()
err = scanRows(rows, dest)
if err != nil {
recordQueryMetrics(p.metricsEnabled, "SELECT", p.schema, p.entity, p.tableName, startedAt, err)
return err
}
// Apply preloads that use separate queries
return p.applySubqueryPreloads(ctx, dest)
err = p.applySubqueryPreloads(ctx, dest)
recordQueryMetrics(p.metricsEnabled, "SELECT", p.schema, p.entity, p.tableName, startedAt, err)
return err
}
func (p *PgSQLSelectQuery) ScanModel(ctx context.Context) error {
@@ -457,15 +534,8 @@ func (p *PgSQLSelectQuery) ScanModel(ctx context.Context) error {
return p.Scan(ctx, p.model)
}
func (p *PgSQLSelectQuery) Count(ctx context.Context) (count int, err error) {
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("PgSQLSelectQuery.Count", r)
count = 0
}
}()
// Build a COUNT query
// countInternal executes the COUNT query and returns the result without recording metrics.
func (p *PgSQLSelectQuery) countInternal(ctx context.Context) (int, error) {
var sb strings.Builder
sb.WriteString("SELECT COUNT(*) FROM ")
sb.WriteString(p.tableName)
@@ -499,10 +569,26 @@ func (p *PgSQLSelectQuery) Count(ctx context.Context) (count int, err error) {
row = p.db.QueryRowContext(ctx, query, p.args...)
}
err = row.Scan(&count)
var count int
if err := row.Scan(&count); err != nil {
return 0, err
}
return count, nil
}
func (p *PgSQLSelectQuery) Count(ctx context.Context) (count int, err error) {
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("PgSQLSelectQuery.Count", r)
count = 0
}
}()
startedAt := time.Now()
count, err = p.countInternal(ctx)
if err != nil {
logger.Error("PgSQL COUNT failed: %v", err)
}
recordQueryMetrics(p.metricsEnabled, "COUNT", p.schema, p.entity, p.tableName, startedAt, err)
return count, err
}
@@ -513,27 +599,32 @@ func (p *PgSQLSelectQuery) Exists(ctx context.Context) (exists bool, err error)
exists = false
}
}()
count, err := p.Count(ctx)
startedAt := time.Now()
count, err := p.countInternal(ctx)
if err != nil {
logger.Error("PgSQL EXISTS failed: %v", err)
}
recordQueryMetrics(p.metricsEnabled, "EXISTS", p.schema, p.entity, p.tableName, startedAt, err)
return count > 0, err
}
// PgSQLInsertQuery implements InsertQuery for PostgreSQL
type PgSQLInsertQuery struct {
db *sql.DB
tx *sql.Tx
tableName string
driverName string
values map[string]interface{}
returning []string
db *sql.DB
tx *sql.Tx
schema string
tableName string
entity string
driverName string
values map[string]interface{}
valueOrder []string
returning []string
metricsEnabled bool
}
func (p *PgSQLInsertQuery) Model(model interface{}) common.InsertQuery {
if provider, ok := model.(common.TableNameProvider); ok {
fullTableName := provider.TableName()
// For SQLite, convert "schema.table" to "schema_table"
_, p.tableName = parseTableName(fullTableName, p.driverName)
}
p.schema, p.tableName = schemaAndTableFromModel(model, p.driverName)
p.entity = entityNameFromModel(model, p.tableName)
// Extract values from model using reflection
// This is a simplified implementation
return p
@@ -541,11 +632,17 @@ func (p *PgSQLInsertQuery) Model(model interface{}) common.InsertQuery {
func (p *PgSQLInsertQuery) Table(table string) common.InsertQuery {
// For SQLite, convert "schema.table" to "schema_table"
_, p.tableName = parseTableName(table, p.driverName)
p.schema, p.tableName = parseTableName(table, p.driverName)
if p.entity == "" {
p.entity = cleanMetricIdentifier(p.tableName)
}
return p
}
func (p *PgSQLInsertQuery) Value(column string, value interface{}) common.InsertQuery {
if _, exists := p.values[column]; !exists {
p.valueOrder = append(p.valueOrder, column)
}
p.values[column] = value
return p
}
@@ -561,25 +658,27 @@ func (p *PgSQLInsertQuery) Returning(columns ...string) common.InsertQuery {
}
func (p *PgSQLInsertQuery) Exec(ctx context.Context) (res common.Result, err error) {
startedAt := time.Now()
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("PgSQLInsertQuery.Exec", r)
}
recordQueryMetrics(p.metricsEnabled, "INSERT", p.schema, p.entity, p.tableName, startedAt, err)
}()
if len(p.values) == 0 {
return nil, fmt.Errorf("no values to insert")
err = fmt.Errorf("no values to insert")
return nil, err
}
columns := make([]string, 0, len(p.values))
placeholders := make([]string, 0, len(p.values))
args := make([]interface{}, 0, len(p.values))
i := 1
for col, val := range p.values {
for _, col := range p.valueOrder {
columns = append(columns, col)
placeholders = append(placeholders, fmt.Sprintf("$%d", i))
args = append(args, val)
args = append(args, p.values[col])
i++
}
@@ -611,35 +710,40 @@ func (p *PgSQLInsertQuery) Exec(ctx context.Context) (res common.Result, err err
// PgSQLUpdateQuery implements UpdateQuery for PostgreSQL
type PgSQLUpdateQuery struct {
db *sql.DB
tx *sql.Tx
tableName string
driverName string
model interface{}
sets map[string]interface{}
whereClauses []string
args []interface{}
paramCounter int
returning []string
db *sql.DB
tx *sql.Tx
schema string
tableName string
entity string
driverName string
model interface{}
sets map[string]interface{}
setOrder []string
whereClauses []string
args []interface{}
paramCounter int
returning []string
metricsEnabled bool
}
func (p *PgSQLUpdateQuery) Model(model interface{}) common.UpdateQuery {
p.model = model
if provider, ok := model.(common.TableNameProvider); ok {
fullTableName := provider.TableName()
// For SQLite, convert "schema.table" to "schema_table"
_, p.tableName = parseTableName(fullTableName, p.driverName)
}
p.schema, p.tableName = schemaAndTableFromModel(model, p.driverName)
p.entity = entityNameFromModel(model, p.tableName)
return p
}
func (p *PgSQLUpdateQuery) Table(table string) common.UpdateQuery {
// For SQLite, convert "schema.table" to "schema_table"
_, p.tableName = parseTableName(table, p.driverName)
p.schema, p.tableName = parseTableName(table, p.driverName)
if p.entity == "" {
p.entity = cleanMetricIdentifier(p.tableName)
}
if p.model == nil {
model, err := modelregistry.GetModelByName(table)
if err == nil {
p.model = model
p.entity = entityNameFromModel(model, p.tableName)
}
}
return p
@@ -649,6 +753,9 @@ func (p *PgSQLUpdateQuery) Set(column string, value interface{}) common.UpdateQu
if p.model != nil && !reflection.IsColumnWritable(p.model, column) {
return p
}
if _, exists := p.sets[column]; !exists {
p.setOrder = append(p.setOrder, column)
}
p.sets[column] = value
return p
}
@@ -659,13 +766,23 @@ func (p *PgSQLUpdateQuery) SetMap(values map[string]interface{}) common.UpdateQu
pkName = reflection.GetPrimaryKeyName(p.model)
}
for column, value := range values {
orderedColumns := make([]string, 0, len(values))
for column := range values {
orderedColumns = append(orderedColumns, column)
}
sort.Strings(orderedColumns)
for _, column := range orderedColumns {
value := values[column]
if pkName != "" && column == pkName {
continue
}
if p.model != nil && !reflection.IsColumnWritable(p.model, column) {
continue
}
if _, exists := p.sets[column]; !exists {
p.setOrder = append(p.setOrder, column)
}
p.sets[column] = value
}
return p
@@ -694,24 +811,26 @@ func (p *PgSQLUpdateQuery) replacePlaceholders(query string, argCount int) strin
}
func (p *PgSQLUpdateQuery) Exec(ctx context.Context) (res common.Result, err error) {
startedAt := time.Now()
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("PgSQLUpdateQuery.Exec", r)
}
recordQueryMetrics(p.metricsEnabled, "UPDATE", p.schema, p.entity, p.tableName, startedAt, err)
}()
if len(p.sets) == 0 {
return nil, fmt.Errorf("no values to update")
err = fmt.Errorf("no values to update")
return nil, err
}
setClauses := make([]string, 0, len(p.sets))
setArgs := make([]interface{}, 0, len(p.sets))
// SET parameters start at $1
i := 1
for col, val := range p.sets {
for _, col := range p.setOrder {
setClauses = append(setClauses, fmt.Sprintf("%s = $%d", col, i))
setArgs = append(setArgs, val)
setArgs = append(setArgs, p.sets[col])
i++
}
@@ -773,27 +892,30 @@ func (p *PgSQLUpdateQuery) Exec(ctx context.Context) (res common.Result, err err
// PgSQLDeleteQuery implements DeleteQuery for PostgreSQL
type PgSQLDeleteQuery struct {
db *sql.DB
tx *sql.Tx
tableName string
driverName string
whereClauses []string
args []interface{}
paramCounter int
db *sql.DB
tx *sql.Tx
schema string
tableName string
entity string
driverName string
whereClauses []string
args []interface{}
paramCounter int
metricsEnabled bool
}
func (p *PgSQLDeleteQuery) Model(model interface{}) common.DeleteQuery {
if provider, ok := model.(common.TableNameProvider); ok {
fullTableName := provider.TableName()
// For SQLite, convert "schema.table" to "schema_table"
_, p.tableName = parseTableName(fullTableName, p.driverName)
}
p.schema, p.tableName = schemaAndTableFromModel(model, p.driverName)
p.entity = entityNameFromModel(model, p.tableName)
return p
}
func (p *PgSQLDeleteQuery) Table(table string) common.DeleteQuery {
// For SQLite, convert "schema.table" to "schema_table"
_, p.tableName = parseTableName(table, p.driverName)
p.schema, p.tableName = parseTableName(table, p.driverName)
if p.entity == "" {
p.entity = cleanMetricIdentifier(p.tableName)
}
return p
}
@@ -815,10 +937,12 @@ func (p *PgSQLDeleteQuery) replacePlaceholders(query string, argCount int) strin
}
func (p *PgSQLDeleteQuery) Exec(ctx context.Context) (res common.Result, err error) {
startedAt := time.Now()
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("PgSQLDeleteQuery.Exec", r)
}
recordQueryMetrics(p.metricsEnabled, "DELETE", p.schema, p.entity, p.tableName, startedAt, err)
}()
query := fmt.Sprintf("DELETE FROM %s", p.tableName)
@@ -866,66 +990,80 @@ func (p *PgSQLResult) LastInsertId() (int64, error) {
// PgSQLTxAdapter wraps a PostgreSQL transaction
type PgSQLTxAdapter struct {
tx *sql.Tx
driverName string
tx *sql.Tx
driverName string
metricsEnabled bool
}
func (p *PgSQLTxAdapter) NewSelect() common.SelectQuery {
return &PgSQLSelectQuery{
tx: p.tx,
driverName: p.driverName,
columns: []string{"*"},
args: make([]interface{}, 0),
tx: p.tx,
driverName: p.driverName,
columns: []string{"*"},
args: make([]interface{}, 0),
metricsEnabled: p.metricsEnabled,
}
}
func (p *PgSQLTxAdapter) NewInsert() common.InsertQuery {
return &PgSQLInsertQuery{
tx: p.tx,
driverName: p.driverName,
values: make(map[string]interface{}),
tx: p.tx,
driverName: p.driverName,
values: make(map[string]interface{}),
metricsEnabled: p.metricsEnabled,
}
}
func (p *PgSQLTxAdapter) NewUpdate() common.UpdateQuery {
return &PgSQLUpdateQuery{
tx: p.tx,
driverName: p.driverName,
sets: make(map[string]interface{}),
args: make([]interface{}, 0),
whereClauses: make([]string, 0),
tx: p.tx,
driverName: p.driverName,
sets: make(map[string]interface{}),
args: make([]interface{}, 0),
whereClauses: make([]string, 0),
metricsEnabled: p.metricsEnabled,
}
}
func (p *PgSQLTxAdapter) NewDelete() common.DeleteQuery {
return &PgSQLDeleteQuery{
tx: p.tx,
driverName: p.driverName,
args: make([]interface{}, 0),
whereClauses: make([]string, 0),
tx: p.tx,
driverName: p.driverName,
args: make([]interface{}, 0),
whereClauses: make([]string, 0),
metricsEnabled: p.metricsEnabled,
}
}
func (p *PgSQLTxAdapter) Exec(ctx context.Context, query string, args ...interface{}) (common.Result, error) {
startedAt := time.Now()
operation, schema, entity, table := metricTargetFromRawQuery(query, p.driverName)
logger.Debug("PgSQL Tx Exec: %s [args: %v]", query, args)
result, err := p.tx.ExecContext(ctx, query, args...)
if err != nil {
logger.Error("PgSQL Tx Exec failed: %v", err)
recordQueryMetrics(p.metricsEnabled, operation, schema, entity, table, startedAt, err)
return nil, err
}
recordQueryMetrics(p.metricsEnabled, operation, schema, entity, table, startedAt, nil)
return &PgSQLResult{result: result}, nil
}
func (p *PgSQLTxAdapter) Query(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
startedAt := time.Now()
operation, schema, entity, table := metricTargetFromRawQuery(query, p.driverName)
logger.Debug("PgSQL Tx Query: %s [args: %v]", query, args)
rows, err := p.tx.QueryContext(ctx, query, args...)
if err != nil {
logger.Error("PgSQL Tx Query failed: %v", err)
recordQueryMetrics(p.metricsEnabled, operation, schema, entity, table, startedAt, err)
return err
}
defer rows.Close()
return scanRows(rows, dest)
err = scanRows(rows, dest)
recordQueryMetrics(p.metricsEnabled, operation, schema, entity, table, startedAt, err)
return err
}
func (p *PgSQLTxAdapter) BeginTx(ctx context.Context) (common.Database, error) {

View File

@@ -0,0 +1,335 @@
package database
import (
"reflect"
"strings"
"time"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/metrics"
"github.com/bitechdev/ResolveSpec/pkg/reflection"
)
const maxMetricFallbackEntityLength = 120
func recordQueryMetrics(enabled bool, operation, schema, entity, table string, startedAt time.Time, err error) {
if !enabled {
return
}
metrics.GetProvider().RecordDBQuery(
normalizeMetricOperation(operation),
normalizeMetricSchema(schema),
normalizeMetricEntity(entity, table),
normalizeMetricTable(table),
time.Since(startedAt),
err,
)
}
func normalizeMetricOperation(operation string) string {
operation = strings.ToUpper(strings.TrimSpace(operation))
if operation == "" {
return "UNKNOWN"
}
return operation
}
func normalizeMetricSchema(schema string) string {
schema = cleanMetricIdentifier(schema)
if schema == "" {
return "default"
}
return schema
}
func normalizeMetricEntity(entity, table string) string {
entity = cleanMetricIdentifier(entity)
if entity != "" {
return entity
}
table = cleanMetricIdentifier(table)
if table != "" {
return table
}
return "unknown"
}
func normalizeMetricTable(table string) string {
table = cleanMetricIdentifier(table)
if table == "" {
return "unknown"
}
return table
}
func entityNameFromModel(model interface{}, table string) string {
if model == nil {
return cleanMetricIdentifier(table)
}
modelType := reflect.TypeOf(model)
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
modelType = modelType.Elem()
}
if modelType == nil {
return cleanMetricIdentifier(table)
}
if modelType.Kind() == reflect.Struct && modelType.Name() != "" {
return reflection.ToSnakeCase(modelType.Name())
}
return cleanMetricIdentifier(table)
}
func schemaAndTableFromModel(model interface{}, driverName string) (schema, table string) {
provider, ok := tableNameProviderFromModel(model)
if !ok {
return "", ""
}
return parseTableName(provider.TableName(), driverName)
}
// tableNameProviderType is cached to avoid repeated reflection on every call.
var tableNameProviderType = reflect.TypeOf((*common.TableNameProvider)(nil)).Elem()
func tableNameProviderFromModel(model interface{}) (common.TableNameProvider, bool) {
if model == nil {
return nil, false
}
if provider, ok := model.(common.TableNameProvider); ok {
return provider, true
}
modelType := reflect.TypeOf(model)
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
modelType = modelType.Elem()
}
if modelType == nil || modelType.Kind() != reflect.Struct {
return nil, false
}
// Check whether *T implements TableNameProvider before allocating.
ptrType := reflect.PointerTo(modelType)
if !ptrType.Implements(tableNameProviderType) && !modelType.Implements(tableNameProviderType) {
return nil, false
}
modelValue := reflect.New(modelType)
if provider, ok := modelValue.Interface().(common.TableNameProvider); ok {
return provider, true
}
if provider, ok := modelValue.Elem().Interface().(common.TableNameProvider); ok {
return provider, true
}
return nil, false
}
func metricTargetFromRawQuery(query, driverName string) (operation, schema, entity, table string) {
operation = normalizeMetricOperation(firstQueryKeyword(query))
tableRef := tableFromRawQuery(query, operation)
if tableRef == "" {
return operation, "", fallbackMetricEntityFromQuery(query), "unknown"
}
schema, table = parseTableName(tableRef, driverName)
entity = cleanMetricIdentifier(table)
return operation, schema, entity, table
}
func fallbackMetricEntityFromQuery(query string) string {
query = sanitizeMetricQueryShape(query)
if query == "" {
return "unknown"
}
if len(query) > maxMetricFallbackEntityLength {
return query[:maxMetricFallbackEntityLength-3] + "..."
}
return query
}
func sanitizeMetricQueryShape(query string) string {
query = strings.TrimSpace(query)
if query == "" {
return ""
}
var out strings.Builder
for i := 0; i < len(query); {
if query[i] == '\'' {
out.WriteByte('?')
i++
for i < len(query) {
if query[i] == '\'' {
if i+1 < len(query) && query[i+1] == '\'' {
i += 2
continue
}
i++
break
}
i++
}
continue
}
if query[i] == '?' {
out.WriteByte('?')
i++
continue
}
if query[i] == '$' && i+1 < len(query) && isASCIIDigit(query[i+1]) {
out.WriteByte('?')
i++
for i < len(query) && isASCIIDigit(query[i]) {
i++
}
continue
}
if query[i] == ':' && (i == 0 || query[i-1] != ':') && i+1 < len(query) && isIdentifierStart(query[i+1]) {
out.WriteByte('?')
i++
for i < len(query) && isIdentifierPart(query[i]) {
i++
}
continue
}
if query[i] == '@' && (i == 0 || query[i-1] != '@') && i+1 < len(query) && isIdentifierStart(query[i+1]) {
out.WriteByte('?')
i++
for i < len(query) && isIdentifierPart(query[i]) {
i++
}
continue
}
if startsNumericLiteral(query, i) {
out.WriteByte('?')
i++
for i < len(query) && (isASCIIDigit(query[i]) || query[i] == '.') {
i++
}
continue
}
out.WriteByte(query[i])
i++
}
return strings.Join(strings.Fields(out.String()), " ")
}
func startsNumericLiteral(query string, idx int) bool {
if idx >= len(query) {
return false
}
start := idx
if query[idx] == '-' {
if idx+1 >= len(query) || !isASCIIDigit(query[idx+1]) {
return false
}
start++
}
if !isASCIIDigit(query[start]) {
return false
}
if idx > 0 && isIdentifierPart(query[idx-1]) {
return false
}
if start+1 < len(query) && query[start] == '0' && (query[start+1] == 'x' || query[start+1] == 'X') {
return false
}
return true
}
func isASCIIDigit(ch byte) bool {
return ch >= '0' && ch <= '9'
}
func isIdentifierStart(ch byte) bool {
return (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || ch == '_'
}
func isIdentifierPart(ch byte) bool {
return isIdentifierStart(ch) || isASCIIDigit(ch)
}
func firstQueryKeyword(query string) string {
query = strings.TrimSpace(query)
if query == "" {
return ""
}
fields := strings.Fields(query)
if len(fields) == 0 {
return ""
}
return fields[0]
}
func tableFromRawQuery(query, operation string) string {
tokens := tokenizeQuery(query)
if len(tokens) == 0 {
return ""
}
switch operation {
case "SELECT":
return tokenAfter(tokens, "FROM")
case "INSERT":
return tokenAfter(tokens, "INTO")
case "UPDATE":
return tokenAfter(tokens, "UPDATE")
case "DELETE":
return tokenAfter(tokens, "FROM")
default:
return ""
}
}
func tokenAfter(tokens []string, keyword string) string {
for idx, token := range tokens {
if strings.EqualFold(token, keyword) && idx+1 < len(tokens) {
return cleanMetricIdentifier(tokens[idx+1])
}
}
return ""
}
func tokenizeQuery(query string) []string {
replacer := strings.NewReplacer(
"\n", " ",
"\t", " ",
"(", " ",
")", " ",
",", " ",
)
return strings.Fields(replacer.Replace(query))
}
func cleanMetricIdentifier(value string) string {
value = strings.TrimSpace(value)
value = strings.Trim(value, "\"'`[]")
value = strings.TrimRight(value, ";")
return value
}

View File

@@ -0,0 +1,348 @@
package database
import (
"context"
"database/sql"
"fmt"
"net/http"
"strings"
"sync"
"testing"
"time"
"github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/uptrace/bun"
"github.com/uptrace/bun/dialect/sqlitedialect"
"github.com/uptrace/bun/driver/sqliteshim"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"github.com/bitechdev/ResolveSpec/pkg/metrics"
)
type queryMetricCall struct {
operation string
schema string
entity string
table string
}
type capturingMetricsProvider struct {
mu sync.Mutex
calls []queryMetricCall
}
func (c *capturingMetricsProvider) RecordHTTPRequest(method, path, status string, duration time.Duration) {
}
func (c *capturingMetricsProvider) IncRequestsInFlight() {}
func (c *capturingMetricsProvider) DecRequestsInFlight() {}
func (c *capturingMetricsProvider) RecordDBQuery(operation, schema, entity, table string, duration time.Duration, err error) {
c.mu.Lock()
defer c.mu.Unlock()
c.calls = append(c.calls, queryMetricCall{
operation: operation,
schema: schema,
entity: entity,
table: table,
})
}
func (c *capturingMetricsProvider) RecordCacheHit(provider string) {}
func (c *capturingMetricsProvider) RecordCacheMiss(provider string) {}
func (c *capturingMetricsProvider) UpdateCacheSize(provider string, size int64) {
}
func (c *capturingMetricsProvider) RecordEventPublished(source, eventType string) {}
func (c *capturingMetricsProvider) RecordEventProcessed(source, eventType, status string, duration time.Duration) {
}
func (c *capturingMetricsProvider) UpdateEventQueueSize(size int64) {}
func (c *capturingMetricsProvider) RecordPanic(methodName string) {}
func (c *capturingMetricsProvider) Handler() http.Handler { return http.NewServeMux() }
func (c *capturingMetricsProvider) snapshot() []queryMetricCall {
c.mu.Lock()
defer c.mu.Unlock()
out := make([]queryMetricCall, len(c.calls))
copy(out, c.calls)
return out
}
type queryMetricsGormUser struct {
ID int `gorm:"primaryKey"`
Name string
}
func (queryMetricsGormUser) TableName() string {
return "metrics_gorm_users"
}
type queryMetricsBunUser struct {
bun.BaseModel `bun:"table:metrics_bun_users"`
ID int64 `bun:"id,pk,autoincrement"`
Name string `bun:"name"`
}
func TestPgSQLAdapterRecordsSchemaEntityTableMetrics(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
provider := &capturingMetricsProvider{}
prev := metrics.GetProvider()
metrics.SetProvider(provider)
defer metrics.SetProvider(prev)
mock.ExpectExec(`UPDATE users SET name = \$1 WHERE id = \$2`).
WithArgs("Alice", 1).
WillReturnResult(sqlmock.NewResult(0, 1))
adapter := NewPgSQLAdapter(db)
_, err = adapter.NewUpdate().
Table("public.users").
Set("name", "Alice").
Where("id = ?", 1).
Exec(context.Background())
require.NoError(t, err)
require.NoError(t, mock.ExpectationsWereMet())
calls := provider.snapshot()
require.Len(t, calls, 1)
assert.Equal(t, "UPDATE", calls[0].operation)
assert.Equal(t, "public", calls[0].schema)
assert.Equal(t, "users", calls[0].entity)
assert.Equal(t, "users", calls[0].table)
}
func TestPgSQLAdapterDisableMetricsSuppressesEmission(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
provider := &capturingMetricsProvider{}
prev := metrics.GetProvider()
metrics.SetProvider(provider)
defer metrics.SetProvider(prev)
mock.ExpectExec(`DELETE FROM users WHERE id = \$1`).
WithArgs(1).
WillReturnResult(sqlmock.NewResult(0, 1))
adapter := NewPgSQLAdapter(db).SetMetricsEnabled(false)
_, err = adapter.NewDelete().
Table("users").
Where("id = ?", 1).
Exec(context.Background())
require.NoError(t, err)
require.NoError(t, mock.ExpectationsWereMet())
assert.Empty(t, provider.snapshot())
}
func TestGormAdapterRecordsEntityAndTableMetrics(t *testing.T) {
db, err := gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{})
require.NoError(t, err)
require.NoError(t, db.AutoMigrate(&queryMetricsGormUser{}))
require.NoError(t, db.Create(&queryMetricsGormUser{Name: "Alice"}).Error)
provider := &capturingMetricsProvider{}
prev := metrics.GetProvider()
metrics.SetProvider(provider)
defer metrics.SetProvider(prev)
adapter := NewGormAdapter(db)
var users []queryMetricsGormUser
err = adapter.NewSelect().Model(&users).Scan(context.Background(), &users)
require.NoError(t, err)
require.NotEmpty(t, users)
calls := provider.snapshot()
require.Len(t, calls, 1)
assert.Equal(t, "SELECT", calls[0].operation)
assert.Equal(t, "default", calls[0].schema)
assert.Equal(t, "query_metrics_gorm_user", calls[0].entity)
assert.Equal(t, "metrics_gorm_users", calls[0].table)
}
func TestPgSQLAdapterRecordsErrorMetric(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
provider := &capturingMetricsProvider{}
prev := metrics.GetProvider()
metrics.SetProvider(provider)
defer metrics.SetProvider(prev)
mock.ExpectExec(`INSERT INTO users`).
WillReturnError(fmt.Errorf("unique constraint violation"))
adapter := NewPgSQLAdapter(db)
_, err = adapter.NewInsert().
Table("users").
Value("name", "Alice").
Exec(context.Background())
require.Error(t, err)
calls := provider.snapshot()
require.Len(t, calls, 1)
assert.Equal(t, "INSERT", calls[0].operation)
assert.Equal(t, "users", calls[0].table)
}
func TestPgSQLAdapterRecordsExistsMetric(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
provider := &capturingMetricsProvider{}
prev := metrics.GetProvider()
metrics.SetProvider(provider)
defer metrics.SetProvider(prev)
mock.ExpectQuery(`SELECT COUNT\(\*\) FROM users`).
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(3))
adapter := NewPgSQLAdapter(db)
exists, err := adapter.NewSelect().Table("users").Exists(context.Background())
require.NoError(t, err)
assert.True(t, exists)
calls := provider.snapshot()
require.Len(t, calls, 1)
assert.Equal(t, "EXISTS", calls[0].operation)
assert.Equal(t, "users", calls[0].table)
}
func TestPgSQLAdapterRecordsCountMetric(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
provider := &capturingMetricsProvider{}
prev := metrics.GetProvider()
metrics.SetProvider(provider)
defer metrics.SetProvider(prev)
mock.ExpectQuery(`SELECT COUNT\(\*\) FROM users`).
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(5))
adapter := NewPgSQLAdapter(db)
count, err := adapter.NewSelect().Table("users").Count(context.Background())
require.NoError(t, err)
assert.Equal(t, 5, count)
calls := provider.snapshot()
require.Len(t, calls, 1)
assert.Equal(t, "COUNT", calls[0].operation)
assert.Equal(t, "users", calls[0].table)
}
func TestPgSQLAdapterRawExecRecordsMetric(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
provider := &capturingMetricsProvider{}
prev := metrics.GetProvider()
metrics.SetProvider(provider)
defer metrics.SetProvider(prev)
mock.ExpectExec(`UPDATE public\.orders SET status = \$1`).
WithArgs("shipped").
WillReturnResult(sqlmock.NewResult(0, 2))
adapter := NewPgSQLAdapter(db)
_, err = adapter.Exec(context.Background(), `UPDATE public.orders SET status = $1`, "shipped")
require.NoError(t, err)
calls := provider.snapshot()
require.Len(t, calls, 1)
assert.Equal(t, "UPDATE", calls[0].operation)
assert.Equal(t, "public", calls[0].schema)
assert.Equal(t, "orders", calls[0].table)
}
func TestPgSQLAdapterRawExecUsesSQLAsEntityWhenTargetUnknown(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
provider := &capturingMetricsProvider{}
prev := metrics.GetProvider()
metrics.SetProvider(provider)
defer metrics.SetProvider(prev)
query := `select core.c_setuserid($1)`
mock.ExpectExec(`select core\.c_setuserid\(\$1\)`).
WithArgs(42).
WillReturnResult(sqlmock.NewResult(0, 1))
adapter := NewPgSQLAdapter(db)
_, err = adapter.Exec(context.Background(), query, 42)
require.NoError(t, err)
calls := provider.snapshot()
require.Len(t, calls, 1)
assert.Equal(t, "SELECT", calls[0].operation)
assert.Equal(t, "default", calls[0].schema)
assert.Equal(t, "select core.c_setuserid(?)", calls[0].entity)
assert.Equal(t, "unknown", calls[0].table)
}
func TestFallbackMetricEntityFromQuerySanitizesAndTruncates(t *testing.T) {
entity := fallbackMetricEntityFromQuery(" \n SELECT some_function(1, 'abc', $2, ?, :name, @p1, true, null) \t ")
assert.Equal(t, "SELECT some_function(?, ?, ?, ?, ?, ?, true, null)", entity)
entity = fallbackMetricEntityFromQuery("SELECT price::numeric, id FROM logs WHERE code = -42")
assert.Equal(t, "SELECT price::numeric, id FROM logs WHERE code = ?", entity)
longQuery := "SELECT " + strings.Repeat("x", maxMetricFallbackEntityLength)
entity = fallbackMetricEntityFromQuery(longQuery)
assert.Len(t, entity, maxMetricFallbackEntityLength)
assert.True(t, strings.HasSuffix(entity, "..."))
}
func TestBunAdapterRecordsEntityAndTableMetrics(t *testing.T) {
sqldb, err := sql.Open(sqliteshim.ShimName, "file::memory:?cache=shared")
require.NoError(t, err)
defer sqldb.Close()
db := bun.NewDB(sqldb, sqlitedialect.New())
defer db.Close()
_, err = db.NewCreateTable().
Model((*queryMetricsBunUser)(nil)).
IfNotExists().
Exec(context.Background())
require.NoError(t, err)
_, err = db.NewInsert().Model(&queryMetricsBunUser{Name: "Alice"}).Exec(context.Background())
require.NoError(t, err)
provider := &capturingMetricsProvider{}
prev := metrics.GetProvider()
metrics.SetProvider(provider)
defer metrics.SetProvider(prev)
adapter := NewBunAdapter(db)
var users []queryMetricsBunUser
err = adapter.NewSelect().Model(&users).Scan(context.Background(), &users)
require.NoError(t, err)
require.NotEmpty(t, users)
calls := provider.snapshot()
require.Len(t, calls, 1)
assert.Equal(t, "SELECT", calls[0].operation)
assert.Equal(t, "default", calls[0].schema)
assert.Equal(t, "query_metrics_bun_user", calls[0].entity)
assert.Equal(t, "metrics_bun_users", calls[0].table)
}

View File

@@ -98,8 +98,8 @@ func (p *NestedCUDProcessor) ProcessNestedCUD(
}
}
// Filter regularData to only include fields that exist in the model
// Use MapToStruct to validate and filter fields
// Filter regularData to only include fields that exist in the model,
// and translate JSON keys to their actual database column names.
regularData = p.filterValidFields(regularData, model)
// Inject parent IDs for foreign key resolution
@@ -191,14 +191,15 @@ func (p *NestedCUDProcessor) extractCRUDRequest(data map[string]interface{}) str
return ""
}
// filterValidFields filters input data to only include fields that exist in the model
// Uses reflection.MapToStruct to validate fields and extract only those that match the model
// filterValidFields filters input data to only include fields that exist in the model,
// and translates JSON key names to their actual database column names.
// For example, a field tagged `json:"_changed_date" bun:"changed_date"` will be
// included in the result as "changed_date", not "_changed_date".
func (p *NestedCUDProcessor) filterValidFields(data map[string]interface{}, model interface{}) map[string]interface{} {
if len(data) == 0 {
return data
}
// Create a new instance of the model to use with MapToStruct
modelType := reflect.TypeOf(model)
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
modelType = modelType.Elem()
@@ -208,25 +209,16 @@ func (p *NestedCUDProcessor) filterValidFields(data map[string]interface{}, mode
return data
}
// Create a new instance of the model
tempModel := reflect.New(modelType).Interface()
// Build a mapping from JSON key -> DB column name for all writable fields.
// This both validates which fields belong to the model and translates their names
// to the correct column names for use in SQL insert/update queries.
jsonToDBCol := reflection.BuildJSONToDBColumnMap(modelType)
// Use MapToStruct to map the data - this will only map valid fields
err := reflection.MapToStruct(data, tempModel)
if err != nil {
logger.Debug("Error mapping data to model: %v", err)
return data
}
// Extract the mapped fields back into a map
// This effectively filters out any fields that don't exist in the model
filteredData := make(map[string]interface{})
tempModelValue := reflect.ValueOf(tempModel).Elem()
for key, value := range data {
// Check if the field was successfully mapped
if fieldWasMapped(tempModelValue, modelType, key) {
filteredData[key] = value
dbColName, exists := jsonToDBCol[key]
if exists {
filteredData[dbColName] = value
} else {
logger.Debug("Skipping invalid field '%s' - not found in model %v", key, modelType)
}
@@ -235,72 +227,8 @@ func (p *NestedCUDProcessor) filterValidFields(data map[string]interface{}, mode
return filteredData
}
// fieldWasMapped checks if a field with the given key was mapped to the model
func fieldWasMapped(modelValue reflect.Value, modelType reflect.Type, key string) bool {
// Look for the field by JSON tag or field name
for i := 0; i < modelType.NumField(); i++ {
field := modelType.Field(i)
// Skip unexported fields
if !field.IsExported() {
continue
}
// Check JSON tag
jsonTag := field.Tag.Get("json")
if jsonTag != "" && jsonTag != "-" {
parts := strings.Split(jsonTag, ",")
if len(parts) > 0 && parts[0] == key {
return true
}
}
// Check bun tag
bunTag := field.Tag.Get("bun")
if bunTag != "" && bunTag != "-" {
if colName := reflection.ExtractColumnFromBunTag(bunTag); colName == key {
return true
}
}
// Check gorm tag
gormTag := field.Tag.Get("gorm")
if gormTag != "" && gormTag != "-" {
if colName := reflection.ExtractColumnFromGormTag(gormTag); colName == key {
return true
}
}
// Check lowercase field name
if strings.EqualFold(field.Name, key) {
return true
}
// Handle embedded structs recursively
if field.Anonymous {
fieldType := field.Type
if fieldType.Kind() == reflect.Ptr {
fieldType = fieldType.Elem()
}
if fieldType.Kind() == reflect.Struct {
embeddedValue := modelValue.Field(i)
if embeddedValue.Kind() == reflect.Ptr {
if embeddedValue.IsNil() {
continue
}
embeddedValue = embeddedValue.Elem()
}
if fieldWasMapped(embeddedValue, fieldType, key) {
return true
}
}
}
}
return false
}
// injectForeignKeys injects parent IDs into data for foreign key fields
// injectForeignKeys injects parent IDs into data for foreign key fields.
// data is expected to be keyed by DB column names (as returned by filterValidFields).
func (p *NestedCUDProcessor) injectForeignKeys(data map[string]interface{}, modelType reflect.Type, parentIDs map[string]interface{}) {
if len(parentIDs) == 0 {
return
@@ -319,10 +247,11 @@ func (p *NestedCUDProcessor) injectForeignKeys(data map[string]interface{}, mode
if strings.EqualFold(jsonName, parentKey+"_id") ||
strings.EqualFold(jsonName, parentKey+"id") ||
strings.EqualFold(field.Name, parentKey+"ID") {
// Only inject if not already present
if _, exists := data[jsonName]; !exists {
logger.Debug("Injecting foreign key: %s = %v", jsonName, parentID)
data[jsonName] = parentID
// Use the DB column name as the key, since data is keyed by DB column names
dbColName := reflection.GetColumnName(field)
if _, exists := data[dbColName]; !exists {
logger.Debug("Injecting foreign key: %s = %v", dbColName, parentID)
data[dbColName] = parentID
}
}
}

View File

@@ -168,16 +168,17 @@ func SanitizeWhereClause(where string, tableName string, options ...*RequestOpti
}
// Build a set of allowed table prefixes (main table + preloaded relations)
// Keys are stored lowercase for case-insensitive matching
allowedPrefixes := make(map[string]bool)
if tableName != "" {
allowedPrefixes[tableName] = true
allowedPrefixes[strings.ToLower(tableName)] = true
}
// Add preload relation names as allowed prefixes
if len(options) > 0 && options[0] != nil {
for pi := range options[0].Preload {
if options[0].Preload[pi].Relation != "" {
allowedPrefixes[options[0].Preload[pi].Relation] = true
allowedPrefixes[strings.ToLower(options[0].Preload[pi].Relation)] = true
logger.Debug("Added preload relation '%s' as allowed table prefix", options[0].Preload[pi].Relation)
}
}
@@ -185,7 +186,7 @@ func SanitizeWhereClause(where string, tableName string, options ...*RequestOpti
// Add join aliases as allowed prefixes
for _, alias := range options[0].JoinAliases {
if alias != "" {
allowedPrefixes[alias] = true
allowedPrefixes[strings.ToLower(alias)] = true
logger.Debug("Added join alias '%s' as allowed table prefix", alias)
}
}
@@ -217,8 +218,8 @@ func SanitizeWhereClause(where string, tableName string, options ...*RequestOpti
currentPrefix, columnName := extractTableAndColumn(condToCheck)
if currentPrefix != "" && columnName != "" {
// Check if the prefix is allowed (main table or preload relation)
if !allowedPrefixes[currentPrefix] {
// Check if the prefix is allowed (main table or preload relation) - case-insensitive
if !allowedPrefixes[strings.ToLower(currentPrefix)] {
// Prefix is not in the allowed list - only fix if it's a valid column in the main table
if validColumns == nil || isValidColumn(columnName, validColumns) {
// Replace the incorrect prefix with the correct main table name

View File

@@ -26,6 +26,7 @@ type Connection interface {
Bun() (*bun.DB, error)
GORM() (*gorm.DB, error)
Native() (*sql.DB, error)
DB() (*sql.DB, error)
// Common Database interface (for SQL databases)
Database() (common.Database, error)
@@ -224,6 +225,11 @@ func (c *sqlConnection) Native() (*sql.DB, error) {
return c.nativeDB, nil
}
// DB returns the underlying *sql.DB connection
func (c *sqlConnection) DB() (*sql.DB, error) {
return c.Native()
}
// Bun returns a Bun ORM instance wrapping the native connection
func (c *sqlConnection) Bun() (*bun.DB, error) {
if c == nil {
@@ -353,6 +359,42 @@ func (c *sqlConnection) Stats() *ConnectionStats {
return stats
}
func (c *sqlConnection) reconnectForAdapter() error {
timeout := c.config.ConnectTimeout
if timeout <= 0 {
timeout = 10 * time.Second
}
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
return c.Reconnect(ctx)
}
func (c *sqlConnection) reopenNativeForAdapter() (*sql.DB, error) {
if err := c.reconnectForAdapter(); err != nil {
return nil, err
}
return c.Native()
}
func (c *sqlConnection) reopenBunForAdapter() (*bun.DB, error) {
if err := c.reconnectForAdapter(); err != nil {
return nil, err
}
return c.Bun()
}
func (c *sqlConnection) reopenGORMForAdapter() (*gorm.DB, error) {
if err := c.reconnectForAdapter(); err != nil {
return nil, err
}
return c.GORM()
}
// getBunAdapter returns or creates the Bun adapter
func (c *sqlConnection) getBunAdapter() (common.Database, error) {
if c == nil {
@@ -385,7 +427,9 @@ func (c *sqlConnection) getBunAdapter() (common.Database, error) {
c.bunDB = bun.NewDB(native, dialect)
}
c.bunAdapter = database.NewBunAdapter(c.bunDB)
c.bunAdapter = database.NewBunAdapter(c.bunDB).
WithDBFactory(c.reopenBunForAdapter).
SetMetricsEnabled(c.config.EnableMetrics)
return c.bunAdapter, nil
}
@@ -426,7 +470,9 @@ func (c *sqlConnection) getGORMAdapter() (common.Database, error) {
c.gormDB = db
}
c.gormAdapter = database.NewGormAdapter(c.gormDB)
c.gormAdapter = database.NewGormAdapter(c.gormDB).
WithDBFactory(c.reopenGORMForAdapter).
SetMetricsEnabled(c.config.EnableMetrics)
return c.gormAdapter, nil
}
@@ -467,11 +513,17 @@ func (c *sqlConnection) getNativeAdapter() (common.Database, error) {
// Create a native adapter based on database type
switch c.dbType {
case DatabaseTypePostgreSQL:
c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB, string(c.dbType))
c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB, string(c.dbType)).
WithDBFactory(c.reopenNativeForAdapter).
SetMetricsEnabled(c.config.EnableMetrics)
case DatabaseTypeSQLite:
c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB, string(c.dbType))
c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB, string(c.dbType)).
WithDBFactory(c.reopenNativeForAdapter).
SetMetricsEnabled(c.config.EnableMetrics)
case DatabaseTypeMSSQL:
c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB, string(c.dbType))
c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB, string(c.dbType)).
WithDBFactory(c.reopenNativeForAdapter).
SetMetricsEnabled(c.config.EnableMetrics)
default:
return nil, ErrUnsupportedDatabase
}
@@ -645,6 +697,11 @@ func (c *mongoConnection) Native() (*sql.DB, error) {
return nil, ErrNotSQLDatabase
}
// DB returns an error for MongoDB connections
func (c *mongoConnection) DB() (*sql.DB, error) {
return nil, ErrNotSQLDatabase
}
// Database returns an error for MongoDB connections
func (c *mongoConnection) Database() (common.Database, error) {
return nil, ErrNotSQLDatabase

View File

@@ -4,8 +4,13 @@ import (
"context"
"database/sql"
"testing"
"time"
_ "github.com/mattn/go-sqlite3"
"gorm.io/gorm"
"github.com/bitechdev/ResolveSpec/pkg/common/adapters/database"
"github.com/bitechdev/ResolveSpec/pkg/dbmanager/providers"
)
func TestNewConnectionFromDB(t *testing.T) {
@@ -208,3 +213,157 @@ func TestNewConnectionFromDB_PostgreSQL(t *testing.T) {
t.Errorf("Expected type DatabaseTypePostgreSQL, got '%s'", conn.Type())
}
}
func TestDatabaseNativeAdapterReconnectFactory(t *testing.T) {
conn := newSQLConnection("test-native", DatabaseTypeSQLite, ConnectionConfig{
Name: "test-native",
Type: DatabaseTypeSQLite,
FilePath: ":memory:",
DefaultORM: string(ORMTypeNative),
ConnectTimeout: 2 * time.Second,
}, providers.NewSQLiteProvider())
ctx := context.Background()
if err := conn.Connect(ctx); err != nil {
t.Fatalf("Failed to connect: %v", err)
}
defer conn.Close()
db, err := conn.Database()
if err != nil {
t.Fatalf("Failed to get database adapter: %v", err)
}
adapter, ok := db.(*database.PgSQLAdapter)
if !ok {
t.Fatalf("Expected PgSQLAdapter, got %T", db)
}
underlyingBefore, ok := adapter.GetUnderlyingDB().(*sql.DB)
if !ok {
t.Fatalf("Expected underlying *sql.DB, got %T", adapter.GetUnderlyingDB())
}
if err := underlyingBefore.Close(); err != nil {
t.Fatalf("Failed to close underlying database: %v", err)
}
if _, err := db.Exec(ctx, "SELECT 1"); err != nil {
t.Fatalf("Expected native adapter to reconnect, got error: %v", err)
}
underlyingAfter, ok := adapter.GetUnderlyingDB().(*sql.DB)
if !ok {
t.Fatalf("Expected reconnected *sql.DB, got %T", adapter.GetUnderlyingDB())
}
if underlyingAfter == underlyingBefore {
t.Fatal("Expected adapter to swap to a fresh *sql.DB after reconnect")
}
}
func TestDatabaseBunAdapterReconnectFactory(t *testing.T) {
conn := newSQLConnection("test-bun", DatabaseTypeSQLite, ConnectionConfig{
Name: "test-bun",
Type: DatabaseTypeSQLite,
FilePath: ":memory:",
DefaultORM: string(ORMTypeBun),
ConnectTimeout: 2 * time.Second,
}, providers.NewSQLiteProvider())
ctx := context.Background()
if err := conn.Connect(ctx); err != nil {
t.Fatalf("Failed to connect: %v", err)
}
defer conn.Close()
db, err := conn.Database()
if err != nil {
t.Fatalf("Failed to get database adapter: %v", err)
}
adapter, ok := db.(*database.BunAdapter)
if !ok {
t.Fatalf("Expected BunAdapter, got %T", db)
}
underlyingBefore, ok := adapter.GetUnderlyingDB().(interface{ Close() error })
if !ok {
t.Fatalf("Expected underlying Bun DB with Close method, got %T", adapter.GetUnderlyingDB())
}
if err := underlyingBefore.Close(); err != nil {
t.Fatalf("Failed to close underlying Bun database: %v", err)
}
if _, err := db.Exec(ctx, "SELECT 1"); err != nil {
t.Fatalf("Expected Bun adapter to reconnect, got error: %v", err)
}
underlyingAfter := adapter.GetUnderlyingDB()
if underlyingAfter == underlyingBefore {
t.Fatal("Expected adapter to swap to a fresh Bun DB after reconnect")
}
}
func TestDatabaseGormAdapterReconnectFactory(t *testing.T) {
conn := newSQLConnection("test-gorm", DatabaseTypeSQLite, ConnectionConfig{
Name: "test-gorm",
Type: DatabaseTypeSQLite,
FilePath: ":memory:",
DefaultORM: string(ORMTypeGORM),
ConnectTimeout: 2 * time.Second,
}, providers.NewSQLiteProvider())
ctx := context.Background()
if err := conn.Connect(ctx); err != nil {
t.Fatalf("Failed to connect: %v", err)
}
defer conn.Close()
db, err := conn.Database()
if err != nil {
t.Fatalf("Failed to get database adapter: %v", err)
}
adapter, ok := db.(*database.GormAdapter)
if !ok {
t.Fatalf("Expected GormAdapter, got %T", db)
}
gormBefore, ok := adapter.GetUnderlyingDB().(*gorm.DB)
if !ok {
t.Fatalf("Expected underlying *gorm.DB, got %T", adapter.GetUnderlyingDB())
}
sqlBefore, err := gormBefore.DB()
if err != nil {
t.Fatalf("Failed to get underlying *sql.DB: %v", err)
}
if err := sqlBefore.Close(); err != nil {
t.Fatalf("Failed to close underlying database: %v", err)
}
count, err := db.NewSelect().Table("sqlite_master").Count(ctx)
if err != nil {
t.Fatalf("Expected GORM query builder to reconnect, got error: %v", err)
}
if count < 0 {
t.Fatalf("Expected non-negative count, got %d", count)
}
gormAfter, ok := adapter.GetUnderlyingDB().(*gorm.DB)
if !ok {
t.Fatalf("Expected reconnected *gorm.DB, got %T", adapter.GetUnderlyingDB())
}
sqlAfter, err := gormAfter.DB()
if err != nil {
t.Fatalf("Failed to get reconnected *sql.DB: %v", err)
}
if sqlAfter == sqlBefore {
t.Fatal("Expected GORM adapter to use a fresh *sql.DB after reconnect")
}
}

View File

@@ -2,7 +2,9 @@ package dbmanager
import (
"context"
"errors"
"fmt"
"strings"
"sync"
"time"
@@ -366,8 +368,11 @@ func (m *connectionManager) performHealthCheck() {
"connection", item.name,
"error", err)
// Attempt reconnection if enabled
if m.config.EnableAutoReconnect {
// Only reconnect when the client handle itself is closed/disconnected.
// For transient database restarts or network blips, *sql.DB can recover
// on its own; forcing Close()+Connect() here invalidates any cached ORM
// wrappers and callers that still hold the old handle.
if m.config.EnableAutoReconnect && shouldReconnectAfterHealthCheck(err) {
logger.Info("Attempting reconnection: connection=%s", item.name)
if err := item.conn.Reconnect(ctx); err != nil {
logger.Error("Reconnection failed",
@@ -376,7 +381,21 @@ func (m *connectionManager) performHealthCheck() {
} else {
logger.Info("Reconnection successful: connection=%s", item.name)
}
} else if m.config.EnableAutoReconnect {
logger.Info("Skipping reconnect for transient health check failure: connection=%s", item.name)
}
}
}
}
func shouldReconnectAfterHealthCheck(err error) bool {
if err == nil {
return false
}
if errors.Is(err, ErrConnectionClosed) {
return true
}
return strings.Contains(err.Error(), "sql: database is closed")
}

View File

@@ -3,12 +3,38 @@ package dbmanager
import (
"context"
"database/sql"
"fmt"
"testing"
"time"
"github.com/uptrace/bun"
"go.mongodb.org/mongo-driver/mongo"
"gorm.io/gorm"
"github.com/bitechdev/ResolveSpec/pkg/common"
_ "github.com/mattn/go-sqlite3"
)
type healthCheckStubConnection struct {
healthErr error
reconnectCalls int
}
func (c *healthCheckStubConnection) Name() string { return "stub" }
func (c *healthCheckStubConnection) Type() DatabaseType { return DatabaseTypePostgreSQL }
func (c *healthCheckStubConnection) Bun() (*bun.DB, error) { return nil, fmt.Errorf("not implemented") }
func (c *healthCheckStubConnection) GORM() (*gorm.DB, error) { return nil, fmt.Errorf("not implemented") }
func (c *healthCheckStubConnection) Native() (*sql.DB, error) { return nil, fmt.Errorf("not implemented") }
func (c *healthCheckStubConnection) DB() (*sql.DB, error) { return nil, fmt.Errorf("not implemented") }
func (c *healthCheckStubConnection) Database() (common.Database, error) { return nil, fmt.Errorf("not implemented") }
func (c *healthCheckStubConnection) MongoDB() (*mongo.Client, error) { return nil, fmt.Errorf("not implemented") }
func (c *healthCheckStubConnection) Connect(ctx context.Context) error { return nil }
func (c *healthCheckStubConnection) Close() error { return nil }
func (c *healthCheckStubConnection) HealthCheck(ctx context.Context) error { return c.healthErr }
func (c *healthCheckStubConnection) Reconnect(ctx context.Context) error { c.reconnectCalls++; return nil }
func (c *healthCheckStubConnection) Stats() *ConnectionStats { return &ConnectionStats{} }
func TestBackgroundHealthChecker(t *testing.T) {
// Create a SQLite in-memory database
db, err := sql.Open("sqlite3", ":memory:")
@@ -224,3 +250,41 @@ func TestManagerStatsAfterClose(t *testing.T) {
t.Errorf("Expected 0 total connections after close, got %d", stats.TotalConnections)
}
}
func TestPerformHealthCheckSkipsReconnectForTransientFailures(t *testing.T) {
conn := &healthCheckStubConnection{
healthErr: fmt.Errorf("connection 'primary' health check: dial tcp 127.0.0.1:5432: connect: connection refused"),
}
mgr := &connectionManager{
connections: map[string]Connection{"primary": conn},
config: ManagerConfig{
EnableAutoReconnect: true,
},
}
mgr.performHealthCheck()
if conn.reconnectCalls != 0 {
t.Fatalf("expected no reconnect attempts for transient health failure, got %d", conn.reconnectCalls)
}
}
func TestPerformHealthCheckReconnectsClosedConnections(t *testing.T) {
conn := &healthCheckStubConnection{
healthErr: NewConnectionError("primary", "health check", fmt.Errorf("sql: database is closed")),
}
mgr := &connectionManager{
connections: map[string]Connection{"primary": conn},
config: ManagerConfig{
EnableAutoReconnect: true,
},
}
mgr.performHealthCheck()
if conn.reconnectCalls != 1 {
t.Fatalf("expected reconnect attempt for closed database handle, got %d", conn.reconnectCalls)
}
}

View File

@@ -3,6 +3,7 @@ package providers_test
import (
"context"
"fmt"
"log"
"time"
"github.com/bitechdev/ResolveSpec/pkg/dbmanager"
@@ -29,14 +30,14 @@ func ExamplePostgresListener_basic() {
ctx := context.Background()
if err := provider.Connect(ctx, cfg); err != nil {
panic(fmt.Sprintf("Failed to connect: %v", err))
log.Fatalf("Failed to connect: %v", err)
}
defer provider.Close()
// Get listener
listener, err := provider.GetListener(ctx)
if err != nil {
panic(fmt.Sprintf("Failed to get listener: %v", err))
log.Fatalf("Failed to get listener: %v", err)
}
// Subscribe to a channel with a handler
@@ -44,13 +45,13 @@ func ExamplePostgresListener_basic() {
fmt.Printf("Received notification on %s: %s\n", channel, payload)
})
if err != nil {
panic(fmt.Sprintf("Failed to listen: %v", err))
log.Fatalf("Failed to listen: %v", err)
}
// Send a notification
err = listener.Notify(ctx, "user_events", `{"event":"user_created","user_id":123}`)
if err != nil {
panic(fmt.Sprintf("Failed to notify: %v", err))
log.Fatalf("Failed to notify: %v", err)
}
// Wait for notification to be processed
@@ -58,7 +59,7 @@ func ExamplePostgresListener_basic() {
// Unsubscribe from the channel
if err := listener.Unlisten("user_events"); err != nil {
panic(fmt.Sprintf("Failed to unlisten: %v", err))
log.Fatalf("Failed to unlisten: %v", err)
}
}
@@ -80,13 +81,13 @@ func ExamplePostgresListener_multipleChannels() {
ctx := context.Background()
if err := provider.Connect(ctx, cfg); err != nil {
panic(fmt.Sprintf("Failed to connect: %v", err))
log.Fatalf("Failed to connect: %v", err)
}
defer provider.Close()
listener, err := provider.GetListener(ctx)
if err != nil {
panic(fmt.Sprintf("Failed to get listener: %v", err))
log.Fatalf("Failed to get listener: %v", err)
}
// Listen to multiple channels
@@ -97,7 +98,7 @@ func ExamplePostgresListener_multipleChannels() {
fmt.Printf("[%s] %s\n", ch, payload)
})
if err != nil {
panic(fmt.Sprintf("Failed to listen on %s: %v", channel, err))
log.Fatalf("Failed to listen on %s: %v", channel, err)
}
}
@@ -140,14 +141,14 @@ func ExamplePostgresListener_withDBManager() {
provider := providers.NewPostgresProvider()
if err := provider.Connect(ctx, cfg); err != nil {
panic(err)
log.Fatal(err)
}
defer provider.Close()
// Get listener
listener, err := provider.GetListener(ctx)
if err != nil {
panic(err)
log.Fatal(err)
}
// Subscribe to application events
@@ -186,13 +187,13 @@ func ExamplePostgresListener_errorHandling() {
ctx := context.Background()
if err := provider.Connect(ctx, cfg); err != nil {
panic(fmt.Sprintf("Failed to connect: %v", err))
log.Fatalf("Failed to connect: %v", err)
}
defer provider.Close()
listener, err := provider.GetListener(ctx)
if err != nil {
panic(fmt.Sprintf("Failed to get listener: %v", err))
log.Fatalf("Failed to get listener: %v", err)
}
// The listener automatically reconnects if the connection is lost

View File

@@ -4,11 +4,17 @@ import (
"context"
"database/sql"
"errors"
"strings"
"time"
"go.mongodb.org/mongo-driver/mongo"
)
// isDBClosed reports whether err indicates the *sql.DB has been closed.
func isDBClosed(err error) bool {
return err != nil && strings.Contains(err.Error(), "sql: database is closed")
}
// Common errors
var (
// ErrNotSQLDatabase is returned when attempting SQL operations on a non-SQL database

View File

@@ -4,6 +4,7 @@ import (
"context"
"database/sql"
"fmt"
"sync"
"time"
_ "github.com/glebarez/sqlite" // Pure Go SQLite driver
@@ -14,8 +15,10 @@ import (
// SQLiteProvider implements Provider for SQLite databases
type SQLiteProvider struct {
db *sql.DB
config ConnectionConfig
db *sql.DB
dbMu sync.RWMutex
dbFactory func() (*sql.DB, error)
config ConnectionConfig
}
// NewSQLiteProvider creates a new SQLite provider
@@ -129,7 +132,13 @@ func (p *SQLiteProvider) HealthCheck(ctx context.Context) error {
// Execute a simple query to verify the database is accessible
var result int
err := p.db.QueryRowContext(healthCtx, "SELECT 1").Scan(&result)
run := func() error { return p.getDB().QueryRowContext(healthCtx, "SELECT 1").Scan(&result) }
err := run()
if isDBClosed(err) {
if reconnErr := p.reconnectDB(); reconnErr == nil {
err = run()
}
}
if err != nil {
return fmt.Errorf("health check failed: %w", err)
}
@@ -141,6 +150,32 @@ func (p *SQLiteProvider) HealthCheck(ctx context.Context) error {
return nil
}
// WithDBFactory configures a factory used to reopen the database connection if it is closed.
func (p *SQLiteProvider) WithDBFactory(factory func() (*sql.DB, error)) *SQLiteProvider {
p.dbFactory = factory
return p
}
func (p *SQLiteProvider) getDB() *sql.DB {
p.dbMu.RLock()
defer p.dbMu.RUnlock()
return p.db
}
func (p *SQLiteProvider) reconnectDB() error {
if p.dbFactory == nil {
return fmt.Errorf("no db factory configured for reconnect")
}
newDB, err := p.dbFactory()
if err != nil {
return err
}
p.dbMu.Lock()
p.db = newDB
p.dbMu.Unlock()
return nil
}
// GetNative returns the native *sql.DB connection
func (p *SQLiteProvider) GetNative() (*sql.DB, error) {
if p.db == nil {

View File

@@ -2,6 +2,7 @@ package metrics
import (
"net/http"
"sync"
"time"
"github.com/bitechdev/ResolveSpec/pkg/logger"
@@ -19,7 +20,7 @@ type Provider interface {
DecRequestsInFlight()
// RecordDBQuery records metrics for a database query
RecordDBQuery(operation, table string, duration time.Duration, err error)
RecordDBQuery(operation, schema, entity, table string, duration time.Duration, err error)
// RecordCacheHit records a cache hit
RecordCacheHit(provider string)
@@ -46,21 +47,28 @@ type Provider interface {
Handler() http.Handler
}
// globalProvider is the global metrics provider
var globalProvider Provider
// globalProvider is the global metrics provider, protected by globalProviderMu.
var (
globalProviderMu sync.RWMutex
globalProvider Provider
)
// SetProvider sets the global metrics provider
// SetProvider sets the global metrics provider.
func SetProvider(p Provider) {
globalProviderMu.Lock()
globalProvider = p
globalProviderMu.Unlock()
}
// GetProvider returns the current metrics provider
// GetProvider returns the current metrics provider.
func GetProvider() Provider {
if globalProvider == nil {
// Return no-op provider if none is set
globalProviderMu.RLock()
p := globalProvider
globalProviderMu.RUnlock()
if p == nil {
return &NoOpProvider{}
}
return globalProvider
return p
}
// NoOpProvider is a no-op implementation of Provider
@@ -69,7 +77,7 @@ type NoOpProvider struct{}
func (n *NoOpProvider) RecordHTTPRequest(method, path, status string, duration time.Duration) {}
func (n *NoOpProvider) IncRequestsInFlight() {}
func (n *NoOpProvider) DecRequestsInFlight() {}
func (n *NoOpProvider) RecordDBQuery(operation, table string, duration time.Duration, err error) {
func (n *NoOpProvider) RecordDBQuery(operation, schema, entity, table string, duration time.Duration, err error) {
}
func (n *NoOpProvider) RecordCacheHit(provider string) {}
func (n *NoOpProvider) RecordCacheMiss(provider string) {}

View File

@@ -83,14 +83,14 @@ func NewPrometheusProvider(cfg *Config) *PrometheusProvider {
Help: "Database query duration in seconds",
Buckets: cfg.DBQueryBuckets,
},
[]string{"operation", "table"},
[]string{"operation", "schema", "entity", "table"},
),
dbQueryTotal: promauto.NewCounterVec(
prometheus.CounterOpts{
Name: metricName("db_queries_total"),
Help: "Total number of database queries",
},
[]string{"operation", "table", "status"},
[]string{"operation", "schema", "entity", "table", "status"},
),
cacheHits: promauto.NewCounterVec(
prometheus.CounterOpts{
@@ -204,13 +204,13 @@ func (p *PrometheusProvider) DecRequestsInFlight() {
}
// RecordDBQuery implements Provider interface
func (p *PrometheusProvider) RecordDBQuery(operation, table string, duration time.Duration, err error) {
func (p *PrometheusProvider) RecordDBQuery(operation, schema, entity, table string, duration time.Duration, err error) {
status := "success"
if err != nil {
status = "error"
}
p.dbQueryDuration.WithLabelValues(operation, table).Observe(duration.Seconds())
p.dbQueryTotal.WithLabelValues(operation, table, status).Inc()
p.dbQueryDuration.WithLabelValues(operation, schema, entity, table).Observe(duration.Seconds())
p.dbQueryTotal.WithLabelValues(operation, schema, entity, table, status).Inc()
}
// RecordCacheHit implements Provider interface

View File

@@ -196,6 +196,92 @@ func collectColumnsFromType(typ reflect.Type, columns *[]string) {
}
}
// GetColumnName extracts the database column name from a struct field.
// Priority: bun tag -> gorm tag -> json tag -> lowercase field name.
// This is the exported version for use by other packages.
func GetColumnName(field reflect.StructField) string {
return getColumnNameFromField(field)
}
// BuildJSONToDBColumnMap returns a map from JSON key names to database column names
// for the given model type. Only writable, non-relation fields are included.
// This is used to translate incoming request data (keyed by JSON names) into
// properly named database columns before insert/update operations.
func BuildJSONToDBColumnMap(modelType reflect.Type) map[string]string {
result := make(map[string]string)
buildJSONToDBMap(modelType, result, false)
return result
}
func buildJSONToDBMap(modelType reflect.Type, result map[string]string, scanOnly bool) {
for i := 0; i < modelType.NumField(); i++ {
field := modelType.Field(i)
if !field.IsExported() {
continue
}
bunTag := field.Tag.Get("bun")
gormTag := field.Tag.Get("gorm")
// Handle embedded structs
if field.Anonymous {
ft := field.Type
if ft.Kind() == reflect.Ptr {
ft = ft.Elem()
}
isScanOnly := scanOnly
if bunTag != "" && isBunFieldScanOnly(bunTag) {
isScanOnly = true
}
if ft.Kind() == reflect.Struct {
buildJSONToDBMap(ft, result, isScanOnly)
continue
}
}
if scanOnly {
continue
}
// Skip explicitly excluded fields
if bunTag == "-" || gormTag == "-" {
continue
}
// Skip scan-only fields
if bunTag != "" && isBunFieldScanOnly(bunTag) {
continue
}
// Skip bun relation fields
if bunTag != "" && (strings.Contains(bunTag, "rel:") || strings.Contains(bunTag, "join:") || strings.Contains(bunTag, "m2m:")) {
continue
}
// Skip gorm relation fields
if gormTag != "" && (strings.Contains(gormTag, "foreignKey:") || strings.Contains(gormTag, "references:") || strings.Contains(gormTag, "many2many:")) {
continue
}
// Get JSON key (how the field appears in incoming request data)
jsonKey := ""
if jsonTag := field.Tag.Get("json"); jsonTag != "" && jsonTag != "-" {
parts := strings.Split(jsonTag, ",")
if len(parts) > 0 && parts[0] != "" {
jsonKey = parts[0]
}
}
if jsonKey == "" {
jsonKey = strings.ToLower(field.Name)
}
// Get the actual DB column name (bun > gorm > json > field name)
dbColName := getColumnNameFromField(field)
result[jsonKey] = dbColName
}
}
// getColumnNameFromField extracts the column name from a struct field
// Priority: bun tag -> gorm tag -> json tag -> lowercase field name
func getColumnNameFromField(field reflect.StructField) string {

View File

@@ -823,12 +823,12 @@ func TestToSnakeCase(t *testing.T) {
{
name: "UserID",
input: "UserID",
expected: "user_i_d",
expected: "user_id",
},
{
name: "HTTPServer",
input: "HTTPServer",
expected: "h_t_t_p_server",
expected: "http_server",
},
{
name: "lowercase",
@@ -838,7 +838,7 @@ func TestToSnakeCase(t *testing.T) {
{
name: "UPPERCASE",
input: "UPPERCASE",
expected: "u_p_p_e_r_c_a_s_e",
expected: "uppercase",
},
{
name: "Single",

671
pkg/resolvemcp/README.md Normal file
View File

@@ -0,0 +1,671 @@
# resolvemcp
Package `resolvemcp` exposes registered database models as **Model Context Protocol (MCP) tools and resources** over HTTP/SSE transport. It mirrors the `resolvespec` package patterns — same model registration API, same filter/sort/pagination/preload options, same lifecycle hook system.
## Quick Start
```go
import (
"github.com/bitechdev/ResolveSpec/pkg/resolvemcp"
"github.com/gorilla/mux"
)
// 1. Create a handler
handler := resolvemcp.NewHandlerWithGORM(db, resolvemcp.Config{
BaseURL: "http://localhost:8080",
})
// 2. Register models
handler.RegisterModel("public", "users", &User{})
handler.RegisterModel("public", "orders", &Order{})
// 3. Mount routes
r := mux.NewRouter()
resolvemcp.SetupMuxRoutes(r, handler)
```
---
## Config
```go
type Config struct {
// BaseURL is the public-facing base URL of the server (e.g. "http://localhost:8080").
// Sent to MCP clients during the SSE handshake so they know where to POST messages.
// If empty, it is detected from each incoming request using the Host header and
// TLS state (X-Forwarded-Proto is honoured for reverse-proxy deployments).
BaseURL string
// BasePath is the URL path prefix where MCP endpoints are mounted (e.g. "/mcp").
// Required.
BasePath string
}
```
## Handler Creation
| Function | Description |
|---|---|
| `NewHandlerWithGORM(db *gorm.DB, cfg Config) *Handler` | Backed by GORM |
| `NewHandlerWithBun(db *bun.DB, cfg Config) *Handler` | Backed by Bun |
| `NewHandlerWithDB(db common.Database, cfg Config) *Handler` | Backed by any `common.Database` |
| `NewHandler(db common.Database, registry common.ModelRegistry, cfg Config) *Handler` | Full control over registry |
---
## Registering Models
```go
handler.RegisterModel(schema, entity string, model interface{}) error
```
- `schema` — database schema name (e.g. `"public"`), or empty string for no schema prefix.
- `entity` — table/entity name (e.g. `"users"`).
- `model` — a pointer to a struct (e.g. `&User{}`).
Each call immediately creates four MCP **tools** and one MCP **resource** for the model.
---
## HTTP Transports
`Config.BasePath` is required and used for all route registration.
`Config.BaseURL` is optional — when empty it is detected from each request.
Two transports are supported: **SSE** (legacy, two-endpoint) and **Streamable HTTP** (recommended, single-endpoint).
---
### SSE Transport
Two endpoints: `GET {BasePath}/sse` (subscribe) + `POST {BasePath}/message` (send).
#### Gorilla Mux
```go
resolvemcp.SetupMuxRoutes(r, handler)
```
| Route | Method | Description |
|---|---|---|
| `{BasePath}/sse` | GET | SSE connection — clients subscribe here |
| `{BasePath}/message` | POST | JSON-RPC — clients send requests here |
#### bunrouter
```go
resolvemcp.SetupBunRouterRoutes(router, handler)
```
#### Gin / net/http / Echo
```go
sse := handler.SSEServer()
engine.Any("/mcp/*path", gin.WrapH(sse)) // Gin
http.Handle("/mcp/", sse) // net/http
e.Any("/mcp/*", echo.WrapHandler(sse)) // Echo
```
---
### Streamable HTTP Transport
Single endpoint at `{BasePath}`. Handles POST (client→server) and GET (server→client streaming). Preferred for new integrations.
#### Gorilla Mux
```go
resolvemcp.SetupMuxStreamableHTTPRoutes(r, handler)
```
Mounts the handler at `{BasePath}` (all methods).
#### bunrouter
```go
resolvemcp.SetupBunRouterStreamableHTTPRoutes(router, handler)
```
Registers GET, POST, DELETE on `{BasePath}`.
#### Gin / net/http / Echo
```go
h := handler.StreamableHTTPServer()
// or: h := resolvemcp.NewStreamableHTTPHandler(handler)
engine.Any("/mcp", gin.WrapH(h)) // Gin
http.Handle("/mcp", h) // net/http
e.Any("/mcp", echo.WrapHandler(h)) // Echo
```
---
## OAuth2 Authentication
`resolvemcp` ships a full **MCP-standard OAuth2 authorization server** (`pkg/security.OAuthServer`) that MCP clients (Claude Desktop, Cursor, etc.) can discover and use automatically.
It can operate as:
- **Its own identity provider** — shows a login form, validates via `DatabaseAuthenticator.Login()`
- **An OAuth2 federation layer** — delegates to external providers (Google, GitHub, Microsoft, etc.)
- **Both simultaneously**
### Standard endpoints served
| Path | Spec | Purpose |
|---|---|---|
| `GET /.well-known/oauth-authorization-server` | RFC 8414 | MCP client auto-discovery |
| `POST /oauth/register` | RFC 7591 | Dynamic client registration |
| `GET /oauth/authorize` | OAuth 2.1 + PKCE | Start login (form or provider redirect) |
| `POST /oauth/authorize` | — | Login form submission |
| `POST /oauth/token` | OAuth 2.1 | Auth code → Bearer token exchange |
| `POST /oauth/token` (refresh) | OAuth 2.1 | Refresh token rotation |
| `GET /oauth/provider/callback` | Internal | External provider redirect target |
MCP clients send `Authorization: Bearer <token>` on all subsequent requests.
---
### Mode 1 — Direct login (server as identity provider)
```go
import "github.com/bitechdev/ResolveSpec/pkg/security"
db, _ := sql.Open("postgres", dsn)
auth := security.NewDatabaseAuthenticator(db)
handler := resolvemcp.NewHandlerWithGORM(gormDB, resolvemcp.Config{
BaseURL: "https://api.example.com",
BasePath: "/mcp",
})
// Enable the OAuth2 server — auth enables the login form
handler.EnableOAuthServer(security.OAuthServerConfig{
Issuer: "https://api.example.com",
}, auth)
provider, _ := security.NewCompositeSecurityProvider(auth, colSec, rowSec)
securityList, _ := security.NewSecurityList(provider)
security.RegisterSecurityHooks(handler, securityList)
http.ListenAndServe(":8080", handler.HTTPHandler(securityList))
```
MCP client flow:
1. Discovers server at `/.well-known/oauth-authorization-server`
2. Registers itself at `/oauth/register`
3. Redirects user to `/oauth/authorize` → login form appears
4. On submit, exchanges code at `/oauth/token` → receives `Authorization: Bearer` token
5. Uses token on all MCP tool calls
---
### Mode 2 — External provider (Google, GitHub, etc.)
The `RedirectURL` in the provider config must point to `/oauth/provider/callback` on this server.
```go
auth := security.NewDatabaseAuthenticator(db).WithOAuth2(security.OAuth2Config{
ClientID: os.Getenv("GOOGLE_CLIENT_ID"),
ClientSecret: os.Getenv("GOOGLE_CLIENT_SECRET"),
RedirectURL: "https://api.example.com/oauth/provider/callback",
Scopes: []string{"openid", "profile", "email"},
AuthURL: "https://accounts.google.com/o/oauth2/auth",
TokenURL: "https://oauth2.googleapis.com/token",
UserInfoURL: "https://www.googleapis.com/oauth2/v2/userinfo",
ProviderName: "google",
})
// Pass `auth` so the OAuth server supports persistence, introspection, and revocation.
// Google handles the end-user authentication flow via redirect.
handler.EnableOAuthServer(security.OAuthServerConfig{
Issuer: "https://api.example.com",
}, auth)
handler.RegisterOAuth2Provider(auth, "google")
```
---
### Mode 3 — Both (login form + external providers)
```go
handler.EnableOAuthServer(security.OAuthServerConfig{
Issuer: "https://api.example.com",
LoginTitle: "My App Login",
}, auth) // auth enables the username/password form
handler.RegisterOAuth2Provider(googleAuth, "google")
handler.RegisterOAuth2Provider(githubAuth, "github")
```
When external providers are registered they take priority; the login form is used as fallback when no providers are configured.
---
### Using `security.OAuthServer` standalone
The authorization server lives in `pkg/security` and can be used with any HTTP framework independently of `resolvemcp`:
```go
oauthSrv := security.NewOAuthServer(security.OAuthServerConfig{
Issuer: "https://api.example.com",
}, auth)
oauthSrv.RegisterExternalProvider(googleAuth, "google")
mux := http.NewServeMux()
mux.Handle("/", oauthSrv.HTTPHandler()) // mounts all OAuth2 routes
mux.Handle("/mcp/", myMCPHandler)
http.ListenAndServe(":8080", mux)
```
---
### Cookie-based flow (legacy)
For simple setups without full MCP OAuth2 compliance, use the legacy helpers that set a session cookie after external provider login:
```go
resolvemcp.SetupMuxOAuth2Routes(r, auth, resolvemcp.OAuth2RouteConfig{
ProviderName: "google",
LoginPath: "/auth/google/login",
CallbackPath: "/auth/google/callback",
AfterLoginRedirect: "/",
})
resolvemcp.SetupMuxRoutesWithAuth(r, handler, securityList)
```
---
## Security
`resolvemcp` integrates with the `security` package to provide per-entity access control, row-level security, and column-level security — the same system used by `resolvespec` and `restheadspec`.
### Wiring security hooks
```go
import "github.com/bitechdev/ResolveSpec/pkg/security"
securityList := security.NewSecurityList(mySecurityProvider)
resolvemcp.RegisterSecurityHooks(handler, securityList)
```
Call `RegisterSecurityHooks` **once**, after creating the handler and before registering models. It installs these controls automatically:
| Hook | Effect |
|---|---|
| `BeforeHandle` | Enforces per-entity operation rules (see below) |
| `BeforeRead` | Loads RLS/CLS rules, then injects a user-scoped WHERE clause |
| `AfterRead` | Masks/hides columns per column-security rules; writes audit log |
| `BeforeUpdate` | Blocks update if `CanUpdate` is false |
| `BeforeDelete` | Blocks delete if `CanDelete` is false |
### Per-entity operation rules
Use `RegisterModelWithRules` instead of `RegisterModel` to set access rules at registration time:
```go
import "github.com/bitechdev/ResolveSpec/pkg/modelregistry"
// Read-only entity
handler.RegisterModelWithRules("public", "audit_logs", &AuditLog{}, modelregistry.ModelRules{
CanRead: true,
CanCreate: false,
CanUpdate: false,
CanDelete: false,
})
// Public read, authenticated write
handler.RegisterModelWithRules("public", "products", &Product{}, modelregistry.ModelRules{
CanPublicRead: true,
CanRead: true,
CanCreate: true,
CanUpdate: true,
CanDelete: false,
})
```
To update rules for an already-registered model:
```go
handler.SetModelRules("public", "users", modelregistry.ModelRules{
CanRead: true,
CanCreate: true,
CanUpdate: true,
CanDelete: false,
})
```
`RegisterModel` (no rules) registers with all-allowed defaults (`CanRead/Create/Update/Delete = true`).
### ModelRules fields
| Field | Default | Description |
|---|---|---|
| `CanPublicRead` | `false` | Allow unauthenticated reads |
| `CanPublicCreate` | `false` | Allow unauthenticated creates |
| `CanPublicUpdate` | `false` | Allow unauthenticated updates |
| `CanPublicDelete` | `false` | Allow unauthenticated deletes |
| `CanRead` | `true` | Allow authenticated reads |
| `CanCreate` | `true` | Allow authenticated creates |
| `CanUpdate` | `true` | Allow authenticated updates |
| `CanDelete` | `true` | Allow authenticated deletes |
| `SecurityDisabled` | `false` | Skip all security checks for this model |
---
## MCP Tools
### Tool Naming
```
{operation}_{schema}_{entity} // e.g. read_public_users
{operation}_{entity} // e.g. read_users (when schema is empty)
```
Operations: `read`, `create`, `update`, `delete`.
### Read Tool — `read_{schema}_{entity}`
Fetch one or many records.
| Argument | Type | Description |
|---|---|---|
| `id` | string | Primary key value. Omit to return multiple records. |
| `limit` | number | Max records per page (recommended: 10100). |
| `offset` | number | Records to skip (offset-based pagination). |
| `cursor_forward` | string | PK of the **last** record on the current page (next-page cursor). |
| `cursor_backward` | string | PK of the **first** record on the current page (prev-page cursor). |
| `columns` | array | Column names to include. Omit for all columns. |
| `omit_columns` | array | Column names to exclude. |
| `filters` | array | Filter objects (see [Filtering](#filtering)). |
| `sort` | array | Sort objects (see [Sorting](#sorting)). |
| `preloads` | array | Relation preload objects (see [Preloading](#preloading)). |
**Response:**
```json
{
"success": true,
"data": [...],
"metadata": {
"total": 100,
"filtered": 100,
"count": 10,
"limit": 10,
"offset": 0
}
}
```
### Create Tool — `create_{schema}_{entity}`
Insert one or more records.
| Argument | Type | Description |
|---|---|---|
| `data` | object \| array | Single object or array of objects to insert. |
Array input runs inside a single transaction — all succeed or all fail.
**Response:**
```json
{ "success": true, "data": { ... } }
```
### Update Tool — `update_{schema}_{entity}`
Partially update an existing record. Only non-null, non-empty fields in `data` are applied; existing values are preserved for omitted fields.
| Argument | Type | Description |
|---|---|---|
| `id` | string | Primary key of the record. Can also be included inside `data`. |
| `data` | object (required) | Fields to update. |
**Response:**
```json
{ "success": true, "data": { ...merged record... } }
```
### Delete Tool — `delete_{schema}_{entity}`
Delete a record by primary key. **Irreversible.**
| Argument | Type | Description |
|---|---|---|
| `id` | string (required) | Primary key of the record to delete. |
**Response:**
```json
{ "success": true, "data": { ...deleted record... } }
```
### Annotation Tool — `resolvespec_annotate`
Store or retrieve freeform annotation records for any tool, model, or entity. Registered automatically on every handler.
| Argument | Type | Description |
|---|---|---|
| `tool_name` | string (required) | Key to annotate — an MCP tool name (e.g. `read_public_users`), a model name (e.g. `public.users`), or any other identifier. |
| `annotations` | object | Annotation data to persist. Omit to retrieve existing annotations instead. |
**Set annotations** (calls `resolvespec_set_annotation(tool_name, annotations)`):
```json
{ "tool_name": "read_public_users", "annotations": { "description": "Returns active users", "owner": "platform-team" } }
```
**Response:**
```json
{ "success": true, "tool_name": "read_public_users", "action": "set" }
```
**Get annotations** (calls `resolvespec_get_annotation(tool_name)`):
```json
{ "tool_name": "read_public_users" }
```
**Response:**
```json
{ "success": true, "tool_name": "read_public_users", "action": "get", "annotations": { ... } }
```
---
### Resource — `{schema}.{entity}`
Each model is also registered as an MCP resource with URI `schema.entity` (or just `entity` when schema is empty). Reading the resource returns up to 100 records as `application/json`.
---
## Filtering
Pass an array of filter objects to the `filters` argument:
```json
[
{ "column": "status", "operator": "=", "value": "active" },
{ "column": "age", "operator": ">", "value": 18, "logic_operator": "AND" },
{ "column": "role", "operator": "in", "value": ["admin", "editor"], "logic_operator": "OR" }
]
```
### Supported Operators
| Operator | Aliases | Description |
|---|---|---|
| `=` | `eq` | Equal |
| `!=` | `neq`, `<>` | Not equal |
| `>` | `gt` | Greater than |
| `>=` | `gte` | Greater than or equal |
| `<` | `lt` | Less than |
| `<=` | `lte` | Less than or equal |
| `like` | | SQL LIKE (case-sensitive) |
| `ilike` | | SQL ILIKE (case-insensitive) |
| `in` | | Value in list |
| `is_null` | | Column IS NULL |
| `is_not_null` | | Column IS NOT NULL |
### Logic Operators
- `"logic_operator": "AND"` (default) — filter is AND-chained with the previous condition.
- `"logic_operator": "OR"` — filter is OR-grouped with the previous condition.
Consecutive OR filters are grouped into a single `(cond1 OR cond2 OR ...)` clause.
---
## Sorting
```json
[
{ "column": "created_at", "direction": "desc" },
{ "column": "name", "direction": "asc" }
]
```
---
## Pagination
### Offset-Based
```json
{ "limit": 20, "offset": 40 }
```
### Cursor-Based
Cursor pagination uses a SQL `EXISTS` subquery for stable, efficient paging. Always pair with a `sort` argument.
```json
// Next page: pass the PK of the last record on the current page
{ "cursor_forward": "42", "limit": 20, "sort": [{"column": "id", "direction": "asc"}] }
// Previous page: pass the PK of the first record on the current page
{ "cursor_backward": "23", "limit": 20, "sort": [{"column": "id", "direction": "asc"}] }
```
---
## Preloading Relations
```json
[
{ "relation": "Profile" },
{ "relation": "Orders" }
]
```
Available relations are listed in each tool's description. Only relations defined on the model struct are valid.
---
## Hook System
Hooks let you intercept and modify CRUD operations at well-defined lifecycle points.
### Hook Types
| Constant | Fires |
|---|---|
| `BeforeHandle` | After model resolution, before operation dispatch (all CRUD) |
| `BeforeRead` / `AfterRead` | Around read queries |
| `BeforeCreate` / `AfterCreate` | Around insert |
| `BeforeUpdate` / `AfterUpdate` | Around update |
| `BeforeDelete` / `AfterDelete` | Around delete |
### Registering Hooks
```go
handler.Hooks().Register(resolvemcp.BeforeCreate, func(ctx *resolvemcp.HookContext) error {
// Inject a timestamp before insert
if data, ok := ctx.Data.(map[string]interface{}); ok {
data["created_at"] = time.Now()
}
return nil
})
// Register the same hook for multiple events
handler.Hooks().RegisterMultiple(
[]resolvemcp.HookType{resolvemcp.BeforeCreate, resolvemcp.BeforeUpdate},
auditHook,
)
```
### HookContext Fields
| Field | Type | Description |
|---|---|---|
| `Context` | `context.Context` | Request context |
| `Handler` | `*Handler` | The resolvemcp handler |
| `Schema` | `string` | Database schema name |
| `Entity` | `string` | Entity/table name |
| `Model` | `interface{}` | Registered model instance |
| `Options` | `common.RequestOptions` | Parsed request options (read operations) |
| `Operation` | `string` | `"read"`, `"create"`, `"update"`, or `"delete"` |
| `ID` | `string` | Primary key from request (read/update/delete) |
| `Data` | `interface{}` | Input data (create/update — modifiable) |
| `Result` | `interface{}` | Output data (set by After hooks) |
| `Error` | `error` | Operation error, if any |
| `Query` | `common.SelectQuery` | Live query object (available in `BeforeRead`) |
| `Tx` | `common.Database` | Database/transaction handle |
| `Abort` | `bool` | Set to `true` to abort the operation |
| `AbortMessage` | `string` | Error message returned when aborting |
| `AbortCode` | `int` | Optional status code for the abort |
### Aborting an Operation
```go
handler.Hooks().Register(resolvemcp.BeforeDelete, func(ctx *resolvemcp.HookContext) error {
ctx.Abort = true
ctx.AbortMessage = "deletion is disabled"
return nil
})
```
### Managing Hooks
```go
registry := handler.Hooks()
registry.HasHooks(resolvemcp.BeforeCreate) // bool
registry.Clear(resolvemcp.BeforeCreate) // remove hooks for one type
registry.ClearAll() // remove all hooks
```
---
## Context Helpers
Request metadata is threaded through `context.Context` during handler execution. Hooks and custom tools can read it:
```go
schema := resolvemcp.GetSchema(ctx)
entity := resolvemcp.GetEntity(ctx)
tableName := resolvemcp.GetTableName(ctx)
model := resolvemcp.GetModel(ctx)
modelPtr := resolvemcp.GetModelPtr(ctx)
```
You can also set values manually (e.g. in middleware):
```go
ctx = resolvemcp.WithSchema(ctx, "tenant_a")
```
---
## Adding Custom MCP Tools
Access the underlying `*server.MCPServer` to register additional tools:
```go
mcpServer := handler.MCPServer()
mcpServer.AddTool(myTool, myHandler)
```
---
## Table Name Resolution
The handler resolves table names in priority order:
1. `TableNameProvider` interface — `TableName() string` (can return `"schema.table"`)
2. `SchemaProvider` interface — `SchemaName() string` (combined with entity name)
3. Fallback: `schema.entity` (or `schema_entity` for SQLite)

View File

@@ -0,0 +1,107 @@
package resolvemcp
import (
"context"
"encoding/json"
"fmt"
"github.com/mark3labs/mcp-go/mcp"
)
const annotationToolName = "resolvespec_annotate"
// registerAnnotationTool adds the resolvespec_annotate tool to the MCP server.
// The tool lets models/entities store and retrieve freeform annotation records
// using the resolvespec_set_annotation / resolvespec_get_annotation database procedures.
func registerAnnotationTool(h *Handler) {
tool := mcp.NewTool(annotationToolName,
mcp.WithDescription(
"Store or retrieve annotations for any MCP tool, model, or entity.\n\n"+
"To set annotations: provide both 'tool_name' and 'annotations'. "+
"Calls resolvespec_set_annotation(tool_name, annotations) to persist the data.\n\n"+
"To get annotations: provide only 'tool_name'. "+
"Calls resolvespec_get_annotation(tool_name) and returns the stored annotations.\n\n"+
"'tool_name' may be any identifier: an MCP tool name (e.g. 'read_public_users'), "+
"a model/entity name (e.g. 'public.users'), or any other key.",
),
mcp.WithString("tool_name",
mcp.Description("Name of the tool, model, or entity to annotate (e.g. 'read_public_users', 'public.users')."),
mcp.Required(),
),
mcp.WithObject("annotations",
mcp.Description("Annotation data to store. Omit to retrieve existing annotations instead of setting them."),
),
)
h.mcpServer.AddTool(tool, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
args := req.GetArguments()
toolName, ok := args["tool_name"].(string)
if !ok || toolName == "" {
return mcp.NewToolResultError("missing required argument: tool_name"), nil
}
annotations, hasAnnotations := args["annotations"]
if hasAnnotations && annotations != nil {
return executeSetAnnotation(ctx, h, toolName, annotations)
}
return executeGetAnnotation(ctx, h, toolName)
})
}
func executeSetAnnotation(ctx context.Context, h *Handler, toolName string, annotations interface{}) (*mcp.CallToolResult, error) {
jsonBytes, err := json.Marshal(annotations)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to marshal annotations: %v", err)), nil
}
_, err = h.db.Exec(ctx, "SELECT resolvespec_set_annotation($1, $2)", toolName, string(jsonBytes))
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to set annotation: %v", err)), nil
}
return marshalResult(map[string]interface{}{
"success": true,
"tool_name": toolName,
"action": "set",
})
}
func executeGetAnnotation(ctx context.Context, h *Handler, toolName string) (*mcp.CallToolResult, error) {
var rows []map[string]interface{}
err := h.db.Query(ctx, &rows, "SELECT resolvespec_get_annotation($1)", toolName)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to get annotation: %v", err)), nil
}
var annotations interface{}
if len(rows) > 0 {
// The procedure returns a single value; extract the first column of the first row.
for _, v := range rows[0] {
annotations = v
break
}
}
// If the value is a []byte or string containing JSON, decode it so it round-trips cleanly.
switch v := annotations.(type) {
case []byte:
var decoded interface{}
if json.Unmarshal(v, &decoded) == nil {
annotations = decoded
}
case string:
var decoded interface{}
if json.Unmarshal([]byte(v), &decoded) == nil {
annotations = decoded
}
}
return marshalResult(map[string]interface{}{
"success": true,
"tool_name": toolName,
"action": "get",
"annotations": annotations,
})
}

71
pkg/resolvemcp/context.go Normal file
View File

@@ -0,0 +1,71 @@
package resolvemcp
import "context"
type contextKey string
const (
contextKeySchema contextKey = "schema"
contextKeyEntity contextKey = "entity"
contextKeyTableName contextKey = "tableName"
contextKeyModel contextKey = "model"
contextKeyModelPtr contextKey = "modelPtr"
)
func WithSchema(ctx context.Context, schema string) context.Context {
return context.WithValue(ctx, contextKeySchema, schema)
}
func GetSchema(ctx context.Context) string {
if v := ctx.Value(contextKeySchema); v != nil {
return v.(string)
}
return ""
}
func WithEntity(ctx context.Context, entity string) context.Context {
return context.WithValue(ctx, contextKeyEntity, entity)
}
func GetEntity(ctx context.Context) string {
if v := ctx.Value(contextKeyEntity); v != nil {
return v.(string)
}
return ""
}
func WithTableName(ctx context.Context, tableName string) context.Context {
return context.WithValue(ctx, contextKeyTableName, tableName)
}
func GetTableName(ctx context.Context) string {
if v := ctx.Value(contextKeyTableName); v != nil {
return v.(string)
}
return ""
}
func WithModel(ctx context.Context, model interface{}) context.Context {
return context.WithValue(ctx, contextKeyModel, model)
}
func GetModel(ctx context.Context) interface{} {
return ctx.Value(contextKeyModel)
}
func WithModelPtr(ctx context.Context, modelPtr interface{}) context.Context {
return context.WithValue(ctx, contextKeyModelPtr, modelPtr)
}
func GetModelPtr(ctx context.Context) interface{} {
return ctx.Value(contextKeyModelPtr)
}
func withRequestData(ctx context.Context, schema, entity, tableName string, model, modelPtr interface{}) context.Context {
ctx = WithSchema(ctx, schema)
ctx = WithEntity(ctx, entity)
ctx = WithTableName(ctx, tableName)
ctx = WithModel(ctx, model)
ctx = WithModelPtr(ctx, modelPtr)
return ctx
}

161
pkg/resolvemcp/cursor.go Normal file
View File

@@ -0,0 +1,161 @@
package resolvemcp
// Cursor-based pagination adapted from pkg/resolvespec/cursor.go.
import (
"fmt"
"strings"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/logger"
)
type cursorDirection int
const (
cursorForward cursorDirection = 1
cursorBackward cursorDirection = -1
)
// getCursorFilter generates a SQL EXISTS subquery for cursor-based pagination.
// expandJoins is an optional map[alias]string of JOIN clauses for join-column sort support.
func getCursorFilter(
tableName string,
pkName string,
modelColumns []string,
options common.RequestOptions,
expandJoins map[string]string,
) (string, error) {
fullTableName := tableName
if strings.Contains(tableName, ".") {
tableName = strings.SplitN(tableName, ".", 2)[1]
}
cursorID, direction := getActiveCursor(options)
if cursorID == "" {
return "", fmt.Errorf("no cursor provided for table %s", tableName)
}
sortItems := options.Sort
if len(sortItems) == 0 {
return "", fmt.Errorf("no sort columns defined")
}
var whereClauses []string
joinSQL := ""
reverse := direction < 0
for _, s := range sortItems {
col := strings.Trim(strings.TrimSpace(s.Column), "()")
if col == "" {
continue
}
parts := strings.Split(col, ".")
field := strings.TrimSpace(parts[len(parts)-1])
prefix := strings.Join(parts[:len(parts)-1], ".")
desc := strings.EqualFold(s.Direction, "desc")
if reverse {
desc = !desc
}
cursorCol, targetCol, isJoin, err := resolveCursorColumn(field, prefix, tableName, modelColumns)
if err != nil {
logger.Warn("Skipping invalid sort column %q: %v", col, err)
continue
}
if isJoin {
if expandJoins != nil {
if joinClause, ok := expandJoins[prefix]; ok {
jSQL, cRef := rewriteCursorJoin(joinClause, tableName, prefix)
joinSQL = jSQL
cursorCol = cRef + "." + field
targetCol = prefix + "." + field
}
}
if cursorCol == "" {
logger.Warn("Skipping cursor sort column %q: join alias %q not in expandJoins", col, prefix)
continue
}
}
op := "<"
if desc {
op = ">"
}
whereClauses = append(whereClauses, fmt.Sprintf("%s %s %s", cursorCol, op, targetCol))
}
if len(whereClauses) == 0 {
return "", fmt.Errorf("no valid sort columns after filtering")
}
orSQL := buildCursorPriorityChain(whereClauses)
query := fmt.Sprintf(`EXISTS (
SELECT 1
FROM %s cursor_select
%s
WHERE cursor_select.%s = %s
AND (%s)
)`,
fullTableName,
joinSQL,
pkName,
cursorID,
orSQL,
)
return query, nil
}
func getActiveCursor(options common.RequestOptions) (id string, direction cursorDirection) {
if options.CursorForward != "" {
return options.CursorForward, cursorForward
}
if options.CursorBackward != "" {
return options.CursorBackward, cursorBackward
}
return "", 0
}
func resolveCursorColumn(field, prefix, tableName string, modelColumns []string) (cursorCol, targetCol string, isJoin bool, err error) {
if strings.Contains(field, "->") {
return "cursor_select." + field, tableName + "." + field, false, nil
}
if modelColumns != nil {
for _, col := range modelColumns {
if strings.EqualFold(col, field) {
return "cursor_select." + field, tableName + "." + field, false, nil
}
}
} else {
return "cursor_select." + field, tableName + "." + field, false, nil
}
if prefix != "" && prefix != tableName {
return "", "", true, nil
}
return "", "", false, fmt.Errorf("invalid column: %s", field)
}
func rewriteCursorJoin(joinClause, mainTable, alias string) (joinSQL, cursorAlias string) {
joinSQL = strings.ReplaceAll(joinClause, mainTable+".", "cursor_select.")
cursorAlias = "cursor_select_" + alias
joinSQL = strings.ReplaceAll(joinSQL, " "+alias+" ", " "+cursorAlias+" ")
joinSQL = strings.ReplaceAll(joinSQL, " "+alias+".", " "+cursorAlias+".")
return joinSQL, cursorAlias
}
func buildCursorPriorityChain(clauses []string) string {
var or []string
for i := 0; i < len(clauses); i++ {
and := strings.Join(clauses[:i+1], "\n AND ")
or = append(or, "("+and+")")
}
return strings.Join(or, "\n OR ")
}

761
pkg/resolvemcp/handler.go Normal file
View File

@@ -0,0 +1,761 @@
package resolvemcp
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"net/http"
"reflect"
"strings"
"sync"
"github.com/mark3labs/mcp-go/server"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
"github.com/bitechdev/ResolveSpec/pkg/reflection"
"github.com/bitechdev/ResolveSpec/pkg/security"
)
// Handler exposes registered database models as MCP tools and resources.
type Handler struct {
db common.Database
registry common.ModelRegistry
hooks *HookRegistry
mcpServer *server.MCPServer
config Config
name string
version string
oauth2Regs []oauth2Registration
oauthSrv *security.OAuthServer
}
// NewHandler creates a Handler with the given database, model registry, and config.
func NewHandler(db common.Database, registry common.ModelRegistry, cfg Config) *Handler {
h := &Handler{
db: db,
registry: registry,
hooks: NewHookRegistry(),
mcpServer: server.NewMCPServer("resolvemcp", "1.0.0"),
config: cfg,
name: "resolvemcp",
version: "1.0.0",
}
registerAnnotationTool(h)
return h
}
// Hooks returns the hook registry.
func (h *Handler) Hooks() *HookRegistry {
return h.hooks
}
// GetDatabase returns the underlying database.
func (h *Handler) GetDatabase() common.Database {
return h.db
}
// MCPServer returns the underlying MCP server, e.g. to add custom tools.
func (h *Handler) MCPServer() *server.MCPServer {
return h.mcpServer
}
// SSEServer returns an http.Handler that serves MCP over SSE.
// Config.BasePath must be set. Config.BaseURL is used when set; if empty it is
// detected automatically from each incoming request.
func (h *Handler) SSEServer() http.Handler {
if h.config.BaseURL != "" {
return h.newSSEServer(h.config.BaseURL, h.config.BasePath)
}
return &dynamicSSEHandler{h: h}
}
// StreamableHTTPServer returns an http.Handler that serves MCP over the streamable HTTP transport.
// Unlike SSE (which requires two endpoints), streamable HTTP uses a single endpoint for all
// client-server communication (POST for requests, GET for server-initiated messages).
// Mount the returned handler at the desired path; the path itself becomes the MCP endpoint.
func (h *Handler) StreamableHTTPServer() http.Handler {
return server.NewStreamableHTTPServer(h.mcpServer)
}
// newSSEServer creates a concrete *server.SSEServer for known baseURL and basePath values.
func (h *Handler) newSSEServer(baseURL, basePath string) *server.SSEServer {
return server.NewSSEServer(
h.mcpServer,
server.WithBaseURL(baseURL),
server.WithStaticBasePath(basePath),
)
}
// dynamicSSEHandler detects BaseURL from each request and delegates to a cached
// *server.SSEServer per detected baseURL. Used when Config.BaseURL is empty.
type dynamicSSEHandler struct {
h *Handler
mu sync.Mutex
pool map[string]*server.SSEServer
}
func (d *dynamicSSEHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
baseURL := requestBaseURL(r)
d.mu.Lock()
if d.pool == nil {
d.pool = make(map[string]*server.SSEServer)
}
s, ok := d.pool[baseURL]
if !ok {
s = d.h.newSSEServer(baseURL, d.h.config.BasePath)
d.pool[baseURL] = s
}
d.mu.Unlock()
s.ServeHTTP(w, r)
}
// requestBaseURL builds the base URL from an incoming request.
// It honours the X-Forwarded-Proto header for deployments behind a proxy.
func requestBaseURL(r *http.Request) string {
scheme := "http"
if r.TLS != nil {
scheme = "https"
}
if proto := r.Header.Get("X-Forwarded-Proto"); proto != "" {
scheme = proto
}
return scheme + "://" + r.Host
}
// RegisterModel registers a model and immediately exposes it as MCP tools and a resource.
func (h *Handler) RegisterModel(schema, entity string, model interface{}) error {
fullName := buildModelName(schema, entity)
if err := h.registry.RegisterModel(fullName, model); err != nil {
return err
}
registerModelTools(h, schema, entity, model)
return nil
}
// RegisterModelWithRules registers a model and sets per-entity operation rules
// (CanRead, CanCreate, CanUpdate, CanDelete, CanPublic*, SecurityDisabled).
// Requires RegisterSecurityHooks to have been called for the rules to be enforced.
func (h *Handler) RegisterModelWithRules(schema, entity string, model interface{}, rules modelregistry.ModelRules) error {
reg, ok := h.registry.(*modelregistry.DefaultModelRegistry)
if !ok {
return fmt.Errorf("resolvemcp: registry does not support model rules (use NewHandlerWithGORM/Bun/DB)")
}
fullName := buildModelName(schema, entity)
if err := reg.RegisterModelWithRules(fullName, model, rules); err != nil {
return err
}
registerModelTools(h, schema, entity, model)
return nil
}
// SetModelRules updates the operation rules for an already-registered model.
// Requires RegisterSecurityHooks to have been called for the rules to be enforced.
func (h *Handler) SetModelRules(schema, entity string, rules modelregistry.ModelRules) error {
reg, ok := h.registry.(*modelregistry.DefaultModelRegistry)
if !ok {
return fmt.Errorf("resolvemcp: registry does not support model rules (use NewHandlerWithGORM/Bun/DB)")
}
return reg.SetModelRules(buildModelName(schema, entity), rules)
}
// buildModelName builds the registry key for a model (same format as resolvespec).
func buildModelName(schema, entity string) string {
if schema == "" {
return entity
}
return fmt.Sprintf("%s.%s", schema, entity)
}
// getTableName returns the fully qualified table name for a model.
func (h *Handler) getTableName(schema, entity string, model interface{}) string {
schemaName, tableName := h.getSchemaAndTable(schema, entity, model)
if schemaName != "" {
if h.db.DriverName() == "sqlite" {
return fmt.Sprintf("%s_%s", schemaName, tableName)
}
return fmt.Sprintf("%s.%s", schemaName, tableName)
}
return tableName
}
func (h *Handler) getSchemaAndTable(defaultSchema, entity string, model interface{}) (schema, table string) {
if tableProvider, ok := model.(common.TableNameProvider); ok {
tableName := tableProvider.TableName()
if idx := strings.LastIndex(tableName, "."); idx != -1 {
return tableName[:idx], tableName[idx+1:]
}
if schemaProvider, ok := model.(common.SchemaProvider); ok {
return schemaProvider.SchemaName(), tableName
}
return defaultSchema, tableName
}
if schemaProvider, ok := model.(common.SchemaProvider); ok {
return schemaProvider.SchemaName(), entity
}
return defaultSchema, entity
}
// recoverPanic catches a panic from the current goroutine and returns it as an error.
// Usage: defer recoverPanic(&returnedErr)
func recoverPanic(err *error) {
if r := recover(); r != nil {
msg := fmt.Sprintf("%v", r)
logger.Error("[resolvemcp] panic recovered: %s", msg)
*err = fmt.Errorf("internal error: %s", msg)
}
}
// executeRead reads records from the database and returns raw data + metadata.
func (h *Handler) executeRead(ctx context.Context, schema, entity, id string, options common.RequestOptions) (_ interface{}, _ *common.Metadata, retErr error) {
defer recoverPanic(&retErr)
model, err := h.registry.GetModelByEntity(schema, entity)
if err != nil {
return nil, nil, fmt.Errorf("model not found: %w", err)
}
unwrapped, err := common.ValidateAndUnwrapModel(model)
if err != nil {
return nil, nil, fmt.Errorf("invalid model: %w", err)
}
model = unwrapped.Model
modelType := unwrapped.ModelType
tableName := h.getTableName(schema, entity, model)
ctx = withRequestData(ctx, schema, entity, tableName, model, unwrapped.ModelPtr)
validator := common.NewColumnValidator(model)
options = validator.FilterRequestOptions(options)
// BeforeHandle hook
hookCtx := &HookContext{
Context: ctx,
Handler: h,
Schema: schema,
Entity: entity,
Model: model,
Operation: "read",
Options: options,
ID: id,
Tx: h.db,
}
if err := h.hooks.Execute(BeforeHandle, hookCtx); err != nil {
return nil, nil, err
}
sliceType := reflect.SliceOf(reflect.PointerTo(modelType))
modelPtr := reflect.New(sliceType).Interface()
query := h.db.NewSelect().Model(modelPtr)
tempInstance := reflect.New(modelType).Interface()
if provider, ok := tempInstance.(common.TableNameProvider); !ok || provider.TableName() == "" {
query = query.Table(tableName)
}
// Column selection
if len(options.Columns) == 0 && len(options.ComputedColumns) > 0 {
options.Columns = reflection.GetSQLModelColumns(model)
}
for _, col := range options.Columns {
query = query.Column(reflection.ExtractSourceColumn(col))
}
for _, cu := range options.ComputedColumns {
query = query.ColumnExpr(fmt.Sprintf("(%s) AS %s", cu.Expression, cu.Name))
}
// Filters
query = h.applyFilters(query, options.Filters)
// Custom operators
for _, customOp := range options.CustomOperators {
query = query.Where(customOp.SQL)
}
// Sorting
for _, sort := range options.Sort {
direction := "ASC"
if strings.EqualFold(sort.Direction, "desc") {
direction = "DESC"
}
query = query.Order(fmt.Sprintf("%s %s", sort.Column, direction))
}
// Cursor pagination
if options.CursorForward != "" || options.CursorBackward != "" {
pkName := reflection.GetPrimaryKeyName(model)
modelColumns := reflection.GetModelColumns(model)
if len(options.Sort) == 0 {
options.Sort = []common.SortOption{{Column: pkName, Direction: "ASC"}}
}
// expandJoins is empty for resolvemcp — no custom SQL join support yet
cursorFilter, err := getCursorFilter(tableName, pkName, modelColumns, options, nil)
if err != nil {
return nil, nil, fmt.Errorf("cursor error: %w", err)
}
if cursorFilter != "" {
sanitized := common.SanitizeWhereClause(cursorFilter, reflection.ExtractTableNameOnly(tableName), &options)
sanitized = common.EnsureOuterParentheses(sanitized)
if sanitized != "" {
query = query.Where(sanitized)
}
}
}
// Count — must happen before preloads are applied; Bun panics when counting with relations.
total, err := query.Count(ctx)
if err != nil {
return nil, nil, fmt.Errorf("error counting records: %w", err)
}
// Pagination
if options.Limit != nil && *options.Limit > 0 {
query = query.Limit(*options.Limit)
}
if options.Offset != nil && *options.Offset > 0 {
query = query.Offset(*options.Offset)
}
// Preloads — applied after count to avoid Bun panic when counting with relations.
if len(options.Preload) > 0 {
var preloadErr error
query, preloadErr = h.applyPreloads(model, query, options.Preload)
if preloadErr != nil {
return nil, nil, fmt.Errorf("failed to apply preloads: %w", preloadErr)
}
}
// BeforeRead hook
hookCtx.Query = query
if err := h.hooks.Execute(BeforeRead, hookCtx); err != nil {
return nil, nil, err
}
var data interface{}
if id != "" {
singleResult := reflect.New(modelType).Interface()
pkName := reflection.GetPrimaryKeyName(singleResult)
query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id)
if err := query.Scan(ctx, singleResult); err != nil {
if err == sql.ErrNoRows {
return nil, nil, fmt.Errorf("record not found")
}
return nil, nil, fmt.Errorf("query error: %w", err)
}
data = singleResult
} else {
if err := query.Scan(ctx, modelPtr); err != nil {
return nil, nil, fmt.Errorf("query error: %w", err)
}
data = reflect.ValueOf(modelPtr).Elem().Interface()
}
limit := 0
offset := 0
if options.Limit != nil {
limit = *options.Limit
}
if options.Offset != nil {
offset = *options.Offset
}
// Count is the number of records in this page, not the total.
var pageCount int64
if id != "" {
pageCount = 1
} else {
pageCount = int64(reflect.ValueOf(data).Len())
}
metadata := &common.Metadata{
Total: int64(total),
Filtered: int64(total),
Count: pageCount,
Limit: limit,
Offset: offset,
}
// AfterRead hook
hookCtx.Result = data
if err := h.hooks.Execute(AfterRead, hookCtx); err != nil {
return nil, nil, err
}
return data, metadata, nil
}
// executeCreate inserts one or more records.
func (h *Handler) executeCreate(ctx context.Context, schema, entity string, data interface{}) (_ interface{}, retErr error) {
defer recoverPanic(&retErr)
model, err := h.registry.GetModelByEntity(schema, entity)
if err != nil {
return nil, fmt.Errorf("model not found: %w", err)
}
result, err := common.ValidateAndUnwrapModel(model)
if err != nil {
return nil, fmt.Errorf("invalid model: %w", err)
}
model = result.Model
tableName := h.getTableName(schema, entity, model)
ctx = withRequestData(ctx, schema, entity, tableName, model, result.ModelPtr)
hookCtx := &HookContext{
Context: ctx,
Handler: h,
Schema: schema,
Entity: entity,
Model: model,
Operation: "create",
Data: data,
Tx: h.db,
}
if err := h.hooks.Execute(BeforeHandle, hookCtx); err != nil {
return nil, err
}
if err := h.hooks.Execute(BeforeCreate, hookCtx); err != nil {
return nil, err
}
// Use potentially modified data
data = hookCtx.Data
switch v := data.(type) {
case map[string]interface{}:
query := h.db.NewInsert().Table(tableName)
for key, value := range v {
query = query.Value(key, value)
}
if _, err := query.Exec(ctx); err != nil {
return nil, fmt.Errorf("create error: %w", err)
}
hookCtx.Result = v
if err := h.hooks.Execute(AfterCreate, hookCtx); err != nil {
return nil, fmt.Errorf("AfterCreate hook failed: %w", err)
}
return v, nil
case []interface{}:
results := make([]interface{}, 0, len(v))
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
for _, item := range v {
itemMap, ok := item.(map[string]interface{})
if !ok {
return fmt.Errorf("each item must be an object")
}
q := tx.NewInsert().Table(tableName)
for key, value := range itemMap {
q = q.Value(key, value)
}
if _, err := q.Exec(ctx); err != nil {
return err
}
results = append(results, item)
}
return nil
})
if err != nil {
return nil, fmt.Errorf("batch create error: %w", err)
}
hookCtx.Result = results
if err := h.hooks.Execute(AfterCreate, hookCtx); err != nil {
return nil, fmt.Errorf("AfterCreate hook failed: %w", err)
}
return results, nil
default:
return nil, fmt.Errorf("data must be an object or array of objects")
}
}
// executeUpdate updates a record by ID.
func (h *Handler) executeUpdate(ctx context.Context, schema, entity, id string, data interface{}) (_ interface{}, retErr error) {
defer recoverPanic(&retErr)
model, err := h.registry.GetModelByEntity(schema, entity)
if err != nil {
return nil, fmt.Errorf("model not found: %w", err)
}
result, err := common.ValidateAndUnwrapModel(model)
if err != nil {
return nil, fmt.Errorf("invalid model: %w", err)
}
model = result.Model
tableName := h.getTableName(schema, entity, model)
ctx = withRequestData(ctx, schema, entity, tableName, model, result.ModelPtr)
updates, ok := data.(map[string]interface{})
if !ok {
return nil, fmt.Errorf("data must be an object")
}
if id == "" {
if idVal, exists := updates["id"]; exists {
id = fmt.Sprintf("%v", idVal)
}
}
if id == "" {
return nil, fmt.Errorf("update requires an ID")
}
pkName := reflection.GetPrimaryKeyName(model)
var updateResult interface{}
err = h.db.RunInTransaction(ctx, func(tx common.Database) error {
// Read existing record
modelType := reflect.TypeOf(model)
if modelType.Kind() == reflect.Ptr {
modelType = modelType.Elem()
}
existingRecord := reflect.New(modelType).Interface()
selectQuery := tx.NewSelect().Model(existingRecord).Column("*").
Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id)
if err := selectQuery.ScanModel(ctx); err != nil {
if err == sql.ErrNoRows {
return fmt.Errorf("no records found to update")
}
return fmt.Errorf("error fetching existing record: %w", err)
}
// Convert to map
existingMap := make(map[string]interface{})
jsonData, err := json.Marshal(existingRecord)
if err != nil {
return fmt.Errorf("error marshaling existing record: %w", err)
}
if err := json.Unmarshal(jsonData, &existingMap); err != nil {
return fmt.Errorf("error unmarshaling existing record: %w", err)
}
hookCtx := &HookContext{
Context: ctx,
Handler: h,
Schema: schema,
Entity: entity,
Model: model,
Operation: "update",
ID: id,
Data: updates,
Tx: tx,
}
if err := h.hooks.Execute(BeforeUpdate, hookCtx); err != nil {
return err
}
if modifiedData, ok := hookCtx.Data.(map[string]interface{}); ok {
updates = modifiedData
}
// Merge non-nil, non-empty values
for key, newValue := range updates {
if newValue == nil {
continue
}
if strVal, ok := newValue.(string); ok && strVal == "" {
continue
}
existingMap[key] = newValue
}
q := tx.NewUpdate().Table(tableName).SetMap(existingMap).
Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id)
res, err := q.Exec(ctx)
if err != nil {
return fmt.Errorf("error updating record: %w", err)
}
if res.RowsAffected() == 0 {
return fmt.Errorf("no records found to update")
}
updateResult = existingMap
hookCtx.Result = updateResult
return h.hooks.Execute(AfterUpdate, hookCtx)
})
if err != nil {
return nil, err
}
return updateResult, nil
}
// executeDelete deletes a record by ID.
func (h *Handler) executeDelete(ctx context.Context, schema, entity, id string) (_ interface{}, retErr error) {
defer recoverPanic(&retErr)
if id == "" {
return nil, fmt.Errorf("delete requires an ID")
}
model, err := h.registry.GetModelByEntity(schema, entity)
if err != nil {
return nil, fmt.Errorf("model not found: %w", err)
}
result, err := common.ValidateAndUnwrapModel(model)
if err != nil {
return nil, fmt.Errorf("invalid model: %w", err)
}
model = result.Model
tableName := h.getTableName(schema, entity, model)
ctx = withRequestData(ctx, schema, entity, tableName, model, result.ModelPtr)
pkName := reflection.GetPrimaryKeyName(model)
hookCtx := &HookContext{
Context: ctx,
Handler: h,
Schema: schema,
Entity: entity,
Model: model,
Operation: "delete",
ID: id,
Tx: h.db,
}
if err := h.hooks.Execute(BeforeHandle, hookCtx); err != nil {
return nil, err
}
if err := h.hooks.Execute(BeforeDelete, hookCtx); err != nil {
return nil, err
}
modelType := reflect.TypeOf(model)
if modelType.Kind() == reflect.Ptr {
modelType = modelType.Elem()
}
var recordToDelete interface{}
err = h.db.RunInTransaction(ctx, func(tx common.Database) error {
record := reflect.New(modelType).Interface()
selectQuery := tx.NewSelect().Model(record).
Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id)
if err := selectQuery.ScanModel(ctx); err != nil {
if err == sql.ErrNoRows {
return fmt.Errorf("record not found")
}
return fmt.Errorf("error fetching record: %w", err)
}
res, err := tx.NewDelete().Table(tableName).
Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id).
Exec(ctx)
if err != nil {
return fmt.Errorf("delete error: %w", err)
}
if res.RowsAffected() == 0 {
return fmt.Errorf("record not found or already deleted")
}
recordToDelete = record
hookCtx.Tx = tx
hookCtx.Result = record
return h.hooks.Execute(AfterDelete, hookCtx)
})
if err != nil {
return nil, err
}
logger.Info("[resolvemcp] Deleted record %s from %s.%s", id, schema, entity)
return recordToDelete, nil
}
// applyFilters applies all filters with OR grouping logic.
func (h *Handler) applyFilters(query common.SelectQuery, filters []common.FilterOption) common.SelectQuery {
if len(filters) == 0 {
return query
}
i := 0
for i < len(filters) {
startORGroup := i+1 < len(filters) && strings.EqualFold(filters[i+1].LogicOperator, "OR")
if startORGroup {
orGroup := []common.FilterOption{filters[i]}
j := i + 1
for j < len(filters) && strings.EqualFold(filters[j].LogicOperator, "OR") {
orGroup = append(orGroup, filters[j])
j++
}
query = h.applyFilterGroup(query, orGroup)
i = j
} else {
condition, args := h.buildFilterCondition(filters[i])
if condition != "" {
query = query.Where(condition, args...)
}
i++
}
}
return query
}
func (h *Handler) applyFilterGroup(query common.SelectQuery, filters []common.FilterOption) common.SelectQuery {
var conditions []string
var args []interface{}
for _, filter := range filters {
condition, filterArgs := h.buildFilterCondition(filter)
if condition != "" {
conditions = append(conditions, condition)
args = append(args, filterArgs...)
}
}
if len(conditions) == 0 {
return query
}
if len(conditions) == 1 {
return query.Where(conditions[0], args...)
}
return query.Where("("+strings.Join(conditions, " OR ")+")", args...)
}
func (h *Handler) buildFilterCondition(filter common.FilterOption) (condition string, args []interface{}) {
switch filter.Operator {
case "eq", "=":
return fmt.Sprintf("%s = ?", filter.Column), []interface{}{filter.Value}
case "neq", "!=", "<>":
return fmt.Sprintf("%s != ?", filter.Column), []interface{}{filter.Value}
case "gt", ">":
return fmt.Sprintf("%s > ?", filter.Column), []interface{}{filter.Value}
case "gte", ">=":
return fmt.Sprintf("%s >= ?", filter.Column), []interface{}{filter.Value}
case "lt", "<":
return fmt.Sprintf("%s < ?", filter.Column), []interface{}{filter.Value}
case "lte", "<=":
return fmt.Sprintf("%s <= ?", filter.Column), []interface{}{filter.Value}
case "like":
return fmt.Sprintf("%s LIKE ?", filter.Column), []interface{}{filter.Value}
case "ilike":
return fmt.Sprintf("%s ILIKE ?", filter.Column), []interface{}{filter.Value}
case "in":
condition, args := common.BuildInCondition(filter.Column, filter.Value)
return condition, args
case "is_null":
return fmt.Sprintf("%s IS NULL", filter.Column), nil
case "is_not_null":
return fmt.Sprintf("%s IS NOT NULL", filter.Column), nil
}
return "", nil
}
func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, preloads []common.PreloadOption) (common.SelectQuery, error) {
for i := range preloads {
preload := &preloads[i]
if preload.Relation == "" {
continue
}
query = query.PreloadRelation(preload.Relation)
}
return query, nil
}

113
pkg/resolvemcp/hooks.go Normal file
View File

@@ -0,0 +1,113 @@
package resolvemcp
import (
"context"
"fmt"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/logger"
)
// HookType defines the type of hook to execute
type HookType string
const (
// BeforeHandle fires after model resolution, before operation dispatch.
BeforeHandle HookType = "before_handle"
BeforeRead HookType = "before_read"
AfterRead HookType = "after_read"
BeforeCreate HookType = "before_create"
AfterCreate HookType = "after_create"
BeforeUpdate HookType = "before_update"
AfterUpdate HookType = "after_update"
BeforeDelete HookType = "before_delete"
AfterDelete HookType = "after_delete"
)
// HookContext contains all the data available to a hook
type HookContext struct {
Context context.Context
Handler *Handler
Schema string
Entity string
Model interface{}
Options common.RequestOptions
Operation string
ID string
Data interface{}
Result interface{}
Error error
Query common.SelectQuery
Abort bool
AbortMessage string
AbortCode int
Tx common.Database
}
// HookFunc is the signature for hook functions
type HookFunc func(*HookContext) error
// HookRegistry manages all registered hooks
type HookRegistry struct {
hooks map[HookType][]HookFunc
}
func NewHookRegistry() *HookRegistry {
return &HookRegistry{
hooks: make(map[HookType][]HookFunc),
}
}
func (r *HookRegistry) Register(hookType HookType, hook HookFunc) {
if r.hooks == nil {
r.hooks = make(map[HookType][]HookFunc)
}
r.hooks[hookType] = append(r.hooks[hookType], hook)
logger.Info("Registered resolvemcp hook for %s (total: %d)", hookType, len(r.hooks[hookType]))
}
func (r *HookRegistry) RegisterMultiple(hookTypes []HookType, hook HookFunc) {
for _, hookType := range hookTypes {
r.Register(hookType, hook)
}
}
func (r *HookRegistry) Execute(hookType HookType, ctx *HookContext) error {
hooks, exists := r.hooks[hookType]
if !exists || len(hooks) == 0 {
return nil
}
logger.Debug("Executing %d resolvemcp hook(s) for %s", len(hooks), hookType)
for i, hook := range hooks {
if err := hook(ctx); err != nil {
logger.Error("resolvemcp hook %d for %s failed: %v", i+1, hookType, err)
return fmt.Errorf("hook execution failed: %w", err)
}
if ctx.Abort {
logger.Warn("resolvemcp hook %d for %s requested abort: %s", i+1, hookType, ctx.AbortMessage)
return fmt.Errorf("operation aborted by hook: %s", ctx.AbortMessage)
}
}
return nil
}
func (r *HookRegistry) Clear(hookType HookType) {
delete(r.hooks, hookType)
}
func (r *HookRegistry) ClearAll() {
r.hooks = make(map[HookType][]HookFunc)
}
func (r *HookRegistry) HasHooks(hookType HookType) bool {
hooks, exists := r.hooks[hookType]
return exists && len(hooks) > 0
}

264
pkg/resolvemcp/oauth2.go Normal file
View File

@@ -0,0 +1,264 @@
package resolvemcp
import (
"encoding/json"
"net/http"
"github.com/gorilla/mux"
"github.com/bitechdev/ResolveSpec/pkg/security"
)
// --------------------------------------------------------------------------
// OAuth2 registration on the Handler
// --------------------------------------------------------------------------
// oauth2Registration stores a configured auth provider and its route config.
type oauth2Registration struct {
auth *security.DatabaseAuthenticator
cfg OAuth2RouteConfig
}
// RegisterOAuth2 attaches an OAuth2 provider to the Handler.
// The login and callback HTTP routes are served by HTTPHandler / StreamableHTTPMux.
// Call this once per provider before serving requests.
//
// Example:
//
// auth := security.NewGoogleAuthenticator(clientID, secret, redirectURL, db)
// handler.RegisterOAuth2(auth, resolvemcp.OAuth2RouteConfig{
// ProviderName: "google",
// LoginPath: "/auth/google/login",
// CallbackPath: "/auth/google/callback",
// AfterLoginRedirect: "/",
// })
func (h *Handler) RegisterOAuth2(auth *security.DatabaseAuthenticator, cfg OAuth2RouteConfig) {
h.oauth2Regs = append(h.oauth2Regs, oauth2Registration{auth: auth, cfg: cfg})
}
// HTTPHandler returns a single http.Handler that serves:
// - MCP OAuth2 authorization server endpoints (when EnableOAuthServer has been called)
// - OAuth2 login + callback routes for every registered provider (legacy cookie flow)
// - The MCP SSE transport wrapped with required authentication middleware
//
// Example:
//
// auth := security.NewGoogleAuthenticator(...)
// handler.RegisterOAuth2(auth, cfg)
// handler.EnableOAuthServer(security.OAuthServerConfig{Issuer: "https://api.example.com"})
// security.RegisterSecurityHooks(handler, securityList)
// http.ListenAndServe(":8080", handler.HTTPHandler(securityList))
func (h *Handler) HTTPHandler(securityList *security.SecurityList) http.Handler {
mux := http.NewServeMux()
if h.oauthSrv != nil {
h.mountOAuthServerRoutes(mux)
}
h.mountOAuth2Routes(mux)
mcpHandler := h.AuthedSSEServer(securityList)
basePath := h.config.BasePath
if basePath == "" {
basePath = "/mcp"
}
mux.Handle(basePath+"/sse", mcpHandler)
mux.Handle(basePath+"/message", mcpHandler)
mux.Handle(basePath+"/", http.StripPrefix(basePath, mcpHandler))
return mux
}
// StreamableHTTPMux returns a single http.Handler that serves:
// - MCP OAuth2 authorization server endpoints (when EnableOAuthServer has been called)
// - OAuth2 login + callback routes for every registered provider (legacy cookie flow)
// - The MCP streamable HTTP transport wrapped with required authentication middleware
//
// Example:
//
// http.ListenAndServe(":8080", handler.StreamableHTTPMux(securityList))
func (h *Handler) StreamableHTTPMux(securityList *security.SecurityList) http.Handler {
mux := http.NewServeMux()
if h.oauthSrv != nil {
h.mountOAuthServerRoutes(mux)
}
h.mountOAuth2Routes(mux)
mcpHandler := h.AuthedStreamableHTTPServer(securityList)
basePath := h.config.BasePath
if basePath == "" {
basePath = "/mcp"
}
mux.Handle(basePath+"/", http.StripPrefix(basePath, mcpHandler))
mux.Handle(basePath, mcpHandler)
return mux
}
// mountOAuth2Routes registers all stored OAuth2 login+callback routes onto mux.
func (h *Handler) mountOAuth2Routes(mux *http.ServeMux) {
for _, reg := range h.oauth2Regs {
var cookieOpts []security.SessionCookieOptions
if reg.cfg.CookieOptions != nil {
cookieOpts = append(cookieOpts, *reg.cfg.CookieOptions)
}
mux.Handle(reg.cfg.LoginPath, OAuth2LoginHandler(reg.auth, reg.cfg.ProviderName))
mux.Handle(reg.cfg.CallbackPath, OAuth2CallbackHandler(reg.auth, reg.cfg.ProviderName, reg.cfg.AfterLoginRedirect, cookieOpts...))
}
}
// --------------------------------------------------------------------------
// Auth-wrapped transports
// --------------------------------------------------------------------------
// AuthedSSEServer wraps SSEServer with required authentication middleware from pkg/security.
// The middleware reads the session cookie / Authorization header and populates the user
// context into the request context, making it available to BeforeHandle security hooks.
// Unauthenticated requests receive 401 before reaching any MCP tool.
func (h *Handler) AuthedSSEServer(securityList *security.SecurityList) http.Handler {
return security.NewAuthMiddleware(securityList)(h.SSEServer())
}
// OptionalAuthSSEServer wraps SSEServer with optional authentication middleware.
// Unauthenticated requests continue as guest rather than returning 401.
// Use together with RegisterSecurityHooks and per-model CanPublicRead/Write rules
// to allow mixed public/private access.
func (h *Handler) OptionalAuthSSEServer(securityList *security.SecurityList) http.Handler {
return security.NewOptionalAuthMiddleware(securityList)(h.SSEServer())
}
// AuthedStreamableHTTPServer wraps StreamableHTTPServer with required authentication middleware.
func (h *Handler) AuthedStreamableHTTPServer(securityList *security.SecurityList) http.Handler {
return security.NewAuthMiddleware(securityList)(h.StreamableHTTPServer())
}
// OptionalAuthStreamableHTTPServer wraps StreamableHTTPServer with optional authentication middleware.
func (h *Handler) OptionalAuthStreamableHTTPServer(securityList *security.SecurityList) http.Handler {
return security.NewOptionalAuthMiddleware(securityList)(h.StreamableHTTPServer())
}
// --------------------------------------------------------------------------
// OAuth2 route config and standalone handlers
// --------------------------------------------------------------------------
// OAuth2RouteConfig configures the OAuth2 HTTP endpoints for a single provider.
type OAuth2RouteConfig struct {
// ProviderName is the OAuth2 provider name as registered with WithOAuth2()
// (e.g. "google", "github", "microsoft").
ProviderName string
// LoginPath is the HTTP path that redirects the browser to the OAuth2 provider
// (e.g. "/auth/google/login").
LoginPath string
// CallbackPath is the HTTP path that the OAuth2 provider redirects back to
// (e.g. "/auth/google/callback"). Must match the RedirectURL in OAuth2Config.
CallbackPath string
// AfterLoginRedirect is the URL to redirect the browser to after a successful
// login. When empty the LoginResponse JSON is written directly to the response.
AfterLoginRedirect string
// CookieOptions customises the session cookie written on successful login.
// Defaults to HttpOnly, Secure, SameSite=Lax when nil.
CookieOptions *security.SessionCookieOptions
}
// OAuth2LoginHandler returns an http.HandlerFunc that redirects the browser to
// the OAuth2 provider's authorization URL.
//
// Register it on any router:
//
// mux.Handle("/auth/google/login", resolvemcp.OAuth2LoginHandler(auth, "google"))
func OAuth2LoginHandler(auth *security.DatabaseAuthenticator, providerName string) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
state, err := auth.OAuth2GenerateState()
if err != nil {
http.Error(w, "failed to generate state", http.StatusInternalServerError)
return
}
authURL, err := auth.OAuth2GetAuthURL(providerName, state)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
}
}
// OAuth2CallbackHandler returns an http.HandlerFunc that handles the OAuth2 provider
// callback: exchanges the authorization code for a session token, writes the session
// cookie, then either redirects to afterLoginRedirect or writes the LoginResponse as JSON.
//
// Register it on any router:
//
// mux.Handle("/auth/google/callback", resolvemcp.OAuth2CallbackHandler(auth, "google", "/dashboard"))
func OAuth2CallbackHandler(auth *security.DatabaseAuthenticator, providerName, afterLoginRedirect string, cookieOpts ...security.SessionCookieOptions) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
code := r.URL.Query().Get("code")
state := r.URL.Query().Get("state")
if code == "" {
http.Error(w, "missing code parameter", http.StatusBadRequest)
return
}
loginResp, err := auth.OAuth2HandleCallback(r.Context(), providerName, code, state)
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
security.SetSessionCookie(w, loginResp, cookieOpts...)
if afterLoginRedirect != "" {
http.Redirect(w, r, afterLoginRedirect, http.StatusTemporaryRedirect)
return
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(loginResp) //nolint:errcheck
}
}
// --------------------------------------------------------------------------
// Gorilla Mux convenience helpers
// --------------------------------------------------------------------------
// SetupMuxOAuth2Routes registers the OAuth2 login and callback routes on a Gorilla Mux router.
//
// Example:
//
// resolvemcp.SetupMuxOAuth2Routes(r, auth, resolvemcp.OAuth2RouteConfig{
// ProviderName: "google", LoginPath: "/auth/google/login",
// CallbackPath: "/auth/google/callback", AfterLoginRedirect: "/",
// })
func SetupMuxOAuth2Routes(muxRouter *mux.Router, auth *security.DatabaseAuthenticator, cfg OAuth2RouteConfig) {
var cookieOpts []security.SessionCookieOptions
if cfg.CookieOptions != nil {
cookieOpts = append(cookieOpts, *cfg.CookieOptions)
}
muxRouter.Handle(cfg.LoginPath,
OAuth2LoginHandler(auth, cfg.ProviderName),
).Methods(http.MethodGet)
muxRouter.Handle(cfg.CallbackPath,
OAuth2CallbackHandler(auth, cfg.ProviderName, cfg.AfterLoginRedirect, cookieOpts...),
).Methods(http.MethodGet)
}
// SetupMuxRoutesWithAuth mounts the MCP SSE endpoints on a Gorilla Mux router
// with required authentication middleware applied.
func SetupMuxRoutesWithAuth(muxRouter *mux.Router, handler *Handler, securityList *security.SecurityList) {
basePath := handler.config.BasePath
h := handler.AuthedSSEServer(securityList)
muxRouter.Handle(basePath+"/sse", h).Methods(http.MethodGet, http.MethodOptions)
muxRouter.Handle(basePath+"/message", h).Methods(http.MethodPost, http.MethodOptions)
muxRouter.PathPrefix(basePath).Handler(http.StripPrefix(basePath, h))
}
// SetupMuxStreamableHTTPRoutesWithAuth mounts the MCP streamable HTTP endpoint on a
// Gorilla Mux router with required authentication middleware applied.
func SetupMuxStreamableHTTPRoutesWithAuth(muxRouter *mux.Router, handler *Handler, securityList *security.SecurityList) {
basePath := handler.config.BasePath
h := handler.AuthedStreamableHTTPServer(securityList)
muxRouter.PathPrefix(basePath).Handler(http.StripPrefix(basePath, h))
}

View File

@@ -0,0 +1,51 @@
package resolvemcp
import (
"net/http"
"github.com/bitechdev/ResolveSpec/pkg/security"
)
// EnableOAuthServer activates the MCP-standard OAuth2 authorization server on this Handler.
//
// Pass a DatabaseAuthenticator to enable direct username/password login — the server acts as
// its own identity provider and renders a login form at /oauth/authorize. Pass nil to use
// only external providers registered via RegisterOAuth2Provider.
//
// After calling this, HTTPHandler and StreamableHTTPMux serve the full set of RFC-compliant
// endpoints required by MCP clients alongside the MCP transport:
//
// GET /.well-known/oauth-authorization-server RFC 8414 — auto-discovery
// POST /oauth/register RFC 7591 — dynamic client registration
// GET /oauth/authorize OAuth 2.1 + PKCE — start login
// POST /oauth/authorize Login form submission (password flow)
// POST /oauth/token Bearer token exchange + refresh
// GET /oauth/provider/callback External provider redirect target
func (h *Handler) EnableOAuthServer(cfg security.OAuthServerConfig, auth *security.DatabaseAuthenticator) {
h.oauthSrv = security.NewOAuthServer(cfg, auth)
// Wire any external providers already registered via RegisterOAuth2
for _, reg := range h.oauth2Regs {
h.oauthSrv.RegisterExternalProvider(reg.auth, reg.cfg.ProviderName)
}
}
// RegisterOAuth2Provider adds an external OAuth2 provider to the MCP OAuth2 authorization server.
// EnableOAuthServer must be called before this. The auth must have been configured with
// WithOAuth2(providerName, ...) for the given provider name.
func (h *Handler) RegisterOAuth2Provider(auth *security.DatabaseAuthenticator, providerName string) {
if h.oauthSrv != nil {
h.oauthSrv.RegisterExternalProvider(auth, providerName)
}
}
// mountOAuthServerRoutes mounts the security.OAuthServer's HTTP handler onto mux.
func (h *Handler) mountOAuthServerRoutes(mux *http.ServeMux) {
oauthHandler := h.oauthSrv.HTTPHandler()
// Delegate all /oauth/ and /.well-known/ paths to the OAuth server
mux.Handle("/.well-known/", oauthHandler)
mux.Handle("/oauth/", oauthHandler)
if h.oauthSrv != nil {
// Also mount the external provider callback path if it differs from /oauth/
mux.Handle(h.oauthSrv.ProviderCallbackPath(), oauthHandler)
}
}

View File

@@ -0,0 +1,133 @@
// Package resolvemcp exposes registered database models as Model Context Protocol (MCP) tools
// and resources over HTTP/SSE transport.
//
// It mirrors the resolvespec package patterns:
// - Same model registration API
// - Same filter, sort, cursor pagination, preload options
// - Same lifecycle hook system
//
// Usage:
//
// handler := resolvemcp.NewHandlerWithGORM(db, resolvemcp.Config{BaseURL: "http://localhost:8080"})
// handler.RegisterModel("public", "users", &User{})
//
// r := mux.NewRouter()
// resolvemcp.SetupMuxRoutes(r, handler)
package resolvemcp
import (
"net/http"
"github.com/gorilla/mux"
"github.com/uptrace/bun"
bunrouter "github.com/uptrace/bunrouter"
"gorm.io/gorm"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/common/adapters/database"
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
)
// Config holds configuration for the resolvemcp handler.
type Config struct {
// BaseURL is the public-facing base URL of the server (e.g. "http://localhost:8080").
// It is sent to MCP clients during the SSE handshake so they know where to POST messages.
BaseURL string
// BasePath is the URL path prefix where the MCP endpoints are mounted (e.g. "/mcp").
// If empty, the path is detected from each incoming request automatically.
BasePath string
}
// NewHandlerWithGORM creates a Handler backed by a GORM database connection.
func NewHandlerWithGORM(db *gorm.DB, cfg Config) *Handler {
return NewHandler(database.NewGormAdapter(db), modelregistry.NewModelRegistry(), cfg)
}
// NewHandlerWithBun creates a Handler backed by a Bun database connection.
func NewHandlerWithBun(db *bun.DB, cfg Config) *Handler {
return NewHandler(database.NewBunAdapter(db), modelregistry.NewModelRegistry(), cfg)
}
// NewHandlerWithDB creates a Handler using an existing common.Database and a new registry.
func NewHandlerWithDB(db common.Database, cfg Config) *Handler {
return NewHandler(db, modelregistry.NewModelRegistry(), cfg)
}
// SetupMuxRoutes mounts the MCP HTTP/SSE endpoints on the given Gorilla Mux router
// using the base path from Config.BasePath (falls back to "/mcp" if empty).
//
// Two routes are registered:
// - GET {basePath}/sse — SSE connection endpoint (client subscribes here)
// - POST {basePath}/message — JSON-RPC message endpoint (client sends requests here)
//
// To protect these routes with authentication, wrap the mux router or apply middleware
// before calling SetupMuxRoutes.
func SetupMuxRoutes(muxRouter *mux.Router, handler *Handler) {
basePath := handler.config.BasePath
h := handler.SSEServer()
muxRouter.Handle(basePath+"/sse", h).Methods("GET", "OPTIONS")
muxRouter.Handle(basePath+"/message", h).Methods("POST", "OPTIONS")
// Convenience: also expose the full SSE server at basePath for clients that
// use ServeHTTP directly (e.g. net/http default mux).
muxRouter.PathPrefix(basePath).Handler(http.StripPrefix(basePath, h))
}
// SetupBunRouterRoutes mounts the MCP HTTP/SSE endpoints on a bunrouter router
// using the base path from Config.BasePath.
//
// Two routes are registered:
// - GET {basePath}/sse — SSE connection endpoint
// - POST {basePath}/message — JSON-RPC message endpoint
func SetupBunRouterRoutes(router *bunrouter.Router, handler *Handler) {
basePath := handler.config.BasePath
h := handler.SSEServer()
router.GET(basePath+"/sse", bunrouter.HTTPHandler(h))
router.POST(basePath+"/message", bunrouter.HTTPHandler(h))
}
// NewSSEServer returns an http.Handler that serves MCP over SSE.
// If Config.BasePath is set it is used directly; otherwise the base path is
// detected from each incoming request (by stripping the "/sse" or "/message" suffix).
//
// h := resolvemcp.NewSSEServer(handler)
// http.Handle("/api/mcp/", h)
func NewSSEServer(handler *Handler) http.Handler {
return handler.SSEServer()
}
// SetupMuxStreamableHTTPRoutes mounts the MCP streamable HTTP endpoint on the given Gorilla Mux router.
// The streamable HTTP transport uses a single endpoint (Config.BasePath) for all communication:
// POST for client→server messages, GET for server→client streaming.
//
// Example:
//
// resolvemcp.SetupMuxStreamableHTTPRoutes(r, handler) // mounts at Config.BasePath
func SetupMuxStreamableHTTPRoutes(muxRouter *mux.Router, handler *Handler) {
basePath := handler.config.BasePath
h := handler.StreamableHTTPServer()
muxRouter.PathPrefix(basePath).Handler(http.StripPrefix(basePath, h))
}
// SetupBunRouterStreamableHTTPRoutes mounts the MCP streamable HTTP endpoint on a bunrouter router.
// The streamable HTTP transport uses a single endpoint (Config.BasePath).
func SetupBunRouterStreamableHTTPRoutes(router *bunrouter.Router, handler *Handler) {
basePath := handler.config.BasePath
h := handler.StreamableHTTPServer()
router.GET(basePath, bunrouter.HTTPHandler(h))
router.POST(basePath, bunrouter.HTTPHandler(h))
router.DELETE(basePath, bunrouter.HTTPHandler(h))
}
// NewStreamableHTTPHandler returns an http.Handler that serves MCP over the streamable HTTP transport.
// Mount it at the desired path; that path becomes the MCP endpoint.
//
// h := resolvemcp.NewStreamableHTTPHandler(handler)
// http.Handle("/mcp", h)
// engine.Any("/mcp", gin.WrapH(h))
func NewStreamableHTTPHandler(handler *Handler) http.Handler {
return handler.StreamableHTTPServer()
}

View File

@@ -0,0 +1,115 @@
package resolvemcp
import (
"context"
"net/http"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/security"
)
// RegisterSecurityHooks wires the security package's access-control layer into the
// resolvemcp handler. Call it once after creating the handler, before registering models.
//
// The following controls are applied:
// - Per-entity operation rules (CanRead, CanCreate, CanUpdate, CanDelete, CanPublic*)
// stored via RegisterModelWithRules / SetModelRules.
// - Row-level security: WHERE clause injected per user from the SecurityList provider.
// - Column-level security: sensitive columns masked/hidden in read results.
// - Audit logging after each read.
func RegisterSecurityHooks(handler *Handler, securityList *security.SecurityList) {
// BeforeHandle: enforce model-level operation rules (auth check).
handler.Hooks().Register(BeforeHandle, func(hookCtx *HookContext) error {
if err := security.CheckModelAuthAllowed(newSecurityContext(hookCtx), hookCtx.Operation); err != nil {
hookCtx.Abort = true
hookCtx.AbortMessage = err.Error()
hookCtx.AbortCode = http.StatusUnauthorized
return err
}
return nil
})
// BeforeRead (1st): load RLS + CLS rules from the provider into SecurityList.
handler.Hooks().Register(BeforeRead, func(hookCtx *HookContext) error {
return security.LoadSecurityRules(newSecurityContext(hookCtx), securityList)
})
// BeforeRead (2nd): apply row-level security — injects a WHERE clause into the query.
// resolvemcp has no separate BeforeScan hook; the query is available in BeforeRead.
handler.Hooks().Register(BeforeRead, func(hookCtx *HookContext) error {
return security.ApplyRowSecurity(newSecurityContext(hookCtx), securityList)
})
// AfterRead (1st): apply column-level security — mask/hide columns in the result.
handler.Hooks().Register(AfterRead, func(hookCtx *HookContext) error {
return security.ApplyColumnSecurity(newSecurityContext(hookCtx), securityList)
})
// AfterRead (2nd): audit log.
handler.Hooks().Register(AfterRead, func(hookCtx *HookContext) error {
return security.LogDataAccess(newSecurityContext(hookCtx))
})
// BeforeUpdate: enforce CanUpdate rule.
handler.Hooks().Register(BeforeUpdate, func(hookCtx *HookContext) error {
return security.CheckModelUpdateAllowed(newSecurityContext(hookCtx))
})
// BeforeDelete: enforce CanDelete rule.
handler.Hooks().Register(BeforeDelete, func(hookCtx *HookContext) error {
return security.CheckModelDeleteAllowed(newSecurityContext(hookCtx))
})
logger.Info("Security hooks registered for resolvemcp handler")
}
// --------------------------------------------------------------------------
// securityContext — adapts resolvemcp.HookContext to security.SecurityContext
// --------------------------------------------------------------------------
type securityContext struct {
ctx *HookContext
}
func newSecurityContext(ctx *HookContext) security.SecurityContext {
return &securityContext{ctx: ctx}
}
func (s *securityContext) GetContext() context.Context {
return s.ctx.Context
}
func (s *securityContext) GetUserID() (int, bool) {
return security.GetUserID(s.ctx.Context)
}
func (s *securityContext) GetSchema() string {
return s.ctx.Schema
}
func (s *securityContext) GetEntity() string {
return s.ctx.Entity
}
func (s *securityContext) GetModel() interface{} {
return s.ctx.Model
}
func (s *securityContext) GetQuery() interface{} {
return s.ctx.Query
}
func (s *securityContext) SetQuery(query interface{}) {
if q, ok := query.(common.SelectQuery); ok {
s.ctx.Query = q
}
}
func (s *securityContext) GetResult() interface{} {
return s.ctx.Result
}
func (s *securityContext) SetResult(result interface{}) {
s.ctx.Result = result
}

692
pkg/resolvemcp/tools.go Normal file
View File

@@ -0,0 +1,692 @@
package resolvemcp
import (
"context"
"encoding/json"
"fmt"
"reflect"
"strings"
"github.com/mark3labs/mcp-go/mcp"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/reflection"
)
// toolName builds the MCP tool name for a given operation and model.
func toolName(operation, schema, entity string) string {
if schema == "" {
return fmt.Sprintf("%s_%s", operation, entity)
}
return fmt.Sprintf("%s_%s_%s", operation, schema, entity)
}
// registerModelTools registers the four CRUD tools and resource for a model.
func registerModelTools(h *Handler, schema, entity string, model interface{}) {
info := buildModelInfo(schema, entity, model)
registerReadTool(h, schema, entity, info)
registerCreateTool(h, schema, entity, info)
registerUpdateTool(h, schema, entity, info)
registerDeleteTool(h, schema, entity, info)
registerModelResource(h, schema, entity, info)
logger.Info("[resolvemcp] Registered MCP tools for %s", info.fullName)
}
// --------------------------------------------------------------------------
// Model introspection
// --------------------------------------------------------------------------
// modelInfo holds pre-computed metadata for a model used in tool descriptions.
type modelInfo struct {
fullName string // e.g. "public.users"
pkName string // e.g. "id"
columns []columnInfo
relationNames []string
schemaDoc string // formatted multi-line schema listing
}
type columnInfo struct {
jsonName string
sqlName string
goType string
sqlType string
isPrimary bool
isUnique bool
isFK bool
nullable bool
}
// buildModelInfo extracts column metadata and pre-builds the schema documentation string.
func buildModelInfo(schema, entity string, model interface{}) modelInfo {
info := modelInfo{
fullName: buildModelName(schema, entity),
pkName: reflection.GetPrimaryKeyName(model),
}
// Unwrap to base struct type
modelType := reflect.TypeOf(model)
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice) {
modelType = modelType.Elem()
}
if modelType == nil || modelType.Kind() != reflect.Struct {
return info
}
details := reflection.GetModelColumnDetail(reflect.New(modelType).Elem())
for _, d := range details {
// Derive the JSON name from the struct field
jsonName := fieldJSONName(modelType, d.Name)
if jsonName == "" || jsonName == "-" {
continue
}
// Skip relation fields (slice or user-defined struct that isn't time.Time).
fieldType, found := modelType.FieldByName(d.Name)
if found {
ft := fieldType.Type
if ft.Kind() == reflect.Ptr {
ft = ft.Elem()
}
isUserStruct := ft.Kind() == reflect.Struct && ft.Name() != "Time" && ft.PkgPath() != ""
if ft.Kind() == reflect.Slice || isUserStruct {
info.relationNames = append(info.relationNames, jsonName)
continue
}
}
sqlName := d.SQLName
if sqlName == "" {
sqlName = jsonName
}
// Derive Go type name, unwrapping pointer if needed.
goType := d.DataType
if goType == "" && found {
ft := fieldType.Type
for ft.Kind() == reflect.Ptr {
ft = ft.Elem()
}
goType = ft.Name()
}
// isPrimary: use both the GORM-tag detection and a name comparison against
// the known primary key (handles camelCase "primaryKey" tags correctly).
isPrimary := d.SQLKey == "primary_key" ||
(info.pkName != "" && (sqlName == info.pkName || jsonName == info.pkName))
ci := columnInfo{
jsonName: jsonName,
sqlName: sqlName,
goType: goType,
sqlType: d.SQLDataType,
isPrimary: isPrimary,
isUnique: d.SQLKey == "unique" || d.SQLKey == "uniqueindex",
isFK: d.SQLKey == "foreign_key",
nullable: d.Nullable,
}
info.columns = append(info.columns, ci)
}
info.schemaDoc = buildSchemaDoc(info)
return info
}
// fieldJSONName returns the JSON tag name for a struct field, falling back to the field name.
func fieldJSONName(modelType reflect.Type, fieldName string) string {
field, ok := modelType.FieldByName(fieldName)
if !ok {
return fieldName
}
tag := field.Tag.Get("json")
if tag == "" {
return fieldName
}
parts := strings.SplitN(tag, ",", 2)
if parts[0] == "" {
return fieldName
}
return parts[0]
}
// buildSchemaDoc builds a human-readable column listing for inclusion in tool descriptions.
func buildSchemaDoc(info modelInfo) string {
if len(info.columns) == 0 {
return ""
}
var sb strings.Builder
sb.WriteString("Columns:\n")
for _, c := range info.columns {
line := fmt.Sprintf(" • %s", c.jsonName)
typeDesc := c.goType
if c.sqlType != "" {
typeDesc = c.sqlType
}
if typeDesc != "" {
line += fmt.Sprintf(" (%s)", typeDesc)
}
var flags []string
if c.isPrimary {
flags = append(flags, "primary key")
}
if c.isUnique {
flags = append(flags, "unique")
}
if c.isFK {
flags = append(flags, "foreign key")
}
if !c.nullable && !c.isPrimary {
flags = append(flags, "not null")
} else if c.nullable {
flags = append(flags, "nullable")
}
if len(flags) > 0 {
line += " — " + strings.Join(flags, ", ")
}
sb.WriteString(line + "\n")
}
if len(info.relationNames) > 0 {
sb.WriteString("Relations (preloadable): " + strings.Join(info.relationNames, ", ") + "\n")
}
return sb.String()
}
// columnNameList returns a comma-separated list of JSON column names (for descriptions).
func columnNameList(cols []columnInfo) string {
names := make([]string, len(cols))
for i, c := range cols {
names[i] = c.jsonName
}
return strings.Join(names, ", ")
}
// writableColumnNames returns JSON names for all non-primary-key columns.
func writableColumnNames(cols []columnInfo) []string {
var names []string
for _, c := range cols {
if !c.isPrimary {
names = append(names, c.jsonName)
}
}
return names
}
// --------------------------------------------------------------------------
// Read tool
// --------------------------------------------------------------------------
func registerReadTool(h *Handler, schema, entity string, info modelInfo) {
name := toolName("read", schema, entity)
var descParts []string
descParts = append(descParts, fmt.Sprintf("Read records from the '%s' database table.", info.fullName))
if info.pkName != "" {
descParts = append(descParts, fmt.Sprintf("Primary key: '%s'. Pass it via 'id' to fetch a single record.", info.pkName))
}
if info.schemaDoc != "" {
descParts = append(descParts, info.schemaDoc)
}
descParts = append(descParts,
"Pagination: use 'limit'/'offset' for offset-based paging, or 'cursor_forward'/'cursor_backward' (pass the primary key value of the last/first record on the current page) for cursor-based paging.",
"Filtering: each filter object requires 'column' (JSON field name) and 'operator'. Supported operators: = != > < >= <= like ilike in is_null is_not_null. Combine with 'logic_operator': AND (default) or OR.",
"Sorting: each sort object requires 'column' and 'direction' (asc or desc).",
)
if len(info.relationNames) > 0 {
descParts = append(descParts, fmt.Sprintf("Preloadable relations: %s. Pass relation name in 'preloads'.", strings.Join(info.relationNames, ", ")))
}
description := strings.Join(descParts, "\n\n")
filterDesc := `Array of filter objects. Example: [{"column":"status","operator":"=","value":"active"},{"column":"age","operator":">","value":18,"logic_operator":"AND"}]`
if len(info.columns) > 0 {
filterDesc += fmt.Sprintf(" Available columns: %s.", columnNameList(info.columns))
}
sortDesc := `Array of sort objects. Example: [{"column":"created_at","direction":"desc"}]`
if len(info.columns) > 0 {
sortDesc += fmt.Sprintf(" Available columns: %s.", columnNameList(info.columns))
}
tool := mcp.NewTool(name,
mcp.WithDescription(description),
mcp.WithString("id",
mcp.Description(fmt.Sprintf("Primary key (%s) of a single record to fetch. Omit to return multiple records.", info.pkName)),
),
mcp.WithNumber("limit",
mcp.Description("Maximum number of records to return per page. Recommended: 10100."),
),
mcp.WithNumber("offset",
mcp.Description("Number of records to skip (for offset-based pagination). Use with 'limit'."),
),
mcp.WithString("cursor_forward",
mcp.Description(fmt.Sprintf("Cursor for the next page: pass the '%s' value of the last record on the current page. Requires 'sort' to be set.", info.pkName)),
),
mcp.WithString("cursor_backward",
mcp.Description(fmt.Sprintf("Cursor for the previous page: pass the '%s' value of the first record on the current page. Requires 'sort' to be set.", info.pkName)),
),
mcp.WithArray("columns",
mcp.Description(fmt.Sprintf("Columns to include in the result. Omit to return all columns. Available: %s.", columnNameList(info.columns))),
),
mcp.WithArray("omit_columns",
mcp.Description(fmt.Sprintf("Columns to exclude from the result. Available: %s.", columnNameList(info.columns))),
),
mcp.WithArray("filters",
mcp.Description(filterDesc),
),
mcp.WithArray("sort",
mcp.Description(sortDesc),
),
mcp.WithArray("preloads",
mcp.Description(buildPreloadDesc(info)),
),
)
h.mcpServer.AddTool(tool, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
args := req.GetArguments()
id, _ := args["id"].(string)
options := parseRequestOptions(args)
data, metadata, err := h.executeRead(ctx, schema, entity, id, options)
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
return marshalResult(map[string]interface{}{
"success": true,
"data": data,
"metadata": metadata,
})
})
}
func buildPreloadDesc(info modelInfo) string {
if len(info.relationNames) == 0 {
return `Array of relation preload objects. Each object: {"relation":"RelationName"}. No relations defined on this model.`
}
return fmt.Sprintf(
`Array of relation preload objects. Each object: {"relation":"RelationName","columns":["col1","col2"]}. Available relations: %s.`,
strings.Join(info.relationNames, ", "),
)
}
// --------------------------------------------------------------------------
// Create tool
// --------------------------------------------------------------------------
func registerCreateTool(h *Handler, schema, entity string, info modelInfo) {
name := toolName("create", schema, entity)
writable := writableColumnNames(info.columns)
var descParts []string
descParts = append(descParts, fmt.Sprintf("Create one or more new records in the '%s' table.", info.fullName))
if len(writable) > 0 {
descParts = append(descParts, fmt.Sprintf("Writable fields: %s.", strings.Join(writable, ", ")))
}
if info.pkName != "" {
descParts = append(descParts, fmt.Sprintf("The primary key ('%s') is typically auto-generated — omit it unless you need to supply it explicitly.", info.pkName))
}
descParts = append(descParts,
"Pass a single JSON object to 'data' to create one record. Pass an array of objects to create multiple records in a single transaction (all succeed or all fail).",
)
if info.schemaDoc != "" {
descParts = append(descParts, info.schemaDoc)
}
description := strings.Join(descParts, "\n\n")
dataDesc := "Record fields to create."
if len(writable) > 0 {
dataDesc += fmt.Sprintf(" Writable fields: %s.", strings.Join(writable, ", "))
}
dataDesc += " Pass a single object or an array of objects."
tool := mcp.NewTool(name,
mcp.WithDescription(description),
mcp.WithObject("data",
mcp.Description(dataDesc),
mcp.Required(),
),
)
h.mcpServer.AddTool(tool, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
args := req.GetArguments()
data, ok := args["data"]
if !ok {
return mcp.NewToolResultError("missing required argument: data"), nil
}
result, err := h.executeCreate(ctx, schema, entity, data)
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
return marshalResult(map[string]interface{}{
"success": true,
"data": result,
})
})
}
// --------------------------------------------------------------------------
// Update tool
// --------------------------------------------------------------------------
func registerUpdateTool(h *Handler, schema, entity string, info modelInfo) {
name := toolName("update", schema, entity)
writable := writableColumnNames(info.columns)
var descParts []string
descParts = append(descParts, fmt.Sprintf("Update an existing record in the '%s' table.", info.fullName))
if info.pkName != "" {
descParts = append(descParts, fmt.Sprintf("Identify the record by its primary key ('%s') via the 'id' argument or by including '%s' inside 'data'.", info.pkName, info.pkName))
}
if len(writable) > 0 {
descParts = append(descParts, fmt.Sprintf("Updatable fields: %s.", strings.Join(writable, ", ")))
}
descParts = append(descParts,
"Only non-null, non-empty fields in 'data' are applied — existing values are preserved for fields you omit. Returns the merged record as stored.",
)
if info.schemaDoc != "" {
descParts = append(descParts, info.schemaDoc)
}
description := strings.Join(descParts, "\n\n")
idDesc := fmt.Sprintf("Primary key ('%s') of the record to update. Can also be included inside 'data'.", info.pkName)
dataDesc := "Fields to update (non-null, non-empty values are merged into the existing record)."
if len(writable) > 0 {
dataDesc += fmt.Sprintf(" Updatable fields: %s.", strings.Join(writable, ", "))
}
tool := mcp.NewTool(name,
mcp.WithDescription(description),
mcp.WithString("id",
mcp.Description(idDesc),
),
mcp.WithObject("data",
mcp.Description(dataDesc),
mcp.Required(),
),
)
h.mcpServer.AddTool(tool, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
args := req.GetArguments()
id, _ := args["id"].(string)
data, ok := args["data"]
if !ok {
return mcp.NewToolResultError("missing required argument: data"), nil
}
dataMap, ok := data.(map[string]interface{})
if !ok {
return mcp.NewToolResultError("data must be an object"), nil
}
result, err := h.executeUpdate(ctx, schema, entity, id, dataMap)
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
return marshalResult(map[string]interface{}{
"success": true,
"data": result,
})
})
}
// --------------------------------------------------------------------------
// Delete tool
// --------------------------------------------------------------------------
func registerDeleteTool(h *Handler, schema, entity string, info modelInfo) {
name := toolName("delete", schema, entity)
descParts := []string{
fmt.Sprintf("Delete a record from the '%s' table by its primary key.", info.fullName),
}
if info.pkName != "" {
descParts = append(descParts, fmt.Sprintf("Pass the '%s' value of the record to delete via the 'id' argument.", info.pkName))
}
descParts = append(descParts, "Returns the deleted record. This operation is irreversible.")
description := strings.Join(descParts, " ")
tool := mcp.NewTool(name,
mcp.WithDescription(description),
mcp.WithString("id",
mcp.Description(fmt.Sprintf("Primary key ('%s') of the record to delete.", info.pkName)),
mcp.Required(),
),
)
h.mcpServer.AddTool(tool, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
args := req.GetArguments()
id, _ := args["id"].(string)
result, err := h.executeDelete(ctx, schema, entity, id)
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
return marshalResult(map[string]interface{}{
"success": true,
"data": result,
})
})
}
// --------------------------------------------------------------------------
// Resource registration
// --------------------------------------------------------------------------
func registerModelResource(h *Handler, schema, entity string, info modelInfo) {
resourceURI := info.fullName
var resourceDesc strings.Builder
fmt.Fprintf(&resourceDesc, "Database table: %s", info.fullName)
if info.pkName != "" {
fmt.Fprintf(&resourceDesc, " (primary key: %s)", info.pkName)
}
if info.schemaDoc != "" {
resourceDesc.WriteString("\n\n")
resourceDesc.WriteString(info.schemaDoc)
}
resource := mcp.NewResource(
resourceURI,
entity,
mcp.WithResourceDescription(resourceDesc.String()),
mcp.WithMIMEType("application/json"),
)
h.mcpServer.AddResource(resource, func(ctx context.Context, req mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) {
limit := 100
options := common.RequestOptions{Limit: &limit}
data, metadata, err := h.executeRead(ctx, schema, entity, "", options)
if err != nil {
return nil, err
}
payload := map[string]interface{}{
"data": data,
"metadata": metadata,
}
jsonBytes, err := json.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("error marshaling resource: %w", err)
}
return []mcp.ResourceContents{
mcp.TextResourceContents{
URI: req.Params.URI,
MIMEType: "application/json",
Text: string(jsonBytes),
},
}, nil
})
}
// --------------------------------------------------------------------------
// Argument parsing helpers
// --------------------------------------------------------------------------
// parseRequestOptions converts raw MCP tool arguments into common.RequestOptions.
func parseRequestOptions(args map[string]interface{}) common.RequestOptions {
options := common.RequestOptions{}
if v, ok := args["limit"]; ok {
switch n := v.(type) {
case float64:
limit := int(n)
options.Limit = &limit
case int:
options.Limit = &n
}
}
if v, ok := args["offset"]; ok {
switch n := v.(type) {
case float64:
offset := int(n)
options.Offset = &offset
case int:
options.Offset = &n
}
}
if v, ok := args["cursor_forward"].(string); ok {
options.CursorForward = v
}
if v, ok := args["cursor_backward"].(string); ok {
options.CursorBackward = v
}
options.Columns = parseStringArray(args["columns"])
options.OmitColumns = parseStringArray(args["omit_columns"])
options.Filters = parseFilters(args["filters"])
options.Sort = parseSortOptions(args["sort"])
options.Preload = parsePreloadOptions(args["preloads"])
return options
}
func parseStringArray(raw interface{}) []string {
if raw == nil {
return nil
}
items, ok := raw.([]interface{})
if !ok {
return nil
}
result := make([]string, 0, len(items))
for _, item := range items {
if s, ok := item.(string); ok {
result = append(result, s)
}
}
return result
}
func parseFilters(raw interface{}) []common.FilterOption {
if raw == nil {
return nil
}
items, ok := raw.([]interface{})
if !ok {
return nil
}
result := make([]common.FilterOption, 0, len(items))
for _, item := range items {
b, err := json.Marshal(item)
if err != nil {
continue
}
var f common.FilterOption
if err := json.Unmarshal(b, &f); err != nil {
continue
}
if f.Column == "" || f.Operator == "" {
continue
}
if strings.EqualFold(f.LogicOperator, "or") {
f.LogicOperator = "OR"
} else {
f.LogicOperator = "AND"
}
result = append(result, f)
}
return result
}
func parseSortOptions(raw interface{}) []common.SortOption {
if raw == nil {
return nil
}
items, ok := raw.([]interface{})
if !ok {
return nil
}
result := make([]common.SortOption, 0, len(items))
for _, item := range items {
b, err := json.Marshal(item)
if err != nil {
continue
}
var s common.SortOption
if err := json.Unmarshal(b, &s); err != nil {
continue
}
if s.Column == "" {
continue
}
result = append(result, s)
}
return result
}
func parsePreloadOptions(raw interface{}) []common.PreloadOption {
if raw == nil {
return nil
}
items, ok := raw.([]interface{})
if !ok {
return nil
}
result := make([]common.PreloadOption, 0, len(items))
for _, item := range items {
b, err := json.Marshal(item)
if err != nil {
continue
}
var p common.PreloadOption
if err := json.Unmarshal(b, &p); err != nil {
continue
}
if p.Relation == "" {
continue
}
result = append(result, p)
}
return result
}
// marshalResult marshals a value to JSON and returns it as an MCP text result.
func marshalResult(v interface{}) (*mcp.CallToolResult, error) {
b, err := json.Marshal(v)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("error marshaling result: %v", err)), nil
}
return mcp.NewToolResultText(string(b)), nil
}

View File

@@ -24,6 +24,7 @@ const (
// - pkName: primary key column (e.g. "id")
// - modelColumns: optional list of valid main-table columns (for validation). Pass nil to skip.
// - options: the request options containing sort and cursor information
// - expandJoins: optional map[alias]string of JOIN clauses for join-column sort support
//
// Returns SQL snippet to embed in WHERE clause.
func GetCursorFilter(
@@ -31,6 +32,7 @@ func GetCursorFilter(
pkName string,
modelColumns []string,
options common.RequestOptions,
expandJoins map[string]string,
) (string, error) {
// Separate schema prefix from bare table name
fullTableName := tableName
@@ -58,18 +60,19 @@ func GetCursorFilter(
// 3. Prepare
// --------------------------------------------------------------------- //
var whereClauses []string
joinSQL := ""
reverse := direction < 0
// --------------------------------------------------------------------- //
// 4. Process each sort column
// --------------------------------------------------------------------- //
for _, s := range sortItems {
col := strings.TrimSpace(s.Column)
col := strings.Trim(strings.TrimSpace(s.Column), "()")
if col == "" {
continue
}
// Parse: "created_at", "user.name", etc.
// Parse: "created_at", "user.name", "fn.sortorder", etc.
parts := strings.Split(col, ".")
field := strings.TrimSpace(parts[len(parts)-1])
prefix := strings.Join(parts[:len(parts)-1], ".")
@@ -82,7 +85,7 @@ func GetCursorFilter(
}
// Resolve column
cursorCol, targetCol, err := resolveColumn(
cursorCol, targetCol, isJoin, err := resolveColumn(
field, prefix, tableName, modelColumns,
)
if err != nil {
@@ -90,6 +93,22 @@ func GetCursorFilter(
continue
}
// Handle joins
if isJoin {
if expandJoins != nil {
if joinClause, ok := expandJoins[prefix]; ok {
jSQL, cRef := rewriteJoin(joinClause, tableName, prefix)
joinSQL = jSQL
cursorCol = cRef + "." + field
targetCol = prefix + "." + field
}
}
if cursorCol == "" {
logger.Warn("Skipping cursor sort column %q: join alias %q not in expandJoins", col, prefix)
continue
}
}
// Build inequality
op := "<"
if desc {
@@ -113,10 +132,12 @@ func GetCursorFilter(
query := fmt.Sprintf(`EXISTS (
SELECT 1
FROM %s cursor_select
%s
WHERE cursor_select.%s = %s
AND (%s)
)`,
fullTableName,
joinSQL,
pkName,
cursorID,
orSQL,
@@ -137,35 +158,44 @@ func getActiveCursor(options common.RequestOptions) (id string, direction Cursor
return "", 0
}
// Helper: resolve column (main table only for now)
// Helper: resolve column (main table or join)
func resolveColumn(
field, prefix, tableName string,
modelColumns []string,
) (cursorCol, targetCol string, err error) {
) (cursorCol, targetCol string, isJoin bool, err error) {
// JSON field
if strings.Contains(field, "->") {
return "cursor_select." + field, tableName + "." + field, nil
return "cursor_select." + field, tableName + "." + field, false, nil
}
// Main table column
if modelColumns != nil {
for _, col := range modelColumns {
if strings.EqualFold(col, field) {
return "cursor_select." + field, tableName + "." + field, nil
return "cursor_select." + field, tableName + "." + field, false, nil
}
}
} else {
// No validation → allow all main-table fields
return "cursor_select." + field, tableName + "." + field, nil
return "cursor_select." + field, tableName + "." + field, false, nil
}
// Joined column (not supported in resolvespec yet)
// Joined column
if prefix != "" && prefix != tableName {
return "", "", fmt.Errorf("joined columns not supported in cursor pagination: %s", field)
return "", "", true, nil
}
return "", "", fmt.Errorf("invalid column: %s", field)
return "", "", false, fmt.Errorf("invalid column: %s", field)
}
// Helper: rewrite JOIN clause for cursor subquery
func rewriteJoin(joinClause, mainTable, alias string) (joinSQL, cursorAlias string) {
joinSQL = strings.ReplaceAll(joinClause, mainTable+".", "cursor_select.")
cursorAlias = "cursor_select_" + alias
joinSQL = strings.ReplaceAll(joinSQL, " "+alias+" ", " "+cursorAlias+" ")
joinSQL = strings.ReplaceAll(joinSQL, " "+alias+".", " "+cursorAlias+".")
return joinSQL, cursorAlias
}
// ------------------------------------------------------------------------- //

View File

@@ -20,7 +20,7 @@ func TestGetCursorFilter_Forward(t *testing.T) {
pkName := "id"
modelColumns := []string{"id", "title", "created_at", "user_id"}
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options)
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil)
if err != nil {
t.Fatalf("GetCursorFilter failed: %v", err)
}
@@ -65,7 +65,7 @@ func TestGetCursorFilter_Backward(t *testing.T) {
pkName := "id"
modelColumns := []string{"id", "title", "created_at", "user_id"}
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options)
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil)
if err != nil {
t.Fatalf("GetCursorFilter failed: %v", err)
}
@@ -96,7 +96,7 @@ func TestGetCursorFilter_NoCursor(t *testing.T) {
pkName := "id"
modelColumns := []string{"id", "title", "created_at"}
_, err := GetCursorFilter(tableName, pkName, modelColumns, options)
_, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil)
if err == nil {
t.Error("Expected error when no cursor is provided")
}
@@ -116,7 +116,7 @@ func TestGetCursorFilter_NoSort(t *testing.T) {
pkName := "id"
modelColumns := []string{"id", "title"}
_, err := GetCursorFilter(tableName, pkName, modelColumns, options)
_, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil)
if err == nil {
t.Error("Expected error when no sort columns are defined")
}
@@ -140,7 +140,7 @@ func TestGetCursorFilter_MultiColumnSort(t *testing.T) {
pkName := "id"
modelColumns := []string{"id", "title", "priority", "created_at"}
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options)
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil)
if err != nil {
t.Fatalf("GetCursorFilter failed: %v", err)
}
@@ -170,7 +170,7 @@ func TestGetCursorFilter_WithSchemaPrefix(t *testing.T) {
pkName := "id"
modelColumns := []string{"id", "name", "email"}
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options)
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil)
if err != nil {
t.Fatalf("GetCursorFilter failed: %v", err)
}
@@ -183,6 +183,37 @@ func TestGetCursorFilter_WithSchemaPrefix(t *testing.T) {
t.Logf("Generated cursor filter with schema: %s", filter)
}
func TestGetCursorFilter_LateralJoin(t *testing.T) {
lateralJoin := "inner join lateral (\nselect string_agg(a.name, '.') as sortorder\nfrom tree(account.rid_account) r\ninner join account a on a.id = r.id\n) fn on true"
options := common.RequestOptions{
Sort: []common.SortOption{{Column: "fn.sortorder", Direction: "ASC"}},
CursorForward: "8975",
}
tableName := "core.account"
pkName := "rid_account"
modelColumns := []string{"rid_account", "description", "pastelno"}
expandJoins := map[string]string{"fn": lateralJoin}
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options, expandJoins)
if err != nil {
t.Fatalf("GetCursorFilter failed: %v", err)
}
t.Logf("Generated lateral cursor filter: %s", filter)
if !strings.Contains(filter, "cursor_select_fn") {
t.Errorf("Filter should reference cursor_select_fn alias, got: %s", filter)
}
if !strings.Contains(filter, "sortorder") {
t.Errorf("Filter should reference sortorder column, got: %s", filter)
}
if strings.Contains(filter, " < ") || strings.Contains(filter, " > ") {
t.Errorf("Filter should not contain empty comparison operators, got: %s", filter)
}
}
func TestGetActiveCursor(t *testing.T) {
tests := []struct {
name string
@@ -288,18 +319,19 @@ func TestResolveColumn(t *testing.T) {
wantErr: false,
},
{
name: "Joined column (not supported)",
name: "Joined column (isJoin=true, no error)",
field: "name",
prefix: "user",
tableName: "posts",
modelColumns: []string{"id", "title"},
wantErr: true,
wantErr: false,
// cursorCol and targetCol are empty when isJoin=true; handled by caller
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cursor, target, err := resolveColumn(tt.field, tt.prefix, tt.tableName, tt.modelColumns)
cursor, target, isJoin, err := resolveColumn(tt.field, tt.prefix, tt.tableName, tt.modelColumns)
if tt.wantErr {
if err == nil {
@@ -312,6 +344,14 @@ func TestResolveColumn(t *testing.T) {
t.Fatalf("Unexpected error: %v", err)
}
// For join columns, cursor/target are empty and isJoin=true
if isJoin {
if cursor != "" || target != "" {
t.Errorf("Expected empty cursor/target for join column, got %q / %q", cursor, target)
}
return
}
if cursor != tt.wantCursor {
t.Errorf("Expected cursor %q, got %q", tt.wantCursor, cursor)
}
@@ -362,7 +402,7 @@ func TestCursorFilter_SQL_Safety(t *testing.T) {
pkName := "id"
modelColumns := []string{"id", "created_at"}
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options)
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil)
if err != nil {
t.Fatalf("GetCursorFilter failed: %v", err)
}

View File

@@ -334,8 +334,8 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
options.Sort = []common.SortOption{{Column: pkName, Direction: "ASC"}}
}
// Get cursor filter SQL
cursorFilter, err := GetCursorFilter(tableName, pkName, modelColumns, options)
// Get cursor filter SQL (expandJoins is empty for resolvespec — no custom SQL join support yet)
cursorFilter, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil)
if err != nil {
logger.Error("Error building cursor filter: %v", err)
h.sendError(w, http.StatusBadRequest, "cursor_error", "Invalid cursor pagination", err)

View File

@@ -64,7 +64,7 @@ func (opts *ExtendedRequestOptions) GetCursorFilter(
// 4. Process each sort column
// --------------------------------------------------------------------- //
for _, s := range sortItems {
col := strings.TrimSpace(s.Column)
col := strings.Trim(strings.TrimSpace(s.Column), "()")
if col == "" {
continue
}
@@ -93,12 +93,18 @@ func (opts *ExtendedRequestOptions) GetCursorFilter(
}
// Handle joins
if isJoin && expandJoins != nil {
if joinClause, ok := expandJoins[prefix]; ok {
jSQL, cRef := rewriteJoin(joinClause, tableName, prefix)
joinSQL = jSQL
cursorCol = cRef + "." + field
targetCol = prefix + "." + field
if isJoin {
if expandJoins != nil {
if joinClause, ok := expandJoins[prefix]; ok {
jSQL, cRef := rewriteJoin(joinClause, tableName, prefix)
joinSQL = jSQL
cursorCol = cRef + "." + field
targetCol = prefix + "." + field
}
}
if cursorCol == "" {
logger.Warn("Skipping cursor sort column %q: join alias %q not in expandJoins", col, prefix)
continue
}
}

View File

@@ -278,6 +278,47 @@ func TestCleanSortField(t *testing.T) {
}
}
func TestGetCursorFilter_LateralJoin(t *testing.T) {
lateralJoin := "inner join lateral (\nselect string_agg(a.name, '.') as sortorder\nfrom tree(account.rid_account) r\ninner join account a on a.id = r.id\n) fn on true"
opts := &ExtendedRequestOptions{
RequestOptions: common.RequestOptions{
Sort: []common.SortOption{
{Column: "fn.sortorder", Direction: "ASC"},
},
},
}
opts.CursorForward = "8975"
tableName := "core.account"
pkName := "rid_account"
// modelColumns does not contain "sortorder" - it's a lateral join computed column
modelColumns := []string{"rid_account", "description", "pastelno"}
expandJoins := map[string]string{"fn": lateralJoin}
filter, err := opts.GetCursorFilter(tableName, pkName, modelColumns, expandJoins)
if err != nil {
t.Fatalf("GetCursorFilter failed: %v", err)
}
t.Logf("Generated lateral cursor filter: %s", filter)
// Should contain the rewritten lateral join inside the EXISTS subquery
if !strings.Contains(filter, "cursor_select_fn") {
t.Errorf("Filter should reference cursor_select_fn alias, got: %s", filter)
}
// Should compare fn.sortorder values
if !strings.Contains(filter, "sortorder") {
t.Errorf("Filter should reference sortorder column, got: %s", filter)
}
// Should NOT contain empty comparison like "< "
if strings.Contains(filter, " < ") || strings.Contains(filter, " > ") {
t.Errorf("Filter should not contain empty comparison operators, got: %s", filter)
}
}
func TestBuildPriorityChain(t *testing.T) {
clauses := []string{
"cursor_select.priority > posts.priority",

View File

@@ -723,13 +723,15 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
// Extract model columns for validation using the generic database function
modelColumns := reflection.GetModelColumns(model)
// Build expand joins map (if needed in future)
var expandJoins map[string]string
if len(options.Expand) > 0 {
expandJoins = make(map[string]string)
// TODO: Build actual JOIN SQL for each expand relation
// For now, pass empty map as joins are handled via Preload
// Build expand joins map: custom SQL joins are available in cursor subquery
expandJoins := make(map[string]string)
for _, joinClause := range options.CustomSQLJoin {
alias := extractJoinAlias(joinClause)
if alias != "" {
expandJoins[alias] = joinClause
}
}
// TODO: also add Expand relation JOINs when those are built as SQL rather than Preload
// Default sort to primary key when none provided
if len(options.Sort) == 0 {

View File

@@ -274,9 +274,11 @@ func (h *Handler) parseOptionsFromHeaders(r common.Request, model interface{}) E
}
}
// Resolve relation names (convert table names to field names) if model is provided
// Skip resolution if X-Files header was provided, as XFiles uses Prefix which already contains the correct field names
if model != nil && !options.XFilesPresent {
// Resolve relation names (convert table names/prefixes to actual model field names) if model is provided.
// This runs for both regular headers and X-Files, because XFile prefixes don't always match model
// field names (e.g., prefix "HUB" vs field "HUB_RID_HUB"). RelatedKey/ForeignKey are used to
// disambiguate when multiple fields point to the same related type.
if model != nil {
h.resolveRelationNamesInOptions(&options, model)
}
@@ -550,10 +552,8 @@ func (h *Handler) parseCustomSQLJoin(options *ExtendedRequestOptions, value stri
// - "LEFT JOIN departments d ON ..." -> "d"
// - "INNER JOIN users AS u ON ..." -> "u"
// - "JOIN roles r ON ..." -> "r"
// - "INNER JOIN LATERAL (...) fn ON true" -> "fn"
func extractJoinAlias(joinClause string) string {
// Pattern: JOIN table_name [AS] alias ON ...
// We need to extract the alias (word before ON)
upperJoin := strings.ToUpper(joinClause)
// Find the "JOIN" keyword position
@@ -562,7 +562,20 @@ func extractJoinAlias(joinClause string) string {
return ""
}
// Find the "ON" keyword position
// Lateral joins: alias is the word after the closing ) and before ON
if strings.Contains(upperJoin, "LATERAL") {
lastClose := strings.LastIndex(joinClause, ")")
if lastClose != -1 {
words := strings.Fields(joinClause[lastClose+1:])
// words should be like ["fn", "on", "true"] or ["on", "true"]
if len(words) >= 1 && !strings.EqualFold(words[0], "on") {
return words[0]
}
}
return ""
}
// Regular joins: find the "ON" keyword position (first occurrence)
onIdx := strings.Index(upperJoin, " ON ")
if onIdx == -1 {
return ""
@@ -863,8 +876,21 @@ func (h *Handler) resolveRelationNamesInOptions(options *ExtendedRequestOptions,
// Resolve each part of the path
currentModel := model
for _, part := range parts {
resolvedPart := h.resolveRelationName(currentModel, part)
for partIdx, part := range parts {
isLast := partIdx == len(parts)-1
var resolvedPart string
if isLast {
// For the final part, use join-key-aware resolution to disambiguate when
// multiple fields point to the same type (e.g., HUB_RID_HUB vs HUB_RID_ASSIGNEDTO).
// RelatedKey = parent's local column linking to child; ForeignKey = local column linking to parent.
localKey := preload.RelatedKey
if localKey == "" {
localKey = preload.ForeignKey
}
resolvedPart = h.resolveRelationNameWithJoinKey(currentModel, part, localKey)
} else {
resolvedPart = h.resolveRelationName(currentModel, part)
}
resolvedParts = append(resolvedParts, resolvedPart)
// Try to get the model type for the next level
@@ -980,6 +1006,101 @@ func (h *Handler) resolveRelationName(model interface{}, nameOrTable string) str
return nameOrTable
}
// resolveRelationNameWithJoinKey resolves a relation name like resolveRelationName, but when
// multiple fields point to the same related type, uses localKey to pick the one whose bun join
// tag starts with "join:localKey=". Falls back to resolveRelationName if no key match is found.
func (h *Handler) resolveRelationNameWithJoinKey(model interface{}, nameOrTable string, localKey string) string {
if localKey == "" {
return h.resolveRelationName(model, nameOrTable)
}
modelType := reflect.TypeOf(model)
if modelType == nil {
return nameOrTable
}
if modelType.Kind() == reflect.Ptr {
modelType = modelType.Elem()
}
if modelType == nil || modelType.Kind() != reflect.Struct {
return nameOrTable
}
// If it's already a direct field name, return as-is (no ambiguity).
for i := 0; i < modelType.NumField(); i++ {
if modelType.Field(i).Name == nameOrTable {
return nameOrTable
}
}
normalizedInput := strings.ToLower(strings.ReplaceAll(nameOrTable, "_", ""))
localKeyLower := strings.ToLower(localKey)
// Find all fields whose related type matches nameOrTable, then pick the one
// whose bun join tag local key matches localKey.
var fallbackField string
for i := 0; i < modelType.NumField(); i++ {
field := modelType.Field(i)
fieldType := field.Type
var targetType reflect.Type
if fieldType.Kind() == reflect.Slice {
targetType = fieldType.Elem()
} else if fieldType.Kind() == reflect.Ptr {
targetType = fieldType.Elem()
}
if targetType != nil && targetType.Kind() == reflect.Ptr {
targetType = targetType.Elem()
}
if targetType == nil || targetType.Kind() != reflect.Struct {
continue
}
normalizedTypeName := strings.ToLower(targetType.Name())
normalizedTypeName = strings.TrimPrefix(normalizedTypeName, "modelcore")
normalizedTypeName = strings.TrimPrefix(normalizedTypeName, "model")
if normalizedTypeName != normalizedInput {
continue
}
// Type name matches; record as fallback.
if fallbackField == "" {
fallbackField = field.Name
}
// Check bun join tag: "join:localKey=foreignKey"
bunTag := field.Tag.Get("bun")
for _, tagPart := range strings.Split(bunTag, ",") {
tagPart = strings.TrimSpace(tagPart)
if !strings.HasPrefix(tagPart, "join:") {
continue
}
joinSpec := strings.TrimPrefix(tagPart, "join:")
// joinSpec can be "col1=col2" or "col1=col2 col3=col4" (multi-col joins)
joinCols := strings.Fields(joinSpec)
if len(joinCols) == 0 {
joinCols = []string{joinSpec}
}
for _, joinCol := range joinCols {
eqIdx := strings.Index(joinCol, "=")
if eqIdx < 0 {
continue
}
joinLocalKey := strings.ToLower(joinCol[:eqIdx])
if joinLocalKey == localKeyLower {
logger.Debug("Resolved '%s' (localKey: %s) -> field '%s'", nameOrTable, localKey, field.Name)
return field.Name
}
}
}
}
if fallbackField != "" {
logger.Debug("No join key match for '%s' (localKey: %s), using first type match: '%s'", nameOrTable, localKey, fallbackField)
return fallbackField
}
return h.resolveRelationName(model, nameOrTable)
}
// addXFilesPreload converts an XFiles relation into a PreloadOption
// and recursively processes its children
func (h *Handler) addXFilesPreload(xfile *XFiles, options *ExtendedRequestOptions, basePath string) {

View File

@@ -142,6 +142,16 @@ func TestExtractJoinAlias(t *testing.T) {
joinClause: "LEFT JOIN departments",
expected: "",
},
{
name: "LATERAL join with alias",
joinClause: "inner join lateral (select sortorder from compute_fn(t.id)) fn on true",
expected: "fn",
},
{
name: "LATERAL join with multiline subquery containing inner ON",
joinClause: "inner join lateral (\nselect string_agg(a.name, '.') as sortorder\nfrom tree(t.id) r\ninner join account a on a.id = r.id\n) fn on true",
expected: "fn",
},
}
for _, tt := range tests {

153
pkg/security/KEYSTORE.md Normal file
View File

@@ -0,0 +1,153 @@
# Keystore
Per-user named auth keys with pluggable storage. Each user can hold multiple keys of different types — JWT secrets, header API keys, OAuth2 client credentials, or generic API keys. Keys are identified by a human-readable name ("CI deploy", "mobile app") and can carry scopes and arbitrary metadata.
## Key types
| Constant | Value | Use case |
|---|---|---|
| `KeyTypeJWTSecret` | `jwt_secret` | Per-user JWT signing secret |
| `KeyTypeHeaderAPI` | `header_api` | Static API key sent in a request header |
| `KeyTypeOAuth2` | `oauth2` | OAuth2 client credentials |
| `KeyTypeGenericAPI` | `api` | General-purpose application key |
## Storage backends
### ConfigKeyStore
In-memory store seeded from a static list. Suitable for a small, fixed set of service-account keys loaded from a config file. Keys created at runtime via `CreateKey` are held in memory and lost on restart.
```go
// Pre-load keys from config (KeyHash = SHA-256 hex of the raw key)
store := security.NewConfigKeyStore([]security.UserKey{
{
UserID: 1,
KeyType: security.KeyTypeGenericAPI,
KeyHash: "e3b0c44298fc1c149afb...", // sha256(rawKey)
Name: "CI deploy",
Scopes: []string{"deploy"},
IsActive: true,
},
})
```
### DatabaseKeyStore
Backed by PostgreSQL stored procedures. Supports optional caching (default 2-minute TTL). Apply `keystore_schema.sql` before use.
```go
db, _ := sql.Open("postgres", dsn)
store := security.NewDatabaseKeyStore(db)
// With options
store = security.NewDatabaseKeyStore(db, security.DatabaseKeyStoreOptions{
CacheTTL: 5 * time.Minute,
SQLNames: &security.KeyStoreSQLNames{
ValidateKey: "myapp_keystore_validate", // override one procedure name
},
})
```
## Managing keys
```go
ctx := context.Background()
// Create — raw key returned once; store it securely
resp, err := store.CreateKey(ctx, security.CreateKeyRequest{
UserID: 42,
KeyType: security.KeyTypeGenericAPI,
Name: "mobile app",
Scopes: []string{"read", "write"},
})
fmt.Println(resp.RawKey) // only shown here; hashed internally
// List
keys, err := store.GetUserKeys(ctx, 42, "") // "" = all types
keys, err = store.GetUserKeys(ctx, 42, security.KeyTypeGenericAPI)
// Revoke
err = store.DeleteKey(ctx, 42, resp.Key.ID)
// Validate (used by authenticators internally)
key, err := store.ValidateKey(ctx, rawKey, "")
```
## HTTP authentication
`KeyStoreAuthenticator` wraps any `KeyStore` and implements the `Authenticator` interface. It is drop-in compatible with `DatabaseAuthenticator` and works in `CompositeSecurityProvider`.
Keys are extracted from the request in this order:
1. `Authorization: Bearer <key>`
2. `Authorization: ApiKey <key>`
3. `X-API-Key: <key>`
```go
auth := security.NewKeyStoreAuthenticator(store, "") // "" = accept any key type
// Restrict to a specific type:
auth = security.NewKeyStoreAuthenticator(store, security.KeyTypeGenericAPI)
```
Plug it into a handler:
```go
handler := resolvespec.NewHandler(db, registry,
resolvespec.WithAuthenticator(auth),
)
```
`Login` and `Logout` return an error — key lifecycle is managed through `KeyStore` directly.
On successful validation the request context receives a `UserContext` where:
- `UserID` — from the key
- `Roles` — the key's `Scopes`
- `Claims["key_type"]` — key type string
- `Claims["key_name"]` — key name
## Database setup
Apply `keystore_schema.sql` to your PostgreSQL database. It requires the `users` table from the main `database_schema.sql`.
```sql
\i pkg/security/keystore_schema.sql
```
This creates:
- `user_keys` table with indexes on `user_id`, `key_hash`, and `key_type`
- `resolvespec_keystore_get_user_keys(p_user_id, p_key_type)`
- `resolvespec_keystore_create_key(p_request jsonb)`
- `resolvespec_keystore_delete_key(p_user_id, p_key_id)`
- `resolvespec_keystore_validate_key(p_key_hash, p_key_type)`
### Custom procedure names
```go
store := security.NewDatabaseKeyStore(db, security.DatabaseKeyStoreOptions{
SQLNames: &security.KeyStoreSQLNames{
GetUserKeys: "myschema_get_keys",
CreateKey: "myschema_create_key",
DeleteKey: "myschema_delete_key",
ValidateKey: "myschema_validate_key",
},
})
// Validate names at startup
names := &security.KeyStoreSQLNames{
GetUserKeys: "myschema_get_keys",
// ...
}
if err := security.ValidateKeyStoreSQLNames(names); err != nil {
log.Fatal(err)
}
```
## Security notes
- Raw keys are never stored. Only the SHA-256 hex digest is persisted.
- The raw key is generated with `crypto/rand` (32 bytes, base64url-encoded) and returned exactly once in `CreateKeyResponse.RawKey`.
- Hash comparisons in `ConfigKeyStore` use `crypto/subtle.ConstantTimeCompare` to prevent timing side-channels.
- `DeleteKey` performs a soft delete (`is_active = false`). The `DatabaseKeyStore` invalidates the cache entry immediately, but due to the cache TTL a revoked key may authenticate for up to `CacheTTL` (default 2 minutes) in a distributed environment. Set `CacheTTL: 0` to disable caching if immediate revocation is required.

View File

@@ -258,11 +258,8 @@ func (a *JWTAuthenticator) Login(ctx context.Context, req security.LoginRequest)
}
func (a *JWTAuthenticator) Logout(ctx context.Context, req security.LogoutRequest) error {
// Add to blacklist
return a.db.WithContext(ctx).Table("token_blacklist").Create(map[string]any{
"token": req.Token,
"user_id": req.UserID,
}).Error
// Invalidate session via stored procedure
return nil
}
func (a *JWTAuthenticator) Authenticate(r *http.Request) (*security.UserContext, error) {

View File

@@ -12,6 +12,7 @@ Type-safe, composable security system for ResolveSpec with support for authentic
-**Testable** - Easy to mock and test
-**Extensible** - Implement custom providers for your needs
-**Stored Procedures** - All database operations use PostgreSQL stored procedures for security and maintainability
-**OAuth2 Authorization Server** - Built-in OAuth 2.1 + PKCE server (RFC 8414, 7591, 7009, 7662) with login form and external provider federation
## Stored Procedure Architecture
@@ -38,6 +39,12 @@ Type-safe, composable security system for ResolveSpec with support for authentic
| `resolvespec_jwt_logout` | JWT token blacklist | JWTAuthenticator |
| `resolvespec_column_security` | Load column rules | DatabaseColumnSecurityProvider |
| `resolvespec_row_security` | Load row templates | DatabaseRowSecurityProvider |
| `resolvespec_oauth_register_client` | Persist OAuth2 client (RFC 7591) | OAuthServer / DatabaseAuthenticator |
| `resolvespec_oauth_get_client` | Retrieve OAuth2 client by ID | OAuthServer / DatabaseAuthenticator |
| `resolvespec_oauth_save_code` | Persist authorization code | OAuthServer / DatabaseAuthenticator |
| `resolvespec_oauth_exchange_code` | Consume authorization code (single-use) | OAuthServer / DatabaseAuthenticator |
| `resolvespec_oauth_introspect` | Token introspection (RFC 7662) | OAuthServer / DatabaseAuthenticator |
| `resolvespec_oauth_revoke` | Token revocation (RFC 7009) | OAuthServer / DatabaseAuthenticator |
See `database_schema.sql` for complete stored procedure definitions and examples.
@@ -897,6 +904,156 @@ securityList := security.NewSecurityList(provider)
restheadspec.RegisterSecurityHooks(handler, securityList) // or funcspec/resolvespec
```
## OAuth2 Authorization Server
`OAuthServer` is a generic OAuth 2.1 + PKCE authorization server. It is not tied to any spec — `pkg/resolvemcp` uses it, but it can be used standalone with any `http.ServeMux`.
### Endpoints
| Method | Path | RFC |
|--------|------|-----|
| `GET` | `/.well-known/oauth-authorization-server` | RFC 8414 — server metadata |
| `POST` | `/oauth/register` | RFC 7591 — dynamic client registration |
| `GET` | `/oauth/authorize` | OAuth 2.1 — start authorization / provider selection |
| `POST` | `/oauth/authorize` | OAuth 2.1 — login form submission |
| `POST` | `/oauth/token` | OAuth 2.1 — code exchange + refresh |
| `POST` | `/oauth/revoke` | RFC 7009 — token revocation |
| `POST` | `/oauth/introspect` | RFC 7662 — token introspection |
| `GET` | `{ProviderCallbackPath}` | External provider redirect target |
### Config
```go
cfg := security.OAuthServerConfig{
Issuer: "https://example.com", // Required — token issuer URL
ProviderCallbackPath: "/oauth/provider/callback", // External provider redirect target
LoginTitle: "My App Login", // HTML login page title
PersistClients: true, // Store clients in DB (multi-instance safe)
PersistCodes: true, // Store codes in DB (multi-instance safe)
DefaultScopes: []string{"openid", "profile"}, // Returned when no scope requested
AccessTokenTTL: time.Hour,
AuthCodeTTL: 5 * time.Minute,
}
```
| Field | Default | Notes |
|-------|---------|-------|
| `Issuer` | — | Required; trailing slash is trimmed automatically |
| `ProviderCallbackPath` | `/oauth/provider/callback` | |
| `LoginTitle` | `"Sign in"` | |
| `PersistClients` | `false` | Set `true` for multi-instance |
| `PersistCodes` | `false` | Set `true` for multi-instance; does not require `PersistClients` |
| `DefaultScopes` | `["openid","profile","email"]` | |
| `AccessTokenTTL` | `24h` | Also used as `expires_in` in token responses |
| `AuthCodeTTL` | `2m` | |
### Operating Modes
**Mode 1 — Direct login (username/password form)**
Pass a `*DatabaseAuthenticator` to `NewOAuthServer`. The server renders a login form at `GET /oauth/authorize` and issues tokens via the stored session after login.
```go
auth := security.NewDatabaseAuthenticator(db)
srv := security.NewOAuthServer(cfg, auth)
```
**Mode 2 — External provider federation**
Pass a `*DatabaseAuthenticator` for persistence (authorization codes, revoke, introspect) and register external providers. The authorize endpoint redirects to the specified provider (via the `provider` query param) or to the first registered provider by default.
```go
auth := security.NewDatabaseAuthenticator(db)
srv := security.NewOAuthServer(cfg, auth)
srv.RegisterExternalProvider(googleAuth, "google")
srv.RegisterExternalProvider(githubAuth, "github")
```
**Mode 3 — Both**
Pass auth for the login form and also register external providers. The authorize page shows both a login form and provider buttons.
```go
srv := security.NewOAuthServer(cfg, auth)
srv.RegisterExternalProvider(googleAuth, "google")
```
### Standalone Usage
```go
mux := http.NewServeMux()
mux.Handle("/.well-known/", srv.HTTPHandler())
mux.Handle("/oauth/", srv.HTTPHandler())
mux.Handle(cfg.ProviderCallbackPath, srv.HTTPHandler())
http.ListenAndServe(":8080", mux)
```
### DB Persistence
When `PersistClients: true` or `PersistCodes: true`, the server calls the corresponding `DatabaseAuthenticator` methods. Both flags default to `false` (in-memory maps). Enable both for multi-instance deployments.
Requires `oauth_clients` and `oauth_codes` tables + 6 stored procedures from `database_schema.sql`.
#### New DB Types
```go
type OAuthServerClient struct {
ClientID string `json:"client_id"`
RedirectURIs []string `json:"redirect_uris"`
ClientName string `json:"client_name,omitempty"`
GrantTypes []string `json:"grant_types"`
AllowedScopes []string `json:"allowed_scopes,omitempty"`
}
type OAuthCode struct {
Code string `json:"code"`
ClientID string `json:"client_id"`
RedirectURI string `json:"redirect_uri"`
ClientState string `json:"client_state,omitempty"`
CodeChallenge string `json:"code_challenge"`
CodeChallengeMethod string `json:"code_challenge_method"`
SessionToken string `json:"session_token"`
Scopes []string `json:"scopes,omitempty"`
ExpiresAt time.Time `json:"expires_at"`
}
type OAuthTokenInfo struct {
Active bool `json:"active"`
Sub string `json:"sub,omitempty"`
Username string `json:"username,omitempty"`
Email string `json:"email,omitempty"`
Roles []string `json:"roles,omitempty"`
Exp int64 `json:"exp,omitempty"`
Iat int64 `json:"iat,omitempty"`
}
```
#### DatabaseAuthenticator OAuth Methods
```go
auth.OAuthRegisterClient(ctx, client) // RFC 7591 — persist client
auth.OAuthGetClient(ctx, clientID) // retrieve client
auth.OAuthSaveCode(ctx, code) // persist authorization code
auth.OAuthExchangeCode(ctx, code) // consume code (single-use, deletes on read)
auth.OAuthIntrospectToken(ctx, token) // RFC 7662 — returns OAuthTokenInfo
auth.OAuthRevokeToken(ctx, token) // RFC 7009 — revoke session
```
#### SQLNames Fields
```go
type SQLNames struct {
// ... existing fields ...
OAuthRegisterClient string // default: "resolvespec_oauth_register_client"
OAuthGetClient string // default: "resolvespec_oauth_get_client"
OAuthSaveCode string // default: "resolvespec_oauth_save_code"
OAuthExchangeCode string // default: "resolvespec_oauth_exchange_code"
OAuthIntrospect string // default: "resolvespec_oauth_introspect"
OAuthRevoke string // default: "resolvespec_oauth_revoke"
}
```
The main changes:
1. Security package no longer knows about specific spec types
2. Each spec registers its own security hooks

View File

@@ -1397,3 +1397,180 @@ $$ LANGUAGE plpgsql;
-- Get credentials by username
-- SELECT * FROM resolvespec_passkey_get_credentials_by_username('admin');
-- ============================================
-- OAuth2 Server Tables (OAuthServer persistence)
-- ============================================
-- oauth_clients: persistent RFC 7591 registered clients
CREATE TABLE IF NOT EXISTS oauth_clients (
id SERIAL PRIMARY KEY,
client_id VARCHAR(255) NOT NULL UNIQUE,
redirect_uris TEXT[] NOT NULL,
client_name VARCHAR(255),
grant_types TEXT[] DEFAULT ARRAY['authorization_code'],
allowed_scopes TEXT[] DEFAULT ARRAY['openid','profile','email'],
is_active BOOLEAN DEFAULT true,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
-- oauth_codes: short-lived authorization codes (for multi-instance deployments)
-- Note: client_id is stored without a foreign key so codes can be persisted even
-- when OAuth clients are managed in memory rather than persisted in oauth_clients.
CREATE TABLE IF NOT EXISTS oauth_codes (
id SERIAL PRIMARY KEY,
code VARCHAR(255) NOT NULL UNIQUE,
client_id VARCHAR(255) NOT NULL,
redirect_uri TEXT NOT NULL,
client_state TEXT,
code_challenge VARCHAR(255) NOT NULL,
code_challenge_method VARCHAR(10) DEFAULT 'S256',
session_token TEXT NOT NULL,
refresh_token TEXT,
scopes TEXT[],
expires_at TIMESTAMP NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
CREATE INDEX IF NOT EXISTS idx_oauth_codes_code ON oauth_codes(code);
CREATE INDEX IF NOT EXISTS idx_oauth_codes_expires ON oauth_codes(expires_at);
-- ============================================
-- OAuth2 Server Stored Procedures
-- ============================================
CREATE OR REPLACE FUNCTION resolvespec_oauth_register_client(p_data jsonb)
RETURNS TABLE(p_success bool, p_error text, p_data jsonb)
LANGUAGE plpgsql AS $$
DECLARE
v_client_id text;
v_row jsonb;
BEGIN
v_client_id := p_data->>'client_id';
INSERT INTO oauth_clients (client_id, redirect_uris, client_name, grant_types, allowed_scopes)
VALUES (
v_client_id,
ARRAY(SELECT jsonb_array_elements_text(p_data->'redirect_uris')),
p_data->>'client_name',
COALESCE(ARRAY(SELECT jsonb_array_elements_text(p_data->'grant_types')), ARRAY['authorization_code']),
COALESCE(ARRAY(SELECT jsonb_array_elements_text(p_data->'allowed_scopes')), ARRAY['openid','profile','email'])
)
RETURNING to_jsonb(oauth_clients.*) INTO v_row;
RETURN QUERY SELECT true, null::text, v_row;
EXCEPTION WHEN OTHERS THEN
RETURN QUERY SELECT false, SQLERRM, null::jsonb;
END;
$$;
CREATE OR REPLACE FUNCTION resolvespec_oauth_get_client(p_client_id text)
RETURNS TABLE(p_success bool, p_error text, p_data jsonb)
LANGUAGE plpgsql AS $$
DECLARE
v_row jsonb;
BEGIN
SELECT to_jsonb(oauth_clients.*)
INTO v_row
FROM oauth_clients
WHERE client_id = p_client_id AND is_active = true;
IF v_row IS NULL THEN
RETURN QUERY SELECT false, 'client not found'::text, null::jsonb;
ELSE
RETURN QUERY SELECT true, null::text, v_row;
END IF;
END;
$$;
CREATE OR REPLACE FUNCTION resolvespec_oauth_save_code(p_data jsonb)
RETURNS TABLE(p_success bool, p_error text)
LANGUAGE plpgsql AS $$
BEGIN
INSERT INTO oauth_codes (code, client_id, redirect_uri, client_state, code_challenge, code_challenge_method, session_token, refresh_token, scopes, expires_at)
VALUES (
p_data->>'code',
p_data->>'client_id',
p_data->>'redirect_uri',
p_data->>'client_state',
p_data->>'code_challenge',
COALESCE(p_data->>'code_challenge_method', 'S256'),
p_data->>'session_token',
p_data->>'refresh_token',
ARRAY(SELECT jsonb_array_elements_text(p_data->'scopes')),
(p_data->>'expires_at')::timestamp
);
RETURN QUERY SELECT true, null::text;
EXCEPTION WHEN OTHERS THEN
RETURN QUERY SELECT false, SQLERRM;
END;
$$;
CREATE OR REPLACE FUNCTION resolvespec_oauth_exchange_code(p_code text)
RETURNS TABLE(p_success bool, p_error text, p_data jsonb)
LANGUAGE plpgsql AS $$
DECLARE
v_row jsonb;
BEGIN
DELETE FROM oauth_codes
WHERE code = p_code AND expires_at > now()
RETURNING jsonb_build_object(
'client_id', client_id,
'redirect_uri', redirect_uri,
'client_state', client_state,
'code_challenge', code_challenge,
'code_challenge_method', code_challenge_method,
'session_token', session_token,
'refresh_token', refresh_token,
'scopes', to_jsonb(scopes)
) INTO v_row;
IF v_row IS NULL THEN
RETURN QUERY SELECT false, 'invalid or expired code'::text, null::jsonb;
ELSE
RETURN QUERY SELECT true, null::text, v_row;
END IF;
END;
$$;
CREATE OR REPLACE FUNCTION resolvespec_oauth_introspect(p_token text)
RETURNS TABLE(p_success bool, p_error text, p_data jsonb)
LANGUAGE plpgsql AS $$
DECLARE
v_row jsonb;
BEGIN
SELECT jsonb_build_object(
'active', true,
'sub', u.id::text,
'username', u.username,
'email', u.email,
'user_level', u.user_level,
-- NULLIF converts empty string to NULL; string_to_array(NULL) returns NULL;
-- to_jsonb(NULL) returns NULL; COALESCE then returns '[]' for NULL/empty roles.
'roles', COALESCE(to_jsonb(string_to_array(NULLIF(u.roles, ''), ',')), '[]'::jsonb),
'exp', EXTRACT(EPOCH FROM s.expires_at)::bigint,
'iat', EXTRACT(EPOCH FROM s.created_at)::bigint
)
INTO v_row
FROM user_sessions s
JOIN users u ON u.id = s.user_id
WHERE s.session_token = p_token
AND s.expires_at > now()
AND u.is_active = true;
IF v_row IS NULL THEN
RETURN QUERY SELECT true, null::text, '{"active":false}'::jsonb;
ELSE
RETURN QUERY SELECT true, null::text, v_row;
END IF;
END;
$$;
CREATE OR REPLACE FUNCTION resolvespec_oauth_revoke(p_token text)
RETURNS TABLE(p_success bool, p_error text)
LANGUAGE plpgsql AS $$
BEGIN
DELETE FROM user_sessions WHERE session_token = p_token;
RETURN QUERY SELECT true, null::text;
END;
$$;

View File

@@ -135,12 +135,6 @@ func (a *JWTAuthenticatorExample) Login(ctx context.Context, req LoginRequest) (
}
func (a *JWTAuthenticatorExample) Logout(ctx context.Context, req LogoutRequest) error {
// For JWT, logout could involve token blacklisting
// Add token to blacklist table
// err := a.db.WithContext(ctx).Table("token_blacklist").Create(map[string]interface{}{
// "token": req.Token,
// "expires_at": time.Now().Add(24 * time.Hour),
// }).Error
return nil
}

81
pkg/security/keystore.go Normal file
View File

@@ -0,0 +1,81 @@
package security
import (
"context"
"crypto/sha256"
"encoding/hex"
"time"
)
// hashSHA256Hex returns the lowercase hex SHA-256 digest of the given string.
// Used by all keystore implementations to hash raw keys before storage or lookup.
func hashSHA256Hex(raw string) string {
sum := sha256.Sum256([]byte(raw))
return hex.EncodeToString(sum[:])
}
// KeyType identifies the category of an auth key.
type KeyType string
const (
// KeyTypeJWTSecret is a per-user JWT signing secret for token generation.
KeyTypeJWTSecret KeyType = "jwt_secret"
// KeyTypeHeaderAPI is a static API key sent via a request header.
KeyTypeHeaderAPI KeyType = "header_api"
// KeyTypeOAuth2 holds OAuth2 client credentials (client_id / client_secret).
KeyTypeOAuth2 KeyType = "oauth2"
// KeyTypeGenericAPI is a generic application API key.
KeyTypeGenericAPI KeyType = "api"
)
// UserKey represents a single named auth key belonging to a user.
// KeyHash stores the SHA-256 hex digest of the raw key; the raw key is never persisted.
type UserKey struct {
ID int64 `json:"id"`
UserID int `json:"user_id"`
KeyType KeyType `json:"key_type"`
KeyHash string `json:"key_hash"` // SHA-256 hex; never the raw key
Name string `json:"name"`
Scopes []string `json:"scopes,omitempty"`
Meta map[string]any `json:"meta,omitempty"`
ExpiresAt *time.Time `json:"expires_at,omitempty"`
CreatedAt time.Time `json:"created_at"`
LastUsedAt *time.Time `json:"last_used_at,omitempty"`
IsActive bool `json:"is_active"`
}
// CreateKeyRequest specifies the parameters for a new key.
type CreateKeyRequest struct {
UserID int
KeyType KeyType
Name string
Scopes []string
Meta map[string]any
ExpiresAt *time.Time
}
// CreateKeyResponse is returned exactly once when a key is created.
// The caller is responsible for persisting RawKey; it is not stored anywhere.
type CreateKeyResponse struct {
Key UserKey
RawKey string // crypto/rand 32 bytes, base64url-encoded
}
// KeyStore manages per-user auth keys with pluggable storage backends.
// Implementations: ConfigKeyStore (static list) and DatabaseKeyStore (stored procedures).
type KeyStore interface {
// CreateKey generates a new key, stores its hash, and returns the raw key once.
CreateKey(ctx context.Context, req CreateKeyRequest) (*CreateKeyResponse, error)
// GetUserKeys returns all active, non-expired keys for a user.
// Pass an empty KeyType to return all types.
GetUserKeys(ctx context.Context, userID int, keyType KeyType) ([]UserKey, error)
// DeleteKey soft-deletes a key by ID after verifying ownership.
DeleteKey(ctx context.Context, userID int, keyID int64) error
// ValidateKey checks a raw key, returns the matching UserKey on success.
// The implementation hashes the raw key before any lookup.
// Pass an empty KeyType to accept any type.
ValidateKey(ctx context.Context, rawKey string, keyType KeyType) (*UserKey, error)
}

View File

@@ -0,0 +1,97 @@
package security
import (
"context"
"fmt"
"net/http"
"strings"
)
// KeyStoreAuthenticator implements the Authenticator interface using a KeyStore.
// It is suitable for long-lived application credentials (API keys, JWT secrets, etc.)
// rather than interactive sessions. Login and Logout are not supported — key lifecycle
// is managed directly through the KeyStore.
//
// Key extraction order:
// 1. Authorization: Bearer <key>
// 2. Authorization: ApiKey <key>
// 3. X-API-Key header
type KeyStoreAuthenticator struct {
keyStore KeyStore
keyType KeyType // empty = accept any type
}
// NewKeyStoreAuthenticator creates a KeyStoreAuthenticator.
// Pass an empty keyType to accept keys of any type.
func NewKeyStoreAuthenticator(ks KeyStore, keyType KeyType) *KeyStoreAuthenticator {
return &KeyStoreAuthenticator{keyStore: ks, keyType: keyType}
}
// Login is not supported for keystore authentication.
func (a *KeyStoreAuthenticator) Login(_ context.Context, _ LoginRequest) (*LoginResponse, error) {
return nil, fmt.Errorf("keystore authenticator does not support login")
}
// Logout is not supported for keystore authentication.
func (a *KeyStoreAuthenticator) Logout(_ context.Context, _ LogoutRequest) error {
return nil
}
// Authenticate extracts an API key from the request and validates it against the KeyStore.
// Returns a UserContext built from the matching UserKey on success.
func (a *KeyStoreAuthenticator) Authenticate(r *http.Request) (*UserContext, error) {
rawKey := extractAPIKey(r)
if rawKey == "" {
return nil, fmt.Errorf("API key required (Authorization: Bearer/ApiKey <key> or X-API-Key header)")
}
userKey, err := a.keyStore.ValidateKey(r.Context(), rawKey, a.keyType)
if err != nil {
return nil, fmt.Errorf("invalid API key: %w", err)
}
return userKeyToUserContext(userKey), nil
}
// extractAPIKey extracts a raw key from the request using the following precedence:
// 1. Authorization: Bearer <key>
// 2. Authorization: ApiKey <key>
// 3. X-API-Key header
func extractAPIKey(r *http.Request) string {
if auth := r.Header.Get("Authorization"); auth != "" {
if after, ok := strings.CutPrefix(auth, "Bearer "); ok {
return strings.TrimSpace(after)
}
if after, ok := strings.CutPrefix(auth, "ApiKey "); ok {
return strings.TrimSpace(after)
}
}
return strings.TrimSpace(r.Header.Get("X-API-Key"))
}
// userKeyToUserContext converts a UserKey into a UserContext.
// Scopes are mapped to Roles. Key type and name are stored in Claims.
func userKeyToUserContext(k *UserKey) *UserContext {
claims := map[string]any{
"key_type": string(k.KeyType),
"key_name": k.Name,
}
meta := k.Meta
if meta == nil {
meta = map[string]any{}
}
roles := k.Scopes
if roles == nil {
roles = []string{}
}
return &UserContext{
UserID: k.UserID,
SessionID: fmt.Sprintf("key:%d", k.ID),
Roles: roles,
Claims: claims,
Meta: meta,
}
}

View File

@@ -0,0 +1,149 @@
package security
import (
"context"
"crypto/rand"
"crypto/subtle"
"encoding/base64"
"encoding/hex"
"fmt"
"sync"
"sync/atomic"
"time"
)
// ConfigKeyStore is an in-memory keystore backed by a static slice of UserKey values.
// It is designed for config-file driven setups (e.g. service accounts defined in YAML)
// with a small, bounded number of keys. For large or dynamic key sets use DatabaseKeyStore.
//
// Pre-existing entries must have KeyHash set to the SHA-256 hex of the intended raw key.
// Keys created at runtime via CreateKey are held in memory only and lost on restart.
type ConfigKeyStore struct {
mu sync.RWMutex
keys []UserKey
next int64 // monotonic ID counter for runtime-created keys (atomic)
}
// NewConfigKeyStore creates a ConfigKeyStore seeded with the provided keys.
// Pass nil or an empty slice to start with no pre-loaded keys.
// Zero-value entries (CreatedAt is zero) are treated as active and assigned the current time.
func NewConfigKeyStore(keys []UserKey) *ConfigKeyStore {
var maxID int64
copied := make([]UserKey, len(keys))
copy(copied, keys)
for i := range copied {
if copied[i].CreatedAt.IsZero() {
copied[i].IsActive = true
copied[i].CreatedAt = time.Now()
}
if copied[i].ID > maxID {
maxID = copied[i].ID
}
}
return &ConfigKeyStore{keys: copied, next: maxID}
}
// CreateKey generates a new raw key, stores its SHA-256 hash, and returns the raw key once.
func (s *ConfigKeyStore) CreateKey(_ context.Context, req CreateKeyRequest) (*CreateKeyResponse, error) {
rawBytes := make([]byte, 32)
if _, err := rand.Read(rawBytes); err != nil {
return nil, fmt.Errorf("failed to generate key material: %w", err)
}
rawKey := base64.RawURLEncoding.EncodeToString(rawBytes)
hash := hashSHA256Hex(rawKey)
id := atomic.AddInt64(&s.next, 1)
key := UserKey{
ID: id,
UserID: req.UserID,
KeyType: req.KeyType,
KeyHash: hash,
Name: req.Name,
Scopes: req.Scopes,
Meta: req.Meta,
ExpiresAt: req.ExpiresAt,
CreatedAt: time.Now(),
IsActive: true,
}
s.mu.Lock()
s.keys = append(s.keys, key)
s.mu.Unlock()
return &CreateKeyResponse{Key: key, RawKey: rawKey}, nil
}
// GetUserKeys returns all active, non-expired keys for the given user.
// Pass an empty KeyType to return all types.
func (s *ConfigKeyStore) GetUserKeys(_ context.Context, userID int, keyType KeyType) ([]UserKey, error) {
now := time.Now()
s.mu.RLock()
defer s.mu.RUnlock()
var result []UserKey
for i := range s.keys {
k := &s.keys[i]
if k.UserID != userID || !k.IsActive {
continue
}
if k.ExpiresAt != nil && k.ExpiresAt.Before(now) {
continue
}
if keyType != "" && k.KeyType != keyType {
continue
}
result = append(result, *k)
}
return result, nil
}
// DeleteKey soft-deletes a key by setting IsActive to false after ownership verification.
func (s *ConfigKeyStore) DeleteKey(_ context.Context, userID int, keyID int64) error {
s.mu.Lock()
defer s.mu.Unlock()
for i := range s.keys {
if s.keys[i].ID == keyID {
if s.keys[i].UserID != userID {
return fmt.Errorf("key not found or permission denied")
}
s.keys[i].IsActive = false
return nil
}
}
return fmt.Errorf("key not found")
}
// ValidateKey hashes the raw key and finds a matching, active, non-expired entry.
// Uses constant-time comparison to prevent timing side-channels.
// Pass an empty KeyType to accept any type.
func (s *ConfigKeyStore) ValidateKey(_ context.Context, rawKey string, keyType KeyType) (*UserKey, error) {
hash := hashSHA256Hex(rawKey)
hashBytes, _ := hex.DecodeString(hash)
now := time.Now()
// Write lock: ValidateKey updates LastUsedAt on the matched entry.
s.mu.Lock()
defer s.mu.Unlock()
for i := range s.keys {
k := &s.keys[i]
if !k.IsActive {
continue
}
if k.ExpiresAt != nil && k.ExpiresAt.Before(now) {
continue
}
if keyType != "" && k.KeyType != keyType {
continue
}
stored, _ := hex.DecodeString(k.KeyHash)
if subtle.ConstantTimeCompare(hashBytes, stored) != 1 {
continue
}
k.LastUsedAt = &now
result := *k
return &result, nil
}
return nil, fmt.Errorf("invalid or expired key")
}

View File

@@ -0,0 +1,256 @@
package security
import (
"context"
"crypto/rand"
"database/sql"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"sync"
"time"
"github.com/bitechdev/ResolveSpec/pkg/cache"
)
// DatabaseKeyStoreOptions configures DatabaseKeyStore.
type DatabaseKeyStoreOptions struct {
// Cache is an optional cache instance. If nil, uses the default cache.
Cache *cache.Cache
// CacheTTL is the duration to cache ValidateKey results.
// Default: 2 minutes.
CacheTTL time.Duration
// SQLNames provides custom procedure names. If nil, uses DefaultKeyStoreSQLNames().
SQLNames *KeyStoreSQLNames
// DBFactory is called to obtain a fresh *sql.DB when the existing connection is closed.
// If nil, reconnection is disabled.
DBFactory func() (*sql.DB, error)
}
// DatabaseKeyStore is a KeyStore backed by PostgreSQL stored procedures.
// All DB operations go through configurable procedure names; the raw key is
// never passed to the database.
//
// See keystore_schema.sql for the required table and procedure definitions.
//
// Note: DeleteKey invalidates the cache entry for the deleted key. Due to the
// cache TTL, a deleted key may continue to authenticate for up to CacheTTL
// (default 2 minutes) if the cache entry cannot be invalidated.
type DatabaseKeyStore struct {
db *sql.DB
dbMu sync.RWMutex
dbFactory func() (*sql.DB, error)
sqlNames *KeyStoreSQLNames
cache *cache.Cache
cacheTTL time.Duration
}
// NewDatabaseKeyStore creates a DatabaseKeyStore with optional configuration.
func NewDatabaseKeyStore(db *sql.DB, opts ...DatabaseKeyStoreOptions) *DatabaseKeyStore {
o := DatabaseKeyStoreOptions{}
if len(opts) > 0 {
o = opts[0]
}
if o.CacheTTL == 0 {
o.CacheTTL = 2 * time.Minute
}
c := o.Cache
if c == nil {
c = cache.GetDefaultCache()
}
names := MergeKeyStoreSQLNames(DefaultKeyStoreSQLNames(), o.SQLNames)
return &DatabaseKeyStore{
db: db,
dbFactory: o.DBFactory,
sqlNames: names,
cache: c,
cacheTTL: o.CacheTTL,
}
}
func (ks *DatabaseKeyStore) getDB() *sql.DB {
ks.dbMu.RLock()
defer ks.dbMu.RUnlock()
return ks.db
}
func (ks *DatabaseKeyStore) reconnectDB() error {
if ks.dbFactory == nil {
return fmt.Errorf("no db factory configured for reconnect")
}
newDB, err := ks.dbFactory()
if err != nil {
return err
}
ks.dbMu.Lock()
ks.db = newDB
ks.dbMu.Unlock()
return nil
}
// CreateKey generates a raw key, stores its SHA-256 hash via the create procedure,
// and returns the raw key once.
func (ks *DatabaseKeyStore) CreateKey(ctx context.Context, req CreateKeyRequest) (*CreateKeyResponse, error) {
rawBytes := make([]byte, 32)
if _, err := rand.Read(rawBytes); err != nil {
return nil, fmt.Errorf("failed to generate key material: %w", err)
}
rawKey := base64.RawURLEncoding.EncodeToString(rawBytes)
hash := hashSHA256Hex(rawKey)
type createRequest struct {
UserID int `json:"user_id"`
KeyType KeyType `json:"key_type"`
KeyHash string `json:"key_hash"`
Name string `json:"name"`
Scopes []string `json:"scopes,omitempty"`
Meta map[string]any `json:"meta,omitempty"`
ExpiresAt *time.Time `json:"expires_at,omitempty"`
}
reqJSON, err := json.Marshal(createRequest{
UserID: req.UserID,
KeyType: req.KeyType,
KeyHash: hash,
Name: req.Name,
Scopes: req.Scopes,
Meta: req.Meta,
ExpiresAt: req.ExpiresAt,
})
if err != nil {
return nil, fmt.Errorf("failed to marshal create key request: %w", err)
}
var success bool
var errorMsg sql.NullString
var keyJSON sql.NullString
query := fmt.Sprintf(`SELECT p_success, p_error, p_key::text FROM %s($1::jsonb)`, ks.sqlNames.CreateKey)
if err = ks.getDB().QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &keyJSON); err != nil {
return nil, fmt.Errorf("create key procedure failed: %w", err)
}
if !success {
return nil, errors.New(nullStringOr(errorMsg, "create key failed"))
}
var key UserKey
if err = json.Unmarshal([]byte(keyJSON.String), &key); err != nil {
return nil, fmt.Errorf("failed to parse created key: %w", err)
}
return &CreateKeyResponse{Key: key, RawKey: rawKey}, nil
}
// GetUserKeys returns all active, non-expired keys for the given user.
// Pass an empty KeyType to return all types.
func (ks *DatabaseKeyStore) GetUserKeys(ctx context.Context, userID int, keyType KeyType) ([]UserKey, error) {
var success bool
var errorMsg sql.NullString
var keysJSON sql.NullString
query := fmt.Sprintf(`SELECT p_success, p_error, p_keys::text FROM %s($1, $2)`, ks.sqlNames.GetUserKeys)
if err := ks.getDB().QueryRowContext(ctx, query, userID, string(keyType)).Scan(&success, &errorMsg, &keysJSON); err != nil {
return nil, fmt.Errorf("get user keys procedure failed: %w", err)
}
if !success {
return nil, errors.New(nullStringOr(errorMsg, "get user keys failed"))
}
var keys []UserKey
if keysJSON.Valid && keysJSON.String != "" && keysJSON.String != "[]" {
if err := json.Unmarshal([]byte(keysJSON.String), &keys); err != nil {
return nil, fmt.Errorf("failed to parse user keys: %w", err)
}
}
if keys == nil {
keys = []UserKey{}
}
return keys, nil
}
// DeleteKey soft-deletes a key after verifying ownership and invalidates its cache entry.
// The delete procedure returns the key_hash so no separate lookup is needed.
// Note: cache invalidation is best-effort; a cached entry may persist for up to CacheTTL.
func (ks *DatabaseKeyStore) DeleteKey(ctx context.Context, userID int, keyID int64) error {
var success bool
var errorMsg sql.NullString
var keyHash sql.NullString
query := fmt.Sprintf(`SELECT p_success, p_error, p_key_hash FROM %s($1, $2)`, ks.sqlNames.DeleteKey)
if err := ks.getDB().QueryRowContext(ctx, query, userID, keyID).Scan(&success, &errorMsg, &keyHash); err != nil {
return fmt.Errorf("delete key procedure failed: %w", err)
}
if !success {
return errors.New(nullStringOr(errorMsg, "delete key failed"))
}
if keyHash.Valid && keyHash.String != "" && ks.cache != nil {
_ = ks.cache.Delete(ctx, keystoreCacheKey(keyHash.String))
}
return nil
}
// ValidateKey hashes the raw key and calls the validate procedure.
// Results are cached for CacheTTL to reduce DB load on hot paths.
func (ks *DatabaseKeyStore) ValidateKey(ctx context.Context, rawKey string, keyType KeyType) (*UserKey, error) {
hash := hashSHA256Hex(rawKey)
cacheKey := keystoreCacheKey(hash)
if ks.cache != nil {
var cached UserKey
if err := ks.cache.Get(ctx, cacheKey, &cached); err == nil {
if cached.IsActive {
return &cached, nil
}
return nil, errors.New("invalid or expired key")
}
}
var success bool
var errorMsg sql.NullString
var keyJSON sql.NullString
runQuery := func() error {
query := fmt.Sprintf(`SELECT p_success, p_error, p_key::text FROM %s($1, $2)`, ks.sqlNames.ValidateKey)
return ks.getDB().QueryRowContext(ctx, query, hash, string(keyType)).Scan(&success, &errorMsg, &keyJSON)
}
if err := runQuery(); err != nil {
if isDBClosed(err) {
if reconnErr := ks.reconnectDB(); reconnErr == nil {
err = runQuery()
}
if err != nil {
return nil, fmt.Errorf("validate key procedure failed: %w", err)
}
} else {
return nil, fmt.Errorf("validate key procedure failed: %w", err)
}
}
if !success {
return nil, errors.New(nullStringOr(errorMsg, "invalid or expired key"))
}
var key UserKey
if err := json.Unmarshal([]byte(keyJSON.String), &key); err != nil {
return nil, fmt.Errorf("failed to parse validated key: %w", err)
}
if ks.cache != nil {
_ = ks.cache.Set(ctx, cacheKey, key, ks.cacheTTL)
}
return &key, nil
}
func keystoreCacheKey(hash string) string {
return "keystore:validate:" + hash
}
// nullStringOr returns s.String if valid, otherwise the fallback.
func nullStringOr(s sql.NullString, fallback string) string {
if s.Valid && s.String != "" {
return s.String
}
return fallback
}

View File

@@ -0,0 +1,187 @@
-- Keystore schema for per-user auth keys
-- Apply alongside database_schema.sql (requires the users table)
CREATE TABLE IF NOT EXISTS user_keys (
id BIGSERIAL PRIMARY KEY,
user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
key_type VARCHAR(50) NOT NULL,
key_hash VARCHAR(64) NOT NULL UNIQUE, -- SHA-256 hex digest (64 chars)
name VARCHAR(255) NOT NULL DEFAULT '',
scopes TEXT, -- JSON array, e.g. '["read","write"]'
meta JSONB,
expires_at TIMESTAMP,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
last_used_at TIMESTAMP,
is_active BOOLEAN DEFAULT true
);
CREATE INDEX IF NOT EXISTS idx_user_keys_user_id ON user_keys(user_id);
CREATE INDEX IF NOT EXISTS idx_user_keys_key_hash ON user_keys(key_hash);
CREATE INDEX IF NOT EXISTS idx_user_keys_key_type ON user_keys(key_type);
-- resolvespec_keystore_get_user_keys
-- Returns all active, non-expired keys for a user.
-- Pass empty p_key_type to return all key types.
CREATE OR REPLACE FUNCTION resolvespec_keystore_get_user_keys(
p_user_id INTEGER,
p_key_type TEXT DEFAULT ''
)
RETURNS TABLE(p_success BOOLEAN, p_error TEXT, p_keys JSONB)
LANGUAGE plpgsql AS $$
DECLARE
v_keys JSONB;
BEGIN
SELECT COALESCE(
jsonb_agg(
jsonb_build_object(
'id', k.id,
'user_id', k.user_id,
'key_type', k.key_type,
'name', k.name,
'scopes', CASE WHEN k.scopes IS NOT NULL THEN k.scopes::jsonb ELSE '[]'::jsonb END,
'meta', COALESCE(k.meta, '{}'::jsonb),
'expires_at', k.expires_at,
'created_at', k.created_at,
'last_used_at', k.last_used_at,
'is_active', k.is_active
)
),
'[]'::jsonb
)
INTO v_keys
FROM user_keys k
WHERE k.user_id = p_user_id
AND k.is_active = true
AND (k.expires_at IS NULL OR k.expires_at > NOW())
AND (p_key_type = '' OR k.key_type = p_key_type);
RETURN QUERY SELECT true, NULL::TEXT, v_keys;
EXCEPTION WHEN OTHERS THEN
RETURN QUERY SELECT false, SQLERRM, NULL::JSONB;
END;
$$;
-- resolvespec_keystore_create_key
-- Inserts a new key row. key_hash is provided by the caller (Go hashes the raw key).
-- Returns the created key record (without key_hash).
CREATE OR REPLACE FUNCTION resolvespec_keystore_create_key(
p_request JSONB
)
RETURNS TABLE(p_success BOOLEAN, p_error TEXT, p_key JSONB)
LANGUAGE plpgsql AS $$
DECLARE
v_id BIGINT;
v_created_at TIMESTAMP;
v_key JSONB;
BEGIN
INSERT INTO user_keys (user_id, key_type, key_hash, name, scopes, meta, expires_at)
VALUES (
(p_request->>'user_id')::INTEGER,
p_request->>'key_type',
p_request->>'key_hash',
COALESCE(p_request->>'name', ''),
p_request->>'scopes',
p_request->'meta',
CASE WHEN p_request->>'expires_at' IS NOT NULL
THEN (p_request->>'expires_at')::TIMESTAMP
ELSE NULL
END
)
RETURNING id, created_at INTO v_id, v_created_at;
v_key := jsonb_build_object(
'id', v_id,
'user_id', (p_request->>'user_id')::INTEGER,
'key_type', p_request->>'key_type',
'name', COALESCE(p_request->>'name', ''),
'scopes', CASE WHEN p_request->>'scopes' IS NOT NULL
THEN (p_request->>'scopes')::jsonb
ELSE '[]'::jsonb END,
'meta', COALESCE(p_request->'meta', '{}'::jsonb),
'expires_at', p_request->>'expires_at',
'created_at', v_created_at,
'is_active', true
);
RETURN QUERY SELECT true, NULL::TEXT, v_key;
EXCEPTION WHEN OTHERS THEN
RETURN QUERY SELECT false, SQLERRM, NULL::JSONB;
END;
$$;
-- resolvespec_keystore_delete_key
-- Soft-deletes a key (is_active = false) after verifying ownership.
-- Returns p_key_hash so the caller can invalidate cache entries without a separate query.
CREATE OR REPLACE FUNCTION resolvespec_keystore_delete_key(
p_user_id INTEGER,
p_key_id BIGINT
)
RETURNS TABLE(p_success BOOLEAN, p_error TEXT, p_key_hash TEXT)
LANGUAGE plpgsql AS $$
DECLARE
v_hash TEXT;
BEGIN
UPDATE user_keys
SET is_active = false
WHERE id = p_key_id AND user_id = p_user_id AND is_active = true
RETURNING key_hash INTO v_hash;
IF NOT FOUND THEN
RETURN QUERY SELECT false, 'key not found or already deleted'::TEXT, NULL::TEXT;
RETURN;
END IF;
RETURN QUERY SELECT true, NULL::TEXT, v_hash;
EXCEPTION WHEN OTHERS THEN
RETURN QUERY SELECT false, SQLERRM, NULL::TEXT;
END;
$$;
-- resolvespec_keystore_validate_key
-- Looks up a key by its SHA-256 hash, checks active status and expiry,
-- updates last_used_at, and returns the key record.
-- p_key_type can be empty to accept any key type.
CREATE OR REPLACE FUNCTION resolvespec_keystore_validate_key(
p_key_hash TEXT,
p_key_type TEXT DEFAULT ''
)
RETURNS TABLE(p_success BOOLEAN, p_error TEXT, p_key JSONB)
LANGUAGE plpgsql AS $$
DECLARE
v_key_rec user_keys%ROWTYPE;
v_key JSONB;
BEGIN
SELECT * INTO v_key_rec
FROM user_keys
WHERE key_hash = p_key_hash
AND is_active = true
AND (expires_at IS NULL OR expires_at > NOW())
AND (p_key_type = '' OR key_type = p_key_type);
IF NOT FOUND THEN
RETURN QUERY SELECT false, 'invalid or expired key'::TEXT, NULL::JSONB;
RETURN;
END IF;
UPDATE user_keys SET last_used_at = NOW() WHERE id = v_key_rec.id;
v_key := jsonb_build_object(
'id', v_key_rec.id,
'user_id', v_key_rec.user_id,
'key_type', v_key_rec.key_type,
'name', v_key_rec.name,
'scopes', CASE WHEN v_key_rec.scopes IS NOT NULL
THEN v_key_rec.scopes::jsonb
ELSE '[]'::jsonb END,
'meta', COALESCE(v_key_rec.meta, '{}'::jsonb),
'expires_at', v_key_rec.expires_at,
'created_at', v_key_rec.created_at,
'last_used_at', NOW(),
'is_active', v_key_rec.is_active
);
RETURN QUERY SELECT true, NULL::TEXT, v_key;
EXCEPTION WHEN OTHERS THEN
RETURN QUERY SELECT false, SQLERRM, NULL::JSONB;
END;
$$;

View File

@@ -0,0 +1,61 @@
package security
import "fmt"
// KeyStoreSQLNames holds the configurable stored procedure names used by DatabaseKeyStore.
// Use DefaultKeyStoreSQLNames() for defaults and MergeKeyStoreSQLNames() for partial overrides.
type KeyStoreSQLNames struct {
GetUserKeys string // default: "resolvespec_keystore_get_user_keys"
CreateKey string // default: "resolvespec_keystore_create_key"
DeleteKey string // default: "resolvespec_keystore_delete_key"
ValidateKey string // default: "resolvespec_keystore_validate_key"
}
// DefaultKeyStoreSQLNames returns a KeyStoreSQLNames with all default resolvespec_keystore_* values.
func DefaultKeyStoreSQLNames() *KeyStoreSQLNames {
return &KeyStoreSQLNames{
GetUserKeys: "resolvespec_keystore_get_user_keys",
CreateKey: "resolvespec_keystore_create_key",
DeleteKey: "resolvespec_keystore_delete_key",
ValidateKey: "resolvespec_keystore_validate_key",
}
}
// MergeKeyStoreSQLNames returns a copy of base with any non-empty fields from override applied.
// If override is nil, a copy of base is returned.
func MergeKeyStoreSQLNames(base, override *KeyStoreSQLNames) *KeyStoreSQLNames {
if override == nil {
copied := *base
return &copied
}
merged := *base
if override.GetUserKeys != "" {
merged.GetUserKeys = override.GetUserKeys
}
if override.CreateKey != "" {
merged.CreateKey = override.CreateKey
}
if override.DeleteKey != "" {
merged.DeleteKey = override.DeleteKey
}
if override.ValidateKey != "" {
merged.ValidateKey = override.ValidateKey
}
return &merged
}
// ValidateKeyStoreSQLNames checks that all non-empty procedure names are valid SQL identifiers.
func ValidateKeyStoreSQLNames(names *KeyStoreSQLNames) error {
fields := map[string]string{
"GetUserKeys": names.GetUserKeys,
"CreateKey": names.CreateKey,
"DeleteKey": names.DeleteKey,
"ValidateKey": names.ValidateKey,
}
for field, val := range fields {
if val != "" && !validSQLIdentifier.MatchString(val) {
return fmt.Errorf("KeyStoreSQLNames.%s contains invalid characters: %q", field, val)
}
}
return nil
}

View File

@@ -244,10 +244,10 @@ func (a *DatabaseAuthenticator) oauth2GetOrCreateUser(ctx context.Context, userC
var errMsg *string
var userID *int
err = a.db.QueryRowContext(ctx, `
SELECT p_success, p_error, p_user_id
FROM resolvespec_oauth_getorcreateuser($1::jsonb)
`, userJSON).Scan(&success, &errMsg, &userID)
err = a.getDB().QueryRowContext(ctx, fmt.Sprintf(`
SELECT p_success, p_error, p_user_id
FROM %s($1::jsonb)
`, a.sqlNames.OAuthGetOrCreateUser), userJSON).Scan(&success, &errMsg, &userID)
if err != nil {
return 0, fmt.Errorf("failed to get or create user: %w", err)
@@ -287,10 +287,10 @@ func (a *DatabaseAuthenticator) oauth2CreateSession(ctx context.Context, session
var success bool
var errMsg *string
err = a.db.QueryRowContext(ctx, `
SELECT p_success, p_error
FROM resolvespec_oauth_createsession($1::jsonb)
`, sessionJSON).Scan(&success, &errMsg)
err = a.getDB().QueryRowContext(ctx, fmt.Sprintf(`
SELECT p_success, p_error
FROM %s($1::jsonb)
`, a.sqlNames.OAuthCreateSession), sessionJSON).Scan(&success, &errMsg)
if err != nil {
return fmt.Errorf("failed to create session: %w", err)
@@ -385,10 +385,10 @@ func (a *DatabaseAuthenticator) OAuth2RefreshToken(ctx context.Context, refreshT
var errMsg *string
var sessionData []byte
err = a.db.QueryRowContext(ctx, `
err = a.getDB().QueryRowContext(ctx, fmt.Sprintf(`
SELECT p_success, p_error, p_data::text
FROM resolvespec_oauth_getrefreshtoken($1)
`, refreshToken).Scan(&success, &errMsg, &sessionData)
FROM %s($1)
`, a.sqlNames.OAuthGetRefreshToken), refreshToken).Scan(&success, &errMsg, &sessionData)
if err != nil {
return nil, fmt.Errorf("failed to get session by refresh token: %w", err)
@@ -451,10 +451,10 @@ func (a *DatabaseAuthenticator) OAuth2RefreshToken(ctx context.Context, refreshT
var updateSuccess bool
var updateErrMsg *string
err = a.db.QueryRowContext(ctx, `
err = a.getDB().QueryRowContext(ctx, fmt.Sprintf(`
SELECT p_success, p_error
FROM resolvespec_oauth_updaterefreshtoken($1::jsonb)
`, updateJSON).Scan(&updateSuccess, &updateErrMsg)
FROM %s($1::jsonb)
`, a.sqlNames.OAuthUpdateRefreshToken), updateJSON).Scan(&updateSuccess, &updateErrMsg)
if err != nil {
return nil, fmt.Errorf("failed to update session: %w", err)
@@ -472,10 +472,10 @@ func (a *DatabaseAuthenticator) OAuth2RefreshToken(ctx context.Context, refreshT
var userErrMsg *string
var userData []byte
err = a.db.QueryRowContext(ctx, `
err = a.getDB().QueryRowContext(ctx, fmt.Sprintf(`
SELECT p_success, p_error, p_data::text
FROM resolvespec_oauth_getuser($1)
`, session.UserID).Scan(&userSuccess, &userErrMsg, &userData)
FROM %s($1)
`, a.sqlNames.OAuthGetUser), session.UserID).Scan(&userSuccess, &userErrMsg, &userData)
if err != nil {
return nil, fmt.Errorf("failed to get user data: %w", err)

View File

@@ -0,0 +1,917 @@
package security
import (
"context"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"net/http"
"net/url"
"strings"
"sync"
"time"
)
// OAuthServerConfig configures the MCP-standard OAuth2 authorization server.
type OAuthServerConfig struct {
// Issuer is the public base URL of this server (e.g. "https://api.example.com").
// Used in /.well-known/oauth-authorization-server and to build endpoint URLs.
Issuer string
// ProviderCallbackPath is the path on this server that external OAuth2 providers
// redirect back to. Defaults to "/oauth/provider/callback".
ProviderCallbackPath string
// LoginTitle is shown on the built-in login form when the server acts as its own
// identity provider. Defaults to "Sign in".
LoginTitle string
// PersistClients stores registered clients in the database when a DatabaseAuthenticator is provided.
// Clients registered during a session survive server restarts.
PersistClients bool
// PersistCodes stores authorization codes in the database.
// Useful for multi-instance deployments. Defaults to in-memory.
PersistCodes bool
// DefaultScopes lists scopes advertised in server metadata. Defaults to ["openid","profile","email"].
DefaultScopes []string
// AccessTokenTTL is the issued token lifetime. Defaults to 24h.
AccessTokenTTL time.Duration
// AuthCodeTTL is the auth code lifetime. Defaults to 2 minutes.
AuthCodeTTL time.Duration
}
// oauthClient is a dynamically registered OAuth2 client (RFC 7591).
type oauthClient struct {
ClientID string `json:"client_id"`
RedirectURIs []string `json:"redirect_uris"`
ClientName string `json:"client_name,omitempty"`
GrantTypes []string `json:"grant_types"`
AllowedScopes []string `json:"allowed_scopes,omitempty"`
}
// pendingAuth tracks an in-progress authorization code exchange.
type pendingAuth struct {
ClientID string
RedirectURI string
ClientState string
CodeChallenge string
CodeChallengeMethod string
ProviderName string // empty = password login
ExpiresAt time.Time
SessionToken string // set after authentication completes
RefreshToken string // set after authentication completes when refresh tokens are issued
Scopes []string // requested scopes
}
// externalProvider pairs a DatabaseAuthenticator with its provider name.
type externalProvider struct {
auth *DatabaseAuthenticator
providerName string
}
// OAuthServer implements the MCP-standard OAuth2 authorization server (OAuth 2.1 + PKCE).
//
// It can act as both:
// - A direct identity provider using DatabaseAuthenticator username/password login
// - A federation layer that delegates authentication to external OAuth2 providers
// (Google, GitHub, Microsoft, etc.) registered via RegisterExternalProvider
//
// The server exposes these RFC-compliant endpoints:
//
// GET /.well-known/oauth-authorization-server RFC 8414 — server metadata discovery
// POST /oauth/register RFC 7591 — dynamic client registration
// GET /oauth/authorize OAuth 2.1 + PKCE — start authorization
// POST /oauth/authorize Direct login form submission
// POST /oauth/token Token exchange and refresh
// POST /oauth/revoke RFC 7009 — token revocation
// POST /oauth/introspect RFC 7662 — token introspection
// GET {ProviderCallbackPath} Internal — external provider callback
type OAuthServer struct {
cfg OAuthServerConfig
auth *DatabaseAuthenticator // nil = only external providers
providers []externalProvider
mu sync.RWMutex
clients map[string]*oauthClient
pending map[string]*pendingAuth // provider_state → pending (external flow)
codes map[string]*pendingAuth // auth_code → pending (post-auth)
done chan struct{} // closed by Close() to stop background goroutines
}
// NewOAuthServer creates a new MCP OAuth2 authorization server.
//
// Pass a DatabaseAuthenticator to enable direct username/password login (the server
// acts as its own identity provider). Pass nil to use only external providers.
// External providers are added separately via RegisterExternalProvider.
//
// Call Close() to stop background goroutines when the server is no longer needed.
func NewOAuthServer(cfg OAuthServerConfig, auth *DatabaseAuthenticator) *OAuthServer {
if cfg.ProviderCallbackPath == "" {
cfg.ProviderCallbackPath = "/oauth/provider/callback"
}
if cfg.LoginTitle == "" {
cfg.LoginTitle = "Sign in"
}
if len(cfg.DefaultScopes) == 0 {
cfg.DefaultScopes = []string{"openid", "profile", "email"}
}
if cfg.AccessTokenTTL == 0 {
cfg.AccessTokenTTL = 24 * time.Hour
}
if cfg.AuthCodeTTL == 0 {
cfg.AuthCodeTTL = 2 * time.Minute
}
// Normalize issuer: remove trailing slash to ensure consistent endpoint URL construction.
cfg.Issuer = strings.TrimSuffix(cfg.Issuer, "/")
s := &OAuthServer{
cfg: cfg,
auth: auth,
clients: make(map[string]*oauthClient),
pending: make(map[string]*pendingAuth),
codes: make(map[string]*pendingAuth),
done: make(chan struct{}),
}
go s.cleanupExpired()
return s
}
// Close stops the background goroutines started by NewOAuthServer.
// It is safe to call Close multiple times.
func (s *OAuthServer) Close() {
select {
case <-s.done:
// already closed
default:
close(s.done)
}
}
// RegisterExternalProvider adds an external OAuth2 provider (Google, GitHub, Microsoft, etc.)
// that handles user authentication via redirect. The DatabaseAuthenticator must have been
// configured with WithOAuth2(providerName, ...) before calling this.
// Multiple providers can be registered; the first is used as the default.
// All providers must be registered before the server starts serving requests.
func (s *OAuthServer) RegisterExternalProvider(auth *DatabaseAuthenticator, providerName string) {
s.mu.Lock()
s.providers = append(s.providers, externalProvider{auth: auth, providerName: providerName})
s.mu.Unlock()
}
// ProviderCallbackPath returns the configured path for external provider callbacks.
func (s *OAuthServer) ProviderCallbackPath() string {
return s.cfg.ProviderCallbackPath
}
// HTTPHandler returns an http.Handler that serves all RFC-required OAuth2 endpoints.
// Mount it at the root of your HTTP server alongside the MCP transport.
//
// mux := http.NewServeMux()
// mux.Handle("/", oauthServer.HTTPHandler())
// mux.Handle("/mcp/", mcpTransport)
func (s *OAuthServer) HTTPHandler() http.Handler {
mux := http.NewServeMux()
mux.HandleFunc("/.well-known/oauth-authorization-server", s.metadataHandler)
mux.HandleFunc("/oauth/register", s.registerHandler)
mux.HandleFunc("/oauth/authorize", s.authorizeHandler)
mux.HandleFunc("/oauth/token", s.tokenHandler)
mux.HandleFunc("/oauth/revoke", s.revokeHandler)
mux.HandleFunc("/oauth/introspect", s.introspectHandler)
mux.HandleFunc(s.cfg.ProviderCallbackPath, s.providerCallbackHandler)
return mux
}
// cleanupExpired removes stale pending auths and codes every 5 minutes.
func (s *OAuthServer) cleanupExpired() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for {
select {
case <-s.done:
return
case <-ticker.C:
now := time.Now()
s.mu.Lock()
for k, p := range s.pending {
if now.After(p.ExpiresAt) {
delete(s.pending, k)
}
}
for k, p := range s.codes {
if now.After(p.ExpiresAt) {
delete(s.codes, k)
}
}
s.mu.Unlock()
}
}
}
// --------------------------------------------------------------------------
// RFC 8414 — Server metadata
// --------------------------------------------------------------------------
func (s *OAuthServer) metadataHandler(w http.ResponseWriter, r *http.Request) {
issuer := s.cfg.Issuer
meta := map[string]interface{}{
"issuer": issuer,
"authorization_endpoint": issuer + "/oauth/authorize",
"token_endpoint": issuer + "/oauth/token",
"registration_endpoint": issuer + "/oauth/register",
"revocation_endpoint": issuer + "/oauth/revoke",
"introspection_endpoint": issuer + "/oauth/introspect",
"scopes_supported": s.cfg.DefaultScopes,
"response_types_supported": []string{"code"},
"grant_types_supported": []string{"authorization_code", "refresh_token"},
"code_challenge_methods_supported": []string{"S256"},
"token_endpoint_auth_methods_supported": []string{"none"},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(meta) //nolint:errcheck
}
// --------------------------------------------------------------------------
// RFC 7591 — Dynamic client registration
// --------------------------------------------------------------------------
func (s *OAuthServer) registerHandler(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
var req struct {
RedirectURIs []string `json:"redirect_uris"`
ClientName string `json:"client_name"`
GrantTypes []string `json:"grant_types"`
AllowedScopes []string `json:"allowed_scopes"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeOAuthError(w, "invalid_request", "malformed JSON", http.StatusBadRequest)
return
}
if len(req.RedirectURIs) == 0 {
writeOAuthError(w, "invalid_request", "redirect_uris required", http.StatusBadRequest)
return
}
grantTypes := req.GrantTypes
if len(grantTypes) == 0 {
grantTypes = []string{"authorization_code"}
}
allowedScopes := req.AllowedScopes
if len(allowedScopes) == 0 {
allowedScopes = s.cfg.DefaultScopes
}
clientID, err := randomOAuthToken()
if err != nil {
http.Error(w, "server error", http.StatusInternalServerError)
return
}
client := &oauthClient{
ClientID: clientID,
RedirectURIs: req.RedirectURIs,
ClientName: req.ClientName,
GrantTypes: grantTypes,
AllowedScopes: allowedScopes,
}
if s.cfg.PersistClients && s.auth != nil {
dbClient := &OAuthServerClient{
ClientID: client.ClientID,
RedirectURIs: client.RedirectURIs,
ClientName: client.ClientName,
GrantTypes: client.GrantTypes,
AllowedScopes: client.AllowedScopes,
}
if _, err := s.auth.OAuthRegisterClient(r.Context(), dbClient); err != nil {
http.Error(w, "server error", http.StatusInternalServerError)
return
}
}
s.mu.Lock()
s.clients[clientID] = client
s.mu.Unlock()
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusCreated)
json.NewEncoder(w).Encode(client) //nolint:errcheck
}
// --------------------------------------------------------------------------
// Authorization endpoint — GET + POST /oauth/authorize
// --------------------------------------------------------------------------
func (s *OAuthServer) authorizeHandler(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodGet:
s.authorizeGet(w, r)
case http.MethodPost:
s.authorizePost(w, r)
default:
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
}
}
// authorizeGet validates the request and either:
// - Redirects to an external provider (if providers are registered)
// - Renders a login form (if the server is its own identity provider)
func (s *OAuthServer) authorizeGet(w http.ResponseWriter, r *http.Request) {
q := r.URL.Query()
clientID := q.Get("client_id")
redirectURI := q.Get("redirect_uri")
clientState := q.Get("state")
codeChallenge := q.Get("code_challenge")
codeChallengeMethod := q.Get("code_challenge_method")
providerName := q.Get("provider")
scopeStr := q.Get("scope")
var scopes []string
if scopeStr != "" {
scopes = strings.Fields(scopeStr)
}
if q.Get("response_type") != "code" {
writeOAuthError(w, "unsupported_response_type", "only 'code' is supported", http.StatusBadRequest)
return
}
if codeChallenge == "" {
writeOAuthError(w, "invalid_request", "code_challenge required (PKCE S256)", http.StatusBadRequest)
return
}
if codeChallengeMethod != "" && codeChallengeMethod != "S256" {
writeOAuthError(w, "invalid_request", "only S256 code_challenge_method is supported", http.StatusBadRequest)
return
}
client, ok := s.lookupOrFetchClient(r.Context(), clientID)
if !ok {
writeOAuthError(w, "invalid_client", "unknown client_id", http.StatusBadRequest)
return
}
if !oauthSliceContains(client.RedirectURIs, redirectURI) {
writeOAuthError(w, "invalid_request", "redirect_uri not registered", http.StatusBadRequest)
return
}
// External provider path
if len(s.providers) > 0 {
s.redirectToExternalProvider(w, r, clientID, redirectURI, clientState, codeChallenge, codeChallengeMethod, providerName, scopes)
return
}
// Direct login form path (server is its own identity provider)
if s.auth == nil {
http.Error(w, "no authentication provider configured", http.StatusInternalServerError)
return
}
s.renderLoginForm(w, r, clientID, redirectURI, clientState, codeChallenge, codeChallengeMethod, scopeStr, "")
}
// authorizePost handles login form submission for the direct login flow.
func (s *OAuthServer) authorizePost(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
http.Error(w, "invalid form", http.StatusBadRequest)
return
}
clientID := r.FormValue("client_id")
redirectURI := r.FormValue("redirect_uri")
clientState := r.FormValue("client_state")
codeChallenge := r.FormValue("code_challenge")
codeChallengeMethod := r.FormValue("code_challenge_method")
username := r.FormValue("username")
password := r.FormValue("password")
scopeStr := r.FormValue("scope")
var scopes []string
if scopeStr != "" {
scopes = strings.Fields(scopeStr)
}
client, ok := s.lookupOrFetchClient(r.Context(), clientID)
if !ok || !oauthSliceContains(client.RedirectURIs, redirectURI) {
http.Error(w, "invalid client or redirect_uri", http.StatusBadRequest)
return
}
if s.auth == nil {
http.Error(w, "no authentication provider configured", http.StatusInternalServerError)
return
}
loginResp, err := s.auth.Login(r.Context(), LoginRequest{
Username: username,
Password: password,
})
if err != nil {
s.renderLoginForm(w, r, clientID, redirectURI, clientState, codeChallenge, codeChallengeMethod, scopeStr, "Invalid username or password")
return
}
s.issueCodeAndRedirect(w, r, loginResp.Token, loginResp.RefreshToken, clientID, redirectURI, clientState, codeChallenge, codeChallengeMethod, "", scopes)
}
// redirectToExternalProvider stores the pending auth and redirects to the configured provider.
func (s *OAuthServer) redirectToExternalProvider(w http.ResponseWriter, r *http.Request, clientID, redirectURI, clientState, codeChallenge, codeChallengeMethod, providerName string, scopes []string) {
var provider *externalProvider
if providerName != "" {
for i := range s.providers {
if s.providers[i].providerName == providerName {
provider = &s.providers[i]
break
}
}
if provider == nil {
http.Error(w, fmt.Sprintf("provider %q not found", providerName), http.StatusBadRequest)
return
}
} else {
provider = &s.providers[0]
}
providerState, err := randomOAuthToken()
if err != nil {
http.Error(w, "server error", http.StatusInternalServerError)
return
}
pending := &pendingAuth{
ClientID: clientID,
RedirectURI: redirectURI,
ClientState: clientState,
CodeChallenge: codeChallenge,
CodeChallengeMethod: codeChallengeMethod,
ProviderName: provider.providerName,
ExpiresAt: time.Now().Add(10 * time.Minute),
Scopes: scopes,
}
s.mu.Lock()
s.pending[providerState] = pending
s.mu.Unlock()
authURL, err := provider.auth.OAuth2GetAuthURL(provider.providerName, providerState)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
http.Redirect(w, r, authURL, http.StatusFound)
}
// --------------------------------------------------------------------------
// External provider callback — GET {ProviderCallbackPath}
// --------------------------------------------------------------------------
func (s *OAuthServer) providerCallbackHandler(w http.ResponseWriter, r *http.Request) {
code := r.URL.Query().Get("code")
providerState := r.URL.Query().Get("state")
if code == "" {
http.Error(w, "missing code", http.StatusBadRequest)
return
}
s.mu.Lock()
pending, ok := s.pending[providerState]
if ok {
delete(s.pending, providerState)
}
s.mu.Unlock()
if !ok || time.Now().After(pending.ExpiresAt) {
http.Error(w, "invalid or expired state", http.StatusBadRequest)
return
}
provider := s.providerByName(pending.ProviderName)
if provider == nil {
http.Error(w, fmt.Sprintf("provider %q not found", pending.ProviderName), http.StatusInternalServerError)
return
}
loginResp, err := provider.auth.OAuth2HandleCallback(r.Context(), pending.ProviderName, code, providerState)
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
s.issueCodeAndRedirect(w, r, loginResp.Token, loginResp.RefreshToken,
pending.ClientID, pending.RedirectURI, pending.ClientState,
pending.CodeChallenge, pending.CodeChallengeMethod, pending.ProviderName, pending.Scopes)
}
// issueCodeAndRedirect generates a short-lived auth code and redirects to the MCP client.
func (s *OAuthServer) issueCodeAndRedirect(w http.ResponseWriter, r *http.Request, sessionToken, refreshToken, clientID, redirectURI, clientState, codeChallenge, codeChallengeMethod, providerName string, scopes []string) {
authCode, err := randomOAuthToken()
if err != nil {
http.Error(w, "server error", http.StatusInternalServerError)
return
}
pending := &pendingAuth{
ClientID: clientID,
RedirectURI: redirectURI,
ClientState: clientState,
CodeChallenge: codeChallenge,
CodeChallengeMethod: codeChallengeMethod,
ProviderName: providerName,
SessionToken: sessionToken,
RefreshToken: refreshToken,
ExpiresAt: time.Now().Add(s.cfg.AuthCodeTTL),
Scopes: scopes,
}
if s.cfg.PersistCodes && s.auth != nil {
oauthCode := &OAuthCode{
Code: authCode,
ClientID: clientID,
RedirectURI: redirectURI,
ClientState: clientState,
CodeChallenge: codeChallenge,
CodeChallengeMethod: codeChallengeMethod,
SessionToken: sessionToken,
RefreshToken: refreshToken,
Scopes: scopes,
ExpiresAt: pending.ExpiresAt,
}
if err := s.auth.OAuthSaveCode(r.Context(), oauthCode); err != nil {
http.Error(w, "server error", http.StatusInternalServerError)
return
}
} else {
s.mu.Lock()
s.codes[authCode] = pending
s.mu.Unlock()
}
redirectURL, err := url.Parse(redirectURI)
if err != nil {
http.Error(w, "invalid redirect_uri", http.StatusInternalServerError)
return
}
qp := redirectURL.Query()
qp.Set("code", authCode)
if clientState != "" {
qp.Set("state", clientState)
}
redirectURL.RawQuery = qp.Encode()
http.Redirect(w, r, redirectURL.String(), http.StatusFound)
}
// --------------------------------------------------------------------------
// Token endpoint — POST /oauth/token
// --------------------------------------------------------------------------
func (s *OAuthServer) tokenHandler(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
if err := r.ParseForm(); err != nil {
writeOAuthError(w, "invalid_request", "cannot parse form", http.StatusBadRequest)
return
}
switch r.FormValue("grant_type") {
case "authorization_code":
s.handleAuthCodeGrant(w, r)
case "refresh_token":
s.handleRefreshGrant(w, r)
default:
writeOAuthError(w, "unsupported_grant_type", "", http.StatusBadRequest)
}
}
func (s *OAuthServer) handleAuthCodeGrant(w http.ResponseWriter, r *http.Request) {
code := r.FormValue("code")
redirectURI := r.FormValue("redirect_uri")
clientID := r.FormValue("client_id")
codeVerifier := r.FormValue("code_verifier")
if code == "" || codeVerifier == "" {
writeOAuthError(w, "invalid_request", "code and code_verifier required", http.StatusBadRequest)
return
}
var sessionToken string
var refreshToken string
var scopes []string
if s.cfg.PersistCodes && s.auth != nil {
oauthCode, err := s.auth.OAuthExchangeCode(r.Context(), code)
if err != nil {
writeOAuthError(w, "invalid_grant", "code expired or invalid", http.StatusBadRequest)
return
}
if oauthCode.ClientID != clientID {
writeOAuthError(w, "invalid_client", "", http.StatusBadRequest)
return
}
if oauthCode.RedirectURI != redirectURI {
writeOAuthError(w, "invalid_grant", "redirect_uri mismatch", http.StatusBadRequest)
return
}
if !validatePKCESHA256(oauthCode.CodeChallenge, codeVerifier) {
writeOAuthError(w, "invalid_grant", "code_verifier invalid", http.StatusBadRequest)
return
}
sessionToken = oauthCode.SessionToken
refreshToken = oauthCode.RefreshToken
scopes = oauthCode.Scopes
} else {
s.mu.Lock()
pending, ok := s.codes[code]
if ok {
delete(s.codes, code)
}
s.mu.Unlock()
if !ok || time.Now().After(pending.ExpiresAt) {
writeOAuthError(w, "invalid_grant", "code expired or invalid", http.StatusBadRequest)
return
}
if pending.ClientID != clientID {
writeOAuthError(w, "invalid_client", "", http.StatusBadRequest)
return
}
if pending.RedirectURI != redirectURI {
writeOAuthError(w, "invalid_grant", "redirect_uri mismatch", http.StatusBadRequest)
return
}
if !validatePKCESHA256(pending.CodeChallenge, codeVerifier) {
writeOAuthError(w, "invalid_grant", "code_verifier invalid", http.StatusBadRequest)
return
}
sessionToken = pending.SessionToken
refreshToken = pending.RefreshToken
scopes = pending.Scopes
}
s.writeOAuthToken(w, sessionToken, refreshToken, scopes)
}
func (s *OAuthServer) handleRefreshGrant(w http.ResponseWriter, r *http.Request) {
refreshToken := r.FormValue("refresh_token")
providerName := r.FormValue("provider")
if refreshToken == "" {
writeOAuthError(w, "invalid_request", "refresh_token required", http.StatusBadRequest)
return
}
// Try external providers first, then fall back to DatabaseAuthenticator
provider := s.providerByName(providerName)
if provider != nil {
loginResp, err := provider.auth.OAuth2RefreshToken(r.Context(), refreshToken, providerName)
if err != nil {
writeOAuthError(w, "invalid_grant", err.Error(), http.StatusBadRequest)
return
}
s.writeOAuthToken(w, loginResp.Token, loginResp.RefreshToken, nil)
return
}
if s.auth != nil {
loginResp, err := s.auth.RefreshToken(r.Context(), refreshToken)
if err != nil {
writeOAuthError(w, "invalid_grant", err.Error(), http.StatusBadRequest)
return
}
s.writeOAuthToken(w, loginResp.Token, loginResp.RefreshToken, nil)
return
}
writeOAuthError(w, "invalid_grant", "no provider available for refresh", http.StatusBadRequest)
}
// --------------------------------------------------------------------------
// RFC 7009 — Token revocation
// --------------------------------------------------------------------------
func (s *OAuthServer) revokeHandler(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
if err := r.ParseForm(); err != nil {
w.WriteHeader(http.StatusOK)
return
}
token := r.FormValue("token")
if token == "" {
w.WriteHeader(http.StatusOK)
return
}
if s.auth != nil {
s.auth.OAuthRevokeToken(r.Context(), token) //nolint:errcheck
} else {
// In external-provider-only mode, attempt revocation via the first provider's auth.
s.mu.RLock()
var providerAuth *DatabaseAuthenticator
if len(s.providers) > 0 {
providerAuth = s.providers[0].auth
}
s.mu.RUnlock()
if providerAuth != nil {
providerAuth.OAuthRevokeToken(r.Context(), token) //nolint:errcheck
}
}
w.WriteHeader(http.StatusOK)
}
// --------------------------------------------------------------------------
// RFC 7662 — Token introspection
// --------------------------------------------------------------------------
func (s *OAuthServer) introspectHandler(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
if err := r.ParseForm(); err != nil {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{"active":false}`)) //nolint:errcheck
return
}
token := r.FormValue("token")
w.Header().Set("Content-Type", "application/json")
if token == "" {
w.Write([]byte(`{"active":false}`)) //nolint:errcheck
return
}
// Resolve the authenticator to use: prefer the primary auth, then the first provider's auth.
authToUse := s.auth
if authToUse == nil {
s.mu.RLock()
if len(s.providers) > 0 {
authToUse = s.providers[0].auth
}
s.mu.RUnlock()
}
if authToUse == nil {
w.Write([]byte(`{"active":false}`)) //nolint:errcheck
return
}
info, err := authToUse.OAuthIntrospectToken(r.Context(), token)
if err != nil {
w.Write([]byte(`{"active":false}`)) //nolint:errcheck
return
}
json.NewEncoder(w).Encode(info) //nolint:errcheck
}
// --------------------------------------------------------------------------
// Login form (direct identity provider mode)
// --------------------------------------------------------------------------
func (s *OAuthServer) renderLoginForm(w http.ResponseWriter, r *http.Request, clientID, redirectURI, clientState, codeChallenge, codeChallengeMethod, scope, errMsg string) {
w.Header().Set("Content-Type", "text/html; charset=utf-8")
errHTML := ""
if errMsg != "" {
errHTML = `<p style="color:red">` + errMsg + `</p>`
}
fmt.Fprintf(w, loginFormHTML,
s.cfg.LoginTitle,
s.cfg.LoginTitle,
errHTML,
clientID,
htmlEscape(redirectURI),
htmlEscape(clientState),
htmlEscape(codeChallenge),
htmlEscape(codeChallengeMethod),
htmlEscape(scope),
)
}
const loginFormHTML = `<!DOCTYPE html>
<html><head><meta charset="utf-8"><title>%s</title>
<style>body{font-family:sans-serif;display:flex;justify-content:center;align-items:center;min-height:100vh;margin:0;background:#f5f5f5}
.card{background:#fff;padding:2rem;border-radius:8px;box-shadow:0 2px 8px rgba(0,0,0,.15);width:320px}
h2{margin:0 0 1.5rem;font-size:1.25rem}
label{display:block;margin-bottom:.25rem;font-size:.875rem;color:#555}
input[type=text],input[type=password]{width:100%%;box-sizing:border-box;padding:.5rem;border:1px solid #ccc;border-radius:4px;margin-bottom:1rem;font-size:1rem}
button{width:100%%;padding:.6rem;background:#0070f3;color:#fff;border:none;border-radius:4px;font-size:1rem;cursor:pointer}
button:hover{background:#005fd4}.err{color:#d32f2f;margin-bottom:1rem;font-size:.875rem}</style>
</head><body><div class="card">
<h2>%s</h2>%s
<form method="POST" action="authorize">
<input type="hidden" name="client_id" value="%s">
<input type="hidden" name="redirect_uri" value="%s">
<input type="hidden" name="client_state" value="%s">
<input type="hidden" name="code_challenge" value="%s">
<input type="hidden" name="code_challenge_method" value="%s">
<input type="hidden" name="scope" value="%s">
<label>Username</label><input type="text" name="username" autofocus autocomplete="username">
<label>Password</label><input type="password" name="password" autocomplete="current-password">
<button type="submit">Sign in</button>
</form></div></body></html>`
// --------------------------------------------------------------------------
// Helpers
// --------------------------------------------------------------------------
// lookupOrFetchClient checks in-memory first, then DB if PersistClients is enabled.
func (s *OAuthServer) lookupOrFetchClient(ctx context.Context, clientID string) (*oauthClient, bool) {
s.mu.RLock()
c, ok := s.clients[clientID]
s.mu.RUnlock()
if ok {
return c, true
}
if !s.cfg.PersistClients || s.auth == nil {
return nil, false
}
dbClient, err := s.auth.OAuthGetClient(ctx, clientID)
if err != nil {
return nil, false
}
c = &oauthClient{
ClientID: dbClient.ClientID,
RedirectURIs: dbClient.RedirectURIs,
ClientName: dbClient.ClientName,
GrantTypes: dbClient.GrantTypes,
AllowedScopes: dbClient.AllowedScopes,
}
s.mu.Lock()
s.clients[clientID] = c
s.mu.Unlock()
return c, true
}
func (s *OAuthServer) providerByName(name string) *externalProvider {
for i := range s.providers {
if s.providers[i].providerName == name {
return &s.providers[i]
}
}
// If name is empty and only one provider exists, return it
if name == "" && len(s.providers) == 1 {
return &s.providers[0]
}
return nil
}
func validatePKCESHA256(challenge, verifier string) bool {
h := sha256.Sum256([]byte(verifier))
return base64.RawURLEncoding.EncodeToString(h[:]) == challenge
}
func randomOAuthToken() (string, error) {
b := make([]byte, 32)
if _, err := rand.Read(b); err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(b), nil
}
func oauthSliceContains(slice []string, s string) bool {
for _, v := range slice {
if v == s {
return true
}
}
return false
}
func (s *OAuthServer) writeOAuthToken(w http.ResponseWriter, accessToken, refreshToken string, scopes []string) {
expiresIn := int64(s.cfg.AccessTokenTTL.Seconds())
resp := map[string]interface{}{
"access_token": accessToken,
"token_type": "Bearer",
"expires_in": expiresIn,
}
if refreshToken != "" {
resp["refresh_token"] = refreshToken
}
if len(scopes) > 0 {
resp["scope"] = strings.Join(scopes, " ")
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Cache-Control", "no-store")
w.Header().Set("Pragma", "no-cache")
json.NewEncoder(w).Encode(resp) //nolint:errcheck
}
func writeOAuthError(w http.ResponseWriter, errCode, description string, status int) {
resp := map[string]string{"error": errCode}
if description != "" {
resp["error_description"] = description
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
json.NewEncoder(w).Encode(resp) //nolint:errcheck
}
func htmlEscape(s string) string {
s = strings.ReplaceAll(s, "&", "&amp;")
s = strings.ReplaceAll(s, `"`, "&#34;")
s = strings.ReplaceAll(s, "<", "&lt;")
s = strings.ReplaceAll(s, ">", "&gt;")
return s
}

View File

@@ -0,0 +1,204 @@
package security
import (
"context"
"encoding/json"
"fmt"
"time"
)
// OAuthServerClient is a persisted RFC 7591 registered OAuth2 client.
type OAuthServerClient struct {
ClientID string `json:"client_id"`
RedirectURIs []string `json:"redirect_uris"`
ClientName string `json:"client_name,omitempty"`
GrantTypes []string `json:"grant_types"`
AllowedScopes []string `json:"allowed_scopes,omitempty"`
}
// OAuthCode is a short-lived authorization code.
type OAuthCode struct {
Code string `json:"code"`
ClientID string `json:"client_id"`
RedirectURI string `json:"redirect_uri"`
ClientState string `json:"client_state,omitempty"`
CodeChallenge string `json:"code_challenge"`
CodeChallengeMethod string `json:"code_challenge_method"`
SessionToken string `json:"session_token"`
RefreshToken string `json:"refresh_token,omitempty"`
Scopes []string `json:"scopes,omitempty"`
ExpiresAt time.Time `json:"expires_at"`
}
// OAuthTokenInfo is the RFC 7662 token introspection response.
type OAuthTokenInfo struct {
Active bool `json:"active"`
Sub string `json:"sub,omitempty"`
Username string `json:"username,omitempty"`
Email string `json:"email,omitempty"`
UserLevel int `json:"user_level,omitempty"`
Roles []string `json:"roles,omitempty"`
Exp int64 `json:"exp,omitempty"`
Iat int64 `json:"iat,omitempty"`
}
// OAuthRegisterClient persists an OAuth2 client registration.
func (a *DatabaseAuthenticator) OAuthRegisterClient(ctx context.Context, client *OAuthServerClient) (*OAuthServerClient, error) {
input, err := json.Marshal(client)
if err != nil {
return nil, fmt.Errorf("failed to marshal client: %w", err)
}
var success bool
var errMsg *string
var data []byte
err = a.db.QueryRowContext(ctx, fmt.Sprintf(`
SELECT p_success, p_error, p_data::text
FROM %s($1::jsonb)
`, a.sqlNames.OAuthRegisterClient), input).Scan(&success, &errMsg, &data)
if err != nil {
return nil, fmt.Errorf("failed to register client: %w", err)
}
if !success {
if errMsg != nil {
return nil, fmt.Errorf("%s", *errMsg)
}
return nil, fmt.Errorf("failed to register client")
}
var result OAuthServerClient
if err := json.Unmarshal(data, &result); err != nil {
return nil, fmt.Errorf("failed to parse registered client: %w", err)
}
return &result, nil
}
// OAuthGetClient retrieves a registered client by ID.
func (a *DatabaseAuthenticator) OAuthGetClient(ctx context.Context, clientID string) (*OAuthServerClient, error) {
var success bool
var errMsg *string
var data []byte
err := a.db.QueryRowContext(ctx, fmt.Sprintf(`
SELECT p_success, p_error, p_data::text
FROM %s($1)
`, a.sqlNames.OAuthGetClient), clientID).Scan(&success, &errMsg, &data)
if err != nil {
return nil, fmt.Errorf("failed to get client: %w", err)
}
if !success {
if errMsg != nil {
return nil, fmt.Errorf("%s", *errMsg)
}
return nil, fmt.Errorf("client not found")
}
var result OAuthServerClient
if err := json.Unmarshal(data, &result); err != nil {
return nil, fmt.Errorf("failed to parse client: %w", err)
}
return &result, nil
}
// OAuthSaveCode persists an authorization code.
func (a *DatabaseAuthenticator) OAuthSaveCode(ctx context.Context, code *OAuthCode) error {
input, err := json.Marshal(code)
if err != nil {
return fmt.Errorf("failed to marshal code: %w", err)
}
var success bool
var errMsg *string
err = a.db.QueryRowContext(ctx, fmt.Sprintf(`
SELECT p_success, p_error
FROM %s($1::jsonb)
`, a.sqlNames.OAuthSaveCode), input).Scan(&success, &errMsg)
if err != nil {
return fmt.Errorf("failed to save code: %w", err)
}
if !success {
if errMsg != nil {
return fmt.Errorf("%s", *errMsg)
}
return fmt.Errorf("failed to save code")
}
return nil
}
// OAuthExchangeCode retrieves and deletes an authorization code (single use).
func (a *DatabaseAuthenticator) OAuthExchangeCode(ctx context.Context, code string) (*OAuthCode, error) {
var success bool
var errMsg *string
var data []byte
err := a.db.QueryRowContext(ctx, fmt.Sprintf(`
SELECT p_success, p_error, p_data::text
FROM %s($1)
`, a.sqlNames.OAuthExchangeCode), code).Scan(&success, &errMsg, &data)
if err != nil {
return nil, fmt.Errorf("failed to exchange code: %w", err)
}
if !success {
if errMsg != nil {
return nil, fmt.Errorf("%s", *errMsg)
}
return nil, fmt.Errorf("invalid or expired code")
}
var result OAuthCode
if err := json.Unmarshal(data, &result); err != nil {
return nil, fmt.Errorf("failed to parse code data: %w", err)
}
result.Code = code
return &result, nil
}
// OAuthIntrospectToken validates a token and returns its metadata (RFC 7662).
func (a *DatabaseAuthenticator) OAuthIntrospectToken(ctx context.Context, token string) (*OAuthTokenInfo, error) {
var success bool
var errMsg *string
var data []byte
err := a.db.QueryRowContext(ctx, fmt.Sprintf(`
SELECT p_success, p_error, p_data::text
FROM %s($1)
`, a.sqlNames.OAuthIntrospect), token).Scan(&success, &errMsg, &data)
if err != nil {
return nil, fmt.Errorf("failed to introspect token: %w", err)
}
if !success {
if errMsg != nil {
return nil, fmt.Errorf("%s", *errMsg)
}
return nil, fmt.Errorf("introspection failed")
}
var result OAuthTokenInfo
if err := json.Unmarshal(data, &result); err != nil {
return nil, fmt.Errorf("failed to parse token info: %w", err)
}
return &result, nil
}
// OAuthRevokeToken revokes a token by deleting the session (RFC 7009).
func (a *DatabaseAuthenticator) OAuthRevokeToken(ctx context.Context, token string) error {
var success bool
var errMsg *string
err := a.db.QueryRowContext(ctx, fmt.Sprintf(`
SELECT p_success, p_error
FROM %s($1)
`, a.sqlNames.OAuthRevoke), token).Scan(&success, &errMsg)
if err != nil {
return fmt.Errorf("failed to revoke token: %w", err)
}
if !success {
if errMsg != nil {
return fmt.Errorf("%s", *errMsg)
}
return fmt.Errorf("failed to revoke token")
}
return nil
}

View File

@@ -7,16 +7,21 @@ import (
"encoding/base64"
"encoding/json"
"fmt"
"sync"
"time"
)
// DatabasePasskeyProvider implements PasskeyProvider using database storage
// Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults)
type DatabasePasskeyProvider struct {
db *sql.DB
rpID string // Relying Party ID (domain)
rpName string // Relying Party display name
rpOrigin string // Expected origin for WebAuthn
timeout int64 // Timeout in milliseconds (default: 60000)
db *sql.DB
dbMu sync.RWMutex
dbFactory func() (*sql.DB, error)
rpID string // Relying Party ID (domain)
rpName string // Relying Party display name
rpOrigin string // Expected origin for WebAuthn
timeout int64 // Timeout in milliseconds (default: 60000)
sqlNames *SQLNames
}
// DatabasePasskeyProviderOptions configures the passkey provider
@@ -29,6 +34,11 @@ type DatabasePasskeyProviderOptions struct {
RPOrigin string
// Timeout is the timeout for operations in milliseconds (default: 60000)
Timeout int64
// SQLNames provides custom SQL procedure/function names. If nil, uses DefaultSQLNames().
SQLNames *SQLNames
// DBFactory is called to obtain a fresh *sql.DB when the existing connection is closed.
// If nil, reconnection is disabled.
DBFactory func() (*sql.DB, error)
}
// NewDatabasePasskeyProvider creates a new database-backed passkey provider
@@ -37,15 +47,39 @@ func NewDatabasePasskeyProvider(db *sql.DB, opts DatabasePasskeyProviderOptions)
opts.Timeout = 60000 // 60 seconds default
}
sqlNames := MergeSQLNames(DefaultSQLNames(), opts.SQLNames)
return &DatabasePasskeyProvider{
db: db,
rpID: opts.RPID,
rpName: opts.RPName,
rpOrigin: opts.RPOrigin,
timeout: opts.Timeout,
db: db,
dbFactory: opts.DBFactory,
rpID: opts.RPID,
rpName: opts.RPName,
rpOrigin: opts.RPOrigin,
timeout: opts.Timeout,
sqlNames: sqlNames,
}
}
func (p *DatabasePasskeyProvider) getDB() *sql.DB {
p.dbMu.RLock()
defer p.dbMu.RUnlock()
return p.db
}
func (p *DatabasePasskeyProvider) reconnectDB() error {
if p.dbFactory == nil {
return fmt.Errorf("no db factory configured for reconnect")
}
newDB, err := p.dbFactory()
if err != nil {
return err
}
p.dbMu.Lock()
p.db = newDB
p.dbMu.Unlock()
return nil
}
// BeginRegistration creates registration options for a new passkey
func (p *DatabasePasskeyProvider) BeginRegistration(ctx context.Context, userID int, username, displayName string) (*PasskeyRegistrationOptions, error) {
// Generate challenge
@@ -132,8 +166,8 @@ func (p *DatabasePasskeyProvider) CompleteRegistration(ctx context.Context, user
var errorMsg sql.NullString
var credentialID sql.NullInt64
query := `SELECT p_success, p_error, p_credential_id FROM resolvespec_passkey_store_credential($1::jsonb)`
err = p.db.QueryRowContext(ctx, query, string(credJSON)).Scan(&success, &errorMsg, &credentialID)
query := fmt.Sprintf(`SELECT p_success, p_error, p_credential_id FROM %s($1::jsonb)`, p.sqlNames.PasskeyStoreCredential)
err = p.getDB().QueryRowContext(ctx, query, string(credJSON)).Scan(&success, &errorMsg, &credentialID)
if err != nil {
return nil, fmt.Errorf("failed to store credential: %w", err)
}
@@ -173,8 +207,8 @@ func (p *DatabasePasskeyProvider) BeginAuthentication(ctx context.Context, usern
var userID sql.NullInt64
var credentialsJSON sql.NullString
query := `SELECT p_success, p_error, p_user_id, p_credentials::text FROM resolvespec_passkey_get_credentials_by_username($1)`
err := p.db.QueryRowContext(ctx, query, username).Scan(&success, &errorMsg, &userID, &credentialsJSON)
query := fmt.Sprintf(`SELECT p_success, p_error, p_user_id, p_credentials::text FROM %s($1)`, p.sqlNames.PasskeyGetCredsByUsername)
err := p.getDB().QueryRowContext(ctx, query, username).Scan(&success, &errorMsg, &userID, &credentialsJSON)
if err != nil {
return nil, fmt.Errorf("failed to get credentials: %w", err)
}
@@ -233,8 +267,16 @@ func (p *DatabasePasskeyProvider) CompleteAuthentication(ctx context.Context, re
var errorMsg sql.NullString
var credentialJSON sql.NullString
query := `SELECT p_success, p_error, p_credential::text FROM resolvespec_passkey_get_credential($1)`
err := p.db.QueryRowContext(ctx, query, response.RawID).Scan(&success, &errorMsg, &credentialJSON)
runQuery := func() error {
query := fmt.Sprintf(`SELECT p_success, p_error, p_credential::text FROM %s($1)`, p.sqlNames.PasskeyGetCredential)
return p.getDB().QueryRowContext(ctx, query, response.RawID).Scan(&success, &errorMsg, &credentialJSON)
}
err := runQuery()
if isDBClosed(err) {
if reconnErr := p.reconnectDB(); reconnErr == nil {
err = runQuery()
}
}
if err != nil {
return 0, fmt.Errorf("failed to get credential: %w", err)
}
@@ -264,8 +306,8 @@ func (p *DatabasePasskeyProvider) CompleteAuthentication(ctx context.Context, re
var updateError sql.NullString
var cloneWarning sql.NullBool
updateQuery := `SELECT p_success, p_error, p_clone_warning FROM resolvespec_passkey_update_counter($1, $2)`
err = p.db.QueryRowContext(ctx, updateQuery, response.RawID, newCounter).Scan(&updateSuccess, &updateError, &cloneWarning)
updateQuery := fmt.Sprintf(`SELECT p_success, p_error, p_clone_warning FROM %s($1, $2)`, p.sqlNames.PasskeyUpdateCounter)
err = p.getDB().QueryRowContext(ctx, updateQuery, response.RawID, newCounter).Scan(&updateSuccess, &updateError, &cloneWarning)
if err != nil {
return 0, fmt.Errorf("failed to update counter: %w", err)
}
@@ -283,8 +325,8 @@ func (p *DatabasePasskeyProvider) GetCredentials(ctx context.Context, userID int
var errorMsg sql.NullString
var credentialsJSON sql.NullString
query := `SELECT p_success, p_error, p_credentials::text FROM resolvespec_passkey_get_user_credentials($1)`
err := p.db.QueryRowContext(ctx, query, userID).Scan(&success, &errorMsg, &credentialsJSON)
query := fmt.Sprintf(`SELECT p_success, p_error, p_credentials::text FROM %s($1)`, p.sqlNames.PasskeyGetUserCredentials)
err := p.getDB().QueryRowContext(ctx, query, userID).Scan(&success, &errorMsg, &credentialsJSON)
if err != nil {
return nil, fmt.Errorf("failed to get credentials: %w", err)
}
@@ -362,8 +404,8 @@ func (p *DatabasePasskeyProvider) DeleteCredential(ctx context.Context, userID i
var success bool
var errorMsg sql.NullString
query := `SELECT p_success, p_error FROM resolvespec_passkey_delete_credential($1, $2)`
err = p.db.QueryRowContext(ctx, query, userID, credID).Scan(&success, &errorMsg)
query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1, $2)`, p.sqlNames.PasskeyDeleteCredential)
err = p.getDB().QueryRowContext(ctx, query, userID, credID).Scan(&success, &errorMsg)
if err != nil {
return fmt.Errorf("failed to delete credential: %w", err)
}
@@ -388,8 +430,8 @@ func (p *DatabasePasskeyProvider) UpdateCredentialName(ctx context.Context, user
var success bool
var errorMsg sql.NullString
query := `SELECT p_success, p_error FROM resolvespec_passkey_update_name($1, $2, $3)`
err = p.db.QueryRowContext(ctx, query, userID, credID, name).Scan(&success, &errorMsg)
query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1, $2, $3)`, p.sqlNames.PasskeyUpdateName)
err = p.getDB().QueryRowContext(ctx, query, userID, credID, name).Scan(&success, &errorMsg)
if err != nil {
return fmt.Errorf("failed to update credential name: %w", err)
}

View File

@@ -58,15 +58,17 @@ func (a *HeaderAuthenticator) Authenticate(r *http.Request) (*UserContext, error
// DatabaseAuthenticator provides session-based authentication with database storage
// All database operations go through stored procedures for security and consistency
// Requires stored procedures: resolvespec_login, resolvespec_logout, resolvespec_session,
// resolvespec_session_update, resolvespec_refresh_token
// Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults)
// See database_schema.sql for procedure definitions
// Also supports multiple OAuth2 providers configured with WithOAuth2()
// Also supports passkey authentication configured with WithPasskey()
type DatabaseAuthenticator struct {
db *sql.DB
cache *cache.Cache
cacheTTL time.Duration
db *sql.DB
dbMu sync.RWMutex
dbFactory func() (*sql.DB, error)
cache *cache.Cache
cacheTTL time.Duration
sqlNames *SQLNames
// OAuth2 providers registry (multiple providers supported)
oauth2Providers map[string]*OAuth2Provider
@@ -85,6 +87,12 @@ type DatabaseAuthenticatorOptions struct {
Cache *cache.Cache
// PasskeyProvider is an optional passkey provider for WebAuthn/FIDO2 authentication
PasskeyProvider PasskeyProvider
// SQLNames provides custom SQL procedure/function names. If nil, uses DefaultSQLNames().
// Partial overrides are supported: only set the fields you want to change.
SQLNames *SQLNames
// DBFactory is called to obtain a fresh *sql.DB when the existing connection is closed.
// If nil, reconnection is disabled.
DBFactory func() (*sql.DB, error)
}
func NewDatabaseAuthenticator(db *sql.DB) *DatabaseAuthenticator {
@@ -103,14 +111,54 @@ func NewDatabaseAuthenticatorWithOptions(db *sql.DB, opts DatabaseAuthenticatorO
cacheInstance = cache.GetDefaultCache()
}
sqlNames := MergeSQLNames(DefaultSQLNames(), opts.SQLNames)
return &DatabaseAuthenticator{
db: db,
dbFactory: opts.DBFactory,
cache: cacheInstance,
cacheTTL: opts.CacheTTL,
sqlNames: sqlNames,
passkeyProvider: opts.PasskeyProvider,
}
}
func (a *DatabaseAuthenticator) getDB() *sql.DB {
a.dbMu.RLock()
defer a.dbMu.RUnlock()
return a.db
}
func (a *DatabaseAuthenticator) reconnectDB() error {
if a.dbFactory == nil {
return fmt.Errorf("no db factory configured for reconnect")
}
newDB, err := a.dbFactory()
if err != nil {
return err
}
a.dbMu.Lock()
a.db = newDB
a.dbMu.Unlock()
return nil
}
func (a *DatabaseAuthenticator) runDBOpWithReconnect(run func(*sql.DB) error) error {
db := a.getDB()
if db == nil {
return fmt.Errorf("database connection is nil")
}
err := run(db)
if isDBClosed(err) {
if reconnErr := a.reconnectDB(); reconnErr == nil {
err = run(a.getDB())
}
}
return err
}
func (a *DatabaseAuthenticator) Login(ctx context.Context, req LoginRequest) (*LoginResponse, error) {
// Convert LoginRequest to JSON
reqJSON, err := json.Marshal(req)
@@ -118,13 +166,14 @@ func (a *DatabaseAuthenticator) Login(ctx context.Context, req LoginRequest) (*L
return nil, fmt.Errorf("failed to marshal login request: %w", err)
}
// Call resolvespec_login stored procedure
var success bool
var errorMsg sql.NullString
var dataJSON sql.NullString
query := `SELECT p_success, p_error, p_data::text FROM resolvespec_login($1::jsonb)`
err = a.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
err = a.runDBOpWithReconnect(func(db *sql.DB) error {
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Login)
return db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
})
if err != nil {
return nil, fmt.Errorf("login query failed: %w", err)
}
@@ -153,13 +202,14 @@ func (a *DatabaseAuthenticator) Register(ctx context.Context, req RegisterReques
return nil, fmt.Errorf("failed to marshal register request: %w", err)
}
// Call resolvespec_register stored procedure
var success bool
var errorMsg sql.NullString
var dataJSON sql.NullString
query := `SELECT p_success, p_error, p_data::text FROM resolvespec_register($1::jsonb)`
err = a.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
err = a.runDBOpWithReconnect(func(db *sql.DB) error {
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Register)
return db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
})
if err != nil {
return nil, fmt.Errorf("register query failed: %w", err)
}
@@ -187,13 +237,14 @@ func (a *DatabaseAuthenticator) Logout(ctx context.Context, req LogoutRequest) e
return fmt.Errorf("failed to marshal logout request: %w", err)
}
// Call resolvespec_logout stored procedure
var success bool
var errorMsg sql.NullString
var dataJSON sql.NullString
query := `SELECT p_success, p_error, p_data::text FROM resolvespec_logout($1::jsonb)`
err = a.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
err = a.runDBOpWithReconnect(func(db *sql.DB) error {
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Logout)
return db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
})
if err != nil {
return fmt.Errorf("logout query failed: %w", err)
}
@@ -266,8 +317,10 @@ func (a *DatabaseAuthenticator) Authenticate(r *http.Request) (*UserContext, err
var errorMsg sql.NullString
var userJSON sql.NullString
query := `SELECT p_success, p_error, p_user::text FROM resolvespec_session($1, $2)`
err := a.db.QueryRowContext(r.Context(), query, token, reference).Scan(&success, &errorMsg, &userJSON)
err := a.runDBOpWithReconnect(func(db *sql.DB) error {
query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2)`, a.sqlNames.Session)
return db.QueryRowContext(r.Context(), query, token, reference).Scan(&success, &errorMsg, &userJSON)
})
if err != nil {
return nil, fmt.Errorf("session query failed: %w", err)
}
@@ -338,25 +391,27 @@ func (a *DatabaseAuthenticator) updateSessionActivity(ctx context.Context, sessi
return
}
// Call resolvespec_session_update stored procedure
var success bool
var errorMsg sql.NullString
var updatedUserJSON sql.NullString
query := `SELECT p_success, p_error, p_user::text FROM resolvespec_session_update($1, $2::jsonb)`
_ = a.db.QueryRowContext(ctx, query, sessionToken, string(userJSON)).Scan(&success, &errorMsg, &updatedUserJSON)
_ = a.runDBOpWithReconnect(func(db *sql.DB) error {
query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2::jsonb)`, a.sqlNames.SessionUpdate)
return db.QueryRowContext(ctx, query, sessionToken, string(userJSON)).Scan(&success, &errorMsg, &updatedUserJSON)
})
}
// RefreshToken implements Refreshable interface
func (a *DatabaseAuthenticator) RefreshToken(ctx context.Context, refreshToken string) (*LoginResponse, error) {
// Call api_refresh_token stored procedure
// First, we need to get the current user context for the refresh token
var success bool
var errorMsg sql.NullString
var userJSON sql.NullString
// Get current session to pass to refresh
query := `SELECT p_success, p_error, p_user::text FROM resolvespec_session($1, $2)`
err := a.db.QueryRowContext(ctx, query, refreshToken, "refresh").Scan(&success, &errorMsg, &userJSON)
err := a.runDBOpWithReconnect(func(db *sql.DB) error {
query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2)`, a.sqlNames.Session)
return db.QueryRowContext(ctx, query, refreshToken, "refresh").Scan(&success, &errorMsg, &userJSON)
})
if err != nil {
return nil, fmt.Errorf("refresh token query failed: %w", err)
}
@@ -368,13 +423,14 @@ func (a *DatabaseAuthenticator) RefreshToken(ctx context.Context, refreshToken s
return nil, fmt.Errorf("invalid refresh token")
}
// Call resolvespec_refresh_token to generate new token
var newSuccess bool
var newErrorMsg sql.NullString
var newUserJSON sql.NullString
refreshQuery := `SELECT p_success, p_error, p_user::text FROM resolvespec_refresh_token($1, $2::jsonb)`
err = a.db.QueryRowContext(ctx, refreshQuery, refreshToken, userJSON).Scan(&newSuccess, &newErrorMsg, &newUserJSON)
err = a.runDBOpWithReconnect(func(db *sql.DB) error {
refreshQuery := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2::jsonb)`, a.sqlNames.RefreshToken)
return db.QueryRowContext(ctx, refreshQuery, refreshToken, userJSON).Scan(&newSuccess, &newErrorMsg, &newUserJSON)
})
if err != nil {
return nil, fmt.Errorf("refresh token generation failed: %w", err)
}
@@ -401,28 +457,65 @@ func (a *DatabaseAuthenticator) RefreshToken(ctx context.Context, refreshToken s
// JWTAuthenticator provides JWT token-based authentication
// All database operations go through stored procedures
// Requires stored procedures: resolvespec_jwt_login, resolvespec_jwt_logout
// Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults)
// NOTE: JWT signing/verification requires github.com/golang-jwt/jwt/v5 to be installed and imported
type JWTAuthenticator struct {
secretKey []byte
db *sql.DB
dbMu sync.RWMutex
dbFactory func() (*sql.DB, error)
sqlNames *SQLNames
}
func NewJWTAuthenticator(secretKey string, db *sql.DB) *JWTAuthenticator {
func NewJWTAuthenticator(secretKey string, db *sql.DB, names ...*SQLNames) *JWTAuthenticator {
return &JWTAuthenticator{
secretKey: []byte(secretKey),
db: db,
sqlNames: resolveSQLNames(names...),
}
}
// WithDBFactory configures a factory used to reopen the database connection if it is closed.
func (a *JWTAuthenticator) WithDBFactory(factory func() (*sql.DB, error)) *JWTAuthenticator {
a.dbFactory = factory
return a
}
func (a *JWTAuthenticator) getDB() *sql.DB {
a.dbMu.RLock()
defer a.dbMu.RUnlock()
return a.db
}
func (a *JWTAuthenticator) reconnectDB() error {
if a.dbFactory == nil {
return fmt.Errorf("no db factory configured for reconnect")
}
newDB, err := a.dbFactory()
if err != nil {
return err
}
a.dbMu.Lock()
a.db = newDB
a.dbMu.Unlock()
return nil
}
func (a *JWTAuthenticator) Login(ctx context.Context, req LoginRequest) (*LoginResponse, error) {
// Call resolvespec_jwt_login stored procedure
var success bool
var errorMsg sql.NullString
var userJSON []byte
query := `SELECT p_success, p_error, p_user FROM resolvespec_jwt_login($1, $2)`
err := a.db.QueryRowContext(ctx, query, req.Username, req.Password).Scan(&success, &errorMsg, &userJSON)
runLoginQuery := func() error {
query := fmt.Sprintf(`SELECT p_success, p_error, p_user FROM %s($1, $2)`, a.sqlNames.JWTLogin)
return a.getDB().QueryRowContext(ctx, query, req.Username, req.Password).Scan(&success, &errorMsg, &userJSON)
}
err := runLoginQuery()
if isDBClosed(err) {
if reconnErr := a.reconnectDB(); reconnErr == nil {
err = runLoginQuery()
}
}
if err != nil {
return nil, fmt.Errorf("login query failed: %w", err)
}
@@ -471,12 +564,11 @@ func (a *JWTAuthenticator) Login(ctx context.Context, req LoginRequest) (*LoginR
}
func (a *JWTAuthenticator) Logout(ctx context.Context, req LogoutRequest) error {
// Call resolvespec_jwt_logout stored procedure
var success bool
var errorMsg sql.NullString
query := `SELECT p_success, p_error FROM resolvespec_jwt_logout($1, $2)`
err := a.db.QueryRowContext(ctx, query, req.Token, req.UserID).Scan(&success, &errorMsg)
query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1, $2)`, a.sqlNames.JWTLogout)
err := a.getDB().QueryRowContext(ctx, query, req.Token, req.UserID).Scan(&success, &errorMsg)
if err != nil {
return fmt.Errorf("logout query failed: %w", err)
}
@@ -511,25 +603,60 @@ func (a *JWTAuthenticator) Authenticate(r *http.Request) (*UserContext, error) {
// DatabaseColumnSecurityProvider loads column security from database
// All database operations go through stored procedures
// Requires stored procedure: resolvespec_column_security
// Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults)
type DatabaseColumnSecurityProvider struct {
db *sql.DB
db *sql.DB
dbMu sync.RWMutex
dbFactory func() (*sql.DB, error)
sqlNames *SQLNames
}
func NewDatabaseColumnSecurityProvider(db *sql.DB) *DatabaseColumnSecurityProvider {
return &DatabaseColumnSecurityProvider{db: db}
func NewDatabaseColumnSecurityProvider(db *sql.DB, names ...*SQLNames) *DatabaseColumnSecurityProvider {
return &DatabaseColumnSecurityProvider{db: db, sqlNames: resolveSQLNames(names...)}
}
func (p *DatabaseColumnSecurityProvider) WithDBFactory(factory func() (*sql.DB, error)) *DatabaseColumnSecurityProvider {
p.dbFactory = factory
return p
}
func (p *DatabaseColumnSecurityProvider) getDB() *sql.DB {
p.dbMu.RLock()
defer p.dbMu.RUnlock()
return p.db
}
func (p *DatabaseColumnSecurityProvider) reconnectDB() error {
if p.dbFactory == nil {
return fmt.Errorf("no db factory configured for reconnect")
}
newDB, err := p.dbFactory()
if err != nil {
return err
}
p.dbMu.Lock()
p.db = newDB
p.dbMu.Unlock()
return nil
}
func (p *DatabaseColumnSecurityProvider) GetColumnSecurity(ctx context.Context, userID int, schema, table string) ([]ColumnSecurity, error) {
var rules []ColumnSecurity
// Call resolvespec_column_security stored procedure
var success bool
var errorMsg sql.NullString
var rulesJSON []byte
query := `SELECT p_success, p_error, p_rules FROM resolvespec_column_security($1, $2, $3)`
err := p.db.QueryRowContext(ctx, query, userID, schema, table).Scan(&success, &errorMsg, &rulesJSON)
runQuery := func() error {
query := fmt.Sprintf(`SELECT p_success, p_error, p_rules FROM %s($1, $2, $3)`, p.sqlNames.ColumnSecurity)
return p.getDB().QueryRowContext(ctx, query, userID, schema, table).Scan(&success, &errorMsg, &rulesJSON)
}
err := runQuery()
if isDBClosed(err) {
if reconnErr := p.reconnectDB(); reconnErr == nil {
err = runQuery()
}
}
if err != nil {
return nil, fmt.Errorf("failed to load column security: %w", err)
}
@@ -576,23 +703,57 @@ func (p *DatabaseColumnSecurityProvider) GetColumnSecurity(ctx context.Context,
// DatabaseRowSecurityProvider loads row security from database
// All database operations go through stored procedures
// Requires stored procedure: resolvespec_row_security
// Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults)
type DatabaseRowSecurityProvider struct {
db *sql.DB
db *sql.DB
dbMu sync.RWMutex
dbFactory func() (*sql.DB, error)
sqlNames *SQLNames
}
func NewDatabaseRowSecurityProvider(db *sql.DB) *DatabaseRowSecurityProvider {
return &DatabaseRowSecurityProvider{db: db}
func NewDatabaseRowSecurityProvider(db *sql.DB, names ...*SQLNames) *DatabaseRowSecurityProvider {
return &DatabaseRowSecurityProvider{db: db, sqlNames: resolveSQLNames(names...)}
}
func (p *DatabaseRowSecurityProvider) WithDBFactory(factory func() (*sql.DB, error)) *DatabaseRowSecurityProvider {
p.dbFactory = factory
return p
}
func (p *DatabaseRowSecurityProvider) getDB() *sql.DB {
p.dbMu.RLock()
defer p.dbMu.RUnlock()
return p.db
}
func (p *DatabaseRowSecurityProvider) reconnectDB() error {
if p.dbFactory == nil {
return fmt.Errorf("no db factory configured for reconnect")
}
newDB, err := p.dbFactory()
if err != nil {
return err
}
p.dbMu.Lock()
p.db = newDB
p.dbMu.Unlock()
return nil
}
func (p *DatabaseRowSecurityProvider) GetRowSecurity(ctx context.Context, userID int, schema, table string) (RowSecurity, error) {
var template string
var hasBlock bool
// Call resolvespec_row_security stored procedure
query := `SELECT p_template, p_block FROM resolvespec_row_security($1, $2, $3)`
err := p.db.QueryRowContext(ctx, query, schema, table, userID).Scan(&template, &hasBlock)
runQuery := func() error {
query := fmt.Sprintf(`SELECT p_template, p_block FROM %s($1, $2, $3)`, p.sqlNames.RowSecurity)
return p.getDB().QueryRowContext(ctx, query, schema, table, userID).Scan(&template, &hasBlock)
}
err := runQuery()
if isDBClosed(err) {
if reconnErr := p.reconnectDB(); reconnErr == nil {
err = runQuery()
}
}
if err != nil {
return RowSecurity{}, fmt.Errorf("failed to load row security: %w", err)
}
@@ -662,6 +823,11 @@ func (p *ConfigRowSecurityProvider) GetRowSecurity(ctx context.Context, userID i
// Helper functions
// ================
// isDBClosed reports whether err indicates the *sql.DB has been closed.
func isDBClosed(err error) bool {
return err != nil && strings.Contains(err.Error(), "sql: database is closed")
}
func parseRoles(rolesStr string) []string {
if rolesStr == "" {
return []string{}
@@ -758,56 +924,55 @@ func (a *DatabaseAuthenticator) LoginWithPasskey(ctx context.Context, req Passke
return nil, fmt.Errorf("passkey authentication failed: %w", err)
}
// Get user data from database
var username, email, roles string
var userLevel int
query := `SELECT username, email, user_level, COALESCE(roles, '') FROM users WHERE id = $1 AND is_active = true`
err = a.db.QueryRowContext(ctx, query, userID).Scan(&username, &email, &userLevel, &roles)
if err != nil {
return nil, fmt.Errorf("failed to get user data: %w", err)
// Build request JSON for passkey login stored procedure
reqData := map[string]any{
"user_id": userID,
}
// Generate session token
sessionToken := "sess_" + generateRandomString(32) + "_" + fmt.Sprintf("%d", time.Now().Unix())
expiresAt := time.Now().Add(24 * time.Hour)
// Extract IP and user agent from claims
ipAddress := ""
userAgent := ""
if req.Claims != nil {
if ip, ok := req.Claims["ip_address"].(string); ok {
ipAddress = ip
reqData["ip_address"] = ip
}
if ua, ok := req.Claims["user_agent"].(string); ok {
userAgent = ua
reqData["user_agent"] = ua
}
}
// Create session
insertQuery := `INSERT INTO user_sessions (session_token, user_id, expires_at, ip_address, user_agent, last_activity_at)
VALUES ($1, $2, $3, $4, $5, now())`
_, err = a.db.ExecContext(ctx, insertQuery, sessionToken, userID, expiresAt, ipAddress, userAgent)
reqJSON, err := json.Marshal(reqData)
if err != nil {
return nil, fmt.Errorf("failed to create session: %w", err)
return nil, fmt.Errorf("failed to marshal passkey login request: %w", err)
}
// Update last login
updateQuery := `UPDATE users SET last_login_at = now() WHERE id = $1`
_, _ = a.db.ExecContext(ctx, updateQuery, userID)
var success bool
var errorMsg sql.NullString
var dataJSON sql.NullString
// Return login response
return &LoginResponse{
Token: sessionToken,
User: &UserContext{
UserID: userID,
UserName: username,
Email: email,
UserLevel: userLevel,
SessionID: sessionToken,
Roles: parseRoles(roles),
},
ExpiresIn: int64(24 * time.Hour.Seconds()),
}, nil
runPasskeyQuery := func() error {
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.PasskeyLogin)
return a.getDB().QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
}
err = runPasskeyQuery()
if isDBClosed(err) {
if reconnErr := a.reconnectDB(); reconnErr == nil {
err = runPasskeyQuery()
}
}
if err != nil {
return nil, fmt.Errorf("passkey login query failed: %w", err)
}
if !success {
if errorMsg.Valid {
return nil, fmt.Errorf("%s", errorMsg.String)
}
return nil, fmt.Errorf("passkey login failed")
}
var response LoginResponse
if err := json.Unmarshal([]byte(dataJSON.String), &response); err != nil {
return nil, fmt.Errorf("failed to parse passkey login response: %w", err)
}
return &response, nil
}
// GetPasskeyCredentials returns all passkey credentials for a user

View File

@@ -3,6 +3,7 @@ package security
import (
"context"
"database/sql"
"fmt"
"net/http"
"net/http/httptest"
"testing"
@@ -790,6 +791,211 @@ func TestDatabaseAuthenticatorRefreshToken(t *testing.T) {
})
}
func TestDatabaseAuthenticatorReconnectsClosedDBPaths(t *testing.T) {
newAuthWithReconnect := func(t *testing.T) (*DatabaseAuthenticator, sqlmock.Sqlmock, sqlmock.Sqlmock, func()) {
t.Helper()
primaryDB, primaryMock, err := sqlmock.New()
if err != nil {
t.Fatalf("failed to create primary mock db: %v", err)
}
reconnectDB, reconnectMock, err := sqlmock.New()
if err != nil {
primaryDB.Close()
t.Fatalf("failed to create reconnect mock db: %v", err)
}
cacheProvider := cache.NewMemoryProvider(&cache.Options{
DefaultTTL: 1 * time.Minute,
MaxSize: 1000,
})
auth := NewDatabaseAuthenticatorWithOptions(primaryDB, DatabaseAuthenticatorOptions{
Cache: cache.NewCache(cacheProvider),
DBFactory: func() (*sql.DB, error) {
return reconnectDB, nil
},
})
cleanup := func() {
_ = primaryDB.Close()
_ = reconnectDB.Close()
}
return auth, primaryMock, reconnectMock, cleanup
}
t.Run("Authenticate reconnects after closed database", func(t *testing.T) {
auth, primaryMock, reconnectMock, cleanup := newAuthWithReconnect(t)
defer cleanup()
req := httptest.NewRequest("GET", "/test", nil)
req.Header.Set("Authorization", "Bearer reconnect-auth-token")
primaryMock.ExpectQuery(`SELECT p_success, p_error, p_user::text FROM resolvespec_session`).
WithArgs("reconnect-auth-token", "authenticate").
WillReturnError(fmt.Errorf("sql: database is closed"))
reconnectRows := sqlmock.NewRows([]string{"p_success", "p_error", "p_user"}).
AddRow(true, nil, `{"user_id":7,"user_name":"reconnect-user","session_id":"reconnect-auth-token"}`)
reconnectMock.ExpectQuery(`SELECT p_success, p_error, p_user::text FROM resolvespec_session`).
WithArgs("reconnect-auth-token", "authenticate").
WillReturnRows(reconnectRows)
userCtx, err := auth.Authenticate(req)
if err != nil {
t.Fatalf("expected authenticate to reconnect, got %v", err)
}
if userCtx.UserID != 7 {
t.Fatalf("expected user ID 7, got %d", userCtx.UserID)
}
if err := primaryMock.ExpectationsWereMet(); err != nil {
t.Fatalf("primary db expectations not met: %v", err)
}
if err := reconnectMock.ExpectationsWereMet(); err != nil {
t.Fatalf("reconnect db expectations not met: %v", err)
}
})
t.Run("Register reconnects after closed database", func(t *testing.T) {
auth, primaryMock, reconnectMock, cleanup := newAuthWithReconnect(t)
defer cleanup()
req := RegisterRequest{
Username: "reconnect-register",
Password: "password123",
Email: "reconnect@example.com",
UserLevel: 1,
Roles: []string{"user"},
}
primaryMock.ExpectQuery(`SELECT p_success, p_error, p_data::text FROM resolvespec_register`).
WithArgs(sqlmock.AnyArg()).
WillReturnError(fmt.Errorf("sql: database is closed"))
reconnectRows := sqlmock.NewRows([]string{"p_success", "p_error", "p_data"}).
AddRow(true, nil, `{"token":"reconnected-register-token","user":{"user_id":8,"user_name":"reconnect-register"},"expires_in":86400}`)
reconnectMock.ExpectQuery(`SELECT p_success, p_error, p_data::text FROM resolvespec_register`).
WithArgs(sqlmock.AnyArg()).
WillReturnRows(reconnectRows)
resp, err := auth.Register(context.Background(), req)
if err != nil {
t.Fatalf("expected register to reconnect, got %v", err)
}
if resp.Token != "reconnected-register-token" {
t.Fatalf("expected refreshed token, got %s", resp.Token)
}
if err := primaryMock.ExpectationsWereMet(); err != nil {
t.Fatalf("primary db expectations not met: %v", err)
}
if err := reconnectMock.ExpectationsWereMet(); err != nil {
t.Fatalf("reconnect db expectations not met: %v", err)
}
})
t.Run("Logout reconnects after closed database", func(t *testing.T) {
auth, primaryMock, reconnectMock, cleanup := newAuthWithReconnect(t)
defer cleanup()
req := LogoutRequest{Token: "logout-reconnect-token", UserID: 9}
primaryMock.ExpectQuery(`SELECT p_success, p_error, p_data::text FROM resolvespec_logout`).
WithArgs(sqlmock.AnyArg()).
WillReturnError(fmt.Errorf("sql: database is closed"))
reconnectRows := sqlmock.NewRows([]string{"p_success", "p_error", "p_data"}).
AddRow(true, nil, nil)
reconnectMock.ExpectQuery(`SELECT p_success, p_error, p_data::text FROM resolvespec_logout`).
WithArgs(sqlmock.AnyArg()).
WillReturnRows(reconnectRows)
if err := auth.Logout(context.Background(), req); err != nil {
t.Fatalf("expected logout to reconnect, got %v", err)
}
if err := primaryMock.ExpectationsWereMet(); err != nil {
t.Fatalf("primary db expectations not met: %v", err)
}
if err := reconnectMock.ExpectationsWereMet(); err != nil {
t.Fatalf("reconnect db expectations not met: %v", err)
}
})
t.Run("RefreshToken reconnects after closed database", func(t *testing.T) {
auth, primaryMock, reconnectMock, cleanup := newAuthWithReconnect(t)
defer cleanup()
refreshToken := "refresh-reconnect-token"
primaryMock.ExpectQuery(`SELECT p_success, p_error, p_user::text FROM resolvespec_session`).
WithArgs(refreshToken, "refresh").
WillReturnError(fmt.Errorf("sql: database is closed"))
sessionRows := sqlmock.NewRows([]string{"p_success", "p_error", "p_user"}).
AddRow(true, nil, `{"user_id":10,"user_name":"refresh-user"}`)
reconnectMock.ExpectQuery(`SELECT p_success, p_error, p_user::text FROM resolvespec_session`).
WithArgs(refreshToken, "refresh").
WillReturnRows(sessionRows)
refreshRows := sqlmock.NewRows([]string{"p_success", "p_error", "p_user"}).
AddRow(true, nil, `{"user_id":10,"user_name":"refresh-user","session_id":"refreshed-token"}`)
reconnectMock.ExpectQuery(`SELECT p_success, p_error, p_user::text FROM resolvespec_refresh_token`).
WithArgs(refreshToken, sqlmock.AnyArg()).
WillReturnRows(refreshRows)
resp, err := auth.RefreshToken(context.Background(), refreshToken)
if err != nil {
t.Fatalf("expected refresh token to reconnect, got %v", err)
}
if resp.Token != "refreshed-token" {
t.Fatalf("expected refreshed-token, got %s", resp.Token)
}
if err := primaryMock.ExpectationsWereMet(); err != nil {
t.Fatalf("primary db expectations not met: %v", err)
}
if err := reconnectMock.ExpectationsWereMet(); err != nil {
t.Fatalf("reconnect db expectations not met: %v", err)
}
})
t.Run("updateSessionActivity reconnects after closed database", func(t *testing.T) {
auth, primaryMock, reconnectMock, cleanup := newAuthWithReconnect(t)
defer cleanup()
userCtx := &UserContext{UserID: 11, UserName: "activity-user", SessionID: "activity-token"}
primaryMock.ExpectQuery(`SELECT p_success, p_error, p_user::text FROM resolvespec_session_update`).
WithArgs("activity-token", sqlmock.AnyArg()).
WillReturnError(fmt.Errorf("sql: database is closed"))
reconnectRows := sqlmock.NewRows([]string{"p_success", "p_error", "p_user"}).
AddRow(true, nil, `{"user_id":11,"user_name":"activity-user","session_id":"activity-token"}`)
reconnectMock.ExpectQuery(`SELECT p_success, p_error, p_user::text FROM resolvespec_session_update`).
WithArgs("activity-token", sqlmock.AnyArg()).
WillReturnRows(reconnectRows)
auth.updateSessionActivity(context.Background(), "activity-token", userCtx)
if err := primaryMock.ExpectationsWereMet(); err != nil {
t.Fatalf("primary db expectations not met: %v", err)
}
if err := reconnectMock.ExpectationsWereMet(); err != nil {
t.Fatalf("reconnect db expectations not met: %v", err)
}
})
}
// Test JWTAuthenticator
func TestJWTAuthenticator(t *testing.T) {
db, mock, err := sqlmock.New()

254
pkg/security/sql_names.go Normal file
View File

@@ -0,0 +1,254 @@
package security
import (
"fmt"
"reflect"
"regexp"
)
var validSQLIdentifier = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`)
// SQLNames defines all configurable SQL stored procedure and table names
// used by the security package. Override individual fields to remap
// to custom database objects. Use DefaultSQLNames() for baseline defaults,
// and MergeSQLNames() to apply partial overrides.
type SQLNames struct {
// Auth procedures (DatabaseAuthenticator)
Login string // default: "resolvespec_login"
Register string // default: "resolvespec_register"
Logout string // default: "resolvespec_logout"
Session string // default: "resolvespec_session"
SessionUpdate string // default: "resolvespec_session_update"
RefreshToken string // default: "resolvespec_refresh_token"
// JWT procedures (JWTAuthenticator)
JWTLogin string // default: "resolvespec_jwt_login"
JWTLogout string // default: "resolvespec_jwt_logout"
// Security policy procedures
ColumnSecurity string // default: "resolvespec_column_security"
RowSecurity string // default: "resolvespec_row_security"
// TOTP procedures (DatabaseTwoFactorProvider)
TOTPEnable string // default: "resolvespec_totp_enable"
TOTPDisable string // default: "resolvespec_totp_disable"
TOTPGetStatus string // default: "resolvespec_totp_get_status"
TOTPGetSecret string // default: "resolvespec_totp_get_secret"
TOTPRegenerateBackup string // default: "resolvespec_totp_regenerate_backup_codes"
TOTPValidateBackupCode string // default: "resolvespec_totp_validate_backup_code"
// Passkey procedures (DatabasePasskeyProvider)
PasskeyStoreCredential string // default: "resolvespec_passkey_store_credential"
PasskeyGetCredsByUsername string // default: "resolvespec_passkey_get_credentials_by_username"
PasskeyGetCredential string // default: "resolvespec_passkey_get_credential"
PasskeyUpdateCounter string // default: "resolvespec_passkey_update_counter"
PasskeyGetUserCredentials string // default: "resolvespec_passkey_get_user_credentials"
PasskeyDeleteCredential string // default: "resolvespec_passkey_delete_credential"
PasskeyUpdateName string // default: "resolvespec_passkey_update_name"
PasskeyLogin string // default: "resolvespec_passkey_login"
// OAuth2 procedures (DatabaseAuthenticator OAuth2 methods)
OAuthGetOrCreateUser string // default: "resolvespec_oauth_getorcreateuser"
OAuthCreateSession string // default: "resolvespec_oauth_createsession"
OAuthGetRefreshToken string // default: "resolvespec_oauth_getrefreshtoken"
OAuthUpdateRefreshToken string // default: "resolvespec_oauth_updaterefreshtoken"
OAuthGetUser string // default: "resolvespec_oauth_getuser"
// OAuth2 server procedures (OAuthServer persistence)
OAuthRegisterClient string // default: "resolvespec_oauth_register_client"
OAuthGetClient string // default: "resolvespec_oauth_get_client"
OAuthSaveCode string // default: "resolvespec_oauth_save_code"
OAuthExchangeCode string // default: "resolvespec_oauth_exchange_code"
OAuthIntrospect string // default: "resolvespec_oauth_introspect"
OAuthRevoke string // default: "resolvespec_oauth_revoke"
}
// DefaultSQLNames returns an SQLNames with all default resolvespec_* values.
func DefaultSQLNames() *SQLNames {
return &SQLNames{
Login: "resolvespec_login",
Register: "resolvespec_register",
Logout: "resolvespec_logout",
Session: "resolvespec_session",
SessionUpdate: "resolvespec_session_update",
RefreshToken: "resolvespec_refresh_token",
JWTLogin: "resolvespec_jwt_login",
JWTLogout: "resolvespec_jwt_logout",
ColumnSecurity: "resolvespec_column_security",
RowSecurity: "resolvespec_row_security",
TOTPEnable: "resolvespec_totp_enable",
TOTPDisable: "resolvespec_totp_disable",
TOTPGetStatus: "resolvespec_totp_get_status",
TOTPGetSecret: "resolvespec_totp_get_secret",
TOTPRegenerateBackup: "resolvespec_totp_regenerate_backup_codes",
TOTPValidateBackupCode: "resolvespec_totp_validate_backup_code",
PasskeyStoreCredential: "resolvespec_passkey_store_credential",
PasskeyGetCredsByUsername: "resolvespec_passkey_get_credentials_by_username",
PasskeyGetCredential: "resolvespec_passkey_get_credential",
PasskeyUpdateCounter: "resolvespec_passkey_update_counter",
PasskeyGetUserCredentials: "resolvespec_passkey_get_user_credentials",
PasskeyDeleteCredential: "resolvespec_passkey_delete_credential",
PasskeyUpdateName: "resolvespec_passkey_update_name",
PasskeyLogin: "resolvespec_passkey_login",
OAuthGetOrCreateUser: "resolvespec_oauth_getorcreateuser",
OAuthCreateSession: "resolvespec_oauth_createsession",
OAuthGetRefreshToken: "resolvespec_oauth_getrefreshtoken",
OAuthUpdateRefreshToken: "resolvespec_oauth_updaterefreshtoken",
OAuthGetUser: "resolvespec_oauth_getuser",
OAuthRegisterClient: "resolvespec_oauth_register_client",
OAuthGetClient: "resolvespec_oauth_get_client",
OAuthSaveCode: "resolvespec_oauth_save_code",
OAuthExchangeCode: "resolvespec_oauth_exchange_code",
OAuthIntrospect: "resolvespec_oauth_introspect",
OAuthRevoke: "resolvespec_oauth_revoke",
}
}
// MergeSQLNames returns a copy of base with any non-empty fields from override applied.
// If override is nil, a copy of base is returned.
func MergeSQLNames(base, override *SQLNames) *SQLNames {
if override == nil {
copied := *base
return &copied
}
merged := *base
if override.Login != "" {
merged.Login = override.Login
}
if override.Register != "" {
merged.Register = override.Register
}
if override.Logout != "" {
merged.Logout = override.Logout
}
if override.Session != "" {
merged.Session = override.Session
}
if override.SessionUpdate != "" {
merged.SessionUpdate = override.SessionUpdate
}
if override.RefreshToken != "" {
merged.RefreshToken = override.RefreshToken
}
if override.JWTLogin != "" {
merged.JWTLogin = override.JWTLogin
}
if override.JWTLogout != "" {
merged.JWTLogout = override.JWTLogout
}
if override.ColumnSecurity != "" {
merged.ColumnSecurity = override.ColumnSecurity
}
if override.RowSecurity != "" {
merged.RowSecurity = override.RowSecurity
}
if override.TOTPEnable != "" {
merged.TOTPEnable = override.TOTPEnable
}
if override.TOTPDisable != "" {
merged.TOTPDisable = override.TOTPDisable
}
if override.TOTPGetStatus != "" {
merged.TOTPGetStatus = override.TOTPGetStatus
}
if override.TOTPGetSecret != "" {
merged.TOTPGetSecret = override.TOTPGetSecret
}
if override.TOTPRegenerateBackup != "" {
merged.TOTPRegenerateBackup = override.TOTPRegenerateBackup
}
if override.TOTPValidateBackupCode != "" {
merged.TOTPValidateBackupCode = override.TOTPValidateBackupCode
}
if override.PasskeyStoreCredential != "" {
merged.PasskeyStoreCredential = override.PasskeyStoreCredential
}
if override.PasskeyGetCredsByUsername != "" {
merged.PasskeyGetCredsByUsername = override.PasskeyGetCredsByUsername
}
if override.PasskeyGetCredential != "" {
merged.PasskeyGetCredential = override.PasskeyGetCredential
}
if override.PasskeyUpdateCounter != "" {
merged.PasskeyUpdateCounter = override.PasskeyUpdateCounter
}
if override.PasskeyGetUserCredentials != "" {
merged.PasskeyGetUserCredentials = override.PasskeyGetUserCredentials
}
if override.PasskeyDeleteCredential != "" {
merged.PasskeyDeleteCredential = override.PasskeyDeleteCredential
}
if override.PasskeyUpdateName != "" {
merged.PasskeyUpdateName = override.PasskeyUpdateName
}
if override.PasskeyLogin != "" {
merged.PasskeyLogin = override.PasskeyLogin
}
if override.OAuthGetOrCreateUser != "" {
merged.OAuthGetOrCreateUser = override.OAuthGetOrCreateUser
}
if override.OAuthCreateSession != "" {
merged.OAuthCreateSession = override.OAuthCreateSession
}
if override.OAuthGetRefreshToken != "" {
merged.OAuthGetRefreshToken = override.OAuthGetRefreshToken
}
if override.OAuthUpdateRefreshToken != "" {
merged.OAuthUpdateRefreshToken = override.OAuthUpdateRefreshToken
}
if override.OAuthGetUser != "" {
merged.OAuthGetUser = override.OAuthGetUser
}
if override.OAuthRegisterClient != "" {
merged.OAuthRegisterClient = override.OAuthRegisterClient
}
if override.OAuthGetClient != "" {
merged.OAuthGetClient = override.OAuthGetClient
}
if override.OAuthSaveCode != "" {
merged.OAuthSaveCode = override.OAuthSaveCode
}
if override.OAuthExchangeCode != "" {
merged.OAuthExchangeCode = override.OAuthExchangeCode
}
if override.OAuthIntrospect != "" {
merged.OAuthIntrospect = override.OAuthIntrospect
}
if override.OAuthRevoke != "" {
merged.OAuthRevoke = override.OAuthRevoke
}
return &merged
}
// ValidateSQLNames checks that all non-empty fields in names are valid SQL identifiers.
// Returns an error if any field contains invalid characters.
func ValidateSQLNames(names *SQLNames) error {
v := reflect.ValueOf(names).Elem()
typ := v.Type()
for i := 0; i < v.NumField(); i++ {
field := v.Field(i)
if field.Kind() != reflect.String {
continue
}
val := field.String()
if val != "" && !validSQLIdentifier.MatchString(val) {
return fmt.Errorf("SQLNames.%s contains invalid characters: %q", typ.Field(i).Name, val)
}
}
return nil
}
// resolveSQLNames merges an optional override with defaults.
// Used by constructors that accept variadic *SQLNames.
func resolveSQLNames(override ...*SQLNames) *SQLNames {
if len(override) > 0 && override[0] != nil {
return MergeSQLNames(DefaultSQLNames(), override[0])
}
return DefaultSQLNames()
}

View File

@@ -0,0 +1,145 @@
package security
import (
"reflect"
"testing"
)
func TestDefaultSQLNames_AllFieldsNonEmpty(t *testing.T) {
names := DefaultSQLNames()
v := reflect.ValueOf(names).Elem()
typ := v.Type()
for i := 0; i < v.NumField(); i++ {
field := v.Field(i)
if field.Kind() != reflect.String {
continue
}
if field.String() == "" {
t.Errorf("DefaultSQLNames().%s is empty", typ.Field(i).Name)
}
}
}
func TestMergeSQLNames_PartialOverride(t *testing.T) {
base := DefaultSQLNames()
override := &SQLNames{
Login: "custom_login",
TOTPEnable: "custom_totp_enable",
PasskeyLogin: "custom_passkey_login",
}
merged := MergeSQLNames(base, override)
if merged.Login != "custom_login" {
t.Errorf("MergeSQLNames().Login = %q, want %q", merged.Login, "custom_login")
}
if merged.TOTPEnable != "custom_totp_enable" {
t.Errorf("MergeSQLNames().TOTPEnable = %q, want %q", merged.TOTPEnable, "custom_totp_enable")
}
if merged.PasskeyLogin != "custom_passkey_login" {
t.Errorf("MergeSQLNames().PasskeyLogin = %q, want %q", merged.PasskeyLogin, "custom_passkey_login")
}
// Non-overridden fields should retain defaults
if merged.Logout != "resolvespec_logout" {
t.Errorf("MergeSQLNames().Logout = %q, want %q", merged.Logout, "resolvespec_logout")
}
if merged.Session != "resolvespec_session" {
t.Errorf("MergeSQLNames().Session = %q, want %q", merged.Session, "resolvespec_session")
}
}
func TestMergeSQLNames_NilOverride(t *testing.T) {
base := DefaultSQLNames()
merged := MergeSQLNames(base, nil)
// Should be a copy, not the same pointer
if merged == base {
t.Error("MergeSQLNames with nil override should return a copy, not the same pointer")
}
// All values should match
v1 := reflect.ValueOf(base).Elem()
v2 := reflect.ValueOf(merged).Elem()
typ := v1.Type()
for i := 0; i < v1.NumField(); i++ {
f1 := v1.Field(i)
f2 := v2.Field(i)
if f1.Kind() != reflect.String {
continue
}
if f1.String() != f2.String() {
t.Errorf("MergeSQLNames(base, nil).%s = %q, want %q", typ.Field(i).Name, f2.String(), f1.String())
}
}
}
func TestMergeSQLNames_DoesNotMutateBase(t *testing.T) {
base := DefaultSQLNames()
originalLogin := base.Login
override := &SQLNames{Login: "custom_login"}
_ = MergeSQLNames(base, override)
if base.Login != originalLogin {
t.Errorf("MergeSQLNames mutated base: Login = %q, want %q", base.Login, originalLogin)
}
}
func TestMergeSQLNames_AllFieldsMerged(t *testing.T) {
base := DefaultSQLNames()
override := &SQLNames{}
v := reflect.ValueOf(override).Elem()
for i := 0; i < v.NumField(); i++ {
if v.Field(i).Kind() == reflect.String {
v.Field(i).SetString("custom_sentinel")
}
}
merged := MergeSQLNames(base, override)
mv := reflect.ValueOf(merged).Elem()
typ := mv.Type()
for i := 0; i < mv.NumField(); i++ {
if mv.Field(i).Kind() != reflect.String {
continue
}
if mv.Field(i).String() != "custom_sentinel" {
t.Errorf("MergeSQLNames did not merge field %s", typ.Field(i).Name)
}
}
}
func TestValidateSQLNames_Valid(t *testing.T) {
names := DefaultSQLNames()
if err := ValidateSQLNames(names); err != nil {
t.Errorf("ValidateSQLNames(defaults) error = %v", err)
}
}
func TestValidateSQLNames_Invalid(t *testing.T) {
names := DefaultSQLNames()
names.Login = "resolvespec_login; DROP TABLE users; --"
err := ValidateSQLNames(names)
if err == nil {
t.Error("ValidateSQLNames should reject names with invalid characters")
}
}
func TestResolveSQLNames_NoOverride(t *testing.T) {
names := resolveSQLNames()
if names.Login != "resolvespec_login" {
t.Errorf("resolveSQLNames().Login = %q, want default", names.Login)
}
}
func TestResolveSQLNames_WithOverride(t *testing.T) {
names := resolveSQLNames(&SQLNames{Login: "custom_login"})
if names.Login != "custom_login" {
t.Errorf("resolveSQLNames().Login = %q, want %q", names.Login, "custom_login")
}
if names.Logout != "resolvespec_logout" {
t.Errorf("resolveSQLNames().Logout = %q, want default", names.Logout)
}
}

View File

@@ -9,23 +9,23 @@ import (
)
// DatabaseTwoFactorProvider implements TwoFactorAuthProvider using PostgreSQL stored procedures
// Requires stored procedures: resolvespec_totp_enable, resolvespec_totp_disable,
// resolvespec_totp_get_status, resolvespec_totp_get_secret,
// resolvespec_totp_regenerate_backup_codes, resolvespec_totp_validate_backup_code
// Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults)
// See totp_database_schema.sql for procedure definitions
type DatabaseTwoFactorProvider struct {
db *sql.DB
totpGen *TOTPGenerator
db *sql.DB
totpGen *TOTPGenerator
sqlNames *SQLNames
}
// NewDatabaseTwoFactorProvider creates a new database-backed 2FA provider
func NewDatabaseTwoFactorProvider(db *sql.DB, config *TwoFactorConfig) *DatabaseTwoFactorProvider {
func NewDatabaseTwoFactorProvider(db *sql.DB, config *TwoFactorConfig, names ...*SQLNames) *DatabaseTwoFactorProvider {
if config == nil {
config = DefaultTwoFactorConfig()
}
return &DatabaseTwoFactorProvider{
db: db,
totpGen: NewTOTPGenerator(config),
db: db,
totpGen: NewTOTPGenerator(config),
sqlNames: resolveSQLNames(names...),
}
}
@@ -76,7 +76,7 @@ func (p *DatabaseTwoFactorProvider) Enable2FA(userID int, secret string, backupC
var success bool
var errorMsg sql.NullString
query := `SELECT p_success, p_error FROM resolvespec_totp_enable($1, $2, $3::jsonb)`
query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1, $2, $3::jsonb)`, p.sqlNames.TOTPEnable)
err = p.db.QueryRow(query, userID, secret, string(codesJSON)).Scan(&success, &errorMsg)
if err != nil {
return fmt.Errorf("enable 2FA query failed: %w", err)
@@ -97,7 +97,7 @@ func (p *DatabaseTwoFactorProvider) Disable2FA(userID int) error {
var success bool
var errorMsg sql.NullString
query := `SELECT p_success, p_error FROM resolvespec_totp_disable($1)`
query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1)`, p.sqlNames.TOTPDisable)
err := p.db.QueryRow(query, userID).Scan(&success, &errorMsg)
if err != nil {
return fmt.Errorf("disable 2FA query failed: %w", err)
@@ -119,7 +119,7 @@ func (p *DatabaseTwoFactorProvider) Get2FAStatus(userID int) (bool, error) {
var errorMsg sql.NullString
var enabled bool
query := `SELECT p_success, p_error, p_enabled FROM resolvespec_totp_get_status($1)`
query := fmt.Sprintf(`SELECT p_success, p_error, p_enabled FROM %s($1)`, p.sqlNames.TOTPGetStatus)
err := p.db.QueryRow(query, userID).Scan(&success, &errorMsg, &enabled)
if err != nil {
return false, fmt.Errorf("get 2FA status query failed: %w", err)
@@ -141,7 +141,7 @@ func (p *DatabaseTwoFactorProvider) Get2FASecret(userID int) (string, error) {
var errorMsg sql.NullString
var secret sql.NullString
query := `SELECT p_success, p_error, p_secret FROM resolvespec_totp_get_secret($1)`
query := fmt.Sprintf(`SELECT p_success, p_error, p_secret FROM %s($1)`, p.sqlNames.TOTPGetSecret)
err := p.db.QueryRow(query, userID).Scan(&success, &errorMsg, &secret)
if err != nil {
return "", fmt.Errorf("get 2FA secret query failed: %w", err)
@@ -185,7 +185,7 @@ func (p *DatabaseTwoFactorProvider) GenerateBackupCodes(userID int, count int) (
var success bool
var errorMsg sql.NullString
query := `SELECT p_success, p_error FROM resolvespec_totp_regenerate_backup_codes($1, $2::jsonb)`
query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1, $2::jsonb)`, p.sqlNames.TOTPRegenerateBackup)
err = p.db.QueryRow(query, userID, string(codesJSON)).Scan(&success, &errorMsg)
if err != nil {
return nil, fmt.Errorf("regenerate backup codes query failed: %w", err)
@@ -212,7 +212,7 @@ func (p *DatabaseTwoFactorProvider) ValidateBackupCode(userID int, code string)
var errorMsg sql.NullString
var valid bool
query := `SELECT p_success, p_error, p_valid FROM resolvespec_totp_validate_backup_code($1, $2)`
query := fmt.Sprintf(`SELECT p_success, p_error, p_valid FROM %s($1, $2)`, p.sqlNames.TOTPValidateBackupCode)
err := p.db.QueryRow(query, userID, codeHash).Scan(&success, &errorMsg, &valid)
if err != nil {
return false, fmt.Errorf("validate backup code query failed: %w", err)

View File

@@ -3,6 +3,7 @@ package server_test
import (
"context"
"fmt"
"log"
"net/http"
"time"
@@ -29,18 +30,18 @@ func ExampleManager_basic() {
GZIP: true, // Enable GZIP compression
})
if err != nil {
panic(err)
log.Fatal(err)
}
// Start all servers
if err := mgr.StartAll(); err != nil {
panic(err)
log.Fatal(err)
}
// Server is now running...
// When done, stop gracefully
if err := mgr.StopAll(); err != nil {
panic(err)
log.Fatal(err)
}
}
@@ -61,7 +62,7 @@ func ExampleManager_https() {
SSLKey: "/path/to/key.pem",
})
if err != nil {
panic(err)
log.Fatal(err)
}
// Option 2: Self-signed certificate (for development)
@@ -73,27 +74,27 @@ func ExampleManager_https() {
SelfSignedSSL: true,
})
if err != nil {
panic(err)
log.Fatal(err)
}
// Option 3: Let's Encrypt / AutoTLS (for production)
_, err = mgr.Add(server.Config{
Name: "https-server-letsencrypt",
Host: "0.0.0.0",
Port: 443,
Handler: handler,
AutoTLS: true,
AutoTLSDomains: []string{"example.com", "www.example.com"},
AutoTLSEmail: "admin@example.com",
Name: "https-server-letsencrypt",
Host: "0.0.0.0",
Port: 443,
Handler: handler,
AutoTLS: true,
AutoTLSDomains: []string{"example.com", "www.example.com"},
AutoTLSEmail: "admin@example.com",
AutoTLSCacheDir: "./certs-cache",
})
if err != nil {
panic(err)
log.Fatal(err)
}
// Start all servers
if err := mgr.StartAll(); err != nil {
panic(err)
log.Fatal(err)
}
// Cleanup
@@ -136,7 +137,7 @@ func ExampleManager_gracefulShutdown() {
IdleTimeout: 120 * time.Second,
})
if err != nil {
panic(err)
log.Fatal(err)
}
// Start servers and block until shutdown signal (SIGINT/SIGTERM)
@@ -164,7 +165,7 @@ func ExampleManager_healthChecks() {
Handler: mux,
})
if err != nil {
panic(err)
log.Fatal(err)
}
// Add health and readiness endpoints
@@ -173,7 +174,7 @@ func ExampleManager_healthChecks() {
// Start the server
if err := mgr.StartAll(); err != nil {
panic(err)
log.Fatal(err)
}
// Health check returns:
@@ -204,7 +205,7 @@ func ExampleManager_multipleServers() {
GZIP: true,
})
if err != nil {
panic(err)
log.Fatal(err)
}
// Admin API server (different port)
@@ -218,7 +219,7 @@ func ExampleManager_multipleServers() {
Handler: adminHandler,
})
if err != nil {
panic(err)
log.Fatal(err)
}
// Metrics server (internal only)
@@ -232,18 +233,18 @@ func ExampleManager_multipleServers() {
Handler: metricsHandler,
})
if err != nil {
panic(err)
log.Fatal(err)
}
// Start all servers at once
if err := mgr.StartAll(); err != nil {
panic(err)
log.Fatal(err)
}
// Get specific server instance
publicInstance, err := mgr.Get("public-api")
if err != nil {
panic(err)
log.Fatal(err)
}
fmt.Printf("Public API running on: %s\n", publicInstance.Addr())
@@ -253,7 +254,7 @@ func ExampleManager_multipleServers() {
// Stop all servers gracefully (in parallel)
if err := mgr.StopAll(); err != nil {
panic(err)
log.Fatal(err)
}
}
@@ -273,11 +274,11 @@ func ExampleManager_monitoring() {
Handler: handler,
})
if err != nil {
panic(err)
log.Fatal(err)
}
if err := mgr.StartAll(); err != nil {
panic(err)
log.Fatal(err)
}
// Check server status