Compare commits

...

5 Commits

Author SHA1 Message Date
Hein
163593901f Huge preload chains causing errors, workaround to do seperate selects.
Some checks failed
Tests / Run Tests (1.23.x) (push) Has been cancelled
Tests / Run Tests (1.24.x) (push) Has been cancelled
Tests / Lint Code (push) Has been cancelled
Tests / Build (push) Has been cancelled
2025-11-21 17:09:11 +02:00
Hein
1261960e97 Ability to handle multiple x-custom- headers
Some checks are pending
Tests / Run Tests (1.23.x) (push) Waiting to run
Tests / Run Tests (1.24.x) (push) Waiting to run
Tests / Lint Code (push) Waiting to run
Tests / Build (push) Waiting to run
2025-11-21 12:15:07 +02:00
Hein
76bbf33db2 Fixed SingleRecordAsObject true when handleRead with no id 2025-11-21 11:49:08 +02:00
Hein
02c9b96b0c Better SanitizeWhereClause 2025-11-21 11:42:01 +02:00
Hein
9a3564f05f SanitizeWhereClause with tablename on handlers. 2025-11-21 11:00:44 +02:00
7 changed files with 630 additions and 24 deletions

View File

@@ -4,6 +4,7 @@ import (
"context" "context"
"database/sql" "database/sql"
"fmt" "fmt"
"reflect"
"strings" "strings"
"github.com/uptrace/bun" "github.com/uptrace/bun"
@@ -105,6 +106,14 @@ type BunSelectQuery struct {
schema string // Separated schema name schema string // Separated schema name
tableName string // Just the table name, without schema tableName string // Just the table name, without schema
tableAlias string tableAlias string
deferredPreloads []deferredPreload // Preloads to execute as separate queries
}
// deferredPreload represents a preload that will be executed as a separate query
// to avoid PostgreSQL identifier length limits
type deferredPreload struct {
relation string
apply []func(common.SelectQuery) common.SelectQuery
} }
func (b *BunSelectQuery) Model(model interface{}) common.SelectQuery { func (b *BunSelectQuery) Model(model interface{}) common.SelectQuery {
@@ -233,7 +242,92 @@ func (b *BunSelectQuery) Preload(relation string, conditions ...interface{}) com
return b return b
} }
// // shortenAliasForPostgres shortens a table/relation alias if it would exceed PostgreSQL's 63-char limit
// // when combined with typical column names
// func shortenAliasForPostgres(relationPath string) (string, bool) {
// // Convert relation path to the alias format Bun uses: dots become double underscores
// // Also convert to lowercase and use snake_case as Bun does
// parts := strings.Split(relationPath, ".")
// alias := strings.ToLower(strings.Join(parts, "__"))
// // PostgreSQL truncates identifiers to 63 chars
// // If the alias + typical column name would exceed this, we need to shorten
// // Reserve at least 30 chars for column names (e.g., "__rid_mastertype_hubtype")
// const maxAliasLength = 30
// if len(alias) > maxAliasLength {
// // Create a shortened alias using a hash of the original
// hash := md5.Sum([]byte(alias))
// hashStr := hex.EncodeToString(hash[:])[:8]
// // Keep first few chars of original for readability + hash
// prefixLen := maxAliasLength - 9 // 9 = 1 underscore + 8 hash chars
// if prefixLen > len(alias) {
// prefixLen = len(alias)
// }
// shortened := alias[:prefixLen] + "_" + hashStr
// logger.Debug("Shortened alias '%s' (%d chars) to '%s' (%d chars) to avoid PostgreSQL 63-char limit",
// alias, len(alias), shortened, len(shortened))
// return shortened, true
// }
// return alias, false
// }
// // estimateColumnAliasLength estimates the length of a column alias in a nested preload
// // Bun creates aliases like: relationChain__columnName
// func estimateColumnAliasLength(relationPath string, columnName string) int {
// relationParts := strings.Split(relationPath, ".")
// aliasChain := strings.ToLower(strings.Join(relationParts, "__"))
// // Bun adds "__" between alias and column name
// return len(aliasChain) + 2 + len(columnName)
// }
func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery { func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
// Check if this relation chain would create problematic long aliases
relationParts := strings.Split(relation, ".")
aliasChain := strings.ToLower(strings.Join(relationParts, "__"))
// PostgreSQL's identifier limit is 63 characters
const postgresIdentifierLimit = 63
const safeAliasLimit = 35 // Leave room for column names
// If the alias chain is too long, defer this preload to be executed as a separate query
if len(aliasChain) > safeAliasLimit {
logger.Info("Preload relation '%s' creates long alias chain '%s' (%d chars). "+
"Using separate query to avoid PostgreSQL %d-char identifier limit.",
relation, aliasChain, len(aliasChain), postgresIdentifierLimit)
// For nested preloads (e.g., "Parent.Child"), split into separate preloads
// This avoids the long concatenated alias
if len(relationParts) > 1 {
// Load first level normally: "Parent"
firstLevel := relationParts[0]
remainingPath := strings.Join(relationParts[1:], ".")
logger.Info("Splitting nested preload: loading '%s' first, then '%s' separately",
firstLevel, remainingPath)
// Apply the first level preload normally
b.query = b.query.Relation(firstLevel)
// Store the remaining nested preload to be executed after the main query
b.deferredPreloads = append(b.deferredPreloads, deferredPreload{
relation: relation,
apply: apply,
})
return b
}
// Single level but still too long - just warn and continue
logger.Warn("Single-level preload '%s' has a very long name (%d chars). "+
"Consider renaming the field to avoid potential issues.",
relation, len(aliasChain))
}
// Normal preload handling
b.query = b.query.Relation(relation, func(sq *bun.SelectQuery) *bun.SelectQuery { b.query = b.query.Relation(relation, func(sq *bun.SelectQuery) *bun.SelectQuery {
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
@@ -309,7 +403,23 @@ func (b *BunSelectQuery) Scan(ctx context.Context, dest interface{}) (err error)
if dest == nil { if dest == nil {
return fmt.Errorf("destination cannot be nil") return fmt.Errorf("destination cannot be nil")
} }
return b.query.Scan(ctx, dest)
// Execute the main query first
err = b.query.Scan(ctx, dest)
if err != nil {
return err
}
// Execute any deferred preloads
if len(b.deferredPreloads) > 0 {
err = b.executeDeferredPreloads(ctx, dest)
if err != nil {
logger.Warn("Failed to execute deferred preloads: %v", err)
// Don't fail the whole query, just log the warning
}
}
return nil
} }
func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) { func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) {
@@ -322,7 +432,132 @@ func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) {
return fmt.Errorf("model is nil") return fmt.Errorf("model is nil")
} }
return b.query.Scan(ctx) // Execute the main query first
err = b.query.Scan(ctx)
if err != nil {
return err
}
// Execute any deferred preloads
if len(b.deferredPreloads) > 0 {
model := b.query.GetModel()
err = b.executeDeferredPreloads(ctx, model.Value())
if err != nil {
logger.Warn("Failed to execute deferred preloads: %v", err)
// Don't fail the whole query, just log the warning
}
}
return nil
}
// executeDeferredPreloads executes preloads that were deferred to avoid PostgreSQL identifier length limits
func (b *BunSelectQuery) executeDeferredPreloads(ctx context.Context, dest interface{}) error {
if len(b.deferredPreloads) == 0 {
return nil
}
for _, dp := range b.deferredPreloads {
err := b.executeSingleDeferredPreload(ctx, dest, dp)
if err != nil {
return fmt.Errorf("failed to execute deferred preload '%s': %w", dp.relation, err)
}
}
return nil
}
// executeSingleDeferredPreload executes a single deferred preload
// For a relation like "Parent.Child", it:
// 1. Finds all loaded Parent records in dest
// 2. Loads Child records for those Parents using a separate query (loading only "Child", not "Parent.Child")
// 3. Bun automatically assigns the Child records to the appropriate Parent.Child field
func (b *BunSelectQuery) executeSingleDeferredPreload(ctx context.Context, dest interface{}, dp deferredPreload) error {
relationParts := strings.Split(dp.relation, ".")
if len(relationParts) < 2 {
return fmt.Errorf("deferred preload must be nested (e.g., 'Parent.Child'), got: %s", dp.relation)
}
// The parent relation that was already loaded
parentRelation := relationParts[0]
// The child relation we need to load
childRelation := strings.Join(relationParts[1:], ".")
logger.Debug("Executing deferred preload: loading '%s' on already-loaded '%s'", childRelation, parentRelation)
// Use reflection to access the parent relation field(s) in the loaded records
// Then load the child relation for those parent records
destValue := reflect.ValueOf(dest)
if destValue.Kind() == reflect.Ptr {
destValue = destValue.Elem()
}
// Handle both slice and single record
if destValue.Kind() == reflect.Slice {
// Iterate through each record in the slice
for i := 0; i < destValue.Len(); i++ {
record := destValue.Index(i)
if err := b.loadChildRelationForRecord(ctx, record, parentRelation, childRelation, dp.apply); err != nil {
logger.Warn("Failed to load child relation '%s' for record %d: %v", childRelation, i, err)
// Continue with other records
}
}
} else {
// Single record
if err := b.loadChildRelationForRecord(ctx, destValue, parentRelation, childRelation, dp.apply); err != nil {
return fmt.Errorf("failed to load child relation '%s': %w", childRelation, err)
}
}
return nil
}
// loadChildRelationForRecord loads a child relation for a single parent record
func (b *BunSelectQuery) loadChildRelationForRecord(ctx context.Context, record reflect.Value, parentRelation, childRelation string, apply []func(common.SelectQuery) common.SelectQuery) error {
// Ensure we're working with the actual struct value, not a pointer
if record.Kind() == reflect.Ptr {
record = record.Elem()
}
// Get the parent relation field
parentField := record.FieldByName(parentRelation)
if !parentField.IsValid() {
// Parent relation field doesn't exist
logger.Debug("Parent relation field '%s' not found in record", parentRelation)
return nil
}
// Check if the parent field is nil (for pointer fields)
if parentField.Kind() == reflect.Ptr && parentField.IsNil() {
// Parent relation not loaded or nil, skip
logger.Debug("Parent relation field '%s' is nil, skipping child preload", parentRelation)
return nil
}
// Get the interface value to pass to Bun
parentValue := parentField.Interface()
// Load the child relation on the parent record
// This uses a shorter alias since we're only loading "Child", not "Parent.Child"
return b.db.NewSelect().
Model(parentValue).
Relation(childRelation, func(sq *bun.SelectQuery) *bun.SelectQuery {
// Apply any custom query modifications
if len(apply) > 0 {
wrapper := &BunSelectQuery{query: sq, db: b.db}
current := common.SelectQuery(wrapper)
for _, fn := range apply {
if fn != nil {
current = fn(current)
}
}
if finalBun, ok := current.(*BunSelectQuery); ok {
return finalBun.query
}
}
return sq
}).
Scan(ctx)
} }
func (b *BunSelectQuery) Count(ctx context.Context) (count int, err error) { func (b *BunSelectQuery) Count(ctx context.Context) (count int, err error) {

View File

@@ -5,6 +5,8 @@ import (
"strings" "strings"
"github.com/bitechdev/ResolveSpec/pkg/logger" "github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
"github.com/bitechdev/ResolveSpec/pkg/reflection"
) )
// ValidateAndFixPreloadWhere validates that the WHERE clause for a preload contains // ValidateAndFixPreloadWhere validates that the WHERE clause for a preload contains
@@ -135,6 +137,15 @@ func SanitizeWhereClause(where string, tableName string) string {
where = strings.TrimSpace(where) where = strings.TrimSpace(where)
// Strip outer parentheses and re-trim
where = stripOuterParentheses(where)
// Get valid columns from the model if tableName is provided
var validColumns map[string]bool
if tableName != "" {
validColumns = getValidColumnsForTable(tableName)
}
// Split by AND to handle multiple conditions // Split by AND to handle multiple conditions
conditions := splitByAND(where) conditions := splitByAND(where)
@@ -146,22 +157,32 @@ func SanitizeWhereClause(where string, tableName string) string {
continue continue
} }
// Strip parentheses from the condition before checking
condToCheck := stripOuterParentheses(cond)
// Skip trivial conditions that always evaluate to true // Skip trivial conditions that always evaluate to true
if IsTrivialCondition(cond) { if IsTrivialCondition(condToCheck) {
logger.Debug("Removing trivial condition: '%s'", cond) logger.Debug("Removing trivial condition: '%s'", cond)
continue continue
} }
// If tableName is provided and the condition doesn't already have a table prefix, // If tableName is provided and the condition doesn't already have a table prefix,
// attempt to add it // attempt to add it
if tableName != "" && !hasTablePrefix(cond) { if tableName != "" && !hasTablePrefix(condToCheck) {
// Check if this is a SQL expression/literal that shouldn't be prefixed // Check if this is a SQL expression/literal that shouldn't be prefixed
if !IsSQLExpression(strings.ToLower(cond)) { if !IsSQLExpression(strings.ToLower(condToCheck)) {
// Extract the column name and prefix it // Extract the column name and prefix it
columnName := ExtractColumnName(cond) columnName := ExtractColumnName(condToCheck)
if columnName != "" { if columnName != "" {
// Only prefix if this is a valid column in the model
// If we don't have model info (validColumns is nil), prefix anyway for backward compatibility
if validColumns == nil || isValidColumn(columnName, validColumns) {
// Replace in the original condition (without stripped parens)
cond = strings.Replace(cond, columnName, tableName+"."+columnName, 1) cond = strings.Replace(cond, columnName, tableName+"."+columnName, 1)
logger.Debug("Prefixed column in condition: '%s'", cond) logger.Debug("Prefixed column in condition: '%s'", cond)
} else {
logger.Debug("Skipping prefix for '%s' - not a valid column in model", columnName)
}
} }
} }
} }
@@ -182,6 +203,43 @@ func SanitizeWhereClause(where string, tableName string) string {
return result return result
} }
// stripOuterParentheses removes matching outer parentheses from a string
// It handles nested parentheses correctly
func stripOuterParentheses(s string) string {
s = strings.TrimSpace(s)
for {
if len(s) < 2 || s[0] != '(' || s[len(s)-1] != ')' {
return s
}
// Check if these parentheses match (i.e., they're the outermost pair)
depth := 0
matched := false
for i := 0; i < len(s); i++ {
switch s[i] {
case '(':
depth++
case ')':
depth--
if depth == 0 && i == len(s)-1 {
matched = true
} else if depth == 0 {
// Found a closing paren before the end, so outer parens don't match
return s
}
}
}
if !matched {
return s
}
// Strip the outer parentheses and continue
s = strings.TrimSpace(s[1 : len(s)-1])
}
}
// splitByAND splits a WHERE clause by AND operators (case-insensitive) // splitByAND splits a WHERE clause by AND operators (case-insensitive)
// This is a simple split that doesn't handle nested parentheses or complex expressions // This is a simple split that doesn't handle nested parentheses or complex expressions
func splitByAND(where string) []string { func splitByAND(where string) []string {
@@ -245,3 +303,38 @@ func IsSQLKeyword(word string) bool {
} }
return false return false
} }
// getValidColumnsForTable retrieves the valid SQL columns for a table from the model registry
// Returns a map of column names for fast lookup, or nil if the model is not found
func getValidColumnsForTable(tableName string) map[string]bool {
// Try to get the model from the registry
model, err := modelregistry.GetModelByName(tableName)
if err != nil {
// Model not found, return nil to indicate we should use fallback behavior
return nil
}
// Get SQL columns from the model
columns := reflection.GetSQLModelColumns(model)
if len(columns) == 0 {
// No columns found, return nil
return nil
}
// Build a map for fast lookup
columnMap := make(map[string]bool, len(columns))
for _, col := range columns {
columnMap[strings.ToLower(col)] = true
}
return columnMap
}
// isValidColumn checks if a column name exists in the valid columns map
// Handles case-insensitive comparison
func isValidColumn(columnName string, validColumns map[string]bool) bool {
if validColumns == nil {
return true // No model info, assume valid
}
return validColumns[strings.ToLower(columnName)]
}

