feat(resolvemcp): add hook system for model operations

* Implement hooks for CRUD operations: before/after handle, read, create, update, delete.
* Introduce HookContext and HookRegistry for managing hooks.
* Allow registration and execution of multiple hooks per operation.

feat(resolvemcp): implement MCP tools for CRUD operations
* Register tools for reading, creating, updating, and deleting records.
* Define tool arguments and handle requests with appropriate responses.
* Support for resource registration with metadata.

fix(restheadspec): enhance cursor handling for joins
* Improve cursor filter generation to support lateral joins.
* Update join alias extraction to handle lateral joins correctly.
* Ensure cursor filters do not contain empty comparisons.

test(restheadspec): add tests for cursor filters and join alias extraction
* Create tests for lateral join scenarios in cursor filter generation.
* Validate join alias extraction for various join types, including lateral joins.
This commit is contained in:
Hein
2026-03-27 12:57:08 +02:00
parent 7a498edab7
commit 047a1cc187
13 changed files with 1555 additions and 17 deletions

5
go.mod
View File

@@ -40,6 +40,7 @@ require (
go.opentelemetry.io/otel/trace v1.38.0
go.uber.org/zap v1.27.1
golang.org/x/crypto v0.46.0
golang.org/x/oauth2 v0.34.0
golang.org/x/time v0.14.0
gorm.io/driver/postgres v1.6.0
gorm.io/driver/sqlite v1.6.0
@@ -78,6 +79,7 @@ require (
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect
github.com/golang-sql/sqlexp v0.1.0 // indirect
github.com/golang/snappy v1.0.0 // indirect
github.com/google/jsonschema-go v0.4.2 // indirect
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
@@ -86,6 +88,7 @@ require (
github.com/jinzhu/now v1.1.5 // indirect
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
github.com/magiconair/properties v1.8.10 // indirect
github.com/mark3labs/mcp-go v0.46.0 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/moby/docker-image-spec v1.3.1 // indirect
github.com/moby/go-archive v0.1.0 // indirect
@@ -131,6 +134,7 @@ require (
github.com/xdg-go/pbkdf2 v1.0.0 // indirect
github.com/xdg-go/scram v1.2.0 // indirect
github.com/xdg-go/stringprep v1.0.4 // indirect
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 // indirect
github.com/yusufpapurcu/wmi v1.2.4 // indirect
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
@@ -143,7 +147,6 @@ require (
golang.org/x/exp v0.0.0-20251219203646-944ab1f22d93 // indirect
golang.org/x/mod v0.31.0 // indirect
golang.org/x/net v0.48.0 // indirect
golang.org/x/oauth2 v0.34.0 // indirect
golang.org/x/sync v0.19.0 // indirect
golang.org/x/sys v0.39.0 // indirect
golang.org/x/text v0.32.0 // indirect

6
go.sum
View File

@@ -120,6 +120,8 @@ github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/jsonschema-go v0.4.2 h1:tmrUohrwoLZZS/P3x7ex0WAVknEkBZM46iALbcqoRA8=
github.com/google/jsonschema-go v0.4.2/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
@@ -173,6 +175,8 @@ github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I=
github.com/magiconair/properties v1.8.10 h1:s31yESBquKXCV9a/ScB3ESkOjUYYv+X0rg8SYxI99mE=
github.com/magiconair/properties v1.8.10/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
github.com/mark3labs/mcp-go v0.46.0 h1:8KRibF4wcKejbLsHxCA/QBVUr5fQ9nwz/n8lGqmaALo=
github.com/mark3labs/mcp-go v0.46.0/go.mod h1:JKTC7R2LLVagkEWK7Kwu7DbmA6iIvnNAod6yrHiQMag=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-sqlite3 v1.14.33 h1:A5blZ5ulQo2AtayQ9/limgHEkFreKj1Dv226a1K73s0=
@@ -326,6 +330,8 @@ github.com/xdg-go/scram v1.2.0 h1:bYKF2AEwG5rqd1BumT4gAnvwU/M9nBp2pTSxeZw7Wvs=
github.com/xdg-go/scram v1.2.0/go.mod h1:3dlrS0iBaWKYVt2ZfA4cj48umJZ+cAEbR6/SjLA88I8=
github.com/xdg-go/stringprep v1.0.4 h1:XLI/Ng3O1Atzq0oBs3TWm+5ZVgkq2aqdlvP9JtoZ6c8=
github.com/xdg-go/stringprep v1.0.4/go.mod h1:mPGuuIYwz7CmR2bT9j4GbQqutWS1zV24gijq1dTyGkM=
github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 h1:ilQV1hzziu+LLM3zUTJ0trRztfwgjqKnBWNtSRkbmwM=
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78/go.mod h1:aL8wCCfTfSfmXjznFBSZNN13rSJjlIOI1fUNAtF7rmI=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=

71
pkg/resolvemcp/context.go Normal file
View File

@@ -0,0 +1,71 @@
package resolvemcp
import "context"
type contextKey string
const (
contextKeySchema contextKey = "schema"
contextKeyEntity contextKey = "entity"
contextKeyTableName contextKey = "tableName"
contextKeyModel contextKey = "model"
contextKeyModelPtr contextKey = "modelPtr"
)
func WithSchema(ctx context.Context, schema string) context.Context {
return context.WithValue(ctx, contextKeySchema, schema)
}
func GetSchema(ctx context.Context) string {
if v := ctx.Value(contextKeySchema); v != nil {
return v.(string)
}
return ""
}
func WithEntity(ctx context.Context, entity string) context.Context {
return context.WithValue(ctx, contextKeyEntity, entity)
}
func GetEntity(ctx context.Context) string {
if v := ctx.Value(contextKeyEntity); v != nil {
return v.(string)
}
return ""
}
func WithTableName(ctx context.Context, tableName string) context.Context {
return context.WithValue(ctx, contextKeyTableName, tableName)
}
func GetTableName(ctx context.Context) string {
if v := ctx.Value(contextKeyTableName); v != nil {
return v.(string)
}
return ""
}
func WithModel(ctx context.Context, model interface{}) context.Context {
return context.WithValue(ctx, contextKeyModel, model)
}
func GetModel(ctx context.Context) interface{} {
return ctx.Value(contextKeyModel)
}
func WithModelPtr(ctx context.Context, modelPtr interface{}) context.Context {
return context.WithValue(ctx, contextKeyModelPtr, modelPtr)
}
func GetModelPtr(ctx context.Context) interface{} {
return ctx.Value(contextKeyModelPtr)
}
func withRequestData(ctx context.Context, schema, entity, tableName string, model, modelPtr interface{}) context.Context {
ctx = WithSchema(ctx, schema)
ctx = WithEntity(ctx, entity)
ctx = WithTableName(ctx, tableName)
ctx = WithModel(ctx, model)
ctx = WithModelPtr(ctx, modelPtr)
return ctx
}

133
pkg/resolvemcp/cursor.go Normal file
View File

@@ -0,0 +1,133 @@
package resolvemcp
// Cursor-based pagination adapted from pkg/resolvespec/cursor.go.
import (
"fmt"
"strings"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/logger"
)
type cursorDirection int
const (
cursorForward cursorDirection = 1
cursorBackward cursorDirection = -1
)
// getCursorFilter generates a SQL EXISTS subquery for cursor-based pagination.
func getCursorFilter(
tableName string,
pkName string,
modelColumns []string,
options common.RequestOptions,
) (string, error) {
fullTableName := tableName
if strings.Contains(tableName, ".") {
tableName = strings.SplitN(tableName, ".", 2)[1]
}
cursorID, direction := getActiveCursor(options)
if cursorID == "" {
return "", fmt.Errorf("no cursor provided for table %s", tableName)
}
sortItems := options.Sort
if len(sortItems) == 0 {
return "", fmt.Errorf("no sort columns defined")
}
var whereClauses []string
reverse := direction < 0
for _, s := range sortItems {
col := strings.TrimSpace(s.Column)
if col == "" {
continue
}
parts := strings.Split(col, ".")
field := strings.TrimSpace(parts[len(parts)-1])
prefix := strings.Join(parts[:len(parts)-1], ".")
desc := strings.EqualFold(s.Direction, "desc")
if reverse {
desc = !desc
}
cursorCol, targetCol, err := resolveCursorColumn(field, prefix, tableName, modelColumns)
if err != nil {
logger.Warn("Skipping invalid sort column %q: %v", col, err)
continue
}
op := "<"
if desc {
op = ">"
}
whereClauses = append(whereClauses, fmt.Sprintf("%s %s %s", cursorCol, op, targetCol))
}
if len(whereClauses) == 0 {
return "", fmt.Errorf("no valid sort columns after filtering")
}
orSQL := buildCursorPriorityChain(whereClauses)
query := fmt.Sprintf(`EXISTS (
SELECT 1
FROM %s cursor_select
WHERE cursor_select.%s = %s
AND (%s)
)`,
fullTableName,
pkName,
cursorID,
orSQL,
)
return query, nil
}
func getActiveCursor(options common.RequestOptions) (id string, direction cursorDirection) {
if options.CursorForward != "" {
return options.CursorForward, cursorForward
}
if options.CursorBackward != "" {
return options.CursorBackward, cursorBackward
}
return "", 0
}
func resolveCursorColumn(field, prefix, tableName string, modelColumns []string) (cursorCol, targetCol string, err error) {
if strings.Contains(field, "->") {
return "cursor_select." + field, tableName + "." + field, nil
}
if modelColumns != nil {
for _, col := range modelColumns {
if strings.EqualFold(col, field) {
return "cursor_select." + field, tableName + "." + field, nil
}
}
} else {
return "cursor_select." + field, tableName + "." + field, nil
}
if prefix != "" && prefix != tableName {
return "", "", fmt.Errorf("joined columns not supported in cursor pagination: %s", field)
}
return "", "", fmt.Errorf("invalid column: %s", field)
}
func buildCursorPriorityChain(clauses []string) string {
var or []string
for i := 0; i < len(clauses); i++ {
and := strings.Join(clauses[:i+1], "\n AND ")
or = append(or, "("+and+")")
}
return strings.Join(or, "\n OR ")
}

644
pkg/resolvemcp/handler.go Normal file
View File

@@ -0,0 +1,644 @@
package resolvemcp
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"reflect"
"strings"
"github.com/mark3labs/mcp-go/server"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/reflection"
)
// Handler exposes registered database models as MCP tools and resources.
type Handler struct {
db common.Database
registry common.ModelRegistry
hooks *HookRegistry
mcpServer *server.MCPServer
name string
version string
}
// NewHandler creates a Handler with the given database and model registry.
func NewHandler(db common.Database, registry common.ModelRegistry) *Handler {
return &Handler{
db: db,
registry: registry,
hooks: NewHookRegistry(),
mcpServer: server.NewMCPServer("resolvemcp", "1.0.0"),
name: "resolvemcp",
version: "1.0.0",
}
}
// Hooks returns the hook registry.
func (h *Handler) Hooks() *HookRegistry {
return h.hooks
}
// GetDatabase returns the underlying database.
func (h *Handler) GetDatabase() common.Database {
return h.db
}
// MCPServer returns the underlying MCP server, e.g. to add custom tools.
func (h *Handler) MCPServer() *server.MCPServer {
return h.mcpServer
}
// RegisterModel registers a model and immediately exposes it as MCP tools and a resource.
func (h *Handler) RegisterModel(schema, entity string, model interface{}) error {
fullName := buildModelName(schema, entity)
if err := h.registry.RegisterModel(fullName, model); err != nil {
return err
}
registerModelTools(h, schema, entity, model)
return nil
}
// buildModelName builds the registry key for a model (same format as resolvespec).
func buildModelName(schema, entity string) string {
if schema == "" {
return entity
}
return fmt.Sprintf("%s.%s", schema, entity)
}
// getTableName returns the fully qualified table name for a model.
func (h *Handler) getTableName(schema, entity string, model interface{}) string {
schemaName, tableName := h.getSchemaAndTable(schema, entity, model)
if schemaName != "" {
if h.db.DriverName() == "sqlite" {
return fmt.Sprintf("%s_%s", schemaName, tableName)
}
return fmt.Sprintf("%s.%s", schemaName, tableName)
}
return tableName
}
func (h *Handler) getSchemaAndTable(defaultSchema, entity string, model interface{}) (schema, table string) {
if tableProvider, ok := model.(common.TableNameProvider); ok {
tableName := tableProvider.TableName()
if idx := strings.LastIndex(tableName, "."); idx != -1 {
return tableName[:idx], tableName[idx+1:]
}
if schemaProvider, ok := model.(common.SchemaProvider); ok {
return schemaProvider.SchemaName(), tableName
}
return defaultSchema, tableName
}
if schemaProvider, ok := model.(common.SchemaProvider); ok {
return schemaProvider.SchemaName(), entity
}
return defaultSchema, entity
}
// executeRead reads records from the database and returns raw data + metadata.
func (h *Handler) executeRead(ctx context.Context, schema, entity, id string, options common.RequestOptions) (interface{}, *common.Metadata, error) {
model, err := h.registry.GetModelByEntity(schema, entity)
if err != nil {
return nil, nil, fmt.Errorf("model not found: %w", err)
}
unwrapped, err := common.ValidateAndUnwrapModel(model)
if err != nil {
return nil, nil, fmt.Errorf("invalid model: %w", err)
}
model = unwrapped.Model
modelType := unwrapped.ModelType
tableName := h.getTableName(schema, entity, model)
ctx = withRequestData(ctx, schema, entity, tableName, model, unwrapped.ModelPtr)
validator := common.NewColumnValidator(model)
options = validator.FilterRequestOptions(options)
// BeforeHandle hook
hookCtx := &HookContext{
Context: ctx,
Handler: h,
Schema: schema,
Entity: entity,
Model: model,
Operation: "read",
Options: options,
ID: id,
Tx: h.db,
}
if err := h.hooks.Execute(BeforeHandle, hookCtx); err != nil {
return nil, nil, err
}
sliceType := reflect.SliceOf(reflect.PointerTo(modelType))
modelPtr := reflect.New(sliceType).Interface()
query := h.db.NewSelect().Model(modelPtr)
tempInstance := reflect.New(modelType).Interface()
if provider, ok := tempInstance.(common.TableNameProvider); !ok || provider.TableName() == "" {
query = query.Table(tableName)
}
// Column selection
if len(options.Columns) == 0 && len(options.ComputedColumns) > 0 {
options.Columns = reflection.GetSQLModelColumns(model)
}
for _, col := range options.Columns {
query = query.Column(reflection.ExtractSourceColumn(col))
}
for _, cu := range options.ComputedColumns {
query = query.ColumnExpr(fmt.Sprintf("(%s) AS %s", cu.Expression, cu.Name))
}
// Preloads
if len(options.Preload) > 0 {
var err error
query, err = h.applyPreloads(model, query, options.Preload)
if err != nil {
return nil, nil, fmt.Errorf("failed to apply preloads: %w", err)
}
}
// Filters
query = h.applyFilters(query, options.Filters)
// Custom operators
for _, customOp := range options.CustomOperators {
query = query.Where(customOp.SQL)
}
// Sorting
for _, sort := range options.Sort {
direction := "ASC"
if strings.EqualFold(sort.Direction, "desc") {
direction = "DESC"
}
query = query.Order(fmt.Sprintf("%s %s", sort.Column, direction))
}
// Cursor pagination
if options.CursorForward != "" || options.CursorBackward != "" {
pkName := reflection.GetPrimaryKeyName(model)
modelColumns := reflection.GetModelColumns(model)
if len(options.Sort) == 0 {
options.Sort = []common.SortOption{{Column: pkName, Direction: "ASC"}}
}
cursorFilter, err := getCursorFilter(tableName, pkName, modelColumns, options)
if err != nil {
return nil, nil, fmt.Errorf("cursor error: %w", err)
}
if cursorFilter != "" {
sanitized := common.SanitizeWhereClause(cursorFilter, reflection.ExtractTableNameOnly(tableName), &options)
sanitized = common.EnsureOuterParentheses(sanitized)
if sanitized != "" {
query = query.Where(sanitized)
}
}
}
// Count
total, err := query.Count(ctx)
if err != nil {
return nil, nil, fmt.Errorf("error counting records: %w", err)
}
// Pagination
if options.Limit != nil && *options.Limit > 0 {
query = query.Limit(*options.Limit)
}
if options.Offset != nil && *options.Offset > 0 {
query = query.Offset(*options.Offset)
}
// BeforeRead hook
hookCtx.Query = query
if err := h.hooks.Execute(BeforeRead, hookCtx); err != nil {
return nil, nil, err
}
var data interface{}
if id != "" {
singleResult := reflect.New(modelType).Interface()
pkName := reflection.GetPrimaryKeyName(singleResult)
query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id)
if err := query.Scan(ctx, singleResult); err != nil {
if err == sql.ErrNoRows {
return nil, nil, fmt.Errorf("record not found")
}
return nil, nil, fmt.Errorf("query error: %w", err)
}
data = singleResult
} else {
if err := query.Scan(ctx, modelPtr); err != nil {
return nil, nil, fmt.Errorf("query error: %w", err)
}
data = reflect.ValueOf(modelPtr).Elem().Interface()
}
limit := 0
offset := 0
if options.Limit != nil {
limit = *options.Limit
}
if options.Offset != nil {
offset = *options.Offset
}
// Count is the number of records in this page, not the total.
var pageCount int64
if id != "" {
pageCount = 1
} else {
pageCount = int64(reflect.ValueOf(data).Len())
}
metadata := &common.Metadata{
Total: int64(total),
Filtered: int64(total),
Count: pageCount,
Limit: limit,
Offset: offset,
}
// AfterRead hook
hookCtx.Result = data
if err := h.hooks.Execute(AfterRead, hookCtx); err != nil {
return nil, nil, err
}
return data, metadata, nil
}
// executeCreate inserts one or more records.
func (h *Handler) executeCreate(ctx context.Context, schema, entity string, data interface{}) (interface{}, error) {
model, err := h.registry.GetModelByEntity(schema, entity)
if err != nil {
return nil, fmt.Errorf("model not found: %w", err)
}
result, err := common.ValidateAndUnwrapModel(model)
if err != nil {
return nil, fmt.Errorf("invalid model: %w", err)
}
model = result.Model
tableName := h.getTableName(schema, entity, model)
ctx = withRequestData(ctx, schema, entity, tableName, model, result.ModelPtr)
hookCtx := &HookContext{
Context: ctx,
Handler: h,
Schema: schema,
Entity: entity,
Model: model,
Operation: "create",
Data: data,
Tx: h.db,
}
if err := h.hooks.Execute(BeforeHandle, hookCtx); err != nil {
return nil, err
}
if err := h.hooks.Execute(BeforeCreate, hookCtx); err != nil {
return nil, err
}
// Use potentially modified data
data = hookCtx.Data
switch v := data.(type) {
case map[string]interface{}:
query := h.db.NewInsert().Table(tableName)
for key, value := range v {
query = query.Value(key, value)
}
if _, err := query.Exec(ctx); err != nil {
return nil, fmt.Errorf("create error: %w", err)
}
hookCtx.Result = v
if err := h.hooks.Execute(AfterCreate, hookCtx); err != nil {
return nil, fmt.Errorf("AfterCreate hook failed: %w", err)
}
return v, nil
case []interface{}:
results := make([]interface{}, 0, len(v))
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
for _, item := range v {
itemMap, ok := item.(map[string]interface{})
if !ok {
return fmt.Errorf("each item must be an object")
}
q := tx.NewInsert().Table(tableName)
for key, value := range itemMap {
q = q.Value(key, value)
}
if _, err := q.Exec(ctx); err != nil {
return err
}
results = append(results, item)
}
return nil
})
if err != nil {
return nil, fmt.Errorf("batch create error: %w", err)
}
hookCtx.Result = results
if err := h.hooks.Execute(AfterCreate, hookCtx); err != nil {
return nil, fmt.Errorf("AfterCreate hook failed: %w", err)
}
return results, nil
default:
return nil, fmt.Errorf("data must be an object or array of objects")
}
}
// executeUpdate updates a record by ID.
func (h *Handler) executeUpdate(ctx context.Context, schema, entity, id string, data interface{}) (interface{}, error) {
model, err := h.registry.GetModelByEntity(schema, entity)
if err != nil {
return nil, fmt.Errorf("model not found: %w", err)
}
result, err := common.ValidateAndUnwrapModel(model)
if err != nil {
return nil, fmt.Errorf("invalid model: %w", err)
}
model = result.Model
tableName := h.getTableName(schema, entity, model)
ctx = withRequestData(ctx, schema, entity, tableName, model, result.ModelPtr)
updates, ok := data.(map[string]interface{})
if !ok {
return nil, fmt.Errorf("data must be an object")
}
if id == "" {
if idVal, exists := updates["id"]; exists {
id = fmt.Sprintf("%v", idVal)
}
}
if id == "" {
return nil, fmt.Errorf("update requires an ID")
}
pkName := reflection.GetPrimaryKeyName(model)
var updateResult interface{}
err = h.db.RunInTransaction(ctx, func(tx common.Database) error {
// Read existing record
modelType := reflect.TypeOf(model)
if modelType.Kind() == reflect.Ptr {
modelType = modelType.Elem()
}
existingRecord := reflect.New(modelType).Interface()
selectQuery := tx.NewSelect().Model(existingRecord).Column("*").
Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id)
if err := selectQuery.ScanModel(ctx); err != nil {
if err == sql.ErrNoRows {
return fmt.Errorf("no records found to update")
}
return fmt.Errorf("error fetching existing record: %w", err)
}
// Convert to map
existingMap := make(map[string]interface{})
jsonData, err := json.Marshal(existingRecord)
if err != nil {
return fmt.Errorf("error marshaling existing record: %w", err)
}
if err := json.Unmarshal(jsonData, &existingMap); err != nil {
return fmt.Errorf("error unmarshaling existing record: %w", err)
}
hookCtx := &HookContext{
Context: ctx,
Handler: h,
Schema: schema,
Entity: entity,
Model: model,
Operation: "update",
ID: id,
Data: updates,
Tx: tx,
}
if err := h.hooks.Execute(BeforeUpdate, hookCtx); err != nil {
return err
}
if modifiedData, ok := hookCtx.Data.(map[string]interface{}); ok {
updates = modifiedData
}
// Merge non-nil, non-empty values
for key, newValue := range updates {
if newValue == nil {
continue
}
if strVal, ok := newValue.(string); ok && strVal == "" {
continue
}
existingMap[key] = newValue
}
q := tx.NewUpdate().Table(tableName).SetMap(existingMap).
Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id)
res, err := q.Exec(ctx)
if err != nil {
return fmt.Errorf("error updating record: %w", err)
}
if res.RowsAffected() == 0 {
return fmt.Errorf("no records found to update")
}
updateResult = existingMap
hookCtx.Result = updateResult
return h.hooks.Execute(AfterUpdate, hookCtx)
})
if err != nil {
return nil, err
}
return updateResult, nil
}
// executeDelete deletes a record by ID.
func (h *Handler) executeDelete(ctx context.Context, schema, entity, id string) (interface{}, error) {
if id == "" {
return nil, fmt.Errorf("delete requires an ID")
}
model, err := h.registry.GetModelByEntity(schema, entity)
if err != nil {
return nil, fmt.Errorf("model not found: %w", err)
}
result, err := common.ValidateAndUnwrapModel(model)
if err != nil {
return nil, fmt.Errorf("invalid model: %w", err)
}
model = result.Model
tableName := h.getTableName(schema, entity, model)
ctx = withRequestData(ctx, schema, entity, tableName, model, result.ModelPtr)
pkName := reflection.GetPrimaryKeyName(model)
hookCtx := &HookContext{
Context: ctx,
Handler: h,
Schema: schema,
Entity: entity,
Model: model,
Operation: "delete",
ID: id,
Tx: h.db,
}
if err := h.hooks.Execute(BeforeHandle, hookCtx); err != nil {
return nil, err
}
if err := h.hooks.Execute(BeforeDelete, hookCtx); err != nil {
return nil, err
}
modelType := reflect.TypeOf(model)
if modelType.Kind() == reflect.Ptr {
modelType = modelType.Elem()
}
var recordToDelete interface{}
err = h.db.RunInTransaction(ctx, func(tx common.Database) error {
record := reflect.New(modelType).Interface()
selectQuery := tx.NewSelect().Model(record).
Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id)
if err := selectQuery.ScanModel(ctx); err != nil {
if err == sql.ErrNoRows {
return fmt.Errorf("record not found")
}
return fmt.Errorf("error fetching record: %w", err)
}
res, err := tx.NewDelete().Table(tableName).
Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id).
Exec(ctx)
if err != nil {
return fmt.Errorf("delete error: %w", err)
}
if res.RowsAffected() == 0 {
return fmt.Errorf("record not found or already deleted")
}
recordToDelete = record
hookCtx.Tx = tx
hookCtx.Result = record
return h.hooks.Execute(AfterDelete, hookCtx)
})
if err != nil {
return nil, err
}
logger.Info("[resolvemcp] Deleted record %s from %s.%s", id, schema, entity)
return recordToDelete, nil
}
// applyFilters applies all filters with OR grouping logic.
func (h *Handler) applyFilters(query common.SelectQuery, filters []common.FilterOption) common.SelectQuery {
if len(filters) == 0 {
return query
}
i := 0
for i < len(filters) {
startORGroup := i+1 < len(filters) && strings.EqualFold(filters[i+1].LogicOperator, "OR")
if startORGroup {
orGroup := []common.FilterOption{filters[i]}
j := i + 1
for j < len(filters) && strings.EqualFold(filters[j].LogicOperator, "OR") {
orGroup = append(orGroup, filters[j])
j++
}
query = h.applyFilterGroup(query, orGroup)
i = j
} else {
condition, args := h.buildFilterCondition(filters[i])
if condition != "" {
query = query.Where(condition, args...)
}
i++
}
}
return query
}
func (h *Handler) applyFilterGroup(query common.SelectQuery, filters []common.FilterOption) common.SelectQuery {
var conditions []string
var args []interface{}
for _, filter := range filters {
condition, filterArgs := h.buildFilterCondition(filter)
if condition != "" {
conditions = append(conditions, condition)
args = append(args, filterArgs...)
}
}
if len(conditions) == 0 {
return query
}
if len(conditions) == 1 {
return query.Where(conditions[0], args...)
}
return query.Where("("+strings.Join(conditions, " OR ")+")", args...)
}
func (h *Handler) buildFilterCondition(filter common.FilterOption) (string, []interface{}) {
switch filter.Operator {
case "eq", "=":
return fmt.Sprintf("%s = ?", filter.Column), []interface{}{filter.Value}
case "neq", "!=", "<>":
return fmt.Sprintf("%s != ?", filter.Column), []interface{}{filter.Value}
case "gt", ">":
return fmt.Sprintf("%s > ?", filter.Column), []interface{}{filter.Value}
case "gte", ">=":
return fmt.Sprintf("%s >= ?", filter.Column), []interface{}{filter.Value}
case "lt", "<":
return fmt.Sprintf("%s < ?", filter.Column), []interface{}{filter.Value}
case "lte", "<=":
return fmt.Sprintf("%s <= ?", filter.Column), []interface{}{filter.Value}
case "like":
return fmt.Sprintf("%s LIKE ?", filter.Column), []interface{}{filter.Value}
case "ilike":
return fmt.Sprintf("%s ILIKE ?", filter.Column), []interface{}{filter.Value}
case "in":
condition, args := common.BuildInCondition(filter.Column, filter.Value)
return condition, args
case "is_null":
return fmt.Sprintf("%s IS NULL", filter.Column), nil
case "is_not_null":
return fmt.Sprintf("%s IS NOT NULL", filter.Column), nil
}
return "", nil
}
func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, preloads []common.PreloadOption) (common.SelectQuery, error) {
for _, preload := range preloads {
if preload.Relation == "" {
continue
}
query = query.PreloadRelation(preload.Relation)
}
return query, nil
}

