From e1abd5ebc13bb94356fb423614c23cb667dbe4d8 Mon Sep 17 00:00:00 2001 From: Hein Date: Wed, 10 Dec 2025 08:36:24 +0200 Subject: [PATCH] Enhanced the SanitizeWhereClause function --- pkg/common/sql_helpers.go | 105 ++++++++++++--- pkg/common/sql_helpers_test.go | 228 ++++++++++++++++++++++++++++++--- pkg/resolvespec/handler.go | 6 +- pkg/restheadspec/handler.go | 20 +-- 4 files changed, 310 insertions(+), 49 deletions(-) diff --git a/pkg/common/sql_helpers.go b/pkg/common/sql_helpers.go index 31e8638..1d2ac6c 100644 --- a/pkg/common/sql_helpers.go +++ b/pkg/common/sql_helpers.go @@ -78,17 +78,22 @@ func IsTrivialCondition(cond string) bool { return false } -// SanitizeWhereClause removes trivial conditions and optionally prefixes table/relation names to columns +// SanitizeWhereClause removes trivial conditions and fixes incorrect table prefixes // This function should be used everywhere a WHERE statement is sent to ensure clean, efficient SQL // // Parameters: // - where: The WHERE clause string to sanitize -// - tableName: Optional table/relation name to prefix to column references (empty string to skip prefixing) +// - tableName: The correct table/relation name to use when fixing incorrect prefixes +// - options: Optional RequestOptions containing preload relations that should be allowed as valid prefixes // // Returns: -// - The sanitized WHERE clause with trivial conditions removed and columns optionally prefixed +// - The sanitized WHERE clause with trivial conditions removed and incorrect prefixes fixed // - An empty string if all conditions were trivial or the input was empty -func SanitizeWhereClause(where string, tableName string) string { +// +// Note: This function will NOT add prefixes to unprefixed columns. It will only fix +// incorrect prefixes (e.g., wrong_table.column -> correct_table.column), unless the +// prefix matches a preloaded relation name, in which case it's left unchanged. +func SanitizeWhereClause(where string, tableName string, options ...*RequestOptions) string { if where == "" { return "" } @@ -104,6 +109,22 @@ func SanitizeWhereClause(where string, tableName string) string { validColumns = getValidColumnsForTable(tableName) } + // Build a set of allowed table prefixes (main table + preloaded relations) + allowedPrefixes := make(map[string]bool) + if tableName != "" { + allowedPrefixes[tableName] = true + } + + // Add preload relation names as allowed prefixes + if len(options) > 0 && options[0] != nil { + for pi := range options[0].Preload { + if options[0].Preload[pi].Relation != "" { + allowedPrefixes[options[0].Preload[pi].Relation] = true + logger.Debug("Added preload relation '%s' as allowed table prefix", options[0].Preload[pi].Relation) + } + } + } + // Split by AND to handle multiple conditions conditions := splitByAND(where) @@ -124,22 +145,23 @@ func SanitizeWhereClause(where string, tableName string) string { continue } - // If tableName is provided and the condition doesn't already have a table prefix, - // attempt to add it - if tableName != "" && !hasTablePrefix(condToCheck) { - // Check if this is a SQL expression/literal that shouldn't be prefixed - if !IsSQLExpression(strings.ToLower(condToCheck)) { - // Extract the column name and prefix it - columnName := ExtractColumnName(condToCheck) - if columnName != "" { - // Only prefix if this is a valid column in the model - // If we don't have model info (validColumns is nil), prefix anyway for backward compatibility + // If tableName is provided and the condition HAS a table prefix, check if it's correct + if tableName != "" && hasTablePrefix(condToCheck) { + // Extract the current prefix and column name + currentPrefix, columnName := extractTableAndColumn(condToCheck) + + if currentPrefix != "" && columnName != "" { + // Check if the prefix is allowed (main table or preload relation) + if !allowedPrefixes[currentPrefix] { + // Prefix is not in the allowed list - only fix if it's a valid column in the main table if validColumns == nil || isValidColumn(columnName, validColumns) { - // Replace in the original condition (without stripped parens) - cond = strings.Replace(cond, columnName, tableName+"."+columnName, 1) - logger.Debug("Prefixed column in condition: '%s'", cond) + // Replace the incorrect prefix with the correct main table name + oldRef := currentPrefix + "." + columnName + newRef := tableName + "." + columnName + cond = strings.Replace(cond, oldRef, newRef, 1) + logger.Debug("Fixed incorrect table prefix in condition: '%s' -> '%s'", oldRef, newRef) } else { - logger.Debug("Skipping prefix for '%s' - not a valid column in model", columnName) + logger.Debug("Skipping prefix fix for '%s.%s' - not a valid column in main table (might be preload relation)", currentPrefix, columnName) } } } @@ -288,6 +310,53 @@ func getValidColumnsForTable(tableName string) map[string]bool { return columnMap } +// extractTableAndColumn extracts the table prefix and column name from a qualified reference +// For example: "users.status = 'active'" returns ("users", "status") +// Returns empty strings if no table prefix is found +func extractTableAndColumn(cond string) (table string, column string) { + // Common SQL operators to find the column reference + operators := []string{" = ", " != ", " <> ", " > ", " >= ", " < ", " <= ", " LIKE ", " like ", " IN ", " in ", " IS ", " is "} + + var columnRef string + + // Find the column reference (left side of the operator) + for _, op := range operators { + if idx := strings.Index(cond, op); idx > 0 { + columnRef = strings.TrimSpace(cond[:idx]) + break + } + } + + // If no operator found, the whole condition might be the column reference + if columnRef == "" { + parts := strings.Fields(cond) + if len(parts) > 0 { + columnRef = parts[0] + } + } + + if columnRef == "" { + return "", "" + } + + // Remove any quotes + columnRef = strings.Trim(columnRef, "`\"'") + + // Check if it contains a dot (qualified reference) + if dotIdx := strings.LastIndex(columnRef, "."); dotIdx > 0 { + table = columnRef[:dotIdx] + column = columnRef[dotIdx+1:] + + // Remove quotes from table and column if present + table = strings.Trim(table, "`\"'") + column = strings.Trim(column, "`\"'") + + return table, column + } + + return "", "" +} + // isValidColumn checks if a column name exists in the valid columns map // Handles case-insensitive comparison func isValidColumn(columnName string, validColumns map[string]bool) bool { diff --git a/pkg/common/sql_helpers_test.go b/pkg/common/sql_helpers_test.go index 2b0d5a0..5f328a5 100644 --- a/pkg/common/sql_helpers_test.go +++ b/pkg/common/sql_helpers_test.go @@ -32,29 +32,41 @@ func TestSanitizeWhereClause(t *testing.T) { expected: "", }, { - name: "valid condition with parentheses", + name: "valid condition with parentheses - no prefix added", where: "(status = 'active')", tableName: "users", - expected: "users.status = 'active'", + expected: "status = 'active'", }, { - name: "mixed trivial and valid conditions", + name: "mixed trivial and valid conditions - no prefix added", where: "true AND status = 'active' AND 1=1", tableName: "users", - expected: "users.status = 'active'", + expected: "status = 'active'", }, { - name: "condition already with table prefix", + name: "condition with correct table prefix - unchanged", where: "users.status = 'active'", tableName: "users", expected: "users.status = 'active'", }, { - name: "multiple valid conditions", - where: "status = 'active' AND age > 18", + name: "condition with incorrect table prefix - fixed", + where: "wrong_table.status = 'active'", + tableName: "users", + expected: "users.status = 'active'", + }, + { + name: "multiple conditions with incorrect prefix - fixed", + where: "wrong_table.status = 'active' AND wrong_table.age > 18", tableName: "users", expected: "users.status = 'active' AND users.age > 18", }, + { + name: "multiple valid conditions without prefix - no prefix added", + where: "status = 'active' AND age > 18", + tableName: "users", + expected: "status = 'active' AND age > 18", + }, { name: "no table name provided", where: "status = 'active'", @@ -67,6 +79,12 @@ func TestSanitizeWhereClause(t *testing.T) { tableName: "users", expected: "", }, + { + name: "mixed correct and incorrect prefixes", + where: "users.status = 'active' AND wrong_table.age > 18", + tableName: "users", + expected: "users.status = 'active' AND users.age > 18", + }, } for _, tt := range tests { @@ -159,6 +177,158 @@ func TestIsTrivialCondition(t *testing.T) { } } +func TestExtractTableAndColumn(t *testing.T) { + tests := []struct { + name string + input string + expectedTable string + expectedCol string + }{ + { + name: "qualified column with equals", + input: "users.status = 'active'", + expectedTable: "users", + expectedCol: "status", + }, + { + name: "qualified column with greater than", + input: "users.age > 18", + expectedTable: "users", + expectedCol: "age", + }, + { + name: "qualified column with LIKE", + input: "users.name LIKE '%john%'", + expectedTable: "users", + expectedCol: "name", + }, + { + name: "qualified column with IN", + input: "users.status IN ('active', 'pending')", + expectedTable: "users", + expectedCol: "status", + }, + { + name: "unqualified column", + input: "status = 'active'", + expectedTable: "", + expectedCol: "", + }, + { + name: "qualified with backticks", + input: "`users`.`status` = 'active'", + expectedTable: "users", + expectedCol: "status", + }, + { + name: "schema.table.column reference", + input: "public.users.status = 'active'", + expectedTable: "public.users", + expectedCol: "status", + }, + { + name: "empty string", + input: "", + expectedTable: "", + expectedCol: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + table, col := extractTableAndColumn(tt.input) + if table != tt.expectedTable || col != tt.expectedCol { + t.Errorf("extractTableAndColumn(%q) = (%q, %q); want (%q, %q)", + tt.input, table, col, tt.expectedTable, tt.expectedCol) + } + }) + } +} + +func TestSanitizeWhereClauseWithPreloads(t *testing.T) { + tests := []struct { + name string + where string + tableName string + options *RequestOptions + expected string + }{ + { + name: "preload relation prefix is preserved", + where: "Department.name = 'Engineering'", + tableName: "users", + options: &RequestOptions{ + Preload: []PreloadOption{ + {Relation: "Department"}, + }, + }, + expected: "Department.name = 'Engineering'", + }, + { + name: "multiple preload relations - all preserved", + where: "Department.name = 'Engineering' AND Manager.status = 'active'", + tableName: "users", + options: &RequestOptions{ + Preload: []PreloadOption{ + {Relation: "Department"}, + {Relation: "Manager"}, + }, + }, + expected: "Department.name = 'Engineering' AND Manager.status = 'active'", + }, + { + name: "mix of main table and preload relation", + where: "users.status = 'active' AND Department.name = 'Engineering'", + tableName: "users", + options: &RequestOptions{ + Preload: []PreloadOption{ + {Relation: "Department"}, + }, + }, + expected: "users.status = 'active' AND Department.name = 'Engineering'", + }, + { + name: "incorrect prefix fixed when not a preload relation", + where: "wrong_table.status = 'active' AND Department.name = 'Engineering'", + tableName: "users", + options: &RequestOptions{ + Preload: []PreloadOption{ + {Relation: "Department"}, + }, + }, + expected: "users.status = 'active' AND Department.name = 'Engineering'", + }, + { + name: "no options provided - works as before", + where: "wrong_table.status = 'active'", + tableName: "users", + options: nil, + expected: "users.status = 'active'", + }, + { + name: "empty preload list - works as before", + where: "wrong_table.status = 'active'", + tableName: "users", + options: &RequestOptions{Preload: []PreloadOption{}}, + expected: "users.status = 'active'", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var result string + if tt.options != nil { + result = SanitizeWhereClause(tt.where, tt.tableName, tt.options) + } else { + result = SanitizeWhereClause(tt.where, tt.tableName) + } + if result != tt.expected { + t.Errorf("SanitizeWhereClause(%q, %q, options) = %q; want %q", tt.where, tt.tableName, result, tt.expected) + } + }) + } +} + // Test model for model-aware sanitization tests type MasterTask struct { ID int `bun:"id,pk"` @@ -182,34 +352,52 @@ func TestSanitizeWhereClauseWithModel(t *testing.T) { expected string }{ { - name: "valid column gets prefixed", + name: "valid column without prefix - no prefix added", where: "status = 'active'", tableName: "mastertask", + expected: "status = 'active'", + }, + { + name: "multiple valid columns without prefix - no prefix added", + where: "status = 'active' AND user_id = 123", + tableName: "mastertask", + expected: "status = 'active' AND user_id = 123", + }, + { + name: "incorrect table prefix on valid column - fixed", + where: "wrong_table.status = 'active'", + tableName: "mastertask", expected: "mastertask.status = 'active'", }, { - name: "multiple valid columns get prefixed", - where: "status = 'active' AND user_id = 123", + name: "incorrect prefix on invalid column - not fixed", + where: "wrong_table.invalid_column = 'value'", tableName: "mastertask", - expected: "mastertask.status = 'active' AND mastertask.user_id = 123", - }, - { - name: "invalid column does not get prefixed", - where: "invalid_column = 'value'", - tableName: "mastertask", - expected: "invalid_column = 'value'", + expected: "wrong_table.invalid_column = 'value'", }, { name: "mix of valid and trivial conditions", where: "true AND status = 'active' AND 1=1", tableName: "mastertask", + expected: "status = 'active'", + }, + { + name: "parentheses with valid column - no prefix added", + where: "(status = 'active')", + tableName: "mastertask", + expected: "status = 'active'", + }, + { + name: "correct prefix - unchanged", + where: "mastertask.status = 'active'", + tableName: "mastertask", expected: "mastertask.status = 'active'", }, { - name: "parentheses with valid column", - where: "(status = 'active')", + name: "multiple conditions with mixed prefixes", + where: "mastertask.status = 'active' AND wrong_table.user_id = 123", tableName: "mastertask", - expected: "mastertask.status = 'active'", + expected: "mastertask.status = 'active' AND mastertask.user_id = 123", }, } diff --git a/pkg/resolvespec/handler.go b/pkg/resolvespec/handler.go index d410851..84512b6 100644 --- a/pkg/resolvespec/handler.go +++ b/pkg/resolvespec/handler.go @@ -316,7 +316,7 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st // Apply cursor filter to query if cursorFilter != "" { logger.Debug("Applying cursor filter: %s", cursorFilter) - sanitizedCursor := common.SanitizeWhereClause(cursorFilter, reflection.ExtractTableNameOnly(tableName)) + sanitizedCursor := common.SanitizeWhereClause(cursorFilter, reflection.ExtractTableNameOnly(tableName), &options) if sanitizedCursor != "" { query = query.Where(sanitizedCursor) } @@ -1351,7 +1351,9 @@ func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, pre } if len(preload.Where) > 0 { - sanitizedWhere := common.SanitizeWhereClause(preload.Where, reflection.ExtractTableNameOnly(preload.Relation)) + // Build RequestOptions with all preloads to allow references to sibling relations + preloadOpts := &common.RequestOptions{Preload: preloads} + sanitizedWhere := common.SanitizeWhereClause(preload.Where, reflection.ExtractTableNameOnly(preload.Relation), preloadOpts) if len(sanitizedWhere) > 0 { sq = sq.Where(sanitizedWhere) } diff --git a/pkg/restheadspec/handler.go b/pkg/restheadspec/handler.go index df239fb..4295189 100644 --- a/pkg/restheadspec/handler.go +++ b/pkg/restheadspec/handler.go @@ -450,7 +450,7 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st } // Apply the preload with recursive support - query = h.applyPreloadWithRecursion(query, preload, model, 0) + query = h.applyPreloadWithRecursion(query, preload, options.Preload, model, 0) } // Apply DISTINCT if requested @@ -480,8 +480,8 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st // Apply custom SQL WHERE clause (AND condition) if options.CustomSQLWhere != "" { logger.Debug("Applying custom SQL WHERE: %s", options.CustomSQLWhere) - // Sanitize without auto-prefixing since custom SQL may reference multiple tables - sanitizedWhere := common.SanitizeWhereClause(options.CustomSQLWhere, reflection.ExtractTableNameOnly(tableName)) + // Sanitize and allow preload table prefixes since custom SQL may reference multiple tables + sanitizedWhere := common.SanitizeWhereClause(options.CustomSQLWhere, reflection.ExtractTableNameOnly(tableName), &options.RequestOptions) if sanitizedWhere != "" { query = query.Where(sanitizedWhere) } @@ -490,8 +490,8 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st // Apply custom SQL WHERE clause (OR condition) if options.CustomSQLOr != "" { logger.Debug("Applying custom SQL OR: %s", options.CustomSQLOr) - // Sanitize without auto-prefixing since custom SQL may reference multiple tables - sanitizedOr := common.SanitizeWhereClause(options.CustomSQLOr, reflection.ExtractTableNameOnly(tableName)) + // Sanitize and allow preload table prefixes since custom SQL may reference multiple tables + sanitizedOr := common.SanitizeWhereClause(options.CustomSQLOr, reflection.ExtractTableNameOnly(tableName), &options.RequestOptions) if sanitizedOr != "" { query = query.WhereOr(sanitizedOr) } @@ -625,7 +625,7 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st // Apply cursor filter to query if cursorFilter != "" { logger.Debug("Applying cursor filter: %s", cursorFilter) - sanitizedCursor := common.SanitizeWhereClause(cursorFilter, reflection.ExtractTableNameOnly(tableName)) + sanitizedCursor := common.SanitizeWhereClause(cursorFilter, reflection.ExtractTableNameOnly(tableName), &options.RequestOptions) if sanitizedCursor != "" { query = query.Where(sanitizedCursor) } @@ -703,7 +703,7 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st } // applyPreloadWithRecursion applies a preload with support for ComputedQL and recursive preloading -func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload common.PreloadOption, model interface{}, depth int) common.SelectQuery { +func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload common.PreloadOption, allPreloads []common.PreloadOption, model interface{}, depth int) common.SelectQuery { // Log relationship keys if they're specified (from XFiles) if preload.RelatedKey != "" || preload.ForeignKey != "" || preload.PrimaryKey != "" { logger.Debug("Preload %s has relationship keys - PK: %s, RelatedKey: %s, ForeignKey: %s", @@ -799,7 +799,9 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co // Apply WHERE clause if len(preload.Where) > 0 { - sanitizedWhere := common.SanitizeWhereClause(preload.Where, reflection.ExtractTableNameOnly(preload.Relation)) + // Build RequestOptions with all preloads to allow references to sibling relations + preloadOpts := &common.RequestOptions{Preload: allPreloads} + sanitizedWhere := common.SanitizeWhereClause(preload.Where, reflection.ExtractTableNameOnly(preload.Relation), preloadOpts) if len(sanitizedWhere) > 0 { sq = sq.Where(sanitizedWhere) } @@ -832,7 +834,7 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co recursivePreload.Relation = preload.Relation + "." + lastRelationName // Recursively apply preload until we reach depth 5 - query = h.applyPreloadWithRecursion(query, recursivePreload, model, depth+1) + query = h.applyPreloadWithRecursion(query, recursivePreload, allPreloads, model, depth+1) } return query