Compare commits

...

9 Commits

Author SHA1 Message Date
Hein
aa362c77da fix(cursor): trim parentheses from sort column names 2026-03-27 15:07:10 +02:00
Hein
1641eaf278 feat(resolvemcp): enhance handler with configuration support
* Introduce Config struct for BaseURL and BasePath settings
* Update handler creation functions to accept configuration
* Modify SSEServer to use dynamic base URL detection
* Adjust route setup functions to utilize BasePath from config
2026-03-27 13:56:03 +02:00
Hein
200a03c225 feat(resolvemcp): add SSE server and bunrouter setup functions
* Introduce SSEServer method for creating an SSE server bound to the handler.
* Add SetupBunRouterRoutes function to mount MCP HTTP/SSE endpoints on bunrouter.
* Update README with usage examples for new features.
2026-03-27 13:28:03 +02:00
Hein
7ef9cf39d3 style(tools): simplify string formatting in descriptions 2026-03-27 13:10:50 +02:00
Hein
7f6410f665 feat(resolvemcp): add support for join-column sorting in cursor pagination
* Enhance getCursorFilter to accept join clauses for sorting
* Update resolveColumn to handle joined columns
* Modify tests to validate new join functionality
2026-03-27 13:10:42 +02:00
Hein
835bbb0727 style(hooks): reorder fields in HookContext for consistency 2026-03-27 12:57:30 +02:00
Hein
047a1cc187 feat(resolvemcp): add hook system for model operations
* Implement hooks for CRUD operations: before/after handle, read, create, update, delete.
* Introduce HookContext and HookRegistry for managing hooks.
* Allow registration and execution of multiple hooks per operation.

feat(resolvemcp): implement MCP tools for CRUD operations
* Register tools for reading, creating, updating, and deleting records.
* Define tool arguments and handle requests with appropriate responses.
* Support for resource registration with metadata.

fix(restheadspec): enhance cursor handling for joins
* Improve cursor filter generation to support lateral joins.
* Update join alias extraction to handle lateral joins correctly.
* Ensure cursor filters do not contain empty comparisons.

test(restheadspec): add tests for cursor filters and join alias extraction
* Create tests for lateral join scenarios in cursor filter generation.
* Validate join alias extraction for various join types, including lateral joins.
2026-03-27 12:57:08 +02:00
Hein
7a498edab7 fix(headers): enhance relation name resolution logic
* Allow resolution for both regular headers and X-Files.
* Introduce join-key-aware resolution for disambiguation.
* Add new function to handle multiple fields pointing to the same type.
2026-03-25 12:09:03 +02:00
Hein
f10bb0827e fix(sql_helpers): ensure case-insensitive matching for allowed prefixes 2026-03-25 10:57:42 +02:00
19 changed files with 2631 additions and 52 deletions

View File

