mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2026-04-04 23:22:21 +00:00
feat(resolvemcp): add support for join-column sorting in cursor pagination
* Enhance getCursorFilter to accept join clauses for sorting * Update resolveColumn to handle joined columns * Modify tests to validate new join functionality
This commit is contained in:
71
README.md
71
README.md
@@ -9,6 +9,7 @@ ResolveSpec is a flexible and powerful REST API specification and implementation
|
|||||||
3. **FuncSpec** - Header-based API to map and call API's to sql functions
|
3. **FuncSpec** - Header-based API to map and call API's to sql functions
|
||||||
4. **WebSocketSpec** - Real-time bidirectional communication with full CRUD operations
|
4. **WebSocketSpec** - Real-time bidirectional communication with full CRUD operations
|
||||||
5. **MQTTSpec** - MQTT-based API ideal for IoT and mobile applications
|
5. **MQTTSpec** - MQTT-based API ideal for IoT and mobile applications
|
||||||
|
6. **ResolveMCP** - Model Context Protocol (MCP) server that exposes models as AI-consumable tools and resources over HTTP/SSE
|
||||||
|
|
||||||
All share the same core architecture and provide dynamic data querying, relationship preloading, and complex filtering.
|
All share the same core architecture and provide dynamic data querying, relationship preloading, and complex filtering.
|
||||||
|
|
||||||
@@ -21,6 +22,7 @@ All share the same core architecture and provide dynamic data querying, relation
|
|||||||
* [Quick Start](#quick-start)
|
* [Quick Start](#quick-start)
|
||||||
* [ResolveSpec (Body-Based API)](#resolvespec---body-based-api)
|
* [ResolveSpec (Body-Based API)](#resolvespec---body-based-api)
|
||||||
* [RestHeadSpec (Header-Based API)](#restheadspec---header-based-api)
|
* [RestHeadSpec (Header-Based API)](#restheadspec---header-based-api)
|
||||||
|
* [ResolveMCP (MCP Server)](#resolvemcp---mcp-server)
|
||||||
* [Architecture](#architecture)
|
* [Architecture](#architecture)
|
||||||
* [API Structure](#api-structure)
|
* [API Structure](#api-structure)
|
||||||
* [RestHeadSpec Overview](#restheadspec-header-based-api)
|
* [RestHeadSpec Overview](#restheadspec-header-based-api)
|
||||||
@@ -50,6 +52,15 @@ All share the same core architecture and provide dynamic data querying, relation
|
|||||||
* **🆕 Backward Compatible**: Existing code works without changes
|
* **🆕 Backward Compatible**: Existing code works without changes
|
||||||
* **🆕 Better Testing**: Mockable interfaces for easy unit testing
|
* **🆕 Better Testing**: Mockable interfaces for easy unit testing
|
||||||
|
|
||||||
|
### ResolveMCP (v3.2+)
|
||||||
|
|
||||||
|
* **🆕 MCP Server**: Expose any registered database model as Model Context Protocol tools and resources
|
||||||
|
* **🆕 AI-Ready Descriptions**: Tool descriptions include the full column schema, primary key, nullable flags, and relations — giving AI models everything they need to query correctly without guessing
|
||||||
|
* **🆕 Four Tools Per Model**: `read_`, `create_`, `update_`, `delete_` tools auto-registered per model
|
||||||
|
* **🆕 Full Query Support**: Filters, sort, limit/offset, cursor pagination, column selection, and relation preloading all available as tool parameters
|
||||||
|
* **🆕 HTTP/SSE Transport**: Standards-compliant SSE transport for use with Claude Desktop, Cursor, and any MCP-compatible client
|
||||||
|
* **🆕 Lifecycle Hooks**: Same Before/After hook system as ResolveSpec for auth and side-effects
|
||||||
|
|
||||||
### RestHeadSpec (v2.1+)
|
### RestHeadSpec (v2.1+)
|
||||||
|
|
||||||
* **🆕 Header-Based API**: All query options passed via HTTP headers instead of request body
|
* **🆕 Header-Based API**: All query options passed via HTTP headers instead of request body
|
||||||
@@ -190,6 +201,40 @@ restheadspec.SetupMuxRoutes(router, handler, nil)
|
|||||||
|
|
||||||
For complete documentation, see [pkg/restheadspec/README.md](pkg/restheadspec/README.md).
|
For complete documentation, see [pkg/restheadspec/README.md](pkg/restheadspec/README.md).
|
||||||
|
|
||||||
|
### ResolveMCP (MCP Server)
|
||||||
|
|
||||||
|
ResolveMCP exposes registered models as Model Context Protocol tools so AI models (Claude, Cursor, etc.) can query and mutate your database directly:
|
||||||
|
|
||||||
|
```go
|
||||||
|
import "github.com/bitechdev/ResolveSpec/pkg/resolvemcp"
|
||||||
|
|
||||||
|
// Create handler
|
||||||
|
handler := resolvemcp.NewHandlerWithGORM(db)
|
||||||
|
|
||||||
|
// Register models — must be done BEFORE Build()
|
||||||
|
handler.RegisterModel("public", "users", &User{})
|
||||||
|
handler.RegisterModel("public", "posts", &Post{})
|
||||||
|
|
||||||
|
// Finalize: registers MCP tools and resources
|
||||||
|
handler.Build()
|
||||||
|
|
||||||
|
// Mount SSE transport on your existing router
|
||||||
|
router := mux.NewRouter()
|
||||||
|
resolvemcp.SetupMuxRoutes(router, handler, "http://localhost:8080")
|
||||||
|
|
||||||
|
// MCP clients connect to:
|
||||||
|
// SSE stream: GET http://localhost:8080/mcp/sse
|
||||||
|
// Messages: POST http://localhost:8080/mcp/message
|
||||||
|
//
|
||||||
|
// Auto-registered tools per model:
|
||||||
|
// read_public_users — filter, sort, paginate, preload
|
||||||
|
// create_public_users — insert a new record
|
||||||
|
// update_public_users — update a record by ID
|
||||||
|
// delete_public_users — delete a record by ID
|
||||||
|
```
|
||||||
|
|
||||||
|
For complete documentation, see [pkg/resolvemcp/README.md](pkg/resolvemcp/README.md) (if present) or the package source.
|
||||||
|
|
||||||
## Architecture
|
## Architecture
|
||||||
|
|
||||||
### Two Complementary APIs
|
### Two Complementary APIs
|
||||||
@@ -344,6 +389,19 @@ Alternative REST API where query options are passed via HTTP headers.
|
|||||||
|
|
||||||
For complete documentation, see [pkg/restheadspec/README.md](pkg/restheadspec/README.md).
|
For complete documentation, see [pkg/restheadspec/README.md](pkg/restheadspec/README.md).
|
||||||
|
|
||||||
|
#### ResolveMCP - MCP Server
|
||||||
|
|
||||||
|
Expose any registered model as Model Context Protocol tools and resources consumable by AI models over HTTP/SSE.
|
||||||
|
|
||||||
|
**Key Features**:
|
||||||
|
- Four tools per model: `read_`, `create_`, `update_`, `delete_`
|
||||||
|
- Rich AI-readable descriptions: column names, types, primary key, nullable flags, and preloadable relations
|
||||||
|
- Full query support: filters, sort, limit/offset, cursor pagination, column selection, preloads
|
||||||
|
- HTTP/SSE transport compatible with Claude Desktop, Cursor, and any MCP client
|
||||||
|
- Same Before/After lifecycle hooks as ResolveSpec
|
||||||
|
|
||||||
|
For complete documentation, see [pkg/resolvemcp/](pkg/resolvemcp/).
|
||||||
|
|
||||||
#### FuncSpec - Function-Based SQL API
|
#### FuncSpec - Function-Based SQL API
|
||||||
|
|
||||||
Execute SQL functions and queries through a simple HTTP API with header-based parameters.
|
Execute SQL functions and queries through a simple HTTP API with header-based parameters.
|
||||||
@@ -529,7 +587,18 @@ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file
|
|||||||
|
|
||||||
## What's New
|
## What's New
|
||||||
|
|
||||||
### v3.1 (Latest - February 2026)
|
### v3.2 (Latest - March 2026)
|
||||||
|
|
||||||
|
**ResolveMCP - Model Context Protocol Server (🆕)**:
|
||||||
|
|
||||||
|
* **MCP Tools**: Four tools auto-registered per model (`read_`, `create_`, `update_`, `delete_`) over HTTP/SSE transport
|
||||||
|
* **AI-Ready Descriptions**: Full column schema, primary key, nullable flags, and relation names surfaced in tool descriptions so AI models can query without guessing
|
||||||
|
* **Full Query Support**: Filters, sort, limit/offset, cursor pagination, column selection, and relation preloading all available as tool parameters
|
||||||
|
* **HTTP/SSE Transport**: Standards-compliant transport compatible with Claude Desktop, Cursor, and any MCP 2024-11-05 client
|
||||||
|
* **Lifecycle Hooks**: Same Before/After hook system as ResolveSpec for auth, auditing, and side-effects
|
||||||
|
* **MCP Resources**: Each model also exposed as a named resource for direct data access by AI clients
|
||||||
|
|
||||||
|
### v3.1 (February 2026)
|
||||||
|
|
||||||
**SQLite Schema Translation (🆕)**:
|
**SQLite Schema Translation (🆕)**:
|
||||||
|
|
||||||
|
|||||||
@@ -18,11 +18,13 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// getCursorFilter generates a SQL EXISTS subquery for cursor-based pagination.
|
// getCursorFilter generates a SQL EXISTS subquery for cursor-based pagination.
|
||||||
|
// expandJoins is an optional map[alias]string of JOIN clauses for join-column sort support.
|
||||||
func getCursorFilter(
|
func getCursorFilter(
|
||||||
tableName string,
|
tableName string,
|
||||||
pkName string,
|
pkName string,
|
||||||
modelColumns []string,
|
modelColumns []string,
|
||||||
options common.RequestOptions,
|
options common.RequestOptions,
|
||||||
|
expandJoins map[string]string,
|
||||||
) (string, error) {
|
) (string, error) {
|
||||||
fullTableName := tableName
|
fullTableName := tableName
|
||||||
if strings.Contains(tableName, ".") {
|
if strings.Contains(tableName, ".") {
|
||||||
@@ -40,6 +42,7 @@ func getCursorFilter(
|
|||||||
}
|
}
|
||||||
|
|
||||||
var whereClauses []string
|
var whereClauses []string
|
||||||
|
joinSQL := ""
|
||||||
reverse := direction < 0
|
reverse := direction < 0
|
||||||
|
|
||||||
for _, s := range sortItems {
|
for _, s := range sortItems {
|
||||||
@@ -57,12 +60,27 @@ func getCursorFilter(
|
|||||||
desc = !desc
|
desc = !desc
|
||||||
}
|
}
|
||||||
|
|
||||||
cursorCol, targetCol, err := resolveCursorColumn(field, prefix, tableName, modelColumns)
|
cursorCol, targetCol, isJoin, err := resolveCursorColumn(field, prefix, tableName, modelColumns)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Warn("Skipping invalid sort column %q: %v", col, err)
|
logger.Warn("Skipping invalid sort column %q: %v", col, err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if isJoin {
|
||||||
|
if expandJoins != nil {
|
||||||
|
if joinClause, ok := expandJoins[prefix]; ok {
|
||||||
|
jSQL, cRef := rewriteCursorJoin(joinClause, tableName, prefix)
|
||||||
|
joinSQL = jSQL
|
||||||
|
cursorCol = cRef + "." + field
|
||||||
|
targetCol = prefix + "." + field
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if cursorCol == "" {
|
||||||
|
logger.Warn("Skipping cursor sort column %q: join alias %q not in expandJoins", col, prefix)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
op := "<"
|
op := "<"
|
||||||
if desc {
|
if desc {
|
||||||
op = ">"
|
op = ">"
|
||||||
@@ -79,10 +97,12 @@ func getCursorFilter(
|
|||||||
query := fmt.Sprintf(`EXISTS (
|
query := fmt.Sprintf(`EXISTS (
|
||||||
SELECT 1
|
SELECT 1
|
||||||
FROM %s cursor_select
|
FROM %s cursor_select
|
||||||
|
%s
|
||||||
WHERE cursor_select.%s = %s
|
WHERE cursor_select.%s = %s
|
||||||
AND (%s)
|
AND (%s)
|
||||||
)`,
|
)`,
|
||||||
fullTableName,
|
fullTableName,
|
||||||
|
joinSQL,
|
||||||
pkName,
|
pkName,
|
||||||
cursorID,
|
cursorID,
|
||||||
orSQL,
|
orSQL,
|
||||||
@@ -101,26 +121,34 @@ func getActiveCursor(options common.RequestOptions) (id string, direction cursor
|
|||||||
return "", 0
|
return "", 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func resolveCursorColumn(field, prefix, tableName string, modelColumns []string) (cursorCol, targetCol string, err error) {
|
func resolveCursorColumn(field, prefix, tableName string, modelColumns []string) (cursorCol, targetCol string, isJoin bool, err error) {
|
||||||
if strings.Contains(field, "->") {
|
if strings.Contains(field, "->") {
|
||||||
return "cursor_select." + field, tableName + "." + field, nil
|
return "cursor_select." + field, tableName + "." + field, false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if modelColumns != nil {
|
if modelColumns != nil {
|
||||||
for _, col := range modelColumns {
|
for _, col := range modelColumns {
|
||||||
if strings.EqualFold(col, field) {
|
if strings.EqualFold(col, field) {
|
||||||
return "cursor_select." + field, tableName + "." + field, nil
|
return "cursor_select." + field, tableName + "." + field, false, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
return "cursor_select." + field, tableName + "." + field, nil
|
return "cursor_select." + field, tableName + "." + field, false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if prefix != "" && prefix != tableName {
|
if prefix != "" && prefix != tableName {
|
||||||
return "", "", fmt.Errorf("joined columns not supported in cursor pagination: %s", field)
|
return "", "", true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return "", "", fmt.Errorf("invalid column: %s", field)
|
return "", "", false, fmt.Errorf("invalid column: %s", field)
|
||||||
|
}
|
||||||
|
|
||||||
|
func rewriteCursorJoin(joinClause, mainTable, alias string) (joinSQL, cursorAlias string) {
|
||||||
|
joinSQL = strings.ReplaceAll(joinClause, mainTable+".", "cursor_select.")
|
||||||
|
cursorAlias = "cursor_select_" + alias
|
||||||
|
joinSQL = strings.ReplaceAll(joinSQL, " "+alias+" ", " "+cursorAlias+" ")
|
||||||
|
joinSQL = strings.ReplaceAll(joinSQL, " "+alias+".", " "+cursorAlias+".")
|
||||||
|
return joinSQL, cursorAlias
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildCursorPriorityChain(clauses []string) string {
|
func buildCursorPriorityChain(clauses []string) string {
|
||||||
|
|||||||
@@ -191,7 +191,8 @@ func (h *Handler) executeRead(ctx context.Context, schema, entity, id string, op
|
|||||||
options.Sort = []common.SortOption{{Column: pkName, Direction: "ASC"}}
|
options.Sort = []common.SortOption{{Column: pkName, Direction: "ASC"}}
|
||||||
}
|
}
|
||||||
|
|
||||||
cursorFilter, err := getCursorFilter(tableName, pkName, modelColumns, options)
|
// expandJoins is empty for resolvemcp — no custom SQL join support yet
|
||||||
|
cursorFilter, err := getCursorFilter(tableName, pkName, modelColumns, options, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, fmt.Errorf("cursor error: %w", err)
|
return nil, nil, fmt.Errorf("cursor error: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,12 +4,14 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/mark3labs/mcp-go/mcp"
|
"github.com/mark3labs/mcp-go/mcp"
|
||||||
|
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
||||||
)
|
)
|
||||||
|
|
||||||
// toolName builds the MCP tool name for a given operation and model.
|
// toolName builds the MCP tool name for a given operation and model.
|
||||||
@@ -22,54 +24,268 @@ func toolName(operation, schema, entity string) string {
|
|||||||
|
|
||||||
// registerModelTools registers the four CRUD tools and resource for a model.
|
// registerModelTools registers the four CRUD tools and resource for a model.
|
||||||
func registerModelTools(h *Handler, schema, entity string, model interface{}) {
|
func registerModelTools(h *Handler, schema, entity string, model interface{}) {
|
||||||
registerReadTool(h, schema, entity)
|
info := buildModelInfo(schema, entity, model)
|
||||||
registerCreateTool(h, schema, entity)
|
registerReadTool(h, schema, entity, info)
|
||||||
registerUpdateTool(h, schema, entity)
|
registerCreateTool(h, schema, entity, info)
|
||||||
registerDeleteTool(h, schema, entity)
|
registerUpdateTool(h, schema, entity, info)
|
||||||
registerModelResource(h, schema, entity)
|
registerDeleteTool(h, schema, entity, info)
|
||||||
|
registerModelResource(h, schema, entity, info)
|
||||||
|
|
||||||
logger.Info("[resolvemcp] Registered MCP tools for %s.%s", schema, entity)
|
logger.Info("[resolvemcp] Registered MCP tools for %s", info.fullName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
// Model introspection
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// modelInfo holds pre-computed metadata for a model used in tool descriptions.
|
||||||
|
type modelInfo struct {
|
||||||
|
fullName string // e.g. "public.users"
|
||||||
|
pkName string // e.g. "id"
|
||||||
|
columns []columnInfo
|
||||||
|
relationNames []string
|
||||||
|
schemaDoc string // formatted multi-line schema listing
|
||||||
|
}
|
||||||
|
|
||||||
|
type columnInfo struct {
|
||||||
|
jsonName string
|
||||||
|
sqlName string
|
||||||
|
goType string
|
||||||
|
sqlType string
|
||||||
|
isPrimary bool
|
||||||
|
isUnique bool
|
||||||
|
isFK bool
|
||||||
|
nullable bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildModelInfo extracts column metadata and pre-builds the schema documentation string.
|
||||||
|
func buildModelInfo(schema, entity string, model interface{}) modelInfo {
|
||||||
|
info := modelInfo{
|
||||||
|
fullName: buildModelName(schema, entity),
|
||||||
|
pkName: reflection.GetPrimaryKeyName(model),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unwrap to base struct type
|
||||||
|
modelType := reflect.TypeOf(model)
|
||||||
|
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice) {
|
||||||
|
modelType = modelType.Elem()
|
||||||
|
}
|
||||||
|
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||||
|
return info
|
||||||
|
}
|
||||||
|
|
||||||
|
details := reflection.GetModelColumnDetail(reflect.New(modelType).Elem())
|
||||||
|
|
||||||
|
for _, d := range details {
|
||||||
|
// Derive the JSON name from the struct field
|
||||||
|
jsonName := fieldJSONName(modelType, d.Name)
|
||||||
|
if jsonName == "" || jsonName == "-" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip relation fields (slice or user-defined struct that isn't time.Time).
|
||||||
|
fieldType, found := modelType.FieldByName(d.Name)
|
||||||
|
if found {
|
||||||
|
ft := fieldType.Type
|
||||||
|
if ft.Kind() == reflect.Ptr {
|
||||||
|
ft = ft.Elem()
|
||||||
|
}
|
||||||
|
isUserStruct := ft.Kind() == reflect.Struct && ft.Name() != "Time" && ft.PkgPath() != ""
|
||||||
|
if ft.Kind() == reflect.Slice || isUserStruct {
|
||||||
|
info.relationNames = append(info.relationNames, jsonName)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sqlName := d.SQLName
|
||||||
|
if sqlName == "" {
|
||||||
|
sqlName = jsonName
|
||||||
|
}
|
||||||
|
|
||||||
|
// Derive Go type name, unwrapping pointer if needed.
|
||||||
|
goType := d.DataType
|
||||||
|
if goType == "" && found {
|
||||||
|
ft := fieldType.Type
|
||||||
|
for ft.Kind() == reflect.Ptr {
|
||||||
|
ft = ft.Elem()
|
||||||
|
}
|
||||||
|
goType = ft.Name()
|
||||||
|
}
|
||||||
|
|
||||||
|
// isPrimary: use both the GORM-tag detection and a name comparison against
|
||||||
|
// the known primary key (handles camelCase "primaryKey" tags correctly).
|
||||||
|
isPrimary := d.SQLKey == "primary_key" ||
|
||||||
|
(info.pkName != "" && (sqlName == info.pkName || jsonName == info.pkName))
|
||||||
|
|
||||||
|
ci := columnInfo{
|
||||||
|
jsonName: jsonName,
|
||||||
|
sqlName: sqlName,
|
||||||
|
goType: goType,
|
||||||
|
sqlType: d.SQLDataType,
|
||||||
|
isPrimary: isPrimary,
|
||||||
|
isUnique: d.SQLKey == "unique" || d.SQLKey == "uniqueindex",
|
||||||
|
isFK: d.SQLKey == "foreign_key",
|
||||||
|
nullable: d.Nullable,
|
||||||
|
}
|
||||||
|
info.columns = append(info.columns, ci)
|
||||||
|
}
|
||||||
|
|
||||||
|
info.schemaDoc = buildSchemaDoc(info)
|
||||||
|
return info
|
||||||
|
}
|
||||||
|
|
||||||
|
// fieldJSONName returns the JSON tag name for a struct field, falling back to the field name.
|
||||||
|
func fieldJSONName(modelType reflect.Type, fieldName string) string {
|
||||||
|
field, ok := modelType.FieldByName(fieldName)
|
||||||
|
if !ok {
|
||||||
|
return fieldName
|
||||||
|
}
|
||||||
|
tag := field.Tag.Get("json")
|
||||||
|
if tag == "" {
|
||||||
|
return fieldName
|
||||||
|
}
|
||||||
|
parts := strings.SplitN(tag, ",", 2)
|
||||||
|
if parts[0] == "" {
|
||||||
|
return fieldName
|
||||||
|
}
|
||||||
|
return parts[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildSchemaDoc builds a human-readable column listing for inclusion in tool descriptions.
|
||||||
|
func buildSchemaDoc(info modelInfo) string {
|
||||||
|
if len(info.columns) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
var sb strings.Builder
|
||||||
|
sb.WriteString("Columns:\n")
|
||||||
|
for _, c := range info.columns {
|
||||||
|
line := fmt.Sprintf(" • %s", c.jsonName)
|
||||||
|
|
||||||
|
typeDesc := c.goType
|
||||||
|
if c.sqlType != "" {
|
||||||
|
typeDesc = c.sqlType
|
||||||
|
}
|
||||||
|
if typeDesc != "" {
|
||||||
|
line += fmt.Sprintf(" (%s)", typeDesc)
|
||||||
|
}
|
||||||
|
|
||||||
|
var flags []string
|
||||||
|
if c.isPrimary {
|
||||||
|
flags = append(flags, "primary key")
|
||||||
|
}
|
||||||
|
if c.isUnique {
|
||||||
|
flags = append(flags, "unique")
|
||||||
|
}
|
||||||
|
if c.isFK {
|
||||||
|
flags = append(flags, "foreign key")
|
||||||
|
}
|
||||||
|
if !c.nullable && !c.isPrimary {
|
||||||
|
flags = append(flags, "not null")
|
||||||
|
} else if c.nullable {
|
||||||
|
flags = append(flags, "nullable")
|
||||||
|
}
|
||||||
|
if len(flags) > 0 {
|
||||||
|
line += " — " + strings.Join(flags, ", ")
|
||||||
|
}
|
||||||
|
|
||||||
|
sb.WriteString(line + "\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(info.relationNames) > 0 {
|
||||||
|
sb.WriteString("Relations (preloadable): " + strings.Join(info.relationNames, ", ") + "\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
return sb.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// columnNameList returns a comma-separated list of JSON column names (for descriptions).
|
||||||
|
func columnNameList(cols []columnInfo) string {
|
||||||
|
names := make([]string, len(cols))
|
||||||
|
for i, c := range cols {
|
||||||
|
names[i] = c.jsonName
|
||||||
|
}
|
||||||
|
return strings.Join(names, ", ")
|
||||||
|
}
|
||||||
|
|
||||||
|
// writableColumnNames returns JSON names for all non-primary-key columns.
|
||||||
|
func writableColumnNames(cols []columnInfo) []string {
|
||||||
|
var names []string
|
||||||
|
for _, c := range cols {
|
||||||
|
if !c.isPrimary {
|
||||||
|
names = append(names, c.jsonName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return names
|
||||||
}
|
}
|
||||||
|
|
||||||
// --------------------------------------------------------------------------
|
// --------------------------------------------------------------------------
|
||||||
// Read tool
|
// Read tool
|
||||||
// --------------------------------------------------------------------------
|
// --------------------------------------------------------------------------
|
||||||
|
|
||||||
func registerReadTool(h *Handler, schema, entity string) {
|
func registerReadTool(h *Handler, schema, entity string, info modelInfo) {
|
||||||
name := toolName("read", schema, entity)
|
name := toolName("read", schema, entity)
|
||||||
description := fmt.Sprintf("Read records from %s", buildModelName(schema, entity))
|
|
||||||
|
var descParts []string
|
||||||
|
descParts = append(descParts, fmt.Sprintf("Read records from the '%s' database table.", info.fullName))
|
||||||
|
if info.pkName != "" {
|
||||||
|
descParts = append(descParts, fmt.Sprintf("Primary key: '%s'. Pass it via 'id' to fetch a single record.", info.pkName))
|
||||||
|
}
|
||||||
|
if info.schemaDoc != "" {
|
||||||
|
descParts = append(descParts, info.schemaDoc)
|
||||||
|
}
|
||||||
|
descParts = append(descParts,
|
||||||
|
"Pagination: use 'limit'/'offset' for offset-based paging, or 'cursor_forward'/'cursor_backward' (pass the primary key value of the last/first record on the current page) for cursor-based paging.",
|
||||||
|
"Filtering: each filter object requires 'column' (JSON field name) and 'operator'. Supported operators: = != > < >= <= like ilike in is_null is_not_null. Combine with 'logic_operator': AND (default) or OR.",
|
||||||
|
"Sorting: each sort object requires 'column' and 'direction' (asc or desc).",
|
||||||
|
)
|
||||||
|
if len(info.relationNames) > 0 {
|
||||||
|
descParts = append(descParts, fmt.Sprintf("Preloadable relations: %s. Pass relation name in 'preloads'.", strings.Join(info.relationNames, ", ")))
|
||||||
|
}
|
||||||
|
|
||||||
|
description := strings.Join(descParts, "\n\n")
|
||||||
|
|
||||||
|
filterDesc := fmt.Sprintf(`Array of filter objects. Example: [{"column":"status","operator":"=","value":"active"},{"column":"age","operator":">","value":18,"logic_operator":"AND"}]`)
|
||||||
|
if len(info.columns) > 0 {
|
||||||
|
filterDesc += fmt.Sprintf(" Available columns: %s.", columnNameList(info.columns))
|
||||||
|
}
|
||||||
|
|
||||||
|
sortDesc := `Array of sort objects. Example: [{"column":"created_at","direction":"desc"}]`
|
||||||
|
if len(info.columns) > 0 {
|
||||||
|
sortDesc += fmt.Sprintf(" Available columns: %s.", columnNameList(info.columns))
|
||||||
|
}
|
||||||
|
|
||||||
tool := mcp.NewTool(name,
|
tool := mcp.NewTool(name,
|
||||||
mcp.WithDescription(description),
|
mcp.WithDescription(description),
|
||||||
mcp.WithString("id",
|
mcp.WithString("id",
|
||||||
mcp.Description("Primary key of a single record to fetch (optional)"),
|
mcp.Description(fmt.Sprintf("Primary key (%s) of a single record to fetch. Omit to return multiple records.", info.pkName)),
|
||||||
),
|
),
|
||||||
mcp.WithNumber("limit",
|
mcp.WithNumber("limit",
|
||||||
mcp.Description("Maximum number of records to return"),
|
mcp.Description("Maximum number of records to return per page. Recommended: 10–100."),
|
||||||
),
|
),
|
||||||
mcp.WithNumber("offset",
|
mcp.WithNumber("offset",
|
||||||
mcp.Description("Number of records to skip"),
|
mcp.Description("Number of records to skip (for offset-based pagination). Use with 'limit'."),
|
||||||
),
|
),
|
||||||
mcp.WithString("cursor_forward",
|
mcp.WithString("cursor_forward",
|
||||||
mcp.Description("Cursor value for the next page (primary key of last record on current page)"),
|
mcp.Description(fmt.Sprintf("Cursor for the next page: pass the '%s' value of the last record on the current page. Requires 'sort' to be set.", info.pkName)),
|
||||||
),
|
),
|
||||||
mcp.WithString("cursor_backward",
|
mcp.WithString("cursor_backward",
|
||||||
mcp.Description("Cursor value for the previous page"),
|
mcp.Description(fmt.Sprintf("Cursor for the previous page: pass the '%s' value of the first record on the current page. Requires 'sort' to be set.", info.pkName)),
|
||||||
),
|
),
|
||||||
mcp.WithArray("columns",
|
mcp.WithArray("columns",
|
||||||
mcp.Description("List of column names to include in the result"),
|
mcp.Description(fmt.Sprintf("Columns to include in the result. Omit to return all columns. Available: %s.", columnNameList(info.columns))),
|
||||||
),
|
),
|
||||||
mcp.WithArray("omit_columns",
|
mcp.WithArray("omit_columns",
|
||||||
mcp.Description("List of column names to exclude from the result"),
|
mcp.Description(fmt.Sprintf("Columns to exclude from the result. Available: %s.", columnNameList(info.columns))),
|
||||||
),
|
),
|
||||||
mcp.WithArray("filters",
|
mcp.WithArray("filters",
|
||||||
mcp.Description(`Array of filter objects. Each object: {"column":"name","operator":"=","value":"val","logic_operator":"AND|OR"}. Operators: =, !=, >, <, >=, <=, like, ilike, in, is_null, is_not_null`),
|
mcp.Description(filterDesc),
|
||||||
),
|
),
|
||||||
mcp.WithArray("sort",
|
mcp.WithArray("sort",
|
||||||
mcp.Description(`Array of sort objects. Each object: {"column":"name","direction":"asc|desc"}`),
|
mcp.Description(sortDesc),
|
||||||
),
|
),
|
||||||
mcp.WithArray("preloads",
|
mcp.WithArray("preloads",
|
||||||
mcp.Description(`Array of relation preload objects. Each object: {"relation":"RelationName","columns":["col1"]}`),
|
mcp.Description(buildPreloadDesc(info)),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -91,18 +307,52 @@ func registerReadTool(h *Handler, schema, entity string) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func buildPreloadDesc(info modelInfo) string {
|
||||||
|
if len(info.relationNames) == 0 {
|
||||||
|
return `Array of relation preload objects. Each object: {"relation":"RelationName"}. No relations defined on this model.`
|
||||||
|
}
|
||||||
|
return fmt.Sprintf(
|
||||||
|
`Array of relation preload objects. Each object: {"relation":"RelationName","columns":["col1","col2"]}. Available relations: %s.`,
|
||||||
|
strings.Join(info.relationNames, ", "),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
// --------------------------------------------------------------------------
|
// --------------------------------------------------------------------------
|
||||||
// Create tool
|
// Create tool
|
||||||
// --------------------------------------------------------------------------
|
// --------------------------------------------------------------------------
|
||||||
|
|
||||||
func registerCreateTool(h *Handler, schema, entity string) {
|
func registerCreateTool(h *Handler, schema, entity string, info modelInfo) {
|
||||||
name := toolName("create", schema, entity)
|
name := toolName("create", schema, entity)
|
||||||
description := fmt.Sprintf("Create one or more records in %s", buildModelName(schema, entity))
|
|
||||||
|
writable := writableColumnNames(info.columns)
|
||||||
|
|
||||||
|
var descParts []string
|
||||||
|
descParts = append(descParts, fmt.Sprintf("Create one or more new records in the '%s' table.", info.fullName))
|
||||||
|
if len(writable) > 0 {
|
||||||
|
descParts = append(descParts, fmt.Sprintf("Writable fields: %s.", strings.Join(writable, ", ")))
|
||||||
|
}
|
||||||
|
if info.pkName != "" {
|
||||||
|
descParts = append(descParts, fmt.Sprintf("The primary key ('%s') is typically auto-generated — omit it unless you need to supply it explicitly.", info.pkName))
|
||||||
|
}
|
||||||
|
descParts = append(descParts,
|
||||||
|
"Pass a single JSON object to 'data' to create one record. Pass an array of objects to create multiple records in a single transaction (all succeed or all fail).",
|
||||||
|
)
|
||||||
|
if info.schemaDoc != "" {
|
||||||
|
descParts = append(descParts, info.schemaDoc)
|
||||||
|
}
|
||||||
|
|
||||||
|
description := strings.Join(descParts, "\n\n")
|
||||||
|
|
||||||
|
dataDesc := "Record fields to create."
|
||||||
|
if len(writable) > 0 {
|
||||||
|
dataDesc += fmt.Sprintf(" Writable fields: %s.", strings.Join(writable, ", "))
|
||||||
|
}
|
||||||
|
dataDesc += " Pass a single object or an array of objects."
|
||||||
|
|
||||||
tool := mcp.NewTool(name,
|
tool := mcp.NewTool(name,
|
||||||
mcp.WithDescription(description),
|
mcp.WithDescription(description),
|
||||||
mcp.WithObject("data",
|
mcp.WithObject("data",
|
||||||
mcp.Description("Record fields to create (single object), or pass an array as the 'items' key"),
|
mcp.Description(dataDesc),
|
||||||
mcp.Required(),
|
mcp.Required(),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@@ -130,17 +380,42 @@ func registerCreateTool(h *Handler, schema, entity string) {
|
|||||||
// Update tool
|
// Update tool
|
||||||
// --------------------------------------------------------------------------
|
// --------------------------------------------------------------------------
|
||||||
|
|
||||||
func registerUpdateTool(h *Handler, schema, entity string) {
|
func registerUpdateTool(h *Handler, schema, entity string, info modelInfo) {
|
||||||
name := toolName("update", schema, entity)
|
name := toolName("update", schema, entity)
|
||||||
description := fmt.Sprintf("Update an existing record in %s", buildModelName(schema, entity))
|
|
||||||
|
writable := writableColumnNames(info.columns)
|
||||||
|
|
||||||
|
var descParts []string
|
||||||
|
descParts = append(descParts, fmt.Sprintf("Update an existing record in the '%s' table.", info.fullName))
|
||||||
|
if info.pkName != "" {
|
||||||
|
descParts = append(descParts, fmt.Sprintf("Identify the record by its primary key ('%s') via the 'id' argument or by including '%s' inside 'data'.", info.pkName, info.pkName))
|
||||||
|
}
|
||||||
|
if len(writable) > 0 {
|
||||||
|
descParts = append(descParts, fmt.Sprintf("Updatable fields: %s.", strings.Join(writable, ", ")))
|
||||||
|
}
|
||||||
|
descParts = append(descParts,
|
||||||
|
"Only non-null, non-empty fields in 'data' are applied — existing values are preserved for fields you omit. Returns the merged record as stored.",
|
||||||
|
)
|
||||||
|
if info.schemaDoc != "" {
|
||||||
|
descParts = append(descParts, info.schemaDoc)
|
||||||
|
}
|
||||||
|
|
||||||
|
description := strings.Join(descParts, "\n\n")
|
||||||
|
|
||||||
|
idDesc := fmt.Sprintf("Primary key ('%s') of the record to update. Can also be included inside 'data'.", info.pkName)
|
||||||
|
|
||||||
|
dataDesc := "Fields to update (non-null, non-empty values are merged into the existing record)."
|
||||||
|
if len(writable) > 0 {
|
||||||
|
dataDesc += fmt.Sprintf(" Updatable fields: %s.", strings.Join(writable, ", "))
|
||||||
|
}
|
||||||
|
|
||||||
tool := mcp.NewTool(name,
|
tool := mcp.NewTool(name,
|
||||||
mcp.WithDescription(description),
|
mcp.WithDescription(description),
|
||||||
mcp.WithString("id",
|
mcp.WithString("id",
|
||||||
mcp.Description("Primary key of the record to update"),
|
mcp.Description(idDesc),
|
||||||
),
|
),
|
||||||
mcp.WithObject("data",
|
mcp.WithObject("data",
|
||||||
mcp.Description("Fields to update (non-null fields will be merged into the existing record)"),
|
mcp.Description(dataDesc),
|
||||||
mcp.Required(),
|
mcp.Required(),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@@ -174,14 +449,23 @@ func registerUpdateTool(h *Handler, schema, entity string) {
|
|||||||
// Delete tool
|
// Delete tool
|
||||||
// --------------------------------------------------------------------------
|
// --------------------------------------------------------------------------
|
||||||
|
|
||||||
func registerDeleteTool(h *Handler, schema, entity string) {
|
func registerDeleteTool(h *Handler, schema, entity string, info modelInfo) {
|
||||||
name := toolName("delete", schema, entity)
|
name := toolName("delete", schema, entity)
|
||||||
description := fmt.Sprintf("Delete a record from %s by primary key", buildModelName(schema, entity))
|
|
||||||
|
descParts := []string{
|
||||||
|
fmt.Sprintf("Delete a record from the '%s' table by its primary key.", info.fullName),
|
||||||
|
}
|
||||||
|
if info.pkName != "" {
|
||||||
|
descParts = append(descParts, fmt.Sprintf("Pass the '%s' value of the record to delete via the 'id' argument.", info.pkName))
|
||||||
|
}
|
||||||
|
descParts = append(descParts, "Returns the deleted record. This operation is irreversible.")
|
||||||
|
|
||||||
|
description := strings.Join(descParts, " ")
|
||||||
|
|
||||||
tool := mcp.NewTool(name,
|
tool := mcp.NewTool(name,
|
||||||
mcp.WithDescription(description),
|
mcp.WithDescription(description),
|
||||||
mcp.WithString("id",
|
mcp.WithString("id",
|
||||||
mcp.Description("Primary key of the record to delete"),
|
mcp.Description(fmt.Sprintf("Primary key ('%s') of the record to delete.", info.pkName)),
|
||||||
mcp.Required(),
|
mcp.Required(),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@@ -206,17 +490,23 @@ func registerDeleteTool(h *Handler, schema, entity string) {
|
|||||||
// Resource registration
|
// Resource registration
|
||||||
// --------------------------------------------------------------------------
|
// --------------------------------------------------------------------------
|
||||||
|
|
||||||
func registerModelResource(h *Handler, schema, entity string) {
|
func registerModelResource(h *Handler, schema, entity string, info modelInfo) {
|
||||||
resourceURI := buildModelName(schema, entity)
|
resourceURI := info.fullName
|
||||||
displayName := entity
|
|
||||||
if schema != "" {
|
var resourceDesc strings.Builder
|
||||||
displayName = schema + "." + entity
|
resourceDesc.WriteString(fmt.Sprintf("Database table: %s", info.fullName))
|
||||||
|
if info.pkName != "" {
|
||||||
|
resourceDesc.WriteString(fmt.Sprintf(" (primary key: %s)", info.pkName))
|
||||||
|
}
|
||||||
|
if info.schemaDoc != "" {
|
||||||
|
resourceDesc.WriteString("\n\n")
|
||||||
|
resourceDesc.WriteString(info.schemaDoc)
|
||||||
}
|
}
|
||||||
|
|
||||||
resource := mcp.NewResource(
|
resource := mcp.NewResource(
|
||||||
resourceURI,
|
resourceURI,
|
||||||
displayName,
|
entity,
|
||||||
mcp.WithResourceDescription(fmt.Sprintf("Database table: %s", displayName)),
|
mcp.WithResourceDescription(resourceDesc.String()),
|
||||||
mcp.WithMIMEType("application/json"),
|
mcp.WithMIMEType("application/json"),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -256,7 +546,6 @@ func registerModelResource(h *Handler, schema, entity string) {
|
|||||||
func parseRequestOptions(args map[string]interface{}) common.RequestOptions {
|
func parseRequestOptions(args map[string]interface{}) common.RequestOptions {
|
||||||
options := common.RequestOptions{}
|
options := common.RequestOptions{}
|
||||||
|
|
||||||
// limit
|
|
||||||
if v, ok := args["limit"]; ok {
|
if v, ok := args["limit"]; ok {
|
||||||
switch n := v.(type) {
|
switch n := v.(type) {
|
||||||
case float64:
|
case float64:
|
||||||
@@ -267,7 +556,6 @@ func parseRequestOptions(args map[string]interface{}) common.RequestOptions {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// offset
|
|
||||||
if v, ok := args["offset"]; ok {
|
if v, ok := args["offset"]; ok {
|
||||||
switch n := v.(type) {
|
switch n := v.(type) {
|
||||||
case float64:
|
case float64:
|
||||||
@@ -278,7 +566,6 @@ func parseRequestOptions(args map[string]interface{}) common.RequestOptions {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// cursor_forward / cursor_backward
|
|
||||||
if v, ok := args["cursor_forward"].(string); ok {
|
if v, ok := args["cursor_forward"].(string); ok {
|
||||||
options.CursorForward = v
|
options.CursorForward = v
|
||||||
}
|
}
|
||||||
@@ -286,19 +573,10 @@ func parseRequestOptions(args map[string]interface{}) common.RequestOptions {
|
|||||||
options.CursorBackward = v
|
options.CursorBackward = v
|
||||||
}
|
}
|
||||||
|
|
||||||
// columns
|
|
||||||
options.Columns = parseStringArray(args["columns"])
|
options.Columns = parseStringArray(args["columns"])
|
||||||
|
|
||||||
// omit_columns
|
|
||||||
options.OmitColumns = parseStringArray(args["omit_columns"])
|
options.OmitColumns = parseStringArray(args["omit_columns"])
|
||||||
|
|
||||||
// filters — marshal each item and unmarshal into FilterOption
|
|
||||||
options.Filters = parseFilters(args["filters"])
|
options.Filters = parseFilters(args["filters"])
|
||||||
|
|
||||||
// sort
|
|
||||||
options.Sort = parseSortOptions(args["sort"])
|
options.Sort = parseSortOptions(args["sort"])
|
||||||
|
|
||||||
// preloads
|
|
||||||
options.Preload = parsePreloadOptions(args["preloads"])
|
options.Preload = parsePreloadOptions(args["preloads"])
|
||||||
|
|
||||||
return options
|
return options
|
||||||
@@ -342,7 +620,6 @@ func parseFilters(raw interface{}) []common.FilterOption {
|
|||||||
if f.Column == "" || f.Operator == "" {
|
if f.Column == "" || f.Operator == "" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
// Normalise logic operator
|
|
||||||
if strings.EqualFold(f.LogicOperator, "or") {
|
if strings.EqualFold(f.LogicOperator, "or") {
|
||||||
f.LogicOperator = "OR"
|
f.LogicOperator = "OR"
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ const (
|
|||||||
// - pkName: primary key column (e.g. "id")
|
// - pkName: primary key column (e.g. "id")
|
||||||
// - modelColumns: optional list of valid main-table columns (for validation). Pass nil to skip.
|
// - modelColumns: optional list of valid main-table columns (for validation). Pass nil to skip.
|
||||||
// - options: the request options containing sort and cursor information
|
// - options: the request options containing sort and cursor information
|
||||||
|
// - expandJoins: optional map[alias]string of JOIN clauses for join-column sort support
|
||||||
//
|
//
|
||||||
// Returns SQL snippet to embed in WHERE clause.
|
// Returns SQL snippet to embed in WHERE clause.
|
||||||
func GetCursorFilter(
|
func GetCursorFilter(
|
||||||
@@ -31,6 +32,7 @@ func GetCursorFilter(
|
|||||||
pkName string,
|
pkName string,
|
||||||
modelColumns []string,
|
modelColumns []string,
|
||||||
options common.RequestOptions,
|
options common.RequestOptions,
|
||||||
|
expandJoins map[string]string,
|
||||||
) (string, error) {
|
) (string, error) {
|
||||||
// Separate schema prefix from bare table name
|
// Separate schema prefix from bare table name
|
||||||
fullTableName := tableName
|
fullTableName := tableName
|
||||||
@@ -58,6 +60,7 @@ func GetCursorFilter(
|
|||||||
// 3. Prepare
|
// 3. Prepare
|
||||||
// --------------------------------------------------------------------- //
|
// --------------------------------------------------------------------- //
|
||||||
var whereClauses []string
|
var whereClauses []string
|
||||||
|
joinSQL := ""
|
||||||
reverse := direction < 0
|
reverse := direction < 0
|
||||||
|
|
||||||
// --------------------------------------------------------------------- //
|
// --------------------------------------------------------------------- //
|
||||||
@@ -69,7 +72,7 @@ func GetCursorFilter(
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse: "created_at", "user.name", etc.
|
// Parse: "created_at", "user.name", "fn.sortorder", etc.
|
||||||
parts := strings.Split(col, ".")
|
parts := strings.Split(col, ".")
|
||||||
field := strings.TrimSpace(parts[len(parts)-1])
|
field := strings.TrimSpace(parts[len(parts)-1])
|
||||||
prefix := strings.Join(parts[:len(parts)-1], ".")
|
prefix := strings.Join(parts[:len(parts)-1], ".")
|
||||||
@@ -82,7 +85,7 @@ func GetCursorFilter(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Resolve column
|
// Resolve column
|
||||||
cursorCol, targetCol, err := resolveColumn(
|
cursorCol, targetCol, isJoin, err := resolveColumn(
|
||||||
field, prefix, tableName, modelColumns,
|
field, prefix, tableName, modelColumns,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -90,6 +93,22 @@ func GetCursorFilter(
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Handle joins
|
||||||
|
if isJoin {
|
||||||
|
if expandJoins != nil {
|
||||||
|
if joinClause, ok := expandJoins[prefix]; ok {
|
||||||
|
jSQL, cRef := rewriteJoin(joinClause, tableName, prefix)
|
||||||
|
joinSQL = jSQL
|
||||||
|
cursorCol = cRef + "." + field
|
||||||
|
targetCol = prefix + "." + field
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if cursorCol == "" {
|
||||||
|
logger.Warn("Skipping cursor sort column %q: join alias %q not in expandJoins", col, prefix)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Build inequality
|
// Build inequality
|
||||||
op := "<"
|
op := "<"
|
||||||
if desc {
|
if desc {
|
||||||
@@ -113,10 +132,12 @@ func GetCursorFilter(
|
|||||||
query := fmt.Sprintf(`EXISTS (
|
query := fmt.Sprintf(`EXISTS (
|
||||||
SELECT 1
|
SELECT 1
|
||||||
FROM %s cursor_select
|
FROM %s cursor_select
|
||||||
|
%s
|
||||||
WHERE cursor_select.%s = %s
|
WHERE cursor_select.%s = %s
|
||||||
AND (%s)
|
AND (%s)
|
||||||
)`,
|
)`,
|
||||||
fullTableName,
|
fullTableName,
|
||||||
|
joinSQL,
|
||||||
pkName,
|
pkName,
|
||||||
cursorID,
|
cursorID,
|
||||||
orSQL,
|
orSQL,
|
||||||
@@ -137,35 +158,44 @@ func getActiveCursor(options common.RequestOptions) (id string, direction Cursor
|
|||||||
return "", 0
|
return "", 0
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper: resolve column (main table only for now)
|
// Helper: resolve column (main table or join)
|
||||||
func resolveColumn(
|
func resolveColumn(
|
||||||
field, prefix, tableName string,
|
field, prefix, tableName string,
|
||||||
modelColumns []string,
|
modelColumns []string,
|
||||||
) (cursorCol, targetCol string, err error) {
|
) (cursorCol, targetCol string, isJoin bool, err error) {
|
||||||
|
|
||||||
// JSON field
|
// JSON field
|
||||||
if strings.Contains(field, "->") {
|
if strings.Contains(field, "->") {
|
||||||
return "cursor_select." + field, tableName + "." + field, nil
|
return "cursor_select." + field, tableName + "." + field, false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Main table column
|
// Main table column
|
||||||
if modelColumns != nil {
|
if modelColumns != nil {
|
||||||
for _, col := range modelColumns {
|
for _, col := range modelColumns {
|
||||||
if strings.EqualFold(col, field) {
|
if strings.EqualFold(col, field) {
|
||||||
return "cursor_select." + field, tableName + "." + field, nil
|
return "cursor_select." + field, tableName + "." + field, false, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// No validation → allow all main-table fields
|
// No validation → allow all main-table fields
|
||||||
return "cursor_select." + field, tableName + "." + field, nil
|
return "cursor_select." + field, tableName + "." + field, false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Joined column (not supported in resolvespec yet)
|
// Joined column
|
||||||
if prefix != "" && prefix != tableName {
|
if prefix != "" && prefix != tableName {
|
||||||
return "", "", fmt.Errorf("joined columns not supported in cursor pagination: %s", field)
|
return "", "", true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return "", "", fmt.Errorf("invalid column: %s", field)
|
return "", "", false, fmt.Errorf("invalid column: %s", field)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper: rewrite JOIN clause for cursor subquery
|
||||||
|
func rewriteJoin(joinClause, mainTable, alias string) (joinSQL, cursorAlias string) {
|
||||||
|
joinSQL = strings.ReplaceAll(joinClause, mainTable+".", "cursor_select.")
|
||||||
|
cursorAlias = "cursor_select_" + alias
|
||||||
|
joinSQL = strings.ReplaceAll(joinSQL, " "+alias+" ", " "+cursorAlias+" ")
|
||||||
|
joinSQL = strings.ReplaceAll(joinSQL, " "+alias+".", " "+cursorAlias+".")
|
||||||
|
return joinSQL, cursorAlias
|
||||||
}
|
}
|
||||||
|
|
||||||
// ------------------------------------------------------------------------- //
|
// ------------------------------------------------------------------------- //
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ func TestGetCursorFilter_Forward(t *testing.T) {
|
|||||||
pkName := "id"
|
pkName := "id"
|
||||||
modelColumns := []string{"id", "title", "created_at", "user_id"}
|
modelColumns := []string{"id", "title", "created_at", "user_id"}
|
||||||
|
|
||||||
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options)
|
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("GetCursorFilter failed: %v", err)
|
t.Fatalf("GetCursorFilter failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -65,7 +65,7 @@ func TestGetCursorFilter_Backward(t *testing.T) {
|
|||||||
pkName := "id"
|
pkName := "id"
|
||||||
modelColumns := []string{"id", "title", "created_at", "user_id"}
|
modelColumns := []string{"id", "title", "created_at", "user_id"}
|
||||||
|
|
||||||
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options)
|
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("GetCursorFilter failed: %v", err)
|
t.Fatalf("GetCursorFilter failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -96,7 +96,7 @@ func TestGetCursorFilter_NoCursor(t *testing.T) {
|
|||||||
pkName := "id"
|
pkName := "id"
|
||||||
modelColumns := []string{"id", "title", "created_at"}
|
modelColumns := []string{"id", "title", "created_at"}
|
||||||
|
|
||||||
_, err := GetCursorFilter(tableName, pkName, modelColumns, options)
|
_, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("Expected error when no cursor is provided")
|
t.Error("Expected error when no cursor is provided")
|
||||||
}
|
}
|
||||||
@@ -116,7 +116,7 @@ func TestGetCursorFilter_NoSort(t *testing.T) {
|
|||||||
pkName := "id"
|
pkName := "id"
|
||||||
modelColumns := []string{"id", "title"}
|
modelColumns := []string{"id", "title"}
|
||||||
|
|
||||||
_, err := GetCursorFilter(tableName, pkName, modelColumns, options)
|
_, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("Expected error when no sort columns are defined")
|
t.Error("Expected error when no sort columns are defined")
|
||||||
}
|
}
|
||||||
@@ -140,7 +140,7 @@ func TestGetCursorFilter_MultiColumnSort(t *testing.T) {
|
|||||||
pkName := "id"
|
pkName := "id"
|
||||||
modelColumns := []string{"id", "title", "priority", "created_at"}
|
modelColumns := []string{"id", "title", "priority", "created_at"}
|
||||||
|
|
||||||
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options)
|
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("GetCursorFilter failed: %v", err)
|
t.Fatalf("GetCursorFilter failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -170,7 +170,7 @@ func TestGetCursorFilter_WithSchemaPrefix(t *testing.T) {
|
|||||||
pkName := "id"
|
pkName := "id"
|
||||||
modelColumns := []string{"id", "name", "email"}
|
modelColumns := []string{"id", "name", "email"}
|
||||||
|
|
||||||
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options)
|
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("GetCursorFilter failed: %v", err)
|
t.Fatalf("GetCursorFilter failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -183,6 +183,37 @@ func TestGetCursorFilter_WithSchemaPrefix(t *testing.T) {
|
|||||||
t.Logf("Generated cursor filter with schema: %s", filter)
|
t.Logf("Generated cursor filter with schema: %s", filter)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGetCursorFilter_LateralJoin(t *testing.T) {
|
||||||
|
lateralJoin := "inner join lateral (\nselect string_agg(a.name, '.') as sortorder\nfrom tree(account.rid_account) r\ninner join account a on a.id = r.id\n) fn on true"
|
||||||
|
|
||||||
|
options := common.RequestOptions{
|
||||||
|
Sort: []common.SortOption{{Column: "fn.sortorder", Direction: "ASC"}},
|
||||||
|
CursorForward: "8975",
|
||||||
|
}
|
||||||
|
|
||||||
|
tableName := "core.account"
|
||||||
|
pkName := "rid_account"
|
||||||
|
modelColumns := []string{"rid_account", "description", "pastelno"}
|
||||||
|
expandJoins := map[string]string{"fn": lateralJoin}
|
||||||
|
|
||||||
|
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options, expandJoins)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetCursorFilter failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Logf("Generated lateral cursor filter: %s", filter)
|
||||||
|
|
||||||
|
if !strings.Contains(filter, "cursor_select_fn") {
|
||||||
|
t.Errorf("Filter should reference cursor_select_fn alias, got: %s", filter)
|
||||||
|
}
|
||||||
|
if !strings.Contains(filter, "sortorder") {
|
||||||
|
t.Errorf("Filter should reference sortorder column, got: %s", filter)
|
||||||
|
}
|
||||||
|
if strings.Contains(filter, " < ") || strings.Contains(filter, " > ") {
|
||||||
|
t.Errorf("Filter should not contain empty comparison operators, got: %s", filter)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestGetActiveCursor(t *testing.T) {
|
func TestGetActiveCursor(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -288,18 +319,19 @@ func TestResolveColumn(t *testing.T) {
|
|||||||
wantErr: false,
|
wantErr: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Joined column (not supported)",
|
name: "Joined column (isJoin=true, no error)",
|
||||||
field: "name",
|
field: "name",
|
||||||
prefix: "user",
|
prefix: "user",
|
||||||
tableName: "posts",
|
tableName: "posts",
|
||||||
modelColumns: []string{"id", "title"},
|
modelColumns: []string{"id", "title"},
|
||||||
wantErr: true,
|
wantErr: false,
|
||||||
|
// cursorCol and targetCol are empty when isJoin=true; handled by caller
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
cursor, target, err := resolveColumn(tt.field, tt.prefix, tt.tableName, tt.modelColumns)
|
cursor, target, isJoin, err := resolveColumn(tt.field, tt.prefix, tt.tableName, tt.modelColumns)
|
||||||
|
|
||||||
if tt.wantErr {
|
if tt.wantErr {
|
||||||
if err == nil {
|
if err == nil {
|
||||||
@@ -312,6 +344,14 @@ func TestResolveColumn(t *testing.T) {
|
|||||||
t.Fatalf("Unexpected error: %v", err)
|
t.Fatalf("Unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// For join columns, cursor/target are empty and isJoin=true
|
||||||
|
if isJoin {
|
||||||
|
if cursor != "" || target != "" {
|
||||||
|
t.Errorf("Expected empty cursor/target for join column, got %q / %q", cursor, target)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if cursor != tt.wantCursor {
|
if cursor != tt.wantCursor {
|
||||||
t.Errorf("Expected cursor %q, got %q", tt.wantCursor, cursor)
|
t.Errorf("Expected cursor %q, got %q", tt.wantCursor, cursor)
|
||||||
}
|
}
|
||||||
@@ -362,7 +402,7 @@ func TestCursorFilter_SQL_Safety(t *testing.T) {
|
|||||||
pkName := "id"
|
pkName := "id"
|
||||||
modelColumns := []string{"id", "created_at"}
|
modelColumns := []string{"id", "created_at"}
|
||||||
|
|
||||||
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options)
|
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("GetCursorFilter failed: %v", err)
|
t.Fatalf("GetCursorFilter failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -334,8 +334,8 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
options.Sort = []common.SortOption{{Column: pkName, Direction: "ASC"}}
|
options.Sort = []common.SortOption{{Column: pkName, Direction: "ASC"}}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get cursor filter SQL
|
// Get cursor filter SQL (expandJoins is empty for resolvespec — no custom SQL join support yet)
|
||||||
cursorFilter, err := GetCursorFilter(tableName, pkName, modelColumns, options)
|
cursorFilter, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Error building cursor filter: %v", err)
|
logger.Error("Error building cursor filter: %v", err)
|
||||||
h.sendError(w, http.StatusBadRequest, "cursor_error", "Invalid cursor pagination", err)
|
h.sendError(w, http.StatusBadRequest, "cursor_error", "Invalid cursor pagination", err)
|
||||||
|
|||||||
Reference in New Issue
Block a user