113
pkg/resolvemcp/hooks.go Normal file
View File

@@ -0,0 +1,113 @@
package resolvemcp
import (
"context"
"fmt"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/logger"
)
// HookType defines the type of hook to execute
type HookType string
const (
// BeforeHandle fires after model resolution, before operation dispatch.
BeforeHandle HookType = "before_handle"
BeforeRead HookType = "before_read"
AfterRead HookType = "after_read"
BeforeCreate HookType = "before_create"
AfterCreate HookType = "after_create"
BeforeUpdate HookType = "before_update"
AfterUpdate HookType = "after_update"
BeforeDelete HookType = "before_delete"
AfterDelete HookType = "after_delete"
)
// HookContext contains all the data available to a hook
type HookContext struct {
Context context.Context
Handler *Handler
Schema string
Entity string
Model interface{}
Options common.RequestOptions
Operation string
ID string
Data interface{}
Result interface{}
Error error
Query common.SelectQuery
Abort bool
AbortMessage string
AbortCode int
Tx common.Database
}
// HookFunc is the signature for hook functions
type HookFunc func(*HookContext) error
// HookRegistry manages all registered hooks
type HookRegistry struct {
hooks map[HookType][]HookFunc
}
func NewHookRegistry() *HookRegistry {
return &HookRegistry{
hooks: make(map[HookType][]HookFunc),
}
}
func (r *HookRegistry) Register(hookType HookType, hook HookFunc) {
if r.hooks == nil {
r.hooks = make(map[HookType][]HookFunc)
}
r.hooks[hookType] = append(r.hooks[hookType], hook)
logger.Info("Registered resolvemcp hook for %s (total: %d)", hookType, len(r.hooks[hookType]))
}
func (r *HookRegistry) RegisterMultiple(hookTypes []HookType, hook HookFunc) {
for _, hookType := range hookTypes {
r.Register(hookType, hook)
}
}
func (r *HookRegistry) Execute(hookType HookType, ctx *HookContext) error {
hooks, exists := r.hooks[hookType]
if !exists || len(hooks) == 0 {
return nil
}
logger.Debug("Executing %d resolvemcp hook(s) for %s", len(hooks), hookType)
for i, hook := range hooks {
if err := hook(ctx); err != nil {
logger.Error("resolvemcp hook %d for %s failed: %v", i+1, hookType, err)
return fmt.Errorf("hook execution failed: %w", err)
}
if ctx.Abort {
logger.Warn("resolvemcp hook %d for %s requested abort: %s", i+1, hookType, ctx.AbortMessage)
return fmt.Errorf("operation aborted by hook: %s", ctx.AbortMessage)
}
}
return nil
}
func (r *HookRegistry) Clear(hookType HookType) {
delete(r.hooks, hookType)
}
func (r *HookRegistry) ClearAll() {
r.hooks = make(map[HookType][]HookFunc)
}
func (r *HookRegistry) HasHooks(hookType HookType) bool {
hooks, exists := r.hooks[hookType]
return exists && len(hooks) > 0
}

