mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2026-04-09 09:26:24 +00:00
Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7ef9cf39d3 | ||
|
|
7f6410f665 | ||
|
|
835bbb0727 | ||
|
|
047a1cc187 |
71
README.md
71
README.md
@@ -9,6 +9,7 @@ ResolveSpec is a flexible and powerful REST API specification and implementation
|
||||
3. **FuncSpec** - Header-based API to map and call API's to sql functions
|
||||
4. **WebSocketSpec** - Real-time bidirectional communication with full CRUD operations
|
||||
5. **MQTTSpec** - MQTT-based API ideal for IoT and mobile applications
|
||||
6. **ResolveMCP** - Model Context Protocol (MCP) server that exposes models as AI-consumable tools and resources over HTTP/SSE
|
||||
|
||||
All share the same core architecture and provide dynamic data querying, relationship preloading, and complex filtering.
|
||||
|
||||
@@ -21,6 +22,7 @@ All share the same core architecture and provide dynamic data querying, relation
|
||||
* [Quick Start](#quick-start)
|
||||
* [ResolveSpec (Body-Based API)](#resolvespec---body-based-api)
|
||||
* [RestHeadSpec (Header-Based API)](#restheadspec---header-based-api)
|
||||
* [ResolveMCP (MCP Server)](#resolvemcp---mcp-server)
|
||||
* [Architecture](#architecture)
|
||||
* [API Structure](#api-structure)
|
||||
* [RestHeadSpec Overview](#restheadspec-header-based-api)
|
||||
@@ -50,6 +52,15 @@ All share the same core architecture and provide dynamic data querying, relation
|
||||
* **🆕 Backward Compatible**: Existing code works without changes
|
||||
* **🆕 Better Testing**: Mockable interfaces for easy unit testing
|
||||
|
||||
### ResolveMCP (v3.2+)
|
||||
|
||||
* **🆕 MCP Server**: Expose any registered database model as Model Context Protocol tools and resources
|
||||
* **🆕 AI-Ready Descriptions**: Tool descriptions include the full column schema, primary key, nullable flags, and relations — giving AI models everything they need to query correctly without guessing
|
||||
* **🆕 Four Tools Per Model**: `read_`, `create_`, `update_`, `delete_` tools auto-registered per model
|
||||
* **🆕 Full Query Support**: Filters, sort, limit/offset, cursor pagination, column selection, and relation preloading all available as tool parameters
|
||||
* **🆕 HTTP/SSE Transport**: Standards-compliant SSE transport for use with Claude Desktop, Cursor, and any MCP-compatible client
|
||||
* **🆕 Lifecycle Hooks**: Same Before/After hook system as ResolveSpec for auth and side-effects
|
||||
|
||||
### RestHeadSpec (v2.1+)
|
||||
|
||||
* **🆕 Header-Based API**: All query options passed via HTTP headers instead of request body
|
||||
@@ -190,6 +201,40 @@ restheadspec.SetupMuxRoutes(router, handler, nil)
|
||||
|
||||
For complete documentation, see [pkg/restheadspec/README.md](pkg/restheadspec/README.md).
|
||||
|
||||
### ResolveMCP (MCP Server)
|
||||
|
||||
ResolveMCP exposes registered models as Model Context Protocol tools so AI models (Claude, Cursor, etc.) can query and mutate your database directly:
|
||||
|
||||
```go
|
||||
import "github.com/bitechdev/ResolveSpec/pkg/resolvemcp"
|
||||
|
||||
// Create handler
|
||||
handler := resolvemcp.NewHandlerWithGORM(db)
|
||||
|
||||
// Register models — must be done BEFORE Build()
|
||||
handler.RegisterModel("public", "users", &User{})
|
||||
handler.RegisterModel("public", "posts", &Post{})
|
||||
|
||||
// Finalize: registers MCP tools and resources
|
||||
handler.Build()
|
||||
|
||||
// Mount SSE transport on your existing router
|
||||
router := mux.NewRouter()
|
||||
resolvemcp.SetupMuxRoutes(router, handler, "http://localhost:8080")
|
||||
|
||||
// MCP clients connect to:
|
||||
// SSE stream: GET http://localhost:8080/mcp/sse
|
||||
// Messages: POST http://localhost:8080/mcp/message
|
||||
//
|
||||
// Auto-registered tools per model:
|
||||
// read_public_users — filter, sort, paginate, preload
|
||||
// create_public_users — insert a new record
|
||||
// update_public_users — update a record by ID
|
||||
// delete_public_users — delete a record by ID
|
||||
```
|
||||
|
||||
For complete documentation, see [pkg/resolvemcp/README.md](pkg/resolvemcp/README.md) (if present) or the package source.
|
||||
|
||||
## Architecture
|
||||
|
||||
### Two Complementary APIs
|
||||
@@ -344,6 +389,19 @@ Alternative REST API where query options are passed via HTTP headers.
|
||||
|
||||
For complete documentation, see [pkg/restheadspec/README.md](pkg/restheadspec/README.md).
|
||||
|
||||
#### ResolveMCP - MCP Server
|
||||
|
||||
Expose any registered model as Model Context Protocol tools and resources consumable by AI models over HTTP/SSE.
|
||||
|
||||
**Key Features**:
|
||||
- Four tools per model: `read_`, `create_`, `update_`, `delete_`
|
||||
- Rich AI-readable descriptions: column names, types, primary key, nullable flags, and preloadable relations
|
||||
- Full query support: filters, sort, limit/offset, cursor pagination, column selection, preloads
|
||||
- HTTP/SSE transport compatible with Claude Desktop, Cursor, and any MCP client
|
||||
- Same Before/After lifecycle hooks as ResolveSpec
|
||||
|
||||
For complete documentation, see [pkg/resolvemcp/](pkg/resolvemcp/).
|
||||
|
||||
#### FuncSpec - Function-Based SQL API
|
||||
|
||||
Execute SQL functions and queries through a simple HTTP API with header-based parameters.
|
||||
@@ -529,7 +587,18 @@ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file
|
||||
|
||||
## What's New
|
||||
|
||||
### v3.1 (Latest - February 2026)
|
||||
### v3.2 (Latest - March 2026)
|
||||
|
||||
**ResolveMCP - Model Context Protocol Server (🆕)**:
|
||||
|
||||
* **MCP Tools**: Four tools auto-registered per model (`read_`, `create_`, `update_`, `delete_`) over HTTP/SSE transport
|
||||
* **AI-Ready Descriptions**: Full column schema, primary key, nullable flags, and relation names surfaced in tool descriptions so AI models can query without guessing
|
||||
* **Full Query Support**: Filters, sort, limit/offset, cursor pagination, column selection, and relation preloading all available as tool parameters
|
||||
* **HTTP/SSE Transport**: Standards-compliant transport compatible with Claude Desktop, Cursor, and any MCP 2024-11-05 client
|
||||
* **Lifecycle Hooks**: Same Before/After hook system as ResolveSpec for auth, auditing, and side-effects
|
||||
* **MCP Resources**: Each model also exposed as a named resource for direct data access by AI clients
|
||||
|
||||
### v3.1 (February 2026)
|
||||
|
||||
**SQLite Schema Translation (🆕)**:
|
||||
|
||||
|
||||
5
go.mod
5
go.mod
@@ -40,6 +40,7 @@ require (
|
||||
go.opentelemetry.io/otel/trace v1.38.0
|
||||
go.uber.org/zap v1.27.1
|
||||
golang.org/x/crypto v0.46.0
|
||||
golang.org/x/oauth2 v0.34.0
|
||||
golang.org/x/time v0.14.0
|
||||
gorm.io/driver/postgres v1.6.0
|
||||
gorm.io/driver/sqlite v1.6.0
|
||||
@@ -78,6 +79,7 @@ require (
|
||||
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect
|
||||
github.com/golang-sql/sqlexp v0.1.0 // indirect
|
||||
github.com/golang/snappy v1.0.0 // indirect
|
||||
github.com/google/jsonschema-go v0.4.2 // indirect
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 // indirect
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
||||
@@ -86,6 +88,7 @@ require (
|
||||
github.com/jinzhu/now v1.1.5 // indirect
|
||||
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
|
||||
github.com/magiconair/properties v1.8.10 // indirect
|
||||
github.com/mark3labs/mcp-go v0.46.0 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/moby/docker-image-spec v1.3.1 // indirect
|
||||
github.com/moby/go-archive v0.1.0 // indirect
|
||||
@@ -131,6 +134,7 @@ require (
|
||||
github.com/xdg-go/pbkdf2 v1.0.0 // indirect
|
||||
github.com/xdg-go/scram v1.2.0 // indirect
|
||||
github.com/xdg-go/stringprep v1.0.4 // indirect
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
|
||||
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 // indirect
|
||||
github.com/yusufpapurcu/wmi v1.2.4 // indirect
|
||||
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
|
||||
@@ -143,7 +147,6 @@ require (
|
||||
golang.org/x/exp v0.0.0-20251219203646-944ab1f22d93 // indirect
|
||||
golang.org/x/mod v0.31.0 // indirect
|
||||
golang.org/x/net v0.48.0 // indirect
|
||||
golang.org/x/oauth2 v0.34.0 // indirect
|
||||
golang.org/x/sync v0.19.0 // indirect
|
||||
golang.org/x/sys v0.39.0 // indirect
|
||||
golang.org/x/text v0.32.0 // indirect
|
||||
|
||||
6
go.sum
6
go.sum
@@ -120,6 +120,8 @@ github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/jsonschema-go v0.4.2 h1:tmrUohrwoLZZS/P3x7ex0WAVknEkBZM46iALbcqoRA8=
|
||||
github.com/google/jsonschema-go v0.4.2/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE=
|
||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
|
||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
|
||||
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
@@ -173,6 +175,8 @@ github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ
|
||||
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I=
|
||||
github.com/magiconair/properties v1.8.10 h1:s31yESBquKXCV9a/ScB3ESkOjUYYv+X0rg8SYxI99mE=
|
||||
github.com/magiconair/properties v1.8.10/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
|
||||
github.com/mark3labs/mcp-go v0.46.0 h1:8KRibF4wcKejbLsHxCA/QBVUr5fQ9nwz/n8lGqmaALo=
|
||||
github.com/mark3labs/mcp-go v0.46.0/go.mod h1:JKTC7R2LLVagkEWK7Kwu7DbmA6iIvnNAod6yrHiQMag=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-sqlite3 v1.14.33 h1:A5blZ5ulQo2AtayQ9/limgHEkFreKj1Dv226a1K73s0=
|
||||
@@ -326,6 +330,8 @@ github.com/xdg-go/scram v1.2.0 h1:bYKF2AEwG5rqd1BumT4gAnvwU/M9nBp2pTSxeZw7Wvs=
|
||||
github.com/xdg-go/scram v1.2.0/go.mod h1:3dlrS0iBaWKYVt2ZfA4cj48umJZ+cAEbR6/SjLA88I8=
|
||||
github.com/xdg-go/stringprep v1.0.4 h1:XLI/Ng3O1Atzq0oBs3TWm+5ZVgkq2aqdlvP9JtoZ6c8=
|
||||
github.com/xdg-go/stringprep v1.0.4/go.mod h1:mPGuuIYwz7CmR2bT9j4GbQqutWS1zV24gijq1dTyGkM=
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
|
||||
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 h1:ilQV1hzziu+LLM3zUTJ0trRztfwgjqKnBWNtSRkbmwM=
|
||||
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78/go.mod h1:aL8wCCfTfSfmXjznFBSZNN13rSJjlIOI1fUNAtF7rmI=
|
||||
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
||||
|
||||
71
pkg/resolvemcp/context.go
Normal file
71
pkg/resolvemcp/context.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package resolvemcp
|
||||
|
||||
import "context"
|
||||
|
||||
type contextKey string
|
||||
|
||||
const (
|
||||
contextKeySchema contextKey = "schema"
|
||||
contextKeyEntity contextKey = "entity"
|
||||
contextKeyTableName contextKey = "tableName"
|
||||
contextKeyModel contextKey = "model"
|
||||
contextKeyModelPtr contextKey = "modelPtr"
|
||||
)
|
||||
|
||||
func WithSchema(ctx context.Context, schema string) context.Context {
|
||||
return context.WithValue(ctx, contextKeySchema, schema)
|
||||
}
|
||||
|
||||
func GetSchema(ctx context.Context) string {
|
||||
if v := ctx.Value(contextKeySchema); v != nil {
|
||||
return v.(string)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func WithEntity(ctx context.Context, entity string) context.Context {
|
||||
return context.WithValue(ctx, contextKeyEntity, entity)
|
||||
}
|
||||
|
||||
func GetEntity(ctx context.Context) string {
|
||||
if v := ctx.Value(contextKeyEntity); v != nil {
|
||||
return v.(string)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func WithTableName(ctx context.Context, tableName string) context.Context {
|
||||
return context.WithValue(ctx, contextKeyTableName, tableName)
|
||||
}
|
||||
|
||||
func GetTableName(ctx context.Context) string {
|
||||
if v := ctx.Value(contextKeyTableName); v != nil {
|
||||
return v.(string)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func WithModel(ctx context.Context, model interface{}) context.Context {
|
||||
return context.WithValue(ctx, contextKeyModel, model)
|
||||
}
|
||||
|
||||
func GetModel(ctx context.Context) interface{} {
|
||||
return ctx.Value(contextKeyModel)
|
||||
}
|
||||
|
||||
func WithModelPtr(ctx context.Context, modelPtr interface{}) context.Context {
|
||||
return context.WithValue(ctx, contextKeyModelPtr, modelPtr)
|
||||
}
|
||||
|
||||
func GetModelPtr(ctx context.Context) interface{} {
|
||||
return ctx.Value(contextKeyModelPtr)
|
||||
}
|
||||
|
||||
func withRequestData(ctx context.Context, schema, entity, tableName string, model, modelPtr interface{}) context.Context {
|
||||
ctx = WithSchema(ctx, schema)
|
||||
ctx = WithEntity(ctx, entity)
|
||||
ctx = WithTableName(ctx, tableName)
|
||||
ctx = WithModel(ctx, model)
|
||||
ctx = WithModelPtr(ctx, modelPtr)
|
||||
return ctx
|
||||
}
|
||||
161
pkg/resolvemcp/cursor.go
Normal file
161
pkg/resolvemcp/cursor.go
Normal file
@@ -0,0 +1,161 @@
|
||||
package resolvemcp
|
||||
|
||||
// Cursor-based pagination adapted from pkg/resolvespec/cursor.go.
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
)
|
||||
|
||||
type cursorDirection int
|
||||
|
||||
const (
|
||||
cursorForward cursorDirection = 1
|
||||
cursorBackward cursorDirection = -1
|
||||
)
|
||||
|
||||
// getCursorFilter generates a SQL EXISTS subquery for cursor-based pagination.
|
||||
// expandJoins is an optional map[alias]string of JOIN clauses for join-column sort support.
|
||||
func getCursorFilter(
|
||||
tableName string,
|
||||
pkName string,
|
||||
modelColumns []string,
|
||||
options common.RequestOptions,
|
||||
expandJoins map[string]string,
|
||||
) (string, error) {
|
||||
fullTableName := tableName
|
||||
if strings.Contains(tableName, ".") {
|
||||
tableName = strings.SplitN(tableName, ".", 2)[1]
|
||||
}
|
||||
|
||||
cursorID, direction := getActiveCursor(options)
|
||||
if cursorID == "" {
|
||||
return "", fmt.Errorf("no cursor provided for table %s", tableName)
|
||||
}
|
||||
|
||||
sortItems := options.Sort
|
||||
if len(sortItems) == 0 {
|
||||
return "", fmt.Errorf("no sort columns defined")
|
||||
}
|
||||
|
||||
var whereClauses []string
|
||||
joinSQL := ""
|
||||
reverse := direction < 0
|
||||
|
||||
for _, s := range sortItems {
|
||||
col := strings.TrimSpace(s.Column)
|
||||
if col == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
parts := strings.Split(col, ".")
|
||||
field := strings.TrimSpace(parts[len(parts)-1])
|
||||
prefix := strings.Join(parts[:len(parts)-1], ".")
|
||||
|
||||
desc := strings.EqualFold(s.Direction, "desc")
|
||||
if reverse {
|
||||
desc = !desc
|
||||
}
|
||||
|
||||
cursorCol, targetCol, isJoin, err := resolveCursorColumn(field, prefix, tableName, modelColumns)
|
||||
if err != nil {
|
||||
logger.Warn("Skipping invalid sort column %q: %v", col, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if isJoin {
|
||||
if expandJoins != nil {
|
||||
if joinClause, ok := expandJoins[prefix]; ok {
|
||||
jSQL, cRef := rewriteCursorJoin(joinClause, tableName, prefix)
|
||||
joinSQL = jSQL
|
||||
cursorCol = cRef + "." + field
|
||||
targetCol = prefix + "." + field
|
||||
}
|
||||
}
|
||||
if cursorCol == "" {
|
||||
logger.Warn("Skipping cursor sort column %q: join alias %q not in expandJoins", col, prefix)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
op := "<"
|
||||
if desc {
|
||||
op = ">"
|
||||
}
|
||||
whereClauses = append(whereClauses, fmt.Sprintf("%s %s %s", cursorCol, op, targetCol))
|
||||
}
|
||||
|
||||
if len(whereClauses) == 0 {
|
||||
return "", fmt.Errorf("no valid sort columns after filtering")
|
||||
}
|
||||
|
||||
orSQL := buildCursorPriorityChain(whereClauses)
|
||||
|
||||
query := fmt.Sprintf(`EXISTS (
|
||||
SELECT 1
|
||||
FROM %s cursor_select
|
||||
%s
|
||||
WHERE cursor_select.%s = %s
|
||||
AND (%s)
|
||||
)`,
|
||||
fullTableName,
|
||||
joinSQL,
|
||||
pkName,
|
||||
cursorID,
|
||||
orSQL,
|
||||
)
|
||||
|
||||
return query, nil
|
||||
}
|
||||
|
||||
func getActiveCursor(options common.RequestOptions) (id string, direction cursorDirection) {
|
||||
if options.CursorForward != "" {
|
||||
return options.CursorForward, cursorForward
|
||||
}
|
||||
if options.CursorBackward != "" {
|
||||
return options.CursorBackward, cursorBackward
|
||||
}
|
||||
return "", 0
|
||||
}
|
||||
|
||||
func resolveCursorColumn(field, prefix, tableName string, modelColumns []string) (cursorCol, targetCol string, isJoin bool, err error) {
|
||||
if strings.Contains(field, "->") {
|
||||
return "cursor_select." + field, tableName + "." + field, false, nil
|
||||
}
|
||||
|
||||
if modelColumns != nil {
|
||||
for _, col := range modelColumns {
|
||||
if strings.EqualFold(col, field) {
|
||||
return "cursor_select." + field, tableName + "." + field, false, nil
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return "cursor_select." + field, tableName + "." + field, false, nil
|
||||
}
|
||||
|
||||
if prefix != "" && prefix != tableName {
|
||||
return "", "", true, nil
|
||||
}
|
||||
|
||||
return "", "", false, fmt.Errorf("invalid column: %s", field)
|
||||
}
|
||||
|
||||
func rewriteCursorJoin(joinClause, mainTable, alias string) (joinSQL, cursorAlias string) {
|
||||
joinSQL = strings.ReplaceAll(joinClause, mainTable+".", "cursor_select.")
|
||||
cursorAlias = "cursor_select_" + alias
|
||||
joinSQL = strings.ReplaceAll(joinSQL, " "+alias+" ", " "+cursorAlias+" ")
|
||||
joinSQL = strings.ReplaceAll(joinSQL, " "+alias+".", " "+cursorAlias+".")
|
||||
return joinSQL, cursorAlias
|
||||
}
|
||||
|
||||
func buildCursorPriorityChain(clauses []string) string {
|
||||
var or []string
|
||||
for i := 0; i < len(clauses); i++ {
|
||||
and := strings.Join(clauses[:i+1], "\n AND ")
|
||||
or = append(or, "("+and+")")
|
||||
}
|
||||
return strings.Join(or, "\n OR ")
|
||||
}
|
||||
645
pkg/resolvemcp/handler.go
Normal file
645
pkg/resolvemcp/handler.go
Normal file
@@ -0,0 +1,645 @@
|
||||
package resolvemcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/mark3labs/mcp-go/server"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
||||
)
|
||||
|
||||
// Handler exposes registered database models as MCP tools and resources.
|
||||
type Handler struct {
|
||||
db common.Database
|
||||
registry common.ModelRegistry
|
||||
hooks *HookRegistry
|
||||
mcpServer *server.MCPServer
|
||||
name string
|
||||
version string
|
||||
}
|
||||
|
||||
// NewHandler creates a Handler with the given database and model registry.
|
||||
func NewHandler(db common.Database, registry common.ModelRegistry) *Handler {
|
||||
return &Handler{
|
||||
db: db,
|
||||
registry: registry,
|
||||
hooks: NewHookRegistry(),
|
||||
mcpServer: server.NewMCPServer("resolvemcp", "1.0.0"),
|
||||
name: "resolvemcp",
|
||||
version: "1.0.0",
|
||||
}
|
||||
}
|
||||
|
||||
// Hooks returns the hook registry.
|
||||
func (h *Handler) Hooks() *HookRegistry {
|
||||
return h.hooks
|
||||
}
|
||||
|
||||
// GetDatabase returns the underlying database.
|
||||
func (h *Handler) GetDatabase() common.Database {
|
||||
return h.db
|
||||
}
|
||||
|
||||
// MCPServer returns the underlying MCP server, e.g. to add custom tools.
|
||||
func (h *Handler) MCPServer() *server.MCPServer {
|
||||
return h.mcpServer
|
||||
}
|
||||
|
||||
// RegisterModel registers a model and immediately exposes it as MCP tools and a resource.
|
||||
func (h *Handler) RegisterModel(schema, entity string, model interface{}) error {
|
||||
fullName := buildModelName(schema, entity)
|
||||
if err := h.registry.RegisterModel(fullName, model); err != nil {
|
||||
return err
|
||||
}
|
||||
registerModelTools(h, schema, entity, model)
|
||||
return nil
|
||||
}
|
||||
|
||||
// buildModelName builds the registry key for a model (same format as resolvespec).
|
||||
func buildModelName(schema, entity string) string {
|
||||
if schema == "" {
|
||||
return entity
|
||||
}
|
||||
return fmt.Sprintf("%s.%s", schema, entity)
|
||||
}
|
||||
|
||||
// getTableName returns the fully qualified table name for a model.
|
||||
func (h *Handler) getTableName(schema, entity string, model interface{}) string {
|
||||
schemaName, tableName := h.getSchemaAndTable(schema, entity, model)
|
||||
if schemaName != "" {
|
||||
if h.db.DriverName() == "sqlite" {
|
||||
return fmt.Sprintf("%s_%s", schemaName, tableName)
|
||||
}
|
||||
return fmt.Sprintf("%s.%s", schemaName, tableName)
|
||||
}
|
||||
return tableName
|
||||
}
|
||||
|
||||
func (h *Handler) getSchemaAndTable(defaultSchema, entity string, model interface{}) (schema, table string) {
|
||||
if tableProvider, ok := model.(common.TableNameProvider); ok {
|
||||
tableName := tableProvider.TableName()
|
||||
if idx := strings.LastIndex(tableName, "."); idx != -1 {
|
||||
return tableName[:idx], tableName[idx+1:]
|
||||
}
|
||||
if schemaProvider, ok := model.(common.SchemaProvider); ok {
|
||||
return schemaProvider.SchemaName(), tableName
|
||||
}
|
||||
return defaultSchema, tableName
|
||||
}
|
||||
if schemaProvider, ok := model.(common.SchemaProvider); ok {
|
||||
return schemaProvider.SchemaName(), entity
|
||||
}
|
||||
return defaultSchema, entity
|
||||
}
|
||||
|
||||
// executeRead reads records from the database and returns raw data + metadata.
|
||||
func (h *Handler) executeRead(ctx context.Context, schema, entity, id string, options common.RequestOptions) (interface{}, *common.Metadata, error) {
|
||||
model, err := h.registry.GetModelByEntity(schema, entity)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("model not found: %w", err)
|
||||
}
|
||||
|
||||
unwrapped, err := common.ValidateAndUnwrapModel(model)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("invalid model: %w", err)
|
||||
}
|
||||
|
||||
model = unwrapped.Model
|
||||
modelType := unwrapped.ModelType
|
||||
tableName := h.getTableName(schema, entity, model)
|
||||
ctx = withRequestData(ctx, schema, entity, tableName, model, unwrapped.ModelPtr)
|
||||
|
||||
validator := common.NewColumnValidator(model)
|
||||
options = validator.FilterRequestOptions(options)
|
||||
|
||||
// BeforeHandle hook
|
||||
hookCtx := &HookContext{
|
||||
Context: ctx,
|
||||
Handler: h,
|
||||
Schema: schema,
|
||||
Entity: entity,
|
||||
Model: model,
|
||||
Operation: "read",
|
||||
Options: options,
|
||||
ID: id,
|
||||
Tx: h.db,
|
||||
}
|
||||
if err := h.hooks.Execute(BeforeHandle, hookCtx); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
sliceType := reflect.SliceOf(reflect.PointerTo(modelType))
|
||||
modelPtr := reflect.New(sliceType).Interface()
|
||||
|
||||
query := h.db.NewSelect().Model(modelPtr)
|
||||
|
||||
tempInstance := reflect.New(modelType).Interface()
|
||||
if provider, ok := tempInstance.(common.TableNameProvider); !ok || provider.TableName() == "" {
|
||||
query = query.Table(tableName)
|
||||
}
|
||||
|
||||
// Column selection
|
||||
if len(options.Columns) == 0 && len(options.ComputedColumns) > 0 {
|
||||
options.Columns = reflection.GetSQLModelColumns(model)
|
||||
}
|
||||
for _, col := range options.Columns {
|
||||
query = query.Column(reflection.ExtractSourceColumn(col))
|
||||
}
|
||||
for _, cu := range options.ComputedColumns {
|
||||
query = query.ColumnExpr(fmt.Sprintf("(%s) AS %s", cu.Expression, cu.Name))
|
||||
}
|
||||
|
||||
// Preloads
|
||||
if len(options.Preload) > 0 {
|
||||
var err error
|
||||
query, err = h.applyPreloads(model, query, options.Preload)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to apply preloads: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Filters
|
||||
query = h.applyFilters(query, options.Filters)
|
||||
|
||||
// Custom operators
|
||||
for _, customOp := range options.CustomOperators {
|
||||
query = query.Where(customOp.SQL)
|
||||
}
|
||||
|
||||
// Sorting
|
||||
for _, sort := range options.Sort {
|
||||
direction := "ASC"
|
||||
if strings.EqualFold(sort.Direction, "desc") {
|
||||
direction = "DESC"
|
||||
}
|
||||
query = query.Order(fmt.Sprintf("%s %s", sort.Column, direction))
|
||||
}
|
||||
|
||||
// Cursor pagination
|
||||
if options.CursorForward != "" || options.CursorBackward != "" {
|
||||
pkName := reflection.GetPrimaryKeyName(model)
|
||||
modelColumns := reflection.GetModelColumns(model)
|
||||
|
||||
if len(options.Sort) == 0 {
|
||||
options.Sort = []common.SortOption{{Column: pkName, Direction: "ASC"}}
|
||||
}
|
||||
|
||||
// expandJoins is empty for resolvemcp — no custom SQL join support yet
|
||||
cursorFilter, err := getCursorFilter(tableName, pkName, modelColumns, options, nil)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("cursor error: %w", err)
|
||||
}
|
||||
|
||||
if cursorFilter != "" {
|
||||
sanitized := common.SanitizeWhereClause(cursorFilter, reflection.ExtractTableNameOnly(tableName), &options)
|
||||
sanitized = common.EnsureOuterParentheses(sanitized)
|
||||
if sanitized != "" {
|
||||
query = query.Where(sanitized)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Count
|
||||
total, err := query.Count(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("error counting records: %w", err)
|
||||
}
|
||||
|
||||
// Pagination
|
||||
if options.Limit != nil && *options.Limit > 0 {
|
||||
query = query.Limit(*options.Limit)
|
||||
}
|
||||
if options.Offset != nil && *options.Offset > 0 {
|
||||
query = query.Offset(*options.Offset)
|
||||
}
|
||||
|
||||
// BeforeRead hook
|
||||
hookCtx.Query = query
|
||||
if err := h.hooks.Execute(BeforeRead, hookCtx); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
var data interface{}
|
||||
if id != "" {
|
||||
singleResult := reflect.New(modelType).Interface()
|
||||
pkName := reflection.GetPrimaryKeyName(singleResult)
|
||||
query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id)
|
||||
if err := query.Scan(ctx, singleResult); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil, fmt.Errorf("record not found")
|
||||
}
|
||||
return nil, nil, fmt.Errorf("query error: %w", err)
|
||||
}
|
||||
data = singleResult
|
||||
} else {
|
||||
if err := query.Scan(ctx, modelPtr); err != nil {
|
||||
return nil, nil, fmt.Errorf("query error: %w", err)
|
||||
}
|
||||
data = reflect.ValueOf(modelPtr).Elem().Interface()
|
||||
}
|
||||
|
||||
limit := 0
|
||||
offset := 0
|
||||
if options.Limit != nil {
|
||||
limit = *options.Limit
|
||||
}
|
||||
if options.Offset != nil {
|
||||
offset = *options.Offset
|
||||
}
|
||||
|
||||
// Count is the number of records in this page, not the total.
|
||||
var pageCount int64
|
||||
if id != "" {
|
||||
pageCount = 1
|
||||
} else {
|
||||
pageCount = int64(reflect.ValueOf(data).Len())
|
||||
}
|
||||
|
||||
metadata := &common.Metadata{
|
||||
Total: int64(total),
|
||||
Filtered: int64(total),
|
||||
Count: pageCount,
|
||||
Limit: limit,
|
||||
Offset: offset,
|
||||
}
|
||||
|
||||
// AfterRead hook
|
||||
hookCtx.Result = data
|
||||
if err := h.hooks.Execute(AfterRead, hookCtx); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return data, metadata, nil
|
||||
}
|
||||
|
||||
// executeCreate inserts one or more records.
|
||||
func (h *Handler) executeCreate(ctx context.Context, schema, entity string, data interface{}) (interface{}, error) {
|
||||
model, err := h.registry.GetModelByEntity(schema, entity)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("model not found: %w", err)
|
||||
}
|
||||
|
||||
result, err := common.ValidateAndUnwrapModel(model)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid model: %w", err)
|
||||
}
|
||||
|
||||
model = result.Model
|
||||
tableName := h.getTableName(schema, entity, model)
|
||||
ctx = withRequestData(ctx, schema, entity, tableName, model, result.ModelPtr)
|
||||
|
||||
hookCtx := &HookContext{
|
||||
Context: ctx,
|
||||
Handler: h,
|
||||
Schema: schema,
|
||||
Entity: entity,
|
||||
Model: model,
|
||||
Operation: "create",
|
||||
Data: data,
|
||||
Tx: h.db,
|
||||
}
|
||||
if err := h.hooks.Execute(BeforeHandle, hookCtx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := h.hooks.Execute(BeforeCreate, hookCtx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Use potentially modified data
|
||||
data = hookCtx.Data
|
||||
|
||||
switch v := data.(type) {
|
||||
case map[string]interface{}:
|
||||
query := h.db.NewInsert().Table(tableName)
|
||||
for key, value := range v {
|
||||
query = query.Value(key, value)
|
||||
}
|
||||
if _, err := query.Exec(ctx); err != nil {
|
||||
return nil, fmt.Errorf("create error: %w", err)
|
||||
}
|
||||
hookCtx.Result = v
|
||||
if err := h.hooks.Execute(AfterCreate, hookCtx); err != nil {
|
||||
return nil, fmt.Errorf("AfterCreate hook failed: %w", err)
|
||||
}
|
||||
return v, nil
|
||||
|
||||
case []interface{}:
|
||||
results := make([]interface{}, 0, len(v))
|
||||
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
||||
for _, item := range v {
|
||||
itemMap, ok := item.(map[string]interface{})
|
||||
if !ok {
|
||||
return fmt.Errorf("each item must be an object")
|
||||
}
|
||||
q := tx.NewInsert().Table(tableName)
|
||||
for key, value := range itemMap {
|
||||
q = q.Value(key, value)
|
||||
}
|
||||
if _, err := q.Exec(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
results = append(results, item)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("batch create error: %w", err)
|
||||
}
|
||||
hookCtx.Result = results
|
||||
if err := h.hooks.Execute(AfterCreate, hookCtx); err != nil {
|
||||
return nil, fmt.Errorf("AfterCreate hook failed: %w", err)
|
||||
}
|
||||
return results, nil
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("data must be an object or array of objects")
|
||||
}
|
||||
}
|
||||
|
||||
// executeUpdate updates a record by ID.
|
||||
func (h *Handler) executeUpdate(ctx context.Context, schema, entity, id string, data interface{}) (interface{}, error) {
|
||||
model, err := h.registry.GetModelByEntity(schema, entity)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("model not found: %w", err)
|
||||
}
|
||||
|
||||
result, err := common.ValidateAndUnwrapModel(model)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid model: %w", err)
|
||||
}
|
||||
|
||||
model = result.Model
|
||||
tableName := h.getTableName(schema, entity, model)
|
||||
ctx = withRequestData(ctx, schema, entity, tableName, model, result.ModelPtr)
|
||||
|
||||
updates, ok := data.(map[string]interface{})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("data must be an object")
|
||||
}
|
||||
|
||||
if id == "" {
|
||||
if idVal, exists := updates["id"]; exists {
|
||||
id = fmt.Sprintf("%v", idVal)
|
||||
}
|
||||
}
|
||||
if id == "" {
|
||||
return nil, fmt.Errorf("update requires an ID")
|
||||
}
|
||||
|
||||
pkName := reflection.GetPrimaryKeyName(model)
|
||||
|
||||
var updateResult interface{}
|
||||
err = h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
||||
// Read existing record
|
||||
modelType := reflect.TypeOf(model)
|
||||
if modelType.Kind() == reflect.Ptr {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
existingRecord := reflect.New(modelType).Interface()
|
||||
selectQuery := tx.NewSelect().Model(existingRecord).Column("*").
|
||||
Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id)
|
||||
|
||||
if err := selectQuery.ScanModel(ctx); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return fmt.Errorf("no records found to update")
|
||||
}
|
||||
return fmt.Errorf("error fetching existing record: %w", err)
|
||||
}
|
||||
|
||||
// Convert to map
|
||||
existingMap := make(map[string]interface{})
|
||||
jsonData, err := json.Marshal(existingRecord)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error marshaling existing record: %w", err)
|
||||
}
|
||||
if err := json.Unmarshal(jsonData, &existingMap); err != nil {
|
||||
return fmt.Errorf("error unmarshaling existing record: %w", err)
|
||||
}
|
||||
|
||||
hookCtx := &HookContext{
|
||||
Context: ctx,
|
||||
Handler: h,
|
||||
Schema: schema,
|
||||
Entity: entity,
|
||||
Model: model,
|
||||
Operation: "update",
|
||||
ID: id,
|
||||
Data: updates,
|
||||
Tx: tx,
|
||||
}
|
||||
if err := h.hooks.Execute(BeforeUpdate, hookCtx); err != nil {
|
||||
return err
|
||||
}
|
||||
if modifiedData, ok := hookCtx.Data.(map[string]interface{}); ok {
|
||||
updates = modifiedData
|
||||
}
|
||||
|
||||
// Merge non-nil, non-empty values
|
||||
for key, newValue := range updates {
|
||||
if newValue == nil {
|
||||
continue
|
||||
}
|
||||
if strVal, ok := newValue.(string); ok && strVal == "" {
|
||||
continue
|
||||
}
|
||||
existingMap[key] = newValue
|
||||
}
|
||||
|
||||
q := tx.NewUpdate().Table(tableName).SetMap(existingMap).
|
||||
Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id)
|
||||
res, err := q.Exec(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error updating record: %w", err)
|
||||
}
|
||||
if res.RowsAffected() == 0 {
|
||||
return fmt.Errorf("no records found to update")
|
||||
}
|
||||
|
||||
updateResult = existingMap
|
||||
hookCtx.Result = updateResult
|
||||
return h.hooks.Execute(AfterUpdate, hookCtx)
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return updateResult, nil
|
||||
}
|
||||
|
||||
// executeDelete deletes a record by ID.
|
||||
func (h *Handler) executeDelete(ctx context.Context, schema, entity, id string) (interface{}, error) {
|
||||
if id == "" {
|
||||
return nil, fmt.Errorf("delete requires an ID")
|
||||
}
|
||||
|
||||
model, err := h.registry.GetModelByEntity(schema, entity)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("model not found: %w", err)
|
||||
}
|
||||
|
||||
result, err := common.ValidateAndUnwrapModel(model)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid model: %w", err)
|
||||
}
|
||||
|
||||
model = result.Model
|
||||
tableName := h.getTableName(schema, entity, model)
|
||||
ctx = withRequestData(ctx, schema, entity, tableName, model, result.ModelPtr)
|
||||
|
||||
pkName := reflection.GetPrimaryKeyName(model)
|
||||
|
||||
hookCtx := &HookContext{
|
||||
Context: ctx,
|
||||
Handler: h,
|
||||
Schema: schema,
|
||||
Entity: entity,
|
||||
Model: model,
|
||||
Operation: "delete",
|
||||
ID: id,
|
||||
Tx: h.db,
|
||||
}
|
||||
if err := h.hooks.Execute(BeforeHandle, hookCtx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := h.hooks.Execute(BeforeDelete, hookCtx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
modelType := reflect.TypeOf(model)
|
||||
if modelType.Kind() == reflect.Ptr {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
|
||||
var recordToDelete interface{}
|
||||
|
||||
err = h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
||||
record := reflect.New(modelType).Interface()
|
||||
selectQuery := tx.NewSelect().Model(record).
|
||||
Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id)
|
||||
if err := selectQuery.ScanModel(ctx); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return fmt.Errorf("record not found")
|
||||
}
|
||||
return fmt.Errorf("error fetching record: %w", err)
|
||||
}
|
||||
|
||||
res, err := tx.NewDelete().Table(tableName).
|
||||
Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id).
|
||||
Exec(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("delete error: %w", err)
|
||||
}
|
||||
if res.RowsAffected() == 0 {
|
||||
return fmt.Errorf("record not found or already deleted")
|
||||
}
|
||||
|
||||
recordToDelete = record
|
||||
hookCtx.Tx = tx
|
||||
hookCtx.Result = record
|
||||
return h.hooks.Execute(AfterDelete, hookCtx)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
logger.Info("[resolvemcp] Deleted record %s from %s.%s", id, schema, entity)
|
||||
return recordToDelete, nil
|
||||
}
|
||||
|
||||
// applyFilters applies all filters with OR grouping logic.
|
||||
func (h *Handler) applyFilters(query common.SelectQuery, filters []common.FilterOption) common.SelectQuery {
|
||||
if len(filters) == 0 {
|
||||
return query
|
||||
}
|
||||
|
||||
i := 0
|
||||
for i < len(filters) {
|
||||
startORGroup := i+1 < len(filters) && strings.EqualFold(filters[i+1].LogicOperator, "OR")
|
||||
|
||||
if startORGroup {
|
||||
orGroup := []common.FilterOption{filters[i]}
|
||||
j := i + 1
|
||||
for j < len(filters) && strings.EqualFold(filters[j].LogicOperator, "OR") {
|
||||
orGroup = append(orGroup, filters[j])
|
||||
j++
|
||||
}
|
||||
query = h.applyFilterGroup(query, orGroup)
|
||||
i = j
|
||||
} else {
|
||||
condition, args := h.buildFilterCondition(filters[i])
|
||||
if condition != "" {
|
||||
query = query.Where(condition, args...)
|
||||
}
|
||||
i++
|
||||
}
|
||||
}
|
||||
|
||||
return query
|
||||
}
|
||||
|
||||
func (h *Handler) applyFilterGroup(query common.SelectQuery, filters []common.FilterOption) common.SelectQuery {
|
||||
var conditions []string
|
||||
var args []interface{}
|
||||
|
||||
for _, filter := range filters {
|
||||
condition, filterArgs := h.buildFilterCondition(filter)
|
||||
if condition != "" {
|
||||
conditions = append(conditions, condition)
|
||||
args = append(args, filterArgs...)
|
||||
}
|
||||
}
|
||||
|
||||
if len(conditions) == 0 {
|
||||
return query
|
||||
}
|
||||
if len(conditions) == 1 {
|
||||
return query.Where(conditions[0], args...)
|
||||
}
|
||||
return query.Where("("+strings.Join(conditions, " OR ")+")", args...)
|
||||
}
|
||||
|
||||
func (h *Handler) buildFilterCondition(filter common.FilterOption) (string, []interface{}) {
|
||||
switch filter.Operator {
|
||||
case "eq", "=":
|
||||
return fmt.Sprintf("%s = ?", filter.Column), []interface{}{filter.Value}
|
||||
case "neq", "!=", "<>":
|
||||
return fmt.Sprintf("%s != ?", filter.Column), []interface{}{filter.Value}
|
||||
case "gt", ">":
|
||||
return fmt.Sprintf("%s > ?", filter.Column), []interface{}{filter.Value}
|
||||
case "gte", ">=":
|
||||
return fmt.Sprintf("%s >= ?", filter.Column), []interface{}{filter.Value}
|
||||
case "lt", "<":
|
||||
return fmt.Sprintf("%s < ?", filter.Column), []interface{}{filter.Value}
|
||||
case "lte", "<=":
|
||||
return fmt.Sprintf("%s <= ?", filter.Column), []interface{}{filter.Value}
|
||||
case "like":
|
||||
return fmt.Sprintf("%s LIKE ?", filter.Column), []interface{}{filter.Value}
|
||||
case "ilike":
|
||||
return fmt.Sprintf("%s ILIKE ?", filter.Column), []interface{}{filter.Value}
|
||||
case "in":
|
||||
condition, args := common.BuildInCondition(filter.Column, filter.Value)
|
||||
return condition, args
|
||||
case "is_null":
|
||||
return fmt.Sprintf("%s IS NULL", filter.Column), nil
|
||||
case "is_not_null":
|
||||
return fmt.Sprintf("%s IS NOT NULL", filter.Column), nil
|
||||
}
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, preloads []common.PreloadOption) (common.SelectQuery, error) {
|
||||
for _, preload := range preloads {
|
||||
if preload.Relation == "" {
|
||||
continue
|
||||
}
|
||||
query = query.PreloadRelation(preload.Relation)
|
||||
}
|
||||
return query, nil
|
||||
}
|
||||
113
pkg/resolvemcp/hooks.go
Normal file
113
pkg/resolvemcp/hooks.go
Normal file
@@ -0,0 +1,113 @@
|
||||
package resolvemcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
)
|
||||
|
||||
// HookType defines the type of hook to execute
|
||||
type HookType string
|
||||
|
||||
const (
|
||||
// BeforeHandle fires after model resolution, before operation dispatch.
|
||||
BeforeHandle HookType = "before_handle"
|
||||
|
||||
BeforeRead HookType = "before_read"
|
||||
AfterRead HookType = "after_read"
|
||||
|
||||
BeforeCreate HookType = "before_create"
|
||||
AfterCreate HookType = "after_create"
|
||||
|
||||
BeforeUpdate HookType = "before_update"
|
||||
AfterUpdate HookType = "after_update"
|
||||
|
||||
BeforeDelete HookType = "before_delete"
|
||||
AfterDelete HookType = "after_delete"
|
||||
)
|
||||
|
||||
// HookContext contains all the data available to a hook
|
||||
type HookContext struct {
|
||||
Context context.Context
|
||||
Handler *Handler
|
||||
Schema string
|
||||
Entity string
|
||||
Model interface{}
|
||||
Options common.RequestOptions
|
||||
Operation string
|
||||
ID string
|
||||
Data interface{}
|
||||
Result interface{}
|
||||
Error error
|
||||
Query common.SelectQuery
|
||||
Abort bool
|
||||
AbortMessage string
|
||||
AbortCode int
|
||||
Tx common.Database
|
||||
}
|
||||
|
||||
// HookFunc is the signature for hook functions
|
||||
type HookFunc func(*HookContext) error
|
||||
|
||||
// HookRegistry manages all registered hooks
|
||||
type HookRegistry struct {
|
||||
hooks map[HookType][]HookFunc
|
||||
}
|
||||
|
||||
func NewHookRegistry() *HookRegistry {
|
||||
return &HookRegistry{
|
||||
hooks: make(map[HookType][]HookFunc),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *HookRegistry) Register(hookType HookType, hook HookFunc) {
|
||||
if r.hooks == nil {
|
||||
r.hooks = make(map[HookType][]HookFunc)
|
||||
}
|
||||
r.hooks[hookType] = append(r.hooks[hookType], hook)
|
||||
logger.Info("Registered resolvemcp hook for %s (total: %d)", hookType, len(r.hooks[hookType]))
|
||||
}
|
||||
|
||||
func (r *HookRegistry) RegisterMultiple(hookTypes []HookType, hook HookFunc) {
|
||||
for _, hookType := range hookTypes {
|
||||
r.Register(hookType, hook)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *HookRegistry) Execute(hookType HookType, ctx *HookContext) error {
|
||||
hooks, exists := r.hooks[hookType]
|
||||
if !exists || len(hooks) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
logger.Debug("Executing %d resolvemcp hook(s) for %s", len(hooks), hookType)
|
||||
|
||||
for i, hook := range hooks {
|
||||
if err := hook(ctx); err != nil {
|
||||
logger.Error("resolvemcp hook %d for %s failed: %v", i+1, hookType, err)
|
||||
return fmt.Errorf("hook execution failed: %w", err)
|
||||
}
|
||||
|
||||
if ctx.Abort {
|
||||
logger.Warn("resolvemcp hook %d for %s requested abort: %s", i+1, hookType, ctx.AbortMessage)
|
||||
return fmt.Errorf("operation aborted by hook: %s", ctx.AbortMessage)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *HookRegistry) Clear(hookType HookType) {
|
||||
delete(r.hooks, hookType)
|
||||
}
|
||||
|
||||
func (r *HookRegistry) ClearAll() {
|
||||
r.hooks = make(map[HookType][]HookFunc)
|
||||
}
|
||||
|
||||
func (r *HookRegistry) HasHooks(hookType HookType) bool {
|
||||
hooks, exists := r.hooks[hookType]
|
||||
return exists && len(hooks) > 0
|
||||
}
|
||||
83
pkg/resolvemcp/resolvemcp.go
Normal file
83
pkg/resolvemcp/resolvemcp.go
Normal file
@@ -0,0 +1,83 @@
|
||||
// Package resolvemcp exposes registered database models as Model Context Protocol (MCP) tools
|
||||
// and resources over HTTP/SSE transport.
|
||||
//
|
||||
// It mirrors the resolvespec package patterns:
|
||||
// - Same model registration API
|
||||
// - Same filter, sort, cursor pagination, preload options
|
||||
// - Same lifecycle hook system
|
||||
//
|
||||
// Usage:
|
||||
//
|
||||
// handler := resolvemcp.NewHandlerWithGORM(db)
|
||||
// handler.RegisterModel("public", "users", &User{})
|
||||
//
|
||||
// r := mux.NewRouter()
|
||||
// resolvemcp.SetupMuxRoutes(r, handler, "http://localhost:8080")
|
||||
package resolvemcp
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/mark3labs/mcp-go/server"
|
||||
"github.com/uptrace/bun"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common/adapters/database"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
|
||||
)
|
||||
|
||||
// NewHandlerWithGORM creates a Handler backed by a GORM database connection.
|
||||
func NewHandlerWithGORM(db *gorm.DB) *Handler {
|
||||
return NewHandler(database.NewGormAdapter(db), modelregistry.NewModelRegistry())
|
||||
}
|
||||
|
||||
// NewHandlerWithBun creates a Handler backed by a Bun database connection.
|
||||
func NewHandlerWithBun(db *bun.DB) *Handler {
|
||||
return NewHandler(database.NewBunAdapter(db), modelregistry.NewModelRegistry())
|
||||
}
|
||||
|
||||
// NewHandlerWithDB creates a Handler using an existing common.Database and a new registry.
|
||||
func NewHandlerWithDB(db common.Database) *Handler {
|
||||
return NewHandler(db, modelregistry.NewModelRegistry())
|
||||
}
|
||||
|
||||
// SetupMuxRoutes mounts the MCP HTTP/SSE endpoints on the given Gorilla Mux router.
|
||||
//
|
||||
// baseURL is the public-facing base URL of the server (e.g. "http://localhost:8080").
|
||||
// It is sent to MCP clients during the SSE handshake so they know where to POST messages.
|
||||
//
|
||||
// Two routes are registered:
|
||||
// - GET /mcp/sse — SSE connection endpoint (client subscribes here)
|
||||
// - POST /mcp/message — JSON-RPC message endpoint (client sends requests here)
|
||||
//
|
||||
// To protect these routes with authentication, wrap the mux router or apply middleware
|
||||
// before calling SetupMuxRoutes.
|
||||
func SetupMuxRoutes(muxRouter *mux.Router, handler *Handler, baseURL string) {
|
||||
sseServer := server.NewSSEServer(
|
||||
handler.mcpServer,
|
||||
server.WithBaseURL(baseURL),
|
||||
server.WithBasePath("/mcp"),
|
||||
)
|
||||
|
||||
muxRouter.Handle("/mcp/sse", sseServer.SSEHandler()).Methods("GET", "OPTIONS")
|
||||
muxRouter.Handle("/mcp/message", sseServer.MessageHandler()).Methods("POST", "OPTIONS")
|
||||
|
||||
// Convenience: also expose the full SSE server at /mcp for clients that
|
||||
// use ServeHTTP directly (e.g. net/http default mux).
|
||||
muxRouter.PathPrefix("/mcp").Handler(http.StripPrefix("/mcp", sseServer))
|
||||
}
|
||||
|
||||
// NewSSEServer creates an *server.SSEServer that can be mounted manually,
|
||||
// useful when integrating with non-Mux routers or adding extra middleware.
|
||||
//
|
||||
// sseServer := resolvemcp.NewSSEServer(handler, "http://localhost:8080", "/mcp")
|
||||
// http.Handle("/mcp/", http.StripPrefix("/mcp", sseServer))
|
||||
func NewSSEServer(handler *Handler, baseURL, basePath string) *server.SSEServer {
|
||||
return server.NewSSEServer(
|
||||
handler.mcpServer,
|
||||
server.WithBaseURL(baseURL),
|
||||
server.WithBasePath(basePath),
|
||||
)
|
||||
}
|
||||
692
pkg/resolvemcp/tools.go
Normal file
692
pkg/resolvemcp/tools.go
Normal file
@@ -0,0 +1,692 @@
|
||||
package resolvemcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
||||
)
|
||||
|
||||
// toolName builds the MCP tool name for a given operation and model.
|
||||
func toolName(operation, schema, entity string) string {
|
||||
if schema == "" {
|
||||
return fmt.Sprintf("%s_%s", operation, entity)
|
||||
}
|
||||
return fmt.Sprintf("%s_%s_%s", operation, schema, entity)
|
||||
}
|
||||
|
||||
// registerModelTools registers the four CRUD tools and resource for a model.
|
||||
func registerModelTools(h *Handler, schema, entity string, model interface{}) {
|
||||
info := buildModelInfo(schema, entity, model)
|
||||
registerReadTool(h, schema, entity, info)
|
||||
registerCreateTool(h, schema, entity, info)
|
||||
registerUpdateTool(h, schema, entity, info)
|
||||
registerDeleteTool(h, schema, entity, info)
|
||||
registerModelResource(h, schema, entity, info)
|
||||
|
||||
logger.Info("[resolvemcp] Registered MCP tools for %s", info.fullName)
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Model introspection
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// modelInfo holds pre-computed metadata for a model used in tool descriptions.
|
||||
type modelInfo struct {
|
||||
fullName string // e.g. "public.users"
|
||||
pkName string // e.g. "id"
|
||||
columns []columnInfo
|
||||
relationNames []string
|
||||
schemaDoc string // formatted multi-line schema listing
|
||||
}
|
||||
|
||||
type columnInfo struct {
|
||||
jsonName string
|
||||
sqlName string
|
||||
goType string
|
||||
sqlType string
|
||||
isPrimary bool
|
||||
isUnique bool
|
||||
isFK bool
|
||||
nullable bool
|
||||
}
|
||||
|
||||
// buildModelInfo extracts column metadata and pre-builds the schema documentation string.
|
||||
func buildModelInfo(schema, entity string, model interface{}) modelInfo {
|
||||
info := modelInfo{
|
||||
fullName: buildModelName(schema, entity),
|
||||
pkName: reflection.GetPrimaryKeyName(model),
|
||||
}
|
||||
|
||||
// Unwrap to base struct type
|
||||
modelType := reflect.TypeOf(model)
|
||||
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice) {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||
return info
|
||||
}
|
||||
|
||||
details := reflection.GetModelColumnDetail(reflect.New(modelType).Elem())
|
||||
|
||||
for _, d := range details {
|
||||
// Derive the JSON name from the struct field
|
||||
jsonName := fieldJSONName(modelType, d.Name)
|
||||
if jsonName == "" || jsonName == "-" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip relation fields (slice or user-defined struct that isn't time.Time).
|
||||
fieldType, found := modelType.FieldByName(d.Name)
|
||||
if found {
|
||||
ft := fieldType.Type
|
||||
if ft.Kind() == reflect.Ptr {
|
||||
ft = ft.Elem()
|
||||
}
|
||||
isUserStruct := ft.Kind() == reflect.Struct && ft.Name() != "Time" && ft.PkgPath() != ""
|
||||
if ft.Kind() == reflect.Slice || isUserStruct {
|
||||
info.relationNames = append(info.relationNames, jsonName)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
sqlName := d.SQLName
|
||||
if sqlName == "" {
|
||||
sqlName = jsonName
|
||||
}
|
||||
|
||||
// Derive Go type name, unwrapping pointer if needed.
|
||||
goType := d.DataType
|
||||
if goType == "" && found {
|
||||
ft := fieldType.Type
|
||||
for ft.Kind() == reflect.Ptr {
|
||||
ft = ft.Elem()
|
||||
}
|
||||
goType = ft.Name()
|
||||
}
|
||||
|
||||
// isPrimary: use both the GORM-tag detection and a name comparison against
|
||||
// the known primary key (handles camelCase "primaryKey" tags correctly).
|
||||
isPrimary := d.SQLKey == "primary_key" ||
|
||||
(info.pkName != "" && (sqlName == info.pkName || jsonName == info.pkName))
|
||||
|
||||
ci := columnInfo{
|
||||
jsonName: jsonName,
|
||||
sqlName: sqlName,
|
||||
goType: goType,
|
||||
sqlType: d.SQLDataType,
|
||||
isPrimary: isPrimary,
|
||||
isUnique: d.SQLKey == "unique" || d.SQLKey == "uniqueindex",
|
||||
isFK: d.SQLKey == "foreign_key",
|
||||
nullable: d.Nullable,
|
||||
}
|
||||
info.columns = append(info.columns, ci)
|
||||
}
|
||||
|
||||
info.schemaDoc = buildSchemaDoc(info)
|
||||
return info
|
||||
}
|
||||
|
||||
// fieldJSONName returns the JSON tag name for a struct field, falling back to the field name.
|
||||
func fieldJSONName(modelType reflect.Type, fieldName string) string {
|
||||
field, ok := modelType.FieldByName(fieldName)
|
||||
if !ok {
|
||||
return fieldName
|
||||
}
|
||||
tag := field.Tag.Get("json")
|
||||
if tag == "" {
|
||||
return fieldName
|
||||
}
|
||||
parts := strings.SplitN(tag, ",", 2)
|
||||
if parts[0] == "" {
|
||||
return fieldName
|
||||
}
|
||||
return parts[0]
|
||||
}
|
||||
|
||||
// buildSchemaDoc builds a human-readable column listing for inclusion in tool descriptions.
|
||||
func buildSchemaDoc(info modelInfo) string {
|
||||
if len(info.columns) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
sb.WriteString("Columns:\n")
|
||||
for _, c := range info.columns {
|
||||
line := fmt.Sprintf(" • %s", c.jsonName)
|
||||
|
||||
typeDesc := c.goType
|
||||
if c.sqlType != "" {
|
||||
typeDesc = c.sqlType
|
||||
}
|
||||
if typeDesc != "" {
|
||||
line += fmt.Sprintf(" (%s)", typeDesc)
|
||||
}
|
||||
|
||||
var flags []string
|
||||
if c.isPrimary {
|
||||
flags = append(flags, "primary key")
|
||||
}
|
||||
if c.isUnique {
|
||||
flags = append(flags, "unique")
|
||||
}
|
||||
if c.isFK {
|
||||
flags = append(flags, "foreign key")
|
||||
}
|
||||
if !c.nullable && !c.isPrimary {
|
||||
flags = append(flags, "not null")
|
||||
} else if c.nullable {
|
||||
flags = append(flags, "nullable")
|
||||
}
|
||||
if len(flags) > 0 {
|
||||
line += " — " + strings.Join(flags, ", ")
|
||||
}
|
||||
|
||||
sb.WriteString(line + "\n")
|
||||
}
|
||||
|
||||
if len(info.relationNames) > 0 {
|
||||
sb.WriteString("Relations (preloadable): " + strings.Join(info.relationNames, ", ") + "\n")
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// columnNameList returns a comma-separated list of JSON column names (for descriptions).
|
||||
func columnNameList(cols []columnInfo) string {
|
||||
names := make([]string, len(cols))
|
||||
for i, c := range cols {
|
||||
names[i] = c.jsonName
|
||||
}
|
||||
return strings.Join(names, ", ")
|
||||
}
|
||||
|
||||
// writableColumnNames returns JSON names for all non-primary-key columns.
|
||||
func writableColumnNames(cols []columnInfo) []string {
|
||||
var names []string
|
||||
for _, c := range cols {
|
||||
if !c.isPrimary {
|
||||
names = append(names, c.jsonName)
|
||||
}
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Read tool
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
func registerReadTool(h *Handler, schema, entity string, info modelInfo) {
|
||||
name := toolName("read", schema, entity)
|
||||
|
||||
var descParts []string
|
||||
descParts = append(descParts, fmt.Sprintf("Read records from the '%s' database table.", info.fullName))
|
||||
if info.pkName != "" {
|
||||
descParts = append(descParts, fmt.Sprintf("Primary key: '%s'. Pass it via 'id' to fetch a single record.", info.pkName))
|
||||
}
|
||||
if info.schemaDoc != "" {
|
||||
descParts = append(descParts, info.schemaDoc)
|
||||
}
|
||||
descParts = append(descParts,
|
||||
"Pagination: use 'limit'/'offset' for offset-based paging, or 'cursor_forward'/'cursor_backward' (pass the primary key value of the last/first record on the current page) for cursor-based paging.",
|
||||
"Filtering: each filter object requires 'column' (JSON field name) and 'operator'. Supported operators: = != > < >= <= like ilike in is_null is_not_null. Combine with 'logic_operator': AND (default) or OR.",
|
||||
"Sorting: each sort object requires 'column' and 'direction' (asc or desc).",
|
||||
)
|
||||
if len(info.relationNames) > 0 {
|
||||
descParts = append(descParts, fmt.Sprintf("Preloadable relations: %s. Pass relation name in 'preloads'.", strings.Join(info.relationNames, ", ")))
|
||||
}
|
||||
|
||||
description := strings.Join(descParts, "\n\n")
|
||||
|
||||
filterDesc := `Array of filter objects. Example: [{"column":"status","operator":"=","value":"active"},{"column":"age","operator":">","value":18,"logic_operator":"AND"}]`
|
||||
if len(info.columns) > 0 {
|
||||
filterDesc += fmt.Sprintf(" Available columns: %s.", columnNameList(info.columns))
|
||||
}
|
||||
|
||||
sortDesc := `Array of sort objects. Example: [{"column":"created_at","direction":"desc"}]`
|
||||
if len(info.columns) > 0 {
|
||||
sortDesc += fmt.Sprintf(" Available columns: %s.", columnNameList(info.columns))
|
||||
}
|
||||
|
||||
tool := mcp.NewTool(name,
|
||||
mcp.WithDescription(description),
|
||||
mcp.WithString("id",
|
||||
mcp.Description(fmt.Sprintf("Primary key (%s) of a single record to fetch. Omit to return multiple records.", info.pkName)),
|
||||
),
|
||||
mcp.WithNumber("limit",
|
||||
mcp.Description("Maximum number of records to return per page. Recommended: 10–100."),
|
||||
),
|
||||
mcp.WithNumber("offset",
|
||||
mcp.Description("Number of records to skip (for offset-based pagination). Use with 'limit'."),
|
||||
),
|
||||
mcp.WithString("cursor_forward",
|
||||
mcp.Description(fmt.Sprintf("Cursor for the next page: pass the '%s' value of the last record on the current page. Requires 'sort' to be set.", info.pkName)),
|
||||
),
|
||||
mcp.WithString("cursor_backward",
|
||||
mcp.Description(fmt.Sprintf("Cursor for the previous page: pass the '%s' value of the first record on the current page. Requires 'sort' to be set.", info.pkName)),
|
||||
),
|
||||
mcp.WithArray("columns",
|
||||
mcp.Description(fmt.Sprintf("Columns to include in the result. Omit to return all columns. Available: %s.", columnNameList(info.columns))),
|
||||
),
|
||||
mcp.WithArray("omit_columns",
|
||||
mcp.Description(fmt.Sprintf("Columns to exclude from the result. Available: %s.", columnNameList(info.columns))),
|
||||
),
|
||||
mcp.WithArray("filters",
|
||||
mcp.Description(filterDesc),
|
||||
),
|
||||
mcp.WithArray("sort",
|
||||
mcp.Description(sortDesc),
|
||||
),
|
||||
mcp.WithArray("preloads",
|
||||
mcp.Description(buildPreloadDesc(info)),
|
||||
),
|
||||
)
|
||||
|
||||
h.mcpServer.AddTool(tool, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
args := req.GetArguments()
|
||||
id, _ := args["id"].(string)
|
||||
options := parseRequestOptions(args)
|
||||
|
||||
data, metadata, err := h.executeRead(ctx, schema, entity, id, options)
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(err.Error()), nil
|
||||
}
|
||||
|
||||
return marshalResult(map[string]interface{}{
|
||||
"success": true,
|
||||
"data": data,
|
||||
"metadata": metadata,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func buildPreloadDesc(info modelInfo) string {
|
||||
if len(info.relationNames) == 0 {
|
||||
return `Array of relation preload objects. Each object: {"relation":"RelationName"}. No relations defined on this model.`
|
||||
}
|
||||
return fmt.Sprintf(
|
||||
`Array of relation preload objects. Each object: {"relation":"RelationName","columns":["col1","col2"]}. Available relations: %s.`,
|
||||
strings.Join(info.relationNames, ", "),
|
||||
)
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Create tool
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
func registerCreateTool(h *Handler, schema, entity string, info modelInfo) {
|
||||
name := toolName("create", schema, entity)
|
||||
|
||||
writable := writableColumnNames(info.columns)
|
||||
|
||||
var descParts []string
|
||||
descParts = append(descParts, fmt.Sprintf("Create one or more new records in the '%s' table.", info.fullName))
|
||||
if len(writable) > 0 {
|
||||
descParts = append(descParts, fmt.Sprintf("Writable fields: %s.", strings.Join(writable, ", ")))
|
||||
}
|
||||
if info.pkName != "" {
|
||||
descParts = append(descParts, fmt.Sprintf("The primary key ('%s') is typically auto-generated — omit it unless you need to supply it explicitly.", info.pkName))
|
||||
}
|
||||
descParts = append(descParts,
|
||||
"Pass a single JSON object to 'data' to create one record. Pass an array of objects to create multiple records in a single transaction (all succeed or all fail).",
|
||||
)
|
||||
if info.schemaDoc != "" {
|
||||
descParts = append(descParts, info.schemaDoc)
|
||||
}
|
||||
|
||||
description := strings.Join(descParts, "\n\n")
|
||||
|
||||
dataDesc := "Record fields to create."
|
||||
if len(writable) > 0 {
|
||||
dataDesc += fmt.Sprintf(" Writable fields: %s.", strings.Join(writable, ", "))
|
||||
}
|
||||
dataDesc += " Pass a single object or an array of objects."
|
||||
|
||||
tool := mcp.NewTool(name,
|
||||
mcp.WithDescription(description),
|
||||
mcp.WithObject("data",
|
||||
mcp.Description(dataDesc),
|
||||
mcp.Required(),
|
||||
),
|
||||
)
|
||||
|
||||
h.mcpServer.AddTool(tool, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
args := req.GetArguments()
|
||||
data, ok := args["data"]
|
||||
if !ok {
|
||||
return mcp.NewToolResultError("missing required argument: data"), nil
|
||||
}
|
||||
|
||||
result, err := h.executeCreate(ctx, schema, entity, data)
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(err.Error()), nil
|
||||
}
|
||||
|
||||
return marshalResult(map[string]interface{}{
|
||||
"success": true,
|
||||
"data": result,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Update tool
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
func registerUpdateTool(h *Handler, schema, entity string, info modelInfo) {
|
||||
name := toolName("update", schema, entity)
|
||||
|
||||
writable := writableColumnNames(info.columns)
|
||||
|
||||
var descParts []string
|
||||
descParts = append(descParts, fmt.Sprintf("Update an existing record in the '%s' table.", info.fullName))
|
||||
if info.pkName != "" {
|
||||
descParts = append(descParts, fmt.Sprintf("Identify the record by its primary key ('%s') via the 'id' argument or by including '%s' inside 'data'.", info.pkName, info.pkName))
|
||||
}
|
||||
if len(writable) > 0 {
|
||||
descParts = append(descParts, fmt.Sprintf("Updatable fields: %s.", strings.Join(writable, ", ")))
|
||||
}
|
||||
descParts = append(descParts,
|
||||
"Only non-null, non-empty fields in 'data' are applied — existing values are preserved for fields you omit. Returns the merged record as stored.",
|
||||
)
|
||||
if info.schemaDoc != "" {
|
||||
descParts = append(descParts, info.schemaDoc)
|
||||
}
|
||||
|
||||
description := strings.Join(descParts, "\n\n")
|
||||
|
||||
idDesc := fmt.Sprintf("Primary key ('%s') of the record to update. Can also be included inside 'data'.", info.pkName)
|
||||
|
||||
dataDesc := "Fields to update (non-null, non-empty values are merged into the existing record)."
|
||||
if len(writable) > 0 {
|
||||
dataDesc += fmt.Sprintf(" Updatable fields: %s.", strings.Join(writable, ", "))
|
||||
}
|
||||
|
||||
tool := mcp.NewTool(name,
|
||||
mcp.WithDescription(description),
|
||||
mcp.WithString("id",
|
||||
mcp.Description(idDesc),
|
||||
),
|
||||
mcp.WithObject("data",
|
||||
mcp.Description(dataDesc),
|
||||
mcp.Required(),
|
||||
),
|
||||
)
|
||||
|
||||
h.mcpServer.AddTool(tool, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
args := req.GetArguments()
|
||||
id, _ := args["id"].(string)
|
||||
|
||||
data, ok := args["data"]
|
||||
if !ok {
|
||||
return mcp.NewToolResultError("missing required argument: data"), nil
|
||||
}
|
||||
dataMap, ok := data.(map[string]interface{})
|
||||
if !ok {
|
||||
return mcp.NewToolResultError("data must be an object"), nil
|
||||
}
|
||||
|
||||
result, err := h.executeUpdate(ctx, schema, entity, id, dataMap)
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(err.Error()), nil
|
||||
}
|
||||
|
||||
return marshalResult(map[string]interface{}{
|
||||
"success": true,
|
||||
"data": result,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Delete tool
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
func registerDeleteTool(h *Handler, schema, entity string, info modelInfo) {
|
||||
name := toolName("delete", schema, entity)
|
||||
|
||||
descParts := []string{
|
||||
fmt.Sprintf("Delete a record from the '%s' table by its primary key.", info.fullName),
|
||||
}
|
||||
if info.pkName != "" {
|
||||
descParts = append(descParts, fmt.Sprintf("Pass the '%s' value of the record to delete via the 'id' argument.", info.pkName))
|
||||
}
|
||||
descParts = append(descParts, "Returns the deleted record. This operation is irreversible.")
|
||||
|
||||
description := strings.Join(descParts, " ")
|
||||
|
||||
tool := mcp.NewTool(name,
|
||||
mcp.WithDescription(description),
|
||||
mcp.WithString("id",
|
||||
mcp.Description(fmt.Sprintf("Primary key ('%s') of the record to delete.", info.pkName)),
|
||||
mcp.Required(),
|
||||
),
|
||||
)
|
||||
|
||||
h.mcpServer.AddTool(tool, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
args := req.GetArguments()
|
||||
id, _ := args["id"].(string)
|
||||
|
||||
result, err := h.executeDelete(ctx, schema, entity, id)
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(err.Error()), nil
|
||||
}
|
||||
|
||||
return marshalResult(map[string]interface{}{
|
||||
"success": true,
|
||||
"data": result,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Resource registration
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
func registerModelResource(h *Handler, schema, entity string, info modelInfo) {
|
||||
resourceURI := info.fullName
|
||||
|
||||
var resourceDesc strings.Builder
|
||||
fmt.Fprintf(&resourceDesc, "Database table: %s", info.fullName)
|
||||
if info.pkName != "" {
|
||||
fmt.Fprintf(&resourceDesc, " (primary key: %s)", info.pkName)
|
||||
}
|
||||
if info.schemaDoc != "" {
|
||||
resourceDesc.WriteString("\n\n")
|
||||
resourceDesc.WriteString(info.schemaDoc)
|
||||
}
|
||||
|
||||
resource := mcp.NewResource(
|
||||
resourceURI,
|
||||
entity,
|
||||
mcp.WithResourceDescription(resourceDesc.String()),
|
||||
mcp.WithMIMEType("application/json"),
|
||||
)
|
||||
|
||||
h.mcpServer.AddResource(resource, func(ctx context.Context, req mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) {
|
||||
limit := 100
|
||||
options := common.RequestOptions{Limit: &limit}
|
||||
|
||||
data, metadata, err := h.executeRead(ctx, schema, entity, "", options)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
payload := map[string]interface{}{
|
||||
"data": data,
|
||||
"metadata": metadata,
|
||||
}
|
||||
jsonBytes, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error marshaling resource: %w", err)
|
||||
}
|
||||
|
||||
return []mcp.ResourceContents{
|
||||
mcp.TextResourceContents{
|
||||
URI: req.Params.URI,
|
||||
MIMEType: "application/json",
|
||||
Text: string(jsonBytes),
|
||||
},
|
||||
}, nil
|
||||
})
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Argument parsing helpers
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// parseRequestOptions converts raw MCP tool arguments into common.RequestOptions.
|
||||
func parseRequestOptions(args map[string]interface{}) common.RequestOptions {
|
||||
options := common.RequestOptions{}
|
||||
|
||||
if v, ok := args["limit"]; ok {
|
||||
switch n := v.(type) {
|
||||
case float64:
|
||||
limit := int(n)
|
||||
options.Limit = &limit
|
||||
case int:
|
||||
options.Limit = &n
|
||||
}
|
||||
}
|
||||
|
||||
if v, ok := args["offset"]; ok {
|
||||
switch n := v.(type) {
|
||||
case float64:
|
||||
offset := int(n)
|
||||
options.Offset = &offset
|
||||
case int:
|
||||
options.Offset = &n
|
||||
}
|
||||
}
|
||||
|
||||
if v, ok := args["cursor_forward"].(string); ok {
|
||||
options.CursorForward = v
|
||||
}
|
||||
if v, ok := args["cursor_backward"].(string); ok {
|
||||
options.CursorBackward = v
|
||||
}
|
||||
|
||||
options.Columns = parseStringArray(args["columns"])
|
||||
options.OmitColumns = parseStringArray(args["omit_columns"])
|
||||
options.Filters = parseFilters(args["filters"])
|
||||
options.Sort = parseSortOptions(args["sort"])
|
||||
options.Preload = parsePreloadOptions(args["preloads"])
|
||||
|
||||
return options
|
||||
}
|
||||
|
||||
func parseStringArray(raw interface{}) []string {
|
||||
if raw == nil {
|
||||
return nil
|
||||
}
|
||||
items, ok := raw.([]interface{})
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
result := make([]string, 0, len(items))
|
||||
for _, item := range items {
|
||||
if s, ok := item.(string); ok {
|
||||
result = append(result, s)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func parseFilters(raw interface{}) []common.FilterOption {
|
||||
if raw == nil {
|
||||
return nil
|
||||
}
|
||||
items, ok := raw.([]interface{})
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
result := make([]common.FilterOption, 0, len(items))
|
||||
for _, item := range items {
|
||||
b, err := json.Marshal(item)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
var f common.FilterOption
|
||||
if err := json.Unmarshal(b, &f); err != nil {
|
||||
continue
|
||||
}
|
||||
if f.Column == "" || f.Operator == "" {
|
||||
continue
|
||||
}
|
||||
if strings.EqualFold(f.LogicOperator, "or") {
|
||||
f.LogicOperator = "OR"
|
||||
} else {
|
||||
f.LogicOperator = "AND"
|
||||
}
|
||||
result = append(result, f)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func parseSortOptions(raw interface{}) []common.SortOption {
|
||||
if raw == nil {
|
||||
return nil
|
||||
}
|
||||
items, ok := raw.([]interface{})
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
result := make([]common.SortOption, 0, len(items))
|
||||
for _, item := range items {
|
||||
b, err := json.Marshal(item)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
var s common.SortOption
|
||||
if err := json.Unmarshal(b, &s); err != nil {
|
||||
continue
|
||||
}
|
||||
if s.Column == "" {
|
||||
continue
|
||||
}
|
||||
result = append(result, s)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func parsePreloadOptions(raw interface{}) []common.PreloadOption {
|
||||
if raw == nil {
|
||||
return nil
|
||||
}
|
||||
items, ok := raw.([]interface{})
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
result := make([]common.PreloadOption, 0, len(items))
|
||||
for _, item := range items {
|
||||
b, err := json.Marshal(item)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
var p common.PreloadOption
|
||||
if err := json.Unmarshal(b, &p); err != nil {
|
||||
continue
|
||||
}
|
||||
if p.Relation == "" {
|
||||
continue
|
||||
}
|
||||
result = append(result, p)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// marshalResult marshals a value to JSON and returns it as an MCP text result.
|
||||
func marshalResult(v interface{}) (*mcp.CallToolResult, error) {
|
||||
b, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("error marshaling result: %v", err)), nil
|
||||
}
|
||||
return mcp.NewToolResultText(string(b)), nil
|
||||
}
|
||||
@@ -24,6 +24,7 @@ const (
|
||||
// - pkName: primary key column (e.g. "id")
|
||||
// - modelColumns: optional list of valid main-table columns (for validation). Pass nil to skip.
|
||||
// - options: the request options containing sort and cursor information
|
||||
// - expandJoins: optional map[alias]string of JOIN clauses for join-column sort support
|
||||
//
|
||||
// Returns SQL snippet to embed in WHERE clause.
|
||||
func GetCursorFilter(
|
||||
@@ -31,6 +32,7 @@ func GetCursorFilter(
|
||||
pkName string,
|
||||
modelColumns []string,
|
||||
options common.RequestOptions,
|
||||
expandJoins map[string]string,
|
||||
) (string, error) {
|
||||
// Separate schema prefix from bare table name
|
||||
fullTableName := tableName
|
||||
@@ -58,6 +60,7 @@ func GetCursorFilter(
|
||||
// 3. Prepare
|
||||
// --------------------------------------------------------------------- //
|
||||
var whereClauses []string
|
||||
joinSQL := ""
|
||||
reverse := direction < 0
|
||||
|
||||
// --------------------------------------------------------------------- //
|
||||
@@ -69,7 +72,7 @@ func GetCursorFilter(
|
||||
continue
|
||||
}
|
||||
|
||||
// Parse: "created_at", "user.name", etc.
|
||||
// Parse: "created_at", "user.name", "fn.sortorder", etc.
|
||||
parts := strings.Split(col, ".")
|
||||
field := strings.TrimSpace(parts[len(parts)-1])
|
||||
prefix := strings.Join(parts[:len(parts)-1], ".")
|
||||
@@ -82,7 +85,7 @@ func GetCursorFilter(
|
||||
}
|
||||
|
||||
// Resolve column
|
||||
cursorCol, targetCol, err := resolveColumn(
|
||||
cursorCol, targetCol, isJoin, err := resolveColumn(
|
||||
field, prefix, tableName, modelColumns,
|
||||
)
|
||||
if err != nil {
|
||||
@@ -90,6 +93,22 @@ func GetCursorFilter(
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle joins
|
||||
if isJoin {
|
||||
if expandJoins != nil {
|
||||
if joinClause, ok := expandJoins[prefix]; ok {
|
||||
jSQL, cRef := rewriteJoin(joinClause, tableName, prefix)
|
||||
joinSQL = jSQL
|
||||
cursorCol = cRef + "." + field
|
||||
targetCol = prefix + "." + field
|
||||
}
|
||||
}
|
||||
if cursorCol == "" {
|
||||
logger.Warn("Skipping cursor sort column %q: join alias %q not in expandJoins", col, prefix)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Build inequality
|
||||
op := "<"
|
||||
if desc {
|
||||
@@ -113,10 +132,12 @@ func GetCursorFilter(
|
||||
query := fmt.Sprintf(`EXISTS (
|
||||
SELECT 1
|
||||
FROM %s cursor_select
|
||||
%s
|
||||
WHERE cursor_select.%s = %s
|
||||
AND (%s)
|
||||
)`,
|
||||
fullTableName,
|
||||
joinSQL,
|
||||
pkName,
|
||||
cursorID,
|
||||
orSQL,
|
||||
@@ -137,35 +158,44 @@ func getActiveCursor(options common.RequestOptions) (id string, direction Cursor
|
||||
return "", 0
|
||||
}
|
||||
|
||||
// Helper: resolve column (main table only for now)
|
||||
// Helper: resolve column (main table or join)
|
||||
func resolveColumn(
|
||||
field, prefix, tableName string,
|
||||
modelColumns []string,
|
||||
) (cursorCol, targetCol string, err error) {
|
||||
) (cursorCol, targetCol string, isJoin bool, err error) {
|
||||
|
||||
// JSON field
|
||||
if strings.Contains(field, "->") {
|
||||
return "cursor_select." + field, tableName + "." + field, nil
|
||||
return "cursor_select." + field, tableName + "." + field, false, nil
|
||||
}
|
||||
|
||||
// Main table column
|
||||
if modelColumns != nil {
|
||||
for _, col := range modelColumns {
|
||||
if strings.EqualFold(col, field) {
|
||||
return "cursor_select." + field, tableName + "." + field, nil
|
||||
return "cursor_select." + field, tableName + "." + field, false, nil
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// No validation → allow all main-table fields
|
||||
return "cursor_select." + field, tableName + "." + field, nil
|
||||
return "cursor_select." + field, tableName + "." + field, false, nil
|
||||
}
|
||||
|
||||
// Joined column (not supported in resolvespec yet)
|
||||
// Joined column
|
||||
if prefix != "" && prefix != tableName {
|
||||
return "", "", fmt.Errorf("joined columns not supported in cursor pagination: %s", field)
|
||||
return "", "", true, nil
|
||||
}
|
||||
|
||||
return "", "", fmt.Errorf("invalid column: %s", field)
|
||||
return "", "", false, fmt.Errorf("invalid column: %s", field)
|
||||
}
|
||||
|
||||
// Helper: rewrite JOIN clause for cursor subquery
|
||||
func rewriteJoin(joinClause, mainTable, alias string) (joinSQL, cursorAlias string) {
|
||||
joinSQL = strings.ReplaceAll(joinClause, mainTable+".", "cursor_select.")
|
||||
cursorAlias = "cursor_select_" + alias
|
||||
joinSQL = strings.ReplaceAll(joinSQL, " "+alias+" ", " "+cursorAlias+" ")
|
||||
joinSQL = strings.ReplaceAll(joinSQL, " "+alias+".", " "+cursorAlias+".")
|
||||
return joinSQL, cursorAlias
|
||||
}
|
||||
|
||||
// ------------------------------------------------------------------------- //
|
||||
|
||||
@@ -20,7 +20,7 @@ func TestGetCursorFilter_Forward(t *testing.T) {
|
||||
pkName := "id"
|
||||
modelColumns := []string{"id", "title", "created_at", "user_id"}
|
||||
|
||||
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options)
|
||||
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("GetCursorFilter failed: %v", err)
|
||||
}
|
||||
@@ -65,7 +65,7 @@ func TestGetCursorFilter_Backward(t *testing.T) {
|
||||
pkName := "id"
|
||||
modelColumns := []string{"id", "title", "created_at", "user_id"}
|
||||
|
||||
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options)
|
||||
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("GetCursorFilter failed: %v", err)
|
||||
}
|
||||
@@ -96,7 +96,7 @@ func TestGetCursorFilter_NoCursor(t *testing.T) {
|
||||
pkName := "id"
|
||||
modelColumns := []string{"id", "title", "created_at"}
|
||||
|
||||
_, err := GetCursorFilter(tableName, pkName, modelColumns, options)
|
||||
_, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil)
|
||||
if err == nil {
|
||||
t.Error("Expected error when no cursor is provided")
|
||||
}
|
||||
@@ -116,7 +116,7 @@ func TestGetCursorFilter_NoSort(t *testing.T) {
|
||||
pkName := "id"
|
||||
modelColumns := []string{"id", "title"}
|
||||
|
||||
_, err := GetCursorFilter(tableName, pkName, modelColumns, options)
|
||||
_, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil)
|
||||
if err == nil {
|
||||
t.Error("Expected error when no sort columns are defined")
|
||||
}
|
||||
@@ -140,7 +140,7 @@ func TestGetCursorFilter_MultiColumnSort(t *testing.T) {
|
||||
pkName := "id"
|
||||
modelColumns := []string{"id", "title", "priority", "created_at"}
|
||||
|
||||
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options)
|
||||
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("GetCursorFilter failed: %v", err)
|
||||
}
|
||||
@@ -170,7 +170,7 @@ func TestGetCursorFilter_WithSchemaPrefix(t *testing.T) {
|
||||
pkName := "id"
|
||||
modelColumns := []string{"id", "name", "email"}
|
||||
|
||||
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options)
|
||||
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("GetCursorFilter failed: %v", err)
|
||||
}
|
||||
@@ -183,6 +183,37 @@ func TestGetCursorFilter_WithSchemaPrefix(t *testing.T) {
|
||||
t.Logf("Generated cursor filter with schema: %s", filter)
|
||||
}
|
||||
|
||||
func TestGetCursorFilter_LateralJoin(t *testing.T) {
|
||||
lateralJoin := "inner join lateral (\nselect string_agg(a.name, '.') as sortorder\nfrom tree(account.rid_account) r\ninner join account a on a.id = r.id\n) fn on true"
|
||||
|
||||
options := common.RequestOptions{
|
||||
Sort: []common.SortOption{{Column: "fn.sortorder", Direction: "ASC"}},
|
||||
CursorForward: "8975",
|
||||
}
|
||||
|
||||
tableName := "core.account"
|
||||
pkName := "rid_account"
|
||||
modelColumns := []string{"rid_account", "description", "pastelno"}
|
||||
expandJoins := map[string]string{"fn": lateralJoin}
|
||||
|
||||
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options, expandJoins)
|
||||
if err != nil {
|
||||
t.Fatalf("GetCursorFilter failed: %v", err)
|
||||
}
|
||||
|
||||
t.Logf("Generated lateral cursor filter: %s", filter)
|
||||
|
||||
if !strings.Contains(filter, "cursor_select_fn") {
|
||||
t.Errorf("Filter should reference cursor_select_fn alias, got: %s", filter)
|
||||
}
|
||||
if !strings.Contains(filter, "sortorder") {
|
||||
t.Errorf("Filter should reference sortorder column, got: %s", filter)
|
||||
}
|
||||
if strings.Contains(filter, " < ") || strings.Contains(filter, " > ") {
|
||||
t.Errorf("Filter should not contain empty comparison operators, got: %s", filter)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetActiveCursor(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -288,18 +319,19 @@ func TestResolveColumn(t *testing.T) {
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Joined column (not supported)",
|
||||
name: "Joined column (isJoin=true, no error)",
|
||||
field: "name",
|
||||
prefix: "user",
|
||||
tableName: "posts",
|
||||
modelColumns: []string{"id", "title"},
|
||||
wantErr: true,
|
||||
wantErr: false,
|
||||
// cursorCol and targetCol are empty when isJoin=true; handled by caller
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cursor, target, err := resolveColumn(tt.field, tt.prefix, tt.tableName, tt.modelColumns)
|
||||
cursor, target, isJoin, err := resolveColumn(tt.field, tt.prefix, tt.tableName, tt.modelColumns)
|
||||
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
@@ -312,6 +344,14 @@ func TestResolveColumn(t *testing.T) {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// For join columns, cursor/target are empty and isJoin=true
|
||||
if isJoin {
|
||||
if cursor != "" || target != "" {
|
||||
t.Errorf("Expected empty cursor/target for join column, got %q / %q", cursor, target)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if cursor != tt.wantCursor {
|
||||
t.Errorf("Expected cursor %q, got %q", tt.wantCursor, cursor)
|
||||
}
|
||||
@@ -362,7 +402,7 @@ func TestCursorFilter_SQL_Safety(t *testing.T) {
|
||||
pkName := "id"
|
||||
modelColumns := []string{"id", "created_at"}
|
||||
|
||||
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options)
|
||||
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("GetCursorFilter failed: %v", err)
|
||||
}
|
||||
|
||||
@@ -334,8 +334,8 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
||||
options.Sort = []common.SortOption{{Column: pkName, Direction: "ASC"}}
|
||||
}
|
||||
|
||||
// Get cursor filter SQL
|
||||
cursorFilter, err := GetCursorFilter(tableName, pkName, modelColumns, options)
|
||||
// Get cursor filter SQL (expandJoins is empty for resolvespec — no custom SQL join support yet)
|
||||
cursorFilter, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil)
|
||||
if err != nil {
|
||||
logger.Error("Error building cursor filter: %v", err)
|
||||
h.sendError(w, http.StatusBadRequest, "cursor_error", "Invalid cursor pagination", err)
|
||||
|
||||
@@ -93,12 +93,18 @@ func (opts *ExtendedRequestOptions) GetCursorFilter(
|
||||
}
|
||||
|
||||
// Handle joins
|
||||
if isJoin && expandJoins != nil {
|
||||
if joinClause, ok := expandJoins[prefix]; ok {
|
||||
jSQL, cRef := rewriteJoin(joinClause, tableName, prefix)
|
||||
joinSQL = jSQL
|
||||
cursorCol = cRef + "." + field
|
||||
targetCol = prefix + "." + field
|
||||
if isJoin {
|
||||
if expandJoins != nil {
|
||||
if joinClause, ok := expandJoins[prefix]; ok {
|
||||
jSQL, cRef := rewriteJoin(joinClause, tableName, prefix)
|
||||
joinSQL = jSQL
|
||||
cursorCol = cRef + "." + field
|
||||
targetCol = prefix + "." + field
|
||||
}
|
||||
}
|
||||
if cursorCol == "" {
|
||||
logger.Warn("Skipping cursor sort column %q: join alias %q not in expandJoins", col, prefix)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -278,6 +278,47 @@ func TestCleanSortField(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetCursorFilter_LateralJoin(t *testing.T) {
|
||||
lateralJoin := "inner join lateral (\nselect string_agg(a.name, '.') as sortorder\nfrom tree(account.rid_account) r\ninner join account a on a.id = r.id\n) fn on true"
|
||||
|
||||
opts := &ExtendedRequestOptions{
|
||||
RequestOptions: common.RequestOptions{
|
||||
Sort: []common.SortOption{
|
||||
{Column: "fn.sortorder", Direction: "ASC"},
|
||||
},
|
||||
},
|
||||
}
|
||||
opts.CursorForward = "8975"
|
||||
|
||||
tableName := "core.account"
|
||||
pkName := "rid_account"
|
||||
// modelColumns does not contain "sortorder" - it's a lateral join computed column
|
||||
modelColumns := []string{"rid_account", "description", "pastelno"}
|
||||
expandJoins := map[string]string{"fn": lateralJoin}
|
||||
|
||||
filter, err := opts.GetCursorFilter(tableName, pkName, modelColumns, expandJoins)
|
||||
if err != nil {
|
||||
t.Fatalf("GetCursorFilter failed: %v", err)
|
||||
}
|
||||
|
||||
t.Logf("Generated lateral cursor filter: %s", filter)
|
||||
|
||||
// Should contain the rewritten lateral join inside the EXISTS subquery
|
||||
if !strings.Contains(filter, "cursor_select_fn") {
|
||||
t.Errorf("Filter should reference cursor_select_fn alias, got: %s", filter)
|
||||
}
|
||||
|
||||
// Should compare fn.sortorder values
|
||||
if !strings.Contains(filter, "sortorder") {
|
||||
t.Errorf("Filter should reference sortorder column, got: %s", filter)
|
||||
}
|
||||
|
||||
// Should NOT contain empty comparison like "< "
|
||||
if strings.Contains(filter, " < ") || strings.Contains(filter, " > ") {
|
||||
t.Errorf("Filter should not contain empty comparison operators, got: %s", filter)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildPriorityChain(t *testing.T) {
|
||||
clauses := []string{
|
||||
"cursor_select.priority > posts.priority",
|
||||
|
||||
@@ -723,13 +723,15 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
||||
// Extract model columns for validation using the generic database function
|
||||
modelColumns := reflection.GetModelColumns(model)
|
||||
|
||||
// Build expand joins map (if needed in future)
|
||||
var expandJoins map[string]string
|
||||
if len(options.Expand) > 0 {
|
||||
expandJoins = make(map[string]string)
|
||||
// TODO: Build actual JOIN SQL for each expand relation
|
||||
// For now, pass empty map as joins are handled via Preload
|
||||
// Build expand joins map: custom SQL joins are available in cursor subquery
|
||||
expandJoins := make(map[string]string)
|
||||
for _, joinClause := range options.CustomSQLJoin {
|
||||
alias := extractJoinAlias(joinClause)
|
||||
if alias != "" {
|
||||
expandJoins[alias] = joinClause
|
||||
}
|
||||
}
|
||||
// TODO: also add Expand relation JOINs when those are built as SQL rather than Preload
|
||||
|
||||
// Default sort to primary key when none provided
|
||||
if len(options.Sort) == 0 {
|
||||
|
||||
@@ -552,10 +552,8 @@ func (h *Handler) parseCustomSQLJoin(options *ExtendedRequestOptions, value stri
|
||||
// - "LEFT JOIN departments d ON ..." -> "d"
|
||||
// - "INNER JOIN users AS u ON ..." -> "u"
|
||||
// - "JOIN roles r ON ..." -> "r"
|
||||
// - "INNER JOIN LATERAL (...) fn ON true" -> "fn"
|
||||
func extractJoinAlias(joinClause string) string {
|
||||
// Pattern: JOIN table_name [AS] alias ON ...
|
||||
// We need to extract the alias (word before ON)
|
||||
|
||||
upperJoin := strings.ToUpper(joinClause)
|
||||
|
||||
// Find the "JOIN" keyword position
|
||||
@@ -564,7 +562,20 @@ func extractJoinAlias(joinClause string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Find the "ON" keyword position
|
||||
// Lateral joins: alias is the word after the closing ) and before ON
|
||||
if strings.Contains(upperJoin, "LATERAL") {
|
||||
lastClose := strings.LastIndex(joinClause, ")")
|
||||
if lastClose != -1 {
|
||||
words := strings.Fields(joinClause[lastClose+1:])
|
||||
// words should be like ["fn", "on", "true"] or ["on", "true"]
|
||||
if len(words) >= 1 && !strings.EqualFold(words[0], "on") {
|
||||
return words[0]
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// Regular joins: find the "ON" keyword position (first occurrence)
|
||||
onIdx := strings.Index(upperJoin, " ON ")
|
||||
if onIdx == -1 {
|
||||
return ""
|
||||
|
||||
@@ -142,6 +142,16 @@ func TestExtractJoinAlias(t *testing.T) {
|
||||
joinClause: "LEFT JOIN departments",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "LATERAL join with alias",
|
||||
joinClause: "inner join lateral (select sortorder from compute_fn(t.id)) fn on true",
|
||||
expected: "fn",
|
||||
},
|
||||
{
|
||||
name: "LATERAL join with multiline subquery containing inner ON",
|
||||
joinClause: "inner join lateral (\nselect string_agg(a.name, '.') as sortorder\nfrom tree(t.id) r\ninner join account a on a.id = r.id\n) fn on true",
|
||||
expected: "fn",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
||||
Reference in New Issue
Block a user