mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2026-04-05 07:32:26 +00:00
feat(resolvemcp): add hook system for model operations
* Implement hooks for CRUD operations: before/after handle, read, create, update, delete. * Introduce HookContext and HookRegistry for managing hooks. * Allow registration and execution of multiple hooks per operation. feat(resolvemcp): implement MCP tools for CRUD operations * Register tools for reading, creating, updating, and deleting records. * Define tool arguments and handle requests with appropriate responses. * Support for resource registration with metadata. fix(restheadspec): enhance cursor handling for joins * Improve cursor filter generation to support lateral joins. * Update join alias extraction to handle lateral joins correctly. * Ensure cursor filters do not contain empty comparisons. test(restheadspec): add tests for cursor filters and join alias extraction * Create tests for lateral join scenarios in cursor filter generation. * Validate join alias extraction for various join types, including lateral joins.
This commit is contained in:
5
go.mod
5
go.mod
@@ -40,6 +40,7 @@ require (
|
|||||||
go.opentelemetry.io/otel/trace v1.38.0
|
go.opentelemetry.io/otel/trace v1.38.0
|
||||||
go.uber.org/zap v1.27.1
|
go.uber.org/zap v1.27.1
|
||||||
golang.org/x/crypto v0.46.0
|
golang.org/x/crypto v0.46.0
|
||||||
|
golang.org/x/oauth2 v0.34.0
|
||||||
golang.org/x/time v0.14.0
|
golang.org/x/time v0.14.0
|
||||||
gorm.io/driver/postgres v1.6.0
|
gorm.io/driver/postgres v1.6.0
|
||||||
gorm.io/driver/sqlite v1.6.0
|
gorm.io/driver/sqlite v1.6.0
|
||||||
@@ -78,6 +79,7 @@ require (
|
|||||||
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect
|
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect
|
||||||
github.com/golang-sql/sqlexp v0.1.0 // indirect
|
github.com/golang-sql/sqlexp v0.1.0 // indirect
|
||||||
github.com/golang/snappy v1.0.0 // indirect
|
github.com/golang/snappy v1.0.0 // indirect
|
||||||
|
github.com/google/jsonschema-go v0.4.2 // indirect
|
||||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 // indirect
|
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 // indirect
|
||||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
||||||
@@ -86,6 +88,7 @@ require (
|
|||||||
github.com/jinzhu/now v1.1.5 // indirect
|
github.com/jinzhu/now v1.1.5 // indirect
|
||||||
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
|
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
|
||||||
github.com/magiconair/properties v1.8.10 // indirect
|
github.com/magiconair/properties v1.8.10 // indirect
|
||||||
|
github.com/mark3labs/mcp-go v0.46.0 // indirect
|
||||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||||
github.com/moby/docker-image-spec v1.3.1 // indirect
|
github.com/moby/docker-image-spec v1.3.1 // indirect
|
||||||
github.com/moby/go-archive v0.1.0 // indirect
|
github.com/moby/go-archive v0.1.0 // indirect
|
||||||
@@ -131,6 +134,7 @@ require (
|
|||||||
github.com/xdg-go/pbkdf2 v1.0.0 // indirect
|
github.com/xdg-go/pbkdf2 v1.0.0 // indirect
|
||||||
github.com/xdg-go/scram v1.2.0 // indirect
|
github.com/xdg-go/scram v1.2.0 // indirect
|
||||||
github.com/xdg-go/stringprep v1.0.4 // indirect
|
github.com/xdg-go/stringprep v1.0.4 // indirect
|
||||||
|
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
|
||||||
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 // indirect
|
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 // indirect
|
||||||
github.com/yusufpapurcu/wmi v1.2.4 // indirect
|
github.com/yusufpapurcu/wmi v1.2.4 // indirect
|
||||||
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
|
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
|
||||||
@@ -143,7 +147,6 @@ require (
|
|||||||
golang.org/x/exp v0.0.0-20251219203646-944ab1f22d93 // indirect
|
golang.org/x/exp v0.0.0-20251219203646-944ab1f22d93 // indirect
|
||||||
golang.org/x/mod v0.31.0 // indirect
|
golang.org/x/mod v0.31.0 // indirect
|
||||||
golang.org/x/net v0.48.0 // indirect
|
golang.org/x/net v0.48.0 // indirect
|
||||||
golang.org/x/oauth2 v0.34.0 // indirect
|
|
||||||
golang.org/x/sync v0.19.0 // indirect
|
golang.org/x/sync v0.19.0 // indirect
|
||||||
golang.org/x/sys v0.39.0 // indirect
|
golang.org/x/sys v0.39.0 // indirect
|
||||||
golang.org/x/text v0.32.0 // indirect
|
golang.org/x/text v0.32.0 // indirect
|
||||||
|
|||||||
6
go.sum
6
go.sum
@@ -120,6 +120,8 @@ github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
|
|||||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||||
|
github.com/google/jsonschema-go v0.4.2 h1:tmrUohrwoLZZS/P3x7ex0WAVknEkBZM46iALbcqoRA8=
|
||||||
|
github.com/google/jsonschema-go v0.4.2/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE=
|
||||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
|
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
|
||||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
|
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
|
||||||
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||||
@@ -173,6 +175,8 @@ github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ
|
|||||||
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I=
|
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I=
|
||||||
github.com/magiconair/properties v1.8.10 h1:s31yESBquKXCV9a/ScB3ESkOjUYYv+X0rg8SYxI99mE=
|
github.com/magiconair/properties v1.8.10 h1:s31yESBquKXCV9a/ScB3ESkOjUYYv+X0rg8SYxI99mE=
|
||||||
github.com/magiconair/properties v1.8.10/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
|
github.com/magiconair/properties v1.8.10/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
|
||||||
|
github.com/mark3labs/mcp-go v0.46.0 h1:8KRibF4wcKejbLsHxCA/QBVUr5fQ9nwz/n8lGqmaALo=
|
||||||
|
github.com/mark3labs/mcp-go v0.46.0/go.mod h1:JKTC7R2LLVagkEWK7Kwu7DbmA6iIvnNAod6yrHiQMag=
|
||||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||||
github.com/mattn/go-sqlite3 v1.14.33 h1:A5blZ5ulQo2AtayQ9/limgHEkFreKj1Dv226a1K73s0=
|
github.com/mattn/go-sqlite3 v1.14.33 h1:A5blZ5ulQo2AtayQ9/limgHEkFreKj1Dv226a1K73s0=
|
||||||
@@ -326,6 +330,8 @@ github.com/xdg-go/scram v1.2.0 h1:bYKF2AEwG5rqd1BumT4gAnvwU/M9nBp2pTSxeZw7Wvs=
|
|||||||
github.com/xdg-go/scram v1.2.0/go.mod h1:3dlrS0iBaWKYVt2ZfA4cj48umJZ+cAEbR6/SjLA88I8=
|
github.com/xdg-go/scram v1.2.0/go.mod h1:3dlrS0iBaWKYVt2ZfA4cj48umJZ+cAEbR6/SjLA88I8=
|
||||||
github.com/xdg-go/stringprep v1.0.4 h1:XLI/Ng3O1Atzq0oBs3TWm+5ZVgkq2aqdlvP9JtoZ6c8=
|
github.com/xdg-go/stringprep v1.0.4 h1:XLI/Ng3O1Atzq0oBs3TWm+5ZVgkq2aqdlvP9JtoZ6c8=
|
||||||
github.com/xdg-go/stringprep v1.0.4/go.mod h1:mPGuuIYwz7CmR2bT9j4GbQqutWS1zV24gijq1dTyGkM=
|
github.com/xdg-go/stringprep v1.0.4/go.mod h1:mPGuuIYwz7CmR2bT9j4GbQqutWS1zV24gijq1dTyGkM=
|
||||||
|
github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
|
||||||
|
github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
|
||||||
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 h1:ilQV1hzziu+LLM3zUTJ0trRztfwgjqKnBWNtSRkbmwM=
|
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 h1:ilQV1hzziu+LLM3zUTJ0trRztfwgjqKnBWNtSRkbmwM=
|
||||||
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78/go.mod h1:aL8wCCfTfSfmXjznFBSZNN13rSJjlIOI1fUNAtF7rmI=
|
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78/go.mod h1:aL8wCCfTfSfmXjznFBSZNN13rSJjlIOI1fUNAtF7rmI=
|
||||||
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
||||||
|
|||||||
71
pkg/resolvemcp/context.go
Normal file
71
pkg/resolvemcp/context.go
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
package resolvemcp
|
||||||
|
|
||||||
|
import "context"
|
||||||
|
|
||||||
|
type contextKey string
|
||||||
|
|
||||||
|
const (
|
||||||
|
contextKeySchema contextKey = "schema"
|
||||||
|
contextKeyEntity contextKey = "entity"
|
||||||
|
contextKeyTableName contextKey = "tableName"
|
||||||
|
contextKeyModel contextKey = "model"
|
||||||
|
contextKeyModelPtr contextKey = "modelPtr"
|
||||||
|
)
|
||||||
|
|
||||||
|
func WithSchema(ctx context.Context, schema string) context.Context {
|
||||||
|
return context.WithValue(ctx, contextKeySchema, schema)
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetSchema(ctx context.Context) string {
|
||||||
|
if v := ctx.Value(contextKeySchema); v != nil {
|
||||||
|
return v.(string)
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithEntity(ctx context.Context, entity string) context.Context {
|
||||||
|
return context.WithValue(ctx, contextKeyEntity, entity)
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetEntity(ctx context.Context) string {
|
||||||
|
if v := ctx.Value(contextKeyEntity); v != nil {
|
||||||
|
return v.(string)
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithTableName(ctx context.Context, tableName string) context.Context {
|
||||||
|
return context.WithValue(ctx, contextKeyTableName, tableName)
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetTableName(ctx context.Context) string {
|
||||||
|
if v := ctx.Value(contextKeyTableName); v != nil {
|
||||||
|
return v.(string)
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithModel(ctx context.Context, model interface{}) context.Context {
|
||||||
|
return context.WithValue(ctx, contextKeyModel, model)
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetModel(ctx context.Context) interface{} {
|
||||||
|
return ctx.Value(contextKeyModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithModelPtr(ctx context.Context, modelPtr interface{}) context.Context {
|
||||||
|
return context.WithValue(ctx, contextKeyModelPtr, modelPtr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetModelPtr(ctx context.Context) interface{} {
|
||||||
|
return ctx.Value(contextKeyModelPtr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func withRequestData(ctx context.Context, schema, entity, tableName string, model, modelPtr interface{}) context.Context {
|
||||||
|
ctx = WithSchema(ctx, schema)
|
||||||
|
ctx = WithEntity(ctx, entity)
|
||||||
|
ctx = WithTableName(ctx, tableName)
|
||||||
|
ctx = WithModel(ctx, model)
|
||||||
|
ctx = WithModelPtr(ctx, modelPtr)
|
||||||
|
return ctx
|
||||||
|
}
|
||||||
133
pkg/resolvemcp/cursor.go
Normal file
133
pkg/resolvemcp/cursor.go
Normal file
@@ -0,0 +1,133 @@
|
|||||||
|
package resolvemcp
|
||||||
|
|
||||||
|
// Cursor-based pagination adapted from pkg/resolvespec/cursor.go.
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
type cursorDirection int
|
||||||
|
|
||||||
|
const (
|
||||||
|
cursorForward cursorDirection = 1
|
||||||
|
cursorBackward cursorDirection = -1
|
||||||
|
)
|
||||||
|
|
||||||
|
// getCursorFilter generates a SQL EXISTS subquery for cursor-based pagination.
|
||||||
|
func getCursorFilter(
|
||||||
|
tableName string,
|
||||||
|
pkName string,
|
||||||
|
modelColumns []string,
|
||||||
|
options common.RequestOptions,
|
||||||
|
) (string, error) {
|
||||||
|
fullTableName := tableName
|
||||||
|
if strings.Contains(tableName, ".") {
|
||||||
|
tableName = strings.SplitN(tableName, ".", 2)[1]
|
||||||
|
}
|
||||||
|
|
||||||
|
cursorID, direction := getActiveCursor(options)
|
||||||
|
if cursorID == "" {
|
||||||
|
return "", fmt.Errorf("no cursor provided for table %s", tableName)
|
||||||
|
}
|
||||||
|
|
||||||
|
sortItems := options.Sort
|
||||||
|
if len(sortItems) == 0 {
|
||||||
|
return "", fmt.Errorf("no sort columns defined")
|
||||||
|
}
|
||||||
|
|
||||||
|
var whereClauses []string
|
||||||
|
reverse := direction < 0
|
||||||
|
|
||||||
|
for _, s := range sortItems {
|
||||||
|
col := strings.TrimSpace(s.Column)
|
||||||
|
if col == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
parts := strings.Split(col, ".")
|
||||||
|
field := strings.TrimSpace(parts[len(parts)-1])
|
||||||
|
prefix := strings.Join(parts[:len(parts)-1], ".")
|
||||||
|
|
||||||
|
desc := strings.EqualFold(s.Direction, "desc")
|
||||||
|
if reverse {
|
||||||
|
desc = !desc
|
||||||
|
}
|
||||||
|
|
||||||
|
cursorCol, targetCol, err := resolveCursorColumn(field, prefix, tableName, modelColumns)
|
||||||
|
if err != nil {
|
||||||
|
logger.Warn("Skipping invalid sort column %q: %v", col, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
op := "<"
|
||||||
|
if desc {
|
||||||
|
op = ">"
|
||||||
|
}
|
||||||
|
whereClauses = append(whereClauses, fmt.Sprintf("%s %s %s", cursorCol, op, targetCol))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(whereClauses) == 0 {
|
||||||
|
return "", fmt.Errorf("no valid sort columns after filtering")
|
||||||
|
}
|
||||||
|
|
||||||
|
orSQL := buildCursorPriorityChain(whereClauses)
|
||||||
|
|
||||||
|
query := fmt.Sprintf(`EXISTS (
|
||||||
|
SELECT 1
|
||||||
|
FROM %s cursor_select
|
||||||
|
WHERE cursor_select.%s = %s
|
||||||
|
AND (%s)
|
||||||
|
)`,
|
||||||
|
fullTableName,
|
||||||
|
pkName,
|
||||||
|
cursorID,
|
||||||
|
orSQL,
|
||||||
|
)
|
||||||
|
|
||||||
|
return query, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getActiveCursor(options common.RequestOptions) (id string, direction cursorDirection) {
|
||||||
|
if options.CursorForward != "" {
|
||||||
|
return options.CursorForward, cursorForward
|
||||||
|
}
|
||||||
|
if options.CursorBackward != "" {
|
||||||
|
return options.CursorBackward, cursorBackward
|
||||||
|
}
|
||||||
|
return "", 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func resolveCursorColumn(field, prefix, tableName string, modelColumns []string) (cursorCol, targetCol string, err error) {
|
||||||
|
if strings.Contains(field, "->") {
|
||||||
|
return "cursor_select." + field, tableName + "." + field, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if modelColumns != nil {
|
||||||
|
for _, col := range modelColumns {
|
||||||
|
if strings.EqualFold(col, field) {
|
||||||
|
return "cursor_select." + field, tableName + "." + field, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return "cursor_select." + field, tableName + "." + field, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if prefix != "" && prefix != tableName {
|
||||||
|
return "", "", fmt.Errorf("joined columns not supported in cursor pagination: %s", field)
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", "", fmt.Errorf("invalid column: %s", field)
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildCursorPriorityChain(clauses []string) string {
|
||||||
|
var or []string
|
||||||
|
for i := 0; i < len(clauses); i++ {
|
||||||
|
and := strings.Join(clauses[:i+1], "\n AND ")
|
||||||
|
or = append(or, "("+and+")")
|
||||||
|
}
|
||||||
|
return strings.Join(or, "\n OR ")
|
||||||
|
}
|
||||||
644
pkg/resolvemcp/handler.go
Normal file
644
pkg/resolvemcp/handler.go
Normal file
@@ -0,0 +1,644 @@
|
|||||||
|
package resolvemcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/mark3labs/mcp-go/server"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Handler exposes registered database models as MCP tools and resources.
|
||||||
|
type Handler struct {
|
||||||
|
db common.Database
|
||||||
|
registry common.ModelRegistry
|
||||||
|
hooks *HookRegistry
|
||||||
|
mcpServer *server.MCPServer
|
||||||
|
name string
|
||||||
|
version string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewHandler creates a Handler with the given database and model registry.
|
||||||
|
func NewHandler(db common.Database, registry common.ModelRegistry) *Handler {
|
||||||
|
return &Handler{
|
||||||
|
db: db,
|
||||||
|
registry: registry,
|
||||||
|
hooks: NewHookRegistry(),
|
||||||
|
mcpServer: server.NewMCPServer("resolvemcp", "1.0.0"),
|
||||||
|
name: "resolvemcp",
|
||||||
|
version: "1.0.0",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Hooks returns the hook registry.
|
||||||
|
func (h *Handler) Hooks() *HookRegistry {
|
||||||
|
return h.hooks
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDatabase returns the underlying database.
|
||||||
|
func (h *Handler) GetDatabase() common.Database {
|
||||||
|
return h.db
|
||||||
|
}
|
||||||
|
|
||||||
|
// MCPServer returns the underlying MCP server, e.g. to add custom tools.
|
||||||
|
func (h *Handler) MCPServer() *server.MCPServer {
|
||||||
|
return h.mcpServer
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterModel registers a model and immediately exposes it as MCP tools and a resource.
|
||||||
|
func (h *Handler) RegisterModel(schema, entity string, model interface{}) error {
|
||||||
|
fullName := buildModelName(schema, entity)
|
||||||
|
if err := h.registry.RegisterModel(fullName, model); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
registerModelTools(h, schema, entity, model)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildModelName builds the registry key for a model (same format as resolvespec).
|
||||||
|
func buildModelName(schema, entity string) string {
|
||||||
|
if schema == "" {
|
||||||
|
return entity
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s.%s", schema, entity)
|
||||||
|
}
|
||||||
|
|
||||||
|
// getTableName returns the fully qualified table name for a model.
|
||||||
|
func (h *Handler) getTableName(schema, entity string, model interface{}) string {
|
||||||
|
schemaName, tableName := h.getSchemaAndTable(schema, entity, model)
|
||||||
|
if schemaName != "" {
|
||||||
|
if h.db.DriverName() == "sqlite" {
|
||||||
|
return fmt.Sprintf("%s_%s", schemaName, tableName)
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s.%s", schemaName, tableName)
|
||||||
|
}
|
||||||
|
return tableName
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) getSchemaAndTable(defaultSchema, entity string, model interface{}) (schema, table string) {
|
||||||
|
if tableProvider, ok := model.(common.TableNameProvider); ok {
|
||||||
|
tableName := tableProvider.TableName()
|
||||||
|
if idx := strings.LastIndex(tableName, "."); idx != -1 {
|
||||||
|
return tableName[:idx], tableName[idx+1:]
|
||||||
|
}
|
||||||
|
if schemaProvider, ok := model.(common.SchemaProvider); ok {
|
||||||
|
return schemaProvider.SchemaName(), tableName
|
||||||
|
}
|
||||||
|
return defaultSchema, tableName
|
||||||
|
}
|
||||||
|
if schemaProvider, ok := model.(common.SchemaProvider); ok {
|
||||||
|
return schemaProvider.SchemaName(), entity
|
||||||
|
}
|
||||||
|
return defaultSchema, entity
|
||||||
|
}
|
||||||
|
|
||||||
|
// executeRead reads records from the database and returns raw data + metadata.
|
||||||
|
func (h *Handler) executeRead(ctx context.Context, schema, entity, id string, options common.RequestOptions) (interface{}, *common.Metadata, error) {
|
||||||
|
model, err := h.registry.GetModelByEntity(schema, entity)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("model not found: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
unwrapped, err := common.ValidateAndUnwrapModel(model)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("invalid model: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
model = unwrapped.Model
|
||||||
|
modelType := unwrapped.ModelType
|
||||||
|
tableName := h.getTableName(schema, entity, model)
|
||||||
|
ctx = withRequestData(ctx, schema, entity, tableName, model, unwrapped.ModelPtr)
|
||||||
|
|
||||||
|
validator := common.NewColumnValidator(model)
|
||||||
|
options = validator.FilterRequestOptions(options)
|
||||||
|
|
||||||
|
// BeforeHandle hook
|
||||||
|
hookCtx := &HookContext{
|
||||||
|
Context: ctx,
|
||||||
|
Handler: h,
|
||||||
|
Schema: schema,
|
||||||
|
Entity: entity,
|
||||||
|
Model: model,
|
||||||
|
Operation: "read",
|
||||||
|
Options: options,
|
||||||
|
ID: id,
|
||||||
|
Tx: h.db,
|
||||||
|
}
|
||||||
|
if err := h.hooks.Execute(BeforeHandle, hookCtx); err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
sliceType := reflect.SliceOf(reflect.PointerTo(modelType))
|
||||||
|
modelPtr := reflect.New(sliceType).Interface()
|
||||||
|
|
||||||
|
query := h.db.NewSelect().Model(modelPtr)
|
||||||
|
|
||||||
|
tempInstance := reflect.New(modelType).Interface()
|
||||||
|
if provider, ok := tempInstance.(common.TableNameProvider); !ok || provider.TableName() == "" {
|
||||||
|
query = query.Table(tableName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Column selection
|
||||||
|
if len(options.Columns) == 0 && len(options.ComputedColumns) > 0 {
|
||||||
|
options.Columns = reflection.GetSQLModelColumns(model)
|
||||||
|
}
|
||||||
|
for _, col := range options.Columns {
|
||||||
|
query = query.Column(reflection.ExtractSourceColumn(col))
|
||||||
|
}
|
||||||
|
for _, cu := range options.ComputedColumns {
|
||||||
|
query = query.ColumnExpr(fmt.Sprintf("(%s) AS %s", cu.Expression, cu.Name))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Preloads
|
||||||
|
if len(options.Preload) > 0 {
|
||||||
|
var err error
|
||||||
|
query, err = h.applyPreloads(model, query, options.Preload)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("failed to apply preloads: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Filters
|
||||||
|
query = h.applyFilters(query, options.Filters)
|
||||||
|
|
||||||
|
// Custom operators
|
||||||
|
for _, customOp := range options.CustomOperators {
|
||||||
|
query = query.Where(customOp.SQL)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sorting
|
||||||
|
for _, sort := range options.Sort {
|
||||||
|
direction := "ASC"
|
||||||
|
if strings.EqualFold(sort.Direction, "desc") {
|
||||||
|
direction = "DESC"
|
||||||
|
}
|
||||||
|
query = query.Order(fmt.Sprintf("%s %s", sort.Column, direction))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cursor pagination
|
||||||
|
if options.CursorForward != "" || options.CursorBackward != "" {
|
||||||
|
pkName := reflection.GetPrimaryKeyName(model)
|
||||||
|
modelColumns := reflection.GetModelColumns(model)
|
||||||
|
|
||||||
|
if len(options.Sort) == 0 {
|
||||||
|
options.Sort = []common.SortOption{{Column: pkName, Direction: "ASC"}}
|
||||||
|
}
|
||||||
|
|
||||||
|
cursorFilter, err := getCursorFilter(tableName, pkName, modelColumns, options)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("cursor error: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cursorFilter != "" {
|
||||||
|
sanitized := common.SanitizeWhereClause(cursorFilter, reflection.ExtractTableNameOnly(tableName), &options)
|
||||||
|
sanitized = common.EnsureOuterParentheses(sanitized)
|
||||||
|
if sanitized != "" {
|
||||||
|
query = query.Where(sanitized)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Count
|
||||||
|
total, err := query.Count(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("error counting records: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pagination
|
||||||
|
if options.Limit != nil && *options.Limit > 0 {
|
||||||
|
query = query.Limit(*options.Limit)
|
||||||
|
}
|
||||||
|
if options.Offset != nil && *options.Offset > 0 {
|
||||||
|
query = query.Offset(*options.Offset)
|
||||||
|
}
|
||||||
|
|
||||||
|
// BeforeRead hook
|
||||||
|
hookCtx.Query = query
|
||||||
|
if err := h.hooks.Execute(BeforeRead, hookCtx); err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var data interface{}
|
||||||
|
if id != "" {
|
||||||
|
singleResult := reflect.New(modelType).Interface()
|
||||||
|
pkName := reflection.GetPrimaryKeyName(singleResult)
|
||||||
|
query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id)
|
||||||
|
if err := query.Scan(ctx, singleResult); err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return nil, nil, fmt.Errorf("record not found")
|
||||||
|
}
|
||||||
|
return nil, nil, fmt.Errorf("query error: %w", err)
|
||||||
|
}
|
||||||
|
data = singleResult
|
||||||
|
} else {
|
||||||
|
if err := query.Scan(ctx, modelPtr); err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("query error: %w", err)
|
||||||
|
}
|
||||||
|
data = reflect.ValueOf(modelPtr).Elem().Interface()
|
||||||
|
}
|
||||||
|
|
||||||
|
limit := 0
|
||||||
|
offset := 0
|
||||||
|
if options.Limit != nil {
|
||||||
|
limit = *options.Limit
|
||||||
|
}
|
||||||
|
if options.Offset != nil {
|
||||||
|
offset = *options.Offset
|
||||||
|
}
|
||||||
|
|
||||||
|
// Count is the number of records in this page, not the total.
|
||||||
|
var pageCount int64
|
||||||
|
if id != "" {
|
||||||
|
pageCount = 1
|
||||||
|
} else {
|
||||||
|
pageCount = int64(reflect.ValueOf(data).Len())
|
||||||
|
}
|
||||||
|
|
||||||
|
metadata := &common.Metadata{
|
||||||
|
Total: int64(total),
|
||||||
|
Filtered: int64(total),
|
||||||
|
Count: pageCount,
|
||||||
|
Limit: limit,
|
||||||
|
Offset: offset,
|
||||||
|
}
|
||||||
|
|
||||||
|
// AfterRead hook
|
||||||
|
hookCtx.Result = data
|
||||||
|
if err := h.hooks.Execute(AfterRead, hookCtx); err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return data, metadata, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// executeCreate inserts one or more records.
|
||||||
|
func (h *Handler) executeCreate(ctx context.Context, schema, entity string, data interface{}) (interface{}, error) {
|
||||||
|
model, err := h.registry.GetModelByEntity(schema, entity)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("model not found: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := common.ValidateAndUnwrapModel(model)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid model: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
model = result.Model
|
||||||
|
tableName := h.getTableName(schema, entity, model)
|
||||||
|
ctx = withRequestData(ctx, schema, entity, tableName, model, result.ModelPtr)
|
||||||
|
|
||||||
|
hookCtx := &HookContext{
|
||||||
|
Context: ctx,
|
||||||
|
Handler: h,
|
||||||
|
Schema: schema,
|
||||||
|
Entity: entity,
|
||||||
|
Model: model,
|
||||||
|
Operation: "create",
|
||||||
|
Data: data,
|
||||||
|
Tx: h.db,
|
||||||
|
}
|
||||||
|
if err := h.hooks.Execute(BeforeHandle, hookCtx); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := h.hooks.Execute(BeforeCreate, hookCtx); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use potentially modified data
|
||||||
|
data = hookCtx.Data
|
||||||
|
|
||||||
|
switch v := data.(type) {
|
||||||
|
case map[string]interface{}:
|
||||||
|
query := h.db.NewInsert().Table(tableName)
|
||||||
|
for key, value := range v {
|
||||||
|
query = query.Value(key, value)
|
||||||
|
}
|
||||||
|
if _, err := query.Exec(ctx); err != nil {
|
||||||
|
return nil, fmt.Errorf("create error: %w", err)
|
||||||
|
}
|
||||||
|
hookCtx.Result = v
|
||||||
|
if err := h.hooks.Execute(AfterCreate, hookCtx); err != nil {
|
||||||
|
return nil, fmt.Errorf("AfterCreate hook failed: %w", err)
|
||||||
|
}
|
||||||
|
return v, nil
|
||||||
|
|
||||||
|
case []interface{}:
|
||||||
|
results := make([]interface{}, 0, len(v))
|
||||||
|
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
||||||
|
for _, item := range v {
|
||||||
|
itemMap, ok := item.(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("each item must be an object")
|
||||||
|
}
|
||||||
|
q := tx.NewInsert().Table(tableName)
|
||||||
|
for key, value := range itemMap {
|
||||||
|
q = q.Value(key, value)
|
||||||
|
}
|
||||||
|
if _, err := q.Exec(ctx); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
results = append(results, item)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("batch create error: %w", err)
|
||||||
|
}
|
||||||
|
hookCtx.Result = results
|
||||||
|
if err := h.hooks.Execute(AfterCreate, hookCtx); err != nil {
|
||||||
|
return nil, fmt.Errorf("AfterCreate hook failed: %w", err)
|
||||||
|
}
|
||||||
|
return results, nil
|
||||||
|
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("data must be an object or array of objects")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// executeUpdate updates a record by ID.
|
||||||
|
func (h *Handler) executeUpdate(ctx context.Context, schema, entity, id string, data interface{}) (interface{}, error) {
|
||||||
|
model, err := h.registry.GetModelByEntity(schema, entity)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("model not found: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := common.ValidateAndUnwrapModel(model)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid model: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
model = result.Model
|
||||||
|
tableName := h.getTableName(schema, entity, model)
|
||||||
|
ctx = withRequestData(ctx, schema, entity, tableName, model, result.ModelPtr)
|
||||||
|
|
||||||
|
updates, ok := data.(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("data must be an object")
|
||||||
|
}
|
||||||
|
|
||||||
|
if id == "" {
|
||||||
|
if idVal, exists := updates["id"]; exists {
|
||||||
|
id = fmt.Sprintf("%v", idVal)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if id == "" {
|
||||||
|
return nil, fmt.Errorf("update requires an ID")
|
||||||
|
}
|
||||||
|
|
||||||
|
pkName := reflection.GetPrimaryKeyName(model)
|
||||||
|
|
||||||
|
var updateResult interface{}
|
||||||
|
err = h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
||||||
|
// Read existing record
|
||||||
|
modelType := reflect.TypeOf(model)
|
||||||
|
if modelType.Kind() == reflect.Ptr {
|
||||||
|
modelType = modelType.Elem()
|
||||||
|
}
|
||||||
|
existingRecord := reflect.New(modelType).Interface()
|
||||||
|
selectQuery := tx.NewSelect().Model(existingRecord).Column("*").
|
||||||
|
Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id)
|
||||||
|
|
||||||
|
if err := selectQuery.ScanModel(ctx); err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return fmt.Errorf("no records found to update")
|
||||||
|
}
|
||||||
|
return fmt.Errorf("error fetching existing record: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert to map
|
||||||
|
existingMap := make(map[string]interface{})
|
||||||
|
jsonData, err := json.Marshal(existingRecord)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error marshaling existing record: %w", err)
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(jsonData, &existingMap); err != nil {
|
||||||
|
return fmt.Errorf("error unmarshaling existing record: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
hookCtx := &HookContext{
|
||||||
|
Context: ctx,
|
||||||
|
Handler: h,
|
||||||
|
Schema: schema,
|
||||||
|
Entity: entity,
|
||||||
|
Model: model,
|
||||||
|
Operation: "update",
|
||||||
|
ID: id,
|
||||||
|
Data: updates,
|
||||||
|
Tx: tx,
|
||||||
|
}
|
||||||
|
if err := h.hooks.Execute(BeforeUpdate, hookCtx); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if modifiedData, ok := hookCtx.Data.(map[string]interface{}); ok {
|
||||||
|
updates = modifiedData
|
||||||
|
}
|
||||||
|
|
||||||
|
// Merge non-nil, non-empty values
|
||||||
|
for key, newValue := range updates {
|
||||||
|
if newValue == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if strVal, ok := newValue.(string); ok && strVal == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
existingMap[key] = newValue
|
||||||
|
}
|
||||||
|
|
||||||
|
q := tx.NewUpdate().Table(tableName).SetMap(existingMap).
|
||||||
|
Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id)
|
||||||
|
res, err := q.Exec(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error updating record: %w", err)
|
||||||
|
}
|
||||||
|
if res.RowsAffected() == 0 {
|
||||||
|
return fmt.Errorf("no records found to update")
|
||||||
|
}
|
||||||
|
|
||||||
|
updateResult = existingMap
|
||||||
|
hookCtx.Result = updateResult
|
||||||
|
return h.hooks.Execute(AfterUpdate, hookCtx)
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return updateResult, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// executeDelete deletes a record by ID.
|
||||||
|
func (h *Handler) executeDelete(ctx context.Context, schema, entity, id string) (interface{}, error) {
|
||||||
|
if id == "" {
|
||||||
|
return nil, fmt.Errorf("delete requires an ID")
|
||||||
|
}
|
||||||
|
|
||||||
|
model, err := h.registry.GetModelByEntity(schema, entity)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("model not found: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := common.ValidateAndUnwrapModel(model)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid model: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
model = result.Model
|
||||||
|
tableName := h.getTableName(schema, entity, model)
|
||||||
|
ctx = withRequestData(ctx, schema, entity, tableName, model, result.ModelPtr)
|
||||||
|
|
||||||
|
pkName := reflection.GetPrimaryKeyName(model)
|
||||||
|
|
||||||
|
hookCtx := &HookContext{
|
||||||
|
Context: ctx,
|
||||||
|
Handler: h,
|
||||||
|
Schema: schema,
|
||||||
|
Entity: entity,
|
||||||
|
Model: model,
|
||||||
|
Operation: "delete",
|
||||||
|
ID: id,
|
||||||
|
Tx: h.db,
|
||||||
|
}
|
||||||
|
if err := h.hooks.Execute(BeforeHandle, hookCtx); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := h.hooks.Execute(BeforeDelete, hookCtx); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
modelType := reflect.TypeOf(model)
|
||||||
|
if modelType.Kind() == reflect.Ptr {
|
||||||
|
modelType = modelType.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
var recordToDelete interface{}
|
||||||
|
|
||||||
|
err = h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
||||||
|
record := reflect.New(modelType).Interface()
|
||||||
|
selectQuery := tx.NewSelect().Model(record).
|
||||||
|
Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id)
|
||||||
|
if err := selectQuery.ScanModel(ctx); err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return fmt.Errorf("record not found")
|
||||||
|
}
|
||||||
|
return fmt.Errorf("error fetching record: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
res, err := tx.NewDelete().Table(tableName).
|
||||||
|
Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id).
|
||||||
|
Exec(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("delete error: %w", err)
|
||||||
|
}
|
||||||
|
if res.RowsAffected() == 0 {
|
||||||
|
return fmt.Errorf("record not found or already deleted")
|
||||||
|
}
|
||||||
|
|
||||||
|
recordToDelete = record
|
||||||
|
hookCtx.Tx = tx
|
||||||
|
hookCtx.Result = record
|
||||||
|
return h.hooks.Execute(AfterDelete, hookCtx)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("[resolvemcp] Deleted record %s from %s.%s", id, schema, entity)
|
||||||
|
return recordToDelete, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// applyFilters applies all filters with OR grouping logic.
|
||||||
|
func (h *Handler) applyFilters(query common.SelectQuery, filters []common.FilterOption) common.SelectQuery {
|
||||||
|
if len(filters) == 0 {
|
||||||
|
return query
|
||||||
|
}
|
||||||
|
|
||||||
|
i := 0
|
||||||
|
for i < len(filters) {
|
||||||
|
startORGroup := i+1 < len(filters) && strings.EqualFold(filters[i+1].LogicOperator, "OR")
|
||||||
|
|
||||||
|
if startORGroup {
|
||||||
|
orGroup := []common.FilterOption{filters[i]}
|
||||||
|
j := i + 1
|
||||||
|
for j < len(filters) && strings.EqualFold(filters[j].LogicOperator, "OR") {
|
||||||
|
orGroup = append(orGroup, filters[j])
|
||||||
|
j++
|
||||||
|
}
|
||||||
|
query = h.applyFilterGroup(query, orGroup)
|
||||||
|
i = j
|
||||||
|
} else {
|
||||||
|
condition, args := h.buildFilterCondition(filters[i])
|
||||||
|
if condition != "" {
|
||||||
|
query = query.Where(condition, args...)
|
||||||
|
}
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return query
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) applyFilterGroup(query common.SelectQuery, filters []common.FilterOption) common.SelectQuery {
|
||||||
|
var conditions []string
|
||||||
|
var args []interface{}
|
||||||
|
|
||||||
|
for _, filter := range filters {
|
||||||
|
condition, filterArgs := h.buildFilterCondition(filter)
|
||||||
|
if condition != "" {
|
||||||
|
conditions = append(conditions, condition)
|
||||||
|
args = append(args, filterArgs...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(conditions) == 0 {
|
||||||
|
return query
|
||||||
|
}
|
||||||
|
if len(conditions) == 1 {
|
||||||
|
return query.Where(conditions[0], args...)
|
||||||
|
}
|
||||||
|
return query.Where("("+strings.Join(conditions, " OR ")+")", args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) buildFilterCondition(filter common.FilterOption) (string, []interface{}) {
|
||||||
|
switch filter.Operator {
|
||||||
|
case "eq", "=":
|
||||||
|
return fmt.Sprintf("%s = ?", filter.Column), []interface{}{filter.Value}
|
||||||
|
case "neq", "!=", "<>":
|
||||||
|
return fmt.Sprintf("%s != ?", filter.Column), []interface{}{filter.Value}
|
||||||
|
case "gt", ">":
|
||||||
|
return fmt.Sprintf("%s > ?", filter.Column), []interface{}{filter.Value}
|
||||||
|
case "gte", ">=":
|
||||||
|
return fmt.Sprintf("%s >= ?", filter.Column), []interface{}{filter.Value}
|
||||||
|
case "lt", "<":
|
||||||
|
return fmt.Sprintf("%s < ?", filter.Column), []interface{}{filter.Value}
|
||||||
|
case "lte", "<=":
|
||||||
|
return fmt.Sprintf("%s <= ?", filter.Column), []interface{}{filter.Value}
|
||||||
|
case "like":
|
||||||
|
return fmt.Sprintf("%s LIKE ?", filter.Column), []interface{}{filter.Value}
|
||||||
|
case "ilike":
|
||||||
|
return fmt.Sprintf("%s ILIKE ?", filter.Column), []interface{}{filter.Value}
|
||||||
|
case "in":
|
||||||
|
condition, args := common.BuildInCondition(filter.Column, filter.Value)
|
||||||
|
return condition, args
|
||||||
|
case "is_null":
|
||||||
|
return fmt.Sprintf("%s IS NULL", filter.Column), nil
|
||||||
|
case "is_not_null":
|
||||||
|
return fmt.Sprintf("%s IS NOT NULL", filter.Column), nil
|
||||||
|
}
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, preloads []common.PreloadOption) (common.SelectQuery, error) {
|
||||||
|
for _, preload := range preloads {
|
||||||
|
if preload.Relation == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
query = query.PreloadRelation(preload.Relation)
|
||||||
|
}
|
||||||
|
return query, nil
|
||||||
|
}
|
||||||
113
pkg/resolvemcp/hooks.go
Normal file
113
pkg/resolvemcp/hooks.go
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
package resolvemcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
// HookType defines the type of hook to execute
|
||||||
|
type HookType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
// BeforeHandle fires after model resolution, before operation dispatch.
|
||||||
|
BeforeHandle HookType = "before_handle"
|
||||||
|
|
||||||
|
BeforeRead HookType = "before_read"
|
||||||
|
AfterRead HookType = "after_read"
|
||||||
|
|
||||||
|
BeforeCreate HookType = "before_create"
|
||||||
|
AfterCreate HookType = "after_create"
|
||||||
|
|
||||||
|
BeforeUpdate HookType = "before_update"
|
||||||
|
AfterUpdate HookType = "after_update"
|
||||||
|
|
||||||
|
BeforeDelete HookType = "before_delete"
|
||||||
|
AfterDelete HookType = "after_delete"
|
||||||
|
)
|
||||||
|
|
||||||
|
// HookContext contains all the data available to a hook
|
||||||
|
type HookContext struct {
|
||||||
|
Context context.Context
|
||||||
|
Handler *Handler
|
||||||
|
Schema string
|
||||||
|
Entity string
|
||||||
|
Model interface{}
|
||||||
|
Options common.RequestOptions
|
||||||
|
Operation string
|
||||||
|
ID string
|
||||||
|
Data interface{}
|
||||||
|
Result interface{}
|
||||||
|
Error error
|
||||||
|
Query common.SelectQuery
|
||||||
|
Abort bool
|
||||||
|
AbortMessage string
|
||||||
|
AbortCode int
|
||||||
|
Tx common.Database
|
||||||
|
}
|
||||||
|
|
||||||
|
// HookFunc is the signature for hook functions
|
||||||
|
type HookFunc func(*HookContext) error
|
||||||
|
|
||||||
|
// HookRegistry manages all registered hooks
|
||||||
|
type HookRegistry struct {
|
||||||
|
hooks map[HookType][]HookFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewHookRegistry() *HookRegistry {
|
||||||
|
return &HookRegistry{
|
||||||
|
hooks: make(map[HookType][]HookFunc),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *HookRegistry) Register(hookType HookType, hook HookFunc) {
|
||||||
|
if r.hooks == nil {
|
||||||
|
r.hooks = make(map[HookType][]HookFunc)
|
||||||
|
}
|
||||||
|
r.hooks[hookType] = append(r.hooks[hookType], hook)
|
||||||
|
logger.Info("Registered resolvemcp hook for %s (total: %d)", hookType, len(r.hooks[hookType]))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *HookRegistry) RegisterMultiple(hookTypes []HookType, hook HookFunc) {
|
||||||
|
for _, hookType := range hookTypes {
|
||||||
|
r.Register(hookType, hook)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *HookRegistry) Execute(hookType HookType, ctx *HookContext) error {
|
||||||
|
hooks, exists := r.hooks[hookType]
|
||||||
|
if !exists || len(hooks) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Debug("Executing %d resolvemcp hook(s) for %s", len(hooks), hookType)
|
||||||
|
|
||||||
|
for i, hook := range hooks {
|
||||||
|
if err := hook(ctx); err != nil {
|
||||||
|
logger.Error("resolvemcp hook %d for %s failed: %v", i+1, hookType, err)
|
||||||
|
return fmt.Errorf("hook execution failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if ctx.Abort {
|
||||||
|
logger.Warn("resolvemcp hook %d for %s requested abort: %s", i+1, hookType, ctx.AbortMessage)
|
||||||
|
return fmt.Errorf("operation aborted by hook: %s", ctx.AbortMessage)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *HookRegistry) Clear(hookType HookType) {
|
||||||
|
delete(r.hooks, hookType)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *HookRegistry) ClearAll() {
|
||||||
|
r.hooks = make(map[HookType][]HookFunc)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *HookRegistry) HasHooks(hookType HookType) bool {
|
||||||
|
hooks, exists := r.hooks[hookType]
|
||||||
|
return exists && len(hooks) > 0
|
||||||
|
}
|
||||||
83
pkg/resolvemcp/resolvemcp.go
Normal file
83
pkg/resolvemcp/resolvemcp.go
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
// Package resolvemcp exposes registered database models as Model Context Protocol (MCP) tools
|
||||||
|
// and resources over HTTP/SSE transport.
|
||||||
|
//
|
||||||
|
// It mirrors the resolvespec package patterns:
|
||||||
|
// - Same model registration API
|
||||||
|
// - Same filter, sort, cursor pagination, preload options
|
||||||
|
// - Same lifecycle hook system
|
||||||
|
//
|
||||||
|
// Usage:
|
||||||
|
//
|
||||||
|
// handler := resolvemcp.NewHandlerWithGORM(db)
|
||||||
|
// handler.RegisterModel("public", "users", &User{})
|
||||||
|
//
|
||||||
|
// r := mux.NewRouter()
|
||||||
|
// resolvemcp.SetupMuxRoutes(r, handler, "http://localhost:8080")
|
||||||
|
package resolvemcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/gorilla/mux"
|
||||||
|
"github.com/mark3labs/mcp-go/server"
|
||||||
|
"github.com/uptrace/bun"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/common/adapters/database"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewHandlerWithGORM creates a Handler backed by a GORM database connection.
|
||||||
|
func NewHandlerWithGORM(db *gorm.DB) *Handler {
|
||||||
|
return NewHandler(database.NewGormAdapter(db), modelregistry.NewModelRegistry())
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewHandlerWithBun creates a Handler backed by a Bun database connection.
|
||||||
|
func NewHandlerWithBun(db *bun.DB) *Handler {
|
||||||
|
return NewHandler(database.NewBunAdapter(db), modelregistry.NewModelRegistry())
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewHandlerWithDB creates a Handler using an existing common.Database and a new registry.
|
||||||
|
func NewHandlerWithDB(db common.Database) *Handler {
|
||||||
|
return NewHandler(db, modelregistry.NewModelRegistry())
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetupMuxRoutes mounts the MCP HTTP/SSE endpoints on the given Gorilla Mux router.
|
||||||
|
//
|
||||||
|
// baseURL is the public-facing base URL of the server (e.g. "http://localhost:8080").
|
||||||
|
// It is sent to MCP clients during the SSE handshake so they know where to POST messages.
|
||||||
|
//
|
||||||
|
// Two routes are registered:
|
||||||
|
// - GET /mcp/sse — SSE connection endpoint (client subscribes here)
|
||||||
|
// - POST /mcp/message — JSON-RPC message endpoint (client sends requests here)
|
||||||
|
//
|
||||||
|
// To protect these routes with authentication, wrap the mux router or apply middleware
|
||||||
|
// before calling SetupMuxRoutes.
|
||||||
|
func SetupMuxRoutes(muxRouter *mux.Router, handler *Handler, baseURL string) {
|
||||||
|
sseServer := server.NewSSEServer(
|
||||||
|
handler.mcpServer,
|
||||||
|
server.WithBaseURL(baseURL),
|
||||||
|
server.WithBasePath("/mcp"),
|
||||||
|
)
|
||||||
|
|
||||||
|
muxRouter.Handle("/mcp/sse", sseServer.SSEHandler()).Methods("GET", "OPTIONS")
|
||||||
|
muxRouter.Handle("/mcp/message", sseServer.MessageHandler()).Methods("POST", "OPTIONS")
|
||||||
|
|
||||||
|
// Convenience: also expose the full SSE server at /mcp for clients that
|
||||||
|
// use ServeHTTP directly (e.g. net/http default mux).
|
||||||
|
muxRouter.PathPrefix("/mcp").Handler(http.StripPrefix("/mcp", sseServer))
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSSEServer creates an *server.SSEServer that can be mounted manually,
|
||||||
|
// useful when integrating with non-Mux routers or adding extra middleware.
|
||||||
|
//
|
||||||
|
// sseServer := resolvemcp.NewSSEServer(handler, "http://localhost:8080", "/mcp")
|
||||||
|
// http.Handle("/mcp/", http.StripPrefix("/mcp", sseServer))
|
||||||
|
func NewSSEServer(handler *Handler, baseURL, basePath string) *server.SSEServer {
|
||||||
|
return server.NewSSEServer(
|
||||||
|
handler.mcpServer,
|
||||||
|
server.WithBaseURL(baseURL),
|
||||||
|
server.WithBasePath(basePath),
|
||||||
|
)
|
||||||
|
}
|
||||||
415
pkg/resolvemcp/tools.go
Normal file
415
pkg/resolvemcp/tools.go
Normal file
@@ -0,0 +1,415 @@
|
|||||||
|
package resolvemcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/mark3labs/mcp-go/mcp"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
// toolName builds the MCP tool name for a given operation and model.
|
||||||
|
func toolName(operation, schema, entity string) string {
|
||||||
|
if schema == "" {
|
||||||
|
return fmt.Sprintf("%s_%s", operation, entity)
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s_%s_%s", operation, schema, entity)
|
||||||
|
}
|
||||||
|
|
||||||
|
// registerModelTools registers the four CRUD tools and resource for a model.
|
||||||
|
func registerModelTools(h *Handler, schema, entity string, model interface{}) {
|
||||||
|
registerReadTool(h, schema, entity)
|
||||||
|
registerCreateTool(h, schema, entity)
|
||||||
|
registerUpdateTool(h, schema, entity)
|
||||||
|
registerDeleteTool(h, schema, entity)
|
||||||
|
registerModelResource(h, schema, entity)
|
||||||
|
|
||||||
|
logger.Info("[resolvemcp] Registered MCP tools for %s.%s", schema, entity)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
// Read tool
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func registerReadTool(h *Handler, schema, entity string) {
|
||||||
|
name := toolName("read", schema, entity)
|
||||||
|
description := fmt.Sprintf("Read records from %s", buildModelName(schema, entity))
|
||||||
|
|
||||||
|
tool := mcp.NewTool(name,
|
||||||
|
mcp.WithDescription(description),
|
||||||
|
mcp.WithString("id",
|
||||||
|
mcp.Description("Primary key of a single record to fetch (optional)"),
|
||||||
|
),
|
||||||
|
mcp.WithNumber("limit",
|
||||||
|
mcp.Description("Maximum number of records to return"),
|
||||||
|
),
|
||||||
|
mcp.WithNumber("offset",
|
||||||
|
mcp.Description("Number of records to skip"),
|
||||||
|
),
|
||||||
|
mcp.WithString("cursor_forward",
|
||||||
|
mcp.Description("Cursor value for the next page (primary key of last record on current page)"),
|
||||||
|
),
|
||||||
|
mcp.WithString("cursor_backward",
|
||||||
|
mcp.Description("Cursor value for the previous page"),
|
||||||
|
),
|
||||||
|
mcp.WithArray("columns",
|
||||||
|
mcp.Description("List of column names to include in the result"),
|
||||||
|
),
|
||||||
|
mcp.WithArray("omit_columns",
|
||||||
|
mcp.Description("List of column names to exclude from the result"),
|
||||||
|
),
|
||||||
|
mcp.WithArray("filters",
|
||||||
|
mcp.Description(`Array of filter objects. Each object: {"column":"name","operator":"=","value":"val","logic_operator":"AND|OR"}. Operators: =, !=, >, <, >=, <=, like, ilike, in, is_null, is_not_null`),
|
||||||
|
),
|
||||||
|
mcp.WithArray("sort",
|
||||||
|
mcp.Description(`Array of sort objects. Each object: {"column":"name","direction":"asc|desc"}`),
|
||||||
|
),
|
||||||
|
mcp.WithArray("preloads",
|
||||||
|
mcp.Description(`Array of relation preload objects. Each object: {"relation":"RelationName","columns":["col1"]}`),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
h.mcpServer.AddTool(tool, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||||
|
args := req.GetArguments()
|
||||||
|
id, _ := args["id"].(string)
|
||||||
|
options := parseRequestOptions(args)
|
||||||
|
|
||||||
|
data, metadata, err := h.executeRead(ctx, schema, entity, id, options)
|
||||||
|
if err != nil {
|
||||||
|
return mcp.NewToolResultError(err.Error()), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return marshalResult(map[string]interface{}{
|
||||||
|
"success": true,
|
||||||
|
"data": data,
|
||||||
|
"metadata": metadata,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
// Create tool
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func registerCreateTool(h *Handler, schema, entity string) {
|
||||||
|
name := toolName("create", schema, entity)
|
||||||
|
description := fmt.Sprintf("Create one or more records in %s", buildModelName(schema, entity))
|
||||||
|
|
||||||
|
tool := mcp.NewTool(name,
|
||||||
|
mcp.WithDescription(description),
|
||||||
|
mcp.WithObject("data",
|
||||||
|
mcp.Description("Record fields to create (single object), or pass an array as the 'items' key"),
|
||||||
|
mcp.Required(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
h.mcpServer.AddTool(tool, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||||
|
args := req.GetArguments()
|
||||||
|
data, ok := args["data"]
|
||||||
|
if !ok {
|
||||||
|
return mcp.NewToolResultError("missing required argument: data"), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := h.executeCreate(ctx, schema, entity, data)
|
||||||
|
if err != nil {
|
||||||
|
return mcp.NewToolResultError(err.Error()), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return marshalResult(map[string]interface{}{
|
||||||
|
"success": true,
|
||||||
|
"data": result,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
// Update tool
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func registerUpdateTool(h *Handler, schema, entity string) {
|
||||||
|
name := toolName("update", schema, entity)
|
||||||
|
description := fmt.Sprintf("Update an existing record in %s", buildModelName(schema, entity))
|
||||||
|
|
||||||
|
tool := mcp.NewTool(name,
|
||||||
|
mcp.WithDescription(description),
|
||||||
|
mcp.WithString("id",
|
||||||
|
mcp.Description("Primary key of the record to update"),
|
||||||
|
),
|
||||||
|
mcp.WithObject("data",
|
||||||
|
mcp.Description("Fields to update (non-null fields will be merged into the existing record)"),
|
||||||
|
mcp.Required(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
h.mcpServer.AddTool(tool, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||||
|
args := req.GetArguments()
|
||||||
|
id, _ := args["id"].(string)
|
||||||
|
|
||||||
|
data, ok := args["data"]
|
||||||
|
if !ok {
|
||||||
|
return mcp.NewToolResultError("missing required argument: data"), nil
|
||||||
|
}
|
||||||
|
dataMap, ok := data.(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
return mcp.NewToolResultError("data must be an object"), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := h.executeUpdate(ctx, schema, entity, id, dataMap)
|
||||||
|
if err != nil {
|
||||||
|
return mcp.NewToolResultError(err.Error()), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return marshalResult(map[string]interface{}{
|
||||||
|
"success": true,
|
||||||
|
"data": result,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
// Delete tool
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func registerDeleteTool(h *Handler, schema, entity string) {
|
||||||
|
name := toolName("delete", schema, entity)
|
||||||
|
description := fmt.Sprintf("Delete a record from %s by primary key", buildModelName(schema, entity))
|
||||||
|
|
||||||
|
tool := mcp.NewTool(name,
|
||||||
|
mcp.WithDescription(description),
|
||||||
|
mcp.WithString("id",
|
||||||
|
mcp.Description("Primary key of the record to delete"),
|
||||||
|
mcp.Required(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
h.mcpServer.AddTool(tool, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||||
|
args := req.GetArguments()
|
||||||
|
id, _ := args["id"].(string)
|
||||||
|
|
||||||
|
result, err := h.executeDelete(ctx, schema, entity, id)
|
||||||
|
if err != nil {
|
||||||
|
return mcp.NewToolResultError(err.Error()), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return marshalResult(map[string]interface{}{
|
||||||
|
"success": true,
|
||||||
|
"data": result,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
// Resource registration
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func registerModelResource(h *Handler, schema, entity string) {
|
||||||
|
resourceURI := buildModelName(schema, entity)
|
||||||
|
displayName := entity
|
||||||
|
if schema != "" {
|
||||||
|
displayName = schema + "." + entity
|
||||||
|
}
|
||||||
|
|
||||||
|
resource := mcp.NewResource(
|
||||||
|
resourceURI,
|
||||||
|
displayName,
|
||||||
|
mcp.WithResourceDescription(fmt.Sprintf("Database table: %s", displayName)),
|
||||||
|
mcp.WithMIMEType("application/json"),
|
||||||
|
)
|
||||||
|
|
||||||
|
h.mcpServer.AddResource(resource, func(ctx context.Context, req mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) {
|
||||||
|
limit := 100
|
||||||
|
options := common.RequestOptions{Limit: &limit}
|
||||||
|
|
||||||
|
data, metadata, err := h.executeRead(ctx, schema, entity, "", options)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
payload := map[string]interface{}{
|
||||||
|
"data": data,
|
||||||
|
"metadata": metadata,
|
||||||
|
}
|
||||||
|
jsonBytes, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error marshaling resource: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return []mcp.ResourceContents{
|
||||||
|
mcp.TextResourceContents{
|
||||||
|
URI: req.Params.URI,
|
||||||
|
MIMEType: "application/json",
|
||||||
|
Text: string(jsonBytes),
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
// Argument parsing helpers
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// parseRequestOptions converts raw MCP tool arguments into common.RequestOptions.
|
||||||
|
func parseRequestOptions(args map[string]interface{}) common.RequestOptions {
|
||||||
|
options := common.RequestOptions{}
|
||||||
|
|
||||||
|
// limit
|
||||||
|
if v, ok := args["limit"]; ok {
|
||||||
|
switch n := v.(type) {
|
||||||
|
case float64:
|
||||||
|
limit := int(n)
|
||||||
|
options.Limit = &limit
|
||||||
|
case int:
|
||||||
|
options.Limit = &n
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// offset
|
||||||
|
if v, ok := args["offset"]; ok {
|
||||||
|
switch n := v.(type) {
|
||||||
|
case float64:
|
||||||
|
offset := int(n)
|
||||||
|
options.Offset = &offset
|
||||||
|
case int:
|
||||||
|
options.Offset = &n
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// cursor_forward / cursor_backward
|
||||||
|
if v, ok := args["cursor_forward"].(string); ok {
|
||||||
|
options.CursorForward = v
|
||||||
|
}
|
||||||
|
if v, ok := args["cursor_backward"].(string); ok {
|
||||||
|
options.CursorBackward = v
|
||||||
|
}
|
||||||
|
|
||||||
|
// columns
|
||||||
|
options.Columns = parseStringArray(args["columns"])
|
||||||
|
|
||||||
|
// omit_columns
|
||||||
|
options.OmitColumns = parseStringArray(args["omit_columns"])
|
||||||
|
|
||||||
|
// filters — marshal each item and unmarshal into FilterOption
|
||||||
|
options.Filters = parseFilters(args["filters"])
|
||||||
|
|
||||||
|
// sort
|
||||||
|
options.Sort = parseSortOptions(args["sort"])
|
||||||
|
|
||||||
|
// preloads
|
||||||
|
options.Preload = parsePreloadOptions(args["preloads"])
|
||||||
|
|
||||||
|
return options
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseStringArray(raw interface{}) []string {
|
||||||
|
if raw == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
items, ok := raw.([]interface{})
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
result := make([]string, 0, len(items))
|
||||||
|
for _, item := range items {
|
||||||
|
if s, ok := item.(string); ok {
|
||||||
|
result = append(result, s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseFilters(raw interface{}) []common.FilterOption {
|
||||||
|
if raw == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
items, ok := raw.([]interface{})
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
result := make([]common.FilterOption, 0, len(items))
|
||||||
|
for _, item := range items {
|
||||||
|
b, err := json.Marshal(item)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
var f common.FilterOption
|
||||||
|
if err := json.Unmarshal(b, &f); err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if f.Column == "" || f.Operator == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Normalise logic operator
|
||||||
|
if strings.EqualFold(f.LogicOperator, "or") {
|
||||||
|
f.LogicOperator = "OR"
|
||||||
|
} else {
|
||||||
|
f.LogicOperator = "AND"
|
||||||
|
}
|
||||||
|
result = append(result, f)
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseSortOptions(raw interface{}) []common.SortOption {
|
||||||
|
if raw == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
items, ok := raw.([]interface{})
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
result := make([]common.SortOption, 0, len(items))
|
||||||
|
for _, item := range items {
|
||||||
|
b, err := json.Marshal(item)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
var s common.SortOption
|
||||||
|
if err := json.Unmarshal(b, &s); err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if s.Column == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
result = append(result, s)
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func parsePreloadOptions(raw interface{}) []common.PreloadOption {
|
||||||
|
if raw == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
items, ok := raw.([]interface{})
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
result := make([]common.PreloadOption, 0, len(items))
|
||||||
|
for _, item := range items {
|
||||||
|
b, err := json.Marshal(item)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
var p common.PreloadOption
|
||||||
|
if err := json.Unmarshal(b, &p); err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if p.Relation == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
result = append(result, p)
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// marshalResult marshals a value to JSON and returns it as an MCP text result.
|
||||||
|
func marshalResult(v interface{}) (*mcp.CallToolResult, error) {
|
||||||
|
b, err := json.Marshal(v)
|
||||||
|
if err != nil {
|
||||||
|
return mcp.NewToolResultError(fmt.Sprintf("error marshaling result: %v", err)), nil
|
||||||
|
}
|
||||||
|
return mcp.NewToolResultText(string(b)), nil
|
||||||
|
}
|
||||||
@@ -93,12 +93,18 @@ func (opts *ExtendedRequestOptions) GetCursorFilter(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Handle joins
|
// Handle joins
|
||||||
if isJoin && expandJoins != nil {
|
if isJoin {
|
||||||
if joinClause, ok := expandJoins[prefix]; ok {
|
if expandJoins != nil {
|
||||||
jSQL, cRef := rewriteJoin(joinClause, tableName, prefix)
|
if joinClause, ok := expandJoins[prefix]; ok {
|
||||||
joinSQL = jSQL
|
jSQL, cRef := rewriteJoin(joinClause, tableName, prefix)
|
||||||
cursorCol = cRef + "." + field
|
joinSQL = jSQL
|
||||||
targetCol = prefix + "." + field
|
cursorCol = cRef + "." + field
|
||||||
|
targetCol = prefix + "." + field
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if cursorCol == "" {
|
||||||
|
logger.Warn("Skipping cursor sort column %q: join alias %q not in expandJoins", col, prefix)
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -278,6 +278,47 @@ func TestCleanSortField(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGetCursorFilter_LateralJoin(t *testing.T) {
|
||||||
|
lateralJoin := "inner join lateral (\nselect string_agg(a.name, '.') as sortorder\nfrom tree(account.rid_account) r\ninner join account a on a.id = r.id\n) fn on true"
|
||||||
|
|
||||||
|
opts := &ExtendedRequestOptions{
|
||||||
|
RequestOptions: common.RequestOptions{
|
||||||
|
Sort: []common.SortOption{
|
||||||
|
{Column: "fn.sortorder", Direction: "ASC"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
opts.CursorForward = "8975"
|
||||||
|
|
||||||
|
tableName := "core.account"
|
||||||
|
pkName := "rid_account"
|
||||||
|
// modelColumns does not contain "sortorder" - it's a lateral join computed column
|
||||||
|
modelColumns := []string{"rid_account", "description", "pastelno"}
|
||||||
|
expandJoins := map[string]string{"fn": lateralJoin}
|
||||||
|
|
||||||
|
filter, err := opts.GetCursorFilter(tableName, pkName, modelColumns, expandJoins)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetCursorFilter failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Logf("Generated lateral cursor filter: %s", filter)
|
||||||
|
|
||||||
|
// Should contain the rewritten lateral join inside the EXISTS subquery
|
||||||
|
if !strings.Contains(filter, "cursor_select_fn") {
|
||||||
|
t.Errorf("Filter should reference cursor_select_fn alias, got: %s", filter)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should compare fn.sortorder values
|
||||||
|
if !strings.Contains(filter, "sortorder") {
|
||||||
|
t.Errorf("Filter should reference sortorder column, got: %s", filter)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should NOT contain empty comparison like "< "
|
||||||
|
if strings.Contains(filter, " < ") || strings.Contains(filter, " > ") {
|
||||||
|
t.Errorf("Filter should not contain empty comparison operators, got: %s", filter)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestBuildPriorityChain(t *testing.T) {
|
func TestBuildPriorityChain(t *testing.T) {
|
||||||
clauses := []string{
|
clauses := []string{
|
||||||
"cursor_select.priority > posts.priority",
|
"cursor_select.priority > posts.priority",
|
||||||
|
|||||||
@@ -723,13 +723,15 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
// Extract model columns for validation using the generic database function
|
// Extract model columns for validation using the generic database function
|
||||||
modelColumns := reflection.GetModelColumns(model)
|
modelColumns := reflection.GetModelColumns(model)
|
||||||
|
|
||||||
// Build expand joins map (if needed in future)
|
// Build expand joins map: custom SQL joins are available in cursor subquery
|
||||||
var expandJoins map[string]string
|
expandJoins := make(map[string]string)
|
||||||
if len(options.Expand) > 0 {
|
for _, joinClause := range options.CustomSQLJoin {
|
||||||
expandJoins = make(map[string]string)
|
alias := extractJoinAlias(joinClause)
|
||||||
// TODO: Build actual JOIN SQL for each expand relation
|
if alias != "" {
|
||||||
// For now, pass empty map as joins are handled via Preload
|
expandJoins[alias] = joinClause
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
// TODO: also add Expand relation JOINs when those are built as SQL rather than Preload
|
||||||
|
|
||||||
// Default sort to primary key when none provided
|
// Default sort to primary key when none provided
|
||||||
if len(options.Sort) == 0 {
|
if len(options.Sort) == 0 {
|
||||||
|
|||||||
@@ -552,10 +552,8 @@ func (h *Handler) parseCustomSQLJoin(options *ExtendedRequestOptions, value stri
|
|||||||
// - "LEFT JOIN departments d ON ..." -> "d"
|
// - "LEFT JOIN departments d ON ..." -> "d"
|
||||||
// - "INNER JOIN users AS u ON ..." -> "u"
|
// - "INNER JOIN users AS u ON ..." -> "u"
|
||||||
// - "JOIN roles r ON ..." -> "r"
|
// - "JOIN roles r ON ..." -> "r"
|
||||||
|
// - "INNER JOIN LATERAL (...) fn ON true" -> "fn"
|
||||||
func extractJoinAlias(joinClause string) string {
|
func extractJoinAlias(joinClause string) string {
|
||||||
// Pattern: JOIN table_name [AS] alias ON ...
|
|
||||||
// We need to extract the alias (word before ON)
|
|
||||||
|
|
||||||
upperJoin := strings.ToUpper(joinClause)
|
upperJoin := strings.ToUpper(joinClause)
|
||||||
|
|
||||||
// Find the "JOIN" keyword position
|
// Find the "JOIN" keyword position
|
||||||
@@ -564,7 +562,20 @@ func extractJoinAlias(joinClause string) string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
// Find the "ON" keyword position
|
// Lateral joins: alias is the word after the closing ) and before ON
|
||||||
|
if strings.Contains(upperJoin, "LATERAL") {
|
||||||
|
lastClose := strings.LastIndex(joinClause, ")")
|
||||||
|
if lastClose != -1 {
|
||||||
|
words := strings.Fields(joinClause[lastClose+1:])
|
||||||
|
// words should be like ["fn", "on", "true"] or ["on", "true"]
|
||||||
|
if len(words) >= 1 && !strings.EqualFold(words[0], "on") {
|
||||||
|
return words[0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// Regular joins: find the "ON" keyword position (first occurrence)
|
||||||
onIdx := strings.Index(upperJoin, " ON ")
|
onIdx := strings.Index(upperJoin, " ON ")
|
||||||
if onIdx == -1 {
|
if onIdx == -1 {
|
||||||
return ""
|
return ""
|
||||||
|
|||||||
@@ -142,6 +142,16 @@ func TestExtractJoinAlias(t *testing.T) {
|
|||||||
joinClause: "LEFT JOIN departments",
|
joinClause: "LEFT JOIN departments",
|
||||||
expected: "",
|
expected: "",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "LATERAL join with alias",
|
||||||
|
joinClause: "inner join lateral (select sortorder from compute_fn(t.id)) fn on true",
|
||||||
|
expected: "fn",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "LATERAL join with multiline subquery containing inner ON",
|
||||||
|
joinClause: "inner join lateral (\nselect string_agg(a.name, '.') as sortorder\nfrom tree(t.id) r\ninner join account a on a.id = r.id\n) fn on true",
|
||||||
|
expected: "fn",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
|
|||||||
Reference in New Issue
Block a user