mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2026-01-01 01:34:25 +00:00
Cursor pagnation for resolvespec
This commit is contained in:
179
pkg/resolvespec/cursor.go
Normal file
179
pkg/resolvespec/cursor.go
Normal file
@@ -0,0 +1,179 @@
|
||||
package resolvespec
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
)
|
||||
|
||||
// CursorDirection defines pagination direction
|
||||
type CursorDirection int
|
||||
|
||||
const (
|
||||
CursorForward CursorDirection = 1
|
||||
CursorBackward CursorDirection = -1
|
||||
)
|
||||
|
||||
// GetCursorFilter generates a SQL `EXISTS` subquery for cursor-based pagination.
|
||||
// It uses the current request's sort and cursor values.
|
||||
//
|
||||
// Parameters:
|
||||
// - tableName: name of the main table (e.g. "posts")
|
||||
// - pkName: primary key column (e.g. "id")
|
||||
// - modelColumns: optional list of valid main-table columns (for validation). Pass nil to skip.
|
||||
// - options: the request options containing sort and cursor information
|
||||
//
|
||||
// Returns SQL snippet to embed in WHERE clause.
|
||||
func GetCursorFilter(
|
||||
tableName string,
|
||||
pkName string,
|
||||
modelColumns []string,
|
||||
options common.RequestOptions,
|
||||
) (string, error) {
|
||||
// Remove schema prefix if present
|
||||
if strings.Contains(tableName, ".") {
|
||||
tableName = strings.SplitN(tableName, ".", 2)[1]
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------- //
|
||||
// 1. Determine active cursor
|
||||
// --------------------------------------------------------------------- //
|
||||
cursorID, direction := getActiveCursor(options)
|
||||
if cursorID == "" {
|
||||
return "", fmt.Errorf("no cursor provided for table %s", tableName)
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------- //
|
||||
// 2. Extract sort columns
|
||||
// --------------------------------------------------------------------- //
|
||||
sortItems := options.Sort
|
||||
if len(sortItems) == 0 {
|
||||
return "", fmt.Errorf("no sort columns defined")
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------- //
|
||||
// 3. Prepare
|
||||
// --------------------------------------------------------------------- //
|
||||
var whereClauses []string
|
||||
reverse := direction < 0
|
||||
|
||||
// --------------------------------------------------------------------- //
|
||||
// 4. Process each sort column
|
||||
// --------------------------------------------------------------------- //
|
||||
for _, s := range sortItems {
|
||||
col := strings.TrimSpace(s.Column)
|
||||
if col == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Parse: "created_at", "user.name", etc.
|
||||
parts := strings.Split(col, ".")
|
||||
field := strings.TrimSpace(parts[len(parts)-1])
|
||||
prefix := strings.Join(parts[:len(parts)-1], ".")
|
||||
|
||||
// Direction from struct
|
||||
desc := strings.EqualFold(s.Direction, "desc")
|
||||
|
||||
if reverse {
|
||||
desc = !desc
|
||||
}
|
||||
|
||||
// Resolve column
|
||||
cursorCol, targetCol, err := resolveColumn(
|
||||
field, prefix, tableName, modelColumns,
|
||||
)
|
||||
if err != nil {
|
||||
logger.Warn("Skipping invalid sort column %q: %v", col, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Build inequality
|
||||
op := "<"
|
||||
if desc {
|
||||
op = ">"
|
||||
}
|
||||
whereClauses = append(whereClauses, fmt.Sprintf("%s %s %s", cursorCol, op, targetCol))
|
||||
}
|
||||
|
||||
if len(whereClauses) == 0 {
|
||||
return "", fmt.Errorf("no valid sort columns after filtering")
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------- //
|
||||
// 5. Build priority OR-AND chain
|
||||
// --------------------------------------------------------------------- //
|
||||
orSQL := buildPriorityChain(whereClauses)
|
||||
|
||||
// --------------------------------------------------------------------- //
|
||||
// 6. Final EXISTS subquery
|
||||
// --------------------------------------------------------------------- //
|
||||
query := fmt.Sprintf(`EXISTS (
|
||||
SELECT 1
|
||||
FROM %s cursor_select
|
||||
WHERE cursor_select.%s = %s
|
||||
AND (%s)
|
||||
)`,
|
||||
tableName,
|
||||
pkName,
|
||||
cursorID,
|
||||
orSQL,
|
||||
)
|
||||
|
||||
return query, nil
|
||||
}
|
||||
|
||||
// ------------------------------------------------------------------------- //
|
||||
// Helper: get active cursor (forward or backward)
|
||||
func getActiveCursor(options common.RequestOptions) (id string, direction CursorDirection) {
|
||||
if options.CursorForward != "" {
|
||||
return options.CursorForward, CursorForward
|
||||
}
|
||||
if options.CursorBackward != "" {
|
||||
return options.CursorBackward, CursorBackward
|
||||
}
|
||||
return "", 0
|
||||
}
|
||||
|
||||
// Helper: resolve column (main table only for now)
|
||||
func resolveColumn(
|
||||
field, prefix, tableName string,
|
||||
modelColumns []string,
|
||||
) (cursorCol, targetCol string, err error) {
|
||||
|
||||
// JSON field
|
||||
if strings.Contains(field, "->") {
|
||||
return "cursor_select." + field, tableName + "." + field, nil
|
||||
}
|
||||
|
||||
// Main table column
|
||||
if modelColumns != nil {
|
||||
for _, col := range modelColumns {
|
||||
if strings.EqualFold(col, field) {
|
||||
return "cursor_select." + field, tableName + "." + field, nil
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// No validation → allow all main-table fields
|
||||
return "cursor_select." + field, tableName + "." + field, nil
|
||||
}
|
||||
|
||||
// Joined column (not supported in resolvespec yet)
|
||||
if prefix != "" && prefix != tableName {
|
||||
return "", "", fmt.Errorf("joined columns not supported in cursor pagination: %s", field)
|
||||
}
|
||||
|
||||
return "", "", fmt.Errorf("invalid column: %s", field)
|
||||
}
|
||||
|
||||
// ------------------------------------------------------------------------- //
|
||||
// Helper: build OR-AND priority chain
|
||||
func buildPriorityChain(clauses []string) string {
|
||||
var or []string
|
||||
for i := 0; i < len(clauses); i++ {
|
||||
and := strings.Join(clauses[:i+1], "\n AND ")
|
||||
or = append(or, "("+and+")")
|
||||
}
|
||||
return strings.Join(or, "\n OR ")
|
||||
}
|
||||
378
pkg/resolvespec/cursor_test.go
Normal file
378
pkg/resolvespec/cursor_test.go
Normal file
@@ -0,0 +1,378 @@
|
||||
package resolvespec
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
)
|
||||
|
||||
func TestGetCursorFilter_Forward(t *testing.T) {
|
||||
options := common.RequestOptions{
|
||||
Sort: []common.SortOption{
|
||||
{Column: "created_at", Direction: "DESC"},
|
||||
{Column: "id", Direction: "ASC"},
|
||||
},
|
||||
CursorForward: "123",
|
||||
}
|
||||
|
||||
tableName := "posts"
|
||||
pkName := "id"
|
||||
modelColumns := []string{"id", "title", "created_at", "user_id"}
|
||||
|
||||
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options)
|
||||
if err != nil {
|
||||
t.Fatalf("GetCursorFilter failed: %v", err)
|
||||
}
|
||||
|
||||
if filter == "" {
|
||||
t.Fatal("Expected non-empty cursor filter")
|
||||
}
|
||||
|
||||
// Verify filter contains EXISTS subquery
|
||||
if !strings.Contains(filter, "EXISTS") {
|
||||
t.Errorf("Filter should contain EXISTS subquery, got: %s", filter)
|
||||
}
|
||||
|
||||
// Verify filter references the cursor ID
|
||||
if !strings.Contains(filter, "123") {
|
||||
t.Errorf("Filter should reference cursor ID 123, got: %s", filter)
|
||||
}
|
||||
|
||||
// Verify filter contains the table name
|
||||
if !strings.Contains(filter, tableName) {
|
||||
t.Errorf("Filter should reference table name %s, got: %s", tableName, filter)
|
||||
}
|
||||
|
||||
// Verify filter contains primary key
|
||||
if !strings.Contains(filter, pkName) {
|
||||
t.Errorf("Filter should reference primary key %s, got: %s", pkName, filter)
|
||||
}
|
||||
|
||||
t.Logf("Generated cursor filter: %s", filter)
|
||||
}
|
||||
|
||||
func TestGetCursorFilter_Backward(t *testing.T) {
|
||||
options := common.RequestOptions{
|
||||
Sort: []common.SortOption{
|
||||
{Column: "created_at", Direction: "DESC"},
|
||||
{Column: "id", Direction: "ASC"},
|
||||
},
|
||||
CursorBackward: "456",
|
||||
}
|
||||
|
||||
tableName := "posts"
|
||||
pkName := "id"
|
||||
modelColumns := []string{"id", "title", "created_at", "user_id"}
|
||||
|
||||
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options)
|
||||
if err != nil {
|
||||
t.Fatalf("GetCursorFilter failed: %v", err)
|
||||
}
|
||||
|
||||
if filter == "" {
|
||||
t.Fatal("Expected non-empty cursor filter")
|
||||
}
|
||||
|
||||
// Verify filter contains cursor ID
|
||||
if !strings.Contains(filter, "456") {
|
||||
t.Errorf("Filter should reference cursor ID 456, got: %s", filter)
|
||||
}
|
||||
|
||||
// For backward cursor, sort direction should be reversed
|
||||
// This is handled internally by the GetCursorFilter function
|
||||
t.Logf("Generated backward cursor filter: %s", filter)
|
||||
}
|
||||
|
||||
func TestGetCursorFilter_NoCursor(t *testing.T) {
|
||||
options := common.RequestOptions{
|
||||
Sort: []common.SortOption{
|
||||
{Column: "created_at", Direction: "DESC"},
|
||||
},
|
||||
// No cursor set
|
||||
}
|
||||
|
||||
tableName := "posts"
|
||||
pkName := "id"
|
||||
modelColumns := []string{"id", "title", "created_at"}
|
||||
|
||||
_, err := GetCursorFilter(tableName, pkName, modelColumns, options)
|
||||
if err == nil {
|
||||
t.Error("Expected error when no cursor is provided")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "no cursor provided") {
|
||||
t.Errorf("Expected 'no cursor provided' error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetCursorFilter_NoSort(t *testing.T) {
|
||||
options := common.RequestOptions{
|
||||
Sort: []common.SortOption{},
|
||||
CursorForward: "123",
|
||||
}
|
||||
|
||||
tableName := "posts"
|
||||
pkName := "id"
|
||||
modelColumns := []string{"id", "title"}
|
||||
|
||||
_, err := GetCursorFilter(tableName, pkName, modelColumns, options)
|
||||
if err == nil {
|
||||
t.Error("Expected error when no sort columns are defined")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "no sort columns") {
|
||||
t.Errorf("Expected 'no sort columns' error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetCursorFilter_MultiColumnSort(t *testing.T) {
|
||||
options := common.RequestOptions{
|
||||
Sort: []common.SortOption{
|
||||
{Column: "priority", Direction: "DESC"},
|
||||
{Column: "created_at", Direction: "DESC"},
|
||||
{Column: "id", Direction: "ASC"},
|
||||
},
|
||||
CursorForward: "789",
|
||||
}
|
||||
|
||||
tableName := "tasks"
|
||||
pkName := "id"
|
||||
modelColumns := []string{"id", "title", "priority", "created_at"}
|
||||
|
||||
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options)
|
||||
if err != nil {
|
||||
t.Fatalf("GetCursorFilter failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify filter contains priority column
|
||||
if !strings.Contains(filter, "priority") {
|
||||
t.Errorf("Filter should reference priority column, got: %s", filter)
|
||||
}
|
||||
|
||||
// Verify filter contains created_at column
|
||||
if !strings.Contains(filter, "created_at") {
|
||||
t.Errorf("Filter should reference created_at column, got: %s", filter)
|
||||
}
|
||||
|
||||
t.Logf("Generated multi-column cursor filter: %s", filter)
|
||||
}
|
||||
|
||||
func TestGetCursorFilter_WithSchemaPrefix(t *testing.T) {
|
||||
options := common.RequestOptions{
|
||||
Sort: []common.SortOption{
|
||||
{Column: "name", Direction: "ASC"},
|
||||
},
|
||||
CursorForward: "100",
|
||||
}
|
||||
|
||||
tableName := "public.users"
|
||||
pkName := "id"
|
||||
modelColumns := []string{"id", "name", "email"}
|
||||
|
||||
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options)
|
||||
if err != nil {
|
||||
t.Fatalf("GetCursorFilter failed: %v", err)
|
||||
}
|
||||
|
||||
// Should handle schema prefix properly
|
||||
if !strings.Contains(filter, "users") {
|
||||
t.Errorf("Filter should reference table name users, got: %s", filter)
|
||||
}
|
||||
|
||||
t.Logf("Generated cursor filter with schema: %s", filter)
|
||||
}
|
||||
|
||||
func TestGetActiveCursor(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
options common.RequestOptions
|
||||
expectedID string
|
||||
expectedDirection CursorDirection
|
||||
}{
|
||||
{
|
||||
name: "Forward cursor only",
|
||||
options: common.RequestOptions{
|
||||
CursorForward: "123",
|
||||
},
|
||||
expectedID: "123",
|
||||
expectedDirection: CursorForward,
|
||||
},
|
||||
{
|
||||
name: "Backward cursor only",
|
||||
options: common.RequestOptions{
|
||||
CursorBackward: "456",
|
||||
},
|
||||
expectedID: "456",
|
||||
expectedDirection: CursorBackward,
|
||||
},
|
||||
{
|
||||
name: "Both cursors - forward takes precedence",
|
||||
options: common.RequestOptions{
|
||||
CursorForward: "123",
|
||||
CursorBackward: "456",
|
||||
},
|
||||
expectedID: "123",
|
||||
expectedDirection: CursorForward,
|
||||
},
|
||||
{
|
||||
name: "No cursors",
|
||||
options: common.RequestOptions{},
|
||||
expectedID: "",
|
||||
expectedDirection: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
id, direction := getActiveCursor(tt.options)
|
||||
|
||||
if id != tt.expectedID {
|
||||
t.Errorf("Expected cursor ID %q, got %q", tt.expectedID, id)
|
||||
}
|
||||
|
||||
if direction != tt.expectedDirection {
|
||||
t.Errorf("Expected direction %d, got %d", tt.expectedDirection, direction)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveColumn(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
field string
|
||||
prefix string
|
||||
tableName string
|
||||
modelColumns []string
|
||||
wantCursor string
|
||||
wantTarget string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "Simple column",
|
||||
field: "id",
|
||||
prefix: "",
|
||||
tableName: "users",
|
||||
modelColumns: []string{"id", "name", "email"},
|
||||
wantCursor: "cursor_select.id",
|
||||
wantTarget: "users.id",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Column with case insensitive match",
|
||||
field: "NAME",
|
||||
prefix: "",
|
||||
tableName: "users",
|
||||
modelColumns: []string{"id", "name", "email"},
|
||||
wantCursor: "cursor_select.NAME",
|
||||
wantTarget: "users.NAME",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Invalid column",
|
||||
field: "invalid_field",
|
||||
prefix: "",
|
||||
tableName: "users",
|
||||
modelColumns: []string{"id", "name", "email"},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "JSON field",
|
||||
field: "metadata->>'key'",
|
||||
prefix: "",
|
||||
tableName: "posts",
|
||||
modelColumns: []string{"id", "metadata"},
|
||||
wantCursor: "cursor_select.metadata->>'key'",
|
||||
wantTarget: "posts.metadata->>'key'",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Joined column (not supported)",
|
||||
field: "name",
|
||||
prefix: "user",
|
||||
tableName: "posts",
|
||||
modelColumns: []string{"id", "title"},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cursor, target, err := resolveColumn(tt.field, tt.prefix, tt.tableName, tt.modelColumns)
|
||||
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Error("Expected error but got none")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if cursor != tt.wantCursor {
|
||||
t.Errorf("Expected cursor %q, got %q", tt.wantCursor, cursor)
|
||||
}
|
||||
|
||||
if target != tt.wantTarget {
|
||||
t.Errorf("Expected target %q, got %q", tt.wantTarget, target)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildPriorityChain(t *testing.T) {
|
||||
clauses := []string{
|
||||
"cursor_select.priority > tasks.priority",
|
||||
"cursor_select.created_at > tasks.created_at",
|
||||
"cursor_select.id < tasks.id",
|
||||
}
|
||||
|
||||
result := buildPriorityChain(clauses)
|
||||
|
||||
// Should build OR-AND chain for cursor comparison
|
||||
if !strings.Contains(result, "OR") {
|
||||
t.Error("Priority chain should contain OR operators")
|
||||
}
|
||||
|
||||
if !strings.Contains(result, "AND") {
|
||||
t.Error("Priority chain should contain AND operators for composite conditions")
|
||||
}
|
||||
|
||||
// First clause should appear standalone
|
||||
if !strings.Contains(result, clauses[0]) {
|
||||
t.Errorf("Priority chain should contain first clause: %s", clauses[0])
|
||||
}
|
||||
|
||||
t.Logf("Built priority chain: %s", result)
|
||||
}
|
||||
|
||||
func TestCursorFilter_SQL_Safety(t *testing.T) {
|
||||
// Test that cursor filter doesn't allow SQL injection
|
||||
options := common.RequestOptions{
|
||||
Sort: []common.SortOption{
|
||||
{Column: "created_at", Direction: "DESC"},
|
||||
},
|
||||
CursorForward: "123; DROP TABLE users; --",
|
||||
}
|
||||
|
||||
tableName := "posts"
|
||||
pkName := "id"
|
||||
modelColumns := []string{"id", "created_at"}
|
||||
|
||||
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options)
|
||||
if err != nil {
|
||||
t.Fatalf("GetCursorFilter failed: %v", err)
|
||||
}
|
||||
|
||||
// The cursor ID is inserted directly into the query
|
||||
// This should be sanitized by the sanitizeWhereClause function in the handler
|
||||
// For now, just verify it generates a filter
|
||||
if filter == "" {
|
||||
t.Error("Expected non-empty cursor filter even with special characters")
|
||||
}
|
||||
|
||||
t.Logf("Generated filter with special chars in cursor: %s", filter)
|
||||
}
|
||||
@@ -282,17 +282,61 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
||||
query = query.Order(fmt.Sprintf("%s %s", sort.Column, direction))
|
||||
}
|
||||
|
||||
// Apply cursor-based pagination
|
||||
if len(options.CursorForward) > 0 || len(options.CursorBackward) > 0 {
|
||||
logger.Debug("Applying cursor pagination")
|
||||
|
||||
// Get primary key name
|
||||
pkName := reflection.GetPrimaryKeyName(model)
|
||||
|
||||
// Extract model columns for validation
|
||||
modelColumns := reflection.GetModelColumns(model)
|
||||
|
||||
// Get cursor filter SQL
|
||||
cursorFilter, err := GetCursorFilter(tableName, pkName, modelColumns, options)
|
||||
if err != nil {
|
||||
logger.Error("Error building cursor filter: %v", err)
|
||||
h.sendError(w, http.StatusBadRequest, "cursor_error", "Invalid cursor pagination", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Apply cursor filter to query
|
||||
if cursorFilter != "" {
|
||||
logger.Debug("Applying cursor filter: %s", cursorFilter)
|
||||
sanitizedCursor := common.SanitizeWhereClause(cursorFilter, reflection.ExtractTableNameOnly(tableName))
|
||||
if sanitizedCursor != "" {
|
||||
query = query.Where(sanitizedCursor)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get total count before pagination
|
||||
var total int
|
||||
|
||||
// Try to get from cache first
|
||||
cacheKeyHash := cache.BuildQueryCacheKey(
|
||||
tableName,
|
||||
options.Filters,
|
||||
options.Sort,
|
||||
"", // No custom SQL WHERE in resolvespec
|
||||
"", // No custom SQL OR in resolvespec
|
||||
)
|
||||
// Use extended cache key if cursors are present
|
||||
var cacheKeyHash string
|
||||
if len(options.CursorForward) > 0 || len(options.CursorBackward) > 0 {
|
||||
cacheKeyHash = cache.BuildExtendedQueryCacheKey(
|
||||
tableName,
|
||||
options.Filters,
|
||||
options.Sort,
|
||||
"", // No custom SQL WHERE in resolvespec
|
||||
"", // No custom SQL OR in resolvespec
|
||||
nil, // No expand options in resolvespec
|
||||
false, // distinct not used here
|
||||
options.CursorForward,
|
||||
options.CursorBackward,
|
||||
)
|
||||
} else {
|
||||
cacheKeyHash = cache.BuildQueryCacheKey(
|
||||
tableName,
|
||||
options.Filters,
|
||||
options.Sort,
|
||||
"", // No custom SQL WHERE in resolvespec
|
||||
"", // No custom SQL OR in resolvespec
|
||||
)
|
||||
}
|
||||
cacheKey := cache.GetQueryTotalCacheKey(cacheKeyHash)
|
||||
|
||||
// Try to retrieve from cache
|
||||
|
||||
Reference in New Issue
Block a user