View File

@@ -0,0 +1,83 @@
// Package resolvemcp exposes registered database models as Model Context Protocol (MCP) tools
// and resources over HTTP/SSE transport.
//
// It mirrors the resolvespec package patterns:
// - Same model registration API
// - Same filter, sort, cursor pagination, preload options
// - Same lifecycle hook system
//
// Usage:
//
// handler := resolvemcp.NewHandlerWithGORM(db)
// handler.RegisterModel("public", "users", &User{})
//
// r := mux.NewRouter()
// resolvemcp.SetupMuxRoutes(r, handler, "http://localhost:8080")
package resolvemcp
import (
"net/http"
"github.com/gorilla/mux"
"github.com/mark3labs/mcp-go/server"
"github.com/uptrace/bun"
"gorm.io/gorm"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/common/adapters/database"
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
)
// NewHandlerWithGORM creates a Handler backed by a GORM database connection.
func NewHandlerWithGORM(db *gorm.DB) *Handler {
return NewHandler(database.NewGormAdapter(db), modelregistry.NewModelRegistry())
}
// NewHandlerWithBun creates a Handler backed by a Bun database connection.
func NewHandlerWithBun(db *bun.DB) *Handler {
return NewHandler(database.NewBunAdapter(db), modelregistry.NewModelRegistry())
}
// NewHandlerWithDB creates a Handler using an existing common.Database and a new registry.
func NewHandlerWithDB(db common.Database) *Handler {
return NewHandler(db, modelregistry.NewModelRegistry())
}
// SetupMuxRoutes mounts the MCP HTTP/SSE endpoints on the given Gorilla Mux router.
//
// baseURL is the public-facing base URL of the server (e.g. "http://localhost:8080").
// It is sent to MCP clients during the SSE handshake so they know where to POST messages.
//
// Two routes are registered:
// - GET /mcp/sse — SSE connection endpoint (client subscribes here)
// - POST /mcp/message — JSON-RPC message endpoint (client sends requests here)
//
// To protect these routes with authentication, wrap the mux router or apply middleware
// before calling SetupMuxRoutes.
func SetupMuxRoutes(muxRouter *mux.Router, handler *Handler, baseURL string) {
sseServer := server.NewSSEServer(
handler.mcpServer,
server.WithBaseURL(baseURL),
server.WithBasePath("/mcp"),
)
muxRouter.Handle("/mcp/sse", sseServer.SSEHandler()).Methods("GET", "OPTIONS")
muxRouter.Handle("/mcp/message", sseServer.MessageHandler()).Methods("POST", "OPTIONS")
// Convenience: also expose the full SSE server at /mcp for clients that
// use ServeHTTP directly (e.g. net/http default mux).
muxRouter.PathPrefix("/mcp").Handler(http.StripPrefix("/mcp", sseServer))
}
// NewSSEServer creates an *server.SSEServer that can be mounted manually,
// useful when integrating with non-Mux routers or adding extra middleware.
//
// sseServer := resolvemcp.NewSSEServer(handler, "http://localhost:8080", "/mcp")
// http.Handle("/mcp/", http.StripPrefix("/mcp", sseServer))
func NewSSEServer(handler *Handler, baseURL, basePath string) *server.SSEServer {
return server.NewSSEServer(
handler.mcpServer,
server.WithBaseURL(baseURL),
server.WithBasePath(basePath),
)
}

