diff --git a/pkg/common/sql_helpers.go b/pkg/common/sql_helpers.go index 26d8053..3093d49 100644 --- a/pkg/common/sql_helpers.go +++ b/pkg/common/sql_helpers.go @@ -130,6 +130,9 @@ func validateWhereClauseSecurity(where string) error { // Note: This function will NOT add prefixes to unprefixed columns. It will only fix // incorrect prefixes (e.g., wrong_table.column -> correct_table.column), unless the // prefix matches a preloaded relation name, in which case it's left unchanged. +// +// IMPORTANT: Outer parentheses are preserved if the clause contains top-level OR operators +// to prevent OR logic from escaping and affecting the entire query incorrectly. func SanitizeWhereClause(where string, tableName string, options ...*RequestOptions) string { if where == "" { return "" @@ -143,8 +146,19 @@ func SanitizeWhereClause(where string, tableName string, options ...*RequestOpti return "" } - // Strip outer parentheses and re-trim - where = stripOuterParentheses(where) + // Check if the original clause has outer parentheses and contains OR operators + // If so, we need to preserve the outer parentheses to prevent OR logic from escaping + hasOuterParens := false + if len(where) > 0 && where[0] == '(' && where[len(where)-1] == ')' { + _, hasOuterParens = stripOneMatchingOuterParen(where) + } + + // Strip outer parentheses and re-trim for processing + whereWithoutParens := stripOuterParentheses(where) + shouldPreserveParens := hasOuterParens && containsTopLevelOR(whereWithoutParens) + + // Use the stripped version for processing + where = whereWithoutParens // Get valid columns from the model if tableName is provided var validColumns map[string]bool @@ -229,7 +243,14 @@ func SanitizeWhereClause(where string, tableName string, options ...*RequestOpti result := strings.Join(validConditions, " AND ") - if result != where { + // If the original clause had outer parentheses and contains OR operators, + // restore the outer parentheses to prevent OR logic from escaping + if shouldPreserveParens { + result = "(" + result + ")" + logger.Debug("Preserved outer parentheses for OR conditions: '%s'", result) + } + + if result != where && !shouldPreserveParens { logger.Debug("Sanitized WHERE clause: '%s' -> '%s'", where, result) } @@ -290,6 +311,93 @@ func stripOneMatchingOuterParen(s string) (string, bool) { return strings.TrimSpace(s[1 : len(s)-1]), true } +// EnsureOuterParentheses ensures that a SQL clause is wrapped in parentheses +// to prevent OR logic from escaping. It checks if the clause already has +// matching outer parentheses and only adds them if they don't exist. +// +// This is particularly important for OR conditions and complex filters where +// the absence of parentheses could cause the logic to escape and affect +// the entire query incorrectly. +// +// Parameters: +// - clause: The SQL clause to check and potentially wrap +// +// Returns: +// - The clause with guaranteed outer parentheses, or empty string if input is empty +func EnsureOuterParentheses(clause string) string { + if clause == "" { + return "" + } + + clause = strings.TrimSpace(clause) + if clause == "" { + return "" + } + + // Check if the clause already has matching outer parentheses + _, hasOuterParens := stripOneMatchingOuterParen(clause) + + // If it already has matching outer parentheses, return as-is + if hasOuterParens { + return clause + } + + // Otherwise, wrap it in parentheses + return "(" + clause + ")" +} + +// containsTopLevelOR checks if a SQL clause contains OR operators at the top level +// (i.e., not inside parentheses or subqueries). This is used to determine if +// outer parentheses should be preserved to prevent OR logic from escaping. +func containsTopLevelOR(clause string) bool { + if clause == "" { + return false + } + + depth := 0 + inSingleQuote := false + inDoubleQuote := false + lowerClause := strings.ToLower(clause) + + for i := 0; i < len(clause); i++ { + ch := clause[i] + + // Track quote state + if ch == '\'' && !inDoubleQuote { + inSingleQuote = !inSingleQuote + continue + } + if ch == '"' && !inSingleQuote { + inDoubleQuote = !inDoubleQuote + continue + } + + // Skip if inside quotes + if inSingleQuote || inDoubleQuote { + continue + } + + // Track parenthesis depth + switch ch { + case '(': + depth++ + case ')': + depth-- + } + + // Only check for OR at depth 0 (not inside parentheses) + if depth == 0 && i+4 <= len(clause) { + // Check for " OR " (case-insensitive) + substring := lowerClause[i : i+4] + if substring == " or " { + return true + } + } + } + + return false +} + // splitByAND splits a WHERE clause by AND operators (case-insensitive) // This is parenthesis-aware and won't split on AND operators inside subqueries func splitByAND(where string) []string { diff --git a/pkg/common/sql_helpers_test.go b/pkg/common/sql_helpers_test.go index 6f2a4ca..acfd831 100644 --- a/pkg/common/sql_helpers_test.go +++ b/pkg/common/sql_helpers_test.go @@ -659,6 +659,179 @@ func TestSanitizeWhereClauseWithModel(t *testing.T) { } } +func TestEnsureOuterParentheses(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + { + name: "no parentheses", + input: "status = 'active'", + expected: "(status = 'active')", + }, + { + name: "already has outer parentheses", + input: "(status = 'active')", + expected: "(status = 'active')", + }, + { + name: "OR condition without parentheses", + input: "status = 'active' OR status = 'pending'", + expected: "(status = 'active' OR status = 'pending')", + }, + { + name: "OR condition with parentheses", + input: "(status = 'active' OR status = 'pending')", + expected: "(status = 'active' OR status = 'pending')", + }, + { + name: "complex condition with nested parentheses", + input: "(status = 'active' OR status = 'pending') AND (age > 18)", + expected: "((status = 'active' OR status = 'pending') AND (age > 18))", + }, + { + name: "empty string", + input: "", + expected: "", + }, + { + name: "whitespace only", + input: " ", + expected: "", + }, + { + name: "mismatched parentheses - adds outer ones", + input: "(status = 'active' OR status = 'pending'", + expected: "((status = 'active' OR status = 'pending')", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := EnsureOuterParentheses(tt.input) + if result != tt.expected { + t.Errorf("EnsureOuterParentheses(%q) = %q; want %q", tt.input, result, tt.expected) + } + }) + } +} + +func TestContainsTopLevelOR(t *testing.T) { + tests := []struct { + name string + input string + expected bool + }{ + { + name: "no OR operator", + input: "status = 'active' AND age > 18", + expected: false, + }, + { + name: "top-level OR", + input: "status = 'active' OR status = 'pending'", + expected: true, + }, + { + name: "OR inside parentheses", + input: "age > 18 AND (status = 'active' OR status = 'pending')", + expected: false, + }, + { + name: "OR in subquery", + input: "id IN (SELECT id FROM users WHERE status = 'active' OR status = 'pending')", + expected: false, + }, + { + name: "OR inside quotes", + input: "comment = 'this OR that'", + expected: false, + }, + { + name: "mixed - top-level OR and nested OR", + input: "name = 'test' OR (status = 'active' OR status = 'pending')", + expected: true, + }, + { + name: "empty string", + input: "", + expected: false, + }, + { + name: "lowercase or", + input: "status = 'active' or status = 'pending'", + expected: true, + }, + { + name: "uppercase OR", + input: "status = 'active' OR status = 'pending'", + expected: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := containsTopLevelOR(tt.input) + if result != tt.expected { + t.Errorf("containsTopLevelOR(%q) = %v; want %v", tt.input, result, tt.expected) + } + }) + } +} + +func TestSanitizeWhereClause_PreservesParenthesesWithOR(t *testing.T) { + tests := []struct { + name string + where string + tableName string + expected string + }{ + { + name: "OR condition with outer parentheses - preserved", + where: "(status = 'active' OR status = 'pending')", + tableName: "users", + expected: "(users.status = 'active' OR users.status = 'pending')", + }, + { + name: "AND condition with outer parentheses - stripped (no OR)", + where: "(status = 'active' AND age > 18)", + tableName: "users", + expected: "users.status = 'active' AND users.age > 18", + }, + { + name: "complex OR with nested conditions", + where: "((status = 'active' OR status = 'pending') AND age > 18)", + tableName: "users", + // Outer parens are stripped, but inner parens with OR are preserved + expected: "(users.status = 'active' OR users.status = 'pending') AND users.age > 18", + }, + { + name: "OR without outer parentheses - no parentheses added by SanitizeWhereClause", + where: "status = 'active' OR status = 'pending'", + tableName: "users", + expected: "users.status = 'active' OR users.status = 'pending'", + }, + { + name: "simple OR with parentheses - preserved", + where: "(users.status = 'active' OR users.status = 'pending')", + tableName: "users", + // Already has correct prefixes, parentheses preserved + expected: "(users.status = 'active' OR users.status = 'pending')", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + prefixedWhere := AddTablePrefixToColumns(tt.where, tt.tableName) + result := SanitizeWhereClause(prefixedWhere, tt.tableName) + if result != tt.expected { + t.Errorf("SanitizeWhereClause(%q, %q) = %q; want %q", tt.where, tt.tableName, result, tt.expected) + } + }) + } +} + func TestAddTablePrefixToColumns_ComplexConditions(t *testing.T) { tests := []struct { name string diff --git a/pkg/resolvespec/handler.go b/pkg/resolvespec/handler.go index a8d32d0..33a1cb6 100644 --- a/pkg/resolvespec/handler.go +++ b/pkg/resolvespec/handler.go @@ -318,6 +318,8 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st if cursorFilter != "" { logger.Debug("Applying cursor filter: %s", cursorFilter) sanitizedCursor := common.SanitizeWhereClause(cursorFilter, reflection.ExtractTableNameOnly(tableName), &options) + // Ensure outer parentheses to prevent OR logic from escaping + sanitizedCursor = common.EnsureOuterParentheses(sanitizedCursor) if sanitizedCursor != "" { query = query.Where(sanitizedCursor) } @@ -1656,6 +1658,8 @@ func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, pre // Build RequestOptions with all preloads to allow references to sibling relations preloadOpts := &common.RequestOptions{Preload: preloads} sanitizedWhere := common.SanitizeWhereClause(preload.Where, reflection.ExtractTableNameOnly(preload.Relation), preloadOpts) + // Ensure outer parentheses to prevent OR logic from escaping + sanitizedWhere = common.EnsureOuterParentheses(sanitizedWhere) if len(sanitizedWhere) > 0 { sq = sq.Where(sanitizedWhere) } diff --git a/pkg/restheadspec/handler.go b/pkg/restheadspec/handler.go index c176bcb..d5f0aa1 100644 --- a/pkg/restheadspec/handler.go +++ b/pkg/restheadspec/handler.go @@ -463,7 +463,8 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st } // Apply filters - validate and adjust for column types first - for i := range options.Filters { + // Group consecutive OR filters together to prevent OR logic from escaping + for i := 0; i < len(options.Filters); { filter := &options.Filters[i] // Validate and adjust filter based on column type @@ -475,8 +476,39 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st logicOp = "AND" } - logger.Debug("Applying filter: %s %s %v (needsCast=%v, logic=%s)", filter.Column, filter.Operator, filter.Value, castInfo.NeedsCast, logicOp) - query = h.applyFilter(query, *filter, tableName, castInfo.NeedsCast, logicOp) + // Check if this is the start of an OR group + if logicOp == "OR" { + // Collect all consecutive OR filters + orFilters := []*common.FilterOption{filter} + orCastInfo := []ColumnCastInfo{castInfo} + + j := i + 1 + for j < len(options.Filters) { + nextFilter := &options.Filters[j] + nextLogicOp := nextFilter.LogicOperator + if nextLogicOp == "" { + nextLogicOp = "AND" + } + if nextLogicOp == "OR" { + nextCastInfo := h.ValidateAndAdjustFilterForColumnType(nextFilter, model) + orFilters = append(orFilters, nextFilter) + orCastInfo = append(orCastInfo, nextCastInfo) + j++ + } else { + break + } + } + + // Apply the OR group as a single grouped condition + logger.Debug("Applying OR filter group with %d conditions", len(orFilters)) + query = h.applyOrFilterGroup(query, orFilters, orCastInfo, tableName) + i = j + } else { + // Single AND filter - apply normally + logger.Debug("Applying filter: %s %s %v (needsCast=%v, logic=%s)", filter.Column, filter.Operator, filter.Value, castInfo.NeedsCast, logicOp) + query = h.applyFilter(query, *filter, tableName, castInfo.NeedsCast, logicOp) + i++ + } } // Apply custom SQL WHERE clause (AND condition) @@ -486,6 +518,8 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st prefixedWhere := common.AddTablePrefixToColumns(options.CustomSQLWhere, reflection.ExtractTableNameOnly(tableName)) // Then sanitize and allow preload table prefixes since custom SQL may reference multiple tables sanitizedWhere := common.SanitizeWhereClause(prefixedWhere, reflection.ExtractTableNameOnly(tableName), &options.RequestOptions) + // Ensure outer parentheses to prevent OR logic from escaping + sanitizedWhere = common.EnsureOuterParentheses(sanitizedWhere) if sanitizedWhere != "" { query = query.Where(sanitizedWhere) } @@ -497,6 +531,8 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st customOr := common.AddTablePrefixToColumns(options.CustomSQLOr, reflection.ExtractTableNameOnly(tableName)) // Sanitize and allow preload table prefixes since custom SQL may reference multiple tables sanitizedOr := common.SanitizeWhereClause(customOr, reflection.ExtractTableNameOnly(tableName), &options.RequestOptions) + // Ensure outer parentheses to prevent OR logic from escaping + sanitizedOr = common.EnsureOuterParentheses(sanitizedOr) if sanitizedOr != "" { query = query.WhereOr(sanitizedOr) } @@ -1996,6 +2032,99 @@ func (h *Handler) applyFilter(query common.SelectQuery, filter common.FilterOpti } } +// applyOrFilterGroup applies a group of OR filters as a single grouped condition +// This ensures OR conditions are properly grouped with parentheses to prevent OR logic from escaping +func (h *Handler) applyOrFilterGroup(query common.SelectQuery, filters []*common.FilterOption, castInfo []ColumnCastInfo, tableName string) common.SelectQuery { + if len(filters) == 0 { + return query + } + + // Build individual filter conditions + conditions := []string{} + args := []interface{}{} + + for i, filter := range filters { + // Qualify the column name with table name if not already qualified + qualifiedColumn := h.qualifyColumnName(filter.Column, tableName) + + // Apply casting to text if needed for non-numeric columns or non-numeric values + if castInfo[i].NeedsCast { + qualifiedColumn = fmt.Sprintf("CAST(%s AS TEXT)", qualifiedColumn) + } + + // Build the condition based on operator + condition, filterArgs := h.buildFilterCondition(qualifiedColumn, filter, tableName) + if condition != "" { + conditions = append(conditions, condition) + args = append(args, filterArgs...) + } + } + + if len(conditions) == 0 { + return query + } + + // Join all conditions with OR and wrap in parentheses + groupedCondition := "(" + strings.Join(conditions, " OR ") + ")" + logger.Debug("Applying grouped OR conditions: %s", groupedCondition) + + // Apply as AND condition (the OR is already inside the parentheses) + return query.Where(groupedCondition, args...) +} + +// buildFilterCondition builds a single filter condition and returns the condition string and args +func (h *Handler) buildFilterCondition(qualifiedColumn string, filter *common.FilterOption, tableName string) (filterStr string, filterInterface []interface{}) { + switch strings.ToLower(filter.Operator) { + case "eq", "equals": + return fmt.Sprintf("%s = ?", qualifiedColumn), []interface{}{filter.Value} + case "neq", "not_equals", "ne": + return fmt.Sprintf("%s != ?", qualifiedColumn), []interface{}{filter.Value} + case "gt", "greater_than": + return fmt.Sprintf("%s > ?", qualifiedColumn), []interface{}{filter.Value} + case "gte", "greater_than_equals", "ge": + return fmt.Sprintf("%s >= ?", qualifiedColumn), []interface{}{filter.Value} + case "lt", "less_than": + return fmt.Sprintf("%s < ?", qualifiedColumn), []interface{}{filter.Value} + case "lte", "less_than_equals", "le": + return fmt.Sprintf("%s <= ?", qualifiedColumn), []interface{}{filter.Value} + case "like": + return fmt.Sprintf("%s LIKE ?", qualifiedColumn), []interface{}{filter.Value} + case "ilike": + return fmt.Sprintf("%s ILIKE ?", qualifiedColumn), []interface{}{filter.Value} + case "in": + return fmt.Sprintf("%s IN (?)", qualifiedColumn), []interface{}{filter.Value} + case "between": + // Handle between operator - exclusive (> val1 AND < val2) + if values, ok := filter.Value.([]interface{}); ok && len(values) == 2 { + return fmt.Sprintf("(%s > ? AND %s < ?)", qualifiedColumn, qualifiedColumn), []interface{}{values[0], values[1]} + } else if values, ok := filter.Value.([]string); ok && len(values) == 2 { + return fmt.Sprintf("(%s > ? AND %s < ?)", qualifiedColumn, qualifiedColumn), []interface{}{values[0], values[1]} + } + logger.Warn("Invalid BETWEEN filter value format") + return "", nil + case "between_inclusive": + // Handle between inclusive operator - inclusive (>= val1 AND <= val2) + if values, ok := filter.Value.([]interface{}); ok && len(values) == 2 { + return fmt.Sprintf("(%s >= ? AND %s <= ?)", qualifiedColumn, qualifiedColumn), []interface{}{values[0], values[1]} + } else if values, ok := filter.Value.([]string); ok && len(values) == 2 { + return fmt.Sprintf("(%s >= ? AND %s <= ?)", qualifiedColumn, qualifiedColumn), []interface{}{values[0], values[1]} + } + logger.Warn("Invalid BETWEEN INCLUSIVE filter value format") + return "", nil + case "is_null", "isnull": + // Check for NULL values - don't use cast for NULL checks + colName := h.qualifyColumnName(filter.Column, tableName) + return fmt.Sprintf("(%s IS NULL OR %s = '')", colName, colName), nil + case "is_not_null", "isnotnull": + // Check for NOT NULL values - don't use cast for NULL checks + colName := h.qualifyColumnName(filter.Column, tableName) + return fmt.Sprintf("(%s IS NOT NULL AND %s != '')", colName, colName), nil + default: + logger.Warn("Unknown filter operator: %s, defaulting to equals", filter.Operator) + return fmt.Sprintf("%s = ?", qualifiedColumn), []interface{}{filter.Value} + } +} + // parseTableName splits a table name that may contain schema into separate schema and table func (h *Handler) parseTableName(fullTableName string) (schema, table string) { if idx := strings.LastIndex(fullTableName, "."); idx != -1 {