diff --git a/README.md b/README.md index dbb8152..e1a3403 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,7 @@ ResolveSpec is a flexible and powerful REST API specification and implementation 3. **FuncSpec** - Header-based API to map and call API's to sql functions 4. **WebSocketSpec** - Real-time bidirectional communication with full CRUD operations 5. **MQTTSpec** - MQTT-based API ideal for IoT and mobile applications +6. **ResolveMCP** - Model Context Protocol (MCP) server that exposes models as AI-consumable tools and resources over HTTP/SSE All share the same core architecture and provide dynamic data querying, relationship preloading, and complex filtering. @@ -21,6 +22,7 @@ All share the same core architecture and provide dynamic data querying, relation * [Quick Start](#quick-start) * [ResolveSpec (Body-Based API)](#resolvespec---body-based-api) * [RestHeadSpec (Header-Based API)](#restheadspec---header-based-api) + * [ResolveMCP (MCP Server)](#resolvemcp---mcp-server) * [Architecture](#architecture) * [API Structure](#api-structure) * [RestHeadSpec Overview](#restheadspec-header-based-api) @@ -50,6 +52,15 @@ All share the same core architecture and provide dynamic data querying, relation * **🆕 Backward Compatible**: Existing code works without changes * **🆕 Better Testing**: Mockable interfaces for easy unit testing +### ResolveMCP (v3.2+) + +* **🆕 MCP Server**: Expose any registered database model as Model Context Protocol tools and resources +* **🆕 AI-Ready Descriptions**: Tool descriptions include the full column schema, primary key, nullable flags, and relations — giving AI models everything they need to query correctly without guessing +* **🆕 Four Tools Per Model**: `read_`, `create_`, `update_`, `delete_` tools auto-registered per model +* **🆕 Full Query Support**: Filters, sort, limit/offset, cursor pagination, column selection, and relation preloading all available as tool parameters +* **🆕 HTTP/SSE Transport**: Standards-compliant SSE transport for use with Claude Desktop, Cursor, and any MCP-compatible client +* **🆕 Lifecycle Hooks**: Same Before/After hook system as ResolveSpec for auth and side-effects + ### RestHeadSpec (v2.1+) * **🆕 Header-Based API**: All query options passed via HTTP headers instead of request body @@ -190,6 +201,40 @@ restheadspec.SetupMuxRoutes(router, handler, nil) For complete documentation, see [pkg/restheadspec/README.md](pkg/restheadspec/README.md). +### ResolveMCP (MCP Server) + +ResolveMCP exposes registered models as Model Context Protocol tools so AI models (Claude, Cursor, etc.) can query and mutate your database directly: + +```go +import "github.com/bitechdev/ResolveSpec/pkg/resolvemcp" + +// Create handler +handler := resolvemcp.NewHandlerWithGORM(db) + +// Register models — must be done BEFORE Build() +handler.RegisterModel("public", "users", &User{}) +handler.RegisterModel("public", "posts", &Post{}) + +// Finalize: registers MCP tools and resources +handler.Build() + +// Mount SSE transport on your existing router +router := mux.NewRouter() +resolvemcp.SetupMuxRoutes(router, handler, "http://localhost:8080") + +// MCP clients connect to: +// SSE stream: GET http://localhost:8080/mcp/sse +// Messages: POST http://localhost:8080/mcp/message +// +// Auto-registered tools per model: +// read_public_users — filter, sort, paginate, preload +// create_public_users — insert a new record +// update_public_users — update a record by ID +// delete_public_users — delete a record by ID +``` + +For complete documentation, see [pkg/resolvemcp/README.md](pkg/resolvemcp/README.md) (if present) or the package source. + ## Architecture ### Two Complementary APIs @@ -344,6 +389,19 @@ Alternative REST API where query options are passed via HTTP headers. For complete documentation, see [pkg/restheadspec/README.md](pkg/restheadspec/README.md). +#### ResolveMCP - MCP Server + +Expose any registered model as Model Context Protocol tools and resources consumable by AI models over HTTP/SSE. + +**Key Features**: +- Four tools per model: `read_`, `create_`, `update_`, `delete_` +- Rich AI-readable descriptions: column names, types, primary key, nullable flags, and preloadable relations +- Full query support: filters, sort, limit/offset, cursor pagination, column selection, preloads +- HTTP/SSE transport compatible with Claude Desktop, Cursor, and any MCP client +- Same Before/After lifecycle hooks as ResolveSpec + +For complete documentation, see [pkg/resolvemcp/](pkg/resolvemcp/). + #### FuncSpec - Function-Based SQL API Execute SQL functions and queries through a simple HTTP API with header-based parameters. @@ -529,7 +587,18 @@ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file ## What's New -### v3.1 (Latest - February 2026) +### v3.2 (Latest - March 2026) + +**ResolveMCP - Model Context Protocol Server (🆕)**: + +* **MCP Tools**: Four tools auto-registered per model (`read_`, `create_`, `update_`, `delete_`) over HTTP/SSE transport +* **AI-Ready Descriptions**: Full column schema, primary key, nullable flags, and relation names surfaced in tool descriptions so AI models can query without guessing +* **Full Query Support**: Filters, sort, limit/offset, cursor pagination, column selection, and relation preloading all available as tool parameters +* **HTTP/SSE Transport**: Standards-compliant transport compatible with Claude Desktop, Cursor, and any MCP 2024-11-05 client +* **Lifecycle Hooks**: Same Before/After hook system as ResolveSpec for auth, auditing, and side-effects +* **MCP Resources**: Each model also exposed as a named resource for direct data access by AI clients + +### v3.1 (February 2026) **SQLite Schema Translation (🆕)**: diff --git a/pkg/resolvemcp/cursor.go b/pkg/resolvemcp/cursor.go index 89668f1..25d038f 100644 --- a/pkg/resolvemcp/cursor.go +++ b/pkg/resolvemcp/cursor.go @@ -18,11 +18,13 @@ const ( ) // getCursorFilter generates a SQL EXISTS subquery for cursor-based pagination. +// expandJoins is an optional map[alias]string of JOIN clauses for join-column sort support. func getCursorFilter( tableName string, pkName string, modelColumns []string, options common.RequestOptions, + expandJoins map[string]string, ) (string, error) { fullTableName := tableName if strings.Contains(tableName, ".") { @@ -40,6 +42,7 @@ func getCursorFilter( } var whereClauses []string + joinSQL := "" reverse := direction < 0 for _, s := range sortItems { @@ -57,12 +60,27 @@ func getCursorFilter( desc = !desc } - cursorCol, targetCol, err := resolveCursorColumn(field, prefix, tableName, modelColumns) + cursorCol, targetCol, isJoin, err := resolveCursorColumn(field, prefix, tableName, modelColumns) if err != nil { logger.Warn("Skipping invalid sort column %q: %v", col, err) continue } + if isJoin { + if expandJoins != nil { + if joinClause, ok := expandJoins[prefix]; ok { + jSQL, cRef := rewriteCursorJoin(joinClause, tableName, prefix) + joinSQL = jSQL + cursorCol = cRef + "." + field + targetCol = prefix + "." + field + } + } + if cursorCol == "" { + logger.Warn("Skipping cursor sort column %q: join alias %q not in expandJoins", col, prefix) + continue + } + } + op := "<" if desc { op = ">" @@ -79,10 +97,12 @@ func getCursorFilter( query := fmt.Sprintf(`EXISTS ( SELECT 1 FROM %s cursor_select + %s WHERE cursor_select.%s = %s AND (%s) )`, fullTableName, + joinSQL, pkName, cursorID, orSQL, @@ -101,26 +121,34 @@ func getActiveCursor(options common.RequestOptions) (id string, direction cursor return "", 0 } -func resolveCursorColumn(field, prefix, tableName string, modelColumns []string) (cursorCol, targetCol string, err error) { +func resolveCursorColumn(field, prefix, tableName string, modelColumns []string) (cursorCol, targetCol string, isJoin bool, err error) { if strings.Contains(field, "->") { - return "cursor_select." + field, tableName + "." + field, nil + return "cursor_select." + field, tableName + "." + field, false, nil } if modelColumns != nil { for _, col := range modelColumns { if strings.EqualFold(col, field) { - return "cursor_select." + field, tableName + "." + field, nil + return "cursor_select." + field, tableName + "." + field, false, nil } } } else { - return "cursor_select." + field, tableName + "." + field, nil + return "cursor_select." + field, tableName + "." + field, false, nil } if prefix != "" && prefix != tableName { - return "", "", fmt.Errorf("joined columns not supported in cursor pagination: %s", field) + return "", "", true, nil } - return "", "", fmt.Errorf("invalid column: %s", field) + return "", "", false, fmt.Errorf("invalid column: %s", field) +} + +func rewriteCursorJoin(joinClause, mainTable, alias string) (joinSQL, cursorAlias string) { + joinSQL = strings.ReplaceAll(joinClause, mainTable+".", "cursor_select.") + cursorAlias = "cursor_select_" + alias + joinSQL = strings.ReplaceAll(joinSQL, " "+alias+" ", " "+cursorAlias+" ") + joinSQL = strings.ReplaceAll(joinSQL, " "+alias+".", " "+cursorAlias+".") + return joinSQL, cursorAlias } func buildCursorPriorityChain(clauses []string) string { diff --git a/pkg/resolvemcp/handler.go b/pkg/resolvemcp/handler.go index 589e5f7..42bc812 100644 --- a/pkg/resolvemcp/handler.go +++ b/pkg/resolvemcp/handler.go @@ -191,7 +191,8 @@ func (h *Handler) executeRead(ctx context.Context, schema, entity, id string, op options.Sort = []common.SortOption{{Column: pkName, Direction: "ASC"}} } - cursorFilter, err := getCursorFilter(tableName, pkName, modelColumns, options) + // expandJoins is empty for resolvemcp — no custom SQL join support yet + cursorFilter, err := getCursorFilter(tableName, pkName, modelColumns, options, nil) if err != nil { return nil, nil, fmt.Errorf("cursor error: %w", err) } diff --git a/pkg/resolvemcp/tools.go b/pkg/resolvemcp/tools.go index 7a09181..3c96a0e 100644 --- a/pkg/resolvemcp/tools.go +++ b/pkg/resolvemcp/tools.go @@ -4,12 +4,14 @@ import ( "context" "encoding/json" "fmt" + "reflect" "strings" "github.com/mark3labs/mcp-go/mcp" "github.com/bitechdev/ResolveSpec/pkg/common" "github.com/bitechdev/ResolveSpec/pkg/logger" + "github.com/bitechdev/ResolveSpec/pkg/reflection" ) // toolName builds the MCP tool name for a given operation and model. @@ -22,54 +24,268 @@ func toolName(operation, schema, entity string) string { // registerModelTools registers the four CRUD tools and resource for a model. func registerModelTools(h *Handler, schema, entity string, model interface{}) { - registerReadTool(h, schema, entity) - registerCreateTool(h, schema, entity) - registerUpdateTool(h, schema, entity) - registerDeleteTool(h, schema, entity) - registerModelResource(h, schema, entity) + info := buildModelInfo(schema, entity, model) + registerReadTool(h, schema, entity, info) + registerCreateTool(h, schema, entity, info) + registerUpdateTool(h, schema, entity, info) + registerDeleteTool(h, schema, entity, info) + registerModelResource(h, schema, entity, info) - logger.Info("[resolvemcp] Registered MCP tools for %s.%s", schema, entity) + logger.Info("[resolvemcp] Registered MCP tools for %s", info.fullName) +} + +// -------------------------------------------------------------------------- +// Model introspection +// -------------------------------------------------------------------------- + +// modelInfo holds pre-computed metadata for a model used in tool descriptions. +type modelInfo struct { + fullName string // e.g. "public.users" + pkName string // e.g. "id" + columns []columnInfo + relationNames []string + schemaDoc string // formatted multi-line schema listing +} + +type columnInfo struct { + jsonName string + sqlName string + goType string + sqlType string + isPrimary bool + isUnique bool + isFK bool + nullable bool +} + +// buildModelInfo extracts column metadata and pre-builds the schema documentation string. +func buildModelInfo(schema, entity string, model interface{}) modelInfo { + info := modelInfo{ + fullName: buildModelName(schema, entity), + pkName: reflection.GetPrimaryKeyName(model), + } + + // Unwrap to base struct type + modelType := reflect.TypeOf(model) + for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice) { + modelType = modelType.Elem() + } + if modelType == nil || modelType.Kind() != reflect.Struct { + return info + } + + details := reflection.GetModelColumnDetail(reflect.New(modelType).Elem()) + + for _, d := range details { + // Derive the JSON name from the struct field + jsonName := fieldJSONName(modelType, d.Name) + if jsonName == "" || jsonName == "-" { + continue + } + + // Skip relation fields (slice or user-defined struct that isn't time.Time). + fieldType, found := modelType.FieldByName(d.Name) + if found { + ft := fieldType.Type + if ft.Kind() == reflect.Ptr { + ft = ft.Elem() + } + isUserStruct := ft.Kind() == reflect.Struct && ft.Name() != "Time" && ft.PkgPath() != "" + if ft.Kind() == reflect.Slice || isUserStruct { + info.relationNames = append(info.relationNames, jsonName) + continue + } + } + + sqlName := d.SQLName + if sqlName == "" { + sqlName = jsonName + } + + // Derive Go type name, unwrapping pointer if needed. + goType := d.DataType + if goType == "" && found { + ft := fieldType.Type + for ft.Kind() == reflect.Ptr { + ft = ft.Elem() + } + goType = ft.Name() + } + + // isPrimary: use both the GORM-tag detection and a name comparison against + // the known primary key (handles camelCase "primaryKey" tags correctly). + isPrimary := d.SQLKey == "primary_key" || + (info.pkName != "" && (sqlName == info.pkName || jsonName == info.pkName)) + + ci := columnInfo{ + jsonName: jsonName, + sqlName: sqlName, + goType: goType, + sqlType: d.SQLDataType, + isPrimary: isPrimary, + isUnique: d.SQLKey == "unique" || d.SQLKey == "uniqueindex", + isFK: d.SQLKey == "foreign_key", + nullable: d.Nullable, + } + info.columns = append(info.columns, ci) + } + + info.schemaDoc = buildSchemaDoc(info) + return info +} + +// fieldJSONName returns the JSON tag name for a struct field, falling back to the field name. +func fieldJSONName(modelType reflect.Type, fieldName string) string { + field, ok := modelType.FieldByName(fieldName) + if !ok { + return fieldName + } + tag := field.Tag.Get("json") + if tag == "" { + return fieldName + } + parts := strings.SplitN(tag, ",", 2) + if parts[0] == "" { + return fieldName + } + return parts[0] +} + +// buildSchemaDoc builds a human-readable column listing for inclusion in tool descriptions. +func buildSchemaDoc(info modelInfo) string { + if len(info.columns) == 0 { + return "" + } + + var sb strings.Builder + sb.WriteString("Columns:\n") + for _, c := range info.columns { + line := fmt.Sprintf(" • %s", c.jsonName) + + typeDesc := c.goType + if c.sqlType != "" { + typeDesc = c.sqlType + } + if typeDesc != "" { + line += fmt.Sprintf(" (%s)", typeDesc) + } + + var flags []string + if c.isPrimary { + flags = append(flags, "primary key") + } + if c.isUnique { + flags = append(flags, "unique") + } + if c.isFK { + flags = append(flags, "foreign key") + } + if !c.nullable && !c.isPrimary { + flags = append(flags, "not null") + } else if c.nullable { + flags = append(flags, "nullable") + } + if len(flags) > 0 { + line += " — " + strings.Join(flags, ", ") + } + + sb.WriteString(line + "\n") + } + + if len(info.relationNames) > 0 { + sb.WriteString("Relations (preloadable): " + strings.Join(info.relationNames, ", ") + "\n") + } + + return sb.String() +} + +// columnNameList returns a comma-separated list of JSON column names (for descriptions). +func columnNameList(cols []columnInfo) string { + names := make([]string, len(cols)) + for i, c := range cols { + names[i] = c.jsonName + } + return strings.Join(names, ", ") +} + +// writableColumnNames returns JSON names for all non-primary-key columns. +func writableColumnNames(cols []columnInfo) []string { + var names []string + for _, c := range cols { + if !c.isPrimary { + names = append(names, c.jsonName) + } + } + return names } // -------------------------------------------------------------------------- // Read tool // -------------------------------------------------------------------------- -func registerReadTool(h *Handler, schema, entity string) { +func registerReadTool(h *Handler, schema, entity string, info modelInfo) { name := toolName("read", schema, entity) - description := fmt.Sprintf("Read records from %s", buildModelName(schema, entity)) + + var descParts []string + descParts = append(descParts, fmt.Sprintf("Read records from the '%s' database table.", info.fullName)) + if info.pkName != "" { + descParts = append(descParts, fmt.Sprintf("Primary key: '%s'. Pass it via 'id' to fetch a single record.", info.pkName)) + } + if info.schemaDoc != "" { + descParts = append(descParts, info.schemaDoc) + } + descParts = append(descParts, + "Pagination: use 'limit'/'offset' for offset-based paging, or 'cursor_forward'/'cursor_backward' (pass the primary key value of the last/first record on the current page) for cursor-based paging.", + "Filtering: each filter object requires 'column' (JSON field name) and 'operator'. Supported operators: = != > < >= <= like ilike in is_null is_not_null. Combine with 'logic_operator': AND (default) or OR.", + "Sorting: each sort object requires 'column' and 'direction' (asc or desc).", + ) + if len(info.relationNames) > 0 { + descParts = append(descParts, fmt.Sprintf("Preloadable relations: %s. Pass relation name in 'preloads'.", strings.Join(info.relationNames, ", "))) + } + + description := strings.Join(descParts, "\n\n") + + filterDesc := fmt.Sprintf(`Array of filter objects. Example: [{"column":"status","operator":"=","value":"active"},{"column":"age","operator":">","value":18,"logic_operator":"AND"}]`) + if len(info.columns) > 0 { + filterDesc += fmt.Sprintf(" Available columns: %s.", columnNameList(info.columns)) + } + + sortDesc := `Array of sort objects. Example: [{"column":"created_at","direction":"desc"}]` + if len(info.columns) > 0 { + sortDesc += fmt.Sprintf(" Available columns: %s.", columnNameList(info.columns)) + } tool := mcp.NewTool(name, mcp.WithDescription(description), mcp.WithString("id", - mcp.Description("Primary key of a single record to fetch (optional)"), + mcp.Description(fmt.Sprintf("Primary key (%s) of a single record to fetch. Omit to return multiple records.", info.pkName)), ), mcp.WithNumber("limit", - mcp.Description("Maximum number of records to return"), + mcp.Description("Maximum number of records to return per page. Recommended: 10–100."), ), mcp.WithNumber("offset", - mcp.Description("Number of records to skip"), + mcp.Description("Number of records to skip (for offset-based pagination). Use with 'limit'."), ), mcp.WithString("cursor_forward", - mcp.Description("Cursor value for the next page (primary key of last record on current page)"), + mcp.Description(fmt.Sprintf("Cursor for the next page: pass the '%s' value of the last record on the current page. Requires 'sort' to be set.", info.pkName)), ), mcp.WithString("cursor_backward", - mcp.Description("Cursor value for the previous page"), + mcp.Description(fmt.Sprintf("Cursor for the previous page: pass the '%s' value of the first record on the current page. Requires 'sort' to be set.", info.pkName)), ), mcp.WithArray("columns", - mcp.Description("List of column names to include in the result"), + mcp.Description(fmt.Sprintf("Columns to include in the result. Omit to return all columns. Available: %s.", columnNameList(info.columns))), ), mcp.WithArray("omit_columns", - mcp.Description("List of column names to exclude from the result"), + mcp.Description(fmt.Sprintf("Columns to exclude from the result. Available: %s.", columnNameList(info.columns))), ), mcp.WithArray("filters", - mcp.Description(`Array of filter objects. Each object: {"column":"name","operator":"=","value":"val","logic_operator":"AND|OR"}. Operators: =, !=, >, <, >=, <=, like, ilike, in, is_null, is_not_null`), + mcp.Description(filterDesc), ), mcp.WithArray("sort", - mcp.Description(`Array of sort objects. Each object: {"column":"name","direction":"asc|desc"}`), + mcp.Description(sortDesc), ), mcp.WithArray("preloads", - mcp.Description(`Array of relation preload objects. Each object: {"relation":"RelationName","columns":["col1"]}`), + mcp.Description(buildPreloadDesc(info)), ), ) @@ -91,18 +307,52 @@ func registerReadTool(h *Handler, schema, entity string) { }) } +func buildPreloadDesc(info modelInfo) string { + if len(info.relationNames) == 0 { + return `Array of relation preload objects. Each object: {"relation":"RelationName"}. No relations defined on this model.` + } + return fmt.Sprintf( + `Array of relation preload objects. Each object: {"relation":"RelationName","columns":["col1","col2"]}. Available relations: %s.`, + strings.Join(info.relationNames, ", "), + ) +} + // -------------------------------------------------------------------------- // Create tool // -------------------------------------------------------------------------- -func registerCreateTool(h *Handler, schema, entity string) { +func registerCreateTool(h *Handler, schema, entity string, info modelInfo) { name := toolName("create", schema, entity) - description := fmt.Sprintf("Create one or more records in %s", buildModelName(schema, entity)) + + writable := writableColumnNames(info.columns) + + var descParts []string + descParts = append(descParts, fmt.Sprintf("Create one or more new records in the '%s' table.", info.fullName)) + if len(writable) > 0 { + descParts = append(descParts, fmt.Sprintf("Writable fields: %s.", strings.Join(writable, ", "))) + } + if info.pkName != "" { + descParts = append(descParts, fmt.Sprintf("The primary key ('%s') is typically auto-generated — omit it unless you need to supply it explicitly.", info.pkName)) + } + descParts = append(descParts, + "Pass a single JSON object to 'data' to create one record. Pass an array of objects to create multiple records in a single transaction (all succeed or all fail).", + ) + if info.schemaDoc != "" { + descParts = append(descParts, info.schemaDoc) + } + + description := strings.Join(descParts, "\n\n") + + dataDesc := "Record fields to create." + if len(writable) > 0 { + dataDesc += fmt.Sprintf(" Writable fields: %s.", strings.Join(writable, ", ")) + } + dataDesc += " Pass a single object or an array of objects." tool := mcp.NewTool(name, mcp.WithDescription(description), mcp.WithObject("data", - mcp.Description("Record fields to create (single object), or pass an array as the 'items' key"), + mcp.Description(dataDesc), mcp.Required(), ), ) @@ -130,17 +380,42 @@ func registerCreateTool(h *Handler, schema, entity string) { // Update tool // -------------------------------------------------------------------------- -func registerUpdateTool(h *Handler, schema, entity string) { +func registerUpdateTool(h *Handler, schema, entity string, info modelInfo) { name := toolName("update", schema, entity) - description := fmt.Sprintf("Update an existing record in %s", buildModelName(schema, entity)) + + writable := writableColumnNames(info.columns) + + var descParts []string + descParts = append(descParts, fmt.Sprintf("Update an existing record in the '%s' table.", info.fullName)) + if info.pkName != "" { + descParts = append(descParts, fmt.Sprintf("Identify the record by its primary key ('%s') via the 'id' argument or by including '%s' inside 'data'.", info.pkName, info.pkName)) + } + if len(writable) > 0 { + descParts = append(descParts, fmt.Sprintf("Updatable fields: %s.", strings.Join(writable, ", "))) + } + descParts = append(descParts, + "Only non-null, non-empty fields in 'data' are applied — existing values are preserved for fields you omit. Returns the merged record as stored.", + ) + if info.schemaDoc != "" { + descParts = append(descParts, info.schemaDoc) + } + + description := strings.Join(descParts, "\n\n") + + idDesc := fmt.Sprintf("Primary key ('%s') of the record to update. Can also be included inside 'data'.", info.pkName) + + dataDesc := "Fields to update (non-null, non-empty values are merged into the existing record)." + if len(writable) > 0 { + dataDesc += fmt.Sprintf(" Updatable fields: %s.", strings.Join(writable, ", ")) + } tool := mcp.NewTool(name, mcp.WithDescription(description), mcp.WithString("id", - mcp.Description("Primary key of the record to update"), + mcp.Description(idDesc), ), mcp.WithObject("data", - mcp.Description("Fields to update (non-null fields will be merged into the existing record)"), + mcp.Description(dataDesc), mcp.Required(), ), ) @@ -174,14 +449,23 @@ func registerUpdateTool(h *Handler, schema, entity string) { // Delete tool // -------------------------------------------------------------------------- -func registerDeleteTool(h *Handler, schema, entity string) { +func registerDeleteTool(h *Handler, schema, entity string, info modelInfo) { name := toolName("delete", schema, entity) - description := fmt.Sprintf("Delete a record from %s by primary key", buildModelName(schema, entity)) + + descParts := []string{ + fmt.Sprintf("Delete a record from the '%s' table by its primary key.", info.fullName), + } + if info.pkName != "" { + descParts = append(descParts, fmt.Sprintf("Pass the '%s' value of the record to delete via the 'id' argument.", info.pkName)) + } + descParts = append(descParts, "Returns the deleted record. This operation is irreversible.") + + description := strings.Join(descParts, " ") tool := mcp.NewTool(name, mcp.WithDescription(description), mcp.WithString("id", - mcp.Description("Primary key of the record to delete"), + mcp.Description(fmt.Sprintf("Primary key ('%s') of the record to delete.", info.pkName)), mcp.Required(), ), ) @@ -206,17 +490,23 @@ func registerDeleteTool(h *Handler, schema, entity string) { // Resource registration // -------------------------------------------------------------------------- -func registerModelResource(h *Handler, schema, entity string) { - resourceURI := buildModelName(schema, entity) - displayName := entity - if schema != "" { - displayName = schema + "." + entity +func registerModelResource(h *Handler, schema, entity string, info modelInfo) { + resourceURI := info.fullName + + var resourceDesc strings.Builder + resourceDesc.WriteString(fmt.Sprintf("Database table: %s", info.fullName)) + if info.pkName != "" { + resourceDesc.WriteString(fmt.Sprintf(" (primary key: %s)", info.pkName)) + } + if info.schemaDoc != "" { + resourceDesc.WriteString("\n\n") + resourceDesc.WriteString(info.schemaDoc) } resource := mcp.NewResource( resourceURI, - displayName, - mcp.WithResourceDescription(fmt.Sprintf("Database table: %s", displayName)), + entity, + mcp.WithResourceDescription(resourceDesc.String()), mcp.WithMIMEType("application/json"), ) @@ -256,7 +546,6 @@ func registerModelResource(h *Handler, schema, entity string) { func parseRequestOptions(args map[string]interface{}) common.RequestOptions { options := common.RequestOptions{} - // limit if v, ok := args["limit"]; ok { switch n := v.(type) { case float64: @@ -267,7 +556,6 @@ func parseRequestOptions(args map[string]interface{}) common.RequestOptions { } } - // offset if v, ok := args["offset"]; ok { switch n := v.(type) { case float64: @@ -278,7 +566,6 @@ func parseRequestOptions(args map[string]interface{}) common.RequestOptions { } } - // cursor_forward / cursor_backward if v, ok := args["cursor_forward"].(string); ok { options.CursorForward = v } @@ -286,19 +573,10 @@ func parseRequestOptions(args map[string]interface{}) common.RequestOptions { options.CursorBackward = v } - // columns options.Columns = parseStringArray(args["columns"]) - - // omit_columns options.OmitColumns = parseStringArray(args["omit_columns"]) - - // filters — marshal each item and unmarshal into FilterOption options.Filters = parseFilters(args["filters"]) - - // sort options.Sort = parseSortOptions(args["sort"]) - - // preloads options.Preload = parsePreloadOptions(args["preloads"]) return options @@ -342,7 +620,6 @@ func parseFilters(raw interface{}) []common.FilterOption { if f.Column == "" || f.Operator == "" { continue } - // Normalise logic operator if strings.EqualFold(f.LogicOperator, "or") { f.LogicOperator = "OR" } else { diff --git a/pkg/resolvespec/cursor.go b/pkg/resolvespec/cursor.go index 58f7df9..879e4e1 100644 --- a/pkg/resolvespec/cursor.go +++ b/pkg/resolvespec/cursor.go @@ -24,6 +24,7 @@ const ( // - pkName: primary key column (e.g. "id") // - modelColumns: optional list of valid main-table columns (for validation). Pass nil to skip. // - options: the request options containing sort and cursor information +// - expandJoins: optional map[alias]string of JOIN clauses for join-column sort support // // Returns SQL snippet to embed in WHERE clause. func GetCursorFilter( @@ -31,6 +32,7 @@ func GetCursorFilter( pkName string, modelColumns []string, options common.RequestOptions, + expandJoins map[string]string, ) (string, error) { // Separate schema prefix from bare table name fullTableName := tableName @@ -58,6 +60,7 @@ func GetCursorFilter( // 3. Prepare // --------------------------------------------------------------------- // var whereClauses []string + joinSQL := "" reverse := direction < 0 // --------------------------------------------------------------------- // @@ -69,7 +72,7 @@ func GetCursorFilter( continue } - // Parse: "created_at", "user.name", etc. + // Parse: "created_at", "user.name", "fn.sortorder", etc. parts := strings.Split(col, ".") field := strings.TrimSpace(parts[len(parts)-1]) prefix := strings.Join(parts[:len(parts)-1], ".") @@ -82,7 +85,7 @@ func GetCursorFilter( } // Resolve column - cursorCol, targetCol, err := resolveColumn( + cursorCol, targetCol, isJoin, err := resolveColumn( field, prefix, tableName, modelColumns, ) if err != nil { @@ -90,6 +93,22 @@ func GetCursorFilter( continue } + // Handle joins + if isJoin { + if expandJoins != nil { + if joinClause, ok := expandJoins[prefix]; ok { + jSQL, cRef := rewriteJoin(joinClause, tableName, prefix) + joinSQL = jSQL + cursorCol = cRef + "." + field + targetCol = prefix + "." + field + } + } + if cursorCol == "" { + logger.Warn("Skipping cursor sort column %q: join alias %q not in expandJoins", col, prefix) + continue + } + } + // Build inequality op := "<" if desc { @@ -113,10 +132,12 @@ func GetCursorFilter( query := fmt.Sprintf(`EXISTS ( SELECT 1 FROM %s cursor_select + %s WHERE cursor_select.%s = %s AND (%s) )`, fullTableName, + joinSQL, pkName, cursorID, orSQL, @@ -137,35 +158,44 @@ func getActiveCursor(options common.RequestOptions) (id string, direction Cursor return "", 0 } -// Helper: resolve column (main table only for now) +// Helper: resolve column (main table or join) func resolveColumn( field, prefix, tableName string, modelColumns []string, -) (cursorCol, targetCol string, err error) { +) (cursorCol, targetCol string, isJoin bool, err error) { // JSON field if strings.Contains(field, "->") { - return "cursor_select." + field, tableName + "." + field, nil + return "cursor_select." + field, tableName + "." + field, false, nil } // Main table column if modelColumns != nil { for _, col := range modelColumns { if strings.EqualFold(col, field) { - return "cursor_select." + field, tableName + "." + field, nil + return "cursor_select." + field, tableName + "." + field, false, nil } } } else { // No validation → allow all main-table fields - return "cursor_select." + field, tableName + "." + field, nil + return "cursor_select." + field, tableName + "." + field, false, nil } - // Joined column (not supported in resolvespec yet) + // Joined column if prefix != "" && prefix != tableName { - return "", "", fmt.Errorf("joined columns not supported in cursor pagination: %s", field) + return "", "", true, nil } - return "", "", fmt.Errorf("invalid column: %s", field) + return "", "", false, fmt.Errorf("invalid column: %s", field) +} + +// Helper: rewrite JOIN clause for cursor subquery +func rewriteJoin(joinClause, mainTable, alias string) (joinSQL, cursorAlias string) { + joinSQL = strings.ReplaceAll(joinClause, mainTable+".", "cursor_select.") + cursorAlias = "cursor_select_" + alias + joinSQL = strings.ReplaceAll(joinSQL, " "+alias+" ", " "+cursorAlias+" ") + joinSQL = strings.ReplaceAll(joinSQL, " "+alias+".", " "+cursorAlias+".") + return joinSQL, cursorAlias } // ------------------------------------------------------------------------- // diff --git a/pkg/resolvespec/cursor_test.go b/pkg/resolvespec/cursor_test.go index 0b7b1fc..5ffaeef 100644 --- a/pkg/resolvespec/cursor_test.go +++ b/pkg/resolvespec/cursor_test.go @@ -20,7 +20,7 @@ func TestGetCursorFilter_Forward(t *testing.T) { pkName := "id" modelColumns := []string{"id", "title", "created_at", "user_id"} - filter, err := GetCursorFilter(tableName, pkName, modelColumns, options) + filter, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil) if err != nil { t.Fatalf("GetCursorFilter failed: %v", err) } @@ -65,7 +65,7 @@ func TestGetCursorFilter_Backward(t *testing.T) { pkName := "id" modelColumns := []string{"id", "title", "created_at", "user_id"} - filter, err := GetCursorFilter(tableName, pkName, modelColumns, options) + filter, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil) if err != nil { t.Fatalf("GetCursorFilter failed: %v", err) } @@ -96,7 +96,7 @@ func TestGetCursorFilter_NoCursor(t *testing.T) { pkName := "id" modelColumns := []string{"id", "title", "created_at"} - _, err := GetCursorFilter(tableName, pkName, modelColumns, options) + _, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil) if err == nil { t.Error("Expected error when no cursor is provided") } @@ -116,7 +116,7 @@ func TestGetCursorFilter_NoSort(t *testing.T) { pkName := "id" modelColumns := []string{"id", "title"} - _, err := GetCursorFilter(tableName, pkName, modelColumns, options) + _, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil) if err == nil { t.Error("Expected error when no sort columns are defined") } @@ -140,7 +140,7 @@ func TestGetCursorFilter_MultiColumnSort(t *testing.T) { pkName := "id" modelColumns := []string{"id", "title", "priority", "created_at"} - filter, err := GetCursorFilter(tableName, pkName, modelColumns, options) + filter, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil) if err != nil { t.Fatalf("GetCursorFilter failed: %v", err) } @@ -170,7 +170,7 @@ func TestGetCursorFilter_WithSchemaPrefix(t *testing.T) { pkName := "id" modelColumns := []string{"id", "name", "email"} - filter, err := GetCursorFilter(tableName, pkName, modelColumns, options) + filter, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil) if err != nil { t.Fatalf("GetCursorFilter failed: %v", err) } @@ -183,6 +183,37 @@ func TestGetCursorFilter_WithSchemaPrefix(t *testing.T) { t.Logf("Generated cursor filter with schema: %s", filter) } +func TestGetCursorFilter_LateralJoin(t *testing.T) { + lateralJoin := "inner join lateral (\nselect string_agg(a.name, '.') as sortorder\nfrom tree(account.rid_account) r\ninner join account a on a.id = r.id\n) fn on true" + + options := common.RequestOptions{ + Sort: []common.SortOption{{Column: "fn.sortorder", Direction: "ASC"}}, + CursorForward: "8975", + } + + tableName := "core.account" + pkName := "rid_account" + modelColumns := []string{"rid_account", "description", "pastelno"} + expandJoins := map[string]string{"fn": lateralJoin} + + filter, err := GetCursorFilter(tableName, pkName, modelColumns, options, expandJoins) + if err != nil { + t.Fatalf("GetCursorFilter failed: %v", err) + } + + t.Logf("Generated lateral cursor filter: %s", filter) + + if !strings.Contains(filter, "cursor_select_fn") { + t.Errorf("Filter should reference cursor_select_fn alias, got: %s", filter) + } + if !strings.Contains(filter, "sortorder") { + t.Errorf("Filter should reference sortorder column, got: %s", filter) + } + if strings.Contains(filter, " < ") || strings.Contains(filter, " > ") { + t.Errorf("Filter should not contain empty comparison operators, got: %s", filter) + } +} + func TestGetActiveCursor(t *testing.T) { tests := []struct { name string @@ -288,18 +319,19 @@ func TestResolveColumn(t *testing.T) { wantErr: false, }, { - name: "Joined column (not supported)", + name: "Joined column (isJoin=true, no error)", field: "name", prefix: "user", tableName: "posts", modelColumns: []string{"id", "title"}, - wantErr: true, + wantErr: false, + // cursorCol and targetCol are empty when isJoin=true; handled by caller }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - cursor, target, err := resolveColumn(tt.field, tt.prefix, tt.tableName, tt.modelColumns) + cursor, target, isJoin, err := resolveColumn(tt.field, tt.prefix, tt.tableName, tt.modelColumns) if tt.wantErr { if err == nil { @@ -312,6 +344,14 @@ func TestResolveColumn(t *testing.T) { t.Fatalf("Unexpected error: %v", err) } + // For join columns, cursor/target are empty and isJoin=true + if isJoin { + if cursor != "" || target != "" { + t.Errorf("Expected empty cursor/target for join column, got %q / %q", cursor, target) + } + return + } + if cursor != tt.wantCursor { t.Errorf("Expected cursor %q, got %q", tt.wantCursor, cursor) } @@ -362,7 +402,7 @@ func TestCursorFilter_SQL_Safety(t *testing.T) { pkName := "id" modelColumns := []string{"id", "created_at"} - filter, err := GetCursorFilter(tableName, pkName, modelColumns, options) + filter, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil) if err != nil { t.Fatalf("GetCursorFilter failed: %v", err) } diff --git a/pkg/resolvespec/handler.go b/pkg/resolvespec/handler.go index 1a7dd18..8fd163a 100644 --- a/pkg/resolvespec/handler.go +++ b/pkg/resolvespec/handler.go @@ -334,8 +334,8 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st options.Sort = []common.SortOption{{Column: pkName, Direction: "ASC"}} } - // Get cursor filter SQL - cursorFilter, err := GetCursorFilter(tableName, pkName, modelColumns, options) + // Get cursor filter SQL (expandJoins is empty for resolvespec — no custom SQL join support yet) + cursorFilter, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil) if err != nil { logger.Error("Error building cursor filter: %v", err) h.sendError(w, http.StatusBadRequest, "cursor_error", "Invalid cursor pagination", err)