@@ -9,6 +9,7 @@ ResolveSpec is a flexible and powerful REST API specification and implementation
3. **FuncSpec** - Header-based API to map and call API's to sql functions 3. **FuncSpec** - Header-based API to map and call API's to sql functions
4. **WebSocketSpec** - Real-time bidirectional communication with full CRUD operations 4. **WebSocketSpec** - Real-time bidirectional communication with full CRUD operations
5. **MQTTSpec** - MQTT-based API ideal for IoT and mobile applications 5. **MQTTSpec** - MQTT-based API ideal for IoT and mobile applications
6. **ResolveMCP** - Model Context Protocol (MCP) server that exposes models as AI-consumable tools and resources over HTTP/SSE
All share the same core architecture and provide dynamic data querying, relationship preloading, and complex filtering. All share the same core architecture and provide dynamic data querying, relationship preloading, and complex filtering.
@@ -21,6 +22,7 @@ All share the same core architecture and provide dynamic data querying, relation
* [Quick Start](#quick-start) * [Quick Start](#quick-start)
* [ResolveSpec (Body-Based API)](#resolvespec---body-based-api) * [ResolveSpec (Body-Based API)](#resolvespec---body-based-api)
* [RestHeadSpec (Header-Based API)](#restheadspec---header-based-api) * [RestHeadSpec (Header-Based API)](#restheadspec---header-based-api)
* [ResolveMCP (MCP Server)](#resolvemcp---mcp-server)
* [Architecture](#architecture) * [Architecture](#architecture)
* [API Structure](#api-structure) * [API Structure](#api-structure)
* [RestHeadSpec Overview](#restheadspec-header-based-api) * [RestHeadSpec Overview](#restheadspec-header-based-api)
@@ -50,6 +52,15 @@ All share the same core architecture and provide dynamic data querying, relation
* **🆕 Backward Compatible**: Existing code works without changes * **🆕 Backward Compatible**: Existing code works without changes
* **🆕 Better Testing**: Mockable interfaces for easy unit testing * **🆕 Better Testing**: Mockable interfaces for easy unit testing
### ResolveMCP (v3.2+)
* **🆕 MCP Server**: Expose any registered database model as Model Context Protocol tools and resources
* **🆕 AI-Ready Descriptions**: Tool descriptions include the full column schema, primary key, nullable flags, and relations — giving AI models everything they need to query correctly without guessing
* **🆕 Four Tools Per Model**: `read_`, `create_`, `update_`, `delete_` tools auto-registered per model
* **🆕 Full Query Support**: Filters, sort, limit/offset, cursor pagination, column selection, and relation preloading all available as tool parameters
* **🆕 HTTP/SSE Transport**: Standards-compliant SSE transport for use with Claude Desktop, Cursor, and any MCP-compatible client
* **🆕 Lifecycle Hooks**: Same Before/After hook system as ResolveSpec for auth and side-effects
### RestHeadSpec (v2.1+) ### RestHeadSpec (v2.1+)
* **🆕 Header-Based API**: All query options passed via HTTP headers instead of request body * **🆕 Header-Based API**: All query options passed via HTTP headers instead of request body
@@ -190,6 +201,40 @@ restheadspec.SetupMuxRoutes(router, handler, nil)
For complete documentation, see [pkg/restheadspec/README.md](pkg/restheadspec/README.md). For complete documentation, see [pkg/restheadspec/README.md](pkg/restheadspec/README.md).
### ResolveMCP (MCP Server)
ResolveMCP exposes registered models as Model Context Protocol tools so AI models (Claude, Cursor, etc.) can query and mutate your database directly:
```go
import "github.com/bitechdev/ResolveSpec/pkg/resolvemcp"
// Create handler
handler := resolvemcp.NewHandlerWithGORM(db)
// Register models — must be done BEFORE Build()
handler.RegisterModel("public", "users", &User{})
handler.RegisterModel("public", "posts", &Post{})
// Finalize: registers MCP tools and resources
handler.Build()
// Mount SSE transport on your existing router
router := mux.NewRouter()
resolvemcp.SetupMuxRoutes(router, handler, "http://localhost:8080")
// MCP clients connect to:
// SSE stream: GET http://localhost:8080/mcp/sse
// Messages: POST http://localhost:8080/mcp/message
//
// Auto-registered tools per model:
// read_public_users — filter, sort, paginate, preload
// create_public_users — insert a new record
// update_public_users — update a record by ID
// delete_public_users — delete a record by ID
```
For complete documentation, see [pkg/resolvemcp/README.md](pkg/resolvemcp/README.md) (if present) or the package source.
## Architecture ## Architecture
### Two Complementary APIs ### Two Complementary APIs
@@ -344,6 +389,19 @@ Alternative REST API where query options are passed via HTTP headers.
For complete documentation, see [pkg/restheadspec/README.md](pkg/restheadspec/README.md). For complete documentation, see [pkg/restheadspec/README.md](pkg/restheadspec/README.md).
#### ResolveMCP - MCP Server
Expose any registered model as Model Context Protocol tools and resources consumable by AI models over HTTP/SSE.
**Key Features**:
- Four tools per model: `read_`, `create_`, `update_`, `delete_`
- Rich AI-readable descriptions: column names, types, primary key, nullable flags, and preloadable relations
- Full query support: filters, sort, limit/offset, cursor pagination, column selection, preloads
- HTTP/SSE transport compatible with Claude Desktop, Cursor, and any MCP client
- Same Before/After lifecycle hooks as ResolveSpec
For complete documentation, see [pkg/resolvemcp/](pkg/resolvemcp/).
#### FuncSpec - Function-Based SQL API #### FuncSpec - Function-Based SQL API
Execute SQL functions and queries through a simple HTTP API with header-based parameters. Execute SQL functions and queries through a simple HTTP API with header-based parameters.
@@ -529,7 +587,18 @@ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file
## What's New ## What's New
### v3.1 (Latest - February 2026) ### v3.2 (Latest - March 2026)
**ResolveMCP - Model Context Protocol Server (🆕)**:
* **MCP Tools**: Four tools auto-registered per model (`read_`, `create_`, `update_`, `delete_`) over HTTP/SSE transport
* **AI-Ready Descriptions**: Full column schema, primary key, nullable flags, and relation names surfaced in tool descriptions so AI models can query without guessing
* **Full Query Support**: Filters, sort, limit/offset, cursor pagination, column selection, and relation preloading all available as tool parameters
* **HTTP/SSE Transport**: Standards-compliant transport compatible with Claude Desktop, Cursor, and any MCP 2024-11-05 client
* **Lifecycle Hooks**: Same Before/After hook system as ResolveSpec for auth, auditing, and side-effects
* **MCP Resources**: Each model also exposed as a named resource for direct data access by AI clients
### v3.1 (February 2026)
**SQLite Schema Translation (🆕)**: **SQLite Schema Translation (🆕)**:

5
go.mod
View File

@@ -40,6 +40,7 @@ require (
go.opentelemetry.io/otel/trace v1.38.0 go.opentelemetry.io/otel/trace v1.38.0
go.uber.org/zap v1.27.1 go.uber.org/zap v1.27.1
golang.org/x/crypto v0.46.0 golang.org/x/crypto v0.46.0
golang.org/x/oauth2 v0.34.0
golang.org/x/time v0.14.0 golang.org/x/time v0.14.0
gorm.io/driver/postgres v1.6.0 gorm.io/driver/postgres v1.6.0
gorm.io/driver/sqlite v1.6.0 gorm.io/driver/sqlite v1.6.0
@@ -78,6 +79,7 @@ require (
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect
github.com/golang-sql/sqlexp v0.1.0 // indirect github.com/golang-sql/sqlexp v0.1.0 // indirect
github.com/golang/snappy v1.0.0 // indirect github.com/golang/snappy v1.0.0 // indirect
github.com/google/jsonschema-go v0.4.2 // indirect
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
@@ -86,6 +88,7 @@ require (
github.com/jinzhu/now v1.1.5 // indirect github.com/jinzhu/now v1.1.5 // indirect
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
github.com/magiconair/properties v1.8.10 // indirect github.com/magiconair/properties v1.8.10 // indirect
github.com/mark3labs/mcp-go v0.46.0 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-isatty v0.0.20 // indirect
github.com/moby/docker-image-spec v1.3.1 // indirect github.com/moby/docker-image-spec v1.3.1 // indirect
github.com/moby/go-archive v0.1.0 // indirect github.com/moby/go-archive v0.1.0 // indirect
@@ -131,6 +134,7 @@ require (
github.com/xdg-go/pbkdf2 v1.0.0 // indirect github.com/xdg-go/pbkdf2 v1.0.0 // indirect
github.com/xdg-go/scram v1.2.0 // indirect github.com/xdg-go/scram v1.2.0 // indirect
github.com/xdg-go/stringprep v1.0.4 // indirect github.com/xdg-go/stringprep v1.0.4 // indirect
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 // indirect github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 // indirect
github.com/yusufpapurcu/wmi v1.2.4 // indirect github.com/yusufpapurcu/wmi v1.2.4 // indirect
go.opentelemetry.io/auto/sdk v1.1.0 // indirect go.opentelemetry.io/auto/sdk v1.1.0 // indirect
@@ -143,7 +147,6 @@ require (
golang.org/x/exp v0.0.0-20251219203646-944ab1f22d93 // indirect golang.org/x/exp v0.0.0-20251219203646-944ab1f22d93 // indirect
golang.org/x/mod v0.31.0 // indirect golang.org/x/mod v0.31.0 // indirect
golang.org/x/net v0.48.0 // indirect golang.org/x/net v0.48.0 // indirect
golang.org/x/oauth2 v0.34.0 // indirect
golang.org/x/sync v0.19.0 // indirect golang.org/x/sync v0.19.0 // indirect
golang.org/x/sys v0.39.0 // indirect golang.org/x/sys v0.39.0 // indirect
golang.org/x/text v0.32.0 // indirect golang.org/x/text v0.32.0 // indirect

6
go.sum
View File

@@ -120,6 +120,8 @@ github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/jsonschema-go v0.4.2 h1:tmrUohrwoLZZS/P3x7ex0WAVknEkBZM46iALbcqoRA8=
github.com/google/jsonschema-go v0.4.2/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs= github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA= github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
@@ -173,6 +175,8 @@ github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I= github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I=
github.com/magiconair/properties v1.8.10 h1:s31yESBquKXCV9a/ScB3ESkOjUYYv+X0rg8SYxI99mE= github.com/magiconair/properties v1.8.10 h1:s31yESBquKXCV9a/ScB3ESkOjUYYv+X0rg8SYxI99mE=
github.com/magiconair/properties v1.8.10/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= github.com/magiconair/properties v1.8.10/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
github.com/mark3labs/mcp-go v0.46.0 h1:8KRibF4wcKejbLsHxCA/QBVUr5fQ9nwz/n8lGqmaALo=
github.com/mark3labs/mcp-go v0.46.0/go.mod h1:JKTC7R2LLVagkEWK7Kwu7DbmA6iIvnNAod6yrHiQMag=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-sqlite3 v1.14.33 h1:A5blZ5ulQo2AtayQ9/limgHEkFreKj1Dv226a1K73s0= github.com/mattn/go-sqlite3 v1.14.33 h1:A5blZ5ulQo2AtayQ9/limgHEkFreKj1Dv226a1K73s0=
@@ -326,6 +330,8 @@ github.com/xdg-go/scram v1.2.0 h1:bYKF2AEwG5rqd1BumT4gAnvwU/M9nBp2pTSxeZw7Wvs=
github.com/xdg-go/scram v1.2.0/go.mod h1:3dlrS0iBaWKYVt2ZfA4cj48umJZ+cAEbR6/SjLA88I8= github.com/xdg-go/scram v1.2.0/go.mod h1:3dlrS0iBaWKYVt2ZfA4cj48umJZ+cAEbR6/SjLA88I8=
github.com/xdg-go/stringprep v1.0.4 h1:XLI/Ng3O1Atzq0oBs3TWm+5ZVgkq2aqdlvP9JtoZ6c8= github.com/xdg-go/stringprep v1.0.4 h1:XLI/Ng3O1Atzq0oBs3TWm+5ZVgkq2aqdlvP9JtoZ6c8=
github.com/xdg-go/stringprep v1.0.4/go.mod h1:mPGuuIYwz7CmR2bT9j4GbQqutWS1zV24gijq1dTyGkM= github.com/xdg-go/stringprep v1.0.4/go.mod h1:mPGuuIYwz7CmR2bT9j4GbQqutWS1zV24gijq1dTyGkM=
github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 h1:ilQV1hzziu+LLM3zUTJ0trRztfwgjqKnBWNtSRkbmwM= github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 h1:ilQV1hzziu+LLM3zUTJ0trRztfwgjqKnBWNtSRkbmwM=
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78/go.mod h1:aL8wCCfTfSfmXjznFBSZNN13rSJjlIOI1fUNAtF7rmI= github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78/go.mod h1:aL8wCCfTfSfmXjznFBSZNN13rSJjlIOI1fUNAtF7rmI=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=

View File

@@ -168,16 +168,17 @@ func SanitizeWhereClause(where string, tableName string, options ...*RequestOpti
} }
// Build a set of allowed table prefixes (main table + preloaded relations) // Build a set of allowed table prefixes (main table + preloaded relations)
// Keys are stored lowercase for case-insensitive matching
allowedPrefixes := make(map[string]bool) allowedPrefixes := make(map[string]bool)
if tableName != "" { if tableName != "" {
allowedPrefixes[tableName] = true allowedPrefixes[strings.ToLower(tableName)] = true
} }
// Add preload relation names as allowed prefixes // Add preload relation names as allowed prefixes
if len(options) > 0 && options[0] != nil { if len(options) > 0 && options[0] != nil {
for pi := range options[0].Preload { for pi := range options[0].Preload {
if options[0].Preload[pi].Relation != "" { if options[0].Preload[pi].Relation != "" {
allowedPrefixes[options[0].Preload[pi].Relation] = true allowedPrefixes[strings.ToLower(options[0].Preload[pi].Relation)] = true
logger.Debug("Added preload relation '%s' as allowed table prefix", options[0].Preload[pi].Relation) logger.Debug("Added preload relation '%s' as allowed table prefix", options[0].Preload[pi].Relation)
} }
} }
@@ -185,7 +186,7 @@ func SanitizeWhereClause(where string, tableName string, options ...*RequestOpti
// Add join aliases as allowed prefixes // Add join aliases as allowed prefixes
for _, alias := range options[0].JoinAliases { for _, alias := range options[0].JoinAliases {
if alias != "" { if alias != "" {
allowedPrefixes[alias] = true allowedPrefixes[strings.ToLower(alias)] = true
logger.Debug("Added join alias '%s' as allowed table prefix", alias) logger.Debug("Added join alias '%s' as allowed table prefix", alias)
} }
} }
@@ -217,8 +218,8 @@ func SanitizeWhereClause(where string, tableName string, options ...*RequestOpti
currentPrefix, columnName := extractTableAndColumn(condToCheck) currentPrefix, columnName := extractTableAndColumn(condToCheck)
if currentPrefix != "" && columnName != "" { if currentPrefix != "" && columnName != "" {
// Check if the prefix is allowed (main table or preload relation) // Check if the prefix is allowed (main table or preload relation) - case-insensitive
if !allowedPrefixes[currentPrefix] { if !allowedPrefixes[strings.ToLower(currentPrefix)] {
// Prefix is not in the allowed list - only fix if it's a valid column in the main table // Prefix is not in the allowed list - only fix if it's a valid column in the main table
if validColumns == nil || isValidColumn(columnName, validColumns) { if validColumns == nil || isValidColumn(columnName, validColumns) {
// Replace the incorrect prefix with the correct main table name // Replace the incorrect prefix with the correct main table name

407
pkg/resolvemcp/README.md Normal file
View File

@@ -0,0 +1,407 @@
# resolvemcp
Package `resolvemcp` exposes registered database models as **Model Context Protocol (MCP) tools and resources** over HTTP/SSE transport. It mirrors the `resolvespec` package patterns — same model registration API, same filter/sort/pagination/preload options, same lifecycle hook system.
## Quick Start
```go
import (
"github.com/bitechdev/ResolveSpec/pkg/resolvemcp"
"github.com/gorilla/mux"
)
// 1. Create a handler
handler := resolvemcp.NewHandlerWithGORM(db, resolvemcp.Config{
BaseURL: "http://localhost:8080",
})
// 2. Register models
handler.RegisterModel("public", "users", &User{})
handler.RegisterModel("public", "orders", &Order{})
// 3. Mount routes
r := mux.NewRouter()
resolvemcp.SetupMuxRoutes(r, handler)
```
---
## Config
```go
type Config struct {
// BaseURL is the public-facing base URL of the server (e.g. "http://localhost:8080").
// Sent to MCP clients during the SSE handshake so they know where to POST messages.
// If empty, it is detected from each incoming request using the Host header and
// TLS state (X-Forwarded-Proto is honoured for reverse-proxy deployments).
BaseURL string
// BasePath is the URL path prefix where MCP endpoints are mounted (e.g. "/mcp").
// Required.
BasePath string
}
```
## Handler Creation
| Function | Description |
|---|---|
| `NewHandlerWithGORM(db *gorm.DB, cfg Config) *Handler` | Backed by GORM |
| `NewHandlerWithBun(db *bun.DB, cfg Config) *Handler` | Backed by Bun |
| `NewHandlerWithDB(db common.Database, cfg Config) *Handler` | Backed by any `common.Database` |
| `NewHandler(db common.Database, registry common.ModelRegistry, cfg Config) *Handler` | Full control over registry |
---
## Registering Models
```go
handler.RegisterModel(schema, entity string, model interface{}) error
```
- `schema` — database schema name (e.g. `"public"`), or empty string for no schema prefix.
- `entity` — table/entity name (e.g. `"users"`).
- `model` — a pointer to a struct (e.g. `&User{}`).
Each call immediately creates four MCP **tools** and one MCP **resource** for the model.
---
## HTTP / SSE Transport
The `*server.SSEServer` returned by any of the helpers below implements `http.Handler`, so it works with every Go HTTP framework.
`Config.BasePath` is required and used for all route registration.
`Config.BaseURL` is optional — when empty it is detected from each request.
### Gorilla Mux
```go
resolvemcp.SetupMuxRoutes(r, handler)
```
Registers:
| Route | Method | Description |
|---|---|---|
| `{BasePath}/sse` | GET | SSE connection — clients subscribe here |
| `{BasePath}/message` | POST | JSON-RPC — clients send requests here |
| `{BasePath}/*` | any | Full SSE server (convenience prefix) |
### bunrouter
```go
resolvemcp.SetupBunRouterRoutes(router, handler)
```
Registers `GET {BasePath}/sse` and `POST {BasePath}/message` on the provided `*bunrouter.Router`.
### Gin (or any `http.Handler`-compatible framework)
Use `handler.SSEServer()` to get an `http.Handler` and wrap it with the framework's adapter:
```go
sse := handler.SSEServer()
// Gin
engine.Any("/mcp/*path", gin.WrapH(sse))
// net/http
http.Handle("/mcp/", sse)
// Echo
e.Any("/mcp/*", echo.WrapHandler(sse))
```
### Authentication
Add middleware before the MCP routes. The handler itself has no auth layer.
---
## MCP Tools
### Tool Naming
```
{operation}_{schema}_{entity} // e.g. read_public_users
{operation}_{entity} // e.g. read_users (when schema is empty)
```
Operations: `read`, `create`, `update`, `delete`.
### Read Tool — `read_{schema}_{entity}`
Fetch one or many records.
| Argument | Type | Description |
|---|---|---|
| `id` | string | Primary key value. Omit to return multiple records. |
| `limit` | number | Max records per page (recommended: 10100). |
| `offset` | number | Records to skip (offset-based pagination). |
| `cursor_forward` | string | PK of the **last** record on the current page (next-page cursor). |
| `cursor_backward` | string | PK of the **first** record on the current page (prev-page cursor). |
| `columns` | array | Column names to include. Omit for all columns. |
| `omit_columns` | array | Column names to exclude. |
| `filters` | array | Filter objects (see [Filtering](#filtering)). |
| `sort` | array | Sort objects (see [Sorting](#sorting)). |
| `preloads` | array | Relation preload objects (see [Preloading](#preloading)). |
**Response:**
```json
{
"success": true,
"data": [...],
"metadata": {
"total": 100,
"filtered": 100,
"count": 10,
"limit": 10,
"offset": 0
}
}
```
### Create Tool — `create_{schema}_{entity}`
Insert one or more records.
| Argument | Type | Description |
|---|---|---|
| `data` | object \| array | Single object or array of objects to insert. |
Array input runs inside a single transaction — all succeed or all fail.
**Response:**
```json
{ "success": true, "data": { ... } }
```
### Update Tool — `update_{schema}_{entity}`
Partially update an existing record. Only non-null, non-empty fields in `data` are applied; existing values are preserved for omitted fields.
| Argument | Type | Description |
|---|---|---|
| `id` | string | Primary key of the record. Can also be included inside `data`. |
| `data` | object (required) | Fields to update. |
**Response:**
```json
{ "success": true, "data": { ...merged record... } }
```
### Delete Tool — `delete_{schema}_{entity}`
Delete a record by primary key. **Irreversible.**
| Argument | Type | Description |
|---|---|---|
| `id` | string (required) | Primary key of the record to delete. |
**Response:**
```json
{ "success": true, "data": { ...deleted record... } }
```
### Resource — `{schema}.{entity}`
Each model is also registered as an MCP resource with URI `schema.entity` (or just `entity` when schema is empty). Reading the resource returns up to 100 records as `application/json`.
---
## Filtering
Pass an array of filter objects to the `filters` argument:
```json
[
{ "column": "status", "operator": "=", "value": "active" },
{ "column": "age", "operator": ">", "value": 18, "logic_operator": "AND" },
{ "column": "role", "operator": "in", "value": ["admin", "editor"], "logic_operator": "OR" }
]
```
### Supported Operators
| Operator | Aliases | Description |
|---|---|---|
| `=` | `eq` | Equal |
| `!=` | `neq`, `<>` | Not equal |
| `>` | `gt` | Greater than |
| `>=` | `gte` | Greater than or equal |
| `<` | `lt` | Less than |
| `<=` | `lte` | Less than or equal |
| `like` | | SQL LIKE (case-sensitive) |
| `ilike` | | SQL ILIKE (case-insensitive) |
| `in` | | Value in list |
| `is_null` | | Column IS NULL |
| `is_not_null` | | Column IS NOT NULL |
### Logic Operators
- `"logic_operator": "AND"` (default) — filter is AND-chained with the previous condition.
- `"logic_operator": "OR"` — filter is OR-grouped with the previous condition.
Consecutive OR filters are grouped into a single `(cond1 OR cond2 OR ...)` clause.
---
## Sorting
```json
[
{ "column": "created_at", "direction": "desc" },
{ "column": "name", "direction": "asc" }
]
```
---
## Pagination
### Offset-Based
```json
{ "limit": 20, "offset": 40 }
```
### Cursor-Based
Cursor pagination uses a SQL `EXISTS` subquery for stable, efficient paging. Always pair with a `sort` argument.
```json
// Next page: pass the PK of the last record on the current page
{ "cursor_forward": "42", "limit": 20, "sort": [{"column": "id", "direction": "asc"}] }
// Previous page: pass the PK of the first record on the current page
{ "cursor_backward": "23", "limit": 20, "sort": [{"column": "id", "direction": "asc"}] }
```
---
## Preloading Relations
```json
[
{ "relation": "Profile" },
{ "relation": "Orders" }
]
```
Available relations are listed in each tool's description. Only relations defined on the model struct are valid.
---
## Hook System
Hooks let you intercept and modify CRUD operations at well-defined lifecycle points.
### Hook Types
| Constant | Fires |
|---|---|
| `BeforeHandle` | After model resolution, before operation dispatch (all CRUD) |
| `BeforeRead` / `AfterRead` | Around read queries |
| `BeforeCreate` / `AfterCreate` | Around insert |
| `BeforeUpdate` / `AfterUpdate` | Around update |
| `BeforeDelete` / `AfterDelete` | Around delete |
### Registering Hooks
```go
handler.Hooks().Register(resolvemcp.BeforeCreate, func(ctx *resolvemcp.HookContext) error {
// Inject a timestamp before insert
if data, ok := ctx.Data.(map[string]interface{}); ok {
data["created_at"] = time.Now()
}
return nil
})
// Register the same hook for multiple events
handler.Hooks().RegisterMultiple(
[]resolvemcp.HookType{resolvemcp.BeforeCreate, resolvemcp.BeforeUpdate},
auditHook,
)
```
### HookContext Fields
| Field | Type | Description |
|---|---|---|
| `Context` | `context.Context` | Request context |
| `Handler` | `*Handler` | The resolvemcp handler |
| `Schema` | `string` | Database schema name |
| `Entity` | `string` | Entity/table name |
| `Model` | `interface{}` | Registered model instance |
| `Options` | `common.RequestOptions` | Parsed request options (read operations) |
| `Operation` | `string` | `"read"`, `"create"`, `"update"`, or `"delete"` |
| `ID` | `string` | Primary key from request (read/update/delete) |
| `Data` | `interface{}` | Input data (create/update — modifiable) |
| `Result` | `interface{}` | Output data (set by After hooks) |
| `Error` | `error` | Operation error, if any |
| `Query` | `common.SelectQuery` | Live query object (available in `BeforeRead`) |
| `Tx` | `common.Database` | Database/transaction handle |
| `Abort` | `bool` | Set to `true` to abort the operation |
| `AbortMessage` | `string` | Error message returned when aborting |
| `AbortCode` | `int` | Optional status code for the abort |
### Aborting an Operation
```go
handler.Hooks().Register(resolvemcp.BeforeDelete, func(ctx *resolvemcp.HookContext) error {
ctx.Abort = true
ctx.AbortMessage = "deletion is disabled"
return nil
})
```
### Managing Hooks
```go
registry := handler.Hooks()
registry.HasHooks(resolvemcp.BeforeCreate) // bool
registry.Clear(resolvemcp.BeforeCreate) // remove hooks for one type
registry.ClearAll() // remove all hooks
```
---
## Context Helpers
Request metadata is threaded through `context.Context` during handler execution. Hooks and custom tools can read it:
```go
schema := resolvemcp.GetSchema(ctx)
entity := resolvemcp.GetEntity(ctx)
tableName := resolvemcp.GetTableName(ctx)
model := resolvemcp.GetModel(ctx)
modelPtr := resolvemcp.GetModelPtr(ctx)
```
You can also set values manually (e.g. in middleware):
```go
ctx = resolvemcp.WithSchema(ctx, "tenant_a")
```
---
## Adding Custom MCP Tools
Access the underlying `*server.MCPServer` to register additional tools:
```go
mcpServer := handler.MCPServer()
mcpServer.AddTool(myTool, myHandler)
```
---
## Table Name Resolution
The handler resolves table names in priority order:
1. `TableNameProvider` interface — `TableName() string` (can return `"schema.table"`)
2. `SchemaProvider` interface — `SchemaName() string` (combined with entity name)
3. Fallback: `schema.entity` (or `schema_entity` for SQLite)

71
pkg/resolvemcp/context.go Normal file
View File

@@ -0,0 +1,71 @@
package resolvemcp
import "context"
type contextKey string
const (
contextKeySchema contextKey = "schema"
contextKeyEntity contextKey = "entity"
contextKeyTableName contextKey = "tableName"
contextKeyModel contextKey = "model"
contextKeyModelPtr contextKey = "modelPtr"
)
func WithSchema(ctx context.Context, schema string) context.Context {
return context.WithValue(ctx, contextKeySchema, schema)
}
func GetSchema(ctx context.Context) string {
if v := ctx.Value(contextKeySchema); v != nil {
return v.(string)
}
return ""
}
func WithEntity(ctx context.Context, entity string) context.Context {
return context.WithValue(ctx, contextKeyEntity, entity)
}
func GetEntity(ctx context.Context) string {
if v := ctx.Value(contextKeyEntity); v != nil {
return v.(string)
}
return ""
}
func WithTableName(ctx context.Context, tableName string) context.Context {
return context.WithValue(ctx, contextKeyTableName, tableName)
}
func GetTableName(ctx context.Context) string {
if v := ctx.Value(contextKeyTableName); v != nil {
return v.(string)
}
return ""
}
func WithModel(ctx context.Context, model interface{}) context.Context {
return context.WithValue(ctx, contextKeyModel, model)
}
func GetModel(ctx context.Context) interface{} {
return ctx.Value(contextKeyModel)
}
func WithModelPtr(ctx context.Context, modelPtr interface{}) context.Context {
return context.WithValue(ctx, contextKeyModelPtr, modelPtr)
}
func GetModelPtr(ctx context.Context) interface{} {
return ctx.Value(contextKeyModelPtr)
}
func withRequestData(ctx context.Context, schema, entity, tableName string, model, modelPtr interface{}) context.Context {
ctx = WithSchema(ctx, schema)
ctx = WithEntity(ctx, entity)
ctx = WithTableName(ctx, tableName)
ctx = WithModel(ctx, model)
ctx = WithModelPtr(ctx, modelPtr)
return ctx
}

161
pkg/resolvemcp/cursor.go Normal file
View File

@@ -0,0 +1,161 @@
package resolvemcp
// Cursor-based pagination adapted from pkg/resolvespec/cursor.go.
import (
"fmt"
"strings"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/logger"
)
type cursorDirection int
const (
cursorForward cursorDirection = 1
cursorBackward cursorDirection = -1
)
// getCursorFilter generates a SQL EXISTS subquery for cursor-based pagination.
// expandJoins is an optional map[alias]string of JOIN clauses for join-column sort support.
func getCursorFilter(
tableName string,
pkName string,
modelColumns []string,
options common.RequestOptions,
expandJoins map[string]string,
) (string, error) {
fullTableName := tableName
if strings.Contains(tableName, ".") {
tableName = strings.SplitN(tableName, ".", 2)[1]
}
cursorID, direction := getActiveCursor(options)
if cursorID == "" {
return "", fmt.Errorf("no cursor provided for table %s", tableName)
}
sortItems := options.Sort
if len(sortItems) == 0 {
return "", fmt.Errorf("no sort columns defined")
}
var whereClauses []string
joinSQL := ""
reverse := direction < 0
for _, s := range sortItems {
col := strings.Trim(strings.TrimSpace(s.Column), "()")
if col == "" {
continue
}
parts := strings.Split(col, ".")
field := strings.TrimSpace(parts[len(parts)-1])
prefix := strings.Join(parts[:len(parts)-1], ".")
desc := strings.EqualFold(s.Direction, "desc")
if reverse {
desc = !desc
}
cursorCol, targetCol, isJoin, err := resolveCursorColumn(field, prefix, tableName, modelColumns)
if err != nil {
logger.Warn("Skipping invalid sort column %q: %v", col, err)
continue
}
if isJoin {
if expandJoins != nil {
if joinClause, ok := expandJoins[prefix]; ok {
jSQL, cRef := rewriteCursorJoin(joinClause, tableName, prefix)
joinSQL = jSQL
cursorCol = cRef + "." + field
targetCol = prefix + "." + field
}
}
if cursorCol == "" {
logger.Warn("Skipping cursor sort column %q: join alias %q not in expandJoins", col, prefix)
continue
}
}
op := "<"
if desc {
op = ">"
}
whereClauses = append(whereClauses, fmt.Sprintf("%s %s %s", cursorCol, op, targetCol))
}
if len(whereClauses) == 0 {
return "", fmt.Errorf("no valid sort columns after filtering")
}
orSQL := buildCursorPriorityChain(whereClauses)
query := fmt.Sprintf(`EXISTS (
SELECT 1
FROM %s cursor_select
%s
WHERE cursor_select.%s = %s
AND (%s)
)`,
fullTableName,
joinSQL,
pkName,
cursorID,
orSQL,
)
return query, nil
}
func getActiveCursor(options common.RequestOptions) (id string, direction cursorDirection) {
if options.CursorForward != "" {
return options.CursorForward, cursorForward
}
if options.CursorBackward != "" {
return options.CursorBackward, cursorBackward
}
return "", 0
}
func resolveCursorColumn(field, prefix, tableName string, modelColumns []string) (cursorCol, targetCol string, isJoin bool, err error) {
if strings.Contains(field, "->") {
return "cursor_select." + field, tableName + "." + field, false, nil
}
if modelColumns != nil {
for _, col := range modelColumns {
if strings.EqualFold(col, field) {
return "cursor_select." + field, tableName + "." + field, false, nil
}
}
} else {
return "cursor_select." + field, tableName + "." + field, false, nil
}
if prefix != "" && prefix != tableName {
return "", "", true, nil
}
return "", "", false, fmt.Errorf("invalid column: %s", field)
}
func rewriteCursorJoin(joinClause, mainTable, alias string) (joinSQL, cursorAlias string) {
joinSQL = strings.ReplaceAll(joinClause, mainTable+".", "cursor_select.")
cursorAlias = "cursor_select_" + alias
joinSQL = strings.ReplaceAll(joinSQL, " "+alias+" ", " "+cursorAlias+" ")
joinSQL = strings.ReplaceAll(joinSQL, " "+alias+".", " "+cursorAlias+".")
return joinSQL, cursorAlias
}
func buildCursorPriorityChain(clauses []string) string {
var or []string
for i := 0; i < len(clauses); i++ {
and := strings.Join(clauses[:i+1], "\n AND ")
or = append(or, "("+and+")")
}
return strings.Join(or, "\n OR ")
}

706
pkg/resolvemcp/handler.go Normal file
View File

@@ -0,0 +1,706 @@
package resolvemcp
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"net/http"
"reflect"
"strings"
"sync"
"github.com/mark3labs/mcp-go/server"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/reflection"
)
// Handler exposes registered database models as MCP tools and resources.
type Handler struct {
db common.Database
registry common.ModelRegistry
hooks *HookRegistry
mcpServer *server.MCPServer
config Config
name string
version string
}
// NewHandler creates a Handler with the given database, model registry, and config.
func NewHandler(db common.Database, registry common.ModelRegistry, cfg Config) *Handler {
return &Handler{
db: db,
registry: registry,
hooks: NewHookRegistry(),
mcpServer: server.NewMCPServer("resolvemcp", "1.0.0"),
config: cfg,
name: "resolvemcp",
version: "1.0.0",
}
}
// Hooks returns the hook registry.
func (h *Handler) Hooks() *HookRegistry {
return h.hooks
}
// GetDatabase returns the underlying database.
func (h *Handler) GetDatabase() common.Database {
return h.db
}
// MCPServer returns the underlying MCP server, e.g. to add custom tools.
func (h *Handler) MCPServer() *server.MCPServer {
return h.mcpServer
}
// SSEServer returns an http.Handler that serves MCP over SSE.
// Config.BasePath must be set. Config.BaseURL is used when set; if empty it is
// detected automatically from each incoming request.
func (h *Handler) SSEServer() http.Handler {
if h.config.BaseURL != "" {
return h.newSSEServer(h.config.BaseURL, h.config.BasePath)
}
return &dynamicSSEHandler{h: h}
}
// newSSEServer creates a concrete *server.SSEServer for known baseURL and basePath values.
func (h *Handler) newSSEServer(baseURL, basePath string) *server.SSEServer {
return server.NewSSEServer(
h.mcpServer,
server.WithBaseURL(baseURL),
server.WithBasePath(basePath),
)
}
// dynamicSSEHandler detects BaseURL from each request and delegates to a cached
// *server.SSEServer per detected baseURL. Used when Config.BaseURL is empty.
type dynamicSSEHandler struct {
h *Handler
mu sync.Mutex
pool map[string]*server.SSEServer
}
func (d *dynamicSSEHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
baseURL := requestBaseURL(r)
d.mu.Lock()
if d.pool == nil {
d.pool = make(map[string]*server.SSEServer)
}
s, ok := d.pool[baseURL]
if !ok {
s = d.h.newSSEServer(baseURL, d.h.config.BasePath)
d.pool[baseURL] = s
}
d.mu.Unlock()
s.ServeHTTP(w, r)
}
// requestBaseURL builds the base URL from an incoming request.
// It honours the X-Forwarded-Proto header for deployments behind a proxy.
func requestBaseURL(r *http.Request) string {
scheme := "http"
if r.TLS != nil {
scheme = "https"
}
if proto := r.Header.Get("X-Forwarded-Proto"); proto != "" {
scheme = proto
}
return scheme + "://" + r.Host
}
// RegisterModel registers a model and immediately exposes it as MCP tools and a resource.
func (h *Handler) RegisterModel(schema, entity string, model interface{}) error {
fullName := buildModelName(schema, entity)
if err := h.registry.RegisterModel(fullName, model); err != nil {
return err
}
registerModelTools(h, schema, entity, model)
return nil
}
// buildModelName builds the registry key for a model (same format as resolvespec).
func buildModelName(schema, entity string) string {
if schema == "" {
return entity
}
return fmt.Sprintf("%s.%s", schema, entity)
}
// getTableName returns the fully qualified table name for a model.
func (h *Handler) getTableName(schema, entity string, model interface{}) string {
schemaName, tableName := h.getSchemaAndTable(schema, entity, model)
if schemaName != "" {
if h.db.DriverName() == "sqlite" {
return fmt.Sprintf("%s_%s", schemaName, tableName)
}
return fmt.Sprintf("%s.%s", schemaName, tableName)
}
return tableName
}
func (h *Handler) getSchemaAndTable(defaultSchema, entity string, model interface{}) (schema, table string) {
if tableProvider, ok := model.(common.TableNameProvider); ok {
tableName := tableProvider.TableName()
if idx := strings.LastIndex(tableName, "."); idx != -1 {
return tableName[:idx], tableName[idx+1:]
}
if schemaProvider, ok := model.(common.SchemaProvider); ok {
return schemaProvider.SchemaName(), tableName
}
return defaultSchema, tableName
}
if schemaProvider, ok := model.(common.SchemaProvider); ok {
return schemaProvider.SchemaName(), entity
}
return defaultSchema, entity
}
// executeRead reads records from the database and returns raw data + metadata.
func (h *Handler) executeRead(ctx context.Context, schema, entity, id string, options common.RequestOptions) (interface{}, *common.Metadata, error) {
model, err := h.registry.GetModelByEntity(schema, entity)
if err != nil {
return nil, nil, fmt.Errorf("model not found: %w", err)
}
unwrapped, err := common.ValidateAndUnwrapModel(model)
if err != nil {
return nil, nil, fmt.Errorf("invalid model: %w", err)
}
model = unwrapped.Model
modelType := unwrapped.ModelType
tableName := h.getTableName(schema, entity, model)
ctx = withRequestData(ctx, schema, entity, tableName, model, unwrapped.ModelPtr)
validator := common.NewColumnValidator(model)
options = validator.FilterRequestOptions(options)
// BeforeHandle hook
hookCtx := &HookContext{
Context: ctx,
Handler: h,
Schema: schema,
Entity: entity,
Model: model,
Operation: "read",
Options: options,
ID: id,
Tx: h.db,
}
if err := h.hooks.Execute(BeforeHandle, hookCtx); err != nil {
return nil, nil, err
}
sliceType := reflect.SliceOf(reflect.PointerTo(modelType))
modelPtr := reflect.New(sliceType).Interface()
query := h.db.NewSelect().Model(modelPtr)
tempInstance := reflect.New(modelType).Interface()
if provider, ok := tempInstance.(common.TableNameProvider); !ok || provider.TableName() == "" {
query = query.Table(tableName)
}
// Column selection
if len(options.Columns) == 0 && len(options.ComputedColumns) > 0 {
options.Columns = reflection.GetSQLModelColumns(model)
}
for _, col := range options.Columns {
query = query.Column(reflection.ExtractSourceColumn(col))
}
for _, cu := range options.ComputedColumns {
query = query.ColumnExpr(fmt.Sprintf("(%s) AS %s", cu.Expression, cu.Name))
}
// Preloads
if len(options.Preload) > 0 {
var err error
query, err = h.applyPreloads(model, query, options.Preload)
if err != nil {
return nil, nil, fmt.Errorf("failed to apply preloads: %w", err)
}
}
// Filters
query = h.applyFilters(query, options.Filters)
// Custom operators
for _, customOp := range options.CustomOperators {
query = query.Where(customOp.SQL)
}
// Sorting
for _, sort := range options.Sort {
direction := "ASC"
if strings.EqualFold(sort.Direction, "desc") {
direction = "DESC"
}
query = query.Order(fmt.Sprintf("%s %s", sort.Column, direction))
}
// Cursor pagination
if options.CursorForward != "" || options.CursorBackward != "" {
pkName := reflection.GetPrimaryKeyName(model)
modelColumns := reflection.GetModelColumns(model)
if len(options.Sort) == 0 {
options.Sort = []common.SortOption{{Column: pkName, Direction: "ASC"}}
}
// expandJoins is empty for resolvemcp — no custom SQL join support yet
cursorFilter, err := getCursorFilter(tableName, pkName, modelColumns, options, nil)
if err != nil {
return nil, nil, fmt.Errorf("cursor error: %w", err)
}
if cursorFilter != "" {
sanitized := common.SanitizeWhereClause(cursorFilter, reflection.ExtractTableNameOnly(tableName), &options)
sanitized = common.EnsureOuterParentheses(sanitized)
if sanitized != "" {
query = query.Where(sanitized)
}
}
}
// Count
total, err := query.Count(ctx)
if err != nil {
return nil, nil, fmt.Errorf("error counting records: %w", err)
}
// Pagination
if options.Limit != nil && *options.Limit > 0 {
query = query.Limit(*options.Limit)
}
if options.Offset != nil && *options.Offset > 0 {
query = query.Offset(*options.Offset)
}
// BeforeRead hook
hookCtx.Query = query
if err := h.hooks.Execute(BeforeRead, hookCtx); err != nil {
return nil, nil, err
}
var data interface{}
if id != "" {
singleResult := reflect.New(modelType).Interface()
pkName := reflection.GetPrimaryKeyName(singleResult)
query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id)
if err := query.Scan(ctx, singleResult); err != nil {
if err == sql.ErrNoRows {
return nil, nil, fmt.Errorf("record not found")
}
return nil, nil, fmt.Errorf("query error: %w", err)
}
data = singleResult
} else {
if err := query.Scan(ctx, modelPtr); err != nil {
return nil, nil, fmt.Errorf("query error: %w", err)
}
data = reflect.ValueOf(modelPtr).Elem().Interface()
}
limit := 0
offset := 0
if options.Limit != nil {
limit = *options.Limit
}
if options.Offset != nil {
offset = *options.Offset
}
// Count is the number of records in this page, not the total.
var pageCount int64
if id != "" {
pageCount = 1
} else {
pageCount = int64(reflect.ValueOf(data).Len())
}
metadata := &common.Metadata{
Total: int64(total),
Filtered: int64(total),
Count: pageCount,
Limit: limit,
Offset: offset,
}
// AfterRead hook
hookCtx.Result = data
if err := h.hooks.Execute(AfterRead, hookCtx); err != nil {
return nil, nil, err
}
return data, metadata, nil
}
// executeCreate inserts one or more records.
func (h *Handler) executeCreate(ctx context.Context, schema, entity string, data interface{}) (interface{}, error) {
model, err := h.registry.GetModelByEntity(schema, entity)
if err != nil {
return nil, fmt.Errorf("model not found: %w", err)
}
result, err := common.ValidateAndUnwrapModel(model)
if err != nil {
return nil, fmt.Errorf("invalid model: %w", err)
}
model = result.Model
tableName := h.getTableName(schema, entity, model)
ctx = withRequestData(ctx, schema, entity, tableName, model, result.ModelPtr)
hookCtx := &HookContext{
Context: ctx,
Handler: h,
Schema: schema,
Entity: entity,
Model: model,
Operation: "create",
Data: data,
Tx: h.db,
}
if err := h.hooks.Execute(BeforeHandle, hookCtx); err != nil {
return nil, err
}
if err := h.hooks.Execute(BeforeCreate, hookCtx); err != nil {
return nil, err
}
// Use potentially modified data
data = hookCtx.Data
switch v := data.(type) {
case map[string]interface{}:
query := h.db.NewInsert().Table(tableName)
for key, value := range v {
query = query.Value(key, value)
}
if _, err := query.Exec(ctx); err != nil {
return nil, fmt.Errorf("create error: %w", err)
}
hookCtx.Result = v
if err := h.hooks.Execute(AfterCreate, hookCtx); err != nil {
return nil, fmt.Errorf("AfterCreate hook failed: %w", err)
}
return v, nil
case []interface{}:
results := make([]interface{}, 0, len(v))
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
for _, item := range v {
itemMap, ok := item.(map[string]interface{})
if !ok {
return fmt.Errorf("each item must be an object")
}
q := tx.NewInsert().Table(tableName)
for key, value := range itemMap {
q = q.Value(key, value)
}
if _, err := q.Exec(ctx); err != nil {
return err
}
results = append(results, item)
}
return nil
})
if err != nil {
return nil, fmt.Errorf("batch create error: %w", err)
}
hookCtx.Result = results
if err := h.hooks.Execute(AfterCreate, hookCtx); err != nil {
return nil, fmt.Errorf("AfterCreate hook failed: %w", err)
}
return results, nil
default:
return nil, fmt.Errorf("data must be an object or array of objects")
}
}
// executeUpdate updates a record by ID.
func (h *Handler) executeUpdate(ctx context.Context, schema, entity, id string, data interface{}) (interface{}, error) {
model, err := h.registry.GetModelByEntity(schema, entity)
if err != nil {
return nil, fmt.Errorf("model not found: %w", err)
}
result, err := common.ValidateAndUnwrapModel(model)
if err != nil {
return nil, fmt.Errorf("invalid model: %w", err)
}
model = result.Model
tableName := h.getTableName(schema, entity, model)
ctx = withRequestData(ctx, schema, entity, tableName, model, result.ModelPtr)
updates, ok := data.(map[string]interface{})
if !ok {
return nil, fmt.Errorf("data must be an object")
}
if id == "" {
if idVal, exists := updates["id"]; exists {
id = fmt.Sprintf("%v", idVal)
}
}
if id == "" {
return nil, fmt.Errorf("update requires an ID")
}
pkName := reflection.GetPrimaryKeyName(model)
var updateResult interface{}
err = h.db.RunInTransaction(ctx, func(tx common.Database) error {
// Read existing record
modelType := reflect.TypeOf(model)
if modelType.Kind() == reflect.Ptr {
modelType = modelType.Elem()
}
existingRecord := reflect.New(modelType).Interface()
selectQuery := tx.NewSelect().Model(existingRecord).Column("*").
Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id)
if err := selectQuery.ScanModel(ctx); err != nil {
if err == sql.ErrNoRows {
return fmt.Errorf("no records found to update")
}
return fmt.Errorf("error fetching existing record: %w", err)
}
// Convert to map
existingMap := make(map[string]interface{})
jsonData, err := json.Marshal(existingRecord)
if err != nil {
return fmt.Errorf("error marshaling existing record: %w", err)
}
if err := json.Unmarshal(jsonData, &existingMap); err != nil {
return fmt.Errorf("error unmarshaling existing record: %w", err)
}
hookCtx := &HookContext{
Context: ctx,
Handler: h,
Schema: schema,
Entity: entity,
Model: model,
Operation: "update",
ID: id,
Data: updates,
Tx: tx,
}
if err := h.hooks.Execute(BeforeUpdate, hookCtx); err != nil {
return err
}
if modifiedData, ok := hookCtx.Data.(map[string]interface{}); ok {
updates = modifiedData
}
// Merge non-nil, non-empty values
for key, newValue := range updates {
if newValue == nil {
continue
}
if strVal, ok := newValue.(string); ok && strVal == "" {
continue
}
existingMap[key] = newValue
}
q := tx.NewUpdate().Table(tableName).SetMap(existingMap).
Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id)
res, err := q.Exec(ctx)
if err != nil {
return fmt.Errorf("error updating record: %w", err)
}
if res.RowsAffected() == 0 {
return fmt.Errorf("no records found to update")
}
updateResult = existingMap
hookCtx.Result = updateResult
return h.hooks.Execute(AfterUpdate, hookCtx)
})
if err != nil {
return nil, err
}
return updateResult, nil
}
// executeDelete deletes a record by ID.
func (h *Handler) executeDelete(ctx context.Context, schema, entity, id string) (interface{}, error) {
if id == "" {
return nil, fmt.Errorf("delete requires an ID")
}
model, err := h.registry.GetModelByEntity(schema, entity)
if err != nil {
return nil, fmt.Errorf("model not found: %w", err)
}
result, err := common.ValidateAndUnwrapModel(model)
if err != nil {
return nil, fmt.Errorf("invalid model: %w", err)
}
model = result.Model
tableName := h.getTableName(schema, entity, model)
ctx = withRequestData(ctx, schema, entity, tableName, model, result.ModelPtr)
pkName := reflection.GetPrimaryKeyName(model)
hookCtx := &HookContext{
Context: ctx,
Handler: h,
Schema: schema,
Entity: entity,
Model: model,
Operation: "delete",
ID: id,
Tx: h.db,
}
if err := h.hooks.Execute(BeforeHandle, hookCtx); err != nil {
return nil, err
}
if err := h.hooks.Execute(BeforeDelete, hookCtx); err != nil {
return nil, err
}
modelType := reflect.TypeOf(model)
if modelType.Kind() == reflect.Ptr {
modelType = modelType.Elem()
}
var recordToDelete interface{}
err = h.db.RunInTransaction(ctx, func(tx common.Database) error {
record := reflect.New(modelType).Interface()
selectQuery := tx.NewSelect().Model(record).
Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id)
if err := selectQuery.ScanModel(ctx); err != nil {
if err == sql.ErrNoRows {
return fmt.Errorf("record not found")
}
return fmt.Errorf("error fetching record: %w", err)
}
res, err := tx.NewDelete().Table(tableName).
Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id).
Exec(ctx)
if err != nil {
return fmt.Errorf("delete error: %w", err)
}
if res.RowsAffected() == 0 {
return fmt.Errorf("record not found or already deleted")
}
recordToDelete = record
hookCtx.Tx = tx
hookCtx.Result = record
return h.hooks.Execute(AfterDelete, hookCtx)
})
if err != nil {
return nil, err
}
logger.Info("[resolvemcp] Deleted record %s from %s.%s", id, schema, entity)
return recordToDelete, nil
}
// applyFilters applies all filters with OR grouping logic.
func (h *Handler) applyFilters(query common.SelectQuery, filters []common.FilterOption) common.SelectQuery {
if len(filters) == 0 {
return query
}
i := 0
for i < len(filters) {
startORGroup := i+1 < len(filters) && strings.EqualFold(filters[i+1].LogicOperator, "OR")
if startORGroup {
orGroup := []common.FilterOption{filters[i]}
j := i + 1
for j < len(filters) && strings.EqualFold(filters[j].LogicOperator, "OR") {
orGroup = append(orGroup, filters[j])
j++
}
query = h.applyFilterGroup(query, orGroup)
i = j
} else {
condition, args := h.buildFilterCondition(filters[i])
if condition != "" {
query = query.Where(condition, args...)
}
i++
}
}
return query
}
func (h *Handler) applyFilterGroup(query common.SelectQuery, filters []common.FilterOption) common.SelectQuery {
var conditions []string
var args []interface{}
for _, filter := range filters {
condition, filterArgs := h.buildFilterCondition(filter)
if condition != "" {
conditions = append(conditions, condition)
args = append(args, filterArgs...)
}
}
if len(conditions) == 0 {
return query
}
if len(conditions) == 1 {
return query.Where(conditions[0], args...)
}
return query.Where("("+strings.Join(conditions, " OR ")+")", args...)
}
func (h *Handler) buildFilterCondition(filter common.FilterOption) (string, []interface{}) {
switch filter.Operator {
case "eq", "=":
return fmt.Sprintf("%s = ?", filter.Column), []interface{}{filter.Value}
case "neq", "!=", "<>":
return fmt.Sprintf("%s != ?", filter.Column), []interface{}{filter.Value}
case "gt", ">":
return fmt.Sprintf("%s > ?", filter.Column), []interface{}{filter.Value}
case "gte", ">=":
return fmt.Sprintf("%s >= ?", filter.Column), []interface{}{filter.Value}
case "lt", "<":
return fmt.Sprintf("%s < ?", filter.Column), []interface{}{filter.Value}
case "lte", "<=":
return fmt.Sprintf("%s <= ?", filter.Column), []interface{}{filter.Value}
case "like":
return fmt.Sprintf("%s LIKE ?", filter.Column), []interface{}{filter.Value}
case "ilike":
return fmt.Sprintf("%s ILIKE ?", filter.Column), []interface{}{filter.Value}
case "in":
condition, args := common.BuildInCondition(filter.Column, filter.Value)
return condition, args
case "is_null":
return fmt.Sprintf("%s IS NULL", filter.Column), nil
case "is_not_null":
return fmt.Sprintf("%s IS NOT NULL", filter.Column), nil
}
return "", nil
}
func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, preloads []common.PreloadOption) (common.SelectQuery, error) {
for _, preload := range preloads {
if preload.Relation == "" {
continue
}
query = query.PreloadRelation(preload.Relation)
}
return query, nil
}

113
pkg/resolvemcp/hooks.go Normal file
View File

@@ -0,0 +1,113 @@
package resolvemcp
import (
"context"
"fmt"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/logger"
)
// HookType defines the type of hook to execute
type HookType string
const (
// BeforeHandle fires after model resolution, before operation dispatch.
BeforeHandle HookType = "before_handle"
BeforeRead HookType = "before_read"
AfterRead HookType = "after_read"
BeforeCreate HookType = "before_create"
AfterCreate HookType = "after_create"
BeforeUpdate HookType = "before_update"
AfterUpdate HookType = "after_update"
BeforeDelete HookType = "before_delete"
AfterDelete HookType = "after_delete"
)
// HookContext contains all the data available to a hook
type HookContext struct {
Context context.Context
Handler *Handler
Schema string
Entity string
Model interface{}
Options common.RequestOptions
Operation string
ID string
Data interface{}
Result interface{}
Error error
Query common.SelectQuery
Abort bool
AbortMessage string
AbortCode int
Tx common.Database
}
// HookFunc is the signature for hook functions
type HookFunc func(*HookContext) error
// HookRegistry manages all registered hooks
type HookRegistry struct {
hooks map[HookType][]HookFunc
}
func NewHookRegistry() *HookRegistry {
return &HookRegistry{
hooks: make(map[HookType][]HookFunc),
}
}
func (r *HookRegistry) Register(hookType HookType, hook HookFunc) {
if r.hooks == nil {
r.hooks = make(map[HookType][]HookFunc)
}
r.hooks[hookType] = append(r.hooks[hookType], hook)
logger.Info("Registered resolvemcp hook for %s (total: %d)", hookType, len(r.hooks[hookType]))
}
func (r *HookRegistry) RegisterMultiple(hookTypes []HookType, hook HookFunc) {
for _, hookType := range hookTypes {
r.Register(hookType, hook)
}
}
func (r *HookRegistry) Execute(hookType HookType, ctx *HookContext) error {
hooks, exists := r.hooks[hookType]
if !exists || len(hooks) == 0 {
return nil
}
logger.Debug("Executing %d resolvemcp hook(s) for %s", len(hooks), hookType)
for i, hook := range hooks {
if err := hook(ctx); err != nil {
logger.Error("resolvemcp hook %d for %s failed: %v", i+1, hookType, err)
return fmt.Errorf("hook execution failed: %w", err)
}
if ctx.Abort {
logger.Warn("resolvemcp hook %d for %s requested abort: %s", i+1, hookType, ctx.AbortMessage)
return fmt.Errorf("operation aborted by hook: %s", ctx.AbortMessage)
}
}
return nil
}
func (r *HookRegistry) Clear(hookType HookType) {
delete(r.hooks, hookType)
}
func (r *HookRegistry) ClearAll() {
r.hooks = make(map[HookType][]HookFunc)
}
func (r *HookRegistry) HasHooks(hookType HookType) bool {
hooks, exists := r.hooks[hookType]
return exists && len(hooks) > 0
}

View File

@@ -0,0 +1,100 @@
// Package resolvemcp exposes registered database models as Model Context Protocol (MCP) tools
// and resources over HTTP/SSE transport.
//
// It mirrors the resolvespec package patterns:
// - Same model registration API
// - Same filter, sort, cursor pagination, preload options
// - Same lifecycle hook system
//
// Usage:
//
// handler := resolvemcp.NewHandlerWithGORM(db, resolvemcp.Config{BaseURL: "http://localhost:8080"})
// handler.RegisterModel("public", "users", &User{})
//
// r := mux.NewRouter()
// resolvemcp.SetupMuxRoutes(r, handler)
package resolvemcp
import (
"net/http"
"github.com/gorilla/mux"
"github.com/uptrace/bun"
bunrouter "github.com/uptrace/bunrouter"
"gorm.io/gorm"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/common/adapters/database"
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
)
// Config holds configuration for the resolvemcp handler.
type Config struct {
// BaseURL is the public-facing base URL of the server (e.g. "http://localhost:8080").
// It is sent to MCP clients during the SSE handshake so they know where to POST messages.
BaseURL string
// BasePath is the URL path prefix where the MCP endpoints are mounted (e.g. "/mcp").
// If empty, the path is detected from each incoming request automatically.
BasePath string
}
// NewHandlerWithGORM creates a Handler backed by a GORM database connection.
func NewHandlerWithGORM(db *gorm.DB, cfg Config) *Handler {
return NewHandler(database.NewGormAdapter(db), modelregistry.NewModelRegistry(), cfg)
}
// NewHandlerWithBun creates a Handler backed by a Bun database connection.
func NewHandlerWithBun(db *bun.DB, cfg Config) *Handler {
return NewHandler(database.NewBunAdapter(db), modelregistry.NewModelRegistry(), cfg)
}
// NewHandlerWithDB creates a Handler using an existing common.Database and a new registry.
func NewHandlerWithDB(db common.Database, cfg Config) *Handler {
return NewHandler(db, modelregistry.NewModelRegistry(), cfg)
}
// SetupMuxRoutes mounts the MCP HTTP/SSE endpoints on the given Gorilla Mux router
// using the base path from Config.BasePath (falls back to "/mcp" if empty).
//
// Two routes are registered:
// - GET {basePath}/sse — SSE connection endpoint (client subscribes here)
// - POST {basePath}/message — JSON-RPC message endpoint (client sends requests here)
//
// To protect these routes with authentication, wrap the mux router or apply middleware
// before calling SetupMuxRoutes.
func SetupMuxRoutes(muxRouter *mux.Router, handler *Handler) {
basePath := handler.config.BasePath
h := handler.SSEServer()
muxRouter.Handle(basePath+"/sse", h).Methods("GET", "OPTIONS")
muxRouter.Handle(basePath+"/message", h).Methods("POST", "OPTIONS")
// Convenience: also expose the full SSE server at basePath for clients that
// use ServeHTTP directly (e.g. net/http default mux).
muxRouter.PathPrefix(basePath).Handler(http.StripPrefix(basePath, h))
}
// SetupBunRouterRoutes mounts the MCP HTTP/SSE endpoints on a bunrouter router
// using the base path from Config.BasePath.
//
// Two routes are registered:
// - GET {basePath}/sse — SSE connection endpoint
// - POST {basePath}/message — JSON-RPC message endpoint
func SetupBunRouterRoutes(router *bunrouter.Router, handler *Handler) {
basePath := handler.config.BasePath
h := handler.SSEServer()
router.GET(basePath+"/sse", bunrouter.HTTPHandler(h))
router.POST(basePath+"/message", bunrouter.HTTPHandler(h))
}
// NewSSEServer returns an http.Handler that serves MCP over SSE.
// If Config.BasePath is set it is used directly; otherwise the base path is
// detected from each incoming request (by stripping the "/sse" or "/message" suffix).
//
// h := resolvemcp.NewSSEServer(handler)
// http.Handle("/api/mcp/", h)
func NewSSEServer(handler *Handler) http.Handler {
return handler.SSEServer()
}

692
pkg/resolvemcp/tools.go Normal file
View File

@@ -0,0 +1,692 @@
package resolvemcp
import (
"context"
"encoding/json"
"fmt"
"reflect"
"strings"
"github.com/mark3labs/mcp-go/mcp"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/reflection"
)
// toolName builds the MCP tool name for a given operation and model.
func toolName(operation, schema, entity string) string {
if schema == "" {
return fmt.Sprintf("%s_%s", operation, entity)
}
return fmt.Sprintf("%s_%s_%s", operation, schema, entity)
}
// registerModelTools registers the four CRUD tools and resource for a model.
func registerModelTools(h *Handler, schema, entity string, model interface{}) {
info := buildModelInfo(schema, entity, model)
registerReadTool(h, schema, entity, info)
registerCreateTool(h, schema, entity, info)
registerUpdateTool(h, schema, entity, info)
registerDeleteTool(h, schema, entity, info)
registerModelResource(h, schema, entity, info)
logger.Info("[resolvemcp] Registered MCP tools for %s", info.fullName)
}
// --------------------------------------------------------------------------
// Model introspection
// --------------------------------------------------------------------------
// modelInfo holds pre-computed metadata for a model used in tool descriptions.
type modelInfo struct {
fullName string // e.g. "public.users"
pkName string // e.g. "id"
columns []columnInfo
relationNames []string
schemaDoc string // formatted multi-line schema listing
}
type columnInfo struct {
jsonName string
sqlName string
goType string
sqlType string
isPrimary bool
isUnique bool
isFK bool
nullable bool
}
// buildModelInfo extracts column metadata and pre-builds the schema documentation string.
func buildModelInfo(schema, entity string, model interface{}) modelInfo {
info := modelInfo{
fullName: buildModelName(schema, entity),
pkName: reflection.GetPrimaryKeyName(model),
}
// Unwrap to base struct type
modelType := reflect.TypeOf(model)
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice) {
modelType = modelType.Elem()
}
if modelType == nil || modelType.Kind() != reflect.Struct {
return info
}
details := reflection.GetModelColumnDetail(reflect.New(modelType).Elem())
for _, d := range details {
// Derive the JSON name from the struct field
jsonName := fieldJSONName(modelType, d.Name)
if jsonName == "" || jsonName == "-" {
continue
}
// Skip relation fields (slice or user-defined struct that isn't time.Time).
fieldType, found := modelType.FieldByName(d.Name)
if found {
ft := fieldType.Type
if ft.Kind() == reflect.Ptr {
ft = ft.Elem()
}
isUserStruct := ft.Kind() == reflect.Struct && ft.Name() != "Time" && ft.PkgPath() != ""
if ft.Kind() == reflect.Slice || isUserStruct {
info.relationNames = append(info.relationNames, jsonName)
continue
}
}
sqlName := d.SQLName
if sqlName == "" {
sqlName = jsonName
}
// Derive Go type name, unwrapping pointer if needed.
goType := d.DataType
if goType == "" && found {
ft := fieldType.Type
for ft.Kind() == reflect.Ptr {
ft = ft.Elem()
}
goType = ft.Name()
}
// isPrimary: use both the GORM-tag detection and a name comparison against
// the known primary key (handles camelCase "primaryKey" tags correctly).
isPrimary := d.SQLKey == "primary_key" ||
(info.pkName != "" && (sqlName == info.pkName || jsonName == info.pkName))
ci := columnInfo{
jsonName: jsonName,
sqlName: sqlName,
goType: goType,
sqlType: d.SQLDataType,
isPrimary: isPrimary,
isUnique: d.SQLKey == "unique" || d.SQLKey == "uniqueindex",
isFK: d.SQLKey == "foreign_key",
nullable: d.Nullable,
}
info.columns = append(info.columns, ci)
}
info.schemaDoc = buildSchemaDoc(info)
return info
}
// fieldJSONName returns the JSON tag name for a struct field, falling back to the field name.
func fieldJSONName(modelType reflect.Type, fieldName string) string {
field, ok := modelType.FieldByName(fieldName)
if !ok {
return fieldName
}
tag := field.Tag.Get("json")
if tag == "" {
return fieldName
}
parts := strings.SplitN(tag, ",", 2)
if parts[0] == "" {
return fieldName
}
return parts[0]
}
// buildSchemaDoc builds a human-readable column listing for inclusion in tool descriptions.
func buildSchemaDoc(info modelInfo) string {
if len(info.columns) == 0 {
return ""
}
var sb strings.Builder
sb.WriteString("Columns:\n")
for _, c := range info.columns {
line := fmt.Sprintf(" • %s", c.jsonName)
typeDesc := c.goType
if c.sqlType != "" {
typeDesc = c.sqlType
}
if typeDesc != "" {
line += fmt.Sprintf(" (%s)", typeDesc)
}
var flags []string
if c.isPrimary {
flags = append(flags, "primary key")
}
if c.isUnique {
flags = append(flags, "unique")
}
if c.isFK {
flags = append(flags, "foreign key")
}
if !c.nullable && !c.isPrimary {
flags = append(flags, "not null")
} else if c.nullable {
flags = append(flags, "nullable")
}
if len(flags) > 0 {
line += " — " + strings.Join(flags, ", ")
}
sb.WriteString(line + "\n")
}
if len(info.relationNames) > 0 {
sb.WriteString("Relations (preloadable): " + strings.Join(info.relationNames, ", ") + "\n")
}
return sb.String()
}
// columnNameList returns a comma-separated list of JSON column names (for descriptions).
func columnNameList(cols []columnInfo) string {
names := make([]string, len(cols))
for i, c := range cols {
names[i] = c.jsonName
}
return strings.Join(names, ", ")
}
// writableColumnNames returns JSON names for all non-primary-key columns.
func writableColumnNames(cols []columnInfo) []string {
var names []string
for _, c := range cols {
if !c.isPrimary {
names = append(names, c.jsonName)
}
}
return names
}
// --------------------------------------------------------------------------
// Read tool
// --------------------------------------------------------------------------
func registerReadTool(h *Handler, schema, entity string, info modelInfo) {
name := toolName("read", schema, entity)
var descParts []string
descParts = append(descParts, fmt.Sprintf("Read records from the '%s' database table.", info.fullName))
if info.pkName != "" {
descParts = append(descParts, fmt.Sprintf("Primary key: '%s'. Pass it via 'id' to fetch a single record.", info.pkName))
}
if info.schemaDoc != "" {
descParts = append(descParts, info.schemaDoc)
}
descParts = append(descParts,
"Pagination: use 'limit'/'offset' for offset-based paging, or 'cursor_forward'/'cursor_backward' (pass the primary key value of the last/first record on the current page) for cursor-based paging.",
"Filtering: each filter object requires 'column' (JSON field name) and 'operator'. Supported operators: = != > < >= <= like ilike in is_null is_not_null. Combine with 'logic_operator': AND (default) or OR.",
"Sorting: each sort object requires 'column' and 'direction' (asc or desc).",
)
if len(info.relationNames) > 0 {
descParts = append(descParts, fmt.Sprintf("Preloadable relations: %s. Pass relation name in 'preloads'.", strings.Join(info.relationNames, ", ")))
}
description := strings.Join(descParts, "\n\n")
filterDesc := `Array of filter objects. Example: [{"column":"status","operator":"=","value":"active"},{"column":"age","operator":">","value":18,"logic_operator":"AND"}]`
if len(info.columns) > 0 {
filterDesc += fmt.Sprintf(" Available columns: %s.", columnNameList(info.columns))
}
sortDesc := `Array of sort objects. Example: [{"column":"created_at","direction":"desc"}]`
if len(info.columns) > 0 {
sortDesc += fmt.Sprintf(" Available columns: %s.", columnNameList(info.columns))
}
tool := mcp.NewTool(name,
mcp.WithDescription(description),
mcp.WithString("id",
mcp.Description(fmt.Sprintf("Primary key (%s) of a single record to fetch. Omit to return multiple records.", info.pkName)),
),
mcp.WithNumber("limit",
mcp.Description("Maximum number of records to return per page. Recommended: 10100."),
),
mcp.WithNumber("offset",
mcp.Description("Number of records to skip (for offset-based pagination). Use with 'limit'."),
),
mcp.WithString("cursor_forward",
mcp.Description(fmt.Sprintf("Cursor for the next page: pass the '%s' value of the last record on the current page. Requires 'sort' to be set.", info.pkName)),
),
mcp.WithString("cursor_backward",
mcp.Description(fmt.Sprintf("Cursor for the previous page: pass the '%s' value of the first record on the current page. Requires 'sort' to be set.", info.pkName)),
),
mcp.WithArray("columns",
mcp.Description(fmt.Sprintf("Columns to include in the result. Omit to return all columns. Available: %s.", columnNameList(info.columns))),
),
mcp.WithArray("omit_columns",
mcp.Description(fmt.Sprintf("Columns to exclude from the result. Available: %s.", columnNameList(info.columns))),
),
mcp.WithArray("filters",
mcp.Description(filterDesc),
),
mcp.WithArray("sort",
mcp.Description(sortDesc),
),
mcp.WithArray("preloads",
mcp.Description(buildPreloadDesc(info)),
),
)
h.mcpServer.AddTool(tool, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
args := req.GetArguments()
id, _ := args["id"].(string)
options := parseRequestOptions(args)
data, metadata, err := h.executeRead(ctx, schema, entity, id, options)
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
return marshalResult(map[string]interface{}{
"success": true,
"data": data,
"metadata": metadata,
})
})
}
func buildPreloadDesc(info modelInfo) string {
if len(info.relationNames) == 0 {
return `Array of relation preload objects. Each object: {"relation":"RelationName"}. No relations defined on this model.`
}
return fmt.Sprintf(
`Array of relation preload objects. Each object: {"relation":"RelationName","columns":["col1","col2"]}. Available relations: %s.`,
strings.Join(info.relationNames, ", "),
)
}
// --------------------------------------------------------------------------
// Create tool
// --------------------------------------------------------------------------
func registerCreateTool(h *Handler, schema, entity string, info modelInfo) {
name := toolName("create", schema, entity)
writable := writableColumnNames(info.columns)
var descParts []string
descParts = append(descParts, fmt.Sprintf("Create one or more new records in the '%s' table.", info.fullName))
if len(writable) > 0 {
descParts = append(descParts, fmt.Sprintf("Writable fields: %s.", strings.Join(writable, ", ")))
}
if info.pkName != "" {
descParts = append(descParts, fmt.Sprintf("The primary key ('%s') is typically auto-generated — omit it unless you need to supply it explicitly.", info.pkName))
}
descParts = append(descParts,
"Pass a single JSON object to 'data' to create one record. Pass an array of objects to create multiple records in a single transaction (all succeed or all fail).",
)
if info.schemaDoc != "" {
descParts = append(descParts, info.schemaDoc)
}
description := strings.Join(descParts, "\n\n")
dataDesc := "Record fields to create."
if len(writable) > 0 {
dataDesc += fmt.Sprintf(" Writable fields: %s.", strings.Join(writable, ", "))
}
dataDesc += " Pass a single object or an array of objects."
tool := mcp.NewTool(name,
mcp.WithDescription(description),
mcp.WithObject("data",
mcp.Description(dataDesc),
mcp.Required(),
),
)
h.mcpServer.AddTool(tool, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
args := req.GetArguments()
data, ok := args["data"]
if !ok {
return mcp.NewToolResultError("missing required argument: data"), nil
}
result, err := h.executeCreate(ctx, schema, entity, data)
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
return marshalResult(map[string]interface{}{
"success": true,
"data": result,
})
})
}
// --------------------------------------------------------------------------
// Update tool
// --------------------------------------------------------------------------
func registerUpdateTool(h *Handler, schema, entity string, info modelInfo) {
name := toolName("update", schema, entity)
writable := writableColumnNames(info.columns)
var descParts []string
descParts = append(descParts, fmt.Sprintf("Update an existing record in the '%s' table.", info.fullName))
if info.pkName != "" {
descParts = append(descParts, fmt.Sprintf("Identify the record by its primary key ('%s') via the 'id' argument or by including '%s' inside 'data'.", info.pkName, info.pkName))
}
if len(writable) > 0 {
descParts = append(descParts, fmt.Sprintf("Updatable fields: %s.", strings.Join(writable, ", ")))
}
descParts = append(descParts,
"Only non-null, non-empty fields in 'data' are applied — existing values are preserved for fields you omit. Returns the merged record as stored.",
)
if info.schemaDoc != "" {
descParts = append(descParts, info.schemaDoc)
}
description := strings.Join(descParts, "\n\n")
idDesc := fmt.Sprintf("Primary key ('%s') of the record to update. Can also be included inside 'data'.", info.pkName)
dataDesc := "Fields to update (non-null, non-empty values are merged into the existing record)."
if len(writable) > 0 {
dataDesc += fmt.Sprintf(" Updatable fields: %s.", strings.Join(writable, ", "))
}
tool := mcp.NewTool(name,
mcp.WithDescription(description),
mcp.WithString("id",
mcp.Description(idDesc),
),
mcp.WithObject("data",
mcp.Description(dataDesc),
mcp.Required(),
),
)
h.mcpServer.AddTool(tool, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
args := req.GetArguments()
id, _ := args["id"].(string)
data, ok := args["data"]
if !ok {
return mcp.NewToolResultError("missing required argument: data"), nil
}
dataMap, ok := data.(map[string]interface{})
if !ok {
return mcp.NewToolResultError("data must be an object"), nil
}
result, err := h.executeUpdate(ctx, schema, entity, id, dataMap)
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
return marshalResult(map[string]interface{}{
"success": true,
"data": result,
})
})
}
// --------------------------------------------------------------------------
// Delete tool
// --------------------------------------------------------------------------
func registerDeleteTool(h *Handler, schema, entity string, info modelInfo) {
name := toolName("delete", schema, entity)
descParts := []string{
fmt.Sprintf("Delete a record from the '%s' table by its primary key.", info.fullName),
}
if info.pkName != "" {
descParts = append(descParts, fmt.Sprintf("Pass the '%s' value of the record to delete via the 'id' argument.", info.pkName))
}
descParts = append(descParts, "Returns the deleted record. This operation is irreversible.")
description := strings.Join(descParts, " ")
tool := mcp.NewTool(name,
mcp.WithDescription(description),
mcp.WithString("id",
mcp.Description(fmt.Sprintf("Primary key ('%s') of the record to delete.", info.pkName)),
mcp.Required(),
),
)
h.mcpServer.AddTool(tool, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
args := req.GetArguments()
id, _ := args["id"].(string)
result, err := h.executeDelete(ctx, schema, entity, id)
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
return marshalResult(map[string]interface{}{
"success": true,
"data": result,
})
})
}
// --------------------------------------------------------------------------
// Resource registration
// --------------------------------------------------------------------------
func registerModelResource(h *Handler, schema, entity string, info modelInfo) {
resourceURI := info.fullName
var resourceDesc strings.Builder
fmt.Fprintf(&resourceDesc, "Database table: %s", info.fullName)
if info.pkName != "" {
fmt.Fprintf(&resourceDesc, " (primary key: %s)", info.pkName)
}
if info.schemaDoc != "" {
resourceDesc.WriteString("\n\n")
resourceDesc.WriteString(info.schemaDoc)
}
resource := mcp.NewResource(
resourceURI,
entity,
mcp.WithResourceDescription(resourceDesc.String()),
mcp.WithMIMEType("application/json"),
)
h.mcpServer.AddResource(resource, func(ctx context.Context, req mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) {
limit := 100
options := common.RequestOptions{Limit: &limit}
data, metadata, err := h.executeRead(ctx, schema, entity, "", options)
if err != nil {
return nil, err
}
payload := map[string]interface{}{
"data": data,
"metadata": metadata,
}
jsonBytes, err := json.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("error marshaling resource: %w", err)
}
return []mcp.ResourceContents{
mcp.TextResourceContents{
URI: req.Params.URI,
MIMEType: "application/json",
Text: string(jsonBytes),
},
}, nil
})
}
// --------------------------------------------------------------------------
// Argument parsing helpers
// --------------------------------------------------------------------------
// parseRequestOptions converts raw MCP tool arguments into common.RequestOptions.
func parseRequestOptions(args map[string]interface{}) common.RequestOptions {
options := common.RequestOptions{}
if v, ok := args["limit"]; ok {
switch n := v.(type) {
case float64:
limit := int(n)
options.Limit = &limit
case int:
options.Limit = &n
}
}
if v, ok := args["offset"]; ok {
switch n := v.(type) {
case float64:
offset := int(n)
options.Offset = &offset
case int:
options.Offset = &n
}
}
if v, ok := args["cursor_forward"].(string); ok {
options.CursorForward = v
}
if v, ok := args["cursor_backward"].(string); ok {
options.CursorBackward = v
}
options.Columns = parseStringArray(args["columns"])
options.OmitColumns = parseStringArray(args["omit_columns"])
options.Filters = parseFilters(args["filters"])
options.Sort = parseSortOptions(args["sort"])
options.Preload = parsePreloadOptions(args["preloads"])
return options
}
func parseStringArray(raw interface{}) []string {
if raw == nil {
return nil
}
items, ok := raw.([]interface{})
if !ok {
return nil
}
result := make([]string, 0, len(items))
for _, item := range items {
if s, ok := item.(string); ok {
result = append(result, s)
}
}
return result
}
func parseFilters(raw interface{}) []common.FilterOption {
if raw == nil {
return nil
}
items, ok := raw.([]interface{})
if !ok {
return nil
}
result := make([]common.FilterOption, 0, len(items))
for _, item := range items {
b, err := json.Marshal(item)
if err != nil {
continue
}
var f common.FilterOption
if err := json.Unmarshal(b, &f); err != nil {
continue
}
if f.Column == "" || f.Operator == "" {
continue
}
if strings.EqualFold(f.LogicOperator, "or") {
f.LogicOperator = "OR"
} else {
f.LogicOperator = "AND"
}
result = append(result, f)
}
return result
}
func parseSortOptions(raw interface{}) []common.SortOption {
if raw == nil {
return nil
}
items, ok := raw.([]interface{})
if !ok {
return nil
}
result := make([]common.SortOption, 0, len(items))
for _, item := range items {
b, err := json.Marshal(item)
if err != nil {
continue
}
var s common.SortOption
if err := json.Unmarshal(b, &s); err != nil {
continue
}
if s.Column == "" {
continue
}
result = append(result, s)
}
return result
}
func parsePreloadOptions(raw interface{}) []common.PreloadOption {
if raw == nil {
return nil
}
items, ok := raw.([]interface{})
if !ok {
return nil
}
result := make([]common.PreloadOption, 0, len(items))
for _, item := range items {
b, err := json.Marshal(item)
if err != nil {
continue
}
var p common.PreloadOption
if err := json.Unmarshal(b, &p); err != nil {
continue
}
if p.Relation == "" {
continue
}
result = append(result, p)
}
return result
}
// marshalResult marshals a value to JSON and returns it as an MCP text result.
func marshalResult(v interface{}) (*mcp.CallToolResult, error) {
b, err := json.Marshal(v)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("error marshaling result: %v", err)), nil
}
return mcp.NewToolResultText(string(b)), nil
}

View File

@@ -24,6 +24,7 @@ const (
// - pkName: primary key column (e.g. "id") // - pkName: primary key column (e.g. "id")
// - modelColumns: optional list of valid main-table columns (for validation). Pass nil to skip. // - modelColumns: optional list of valid main-table columns (for validation). Pass nil to skip.
// - options: the request options containing sort and cursor information // - options: the request options containing sort and cursor information
// - expandJoins: optional map[alias]string of JOIN clauses for join-column sort support
// //
// Returns SQL snippet to embed in WHERE clause. // Returns SQL snippet to embed in WHERE clause.
func GetCursorFilter( func GetCursorFilter(
@@ -31,6 +32,7 @@ func GetCursorFilter(
pkName string, pkName string,
modelColumns []string, modelColumns []string,
options common.RequestOptions, options common.RequestOptions,
expandJoins map[string]string,
) (string, error) { ) (string, error) {
// Separate schema prefix from bare table name // Separate schema prefix from bare table name
fullTableName := tableName fullTableName := tableName
@@ -58,18 +60,19 @@ func GetCursorFilter(
// 3. Prepare // 3. Prepare
// --------------------------------------------------------------------- // // --------------------------------------------------------------------- //
var whereClauses []string var whereClauses []string
joinSQL := ""
reverse := direction < 0 reverse := direction < 0
// --------------------------------------------------------------------- // // --------------------------------------------------------------------- //
// 4. Process each sort column // 4. Process each sort column
// --------------------------------------------------------------------- // // --------------------------------------------------------------------- //
for _, s := range sortItems { for _, s := range sortItems {
col := strings.TrimSpace(s.Column) col := strings.Trim(strings.TrimSpace(s.Column), "()")
if col == "" { if col == "" {
continue continue
} }
// Parse: "created_at", "user.name", etc. // Parse: "created_at", "user.name", "fn.sortorder", etc.
parts := strings.Split(col, ".") parts := strings.Split(col, ".")
field := strings.TrimSpace(parts[len(parts)-1]) field := strings.TrimSpace(parts[len(parts)-1])
prefix := strings.Join(parts[:len(parts)-1], ".") prefix := strings.Join(parts[:len(parts)-1], ".")
@@ -82,7 +85,7 @@ func GetCursorFilter(
} }
// Resolve column // Resolve column
cursorCol, targetCol, err := resolveColumn( cursorCol, targetCol, isJoin, err := resolveColumn(
field, prefix, tableName, modelColumns, field, prefix, tableName, modelColumns,
) )
if err != nil { if err != nil {
@@ -90,6 +93,22 @@ func GetCursorFilter(
continue continue
} }
// Handle joins
if isJoin {
if expandJoins != nil {
if joinClause, ok := expandJoins[prefix]; ok {
jSQL, cRef := rewriteJoin(joinClause, tableName, prefix)
joinSQL = jSQL
cursorCol = cRef + "." + field
targetCol = prefix + "." + field
}
}
if cursorCol == "" {
logger.Warn("Skipping cursor sort column %q: join alias %q not in expandJoins", col, prefix)
continue
}
}
// Build inequality // Build inequality
op := "<" op := "<"
if desc { if desc {
@@ -113,10 +132,12 @@ func GetCursorFilter(
query := fmt.Sprintf(`EXISTS ( query := fmt.Sprintf(`EXISTS (
SELECT 1 SELECT 1
FROM %s cursor_select FROM %s cursor_select
%s
WHERE cursor_select.%s = %s WHERE cursor_select.%s = %s
AND (%s) AND (%s)
)`, )`,
fullTableName, fullTableName,
joinSQL,
pkName, pkName,
cursorID, cursorID,
orSQL, orSQL,
@@ -137,35 +158,44 @@ func getActiveCursor(options common.RequestOptions) (id string, direction Cursor
return "", 0 return "", 0
} }
// Helper: resolve column (main table only for now) // Helper: resolve column (main table or join)
func resolveColumn( func resolveColumn(
field, prefix, tableName string, field, prefix, tableName string,
modelColumns []string, modelColumns []string,
) (cursorCol, targetCol string, err error) { ) (cursorCol, targetCol string, isJoin bool, err error) {
// JSON field // JSON field
if strings.Contains(field, "->") { if strings.Contains(field, "->") {
return "cursor_select." + field, tableName + "." + field, nil return "cursor_select." + field, tableName + "." + field, false, nil
} }
// Main table column // Main table column
if modelColumns != nil { if modelColumns != nil {
for _, col := range modelColumns { for _, col := range modelColumns {
if strings.EqualFold(col, field) { if strings.EqualFold(col, field) {
return "cursor_select." + field, tableName + "." + field, nil return "cursor_select." + field, tableName + "." + field, false, nil
} }
} }
} else { } else {
// No validation → allow all main-table fields // No validation → allow all main-table fields
return "cursor_select." + field, tableName + "." + field, nil return "cursor_select." + field, tableName + "." + field, false, nil
} }
// Joined column (not supported in resolvespec yet) // Joined column
if prefix != "" && prefix != tableName { if prefix != "" && prefix != tableName {
return "", "", fmt.Errorf("joined columns not supported in cursor pagination: %s", field) return "", "", true, nil
} }
return "", "", fmt.Errorf("invalid column: %s", field) return "", "", false, fmt.Errorf("invalid column: %s", field)
}
// Helper: rewrite JOIN clause for cursor subquery
func rewriteJoin(joinClause, mainTable, alias string) (joinSQL, cursorAlias string) {
joinSQL = strings.ReplaceAll(joinClause, mainTable+".", "cursor_select.")
cursorAlias = "cursor_select_" + alias
joinSQL = strings.ReplaceAll(joinSQL, " "+alias+" ", " "+cursorAlias+" ")
joinSQL = strings.ReplaceAll(joinSQL, " "+alias+".", " "+cursorAlias+".")
return joinSQL, cursorAlias
} }
// ------------------------------------------------------------------------- // // ------------------------------------------------------------------------- //

View File

@@ -20,7 +20,7 @@ func TestGetCursorFilter_Forward(t *testing.T) {
pkName := "id" pkName := "id"
modelColumns := []string{"id", "title", "created_at", "user_id"} modelColumns := []string{"id", "title", "created_at", "user_id"}
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options) filter, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil)
if err != nil { if err != nil {
t.Fatalf("GetCursorFilter failed: %v", err) t.Fatalf("GetCursorFilter failed: %v", err)
} }
@@ -65,7 +65,7 @@ func TestGetCursorFilter_Backward(t *testing.T) {
pkName := "id" pkName := "id"
modelColumns := []string{"id", "title", "created_at", "user_id"} modelColumns := []string{"id", "title", "created_at", "user_id"}
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options) filter, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil)
if err != nil { if err != nil {
t.Fatalf("GetCursorFilter failed: %v", err) t.Fatalf("GetCursorFilter failed: %v", err)
} }
@@ -96,7 +96,7 @@ func TestGetCursorFilter_NoCursor(t *testing.T) {
pkName := "id" pkName := "id"
modelColumns := []string{"id", "title", "created_at"} modelColumns := []string{"id", "title", "created_at"}
_, err := GetCursorFilter(tableName, pkName, modelColumns, options) _, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil)
if err == nil { if err == nil {
t.Error("Expected error when no cursor is provided") t.Error("Expected error when no cursor is provided")
} }
@@ -116,7 +116,7 @@ func TestGetCursorFilter_NoSort(t *testing.T) {
pkName := "id" pkName := "id"
modelColumns := []string{"id", "title"} modelColumns := []string{"id", "title"}
_, err := GetCursorFilter(tableName, pkName, modelColumns, options) _, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil)
if err == nil { if err == nil {
t.Error("Expected error when no sort columns are defined") t.Error("Expected error when no sort columns are defined")
} }
@@ -140,7 +140,7 @@ func TestGetCursorFilter_MultiColumnSort(t *testing.T) {
pkName := "id" pkName := "id"
modelColumns := []string{"id", "title", "priority", "created_at"} modelColumns := []string{"id", "title", "priority", "created_at"}
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options) filter, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil)
if err != nil { if err != nil {
t.Fatalf("GetCursorFilter failed: %v", err) t.Fatalf("GetCursorFilter failed: %v", err)
} }
@@ -170,7 +170,7 @@ func TestGetCursorFilter_WithSchemaPrefix(t *testing.T) {
pkName := "id" pkName := "id"
modelColumns := []string{"id", "name", "email"} modelColumns := []string{"id", "name", "email"}
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options) filter, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil)
if err != nil { if err != nil {
t.Fatalf("GetCursorFilter failed: %v", err) t.Fatalf("GetCursorFilter failed: %v", err)
} }
@@ -183,6 +183,37 @@ func TestGetCursorFilter_WithSchemaPrefix(t *testing.T) {
t.Logf("Generated cursor filter with schema: %s", filter) t.Logf("Generated cursor filter with schema: %s", filter)
} }
func TestGetCursorFilter_LateralJoin(t *testing.T) {
lateralJoin := "inner join lateral (\nselect string_agg(a.name, '.') as sortorder\nfrom tree(account.rid_account) r\ninner join account a on a.id = r.id\n) fn on true"
options := common.RequestOptions{
Sort: []common.SortOption{{Column: "fn.sortorder", Direction: "ASC"}},
CursorForward: "8975",
}
tableName := "core.account"
pkName := "rid_account"
modelColumns := []string{"rid_account", "description", "pastelno"}
expandJoins := map[string]string{"fn": lateralJoin}
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options, expandJoins)
if err != nil {
t.Fatalf("GetCursorFilter failed: %v", err)
}
t.Logf("Generated lateral cursor filter: %s", filter)
if !strings.Contains(filter, "cursor_select_fn") {
t.Errorf("Filter should reference cursor_select_fn alias, got: %s", filter)
}
if !strings.Contains(filter, "sortorder") {
t.Errorf("Filter should reference sortorder column, got: %s", filter)
}
if strings.Contains(filter, " < ") || strings.Contains(filter, " > ") {
t.Errorf("Filter should not contain empty comparison operators, got: %s", filter)
}
}
func TestGetActiveCursor(t *testing.T) { func TestGetActiveCursor(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
@@ -288,18 +319,19 @@ func TestResolveColumn(t *testing.T) {
wantErr: false, wantErr: false,
}, },
{ {
name: "Joined column (not supported)", name: "Joined column (isJoin=true, no error)",
field: "name", field: "name",
prefix: "user", prefix: "user",
tableName: "posts", tableName: "posts",
modelColumns: []string{"id", "title"}, modelColumns: []string{"id", "title"},
wantErr: true, wantErr: false,
// cursorCol and targetCol are empty when isJoin=true; handled by caller
}, },
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
cursor, target, err := resolveColumn(tt.field, tt.prefix, tt.tableName, tt.modelColumns) cursor, target, isJoin, err := resolveColumn(tt.field, tt.prefix, tt.tableName, tt.modelColumns)
if tt.wantErr { if tt.wantErr {
if err == nil { if err == nil {
@@ -312,6 +344,14 @@ func TestResolveColumn(t *testing.T) {
t.Fatalf("Unexpected error: %v", err) t.Fatalf("Unexpected error: %v", err)
} }
// For join columns, cursor/target are empty and isJoin=true
if isJoin {
if cursor != "" || target != "" {
t.Errorf("Expected empty cursor/target for join column, got %q / %q", cursor, target)
}
return
}
if cursor != tt.wantCursor { if cursor != tt.wantCursor {
t.Errorf("Expected cursor %q, got %q", tt.wantCursor, cursor) t.Errorf("Expected cursor %q, got %q", tt.wantCursor, cursor)
} }
@@ -362,7 +402,7 @@ func TestCursorFilter_SQL_Safety(t *testing.T) {
pkName := "id" pkName := "id"
modelColumns := []string{"id", "created_at"} modelColumns := []string{"id", "created_at"}
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options) filter, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil)
if err != nil { if err != nil {
t.Fatalf("GetCursorFilter failed: %v", err) t.Fatalf("GetCursorFilter failed: %v", err)
} }

View File

@@ -334,8 +334,8 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
options.Sort = []common.SortOption{{Column: pkName, Direction: "ASC"}} options.Sort = []common.SortOption{{Column: pkName, Direction: "ASC"}}
} }
// Get cursor filter SQL // Get cursor filter SQL (expandJoins is empty for resolvespec — no custom SQL join support yet)
cursorFilter, err := GetCursorFilter(tableName, pkName, modelColumns, options) cursorFilter, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil)
if err != nil { if err != nil {
logger.Error("Error building cursor filter: %v", err) logger.Error("Error building cursor filter: %v", err)
h.sendError(w, http.StatusBadRequest, "cursor_error", "Invalid cursor pagination", err) h.sendError(w, http.StatusBadRequest, "cursor_error", "Invalid cursor pagination", err)

View File

@@ -64,7 +64,7 @@ func (opts *ExtendedRequestOptions) GetCursorFilter(
// 4. Process each sort column // 4. Process each sort column
// --------------------------------------------------------------------- // // --------------------------------------------------------------------- //
for _, s := range sortItems { for _, s := range sortItems {
col := strings.TrimSpace(s.Column) col := strings.Trim(strings.TrimSpace(s.Column), "()")
if col == "" { if col == "" {
continue continue
} }
@@ -93,12 +93,18 @@ func (opts *ExtendedRequestOptions) GetCursorFilter(
} }
// Handle joins // Handle joins
if isJoin && expandJoins != nil { if isJoin {
if joinClause, ok := expandJoins[prefix]; ok { if expandJoins != nil {
jSQL, cRef := rewriteJoin(joinClause, tableName, prefix) if joinClause, ok := expandJoins[prefix]; ok {
joinSQL = jSQL jSQL, cRef := rewriteJoin(joinClause, tableName, prefix)
cursorCol = cRef + "." + field joinSQL = jSQL
targetCol = prefix + "." + field cursorCol = cRef + "." + field
targetCol = prefix + "." + field
}
}
if cursorCol == "" {
logger.Warn("Skipping cursor sort column %q: join alias %q not in expandJoins", col, prefix)
continue
} }
} }

View File

@@ -278,6 +278,47 @@ func TestCleanSortField(t *testing.T) {
} }
} }
func TestGetCursorFilter_LateralJoin(t *testing.T) {
lateralJoin := "inner join lateral (\nselect string_agg(a.name, '.') as sortorder\nfrom tree(account.rid_account) r\ninner join account a on a.id = r.id\n) fn on true"
opts := &ExtendedRequestOptions{
RequestOptions: common.RequestOptions{
Sort: []common.SortOption{
{Column: "fn.sortorder", Direction: "ASC"},
},
},
}
opts.CursorForward = "8975"
tableName := "core.account"
pkName := "rid_account"
// modelColumns does not contain "sortorder" - it's a lateral join computed column
modelColumns := []string{"rid_account", "description", "pastelno"}
expandJoins := map[string]string{"fn": lateralJoin}
filter, err := opts.GetCursorFilter(tableName, pkName, modelColumns, expandJoins)
if err != nil {
t.Fatalf("GetCursorFilter failed: %v", err)
}
t.Logf("Generated lateral cursor filter: %s", filter)
// Should contain the rewritten lateral join inside the EXISTS subquery
if !strings.Contains(filter, "cursor_select_fn") {
t.Errorf("Filter should reference cursor_select_fn alias, got: %s", filter)
}
// Should compare fn.sortorder values
if !strings.Contains(filter, "sortorder") {
t.Errorf("Filter should reference sortorder column, got: %s", filter)
}
// Should NOT contain empty comparison like "< "
if strings.Contains(filter, " < ") || strings.Contains(filter, " > ") {
t.Errorf("Filter should not contain empty comparison operators, got: %s", filter)
}
}
func TestBuildPriorityChain(t *testing.T) { func TestBuildPriorityChain(t *testing.T) {
clauses := []string{ clauses := []string{
"cursor_select.priority > posts.priority", "cursor_select.priority > posts.priority",

View File

@@ -723,13 +723,15 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
// Extract model columns for validation using the generic database function // Extract model columns for validation using the generic database function
modelColumns := reflection.GetModelColumns(model) modelColumns := reflection.GetModelColumns(model)
// Build expand joins map (if needed in future) // Build expand joins map: custom SQL joins are available in cursor subquery
var expandJoins map[string]string expandJoins := make(map[string]string)
if len(options.Expand) > 0 { for _, joinClause := range options.CustomSQLJoin {
expandJoins = make(map[string]string) alias := extractJoinAlias(joinClause)
// TODO: Build actual JOIN SQL for each expand relation if alias != "" {
// For now, pass empty map as joins are handled via Preload expandJoins[alias] = joinClause
}
} }
// TODO: also add Expand relation JOINs when those are built as SQL rather than Preload
// Default sort to primary key when none provided // Default sort to primary key when none provided
if len(options.Sort) == 0 { if len(options.Sort) == 0 {

View File

@@ -274,9 +274,11 @@ func (h *Handler) parseOptionsFromHeaders(r common.Request, model interface{}) E
} }
} }
// Resolve relation names (convert table names to field names) if model is provided // Resolve relation names (convert table names/prefixes to actual model field names) if model is provided.
// Skip resolution if X-Files header was provided, as XFiles uses Prefix which already contains the correct field names // This runs for both regular headers and X-Files, because XFile prefixes don't always match model
if model != nil && !options.XFilesPresent { // field names (e.g., prefix "HUB" vs field "HUB_RID_HUB"). RelatedKey/ForeignKey are used to
// disambiguate when multiple fields point to the same related type.
if model != nil {
h.resolveRelationNamesInOptions(&options, model) h.resolveRelationNamesInOptions(&options, model)
} }
@@ -550,10 +552,8 @@ func (h *Handler) parseCustomSQLJoin(options *ExtendedRequestOptions, value stri
// - "LEFT JOIN departments d ON ..." -> "d" // - "LEFT JOIN departments d ON ..." -> "d"
// - "INNER JOIN users AS u ON ..." -> "u" // - "INNER JOIN users AS u ON ..." -> "u"
// - "JOIN roles r ON ..." -> "r" // - "JOIN roles r ON ..." -> "r"
// - "INNER JOIN LATERAL (...) fn ON true" -> "fn"
func extractJoinAlias(joinClause string) string { func extractJoinAlias(joinClause string) string {
// Pattern: JOIN table_name [AS] alias ON ...
// We need to extract the alias (word before ON)
upperJoin := strings.ToUpper(joinClause) upperJoin := strings.ToUpper(joinClause)
// Find the "JOIN" keyword position // Find the "JOIN" keyword position
@@ -562,7 +562,20 @@ func extractJoinAlias(joinClause string) string {
return "" return ""
} }
// Find the "ON" keyword position // Lateral joins: alias is the word after the closing ) and before ON
if strings.Contains(upperJoin, "LATERAL") {
lastClose := strings.LastIndex(joinClause, ")")
if lastClose != -1 {
words := strings.Fields(joinClause[lastClose+1:])
// words should be like ["fn", "on", "true"] or ["on", "true"]
if len(words) >= 1 && !strings.EqualFold(words[0], "on") {
return words[0]
}
}
return ""
}
// Regular joins: find the "ON" keyword position (first occurrence)
onIdx := strings.Index(upperJoin, " ON ") onIdx := strings.Index(upperJoin, " ON ")
if onIdx == -1 { if onIdx == -1 {
return "" return ""
@@ -863,8 +876,21 @@ func (h *Handler) resolveRelationNamesInOptions(options *ExtendedRequestOptions,
// Resolve each part of the path // Resolve each part of the path
currentModel := model currentModel := model
for _, part := range parts { for partIdx, part := range parts {
resolvedPart := h.resolveRelationName(currentModel, part) isLast := partIdx == len(parts)-1
var resolvedPart string
if isLast {
// For the final part, use join-key-aware resolution to disambiguate when
// multiple fields point to the same type (e.g., HUB_RID_HUB vs HUB_RID_ASSIGNEDTO).
// RelatedKey = parent's local column linking to child; ForeignKey = local column linking to parent.
localKey := preload.RelatedKey
if localKey == "" {
localKey = preload.ForeignKey
}
resolvedPart = h.resolveRelationNameWithJoinKey(currentModel, part, localKey)
} else {
resolvedPart = h.resolveRelationName(currentModel, part)
}
resolvedParts = append(resolvedParts, resolvedPart) resolvedParts = append(resolvedParts, resolvedPart)
// Try to get the model type for the next level // Try to get the model type for the next level
@@ -980,6 +1006,101 @@ func (h *Handler) resolveRelationName(model interface{}, nameOrTable string) str
return nameOrTable return nameOrTable
} }
// resolveRelationNameWithJoinKey resolves a relation name like resolveRelationName, but when
// multiple fields point to the same related type, uses localKey to pick the one whose bun join
// tag starts with "join:localKey=". Falls back to resolveRelationName if no key match is found.
func (h *Handler) resolveRelationNameWithJoinKey(model interface{}, nameOrTable string, localKey string) string {
if localKey == "" {
return h.resolveRelationName(model, nameOrTable)
}
modelType := reflect.TypeOf(model)
if modelType == nil {
return nameOrTable
}
if modelType.Kind() == reflect.Ptr {
modelType = modelType.Elem()
}
if modelType == nil || modelType.Kind() != reflect.Struct {
return nameOrTable
}
// If it's already a direct field name, return as-is (no ambiguity).
for i := 0; i < modelType.NumField(); i++ {
if modelType.Field(i).Name == nameOrTable {
return nameOrTable
}
}
normalizedInput := strings.ToLower(strings.ReplaceAll(nameOrTable, "_", ""))
localKeyLower := strings.ToLower(localKey)
// Find all fields whose related type matches nameOrTable, then pick the one
// whose bun join tag local key matches localKey.
var fallbackField string
for i := 0; i < modelType.NumField(); i++ {
field := modelType.Field(i)
fieldType := field.Type
var targetType reflect.Type
if fieldType.Kind() == reflect.Slice {
targetType = fieldType.Elem()
} else if fieldType.Kind() == reflect.Ptr {
targetType = fieldType.Elem()
}
if targetType != nil && targetType.Kind() == reflect.Ptr {
targetType = targetType.Elem()
}
if targetType == nil || targetType.Kind() != reflect.Struct {
continue
}
normalizedTypeName := strings.ToLower(targetType.Name())
normalizedTypeName = strings.TrimPrefix(normalizedTypeName, "modelcore")
normalizedTypeName = strings.TrimPrefix(normalizedTypeName, "model")
if normalizedTypeName != normalizedInput {
continue
}
// Type name matches; record as fallback.
if fallbackField == "" {
fallbackField = field.Name
}
// Check bun join tag: "join:localKey=foreignKey"
bunTag := field.Tag.Get("bun")
for _, tagPart := range strings.Split(bunTag, ",") {
tagPart = strings.TrimSpace(tagPart)
if !strings.HasPrefix(tagPart, "join:") {
continue
}
joinSpec := strings.TrimPrefix(tagPart, "join:")
// joinSpec can be "col1=col2" or "col1=col2 col3=col4" (multi-col joins)
joinCols := strings.Fields(joinSpec)
if len(joinCols) == 0 {
joinCols = []string{joinSpec}
}
for _, joinCol := range joinCols {
eqIdx := strings.Index(joinCol, "=")
if eqIdx < 0 {
continue
}
joinLocalKey := strings.ToLower(joinCol[:eqIdx])
if joinLocalKey == localKeyLower {
logger.Debug("Resolved '%s' (localKey: %s) -> field '%s'", nameOrTable, localKey, field.Name)
return field.Name
}
}
}
}
if fallbackField != "" {
logger.Debug("No join key match for '%s' (localKey: %s), using first type match: '%s'", nameOrTable, localKey, fallbackField)
return fallbackField
}
return h.resolveRelationName(model, nameOrTable)
}
// addXFilesPreload converts an XFiles relation into a PreloadOption // addXFilesPreload converts an XFiles relation into a PreloadOption
// and recursively processes its children // and recursively processes its children
func (h *Handler) addXFilesPreload(xfile *XFiles, options *ExtendedRequestOptions, basePath string) { func (h *Handler) addXFilesPreload(xfile *XFiles, options *ExtendedRequestOptions, basePath string) {

View File

@@ -142,6 +142,16 @@ func TestExtractJoinAlias(t *testing.T) {
joinClause: "LEFT JOIN departments", joinClause: "LEFT JOIN departments",
expected: "", expected: "",
}, },
{
name: "LATERAL join with alias",
joinClause: "inner join lateral (select sortorder from compute_fn(t.id)) fn on true",
expected: "fn",
},
{
name: "LATERAL join with multiline subquery containing inner ON",
joinClause: "inner join lateral (\nselect string_agg(a.name, '.') as sortorder\nfrom tree(t.id) r\ninner join account a on a.id = r.id\n) fn on true",
expected: "fn",
},
} }
for _, tt := range tests { for _, tt := range tests {