diff --git a/go.mod b/go.mod index e471573..f1e3fdd 100644 --- a/go.mod +++ b/go.mod @@ -40,6 +40,7 @@ require ( go.opentelemetry.io/otel/trace v1.38.0 go.uber.org/zap v1.27.1 golang.org/x/crypto v0.46.0 + golang.org/x/oauth2 v0.34.0 golang.org/x/time v0.14.0 gorm.io/driver/postgres v1.6.0 gorm.io/driver/sqlite v1.6.0 @@ -78,6 +79,7 @@ require ( github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect github.com/golang-sql/sqlexp v0.1.0 // indirect github.com/golang/snappy v1.0.0 // indirect + github.com/google/jsonschema-go v0.4.2 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect @@ -86,6 +88,7 @@ require ( github.com/jinzhu/now v1.1.5 // indirect github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect github.com/magiconair/properties v1.8.10 // indirect + github.com/mark3labs/mcp-go v0.46.0 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/moby/docker-image-spec v1.3.1 // indirect github.com/moby/go-archive v0.1.0 // indirect @@ -131,6 +134,7 @@ require ( github.com/xdg-go/pbkdf2 v1.0.0 // indirect github.com/xdg-go/scram v1.2.0 // indirect github.com/xdg-go/stringprep v1.0.4 // indirect + github.com/yosida95/uritemplate/v3 v3.0.2 // indirect github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 // indirect github.com/yusufpapurcu/wmi v1.2.4 // indirect go.opentelemetry.io/auto/sdk v1.1.0 // indirect @@ -143,7 +147,6 @@ require ( golang.org/x/exp v0.0.0-20251219203646-944ab1f22d93 // indirect golang.org/x/mod v0.31.0 // indirect golang.org/x/net v0.48.0 // indirect - golang.org/x/oauth2 v0.34.0 // indirect golang.org/x/sync v0.19.0 // indirect golang.org/x/sys v0.39.0 // indirect golang.org/x/text v0.32.0 // indirect diff --git a/go.sum b/go.sum index b746a77..72a6556 100644 --- a/go.sum +++ b/go.sum @@ -120,6 +120,8 @@ github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/jsonschema-go v0.4.2 h1:tmrUohrwoLZZS/P3x7ex0WAVknEkBZM46iALbcqoRA8= +github.com/google/jsonschema-go v0.4.2/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs= github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= @@ -173,6 +175,8 @@ github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I= github.com/magiconair/properties v1.8.10 h1:s31yESBquKXCV9a/ScB3ESkOjUYYv+X0rg8SYxI99mE= github.com/magiconair/properties v1.8.10/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= +github.com/mark3labs/mcp-go v0.46.0 h1:8KRibF4wcKejbLsHxCA/QBVUr5fQ9nwz/n8lGqmaALo= +github.com/mark3labs/mcp-go v0.46.0/go.mod h1:JKTC7R2LLVagkEWK7Kwu7DbmA6iIvnNAod6yrHiQMag= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-sqlite3 v1.14.33 h1:A5blZ5ulQo2AtayQ9/limgHEkFreKj1Dv226a1K73s0= @@ -326,6 +330,8 @@ github.com/xdg-go/scram v1.2.0 h1:bYKF2AEwG5rqd1BumT4gAnvwU/M9nBp2pTSxeZw7Wvs= github.com/xdg-go/scram v1.2.0/go.mod h1:3dlrS0iBaWKYVt2ZfA4cj48umJZ+cAEbR6/SjLA88I8= github.com/xdg-go/stringprep v1.0.4 h1:XLI/Ng3O1Atzq0oBs3TWm+5ZVgkq2aqdlvP9JtoZ6c8= github.com/xdg-go/stringprep v1.0.4/go.mod h1:mPGuuIYwz7CmR2bT9j4GbQqutWS1zV24gijq1dTyGkM= +github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= +github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 h1:ilQV1hzziu+LLM3zUTJ0trRztfwgjqKnBWNtSRkbmwM= github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78/go.mod h1:aL8wCCfTfSfmXjznFBSZNN13rSJjlIOI1fUNAtF7rmI= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= diff --git a/pkg/resolvemcp/context.go b/pkg/resolvemcp/context.go new file mode 100644 index 0000000..f8e97f7 --- /dev/null +++ b/pkg/resolvemcp/context.go @@ -0,0 +1,71 @@ +package resolvemcp + +import "context" + +type contextKey string + +const ( + contextKeySchema contextKey = "schema" + contextKeyEntity contextKey = "entity" + contextKeyTableName contextKey = "tableName" + contextKeyModel contextKey = "model" + contextKeyModelPtr contextKey = "modelPtr" +) + +func WithSchema(ctx context.Context, schema string) context.Context { + return context.WithValue(ctx, contextKeySchema, schema) +} + +func GetSchema(ctx context.Context) string { + if v := ctx.Value(contextKeySchema); v != nil { + return v.(string) + } + return "" +} + +func WithEntity(ctx context.Context, entity string) context.Context { + return context.WithValue(ctx, contextKeyEntity, entity) +} + +func GetEntity(ctx context.Context) string { + if v := ctx.Value(contextKeyEntity); v != nil { + return v.(string) + } + return "" +} + +func WithTableName(ctx context.Context, tableName string) context.Context { + return context.WithValue(ctx, contextKeyTableName, tableName) +} + +func GetTableName(ctx context.Context) string { + if v := ctx.Value(contextKeyTableName); v != nil { + return v.(string) + } + return "" +} + +func WithModel(ctx context.Context, model interface{}) context.Context { + return context.WithValue(ctx, contextKeyModel, model) +} + +func GetModel(ctx context.Context) interface{} { + return ctx.Value(contextKeyModel) +} + +func WithModelPtr(ctx context.Context, modelPtr interface{}) context.Context { + return context.WithValue(ctx, contextKeyModelPtr, modelPtr) +} + +func GetModelPtr(ctx context.Context) interface{} { + return ctx.Value(contextKeyModelPtr) +} + +func withRequestData(ctx context.Context, schema, entity, tableName string, model, modelPtr interface{}) context.Context { + ctx = WithSchema(ctx, schema) + ctx = WithEntity(ctx, entity) + ctx = WithTableName(ctx, tableName) + ctx = WithModel(ctx, model) + ctx = WithModelPtr(ctx, modelPtr) + return ctx +} diff --git a/pkg/resolvemcp/cursor.go b/pkg/resolvemcp/cursor.go new file mode 100644 index 0000000..89668f1 --- /dev/null +++ b/pkg/resolvemcp/cursor.go @@ -0,0 +1,133 @@ +package resolvemcp + +// Cursor-based pagination adapted from pkg/resolvespec/cursor.go. + +import ( + "fmt" + "strings" + + "github.com/bitechdev/ResolveSpec/pkg/common" + "github.com/bitechdev/ResolveSpec/pkg/logger" +) + +type cursorDirection int + +const ( + cursorForward cursorDirection = 1 + cursorBackward cursorDirection = -1 +) + +// getCursorFilter generates a SQL EXISTS subquery for cursor-based pagination. +func getCursorFilter( + tableName string, + pkName string, + modelColumns []string, + options common.RequestOptions, +) (string, error) { + fullTableName := tableName + if strings.Contains(tableName, ".") { + tableName = strings.SplitN(tableName, ".", 2)[1] + } + + cursorID, direction := getActiveCursor(options) + if cursorID == "" { + return "", fmt.Errorf("no cursor provided for table %s", tableName) + } + + sortItems := options.Sort + if len(sortItems) == 0 { + return "", fmt.Errorf("no sort columns defined") + } + + var whereClauses []string + reverse := direction < 0 + + for _, s := range sortItems { + col := strings.TrimSpace(s.Column) + if col == "" { + continue + } + + parts := strings.Split(col, ".") + field := strings.TrimSpace(parts[len(parts)-1]) + prefix := strings.Join(parts[:len(parts)-1], ".") + + desc := strings.EqualFold(s.Direction, "desc") + if reverse { + desc = !desc + } + + cursorCol, targetCol, err := resolveCursorColumn(field, prefix, tableName, modelColumns) + if err != nil { + logger.Warn("Skipping invalid sort column %q: %v", col, err) + continue + } + + op := "<" + if desc { + op = ">" + } + whereClauses = append(whereClauses, fmt.Sprintf("%s %s %s", cursorCol, op, targetCol)) + } + + if len(whereClauses) == 0 { + return "", fmt.Errorf("no valid sort columns after filtering") + } + + orSQL := buildCursorPriorityChain(whereClauses) + + query := fmt.Sprintf(`EXISTS ( + SELECT 1 + FROM %s cursor_select + WHERE cursor_select.%s = %s + AND (%s) +)`, + fullTableName, + pkName, + cursorID, + orSQL, + ) + + return query, nil +} + +func getActiveCursor(options common.RequestOptions) (id string, direction cursorDirection) { + if options.CursorForward != "" { + return options.CursorForward, cursorForward + } + if options.CursorBackward != "" { + return options.CursorBackward, cursorBackward + } + return "", 0 +} + +func resolveCursorColumn(field, prefix, tableName string, modelColumns []string) (cursorCol, targetCol string, err error) { + if strings.Contains(field, "->") { + return "cursor_select." + field, tableName + "." + field, nil + } + + if modelColumns != nil { + for _, col := range modelColumns { + if strings.EqualFold(col, field) { + return "cursor_select." + field, tableName + "." + field, nil + } + } + } else { + return "cursor_select." + field, tableName + "." + field, nil + } + + if prefix != "" && prefix != tableName { + return "", "", fmt.Errorf("joined columns not supported in cursor pagination: %s", field) + } + + return "", "", fmt.Errorf("invalid column: %s", field) +} + +func buildCursorPriorityChain(clauses []string) string { + var or []string + for i := 0; i < len(clauses); i++ { + and := strings.Join(clauses[:i+1], "\n AND ") + or = append(or, "("+and+")") + } + return strings.Join(or, "\n OR ") +} diff --git a/pkg/resolvemcp/handler.go b/pkg/resolvemcp/handler.go new file mode 100644 index 0000000..589e5f7 --- /dev/null +++ b/pkg/resolvemcp/handler.go @@ -0,0 +1,644 @@ +package resolvemcp + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "reflect" + "strings" + + "github.com/mark3labs/mcp-go/server" + + "github.com/bitechdev/ResolveSpec/pkg/common" + "github.com/bitechdev/ResolveSpec/pkg/logger" + "github.com/bitechdev/ResolveSpec/pkg/reflection" +) + +// Handler exposes registered database models as MCP tools and resources. +type Handler struct { + db common.Database + registry common.ModelRegistry + hooks *HookRegistry + mcpServer *server.MCPServer + name string + version string +} + +// NewHandler creates a Handler with the given database and model registry. +func NewHandler(db common.Database, registry common.ModelRegistry) *Handler { + return &Handler{ + db: db, + registry: registry, + hooks: NewHookRegistry(), + mcpServer: server.NewMCPServer("resolvemcp", "1.0.0"), + name: "resolvemcp", + version: "1.0.0", + } +} + +// Hooks returns the hook registry. +func (h *Handler) Hooks() *HookRegistry { + return h.hooks +} + +// GetDatabase returns the underlying database. +func (h *Handler) GetDatabase() common.Database { + return h.db +} + +// MCPServer returns the underlying MCP server, e.g. to add custom tools. +func (h *Handler) MCPServer() *server.MCPServer { + return h.mcpServer +} + +// RegisterModel registers a model and immediately exposes it as MCP tools and a resource. +func (h *Handler) RegisterModel(schema, entity string, model interface{}) error { + fullName := buildModelName(schema, entity) + if err := h.registry.RegisterModel(fullName, model); err != nil { + return err + } + registerModelTools(h, schema, entity, model) + return nil +} + +// buildModelName builds the registry key for a model (same format as resolvespec). +func buildModelName(schema, entity string) string { + if schema == "" { + return entity + } + return fmt.Sprintf("%s.%s", schema, entity) +} + +// getTableName returns the fully qualified table name for a model. +func (h *Handler) getTableName(schema, entity string, model interface{}) string { + schemaName, tableName := h.getSchemaAndTable(schema, entity, model) + if schemaName != "" { + if h.db.DriverName() == "sqlite" { + return fmt.Sprintf("%s_%s", schemaName, tableName) + } + return fmt.Sprintf("%s.%s", schemaName, tableName) + } + return tableName +} + +func (h *Handler) getSchemaAndTable(defaultSchema, entity string, model interface{}) (schema, table string) { + if tableProvider, ok := model.(common.TableNameProvider); ok { + tableName := tableProvider.TableName() + if idx := strings.LastIndex(tableName, "."); idx != -1 { + return tableName[:idx], tableName[idx+1:] + } + if schemaProvider, ok := model.(common.SchemaProvider); ok { + return schemaProvider.SchemaName(), tableName + } + return defaultSchema, tableName + } + if schemaProvider, ok := model.(common.SchemaProvider); ok { + return schemaProvider.SchemaName(), entity + } + return defaultSchema, entity +} + +// executeRead reads records from the database and returns raw data + metadata. +func (h *Handler) executeRead(ctx context.Context, schema, entity, id string, options common.RequestOptions) (interface{}, *common.Metadata, error) { + model, err := h.registry.GetModelByEntity(schema, entity) + if err != nil { + return nil, nil, fmt.Errorf("model not found: %w", err) + } + + unwrapped, err := common.ValidateAndUnwrapModel(model) + if err != nil { + return nil, nil, fmt.Errorf("invalid model: %w", err) + } + + model = unwrapped.Model + modelType := unwrapped.ModelType + tableName := h.getTableName(schema, entity, model) + ctx = withRequestData(ctx, schema, entity, tableName, model, unwrapped.ModelPtr) + + validator := common.NewColumnValidator(model) + options = validator.FilterRequestOptions(options) + + // BeforeHandle hook + hookCtx := &HookContext{ + Context: ctx, + Handler: h, + Schema: schema, + Entity: entity, + Model: model, + Operation: "read", + Options: options, + ID: id, + Tx: h.db, + } + if err := h.hooks.Execute(BeforeHandle, hookCtx); err != nil { + return nil, nil, err + } + + sliceType := reflect.SliceOf(reflect.PointerTo(modelType)) + modelPtr := reflect.New(sliceType).Interface() + + query := h.db.NewSelect().Model(modelPtr) + + tempInstance := reflect.New(modelType).Interface() + if provider, ok := tempInstance.(common.TableNameProvider); !ok || provider.TableName() == "" { + query = query.Table(tableName) + } + + // Column selection + if len(options.Columns) == 0 && len(options.ComputedColumns) > 0 { + options.Columns = reflection.GetSQLModelColumns(model) + } + for _, col := range options.Columns { + query = query.Column(reflection.ExtractSourceColumn(col)) + } + for _, cu := range options.ComputedColumns { + query = query.ColumnExpr(fmt.Sprintf("(%s) AS %s", cu.Expression, cu.Name)) + } + + // Preloads + if len(options.Preload) > 0 { + var err error + query, err = h.applyPreloads(model, query, options.Preload) + if err != nil { + return nil, nil, fmt.Errorf("failed to apply preloads: %w", err) + } + } + + // Filters + query = h.applyFilters(query, options.Filters) + + // Custom operators + for _, customOp := range options.CustomOperators { + query = query.Where(customOp.SQL) + } + + // Sorting + for _, sort := range options.Sort { + direction := "ASC" + if strings.EqualFold(sort.Direction, "desc") { + direction = "DESC" + } + query = query.Order(fmt.Sprintf("%s %s", sort.Column, direction)) + } + + // Cursor pagination + if options.CursorForward != "" || options.CursorBackward != "" { + pkName := reflection.GetPrimaryKeyName(model) + modelColumns := reflection.GetModelColumns(model) + + if len(options.Sort) == 0 { + options.Sort = []common.SortOption{{Column: pkName, Direction: "ASC"}} + } + + cursorFilter, err := getCursorFilter(tableName, pkName, modelColumns, options) + if err != nil { + return nil, nil, fmt.Errorf("cursor error: %w", err) + } + + if cursorFilter != "" { + sanitized := common.SanitizeWhereClause(cursorFilter, reflection.ExtractTableNameOnly(tableName), &options) + sanitized = common.EnsureOuterParentheses(sanitized) + if sanitized != "" { + query = query.Where(sanitized) + } + } + } + + // Count + total, err := query.Count(ctx) + if err != nil { + return nil, nil, fmt.Errorf("error counting records: %w", err) + } + + // Pagination + if options.Limit != nil && *options.Limit > 0 { + query = query.Limit(*options.Limit) + } + if options.Offset != nil && *options.Offset > 0 { + query = query.Offset(*options.Offset) + } + + // BeforeRead hook + hookCtx.Query = query + if err := h.hooks.Execute(BeforeRead, hookCtx); err != nil { + return nil, nil, err + } + + var data interface{} + if id != "" { + singleResult := reflect.New(modelType).Interface() + pkName := reflection.GetPrimaryKeyName(singleResult) + query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id) + if err := query.Scan(ctx, singleResult); err != nil { + if err == sql.ErrNoRows { + return nil, nil, fmt.Errorf("record not found") + } + return nil, nil, fmt.Errorf("query error: %w", err) + } + data = singleResult + } else { + if err := query.Scan(ctx, modelPtr); err != nil { + return nil, nil, fmt.Errorf("query error: %w", err) + } + data = reflect.ValueOf(modelPtr).Elem().Interface() + } + + limit := 0 + offset := 0 + if options.Limit != nil { + limit = *options.Limit + } + if options.Offset != nil { + offset = *options.Offset + } + + // Count is the number of records in this page, not the total. + var pageCount int64 + if id != "" { + pageCount = 1 + } else { + pageCount = int64(reflect.ValueOf(data).Len()) + } + + metadata := &common.Metadata{ + Total: int64(total), + Filtered: int64(total), + Count: pageCount, + Limit: limit, + Offset: offset, + } + + // AfterRead hook + hookCtx.Result = data + if err := h.hooks.Execute(AfterRead, hookCtx); err != nil { + return nil, nil, err + } + + return data, metadata, nil +} + +// executeCreate inserts one or more records. +func (h *Handler) executeCreate(ctx context.Context, schema, entity string, data interface{}) (interface{}, error) { + model, err := h.registry.GetModelByEntity(schema, entity) + if err != nil { + return nil, fmt.Errorf("model not found: %w", err) + } + + result, err := common.ValidateAndUnwrapModel(model) + if err != nil { + return nil, fmt.Errorf("invalid model: %w", err) + } + + model = result.Model + tableName := h.getTableName(schema, entity, model) + ctx = withRequestData(ctx, schema, entity, tableName, model, result.ModelPtr) + + hookCtx := &HookContext{ + Context: ctx, + Handler: h, + Schema: schema, + Entity: entity, + Model: model, + Operation: "create", + Data: data, + Tx: h.db, + } + if err := h.hooks.Execute(BeforeHandle, hookCtx); err != nil { + return nil, err + } + if err := h.hooks.Execute(BeforeCreate, hookCtx); err != nil { + return nil, err + } + + // Use potentially modified data + data = hookCtx.Data + + switch v := data.(type) { + case map[string]interface{}: + query := h.db.NewInsert().Table(tableName) + for key, value := range v { + query = query.Value(key, value) + } + if _, err := query.Exec(ctx); err != nil { + return nil, fmt.Errorf("create error: %w", err) + } + hookCtx.Result = v + if err := h.hooks.Execute(AfterCreate, hookCtx); err != nil { + return nil, fmt.Errorf("AfterCreate hook failed: %w", err) + } + return v, nil + + case []interface{}: + results := make([]interface{}, 0, len(v)) + err := h.db.RunInTransaction(ctx, func(tx common.Database) error { + for _, item := range v { + itemMap, ok := item.(map[string]interface{}) + if !ok { + return fmt.Errorf("each item must be an object") + } + q := tx.NewInsert().Table(tableName) + for key, value := range itemMap { + q = q.Value(key, value) + } + if _, err := q.Exec(ctx); err != nil { + return err + } + results = append(results, item) + } + return nil + }) + if err != nil { + return nil, fmt.Errorf("batch create error: %w", err) + } + hookCtx.Result = results + if err := h.hooks.Execute(AfterCreate, hookCtx); err != nil { + return nil, fmt.Errorf("AfterCreate hook failed: %w", err) + } + return results, nil + + default: + return nil, fmt.Errorf("data must be an object or array of objects") + } +} + +// executeUpdate updates a record by ID. +func (h *Handler) executeUpdate(ctx context.Context, schema, entity, id string, data interface{}) (interface{}, error) { + model, err := h.registry.GetModelByEntity(schema, entity) + if err != nil { + return nil, fmt.Errorf("model not found: %w", err) + } + + result, err := common.ValidateAndUnwrapModel(model) + if err != nil { + return nil, fmt.Errorf("invalid model: %w", err) + } + + model = result.Model + tableName := h.getTableName(schema, entity, model) + ctx = withRequestData(ctx, schema, entity, tableName, model, result.ModelPtr) + + updates, ok := data.(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("data must be an object") + } + + if id == "" { + if idVal, exists := updates["id"]; exists { + id = fmt.Sprintf("%v", idVal) + } + } + if id == "" { + return nil, fmt.Errorf("update requires an ID") + } + + pkName := reflection.GetPrimaryKeyName(model) + + var updateResult interface{} + err = h.db.RunInTransaction(ctx, func(tx common.Database) error { + // Read existing record + modelType := reflect.TypeOf(model) + if modelType.Kind() == reflect.Ptr { + modelType = modelType.Elem() + } + existingRecord := reflect.New(modelType).Interface() + selectQuery := tx.NewSelect().Model(existingRecord).Column("*"). + Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id) + + if err := selectQuery.ScanModel(ctx); err != nil { + if err == sql.ErrNoRows { + return fmt.Errorf("no records found to update") + } + return fmt.Errorf("error fetching existing record: %w", err) + } + + // Convert to map + existingMap := make(map[string]interface{}) + jsonData, err := json.Marshal(existingRecord) + if err != nil { + return fmt.Errorf("error marshaling existing record: %w", err) + } + if err := json.Unmarshal(jsonData, &existingMap); err != nil { + return fmt.Errorf("error unmarshaling existing record: %w", err) + } + + hookCtx := &HookContext{ + Context: ctx, + Handler: h, + Schema: schema, + Entity: entity, + Model: model, + Operation: "update", + ID: id, + Data: updates, + Tx: tx, + } + if err := h.hooks.Execute(BeforeUpdate, hookCtx); err != nil { + return err + } + if modifiedData, ok := hookCtx.Data.(map[string]interface{}); ok { + updates = modifiedData + } + + // Merge non-nil, non-empty values + for key, newValue := range updates { + if newValue == nil { + continue + } + if strVal, ok := newValue.(string); ok && strVal == "" { + continue + } + existingMap[key] = newValue + } + + q := tx.NewUpdate().Table(tableName).SetMap(existingMap). + Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id) + res, err := q.Exec(ctx) + if err != nil { + return fmt.Errorf("error updating record: %w", err) + } + if res.RowsAffected() == 0 { + return fmt.Errorf("no records found to update") + } + + updateResult = existingMap + hookCtx.Result = updateResult + return h.hooks.Execute(AfterUpdate, hookCtx) + }) + + if err != nil { + return nil, err + } + return updateResult, nil +} + +// executeDelete deletes a record by ID. +func (h *Handler) executeDelete(ctx context.Context, schema, entity, id string) (interface{}, error) { + if id == "" { + return nil, fmt.Errorf("delete requires an ID") + } + + model, err := h.registry.GetModelByEntity(schema, entity) + if err != nil { + return nil, fmt.Errorf("model not found: %w", err) + } + + result, err := common.ValidateAndUnwrapModel(model) + if err != nil { + return nil, fmt.Errorf("invalid model: %w", err) + } + + model = result.Model + tableName := h.getTableName(schema, entity, model) + ctx = withRequestData(ctx, schema, entity, tableName, model, result.ModelPtr) + + pkName := reflection.GetPrimaryKeyName(model) + + hookCtx := &HookContext{ + Context: ctx, + Handler: h, + Schema: schema, + Entity: entity, + Model: model, + Operation: "delete", + ID: id, + Tx: h.db, + } + if err := h.hooks.Execute(BeforeHandle, hookCtx); err != nil { + return nil, err + } + if err := h.hooks.Execute(BeforeDelete, hookCtx); err != nil { + return nil, err + } + + modelType := reflect.TypeOf(model) + if modelType.Kind() == reflect.Ptr { + modelType = modelType.Elem() + } + + var recordToDelete interface{} + + err = h.db.RunInTransaction(ctx, func(tx common.Database) error { + record := reflect.New(modelType).Interface() + selectQuery := tx.NewSelect().Model(record). + Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id) + if err := selectQuery.ScanModel(ctx); err != nil { + if err == sql.ErrNoRows { + return fmt.Errorf("record not found") + } + return fmt.Errorf("error fetching record: %w", err) + } + + res, err := tx.NewDelete().Table(tableName). + Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id). + Exec(ctx) + if err != nil { + return fmt.Errorf("delete error: %w", err) + } + if res.RowsAffected() == 0 { + return fmt.Errorf("record not found or already deleted") + } + + recordToDelete = record + hookCtx.Tx = tx + hookCtx.Result = record + return h.hooks.Execute(AfterDelete, hookCtx) + }) + if err != nil { + return nil, err + } + + logger.Info("[resolvemcp] Deleted record %s from %s.%s", id, schema, entity) + return recordToDelete, nil +} + +// applyFilters applies all filters with OR grouping logic. +func (h *Handler) applyFilters(query common.SelectQuery, filters []common.FilterOption) common.SelectQuery { + if len(filters) == 0 { + return query + } + + i := 0 + for i < len(filters) { + startORGroup := i+1 < len(filters) && strings.EqualFold(filters[i+1].LogicOperator, "OR") + + if startORGroup { + orGroup := []common.FilterOption{filters[i]} + j := i + 1 + for j < len(filters) && strings.EqualFold(filters[j].LogicOperator, "OR") { + orGroup = append(orGroup, filters[j]) + j++ + } + query = h.applyFilterGroup(query, orGroup) + i = j + } else { + condition, args := h.buildFilterCondition(filters[i]) + if condition != "" { + query = query.Where(condition, args...) + } + i++ + } + } + + return query +} + +func (h *Handler) applyFilterGroup(query common.SelectQuery, filters []common.FilterOption) common.SelectQuery { + var conditions []string + var args []interface{} + + for _, filter := range filters { + condition, filterArgs := h.buildFilterCondition(filter) + if condition != "" { + conditions = append(conditions, condition) + args = append(args, filterArgs...) + } + } + + if len(conditions) == 0 { + return query + } + if len(conditions) == 1 { + return query.Where(conditions[0], args...) + } + return query.Where("("+strings.Join(conditions, " OR ")+")", args...) +} + +func (h *Handler) buildFilterCondition(filter common.FilterOption) (string, []interface{}) { + switch filter.Operator { + case "eq", "=": + return fmt.Sprintf("%s = ?", filter.Column), []interface{}{filter.Value} + case "neq", "!=", "<>": + return fmt.Sprintf("%s != ?", filter.Column), []interface{}{filter.Value} + case "gt", ">": + return fmt.Sprintf("%s > ?", filter.Column), []interface{}{filter.Value} + case "gte", ">=": + return fmt.Sprintf("%s >= ?", filter.Column), []interface{}{filter.Value} + case "lt", "<": + return fmt.Sprintf("%s < ?", filter.Column), []interface{}{filter.Value} + case "lte", "<=": + return fmt.Sprintf("%s <= ?", filter.Column), []interface{}{filter.Value} + case "like": + return fmt.Sprintf("%s LIKE ?", filter.Column), []interface{}{filter.Value} + case "ilike": + return fmt.Sprintf("%s ILIKE ?", filter.Column), []interface{}{filter.Value} + case "in": + condition, args := common.BuildInCondition(filter.Column, filter.Value) + return condition, args + case "is_null": + return fmt.Sprintf("%s IS NULL", filter.Column), nil + case "is_not_null": + return fmt.Sprintf("%s IS NOT NULL", filter.Column), nil + } + return "", nil +} + +func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, preloads []common.PreloadOption) (common.SelectQuery, error) { + for _, preload := range preloads { + if preload.Relation == "" { + continue + } + query = query.PreloadRelation(preload.Relation) + } + return query, nil +} diff --git a/pkg/resolvemcp/hooks.go b/pkg/resolvemcp/hooks.go new file mode 100644 index 0000000..b98373a --- /dev/null +++ b/pkg/resolvemcp/hooks.go @@ -0,0 +1,113 @@ +package resolvemcp + +import ( + "context" + "fmt" + + "github.com/bitechdev/ResolveSpec/pkg/common" + "github.com/bitechdev/ResolveSpec/pkg/logger" +) + +// HookType defines the type of hook to execute +type HookType string + +const ( + // BeforeHandle fires after model resolution, before operation dispatch. + BeforeHandle HookType = "before_handle" + + BeforeRead HookType = "before_read" + AfterRead HookType = "after_read" + + BeforeCreate HookType = "before_create" + AfterCreate HookType = "after_create" + + BeforeUpdate HookType = "before_update" + AfterUpdate HookType = "after_update" + + BeforeDelete HookType = "before_delete" + AfterDelete HookType = "after_delete" +) + +// HookContext contains all the data available to a hook +type HookContext struct { + Context context.Context + Handler *Handler + Schema string + Entity string + Model interface{} + Options common.RequestOptions + Operation string + ID string + Data interface{} + Result interface{} + Error error + Query common.SelectQuery + Abort bool + AbortMessage string + AbortCode int + Tx common.Database +} + +// HookFunc is the signature for hook functions +type HookFunc func(*HookContext) error + +// HookRegistry manages all registered hooks +type HookRegistry struct { + hooks map[HookType][]HookFunc +} + +func NewHookRegistry() *HookRegistry { + return &HookRegistry{ + hooks: make(map[HookType][]HookFunc), + } +} + +func (r *HookRegistry) Register(hookType HookType, hook HookFunc) { + if r.hooks == nil { + r.hooks = make(map[HookType][]HookFunc) + } + r.hooks[hookType] = append(r.hooks[hookType], hook) + logger.Info("Registered resolvemcp hook for %s (total: %d)", hookType, len(r.hooks[hookType])) +} + +func (r *HookRegistry) RegisterMultiple(hookTypes []HookType, hook HookFunc) { + for _, hookType := range hookTypes { + r.Register(hookType, hook) + } +} + +func (r *HookRegistry) Execute(hookType HookType, ctx *HookContext) error { + hooks, exists := r.hooks[hookType] + if !exists || len(hooks) == 0 { + return nil + } + + logger.Debug("Executing %d resolvemcp hook(s) for %s", len(hooks), hookType) + + for i, hook := range hooks { + if err := hook(ctx); err != nil { + logger.Error("resolvemcp hook %d for %s failed: %v", i+1, hookType, err) + return fmt.Errorf("hook execution failed: %w", err) + } + + if ctx.Abort { + logger.Warn("resolvemcp hook %d for %s requested abort: %s", i+1, hookType, ctx.AbortMessage) + return fmt.Errorf("operation aborted by hook: %s", ctx.AbortMessage) + } + } + + return nil +} + +func (r *HookRegistry) Clear(hookType HookType) { + delete(r.hooks, hookType) +} + +func (r *HookRegistry) ClearAll() { + r.hooks = make(map[HookType][]HookFunc) +} + +func (r *HookRegistry) HasHooks(hookType HookType) bool { + hooks, exists := r.hooks[hookType] + return exists && len(hooks) > 0 +} diff --git a/pkg/resolvemcp/resolvemcp.go b/pkg/resolvemcp/resolvemcp.go new file mode 100644 index 0000000..f85531e --- /dev/null +++ b/pkg/resolvemcp/resolvemcp.go @@ -0,0 +1,83 @@ +// Package resolvemcp exposes registered database models as Model Context Protocol (MCP) tools +// and resources over HTTP/SSE transport. +// +// It mirrors the resolvespec package patterns: +// - Same model registration API +// - Same filter, sort, cursor pagination, preload options +// - Same lifecycle hook system +// +// Usage: +// +// handler := resolvemcp.NewHandlerWithGORM(db) +// handler.RegisterModel("public", "users", &User{}) +// +// r := mux.NewRouter() +// resolvemcp.SetupMuxRoutes(r, handler, "http://localhost:8080") +package resolvemcp + +import ( + "net/http" + + "github.com/gorilla/mux" + "github.com/mark3labs/mcp-go/server" + "github.com/uptrace/bun" + "gorm.io/gorm" + + "github.com/bitechdev/ResolveSpec/pkg/common" + "github.com/bitechdev/ResolveSpec/pkg/common/adapters/database" + "github.com/bitechdev/ResolveSpec/pkg/modelregistry" +) + +// NewHandlerWithGORM creates a Handler backed by a GORM database connection. +func NewHandlerWithGORM(db *gorm.DB) *Handler { + return NewHandler(database.NewGormAdapter(db), modelregistry.NewModelRegistry()) +} + +// NewHandlerWithBun creates a Handler backed by a Bun database connection. +func NewHandlerWithBun(db *bun.DB) *Handler { + return NewHandler(database.NewBunAdapter(db), modelregistry.NewModelRegistry()) +} + +// NewHandlerWithDB creates a Handler using an existing common.Database and a new registry. +func NewHandlerWithDB(db common.Database) *Handler { + return NewHandler(db, modelregistry.NewModelRegistry()) +} + +// SetupMuxRoutes mounts the MCP HTTP/SSE endpoints on the given Gorilla Mux router. +// +// baseURL is the public-facing base URL of the server (e.g. "http://localhost:8080"). +// It is sent to MCP clients during the SSE handshake so they know where to POST messages. +// +// Two routes are registered: +// - GET /mcp/sse — SSE connection endpoint (client subscribes here) +// - POST /mcp/message — JSON-RPC message endpoint (client sends requests here) +// +// To protect these routes with authentication, wrap the mux router or apply middleware +// before calling SetupMuxRoutes. +func SetupMuxRoutes(muxRouter *mux.Router, handler *Handler, baseURL string) { + sseServer := server.NewSSEServer( + handler.mcpServer, + server.WithBaseURL(baseURL), + server.WithBasePath("/mcp"), + ) + + muxRouter.Handle("/mcp/sse", sseServer.SSEHandler()).Methods("GET", "OPTIONS") + muxRouter.Handle("/mcp/message", sseServer.MessageHandler()).Methods("POST", "OPTIONS") + + // Convenience: also expose the full SSE server at /mcp for clients that + // use ServeHTTP directly (e.g. net/http default mux). + muxRouter.PathPrefix("/mcp").Handler(http.StripPrefix("/mcp", sseServer)) +} + +// NewSSEServer creates an *server.SSEServer that can be mounted manually, +// useful when integrating with non-Mux routers or adding extra middleware. +// +// sseServer := resolvemcp.NewSSEServer(handler, "http://localhost:8080", "/mcp") +// http.Handle("/mcp/", http.StripPrefix("/mcp", sseServer)) +func NewSSEServer(handler *Handler, baseURL, basePath string) *server.SSEServer { + return server.NewSSEServer( + handler.mcpServer, + server.WithBaseURL(baseURL), + server.WithBasePath(basePath), + ) +} diff --git a/pkg/resolvemcp/tools.go b/pkg/resolvemcp/tools.go new file mode 100644 index 0000000..7a09181 --- /dev/null +++ b/pkg/resolvemcp/tools.go @@ -0,0 +1,415 @@ +package resolvemcp + +import ( + "context" + "encoding/json" + "fmt" + "strings" + + "github.com/mark3labs/mcp-go/mcp" + + "github.com/bitechdev/ResolveSpec/pkg/common" + "github.com/bitechdev/ResolveSpec/pkg/logger" +) + +// toolName builds the MCP tool name for a given operation and model. +func toolName(operation, schema, entity string) string { + if schema == "" { + return fmt.Sprintf("%s_%s", operation, entity) + } + return fmt.Sprintf("%s_%s_%s", operation, schema, entity) +} + +// registerModelTools registers the four CRUD tools and resource for a model. +func registerModelTools(h *Handler, schema, entity string, model interface{}) { + registerReadTool(h, schema, entity) + registerCreateTool(h, schema, entity) + registerUpdateTool(h, schema, entity) + registerDeleteTool(h, schema, entity) + registerModelResource(h, schema, entity) + + logger.Info("[resolvemcp] Registered MCP tools for %s.%s", schema, entity) +} + +// -------------------------------------------------------------------------- +// Read tool +// -------------------------------------------------------------------------- + +func registerReadTool(h *Handler, schema, entity string) { + name := toolName("read", schema, entity) + description := fmt.Sprintf("Read records from %s", buildModelName(schema, entity)) + + tool := mcp.NewTool(name, + mcp.WithDescription(description), + mcp.WithString("id", + mcp.Description("Primary key of a single record to fetch (optional)"), + ), + mcp.WithNumber("limit", + mcp.Description("Maximum number of records to return"), + ), + mcp.WithNumber("offset", + mcp.Description("Number of records to skip"), + ), + mcp.WithString("cursor_forward", + mcp.Description("Cursor value for the next page (primary key of last record on current page)"), + ), + mcp.WithString("cursor_backward", + mcp.Description("Cursor value for the previous page"), + ), + mcp.WithArray("columns", + mcp.Description("List of column names to include in the result"), + ), + mcp.WithArray("omit_columns", + mcp.Description("List of column names to exclude from the result"), + ), + mcp.WithArray("filters", + mcp.Description(`Array of filter objects. Each object: {"column":"name","operator":"=","value":"val","logic_operator":"AND|OR"}. Operators: =, !=, >, <, >=, <=, like, ilike, in, is_null, is_not_null`), + ), + mcp.WithArray("sort", + mcp.Description(`Array of sort objects. Each object: {"column":"name","direction":"asc|desc"}`), + ), + mcp.WithArray("preloads", + mcp.Description(`Array of relation preload objects. Each object: {"relation":"RelationName","columns":["col1"]}`), + ), + ) + + h.mcpServer.AddTool(tool, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + args := req.GetArguments() + id, _ := args["id"].(string) + options := parseRequestOptions(args) + + data, metadata, err := h.executeRead(ctx, schema, entity, id, options) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + return marshalResult(map[string]interface{}{ + "success": true, + "data": data, + "metadata": metadata, + }) + }) +} + +// -------------------------------------------------------------------------- +// Create tool +// -------------------------------------------------------------------------- + +func registerCreateTool(h *Handler, schema, entity string) { + name := toolName("create", schema, entity) + description := fmt.Sprintf("Create one or more records in %s", buildModelName(schema, entity)) + + tool := mcp.NewTool(name, + mcp.WithDescription(description), + mcp.WithObject("data", + mcp.Description("Record fields to create (single object), or pass an array as the 'items' key"), + mcp.Required(), + ), + ) + + h.mcpServer.AddTool(tool, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + args := req.GetArguments() + data, ok := args["data"] + if !ok { + return mcp.NewToolResultError("missing required argument: data"), nil + } + + result, err := h.executeCreate(ctx, schema, entity, data) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + return marshalResult(map[string]interface{}{ + "success": true, + "data": result, + }) + }) +} + +// -------------------------------------------------------------------------- +// Update tool +// -------------------------------------------------------------------------- + +func registerUpdateTool(h *Handler, schema, entity string) { + name := toolName("update", schema, entity) + description := fmt.Sprintf("Update an existing record in %s", buildModelName(schema, entity)) + + tool := mcp.NewTool(name, + mcp.WithDescription(description), + mcp.WithString("id", + mcp.Description("Primary key of the record to update"), + ), + mcp.WithObject("data", + mcp.Description("Fields to update (non-null fields will be merged into the existing record)"), + mcp.Required(), + ), + ) + + h.mcpServer.AddTool(tool, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + args := req.GetArguments() + id, _ := args["id"].(string) + + data, ok := args["data"] + if !ok { + return mcp.NewToolResultError("missing required argument: data"), nil + } + dataMap, ok := data.(map[string]interface{}) + if !ok { + return mcp.NewToolResultError("data must be an object"), nil + } + + result, err := h.executeUpdate(ctx, schema, entity, id, dataMap) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + return marshalResult(map[string]interface{}{ + "success": true, + "data": result, + }) + }) +} + +// -------------------------------------------------------------------------- +// Delete tool +// -------------------------------------------------------------------------- + +func registerDeleteTool(h *Handler, schema, entity string) { + name := toolName("delete", schema, entity) + description := fmt.Sprintf("Delete a record from %s by primary key", buildModelName(schema, entity)) + + tool := mcp.NewTool(name, + mcp.WithDescription(description), + mcp.WithString("id", + mcp.Description("Primary key of the record to delete"), + mcp.Required(), + ), + ) + + h.mcpServer.AddTool(tool, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + args := req.GetArguments() + id, _ := args["id"].(string) + + result, err := h.executeDelete(ctx, schema, entity, id) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + return marshalResult(map[string]interface{}{ + "success": true, + "data": result, + }) + }) +} + +// -------------------------------------------------------------------------- +// Resource registration +// -------------------------------------------------------------------------- + +func registerModelResource(h *Handler, schema, entity string) { + resourceURI := buildModelName(schema, entity) + displayName := entity + if schema != "" { + displayName = schema + "." + entity + } + + resource := mcp.NewResource( + resourceURI, + displayName, + mcp.WithResourceDescription(fmt.Sprintf("Database table: %s", displayName)), + mcp.WithMIMEType("application/json"), + ) + + h.mcpServer.AddResource(resource, func(ctx context.Context, req mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { + limit := 100 + options := common.RequestOptions{Limit: &limit} + + data, metadata, err := h.executeRead(ctx, schema, entity, "", options) + if err != nil { + return nil, err + } + + payload := map[string]interface{}{ + "data": data, + "metadata": metadata, + } + jsonBytes, err := json.Marshal(payload) + if err != nil { + return nil, fmt.Errorf("error marshaling resource: %w", err) + } + + return []mcp.ResourceContents{ + mcp.TextResourceContents{ + URI: req.Params.URI, + MIMEType: "application/json", + Text: string(jsonBytes), + }, + }, nil + }) +} + +// -------------------------------------------------------------------------- +// Argument parsing helpers +// -------------------------------------------------------------------------- + +// parseRequestOptions converts raw MCP tool arguments into common.RequestOptions. +func parseRequestOptions(args map[string]interface{}) common.RequestOptions { + options := common.RequestOptions{} + + // limit + if v, ok := args["limit"]; ok { + switch n := v.(type) { + case float64: + limit := int(n) + options.Limit = &limit + case int: + options.Limit = &n + } + } + + // offset + if v, ok := args["offset"]; ok { + switch n := v.(type) { + case float64: + offset := int(n) + options.Offset = &offset + case int: + options.Offset = &n + } + } + + // cursor_forward / cursor_backward + if v, ok := args["cursor_forward"].(string); ok { + options.CursorForward = v + } + if v, ok := args["cursor_backward"].(string); ok { + options.CursorBackward = v + } + + // columns + options.Columns = parseStringArray(args["columns"]) + + // omit_columns + options.OmitColumns = parseStringArray(args["omit_columns"]) + + // filters — marshal each item and unmarshal into FilterOption + options.Filters = parseFilters(args["filters"]) + + // sort + options.Sort = parseSortOptions(args["sort"]) + + // preloads + options.Preload = parsePreloadOptions(args["preloads"]) + + return options +} + +func parseStringArray(raw interface{}) []string { + if raw == nil { + return nil + } + items, ok := raw.([]interface{}) + if !ok { + return nil + } + result := make([]string, 0, len(items)) + for _, item := range items { + if s, ok := item.(string); ok { + result = append(result, s) + } + } + return result +} + +func parseFilters(raw interface{}) []common.FilterOption { + if raw == nil { + return nil + } + items, ok := raw.([]interface{}) + if !ok { + return nil + } + result := make([]common.FilterOption, 0, len(items)) + for _, item := range items { + b, err := json.Marshal(item) + if err != nil { + continue + } + var f common.FilterOption + if err := json.Unmarshal(b, &f); err != nil { + continue + } + if f.Column == "" || f.Operator == "" { + continue + } + // Normalise logic operator + if strings.EqualFold(f.LogicOperator, "or") { + f.LogicOperator = "OR" + } else { + f.LogicOperator = "AND" + } + result = append(result, f) + } + return result +} + +func parseSortOptions(raw interface{}) []common.SortOption { + if raw == nil { + return nil + } + items, ok := raw.([]interface{}) + if !ok { + return nil + } + result := make([]common.SortOption, 0, len(items)) + for _, item := range items { + b, err := json.Marshal(item) + if err != nil { + continue + } + var s common.SortOption + if err := json.Unmarshal(b, &s); err != nil { + continue + } + if s.Column == "" { + continue + } + result = append(result, s) + } + return result +} + +func parsePreloadOptions(raw interface{}) []common.PreloadOption { + if raw == nil { + return nil + } + items, ok := raw.([]interface{}) + if !ok { + return nil + } + result := make([]common.PreloadOption, 0, len(items)) + for _, item := range items { + b, err := json.Marshal(item) + if err != nil { + continue + } + var p common.PreloadOption + if err := json.Unmarshal(b, &p); err != nil { + continue + } + if p.Relation == "" { + continue + } + result = append(result, p) + } + return result +} + +// marshalResult marshals a value to JSON and returns it as an MCP text result. +func marshalResult(v interface{}) (*mcp.CallToolResult, error) { + b, err := json.Marshal(v) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("error marshaling result: %v", err)), nil + } + return mcp.NewToolResultText(string(b)), nil +} diff --git a/pkg/restheadspec/cursor.go b/pkg/restheadspec/cursor.go index b060a5b..ab776ed 100644 --- a/pkg/restheadspec/cursor.go +++ b/pkg/restheadspec/cursor.go @@ -93,12 +93,18 @@ func (opts *ExtendedRequestOptions) GetCursorFilter( } // Handle joins - if isJoin && expandJoins != nil { - if joinClause, ok := expandJoins[prefix]; ok { - jSQL, cRef := rewriteJoin(joinClause, tableName, prefix) - joinSQL = jSQL - cursorCol = cRef + "." + field - targetCol = prefix + "." + field + if isJoin { + if expandJoins != nil { + if joinClause, ok := expandJoins[prefix]; ok { + jSQL, cRef := rewriteJoin(joinClause, tableName, prefix) + joinSQL = jSQL + cursorCol = cRef + "." + field + targetCol = prefix + "." + field + } + } + if cursorCol == "" { + logger.Warn("Skipping cursor sort column %q: join alias %q not in expandJoins", col, prefix) + continue } } diff --git a/pkg/restheadspec/cursor_test.go b/pkg/restheadspec/cursor_test.go index a9de870..22b57b3 100644 --- a/pkg/restheadspec/cursor_test.go +++ b/pkg/restheadspec/cursor_test.go @@ -278,6 +278,47 @@ func TestCleanSortField(t *testing.T) { } } +func TestGetCursorFilter_LateralJoin(t *testing.T) { + lateralJoin := "inner join lateral (\nselect string_agg(a.name, '.') as sortorder\nfrom tree(account.rid_account) r\ninner join account a on a.id = r.id\n) fn on true" + + opts := &ExtendedRequestOptions{ + RequestOptions: common.RequestOptions{ + Sort: []common.SortOption{ + {Column: "fn.sortorder", Direction: "ASC"}, + }, + }, + } + opts.CursorForward = "8975" + + tableName := "core.account" + pkName := "rid_account" + // modelColumns does not contain "sortorder" - it's a lateral join computed column + modelColumns := []string{"rid_account", "description", "pastelno"} + expandJoins := map[string]string{"fn": lateralJoin} + + filter, err := opts.GetCursorFilter(tableName, pkName, modelColumns, expandJoins) + if err != nil { + t.Fatalf("GetCursorFilter failed: %v", err) + } + + t.Logf("Generated lateral cursor filter: %s", filter) + + // Should contain the rewritten lateral join inside the EXISTS subquery + if !strings.Contains(filter, "cursor_select_fn") { + t.Errorf("Filter should reference cursor_select_fn alias, got: %s", filter) + } + + // Should compare fn.sortorder values + if !strings.Contains(filter, "sortorder") { + t.Errorf("Filter should reference sortorder column, got: %s", filter) + } + + // Should NOT contain empty comparison like "< " + if strings.Contains(filter, " < ") || strings.Contains(filter, " > ") { + t.Errorf("Filter should not contain empty comparison operators, got: %s", filter) + } +} + func TestBuildPriorityChain(t *testing.T) { clauses := []string{ "cursor_select.priority > posts.priority", diff --git a/pkg/restheadspec/handler.go b/pkg/restheadspec/handler.go index fc32b22..b971480 100644 --- a/pkg/restheadspec/handler.go +++ b/pkg/restheadspec/handler.go @@ -723,13 +723,15 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st // Extract model columns for validation using the generic database function modelColumns := reflection.GetModelColumns(model) - // Build expand joins map (if needed in future) - var expandJoins map[string]string - if len(options.Expand) > 0 { - expandJoins = make(map[string]string) - // TODO: Build actual JOIN SQL for each expand relation - // For now, pass empty map as joins are handled via Preload + // Build expand joins map: custom SQL joins are available in cursor subquery + expandJoins := make(map[string]string) + for _, joinClause := range options.CustomSQLJoin { + alias := extractJoinAlias(joinClause) + if alias != "" { + expandJoins[alias] = joinClause + } } + // TODO: also add Expand relation JOINs when those are built as SQL rather than Preload // Default sort to primary key when none provided if len(options.Sort) == 0 { diff --git a/pkg/restheadspec/headers.go b/pkg/restheadspec/headers.go index cd84539..4af6e48 100644 --- a/pkg/restheadspec/headers.go +++ b/pkg/restheadspec/headers.go @@ -552,10 +552,8 @@ func (h *Handler) parseCustomSQLJoin(options *ExtendedRequestOptions, value stri // - "LEFT JOIN departments d ON ..." -> "d" // - "INNER JOIN users AS u ON ..." -> "u" // - "JOIN roles r ON ..." -> "r" +// - "INNER JOIN LATERAL (...) fn ON true" -> "fn" func extractJoinAlias(joinClause string) string { - // Pattern: JOIN table_name [AS] alias ON ... - // We need to extract the alias (word before ON) - upperJoin := strings.ToUpper(joinClause) // Find the "JOIN" keyword position @@ -564,7 +562,20 @@ func extractJoinAlias(joinClause string) string { return "" } - // Find the "ON" keyword position + // Lateral joins: alias is the word after the closing ) and before ON + if strings.Contains(upperJoin, "LATERAL") { + lastClose := strings.LastIndex(joinClause, ")") + if lastClose != -1 { + words := strings.Fields(joinClause[lastClose+1:]) + // words should be like ["fn", "on", "true"] or ["on", "true"] + if len(words) >= 1 && !strings.EqualFold(words[0], "on") { + return words[0] + } + } + return "" + } + + // Regular joins: find the "ON" keyword position (first occurrence) onIdx := strings.Index(upperJoin, " ON ") if onIdx == -1 { return "" diff --git a/pkg/restheadspec/headers_test.go b/pkg/restheadspec/headers_test.go index d83d09f..34ea03f 100644 --- a/pkg/restheadspec/headers_test.go +++ b/pkg/restheadspec/headers_test.go @@ -142,6 +142,16 @@ func TestExtractJoinAlias(t *testing.T) { joinClause: "LEFT JOIN departments", expected: "", }, + { + name: "LATERAL join with alias", + joinClause: "inner join lateral (select sortorder from compute_fn(t.id)) fn on true", + expected: "fn", + }, + { + name: "LATERAL join with multiline subquery containing inner ON", + joinClause: "inner join lateral (\nselect string_agg(a.name, '.') as sortorder\nfrom tree(t.id) r\ninner join account a on a.id = r.id\n) fn on true", + expected: "fn", + }, } for _, tt := range tests {