From ed67caf055dad94c5412f786a36d0c2354beab14 Mon Sep 17 00:00:00 2001 From: Hein Date: Tue, 23 Dec 2025 14:17:02 +0200 Subject: [PATCH] fix: reasheadspec customsql calls AddTablePrefixToColumns --- pkg/common/sql_helpers.go | 158 ++++++++++++++++++++++++++++++++---- pkg/restheadspec/handler.go | 9 +- 2 files changed, 150 insertions(+), 17 deletions(-) diff --git a/pkg/common/sql_helpers.go b/pkg/common/sql_helpers.go index 5b14ce5..2036dfb 100644 --- a/pkg/common/sql_helpers.go +++ b/pkg/common/sql_helpers.go @@ -208,21 +208,9 @@ func SanitizeWhereClause(where string, tableName string, options ...*RequestOpti } } } - } else if tableName != "" && !hasTablePrefix(condToCheck) { - // If tableName is provided and the condition DOESN'T have a table prefix, - // qualify unambiguous column references to prevent "ambiguous column" errors - // when there are multiple joins on the same table (e.g., recursive preloads) - columnName := extractUnqualifiedColumnName(condToCheck) - if columnName != "" && (validColumns == nil || isValidColumn(columnName, validColumns)) { - // Qualify the column with the table name - // Be careful to only replace the column name, not other occurrences of the string - oldRef := columnName - newRef := tableName + "." + columnName - // Use word boundary matching to avoid replacing partial matches - cond = qualifyColumnInCondition(cond, oldRef, newRef) - logger.Debug("Qualified unqualified column in condition: '%s' added table prefix '%s'", oldRef, tableName) - } } + // Note: We no longer add prefixes to unqualified columns here. + // Use AddTablePrefixToColumns() separately if you need to add prefixes. validConditions = append(validConditions, cond) } @@ -633,3 +621,145 @@ func isValidColumn(columnName string, validColumns map[string]bool) bool { } return validColumns[strings.ToLower(columnName)] } + +// AddTablePrefixToColumns adds table prefix to unqualified column references in a WHERE clause. +// This function only prefixes simple column references and skips: +// - Columns already having a table prefix (containing a dot) +// - Columns inside function calls or expressions (inside parentheses) +// - Columns inside subqueries +// - Columns that don't exist in the table (validation via model registry) +// +// Examples: +// - "status = 'active'" -> "users.status = 'active'" (if status exists in users table) +// - "COALESCE(status, 'default') = 'active'" -> unchanged (status inside function) +// - "users.status = 'active'" -> unchanged (already has prefix) +// - "(status = 'active')" -> "(users.status = 'active')" (grouping parens are OK) +// - "invalid_col = 'value'" -> unchanged (if invalid_col doesn't exist in table) +// +// Parameters: +// - where: The WHERE clause to process +// - tableName: The table name to use as prefix +// +// Returns: +// - The WHERE clause with table prefixes added to appropriate and valid columns +func AddTablePrefixToColumns(where string, tableName string) string { + if where == "" || tableName == "" { + return where + } + + where = strings.TrimSpace(where) + + // Get valid columns from the model registry for validation + validColumns := getValidColumnsForTable(tableName) + + // Split by AND to handle multiple conditions (parenthesis-aware) + conditions := splitByAND(where) + prefixedConditions := make([]string, 0, len(conditions)) + + for _, cond := range conditions { + cond = strings.TrimSpace(cond) + if cond == "" { + continue + } + + // Process this condition to add table prefix if appropriate + processedCond := addPrefixToSingleCondition(cond, tableName, validColumns) + prefixedConditions = append(prefixedConditions, processedCond) + } + + if len(prefixedConditions) == 0 { + return "" + } + + return strings.Join(prefixedConditions, " AND ") +} + +// addPrefixToSingleCondition adds table prefix to a single condition if appropriate +// Returns the condition unchanged if: +// - The condition is a SQL literal/expression (true, false, null, 1=1, etc.) +// - The column reference is inside a function call +// - The column already has a table prefix +// - No valid column reference is found +// - The column doesn't exist in the table (when validColumns is provided) +func addPrefixToSingleCondition(cond string, tableName string, validColumns map[string]bool) string { + // Strip outer grouping parentheses to get to the actual condition + strippedCond := stripOuterParentheses(cond) + + // Skip SQL literals and trivial conditions (true, false, null, 1=1, etc.) + if IsSQLExpression(strippedCond) || IsTrivialCondition(strippedCond) { + logger.Debug("Skipping SQL literal/trivial condition: '%s'", strippedCond) + return cond + } + + // Extract the left side of the comparison (before the operator) + columnRef := extractLeftSideOfComparison(strippedCond) + if columnRef == "" { + return cond + } + + // Skip if it already has a prefix (contains a dot) + if strings.Contains(columnRef, ".") { + logger.Debug("Skipping column '%s' - already has table prefix", columnRef) + return cond + } + + // Skip if it's a function call or expression (contains parentheses) + if strings.Contains(columnRef, "(") { + logger.Debug("Skipping column reference '%s' - inside function or expression", columnRef) + return cond + } + + // Validate that the column exists in the table (if we have column info) + if !isValidColumn(columnRef, validColumns) { + logger.Debug("Skipping column '%s' - not found in table '%s'", columnRef, tableName) + return cond + } + + // It's a simple unqualified column reference that exists in the table - add the table prefix + newRef := tableName + "." + columnRef + result := qualifyColumnInCondition(cond, columnRef, newRef) + logger.Debug("Added table prefix to column: '%s' -> '%s'", columnRef, newRef) + + return result +} + +// extractLeftSideOfComparison extracts the left side of a comparison operator from a condition. +// This is used to identify the column reference that may need a table prefix. +// +// Examples: +// - "status = 'active'" returns "status" +// - "COALESCE(status, 'default') = 'active'" returns "COALESCE(status, 'default')" +// - "priority > 5" returns "priority" +// +// Returns empty string if no operator is found. +func extractLeftSideOfComparison(cond string) string { + operators := []string{" = ", " != ", " <> ", " > ", " >= ", " < ", " <= ", " LIKE ", " like ", " IN ", " in ", " IS ", " is ", " NOT ", " not "} + + // Find the first operator outside of parentheses and quotes + minIdx := -1 + for _, op := range operators { + idx := findOperatorOutsideParentheses(cond, op) + if idx > 0 && (minIdx == -1 || idx < minIdx) { + minIdx = idx + } + } + + if minIdx > 0 { + leftSide := strings.TrimSpace(cond[:minIdx]) + // Remove any surrounding quotes + leftSide = strings.Trim(leftSide, "`\"'") + return leftSide + } + + // No operator found - might be a boolean column + parts := strings.Fields(cond) + if len(parts) > 0 { + columnRef := strings.Trim(parts[0], "`\"'") + // Make sure it's not a SQL keyword + if !IsSQLKeyword(strings.ToLower(columnRef)) { + return columnRef + } + } + + return "" +} diff --git a/pkg/restheadspec/handler.go b/pkg/restheadspec/handler.go index 2cd2e5c..f109fd5 100644 --- a/pkg/restheadspec/handler.go +++ b/pkg/restheadspec/handler.go @@ -482,8 +482,10 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st // Apply custom SQL WHERE clause (AND condition) if options.CustomSQLWhere != "" { logger.Debug("Applying custom SQL WHERE: %s", options.CustomSQLWhere) - // Sanitize and allow preload table prefixes since custom SQL may reference multiple tables - sanitizedWhere := common.SanitizeWhereClause(options.CustomSQLWhere, reflection.ExtractTableNameOnly(tableName), &options.RequestOptions) + // First add table prefixes to unqualified columns (but skip columns inside function calls) + prefixedWhere := common.AddTablePrefixToColumns(options.CustomSQLWhere, reflection.ExtractTableNameOnly(tableName)) + // Then sanitize and allow preload table prefixes since custom SQL may reference multiple tables + sanitizedWhere := common.SanitizeWhereClause(prefixedWhere, reflection.ExtractTableNameOnly(tableName), &options.RequestOptions) if sanitizedWhere != "" { query = query.Where(sanitizedWhere) } @@ -492,8 +494,9 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st // Apply custom SQL WHERE clause (OR condition) if options.CustomSQLOr != "" { logger.Debug("Applying custom SQL OR: %s", options.CustomSQLOr) + customOr := common.AddTablePrefixToColumns(options.CustomSQLOr, reflection.ExtractTableNameOnly(tableName)) // Sanitize and allow preload table prefixes since custom SQL may reference multiple tables - sanitizedOr := common.SanitizeWhereClause(options.CustomSQLOr, reflection.ExtractTableNameOnly(tableName), &options.RequestOptions) + sanitizedOr := common.SanitizeWhereClause(customOr, reflection.ExtractTableNameOnly(tableName), &options.RequestOptions) if sanitizedOr != "" { query = query.WhereOr(sanitizedOr) }