415
pkg/resolvemcp/tools.go Normal file
View File

@@ -0,0 +1,415 @@
package resolvemcp
import (
"context"
"encoding/json"
"fmt"
"strings"
"github.com/mark3labs/mcp-go/mcp"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/logger"
)
// toolName builds the MCP tool name for a given operation and model.
func toolName(operation, schema, entity string) string {
if schema == "" {
return fmt.Sprintf("%s_%s", operation, entity)
}
return fmt.Sprintf("%s_%s_%s", operation, schema, entity)
}
// registerModelTools registers the four CRUD tools and resource for a model.
func registerModelTools(h *Handler, schema, entity string, model interface{}) {
registerReadTool(h, schema, entity)
registerCreateTool(h, schema, entity)
registerUpdateTool(h, schema, entity)
registerDeleteTool(h, schema, entity)
registerModelResource(h, schema, entity)
logger.Info("[resolvemcp] Registered MCP tools for %s.%s", schema, entity)
}
// --------------------------------------------------------------------------
// Read tool
// --------------------------------------------------------------------------
func registerReadTool(h *Handler, schema, entity string) {
name := toolName("read", schema, entity)
description := fmt.Sprintf("Read records from %s", buildModelName(schema, entity))
tool := mcp.NewTool(name,
mcp.WithDescription(description),
mcp.WithString("id",
mcp.Description("Primary key of a single record to fetch (optional)"),
),
mcp.WithNumber("limit",
mcp.Description("Maximum number of records to return"),
),
mcp.WithNumber("offset",
mcp.Description("Number of records to skip"),
),
mcp.WithString("cursor_forward",
mcp.Description("Cursor value for the next page (primary key of last record on current page)"),
),
mcp.WithString("cursor_backward",
mcp.Description("Cursor value for the previous page"),
),
mcp.WithArray("columns",
mcp.Description("List of column names to include in the result"),
),
mcp.WithArray("omit_columns",
mcp.Description("List of column names to exclude from the result"),
),
mcp.WithArray("filters",
mcp.Description(`Array of filter objects. Each object: {"column":"name","operator":"=","value":"val","logic_operator":"AND|OR"}. Operators: =, !=, >, <, >=, <=, like, ilike, in, is_null, is_not_null`),
),
mcp.WithArray("sort",
mcp.Description(`Array of sort objects. Each object: {"column":"name","direction":"asc|desc"}`),
),
mcp.WithArray("preloads",
mcp.Description(`Array of relation preload objects. Each object: {"relation":"RelationName","columns":["col1"]}`),
),
)
h.mcpServer.AddTool(tool, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
args := req.GetArguments()
id, _ := args["id"].(string)
options := parseRequestOptions(args)
data, metadata, err := h.executeRead(ctx, schema, entity, id, options)
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
return marshalResult(map[string]interface{}{
"success": true,
"data": data,
"metadata": metadata,
})
})
}
// --------------------------------------------------------------------------
// Create tool
// --------------------------------------------------------------------------
func registerCreateTool(h *Handler, schema, entity string) {
name := toolName("create", schema, entity)
description := fmt.Sprintf("Create one or more records in %s", buildModelName(schema, entity))
tool := mcp.NewTool(name,
mcp.WithDescription(description),
mcp.WithObject("data",
mcp.Description("Record fields to create (single object), or pass an array as the 'items' key"),
mcp.Required(),
),
)
h.mcpServer.AddTool(tool, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
args := req.GetArguments()
data, ok := args["data"]
if !ok {
return mcp.NewToolResultError("missing required argument: data"), nil
}
result, err := h.executeCreate(ctx, schema, entity, data)
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
return marshalResult(map[string]interface{}{
"success": true,
"data": result,
})
})
}
// --------------------------------------------------------------------------
// Update tool
// --------------------------------------------------------------------------
func registerUpdateTool(h *Handler, schema, entity string) {
name := toolName("update", schema, entity)
description := fmt.Sprintf("Update an existing record in %s", buildModelName(schema, entity))
tool := mcp.NewTool(name,
mcp.WithDescription(description),
mcp.WithString("id",
mcp.Description("Primary key of the record to update"),
),
mcp.WithObject("data",
mcp.Description("Fields to update (non-null fields will be merged into the existing record)"),
mcp.Required(),
),
)
h.mcpServer.AddTool(tool, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
args := req.GetArguments()
id, _ := args["id"].(string)
data, ok := args["data"]
if !ok {
return mcp.NewToolResultError("missing required argument: data"), nil
}
dataMap, ok := data.(map[string]interface{})
if !ok {
return mcp.NewToolResultError("data must be an object"), nil
}
result, err := h.executeUpdate(ctx, schema, entity, id, dataMap)
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
return marshalResult(map[string]interface{}{
"success": true,
"data": result,
})
})
}
// --------------------------------------------------------------------------
// Delete tool
// --------------------------------------------------------------------------
func registerDeleteTool(h *Handler, schema, entity string) {
name := toolName("delete", schema, entity)
description := fmt.Sprintf("Delete a record from %s by primary key", buildModelName(schema, entity))
tool := mcp.NewTool(name,
mcp.WithDescription(description),
mcp.WithString("id",
mcp.Description("Primary key of the record to delete"),
mcp.Required(),
),
)
h.mcpServer.AddTool(tool, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
args := req.GetArguments()
id, _ := args["id"].(string)
result, err := h.executeDelete(ctx, schema, entity, id)
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
return marshalResult(map[string]interface{}{
"success": true,
"data": result,
})
})
}
// --------------------------------------------------------------------------
// Resource registration
// --------------------------------------------------------------------------
func registerModelResource(h *Handler, schema, entity string) {
resourceURI := buildModelName(schema, entity)
displayName := entity
if schema != "" {
displayName = schema + "." + entity
}
resource := mcp.NewResource(
resourceURI,
displayName,
mcp.WithResourceDescription(fmt.Sprintf("Database table: %s", displayName)),
mcp.WithMIMEType("application/json"),
)
h.mcpServer.AddResource(resource, func(ctx context.Context, req mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) {
limit := 100
options := common.RequestOptions{Limit: &limit}
data, metadata, err := h.executeRead(ctx, schema, entity, "", options)
if err != nil {
return nil, err
}
payload := map[string]interface{}{
"data": data,
"metadata": metadata,
}
jsonBytes, err := json.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("error marshaling resource: %w", err)
}
return []mcp.ResourceContents{
mcp.TextResourceContents{
URI: req.Params.URI,
MIMEType: "application/json",
Text: string(jsonBytes),
},
}, nil
})
}
// --------------------------------------------------------------------------
// Argument parsing helpers
// --------------------------------------------------------------------------
// parseRequestOptions converts raw MCP tool arguments into common.RequestOptions.
func parseRequestOptions(args map[string]interface{}) common.RequestOptions {
options := common.RequestOptions{}
// limit
if v, ok := args["limit"]; ok {
switch n := v.(type) {
case float64:
limit := int(n)
options.Limit = &limit
case int:
options.Limit = &n
}
}
// offset
if v, ok := args["offset"]; ok {
switch n := v.(type) {
case float64:
offset := int(n)
options.Offset = &offset
case int:
options.Offset = &n
}
}
// cursor_forward / cursor_backward
if v, ok := args["cursor_forward"].(string); ok {
options.CursorForward = v
}
if v, ok := args["cursor_backward"].(string); ok {
options.CursorBackward = v
}
// columns
options.Columns = parseStringArray(args["columns"])
// omit_columns
options.OmitColumns = parseStringArray(args["omit_columns"])
// filters — marshal each item and unmarshal into FilterOption
options.Filters = parseFilters(args["filters"])
// sort
options.Sort = parseSortOptions(args["sort"])
// preloads
options.Preload = parsePreloadOptions(args["preloads"])
return options
}
func parseStringArray(raw interface{}) []string {
if raw == nil {
return nil
}
items, ok := raw.([]interface{})
if !ok {
return nil
}
result := make([]string, 0, len(items))
for _, item := range items {
if s, ok := item.(string); ok {
result = append(result, s)
}
}
return result
}
func parseFilters(raw interface{}) []common.FilterOption {
if raw == nil {
return nil
}
items, ok := raw.([]interface{})
if !ok {
return nil
}
result := make([]common.FilterOption, 0, len(items))
for _, item := range items {
b, err := json.Marshal(item)
if err != nil {
continue
}
var f common.FilterOption
if err := json.Unmarshal(b, &f); err != nil {
continue
}
if f.Column == "" || f.Operator == "" {
continue
}
// Normalise logic operator
if strings.EqualFold(f.LogicOperator, "or") {
f.LogicOperator = "OR"
} else {
f.LogicOperator = "AND"
}
result = append(result, f)
}
return result
}
func parseSortOptions(raw interface{}) []common.SortOption {
if raw == nil {
return nil
}
items, ok := raw.([]interface{})
if !ok {
return nil
}
result := make([]common.SortOption, 0, len(items))
for _, item := range items {
b, err := json.Marshal(item)
if err != nil {
continue
}
var s common.SortOption
if err := json.Unmarshal(b, &s); err != nil {
continue
}
if s.Column == "" {
continue
}
result = append(result, s)
}
return result
}
func parsePreloadOptions(raw interface{}) []common.PreloadOption {
if raw == nil {
return nil
}
items, ok := raw.([]interface{})
if !ok {
return nil
}
result := make([]common.PreloadOption, 0, len(items))
for _, item := range items {
b, err := json.Marshal(item)
if err != nil {
continue
}
var p common.PreloadOption
if err := json.Unmarshal(b, &p); err != nil {
continue
}
if p.Relation == "" {
continue
}
result = append(result, p)
}
return result
}
// marshalResult marshals a value to JSON and returns it as an MCP text result.
func marshalResult(v interface{}) (*mcp.CallToolResult, error) {
b, err := json.Marshal(v)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("error marshaling result: %v", err)), nil
}
return mcp.NewToolResultText(string(b)), nil
}