View File

@@ -0,0 +1,224 @@
package common
import (
"testing"
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
)
func TestSanitizeWhereClause(t *testing.T) {
tests := []struct {
name string
where string
tableName string
expected string
}{
{
name: "trivial conditions in parentheses",
where: "(true AND true AND true)",
tableName: "mastertask",
expected: "",
},
{
name: "trivial conditions without parentheses",
where: "true AND true AND true",
tableName: "mastertask",
expected: "",
},
{
name: "single trivial condition",
where: "true",
tableName: "mastertask",
expected: "",
},
{
name: "valid condition with parentheses",
where: "(status = 'active')",
tableName: "users",
expected: "users.status = 'active'",
},
{
name: "mixed trivial and valid conditions",
where: "true AND status = 'active' AND 1=1",
tableName: "users",
expected: "users.status = 'active'",
},
{
name: "condition already with table prefix",
where: "users.status = 'active'",
tableName: "users",
expected: "users.status = 'active'",
},
{
name: "multiple valid conditions",
where: "status = 'active' AND age > 18",
tableName: "users",
expected: "users.status = 'active' AND users.age > 18",
},
{
name: "no table name provided",
where: "status = 'active'",
tableName: "",
expected: "status = 'active'",
},
{
name: "empty where clause",
where: "",
tableName: "users",
expected: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := SanitizeWhereClause(tt.where, tt.tableName)
if result != tt.expected {
t.Errorf("SanitizeWhereClause(%q, %q) = %q; want %q", tt.where, tt.tableName, result, tt.expected)
}
})
}
}
func TestStripOuterParentheses(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "single level parentheses",
input: "(true)",
expected: "true",
},
{
name: "multiple levels",
input: "((true))",
expected: "true",
},
{
name: "no parentheses",
input: "true",
expected: "true",
},
{
name: "mismatched parentheses",
input: "(true",
expected: "(true",
},
{
name: "complex expression",
input: "(a AND b)",
expected: "a AND b",
},
{
name: "nested but not outer",
input: "(a AND (b OR c)) AND d",
expected: "(a AND (b OR c)) AND d",
},
{
name: "with spaces",
input: " ( true ) ",
expected: "true",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := stripOuterParentheses(tt.input)
if result != tt.expected {
t.Errorf("stripOuterParentheses(%q) = %q; want %q", tt.input, result, tt.expected)
}
})
}
}
func TestIsTrivialCondition(t *testing.T) {
tests := []struct {
name string
input string
expected bool
}{
{"true", "true", true},
{"true with spaces", " true ", true},
{"TRUE uppercase", "TRUE", true},
{"1=1", "1=1", true},
{"1 = 1", "1 = 1", true},
{"true = true", "true = true", true},
{"valid condition", "status = 'active'", false},
{"false", "false", false},
{"column name", "is_active", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := IsTrivialCondition(tt.input)
if result != tt.expected {
t.Errorf("IsTrivialCondition(%q) = %v; want %v", tt.input, result, tt.expected)
}
})
}
}
// Test model for model-aware sanitization tests
type MasterTask struct {
ID int `bun:"id,pk"`
Name string `bun:"name"`
Status string `bun:"status"`
UserID int `bun:"user_id"`
}
func TestSanitizeWhereClauseWithModel(t *testing.T) {
// Register the test model
err := modelregistry.RegisterModel(MasterTask{}, "mastertask")
if err != nil {
// Model might already be registered, ignore error
t.Logf("Model registration returned: %v", err)
}
tests := []struct {
name string
where string
tableName string
expected string
}{
{
name: "valid column gets prefixed",
where: "status = 'active'",
tableName: "mastertask",
expected: "mastertask.status = 'active'",
},
{
name: "multiple valid columns get prefixed",
where: "status = 'active' AND user_id = 123",
tableName: "mastertask",
expected: "mastertask.status = 'active' AND mastertask.user_id = 123",
},
{
name: "invalid column does not get prefixed",
where: "invalid_column = 'value'",
tableName: "mastertask",
expected: "invalid_column = 'value'",
},
{
name: "mix of valid and trivial conditions",
where: "true AND status = 'active' AND 1=1",
tableName: "mastertask",
expected: "mastertask.status = 'active'",
},
{
name: "parentheses with valid column",
where: "(status = 'active')",
tableName: "mastertask",
expected: "mastertask.status = 'active'",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := SanitizeWhereClause(tt.where, tt.tableName)
if result != tt.expected {
t.Errorf("SanitizeWhereClause(%q, %q) = %q; want %q", tt.where, tt.tableName, result, tt.expected)
}
})
}
}

