feat(resolvemcp): add support for join-column sorting in cursor pagination

* Enhance getCursorFilter to accept join clauses for sorting
* Update resolveColumn to handle joined columns
* Modify tests to validate new join functionality
This commit is contained in:
Hein
2026-03-27 13:10:42 +02:00
parent 835bbb0727
commit 7f6410f665
7 changed files with 524 additions and 79 deletions

View File

@@ -9,6 +9,7 @@ ResolveSpec is a flexible and powerful REST API specification and implementation
3. **FuncSpec** - Header-based API to map and call API's to sql functions 3. **FuncSpec** - Header-based API to map and call API's to sql functions
4. **WebSocketSpec** - Real-time bidirectional communication with full CRUD operations 4. **WebSocketSpec** - Real-time bidirectional communication with full CRUD operations
5. **MQTTSpec** - MQTT-based API ideal for IoT and mobile applications 5. **MQTTSpec** - MQTT-based API ideal for IoT and mobile applications
6. **ResolveMCP** - Model Context Protocol (MCP) server that exposes models as AI-consumable tools and resources over HTTP/SSE
All share the same core architecture and provide dynamic data querying, relationship preloading, and complex filtering. All share the same core architecture and provide dynamic data querying, relationship preloading, and complex filtering.
@@ -21,6 +22,7 @@ All share the same core architecture and provide dynamic data querying, relation
* [Quick Start](#quick-start) * [Quick Start](#quick-start)
* [ResolveSpec (Body-Based API)](#resolvespec---body-based-api) * [ResolveSpec (Body-Based API)](#resolvespec---body-based-api)
* [RestHeadSpec (Header-Based API)](#restheadspec---header-based-api) * [RestHeadSpec (Header-Based API)](#restheadspec---header-based-api)
* [ResolveMCP (MCP Server)](#resolvemcp---mcp-server)
* [Architecture](#architecture) * [Architecture](#architecture)
* [API Structure](#api-structure) * [API Structure](#api-structure)
* [RestHeadSpec Overview](#restheadspec-header-based-api) * [RestHeadSpec Overview](#restheadspec-header-based-api)
@@ -50,6 +52,15 @@ All share the same core architecture and provide dynamic data querying, relation
* **🆕 Backward Compatible**: Existing code works without changes * **🆕 Backward Compatible**: Existing code works without changes
* **🆕 Better Testing**: Mockable interfaces for easy unit testing * **🆕 Better Testing**: Mockable interfaces for easy unit testing
### ResolveMCP (v3.2+)
* **🆕 MCP Server**: Expose any registered database model as Model Context Protocol tools and resources
* **🆕 AI-Ready Descriptions**: Tool descriptions include the full column schema, primary key, nullable flags, and relations — giving AI models everything they need to query correctly without guessing
* **🆕 Four Tools Per Model**: `read_`, `create_`, `update_`, `delete_` tools auto-registered per model
* **🆕 Full Query Support**: Filters, sort, limit/offset, cursor pagination, column selection, and relation preloading all available as tool parameters
* **🆕 HTTP/SSE Transport**: Standards-compliant SSE transport for use with Claude Desktop, Cursor, and any MCP-compatible client
* **🆕 Lifecycle Hooks**: Same Before/After hook system as ResolveSpec for auth and side-effects
### RestHeadSpec (v2.1+) ### RestHeadSpec (v2.1+)
* **🆕 Header-Based API**: All query options passed via HTTP headers instead of request body * **🆕 Header-Based API**: All query options passed via HTTP headers instead of request body
@@ -190,6 +201,40 @@ restheadspec.SetupMuxRoutes(router, handler, nil)
For complete documentation, see [pkg/restheadspec/README.md](pkg/restheadspec/README.md). For complete documentation, see [pkg/restheadspec/README.md](pkg/restheadspec/README.md).
### ResolveMCP (MCP Server)
ResolveMCP exposes registered models as Model Context Protocol tools so AI models (Claude, Cursor, etc.) can query and mutate your database directly:
```go
import "github.com/bitechdev/ResolveSpec/pkg/resolvemcp"
// Create handler
handler := resolvemcp.NewHandlerWithGORM(db)
// Register models — must be done BEFORE Build()
handler.RegisterModel("public", "users", &User{})
handler.RegisterModel("public", "posts", &Post{})
// Finalize: registers MCP tools and resources
handler.Build()
// Mount SSE transport on your existing router
router := mux.NewRouter()
resolvemcp.SetupMuxRoutes(router, handler, "http://localhost:8080")
// MCP clients connect to:
// SSE stream: GET http://localhost:8080/mcp/sse
// Messages: POST http://localhost:8080/mcp/message
//
// Auto-registered tools per model:
// read_public_users — filter, sort, paginate, preload
// create_public_users — insert a new record
// update_public_users — update a record by ID
// delete_public_users — delete a record by ID
```
For complete documentation, see [pkg/resolvemcp/README.md](pkg/resolvemcp/README.md) (if present) or the package source.
## Architecture ## Architecture
### Two Complementary APIs ### Two Complementary APIs
@@ -344,6 +389,19 @@ Alternative REST API where query options are passed via HTTP headers.
For complete documentation, see [pkg/restheadspec/README.md](pkg/restheadspec/README.md). For complete documentation, see [pkg/restheadspec/README.md](pkg/restheadspec/README.md).
#### ResolveMCP - MCP Server
Expose any registered model as Model Context Protocol tools and resources consumable by AI models over HTTP/SSE.
**Key Features**:
- Four tools per model: `read_`, `create_`, `update_`, `delete_`
- Rich AI-readable descriptions: column names, types, primary key, nullable flags, and preloadable relations
- Full query support: filters, sort, limit/offset, cursor pagination, column selection, preloads
- HTTP/SSE transport compatible with Claude Desktop, Cursor, and any MCP client
- Same Before/After lifecycle hooks as ResolveSpec
For complete documentation, see [pkg/resolvemcp/](pkg/resolvemcp/).
#### FuncSpec - Function-Based SQL API #### FuncSpec - Function-Based SQL API
Execute SQL functions and queries through a simple HTTP API with header-based parameters. Execute SQL functions and queries through a simple HTTP API with header-based parameters.
@@ -529,7 +587,18 @@ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file
## What's New ## What's New
### v3.1 (Latest - February 2026) ### v3.2 (Latest - March 2026)
**ResolveMCP - Model Context Protocol Server (🆕)**:
* **MCP Tools**: Four tools auto-registered per model (`read_`, `create_`, `update_`, `delete_`) over HTTP/SSE transport
* **AI-Ready Descriptions**: Full column schema, primary key, nullable flags, and relation names surfaced in tool descriptions so AI models can query without guessing
* **Full Query Support**: Filters, sort, limit/offset, cursor pagination, column selection, and relation preloading all available as tool parameters
* **HTTP/SSE Transport**: Standards-compliant transport compatible with Claude Desktop, Cursor, and any MCP 2024-11-05 client
* **Lifecycle Hooks**: Same Before/After hook system as ResolveSpec for auth, auditing, and side-effects
* **MCP Resources**: Each model also exposed as a named resource for direct data access by AI clients
### v3.1 (February 2026)
**SQLite Schema Translation (🆕)**: **SQLite Schema Translation (🆕)**:

View File

@@ -18,11 +18,13 @@ const (
) )
// getCursorFilter generates a SQL EXISTS subquery for cursor-based pagination. // getCursorFilter generates a SQL EXISTS subquery for cursor-based pagination.
// expandJoins is an optional map[alias]string of JOIN clauses for join-column sort support.
func getCursorFilter( func getCursorFilter(
tableName string, tableName string,
pkName string, pkName string,
modelColumns []string, modelColumns []string,
options common.RequestOptions, options common.RequestOptions,
expandJoins map[string]string,
) (string, error) { ) (string, error) {
fullTableName := tableName fullTableName := tableName
if strings.Contains(tableName, ".") { if strings.Contains(tableName, ".") {
@@ -40,6 +42,7 @@ func getCursorFilter(
} }
var whereClauses []string var whereClauses []string
joinSQL := ""
reverse := direction < 0 reverse := direction < 0
for _, s := range sortItems { for _, s := range sortItems {
@@ -57,12 +60,27 @@ func getCursorFilter(
desc = !desc desc = !desc
} }
cursorCol, targetCol, err := resolveCursorColumn(field, prefix, tableName, modelColumns) cursorCol, targetCol, isJoin, err := resolveCursorColumn(field, prefix, tableName, modelColumns)
if err != nil { if err != nil {
logger.Warn("Skipping invalid sort column %q: %v", col, err) logger.Warn("Skipping invalid sort column %q: %v", col, err)
continue continue
} }
if isJoin {
if expandJoins != nil {
if joinClause, ok := expandJoins[prefix]; ok {
jSQL, cRef := rewriteCursorJoin(joinClause, tableName, prefix)
joinSQL = jSQL
cursorCol = cRef + "." + field
targetCol = prefix + "." + field
}
}
if cursorCol == "" {
logger.Warn("Skipping cursor sort column %q: join alias %q not in expandJoins", col, prefix)
continue
}
}
op := "<" op := "<"
if desc { if desc {
op = ">" op = ">"
@@ -79,10 +97,12 @@ func getCursorFilter(
query := fmt.Sprintf(`EXISTS ( query := fmt.Sprintf(`EXISTS (
SELECT 1 SELECT 1
FROM %s cursor_select FROM %s cursor_select
%s
WHERE cursor_select.%s = %s WHERE cursor_select.%s = %s
AND (%s) AND (%s)
)`, )`,
fullTableName, fullTableName,
joinSQL,
pkName, pkName,
cursorID, cursorID,
orSQL, orSQL,
@@ -101,26 +121,34 @@ func getActiveCursor(options common.RequestOptions) (id string, direction cursor
return "", 0 return "", 0
} }
func resolveCursorColumn(field, prefix, tableName string, modelColumns []string) (cursorCol, targetCol string, err error) { func resolveCursorColumn(field, prefix, tableName string, modelColumns []string) (cursorCol, targetCol string, isJoin bool, err error) {
if strings.Contains(field, "->") { if strings.Contains(field, "->") {
return "cursor_select." + field, tableName + "." + field, nil return "cursor_select." + field, tableName + "." + field, false, nil
} }
if modelColumns != nil { if modelColumns != nil {
for _, col := range modelColumns { for _, col := range modelColumns {
if strings.EqualFold(col, field) { if strings.EqualFold(col, field) {
return "cursor_select." + field, tableName + "." + field, nil return "cursor_select." + field, tableName + "." + field, false, nil
} }
} }
} else { } else {
return "cursor_select." + field, tableName + "." + field, nil return "cursor_select." + field, tableName + "." + field, false, nil
} }
if prefix != "" && prefix != tableName { if prefix != "" && prefix != tableName {
return "", "", fmt.Errorf("joined columns not supported in cursor pagination: %s", field) return "", "", true, nil
} }
return "", "", fmt.Errorf("invalid column: %s", field) return "", "", false, fmt.Errorf("invalid column: %s", field)
}
func rewriteCursorJoin(joinClause, mainTable, alias string) (joinSQL, cursorAlias string) {
joinSQL = strings.ReplaceAll(joinClause, mainTable+".", "cursor_select.")
cursorAlias = "cursor_select_" + alias
joinSQL = strings.ReplaceAll(joinSQL, " "+alias+" ", " "+cursorAlias+" ")
joinSQL = strings.ReplaceAll(joinSQL, " "+alias+".", " "+cursorAlias+".")
return joinSQL, cursorAlias
} }
func buildCursorPriorityChain(clauses []string) string { func buildCursorPriorityChain(clauses []string) string {

View File

@@ -191,7 +191,8 @@ func (h *Handler) executeRead(ctx context.Context, schema, entity, id string, op
options.Sort = []common.SortOption{{Column: pkName, Direction: "ASC"}} options.Sort = []common.SortOption{{Column: pkName, Direction: "ASC"}}
} }
cursorFilter, err := getCursorFilter(tableName, pkName, modelColumns, options) // expandJoins is empty for resolvemcp — no custom SQL join support yet
cursorFilter, err := getCursorFilter(tableName, pkName, modelColumns, options, nil)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("cursor error: %w", err) return nil, nil, fmt.Errorf("cursor error: %w", err)
} }

View File

@@ -4,12 +4,14 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"reflect"
"strings" "strings"
"github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/mcp"
"github.com/bitechdev/ResolveSpec/pkg/common" "github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/logger" "github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/reflection"
) )
// toolName builds the MCP tool name for a given operation and model. // toolName builds the MCP tool name for a given operation and model.
@@ -22,54 +24,268 @@ func toolName(operation, schema, entity string) string {
// registerModelTools registers the four CRUD tools and resource for a model. // registerModelTools registers the four CRUD tools and resource for a model.
func registerModelTools(h *Handler, schema, entity string, model interface{}) { func registerModelTools(h *Handler, schema, entity string, model interface{}) {
registerReadTool(h, schema, entity) info := buildModelInfo(schema, entity, model)
registerCreateTool(h, schema, entity) registerReadTool(h, schema, entity, info)
registerUpdateTool(h, schema, entity) registerCreateTool(h, schema, entity, info)
registerDeleteTool(h, schema, entity) registerUpdateTool(h, schema, entity, info)
registerModelResource(h, schema, entity) registerDeleteTool(h, schema, entity, info)
registerModelResource(h, schema, entity, info)
logger.Info("[resolvemcp] Registered MCP tools for %s.%s", schema, entity) logger.Info("[resolvemcp] Registered MCP tools for %s", info.fullName)
}
// --------------------------------------------------------------------------
// Model introspection
// --------------------------------------------------------------------------
// modelInfo holds pre-computed metadata for a model used in tool descriptions.
type modelInfo struct {
fullName string // e.g. "public.users"
pkName string // e.g. "id"
columns []columnInfo
relationNames []string
schemaDoc string // formatted multi-line schema listing
}
type columnInfo struct {
jsonName string
sqlName string
goType string
sqlType string
isPrimary bool
isUnique bool
isFK bool
nullable bool
}
// buildModelInfo extracts column metadata and pre-builds the schema documentation string.
func buildModelInfo(schema, entity string, model interface{}) modelInfo {
info := modelInfo{
fullName: buildModelName(schema, entity),
pkName: reflection.GetPrimaryKeyName(model),
}
// Unwrap to base struct type
modelType := reflect.TypeOf(model)
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice) {
modelType = modelType.Elem()
}
if modelType == nil || modelType.Kind() != reflect.Struct {
return info
}
details := reflection.GetModelColumnDetail(reflect.New(modelType).Elem())
for _, d := range details {
// Derive the JSON name from the struct field
jsonName := fieldJSONName(modelType, d.Name)
if jsonName == "" || jsonName == "-" {
continue
}
// Skip relation fields (slice or user-defined struct that isn't time.Time).
fieldType, found := modelType.FieldByName(d.Name)
if found {
ft := fieldType.Type
if ft.Kind() == reflect.Ptr {
ft = ft.Elem()
}
isUserStruct := ft.Kind() == reflect.Struct && ft.Name() != "Time" && ft.PkgPath() != ""
if ft.Kind() == reflect.Slice || isUserStruct {
info.relationNames = append(info.relationNames, jsonName)
continue
}
}
sqlName := d.SQLName
if sqlName == "" {
sqlName = jsonName
}
// Derive Go type name, unwrapping pointer if needed.
goType := d.DataType
if goType == "" && found {
ft := fieldType.Type
for ft.Kind() == reflect.Ptr {
ft = ft.Elem()
}
goType = ft.Name()
}
// isPrimary: use both the GORM-tag detection and a name comparison against
// the known primary key (handles camelCase "primaryKey" tags correctly).
isPrimary := d.SQLKey == "primary_key" ||
(info.pkName != "" && (sqlName == info.pkName || jsonName == info.pkName))
ci := columnInfo{
jsonName: jsonName,
sqlName: sqlName,
goType: goType,
sqlType: d.SQLDataType,
isPrimary: isPrimary,
isUnique: d.SQLKey == "unique" || d.SQLKey == "uniqueindex",
isFK: d.SQLKey == "foreign_key",
nullable: d.Nullable,
}
info.columns = append(info.columns, ci)
}
info.schemaDoc = buildSchemaDoc(info)
return info
}
// fieldJSONName returns the JSON tag name for a struct field, falling back to the field name.
func fieldJSONName(modelType reflect.Type, fieldName string) string {
field, ok := modelType.FieldByName(fieldName)
if !ok {
return fieldName
}
tag := field.Tag.Get("json")
if tag == "" {
return fieldName
}
parts := strings.SplitN(tag, ",", 2)
if parts[0] == "" {
return fieldName
}
return parts[0]
}
// buildSchemaDoc builds a human-readable column listing for inclusion in tool descriptions.
func buildSchemaDoc(info modelInfo) string {
if len(info.columns) == 0 {
return ""
}
var sb strings.Builder
sb.WriteString("Columns:\n")
for _, c := range info.columns {
line := fmt.Sprintf(" • %s", c.jsonName)
typeDesc := c.goType
if c.sqlType != "" {
typeDesc = c.sqlType
}
if typeDesc != "" {
line += fmt.Sprintf(" (%s)", typeDesc)
}
var flags []string
if c.isPrimary {
flags = append(flags, "primary key")
}
if c.isUnique {
flags = append(flags, "unique")
}
if c.isFK {
flags = append(flags, "foreign key")
}
if !c.nullable && !c.isPrimary {
flags = append(flags, "not null")
} else if c.nullable {
flags = append(flags, "nullable")
}
if len(flags) > 0 {
line += " — " + strings.Join(flags, ", ")
}
sb.WriteString(line + "\n")
}
if len(info.relationNames) > 0 {
sb.WriteString("Relations (preloadable): " + strings.Join(info.relationNames, ", ") + "\n")
}
return sb.String()
}
// columnNameList returns a comma-separated list of JSON column names (for descriptions).
func columnNameList(cols []columnInfo) string {
names := make([]string, len(cols))
for i, c := range cols {
names[i] = c.jsonName
}
return strings.Join(names, ", ")
}
// writableColumnNames returns JSON names for all non-primary-key columns.
func writableColumnNames(cols []columnInfo) []string {
var names []string
for _, c := range cols {
if !c.isPrimary {
names = append(names, c.jsonName)
}
}
return names
} }
// -------------------------------------------------------------------------- // --------------------------------------------------------------------------
// Read tool // Read tool
// -------------------------------------------------------------------------- // --------------------------------------------------------------------------
func registerReadTool(h *Handler, schema, entity string) { func registerReadTool(h *Handler, schema, entity string, info modelInfo) {
name := toolName("read", schema, entity) name := toolName("read", schema, entity)
description := fmt.Sprintf("Read records from %s", buildModelName(schema, entity))
var descParts []string
descParts = append(descParts, fmt.Sprintf("Read records from the '%s' database table.", info.fullName))
if info.pkName != "" {
descParts = append(descParts, fmt.Sprintf("Primary key: '%s'. Pass it via 'id' to fetch a single record.", info.pkName))
}
if info.schemaDoc != "" {
descParts = append(descParts, info.schemaDoc)
}
descParts = append(descParts,
"Pagination: use 'limit'/'offset' for offset-based paging, or 'cursor_forward'/'cursor_backward' (pass the primary key value of the last/first record on the current page) for cursor-based paging.",
"Filtering: each filter object requires 'column' (JSON field name) and 'operator'. Supported operators: = != > < >= <= like ilike in is_null is_not_null. Combine with 'logic_operator': AND (default) or OR.",
"Sorting: each sort object requires 'column' and 'direction' (asc or desc).",
)
if len(info.relationNames) > 0 {
descParts = append(descParts, fmt.Sprintf("Preloadable relations: %s. Pass relation name in 'preloads'.", strings.Join(info.relationNames, ", ")))
}
description := strings.Join(descParts, "\n\n")
filterDesc := fmt.Sprintf(`Array of filter objects. Example: [{"column":"status","operator":"=","value":"active"},{"column":"age","operator":">","value":18,"logic_operator":"AND"}]`)
if len(info.columns) > 0 {
filterDesc += fmt.Sprintf(" Available columns: %s.", columnNameList(info.columns))
}
sortDesc := `Array of sort objects. Example: [{"column":"created_at","direction":"desc"}]`
if len(info.columns) > 0 {
sortDesc += fmt.Sprintf(" Available columns: %s.", columnNameList(info.columns))
}
tool := mcp.NewTool(name, tool := mcp.NewTool(name,
mcp.WithDescription(description), mcp.WithDescription(description),
mcp.WithString("id", mcp.WithString("id",
mcp.Description("Primary key of a single record to fetch (optional)"), mcp.Description(fmt.Sprintf("Primary key (%s) of a single record to fetch. Omit to return multiple records.", info.pkName)),
), ),
mcp.WithNumber("limit", mcp.WithNumber("limit",
mcp.Description("Maximum number of records to return"), mcp.Description("Maximum number of records to return per page. Recommended: 10100."),
), ),
mcp.WithNumber("offset", mcp.WithNumber("offset",
mcp.Description("Number of records to skip"), mcp.Description("Number of records to skip (for offset-based pagination). Use with 'limit'."),
), ),
mcp.WithString("cursor_forward", mcp.WithString("cursor_forward",
mcp.Description("Cursor value for the next page (primary key of last record on current page)"), mcp.Description(fmt.Sprintf("Cursor for the next page: pass the '%s' value of the last record on the current page. Requires 'sort' to be set.", info.pkName)),
), ),
mcp.WithString("cursor_backward", mcp.WithString("cursor_backward",
mcp.Description("Cursor value for the previous page"), mcp.Description(fmt.Sprintf("Cursor for the previous page: pass the '%s' value of the first record on the current page. Requires 'sort' to be set.", info.pkName)),
), ),
mcp.WithArray("columns", mcp.WithArray("columns",
mcp.Description("List of column names to include in the result"), mcp.Description(fmt.Sprintf("Columns to include in the result. Omit to return all columns. Available: %s.", columnNameList(info.columns))),
), ),
mcp.WithArray("omit_columns", mcp.WithArray("omit_columns",
mcp.Description("List of column names to exclude from the result"), mcp.Description(fmt.Sprintf("Columns to exclude from the result. Available: %s.", columnNameList(info.columns))),
), ),
mcp.WithArray("filters", mcp.WithArray("filters",
mcp.Description(`Array of filter objects. Each object: {"column":"name","operator":"=","value":"val","logic_operator":"AND|OR"}. Operators: =, !=, >, <, >=, <=, like, ilike, in, is_null, is_not_null`), mcp.Description(filterDesc),
), ),
mcp.WithArray("sort", mcp.WithArray("sort",
mcp.Description(`Array of sort objects. Each object: {"column":"name","direction":"asc|desc"}`), mcp.Description(sortDesc),
), ),
mcp.WithArray("preloads", mcp.WithArray("preloads",
mcp.Description(`Array of relation preload objects. Each object: {"relation":"RelationName","columns":["col1"]}`), mcp.Description(buildPreloadDesc(info)),
), ),
) )
@@ -91,18 +307,52 @@ func registerReadTool(h *Handler, schema, entity string) {
}) })
} }
func buildPreloadDesc(info modelInfo) string {
if len(info.relationNames) == 0 {
return `Array of relation preload objects. Each object: {"relation":"RelationName"}. No relations defined on this model.`
}
return fmt.Sprintf(
`Array of relation preload objects. Each object: {"relation":"RelationName","columns":["col1","col2"]}. Available relations: %s.`,
strings.Join(info.relationNames, ", "),
)
}
// -------------------------------------------------------------------------- // --------------------------------------------------------------------------
// Create tool // Create tool
// -------------------------------------------------------------------------- // --------------------------------------------------------------------------
func registerCreateTool(h *Handler, schema, entity string) { func registerCreateTool(h *Handler, schema, entity string, info modelInfo) {
name := toolName("create", schema, entity) name := toolName("create", schema, entity)
description := fmt.Sprintf("Create one or more records in %s", buildModelName(schema, entity))
writable := writableColumnNames(info.columns)
var descParts []string
descParts = append(descParts, fmt.Sprintf("Create one or more new records in the '%s' table.", info.fullName))
if len(writable) > 0 {
descParts = append(descParts, fmt.Sprintf("Writable fields: %s.", strings.Join(writable, ", ")))
}
if info.pkName != "" {
descParts = append(descParts, fmt.Sprintf("The primary key ('%s') is typically auto-generated — omit it unless you need to supply it explicitly.", info.pkName))
}
descParts = append(descParts,
"Pass a single JSON object to 'data' to create one record. Pass an array of objects to create multiple records in a single transaction (all succeed or all fail).",
)
if info.schemaDoc != "" {
descParts = append(descParts, info.schemaDoc)
}
description := strings.Join(descParts, "\n\n")
dataDesc := "Record fields to create."
if len(writable) > 0 {
dataDesc += fmt.Sprintf(" Writable fields: %s.", strings.Join(writable, ", "))
}
dataDesc += " Pass a single object or an array of objects."
tool := mcp.NewTool(name, tool := mcp.NewTool(name,
mcp.WithDescription(description), mcp.WithDescription(description),
mcp.WithObject("data", mcp.WithObject("data",
mcp.Description("Record fields to create (single object), or pass an array as the 'items' key"), mcp.Description(dataDesc),
mcp.Required(), mcp.Required(),
), ),
) )
@@ -130,17 +380,42 @@ func registerCreateTool(h *Handler, schema, entity string) {
// Update tool // Update tool
// -------------------------------------------------------------------------- // --------------------------------------------------------------------------
func registerUpdateTool(h *Handler, schema, entity string) { func registerUpdateTool(h *Handler, schema, entity string, info modelInfo) {
name := toolName("update", schema, entity) name := toolName("update", schema, entity)
description := fmt.Sprintf("Update an existing record in %s", buildModelName(schema, entity))
writable := writableColumnNames(info.columns)
var descParts []string
descParts = append(descParts, fmt.Sprintf("Update an existing record in the '%s' table.", info.fullName))
if info.pkName != "" {
descParts = append(descParts, fmt.Sprintf("Identify the record by its primary key ('%s') via the 'id' argument or by including '%s' inside 'data'.", info.pkName, info.pkName))
}
if len(writable) > 0 {
descParts = append(descParts, fmt.Sprintf("Updatable fields: %s.", strings.Join(writable, ", ")))
}
descParts = append(descParts,
"Only non-null, non-empty fields in 'data' are applied — existing values are preserved for fields you omit. Returns the merged record as stored.",
)
if info.schemaDoc != "" {
descParts = append(descParts, info.schemaDoc)
}
description := strings.Join(descParts, "\n\n")
idDesc := fmt.Sprintf("Primary key ('%s') of the record to update. Can also be included inside 'data'.", info.pkName)
dataDesc := "Fields to update (non-null, non-empty values are merged into the existing record)."
if len(writable) > 0 {
dataDesc += fmt.Sprintf(" Updatable fields: %s.", strings.Join(writable, ", "))
}
tool := mcp.NewTool(name, tool := mcp.NewTool(name,
mcp.WithDescription(description), mcp.WithDescription(description),
mcp.WithString("id", mcp.WithString("id",
mcp.Description("Primary key of the record to update"), mcp.Description(idDesc),
), ),
mcp.WithObject("data", mcp.WithObject("data",
mcp.Description("Fields to update (non-null fields will be merged into the existing record)"), mcp.Description(dataDesc),
mcp.Required(), mcp.Required(),
), ),
) )
@@ -174,14 +449,23 @@ func registerUpdateTool(h *Handler, schema, entity string) {
// Delete tool // Delete tool
// -------------------------------------------------------------------------- // --------------------------------------------------------------------------
func registerDeleteTool(h *Handler, schema, entity string) { func registerDeleteTool(h *Handler, schema, entity string, info modelInfo) {
name := toolName("delete", schema, entity) name := toolName("delete", schema, entity)
description := fmt.Sprintf("Delete a record from %s by primary key", buildModelName(schema, entity))
descParts := []string{
fmt.Sprintf("Delete a record from the '%s' table by its primary key.", info.fullName),
}
if info.pkName != "" {
descParts = append(descParts, fmt.Sprintf("Pass the '%s' value of the record to delete via the 'id' argument.", info.pkName))
}
descParts = append(descParts, "Returns the deleted record. This operation is irreversible.")
description := strings.Join(descParts, " ")
tool := mcp.NewTool(name, tool := mcp.NewTool(name,
mcp.WithDescription(description), mcp.WithDescription(description),
mcp.WithString("id", mcp.WithString("id",
mcp.Description("Primary key of the record to delete"), mcp.Description(fmt.Sprintf("Primary key ('%s') of the record to delete.", info.pkName)),
mcp.Required(), mcp.Required(),
), ),
) )
@@ -206,17 +490,23 @@ func registerDeleteTool(h *Handler, schema, entity string) {
// Resource registration // Resource registration
// -------------------------------------------------------------------------- // --------------------------------------------------------------------------
func registerModelResource(h *Handler, schema, entity string) { func registerModelResource(h *Handler, schema, entity string, info modelInfo) {
resourceURI := buildModelName(schema, entity) resourceURI := info.fullName
displayName := entity
if schema != "" { var resourceDesc strings.Builder
displayName = schema + "." + entity resourceDesc.WriteString(fmt.Sprintf("Database table: %s", info.fullName))
if info.pkName != "" {
resourceDesc.WriteString(fmt.Sprintf(" (primary key: %s)", info.pkName))
}
if info.schemaDoc != "" {
resourceDesc.WriteString("\n\n")
resourceDesc.WriteString(info.schemaDoc)
} }
resource := mcp.NewResource( resource := mcp.NewResource(
resourceURI, resourceURI,
displayName, entity,
mcp.WithResourceDescription(fmt.Sprintf("Database table: %s", displayName)), mcp.WithResourceDescription(resourceDesc.String()),
mcp.WithMIMEType("application/json"), mcp.WithMIMEType("application/json"),
) )
@@ -256,7 +546,6 @@ func registerModelResource(h *Handler, schema, entity string) {
func parseRequestOptions(args map[string]interface{}) common.RequestOptions { func parseRequestOptions(args map[string]interface{}) common.RequestOptions {
options := common.RequestOptions{} options := common.RequestOptions{}
// limit
if v, ok := args["limit"]; ok { if v, ok := args["limit"]; ok {
switch n := v.(type) { switch n := v.(type) {
case float64: case float64:
@@ -267,7 +556,6 @@ func parseRequestOptions(args map[string]interface{}) common.RequestOptions {
} }
} }
// offset
if v, ok := args["offset"]; ok { if v, ok := args["offset"]; ok {
switch n := v.(type) { switch n := v.(type) {
case float64: case float64:
@@ -278,7 +566,6 @@ func parseRequestOptions(args map[string]interface{}) common.RequestOptions {
} }
} }
// cursor_forward / cursor_backward
if v, ok := args["cursor_forward"].(string); ok { if v, ok := args["cursor_forward"].(string); ok {
options.CursorForward = v options.CursorForward = v
} }
@@ -286,19 +573,10 @@ func parseRequestOptions(args map[string]interface{}) common.RequestOptions {
options.CursorBackward = v options.CursorBackward = v
} }
// columns
options.Columns = parseStringArray(args["columns"]) options.Columns = parseStringArray(args["columns"])
// omit_columns
options.OmitColumns = parseStringArray(args["omit_columns"]) options.OmitColumns = parseStringArray(args["omit_columns"])
// filters — marshal each item and unmarshal into FilterOption
options.Filters = parseFilters(args["filters"]) options.Filters = parseFilters(args["filters"])
// sort
options.Sort = parseSortOptions(args["sort"]) options.Sort = parseSortOptions(args["sort"])
// preloads
options.Preload = parsePreloadOptions(args["preloads"]) options.Preload = parsePreloadOptions(args["preloads"])
return options return options
@@ -342,7 +620,6 @@ func parseFilters(raw interface{}) []common.FilterOption {
if f.Column == "" || f.Operator == "" { if f.Column == "" || f.Operator == "" {
continue continue
} }
// Normalise logic operator
if strings.EqualFold(f.LogicOperator, "or") { if strings.EqualFold(f.LogicOperator, "or") {
f.LogicOperator = "OR" f.LogicOperator = "OR"
} else { } else {

View File

@@ -24,6 +24,7 @@ const (
// - pkName: primary key column (e.g. "id") // - pkName: primary key column (e.g. "id")
// - modelColumns: optional list of valid main-table columns (for validation). Pass nil to skip. // - modelColumns: optional list of valid main-table columns (for validation). Pass nil to skip.
// - options: the request options containing sort and cursor information // - options: the request options containing sort and cursor information
// - expandJoins: optional map[alias]string of JOIN clauses for join-column sort support
// //
// Returns SQL snippet to embed in WHERE clause. // Returns SQL snippet to embed in WHERE clause.
func GetCursorFilter( func GetCursorFilter(
@@ -31,6 +32,7 @@ func GetCursorFilter(
pkName string, pkName string,
modelColumns []string, modelColumns []string,
options common.RequestOptions, options common.RequestOptions,
expandJoins map[string]string,
) (string, error) { ) (string, error) {
// Separate schema prefix from bare table name // Separate schema prefix from bare table name
fullTableName := tableName fullTableName := tableName
@@ -58,6 +60,7 @@ func GetCursorFilter(
// 3. Prepare // 3. Prepare
// --------------------------------------------------------------------- // // --------------------------------------------------------------------- //
var whereClauses []string var whereClauses []string
joinSQL := ""
reverse := direction < 0 reverse := direction < 0
// --------------------------------------------------------------------- // // --------------------------------------------------------------------- //
@@ -69,7 +72,7 @@ func GetCursorFilter(
continue continue
} }
// Parse: "created_at", "user.name", etc. // Parse: "created_at", "user.name", "fn.sortorder", etc.
parts := strings.Split(col, ".") parts := strings.Split(col, ".")
field := strings.TrimSpace(parts[len(parts)-1]) field := strings.TrimSpace(parts[len(parts)-1])
prefix := strings.Join(parts[:len(parts)-1], ".") prefix := strings.Join(parts[:len(parts)-1], ".")
@@ -82,7 +85,7 @@ func GetCursorFilter(
} }
// Resolve column // Resolve column
cursorCol, targetCol, err := resolveColumn( cursorCol, targetCol, isJoin, err := resolveColumn(
field, prefix, tableName, modelColumns, field, prefix, tableName, modelColumns,
) )
if err != nil { if err != nil {
@@ -90,6 +93,22 @@ func GetCursorFilter(
continue continue
} }
// Handle joins
if isJoin {
if expandJoins != nil {
if joinClause, ok := expandJoins[prefix]; ok {
jSQL, cRef := rewriteJoin(joinClause, tableName, prefix)
joinSQL = jSQL
cursorCol = cRef + "." + field
targetCol = prefix + "." + field
}
}
if cursorCol == "" {
logger.Warn("Skipping cursor sort column %q: join alias %q not in expandJoins", col, prefix)
continue
}
}
// Build inequality // Build inequality
op := "<" op := "<"
if desc { if desc {
@@ -113,10 +132,12 @@ func GetCursorFilter(
query := fmt.Sprintf(`EXISTS ( query := fmt.Sprintf(`EXISTS (
SELECT 1 SELECT 1
FROM %s cursor_select FROM %s cursor_select
%s
WHERE cursor_select.%s = %s WHERE cursor_select.%s = %s
AND (%s) AND (%s)
)`, )`,
fullTableName, fullTableName,
joinSQL,
pkName, pkName,
cursorID, cursorID,
orSQL, orSQL,
@@ -137,35 +158,44 @@ func getActiveCursor(options common.RequestOptions) (id string, direction Cursor
return "", 0 return "", 0
} }
// Helper: resolve column (main table only for now) // Helper: resolve column (main table or join)
func resolveColumn( func resolveColumn(
field, prefix, tableName string, field, prefix, tableName string,
modelColumns []string, modelColumns []string,
) (cursorCol, targetCol string, err error) { ) (cursorCol, targetCol string, isJoin bool, err error) {
// JSON field // JSON field
if strings.Contains(field, "->") { if strings.Contains(field, "->") {
return "cursor_select." + field, tableName + "." + field, nil return "cursor_select." + field, tableName + "." + field, false, nil
} }
// Main table column // Main table column
if modelColumns != nil { if modelColumns != nil {
for _, col := range modelColumns { for _, col := range modelColumns {
if strings.EqualFold(col, field) { if strings.EqualFold(col, field) {
return "cursor_select." + field, tableName + "." + field, nil return "cursor_select." + field, tableName + "." + field, false, nil
} }
} }
} else { } else {
// No validation → allow all main-table fields // No validation → allow all main-table fields
return "cursor_select." + field, tableName + "." + field, nil return "cursor_select." + field, tableName + "." + field, false, nil
} }
// Joined column (not supported in resolvespec yet) // Joined column
if prefix != "" && prefix != tableName { if prefix != "" && prefix != tableName {
return "", "", fmt.Errorf("joined columns not supported in cursor pagination: %s", field) return "", "", true, nil
} }
return "", "", fmt.Errorf("invalid column: %s", field) return "", "", false, fmt.Errorf("invalid column: %s", field)
}
// Helper: rewrite JOIN clause for cursor subquery
func rewriteJoin(joinClause, mainTable, alias string) (joinSQL, cursorAlias string) {
joinSQL = strings.ReplaceAll(joinClause, mainTable+".", "cursor_select.")
cursorAlias = "cursor_select_" + alias
joinSQL = strings.ReplaceAll(joinSQL, " "+alias+" ", " "+cursorAlias+" ")
joinSQL = strings.ReplaceAll(joinSQL, " "+alias+".", " "+cursorAlias+".")
return joinSQL, cursorAlias
} }
// ------------------------------------------------------------------------- // // ------------------------------------------------------------------------- //

View File

@@ -20,7 +20,7 @@ func TestGetCursorFilter_Forward(t *testing.T) {
pkName := "id" pkName := "id"
modelColumns := []string{"id", "title", "created_at", "user_id"} modelColumns := []string{"id", "title", "created_at", "user_id"}
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options) filter, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil)
if err != nil { if err != nil {
t.Fatalf("GetCursorFilter failed: %v", err) t.Fatalf("GetCursorFilter failed: %v", err)
} }
@@ -65,7 +65,7 @@ func TestGetCursorFilter_Backward(t *testing.T) {
pkName := "id" pkName := "id"
modelColumns := []string{"id", "title", "created_at", "user_id"} modelColumns := []string{"id", "title", "created_at", "user_id"}
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options) filter, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil)
if err != nil { if err != nil {
t.Fatalf("GetCursorFilter failed: %v", err) t.Fatalf("GetCursorFilter failed: %v", err)
} }
@@ -96,7 +96,7 @@ func TestGetCursorFilter_NoCursor(t *testing.T) {
pkName := "id" pkName := "id"
modelColumns := []string{"id", "title", "created_at"} modelColumns := []string{"id", "title", "created_at"}
_, err := GetCursorFilter(tableName, pkName, modelColumns, options) _, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil)
if err == nil { if err == nil {
t.Error("Expected error when no cursor is provided") t.Error("Expected error when no cursor is provided")
} }
@@ -116,7 +116,7 @@ func TestGetCursorFilter_NoSort(t *testing.T) {
pkName := "id" pkName := "id"
modelColumns := []string{"id", "title"} modelColumns := []string{"id", "title"}
_, err := GetCursorFilter(tableName, pkName, modelColumns, options) _, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil)
if err == nil { if err == nil {
t.Error("Expected error when no sort columns are defined") t.Error("Expected error when no sort columns are defined")
} }
@@ -140,7 +140,7 @@ func TestGetCursorFilter_MultiColumnSort(t *testing.T) {
pkName := "id" pkName := "id"
modelColumns := []string{"id", "title", "priority", "created_at"} modelColumns := []string{"id", "title", "priority", "created_at"}
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options) filter, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil)
if err != nil { if err != nil {
t.Fatalf("GetCursorFilter failed: %v", err) t.Fatalf("GetCursorFilter failed: %v", err)
} }
@@ -170,7 +170,7 @@ func TestGetCursorFilter_WithSchemaPrefix(t *testing.T) {
pkName := "id" pkName := "id"
modelColumns := []string{"id", "name", "email"} modelColumns := []string{"id", "name", "email"}
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options) filter, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil)
if err != nil { if err != nil {
t.Fatalf("GetCursorFilter failed: %v", err) t.Fatalf("GetCursorFilter failed: %v", err)
} }
@@ -183,6 +183,37 @@ func TestGetCursorFilter_WithSchemaPrefix(t *testing.T) {
t.Logf("Generated cursor filter with schema: %s", filter) t.Logf("Generated cursor filter with schema: %s", filter)
} }
func TestGetCursorFilter_LateralJoin(t *testing.T) {
lateralJoin := "inner join lateral (\nselect string_agg(a.name, '.') as sortorder\nfrom tree(account.rid_account) r\ninner join account a on a.id = r.id\n) fn on true"
options := common.RequestOptions{
Sort: []common.SortOption{{Column: "fn.sortorder", Direction: "ASC"}},
CursorForward: "8975",
}
tableName := "core.account"
pkName := "rid_account"
modelColumns := []string{"rid_account", "description", "pastelno"}
expandJoins := map[string]string{"fn": lateralJoin}
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options, expandJoins)
if err != nil {
t.Fatalf("GetCursorFilter failed: %v", err)
}
t.Logf("Generated lateral cursor filter: %s", filter)
if !strings.Contains(filter, "cursor_select_fn") {
t.Errorf("Filter should reference cursor_select_fn alias, got: %s", filter)
}
if !strings.Contains(filter, "sortorder") {
t.Errorf("Filter should reference sortorder column, got: %s", filter)
}
if strings.Contains(filter, " < ") || strings.Contains(filter, " > ") {
t.Errorf("Filter should not contain empty comparison operators, got: %s", filter)
}
}
func TestGetActiveCursor(t *testing.T) { func TestGetActiveCursor(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
@@ -288,18 +319,19 @@ func TestResolveColumn(t *testing.T) {
wantErr: false, wantErr: false,
}, },
{ {
name: "Joined column (not supported)", name: "Joined column (isJoin=true, no error)",
field: "name", field: "name",
prefix: "user", prefix: "user",
tableName: "posts", tableName: "posts",
modelColumns: []string{"id", "title"}, modelColumns: []string{"id", "title"},
wantErr: true, wantErr: false,
// cursorCol and targetCol are empty when isJoin=true; handled by caller
}, },
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
cursor, target, err := resolveColumn(tt.field, tt.prefix, tt.tableName, tt.modelColumns) cursor, target, isJoin, err := resolveColumn(tt.field, tt.prefix, tt.tableName, tt.modelColumns)
if tt.wantErr { if tt.wantErr {
if err == nil { if err == nil {
@@ -312,6 +344,14 @@ func TestResolveColumn(t *testing.T) {
t.Fatalf("Unexpected error: %v", err) t.Fatalf("Unexpected error: %v", err)
} }
// For join columns, cursor/target are empty and isJoin=true
if isJoin {
if cursor != "" || target != "" {
t.Errorf("Expected empty cursor/target for join column, got %q / %q", cursor, target)
}
return
}
if cursor != tt.wantCursor { if cursor != tt.wantCursor {
t.Errorf("Expected cursor %q, got %q", tt.wantCursor, cursor) t.Errorf("Expected cursor %q, got %q", tt.wantCursor, cursor)
} }
@@ -362,7 +402,7 @@ func TestCursorFilter_SQL_Safety(t *testing.T) {
pkName := "id" pkName := "id"
modelColumns := []string{"id", "created_at"} modelColumns := []string{"id", "created_at"}
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options) filter, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil)
if err != nil { if err != nil {
t.Fatalf("GetCursorFilter failed: %v", err) t.Fatalf("GetCursorFilter failed: %v", err)
} }

View File

@@ -334,8 +334,8 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
options.Sort = []common.SortOption{{Column: pkName, Direction: "ASC"}} options.Sort = []common.SortOption{{Column: pkName, Direction: "ASC"}}
} }
// Get cursor filter SQL // Get cursor filter SQL (expandJoins is empty for resolvespec — no custom SQL join support yet)
cursorFilter, err := GetCursorFilter(tableName, pkName, modelColumns, options) cursorFilter, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil)
if err != nil { if err != nil {
logger.Error("Error building cursor filter: %v", err) logger.Error("Error building cursor filter: %v", err)
h.sendError(w, http.StatusBadRequest, "cursor_error", "Invalid cursor pagination", err) h.sendError(w, http.StatusBadRequest, "cursor_error", "Invalid cursor pagination", err)