feat(resolvemcp): add support for join-column sorting in cursor pagination

* Enhance getCursorFilter to accept join clauses for sorting
* Update resolveColumn to handle joined columns
* Modify tests to validate new join functionality
This commit is contained in:
Hein
2026-03-27 13:10:42 +02:00
parent 835bbb0727
commit 7f6410f665
7 changed files with 524 additions and 79 deletions

View File

@@ -18,11 +18,13 @@ const (
)
// getCursorFilter generates a SQL EXISTS subquery for cursor-based pagination.
// expandJoins is an optional map[alias]string of JOIN clauses for join-column sort support.
func getCursorFilter(
tableName string,
pkName string,
modelColumns []string,
options common.RequestOptions,
expandJoins map[string]string,
) (string, error) {
fullTableName := tableName
if strings.Contains(tableName, ".") {
@@ -40,6 +42,7 @@ func getCursorFilter(
}
var whereClauses []string
joinSQL := ""
reverse := direction < 0
for _, s := range sortItems {
@@ -57,12 +60,27 @@ func getCursorFilter(
desc = !desc
}
cursorCol, targetCol, err := resolveCursorColumn(field, prefix, tableName, modelColumns)
cursorCol, targetCol, isJoin, err := resolveCursorColumn(field, prefix, tableName, modelColumns)
if err != nil {
logger.Warn("Skipping invalid sort column %q: %v", col, err)
continue
}
if isJoin {
if expandJoins != nil {
if joinClause, ok := expandJoins[prefix]; ok {
jSQL, cRef := rewriteCursorJoin(joinClause, tableName, prefix)
joinSQL = jSQL
cursorCol = cRef + "." + field
targetCol = prefix + "." + field
}
}
if cursorCol == "" {
logger.Warn("Skipping cursor sort column %q: join alias %q not in expandJoins", col, prefix)
continue
}
}
op := "<"
if desc {
op = ">"
@@ -79,10 +97,12 @@ func getCursorFilter(
query := fmt.Sprintf(`EXISTS (
SELECT 1
FROM %s cursor_select
%s
WHERE cursor_select.%s = %s
AND (%s)
)`,
fullTableName,
joinSQL,
pkName,
cursorID,
orSQL,
@@ -101,26 +121,34 @@ func getActiveCursor(options common.RequestOptions) (id string, direction cursor
return "", 0
}
func resolveCursorColumn(field, prefix, tableName string, modelColumns []string) (cursorCol, targetCol string, err error) {
func resolveCursorColumn(field, prefix, tableName string, modelColumns []string) (cursorCol, targetCol string, isJoin bool, err error) {
if strings.Contains(field, "->") {
return "cursor_select." + field, tableName + "." + field, nil
return "cursor_select." + field, tableName + "." + field, false, nil
}
if modelColumns != nil {
for _, col := range modelColumns {
if strings.EqualFold(col, field) {
return "cursor_select." + field, tableName + "." + field, nil
return "cursor_select." + field, tableName + "." + field, false, nil
}
}
} else {
return "cursor_select." + field, tableName + "." + field, nil
return "cursor_select." + field, tableName + "." + field, false, nil
}
if prefix != "" && prefix != tableName {
return "", "", fmt.Errorf("joined columns not supported in cursor pagination: %s", field)
return "", "", true, nil
}
return "", "", fmt.Errorf("invalid column: %s", field)
return "", "", false, fmt.Errorf("invalid column: %s", field)
}
func rewriteCursorJoin(joinClause, mainTable, alias string) (joinSQL, cursorAlias string) {
joinSQL = strings.ReplaceAll(joinClause, mainTable+".", "cursor_select.")
cursorAlias = "cursor_select_" + alias
joinSQL = strings.ReplaceAll(joinSQL, " "+alias+" ", " "+cursorAlias+" ")
joinSQL = strings.ReplaceAll(joinSQL, " "+alias+".", " "+cursorAlias+".")
return joinSQL, cursorAlias
}
func buildCursorPriorityChain(clauses []string) string {

View File

@@ -191,7 +191,8 @@ func (h *Handler) executeRead(ctx context.Context, schema, entity, id string, op
options.Sort = []common.SortOption{{Column: pkName, Direction: "ASC"}}
}
cursorFilter, err := getCursorFilter(tableName, pkName, modelColumns, options)
// expandJoins is empty for resolvemcp — no custom SQL join support yet
cursorFilter, err := getCursorFilter(tableName, pkName, modelColumns, options, nil)
if err != nil {
return nil, nil, fmt.Errorf("cursor error: %w", err)
}

View File

@@ -4,12 +4,14 @@ import (
"context"
"encoding/json"
"fmt"
"reflect"
"strings"
"github.com/mark3labs/mcp-go/mcp"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/reflection"
)
// toolName builds the MCP tool name for a given operation and model.
@@ -22,54 +24,268 @@ func toolName(operation, schema, entity string) string {
// registerModelTools registers the four CRUD tools and resource for a model.
func registerModelTools(h *Handler, schema, entity string, model interface{}) {
registerReadTool(h, schema, entity)
registerCreateTool(h, schema, entity)
registerUpdateTool(h, schema, entity)
registerDeleteTool(h, schema, entity)
registerModelResource(h, schema, entity)
info := buildModelInfo(schema, entity, model)
registerReadTool(h, schema, entity, info)
registerCreateTool(h, schema, entity, info)
registerUpdateTool(h, schema, entity, info)
registerDeleteTool(h, schema, entity, info)
registerModelResource(h, schema, entity, info)
logger.Info("[resolvemcp] Registered MCP tools for %s.%s", schema, entity)
logger.Info("[resolvemcp] Registered MCP tools for %s", info.fullName)
}
// --------------------------------------------------------------------------
// Model introspection
// --------------------------------------------------------------------------
// modelInfo holds pre-computed metadata for a model used in tool descriptions.
type modelInfo struct {
fullName string // e.g. "public.users"
pkName string // e.g. "id"
columns []columnInfo
relationNames []string
schemaDoc string // formatted multi-line schema listing
}
type columnInfo struct {
jsonName string
sqlName string
goType string
sqlType string
isPrimary bool
isUnique bool
isFK bool
nullable bool
}
// buildModelInfo extracts column metadata and pre-builds the schema documentation string.
func buildModelInfo(schema, entity string, model interface{}) modelInfo {
info := modelInfo{
fullName: buildModelName(schema, entity),
pkName: reflection.GetPrimaryKeyName(model),
}
// Unwrap to base struct type
modelType := reflect.TypeOf(model)
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice) {
modelType = modelType.Elem()
}
if modelType == nil || modelType.Kind() != reflect.Struct {
return info
}
details := reflection.GetModelColumnDetail(reflect.New(modelType).Elem())
for _, d := range details {
// Derive the JSON name from the struct field
jsonName := fieldJSONName(modelType, d.Name)
if jsonName == "" || jsonName == "-" {
continue
}
// Skip relation fields (slice or user-defined struct that isn't time.Time).
fieldType, found := modelType.FieldByName(d.Name)
if found {
ft := fieldType.Type
if ft.Kind() == reflect.Ptr {
ft = ft.Elem()
}
isUserStruct := ft.Kind() == reflect.Struct && ft.Name() != "Time" && ft.PkgPath() != ""
if ft.Kind() == reflect.Slice || isUserStruct {
info.relationNames = append(info.relationNames, jsonName)
continue
}
}
sqlName := d.SQLName
if sqlName == "" {
sqlName = jsonName
}
// Derive Go type name, unwrapping pointer if needed.
goType := d.DataType
if goType == "" && found {
ft := fieldType.Type
for ft.Kind() == reflect.Ptr {
ft = ft.Elem()
}
goType = ft.Name()
}
// isPrimary: use both the GORM-tag detection and a name comparison against
// the known primary key (handles camelCase "primaryKey" tags correctly).
isPrimary := d.SQLKey == "primary_key" ||
(info.pkName != "" && (sqlName == info.pkName || jsonName == info.pkName))
ci := columnInfo{
jsonName: jsonName,
sqlName: sqlName,
goType: goType,
sqlType: d.SQLDataType,
isPrimary: isPrimary,
isUnique: d.SQLKey == "unique" || d.SQLKey == "uniqueindex",
isFK: d.SQLKey == "foreign_key",
nullable: d.Nullable,
}
info.columns = append(info.columns, ci)
}
info.schemaDoc = buildSchemaDoc(info)
return info
}
// fieldJSONName returns the JSON tag name for a struct field, falling back to the field name.
func fieldJSONName(modelType reflect.Type, fieldName string) string {
field, ok := modelType.FieldByName(fieldName)
if !ok {
return fieldName
}
tag := field.Tag.Get("json")
if tag == "" {
return fieldName
}
parts := strings.SplitN(tag, ",", 2)
if parts[0] == "" {
return fieldName
}
return parts[0]
}
// buildSchemaDoc builds a human-readable column listing for inclusion in tool descriptions.
func buildSchemaDoc(info modelInfo) string {
if len(info.columns) == 0 {
return ""
}
var sb strings.Builder
sb.WriteString("Columns:\n")
for _, c := range info.columns {
line := fmt.Sprintf(" • %s", c.jsonName)
typeDesc := c.goType
if c.sqlType != "" {
typeDesc = c.sqlType
}
if typeDesc != "" {
line += fmt.Sprintf(" (%s)", typeDesc)
}
var flags []string
if c.isPrimary {
flags = append(flags, "primary key")
}
if c.isUnique {
flags = append(flags, "unique")
}
if c.isFK {
flags = append(flags, "foreign key")
}
if !c.nullable && !c.isPrimary {
flags = append(flags, "not null")
} else if c.nullable {
flags = append(flags, "nullable")
}
if len(flags) > 0 {
line += " — " + strings.Join(flags, ", ")
}
sb.WriteString(line + "\n")
}
if len(info.relationNames) > 0 {
sb.WriteString("Relations (preloadable): " + strings.Join(info.relationNames, ", ") + "\n")
}
return sb.String()
}
// columnNameList returns a comma-separated list of JSON column names (for descriptions).
func columnNameList(cols []columnInfo) string {
names := make([]string, len(cols))
for i, c := range cols {
names[i] = c.jsonName
}
return strings.Join(names, ", ")
}
// writableColumnNames returns JSON names for all non-primary-key columns.
func writableColumnNames(cols []columnInfo) []string {
var names []string
for _, c := range cols {
if !c.isPrimary {
names = append(names, c.jsonName)
}
}
return names
}
// --------------------------------------------------------------------------
// Read tool
// --------------------------------------------------------------------------
func registerReadTool(h *Handler, schema, entity string) {
func registerReadTool(h *Handler, schema, entity string, info modelInfo) {
name := toolName("read", schema, entity)
description := fmt.Sprintf("Read records from %s", buildModelName(schema, entity))
var descParts []string
descParts = append(descParts, fmt.Sprintf("Read records from the '%s' database table.", info.fullName))
if info.pkName != "" {
descParts = append(descParts, fmt.Sprintf("Primary key: '%s'. Pass it via 'id' to fetch a single record.", info.pkName))
}
if info.schemaDoc != "" {
descParts = append(descParts, info.schemaDoc)
}
descParts = append(descParts,
"Pagination: use 'limit'/'offset' for offset-based paging, or 'cursor_forward'/'cursor_backward' (pass the primary key value of the last/first record on the current page) for cursor-based paging.",
"Filtering: each filter object requires 'column' (JSON field name) and 'operator'. Supported operators: = != > < >= <= like ilike in is_null is_not_null. Combine with 'logic_operator': AND (default) or OR.",
"Sorting: each sort object requires 'column' and 'direction' (asc or desc).",
)
if len(info.relationNames) > 0 {
descParts = append(descParts, fmt.Sprintf("Preloadable relations: %s. Pass relation name in 'preloads'.", strings.Join(info.relationNames, ", ")))
}
description := strings.Join(descParts, "\n\n")
filterDesc := fmt.Sprintf(`Array of filter objects. Example: [{"column":"status","operator":"=","value":"active"},{"column":"age","operator":">","value":18,"logic_operator":"AND"}]`)
if len(info.columns) > 0 {
filterDesc += fmt.Sprintf(" Available columns: %s.", columnNameList(info.columns))
}
sortDesc := `Array of sort objects. Example: [{"column":"created_at","direction":"desc"}]`
if len(info.columns) > 0 {
sortDesc += fmt.Sprintf(" Available columns: %s.", columnNameList(info.columns))
}
tool := mcp.NewTool(name,
mcp.WithDescription(description),
mcp.WithString("id",
mcp.Description("Primary key of a single record to fetch (optional)"),
mcp.Description(fmt.Sprintf("Primary key (%s) of a single record to fetch. Omit to return multiple records.", info.pkName)),
),
mcp.WithNumber("limit",
mcp.Description("Maximum number of records to return"),
mcp.Description("Maximum number of records to return per page. Recommended: 10100."),
),
mcp.WithNumber("offset",
mcp.Description("Number of records to skip"),
mcp.Description("Number of records to skip (for offset-based pagination). Use with 'limit'."),
),
mcp.WithString("cursor_forward",
mcp.Description("Cursor value for the next page (primary key of last record on current page)"),
mcp.Description(fmt.Sprintf("Cursor for the next page: pass the '%s' value of the last record on the current page. Requires 'sort' to be set.", info.pkName)),
),
mcp.WithString("cursor_backward",
mcp.Description("Cursor value for the previous page"),
mcp.Description(fmt.Sprintf("Cursor for the previous page: pass the '%s' value of the first record on the current page. Requires 'sort' to be set.", info.pkName)),
),
mcp.WithArray("columns",
mcp.Description("List of column names to include in the result"),
mcp.Description(fmt.Sprintf("Columns to include in the result. Omit to return all columns. Available: %s.", columnNameList(info.columns))),
),
mcp.WithArray("omit_columns",
mcp.Description("List of column names to exclude from the result"),
mcp.Description(fmt.Sprintf("Columns to exclude from the result. Available: %s.", columnNameList(info.columns))),
),
mcp.WithArray("filters",
mcp.Description(`Array of filter objects. Each object: {"column":"name","operator":"=","value":"val","logic_operator":"AND|OR"}. Operators: =, !=, >, <, >=, <=, like, ilike, in, is_null, is_not_null`),
mcp.Description(filterDesc),
),
mcp.WithArray("sort",
mcp.Description(`Array of sort objects. Each object: {"column":"name","direction":"asc|desc"}`),
mcp.Description(sortDesc),
),
mcp.WithArray("preloads",
mcp.Description(`Array of relation preload objects. Each object: {"relation":"RelationName","columns":["col1"]}`),
mcp.Description(buildPreloadDesc(info)),
),
)
@@ -91,18 +307,52 @@ func registerReadTool(h *Handler, schema, entity string) {
})
}
func buildPreloadDesc(info modelInfo) string {
if len(info.relationNames) == 0 {
return `Array of relation preload objects. Each object: {"relation":"RelationName"}. No relations defined on this model.`
}
return fmt.Sprintf(
`Array of relation preload objects. Each object: {"relation":"RelationName","columns":["col1","col2"]}. Available relations: %s.`,
strings.Join(info.relationNames, ", "),
)
}
// --------------------------------------------------------------------------
// Create tool
// --------------------------------------------------------------------------
func registerCreateTool(h *Handler, schema, entity string) {
func registerCreateTool(h *Handler, schema, entity string, info modelInfo) {
name := toolName("create", schema, entity)
description := fmt.Sprintf("Create one or more records in %s", buildModelName(schema, entity))
writable := writableColumnNames(info.columns)
var descParts []string
descParts = append(descParts, fmt.Sprintf("Create one or more new records in the '%s' table.", info.fullName))
if len(writable) > 0 {
descParts = append(descParts, fmt.Sprintf("Writable fields: %s.", strings.Join(writable, ", ")))
}
if info.pkName != "" {
descParts = append(descParts, fmt.Sprintf("The primary key ('%s') is typically auto-generated — omit it unless you need to supply it explicitly.", info.pkName))
}
descParts = append(descParts,
"Pass a single JSON object to 'data' to create one record. Pass an array of objects to create multiple records in a single transaction (all succeed or all fail).",
)
if info.schemaDoc != "" {
descParts = append(descParts, info.schemaDoc)
}
description := strings.Join(descParts, "\n\n")
dataDesc := "Record fields to create."
if len(writable) > 0 {
dataDesc += fmt.Sprintf(" Writable fields: %s.", strings.Join(writable, ", "))
}
dataDesc += " Pass a single object or an array of objects."
tool := mcp.NewTool(name,
mcp.WithDescription(description),
mcp.WithObject("data",
mcp.Description("Record fields to create (single object), or pass an array as the 'items' key"),
mcp.Description(dataDesc),
mcp.Required(),
),
)
@@ -130,17 +380,42 @@ func registerCreateTool(h *Handler, schema, entity string) {
// Update tool
// --------------------------------------------------------------------------
func registerUpdateTool(h *Handler, schema, entity string) {
func registerUpdateTool(h *Handler, schema, entity string, info modelInfo) {
name := toolName("update", schema, entity)
description := fmt.Sprintf("Update an existing record in %s", buildModelName(schema, entity))
writable := writableColumnNames(info.columns)
var descParts []string
descParts = append(descParts, fmt.Sprintf("Update an existing record in the '%s' table.", info.fullName))
if info.pkName != "" {
descParts = append(descParts, fmt.Sprintf("Identify the record by its primary key ('%s') via the 'id' argument or by including '%s' inside 'data'.", info.pkName, info.pkName))
}
if len(writable) > 0 {
descParts = append(descParts, fmt.Sprintf("Updatable fields: %s.", strings.Join(writable, ", ")))
}
descParts = append(descParts,
"Only non-null, non-empty fields in 'data' are applied — existing values are preserved for fields you omit. Returns the merged record as stored.",
)
if info.schemaDoc != "" {
descParts = append(descParts, info.schemaDoc)
}
description := strings.Join(descParts, "\n\n")
idDesc := fmt.Sprintf("Primary key ('%s') of the record to update. Can also be included inside 'data'.", info.pkName)
dataDesc := "Fields to update (non-null, non-empty values are merged into the existing record)."
if len(writable) > 0 {
dataDesc += fmt.Sprintf(" Updatable fields: %s.", strings.Join(writable, ", "))
}
tool := mcp.NewTool(name,
mcp.WithDescription(description),
mcp.WithString("id",
mcp.Description("Primary key of the record to update"),
mcp.Description(idDesc),
),
mcp.WithObject("data",
mcp.Description("Fields to update (non-null fields will be merged into the existing record)"),
mcp.Description(dataDesc),
mcp.Required(),
),
)
@@ -174,14 +449,23 @@ func registerUpdateTool(h *Handler, schema, entity string) {
// Delete tool
// --------------------------------------------------------------------------
func registerDeleteTool(h *Handler, schema, entity string) {
func registerDeleteTool(h *Handler, schema, entity string, info modelInfo) {
name := toolName("delete", schema, entity)
description := fmt.Sprintf("Delete a record from %s by primary key", buildModelName(schema, entity))
descParts := []string{
fmt.Sprintf("Delete a record from the '%s' table by its primary key.", info.fullName),
}
if info.pkName != "" {
descParts = append(descParts, fmt.Sprintf("Pass the '%s' value of the record to delete via the 'id' argument.", info.pkName))
}
descParts = append(descParts, "Returns the deleted record. This operation is irreversible.")
description := strings.Join(descParts, " ")
tool := mcp.NewTool(name,
mcp.WithDescription(description),
mcp.WithString("id",
mcp.Description("Primary key of the record to delete"),
mcp.Description(fmt.Sprintf("Primary key ('%s') of the record to delete.", info.pkName)),
mcp.Required(),
),
)
@@ -206,17 +490,23 @@ func registerDeleteTool(h *Handler, schema, entity string) {
// Resource registration
// --------------------------------------------------------------------------
func registerModelResource(h *Handler, schema, entity string) {
resourceURI := buildModelName(schema, entity)
displayName := entity
if schema != "" {
displayName = schema + "." + entity
func registerModelResource(h *Handler, schema, entity string, info modelInfo) {
resourceURI := info.fullName
var resourceDesc strings.Builder
resourceDesc.WriteString(fmt.Sprintf("Database table: %s", info.fullName))
if info.pkName != "" {
resourceDesc.WriteString(fmt.Sprintf(" (primary key: %s)", info.pkName))
}
if info.schemaDoc != "" {
resourceDesc.WriteString("\n\n")
resourceDesc.WriteString(info.schemaDoc)
}
resource := mcp.NewResource(
resourceURI,
displayName,
mcp.WithResourceDescription(fmt.Sprintf("Database table: %s", displayName)),
entity,
mcp.WithResourceDescription(resourceDesc.String()),
mcp.WithMIMEType("application/json"),
)
@@ -256,7 +546,6 @@ func registerModelResource(h *Handler, schema, entity string) {
func parseRequestOptions(args map[string]interface{}) common.RequestOptions {
options := common.RequestOptions{}
// limit
if v, ok := args["limit"]; ok {
switch n := v.(type) {
case float64:
@@ -267,7 +556,6 @@ func parseRequestOptions(args map[string]interface{}) common.RequestOptions {
}
}
// offset
if v, ok := args["offset"]; ok {
switch n := v.(type) {
case float64:
@@ -278,7 +566,6 @@ func parseRequestOptions(args map[string]interface{}) common.RequestOptions {
}
}
// cursor_forward / cursor_backward
if v, ok := args["cursor_forward"].(string); ok {
options.CursorForward = v
}
@@ -286,19 +573,10 @@ func parseRequestOptions(args map[string]interface{}) common.RequestOptions {
options.CursorBackward = v
}
// columns
options.Columns = parseStringArray(args["columns"])
// omit_columns
options.OmitColumns = parseStringArray(args["omit_columns"])
// filters — marshal each item and unmarshal into FilterOption
options.Filters = parseFilters(args["filters"])
// sort
options.Sort = parseSortOptions(args["sort"])
// preloads
options.Preload = parsePreloadOptions(args["preloads"])
return options
@@ -342,7 +620,6 @@ func parseFilters(raw interface{}) []common.FilterOption {
if f.Column == "" || f.Operator == "" {
continue
}
// Normalise logic operator
if strings.EqualFold(f.LogicOperator, "or") {
f.LogicOperator = "OR"
} else {