View File

@@ -93,12 +93,18 @@ func (opts *ExtendedRequestOptions) GetCursorFilter(
}
// Handle joins
if isJoin && expandJoins != nil {
if joinClause, ok := expandJoins[prefix]; ok {
jSQL, cRef := rewriteJoin(joinClause, tableName, prefix)
joinSQL = jSQL
cursorCol = cRef + "." + field
targetCol = prefix + "." + field
if isJoin {
if expandJoins != nil {
if joinClause, ok := expandJoins[prefix]; ok {
jSQL, cRef := rewriteJoin(joinClause, tableName, prefix)
joinSQL = jSQL
cursorCol = cRef + "." + field
targetCol = prefix + "." + field
}
}
if cursorCol == "" {
logger.Warn("Skipping cursor sort column %q: join alias %q not in expandJoins", col, prefix)
continue
}
}

View File

@@ -278,6 +278,47 @@ func TestCleanSortField(t *testing.T) {
}
}
func TestGetCursorFilter_LateralJoin(t *testing.T) {
lateralJoin := "inner join lateral (\nselect string_agg(a.name, '.') as sortorder\nfrom tree(account.rid_account) r\ninner join account a on a.id = r.id\n) fn on true"
opts := &ExtendedRequestOptions{
RequestOptions: common.RequestOptions{
Sort: []common.SortOption{
{Column: "fn.sortorder", Direction: "ASC"},
},
},
}
opts.CursorForward = "8975"
tableName := "core.account"
pkName := "rid_account"
// modelColumns does not contain "sortorder" - it's a lateral join computed column
modelColumns := []string{"rid_account", "description", "pastelno"}
expandJoins := map[string]string{"fn": lateralJoin}
filter, err := opts.GetCursorFilter(tableName, pkName, modelColumns, expandJoins)
if err != nil {
t.Fatalf("GetCursorFilter failed: %v", err)
}
t.Logf("Generated lateral cursor filter: %s", filter)
// Should contain the rewritten lateral join inside the EXISTS subquery
if !strings.Contains(filter, "cursor_select_fn") {
t.Errorf("Filter should reference cursor_select_fn alias, got: %s", filter)
}
// Should compare fn.sortorder values
if !strings.Contains(filter, "sortorder") {
t.Errorf("Filter should reference sortorder column, got: %s", filter)
}
// Should NOT contain empty comparison like "< "
if strings.Contains(filter, " < ") || strings.Contains(filter, " > ") {
t.Errorf("Filter should not contain empty comparison operators, got: %s", filter)
}
}
func TestBuildPriorityChain(t *testing.T) {
clauses := []string{
"cursor_select.priority > posts.priority",

View File

@@ -723,13 +723,15 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
// Extract model columns for validation using the generic database function
modelColumns := reflection.GetModelColumns(model)
// Build expand joins map (if needed in future)
var expandJoins map[string]string
if len(options.Expand) > 0 {
expandJoins = make(map[string]string)
// TODO: Build actual JOIN SQL for each expand relation
// For now, pass empty map as joins are handled via Preload
// Build expand joins map: custom SQL joins are available in cursor subquery
expandJoins := make(map[string]string)
for _, joinClause := range options.CustomSQLJoin {
alias := extractJoinAlias(joinClause)
if alias != "" {
expandJoins[alias] = joinClause
}
}
// TODO: also add Expand relation JOINs when those are built as SQL rather than Preload
// Default sort to primary key when none provided
if len(options.Sort) == 0 {

View File

@@ -552,10 +552,8 @@ func (h *Handler) parseCustomSQLJoin(options *ExtendedRequestOptions, value stri
// - "LEFT JOIN departments d ON ..." -> "d"
// - "INNER JOIN users AS u ON ..." -> "u"
// - "JOIN roles r ON ..." -> "r"
// - "INNER JOIN LATERAL (...) fn ON true" -> "fn"
func extractJoinAlias(joinClause string) string {
// Pattern: JOIN table_name [AS] alias ON ...
// We need to extract the alias (word before ON)
upperJoin := strings.ToUpper(joinClause)
// Find the "JOIN" keyword position
@@ -564,7 +562,20 @@ func extractJoinAlias(joinClause string) string {
return ""
}
// Find the "ON" keyword position
// Lateral joins: alias is the word after the closing ) and before ON
if strings.Contains(upperJoin, "LATERAL") {
lastClose := strings.LastIndex(joinClause, ")")
if lastClose != -1 {
words := strings.Fields(joinClause[lastClose+1:])
// words should be like ["fn", "on", "true"] or ["on", "true"]
if len(words) >= 1 && !strings.EqualFold(words[0], "on") {
return words[0]
}
}
return ""
}
// Regular joins: find the "ON" keyword position (first occurrence)
onIdx := strings.Index(upperJoin, " ON ")
if onIdx == -1 {
return ""

View File

@@ -142,6 +142,16 @@ func TestExtractJoinAlias(t *testing.T) {
joinClause: "LEFT JOIN departments",
expected: "",
},
{
name: "LATERAL join with alias",
joinClause: "inner join lateral (select sortorder from compute_fn(t.id)) fn on true",
expected: "fn",
},
{
name: "LATERAL join with multiline subquery containing inner ON",
joinClause: "inner join lateral (\nselect string_agg(a.name, '.') as sortorder\nfrom tree(t.id) r\ninner join account a on a.id = r.id\n) fn on true",
expected: "fn",
},
}
for _, tt := range tests {