View File

@@ -17,3 +17,33 @@ func Len(v any) int {
return 0 return 0
} }
} }
// ExtractTableNameOnly extracts the table name from a fully qualified table reference.
// It removes any schema prefix (e.g., "schema.table" -> "table") and truncates at
// the first delimiter (comma, space, tab, or newline). If the input contains multiple
// dots, it returns everything after the last dot up to the first delimiter.
func ExtractTableNameOnly(fullName string) string {
// First, split by dot to remove schema prefix if present
lastDotIndex := -1
for i, char := range fullName {
if char == '.' {
lastDotIndex = i
}
}
// Start from after the last dot (or from beginning if no dot)
startIndex := 0
if lastDotIndex != -1 {
startIndex = lastDotIndex + 1
}
// Now find the end (first delimiter after the table name)
for i := startIndex; i < len(fullName); i++ {
char := rune(fullName[i])
if char == ',' || char == ' ' || char == '\t' || char == '\n' {
return fullName[startIndex:i]
}
}
return fullName[startIndex:]
}

View File

@@ -199,7 +199,9 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
// Apply column selection // Apply column selection
if len(options.Columns) > 0 { if len(options.Columns) > 0 {
logger.Debug("Selecting columns: %v", options.Columns) logger.Debug("Selecting columns: %v", options.Columns)
query = query.Column(options.Columns...) for _, col := range options.Columns {
query = query.Column(reflection.ExtractSourceColumn(col))
}
} }
if len(options.ComputedColumns) > 0 { if len(options.ComputedColumns) > 0 {
@@ -1209,7 +1211,7 @@ func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, pre
} }
if len(preload.Where) > 0 { if len(preload.Where) > 0 {
sanitizedWhere := common.SanitizeWhereClause(preload.Where, preload.Relation) sanitizedWhere := common.SanitizeWhereClause(preload.Where, reflection.ExtractTableNameOnly(preload.Relation))
if len(sanitizedWhere) > 0 { if len(sanitizedWhere) > 0 {
sq = sq.Where(sanitizedWhere) sq = sq.Where(sanitizedWhere)
} }

View File

@@ -213,6 +213,10 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
tableName := GetTableName(ctx) tableName := GetTableName(ctx)
model := GetModel(ctx) model := GetModel(ctx)
if id == "" {
options.SingleRecordAsObject = false
}
// Execute BeforeRead hooks // Execute BeforeRead hooks
hookCtx := &HookContext{ hookCtx := &HookContext{
Context: ctx, Context: ctx,
@@ -299,7 +303,10 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
// Apply column selection // Apply column selection
if len(options.Columns) > 0 { if len(options.Columns) > 0 {
logger.Debug("Selecting columns: %v", options.Columns) logger.Debug("Selecting columns: %v", options.Columns)
query = query.Column(options.Columns...) for _, col := range options.Columns {
query = query.Column(reflection.ExtractSourceColumn(col))
}
} }
// Apply expand (Just expand to Preload for now) // Apply expand (Just expand to Preload for now)
@@ -392,7 +399,7 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
if options.CustomSQLWhere != "" { if options.CustomSQLWhere != "" {
logger.Debug("Applying custom SQL WHERE: %s", options.CustomSQLWhere) logger.Debug("Applying custom SQL WHERE: %s", options.CustomSQLWhere)
// Sanitize without auto-prefixing since custom SQL may reference multiple tables // Sanitize without auto-prefixing since custom SQL may reference multiple tables
sanitizedWhere := common.SanitizeWhereClause(options.CustomSQLWhere, "") sanitizedWhere := common.SanitizeWhereClause(options.CustomSQLWhere, reflection.ExtractTableNameOnly(tableName))
if sanitizedWhere != "" { if sanitizedWhere != "" {
query = query.Where(sanitizedWhere) query = query.Where(sanitizedWhere)
} }
@@ -402,7 +409,7 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
if options.CustomSQLOr != "" { if options.CustomSQLOr != "" {
logger.Debug("Applying custom SQL OR: %s", options.CustomSQLOr) logger.Debug("Applying custom SQL OR: %s", options.CustomSQLOr)
// Sanitize without auto-prefixing since custom SQL may reference multiple tables // Sanitize without auto-prefixing since custom SQL may reference multiple tables
sanitizedOr := common.SanitizeWhereClause(options.CustomSQLOr, "") sanitizedOr := common.SanitizeWhereClause(options.CustomSQLOr, reflection.ExtractTableNameOnly(tableName))
if sanitizedOr != "" { if sanitizedOr != "" {
query = query.WhereOr(sanitizedOr) query = query.WhereOr(sanitizedOr)
} }
@@ -481,7 +488,7 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
// Apply cursor filter to query // Apply cursor filter to query
if cursorFilter != "" { if cursorFilter != "" {
logger.Debug("Applying cursor filter: %s", cursorFilter) logger.Debug("Applying cursor filter: %s", cursorFilter)
sanitizedCursor := common.SanitizeWhereClause(cursorFilter, "") sanitizedCursor := common.SanitizeWhereClause(cursorFilter, reflection.ExtractTableNameOnly(tableName))
if sanitizedCursor != "" { if sanitizedCursor != "" {
query = query.Where(sanitizedCursor) query = query.Where(sanitizedCursor)
} }
@@ -655,7 +662,7 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co
// Apply WHERE clause // Apply WHERE clause
if len(preload.Where) > 0 { if len(preload.Where) > 0 {
sanitizedWhere := common.SanitizeWhereClause(preload.Where, preload.Relation) sanitizedWhere := common.SanitizeWhereClause(preload.Where, reflection.ExtractTableNameOnly(preload.Relation))
if len(sanitizedWhere) > 0 { if len(sanitizedWhere) > 0 {
sq = sq.Where(sanitizedWhere) sq = sq.Where(sanitizedWhere)
} }

View File

@@ -162,9 +162,17 @@ func (h *Handler) parseOptionsFromHeaders(r common.Request, model interface{}) E
case strings.HasPrefix(key, "x-searchcols"): case strings.HasPrefix(key, "x-searchcols"):
options.SearchColumns = h.parseCommaSeparated(decodedValue) options.SearchColumns = h.parseCommaSeparated(decodedValue)
case strings.HasPrefix(key, "x-custom-sql-w"): case strings.HasPrefix(key, "x-custom-sql-w"):
if options.CustomSQLWhere != "" {
options.CustomSQLWhere = fmt.Sprintf("%s AND (%s)", options.CustomSQLWhere, decodedValue)
} else {
options.CustomSQLWhere = decodedValue options.CustomSQLWhere = decodedValue
}
case strings.HasPrefix(key, "x-custom-sql-or"): case strings.HasPrefix(key, "x-custom-sql-or"):
if options.CustomSQLOr != "" {
options.CustomSQLOr = fmt.Sprintf("%s OR (%s)", options.CustomSQLOr, decodedValue)
} else {
options.CustomSQLOr = decodedValue options.CustomSQLOr = decodedValue
}
// Joins & Relations // Joins & Relations
case strings.HasPrefix(key, "x-preload"): case strings.HasPrefix(key, "x-preload"):
@@ -226,6 +234,7 @@ func (h *Handler) parseOptionsFromHeaders(r common.Request, model interface{}) E
case strings.HasPrefix(key, "x-cql-sel-"): case strings.HasPrefix(key, "x-cql-sel-"):
colName := strings.TrimPrefix(key, "x-cql-sel-") colName := strings.TrimPrefix(key, "x-cql-sel-")
options.ComputedQL[colName] = decodedValue options.ComputedQL[colName] = decodedValue
case strings.HasPrefix(key, "x-distinct"): case strings.HasPrefix(key, "x-distinct"):
options.Distinct = strings.EqualFold(decodedValue, "true") options.Distinct = strings.EqualFold(decodedValue, "true")
case strings.HasPrefix(key, "x-skipcount"): case strings.HasPrefix(key, "x-skipcount"):
@@ -267,6 +276,12 @@ func (h *Handler) parseOptionsFromHeaders(r common.Request, model interface{}) E
h.resolveRelationNamesInOptions(&options, model) h.resolveRelationNamesInOptions(&options, model)
} }
// Always sort according to the primary key if no sorting is specified
if len(options.Sort) == 0 {
pkName := reflection.GetPrimaryKeyName(model)
options.Sort = []common.SortOption{{Column: pkName, Direction: "ASC"}}
}
return options return options
} }
@@ -777,7 +792,7 @@ func (h *Handler) resolveRelationName(model interface{}, nameOrTable string) str
field := modelType.Field(i) field := modelType.Field(i)
if field.Name == nameOrTable { if field.Name == nameOrTable {
// It's already a field name // It's already a field name
logger.Debug("Input '%s' is a field name", nameOrTable) // logger.Debug("Input '%s' is a field name", nameOrTable)
return nameOrTable return nameOrTable
} }
} }