From 1cd04b70833d232dfe6a54ede1967439a4754be5 Mon Sep 17 00:00:00 2001 From: Hein Date: Thu, 20 Nov 2025 17:02:27 +0200 Subject: [PATCH] Better where clause handling for preloads --- pkg/common/sql_helpers.go | 136 ++++++++++++++++++++++++++++++++++++ pkg/resolvespec/handler.go | 65 +---------------- pkg/restheadspec/handler.go | 65 +---------------- 3 files changed, 138 insertions(+), 128 deletions(-) create mode 100644 pkg/common/sql_helpers.go diff --git a/pkg/common/sql_helpers.go b/pkg/common/sql_helpers.go new file mode 100644 index 0000000..04aa3f2 --- /dev/null +++ b/pkg/common/sql_helpers.go @@ -0,0 +1,136 @@ +package common + +import ( + "fmt" + "strings" + + "github.com/bitechdev/ResolveSpec/pkg/logger" +) + +// ValidateAndFixPreloadWhere validates that the WHERE clause for a preload contains +// the relation prefix (alias). If not present, it attempts to add it to column references. +// Returns the fixed WHERE clause and an error if it cannot be safely fixed. +func ValidateAndFixPreloadWhere(where string, relationName string) (string, error) { + if where == "" { + return where, nil + } + + // Check if the relation name is already present in the WHERE clause + lowerWhere := strings.ToLower(where) + lowerRelation := strings.ToLower(relationName) + + // Check for patterns like "relation.", "relation ", or just "relation" followed by a dot + if strings.Contains(lowerWhere, lowerRelation+".") || + strings.Contains(lowerWhere, "`"+lowerRelation+"`.") || + strings.Contains(lowerWhere, "\""+lowerRelation+"\".") { + // Relation prefix is already present + return where, nil + } + + // If the WHERE clause is complex (contains OR, parentheses, subqueries, etc.), + // we can't safely auto-fix it - require explicit prefix + if strings.Contains(lowerWhere, " or ") || + strings.Contains(where, "(") || + strings.Contains(where, ")") { + return "", fmt.Errorf("preload WHERE condition must reference the relation '%s' (e.g., '%s.column_name'). Complex WHERE clauses with OR/parentheses must explicitly use the relation prefix", relationName, relationName) + } + + // Try to add the relation prefix to simple column references + // This handles basic cases like "column = value" or "column = value AND other_column = value" + // Split by AND to handle multiple conditions (case-insensitive) + originalConditions := strings.Split(where, " AND ") + + // If uppercase split didn't work, try lowercase + if len(originalConditions) == 1 { + originalConditions = strings.Split(where, " and ") + } + + fixedConditions := make([]string, 0, len(originalConditions)) + + for _, cond := range originalConditions { + cond = strings.TrimSpace(cond) + if cond == "" { + continue + } + + // Check if this condition already has a table prefix (contains a dot) + if strings.Contains(cond, ".") { + fixedConditions = append(fixedConditions, cond) + continue + } + + // Check if this is a SQL expression/literal that shouldn't be prefixed + lowerCond := strings.ToLower(strings.TrimSpace(cond)) + if IsSQLExpression(lowerCond) { + // Don't prefix SQL expressions like "true", "false", "1=1", etc. + fixedConditions = append(fixedConditions, cond) + continue + } + + // Extract the column name (first identifier before operator) + columnName := ExtractColumnName(cond) + if columnName == "" { + // Can't identify column name, require explicit prefix + return "", fmt.Errorf("preload WHERE condition must reference the relation '%s' (e.g., '%s.column_name'). Cannot auto-fix condition: %s", relationName, relationName, cond) + } + + // Add relation prefix to the column name only + fixedCond := strings.Replace(cond, columnName, relationName+"."+columnName, 1) + fixedConditions = append(fixedConditions, fixedCond) + } + + fixedWhere := strings.Join(fixedConditions, " AND ") + logger.Debug("Auto-fixed preload WHERE clause: '%s' -> '%s'", where, fixedWhere) + return fixedWhere, nil +} + +// IsSQLExpression checks if a condition is a SQL expression that shouldn't be prefixed +func IsSQLExpression(cond string) bool { + // Common SQL literals and expressions + sqlLiterals := []string{"true", "false", "null", "1=1", "1 = 1", "0=0", "0 = 0"} + for _, literal := range sqlLiterals { + if cond == literal { + return true + } + } + return false +} + +// ExtractColumnName extracts the column name from a WHERE condition +// For example: "status = 'active'" returns "status" +func ExtractColumnName(cond string) string { + // Common SQL operators + operators := []string{" = ", " != ", " <> ", " > ", " >= ", " < ", " <= ", " LIKE ", " like ", " IN ", " in ", " IS ", " is "} + + for _, op := range operators { + if idx := strings.Index(cond, op); idx > 0 { + columnName := strings.TrimSpace(cond[:idx]) + // Remove quotes if present + columnName = strings.Trim(columnName, "`\"'") + return columnName + } + } + + // If no operator found, check if it's a simple identifier (for boolean columns) + parts := strings.Fields(cond) + if len(parts) > 0 { + columnName := strings.Trim(parts[0], "`\"'") + // Check if it's a valid identifier (not a SQL keyword) + if !IsSQLKeyword(strings.ToLower(columnName)) { + return columnName + } + } + + return "" +} + +// IsSQLKeyword checks if a string is a SQL keyword that shouldn't be treated as a column name +func IsSQLKeyword(word string) bool { + keywords := []string{"select", "from", "where", "and", "or", "not", "in", "is", "null", "true", "false", "like", "between", "exists"} + for _, kw := range keywords { + if word == kw { + return true + } + } + return false +} diff --git a/pkg/resolvespec/handler.go b/pkg/resolvespec/handler.go index 0e532cb..2372733 100644 --- a/pkg/resolvespec/handler.go +++ b/pkg/resolvespec/handler.go @@ -1105,69 +1105,6 @@ type relationshipInfo struct { relatedModel interface{} } -// validateAndFixPreloadWhere validates that the WHERE clause for a preload contains -// the relation prefix (alias). If not present, it attempts to add it to column references. -// Returns the fixed WHERE clause and an error if it cannot be safely fixed. -func (h *Handler) validateAndFixPreloadWhere(where string, relationName string) (string, error) { - if where == "" { - return where, nil - } - - // Check if the relation name is already present in the WHERE clause - lowerWhere := strings.ToLower(where) - lowerRelation := strings.ToLower(relationName) - - // Check for patterns like "relation.", "relation ", or just "relation" followed by a dot - if strings.Contains(lowerWhere, lowerRelation+".") || - strings.Contains(lowerWhere, "`"+lowerRelation+"`.") || - strings.Contains(lowerWhere, "\""+lowerRelation+"\".") { - // Relation prefix is already present - return where, nil - } - - // If the WHERE clause is complex (contains OR, parentheses, subqueries, etc.), - // we can't safely auto-fix it - require explicit prefix - if strings.Contains(lowerWhere, " or ") || - strings.Contains(where, "(") || - strings.Contains(where, ")") { - return "", fmt.Errorf("preload WHERE condition must reference the relation '%s' (e.g., '%s.column_name'). Complex WHERE clauses with OR/parentheses must explicitly use the relation prefix", relationName, relationName) - } - - // Try to add the relation prefix to simple column references - // This handles basic cases like "column = value" or "column = value AND other_column = value" - // Split by AND to handle multiple conditions (case-insensitive) - originalConditions := strings.Split(where, " AND ") - - // If uppercase split didn't work, try lowercase - if len(originalConditions) == 1 { - originalConditions = strings.Split(where, " and ") - } - - fixedConditions := make([]string, 0, len(originalConditions)) - - for _, cond := range originalConditions { - cond = strings.TrimSpace(cond) - if cond == "" { - continue - } - - // Check if this condition already has a table prefix (contains a dot) - if strings.Contains(cond, ".") { - fixedConditions = append(fixedConditions, cond) - continue - } - - // Add relation prefix to the column name - // This prefixes the entire condition with "relationName." - fixedCond := fmt.Sprintf("%s.%s", relationName, cond) - fixedConditions = append(fixedConditions, fixedCond) - } - - fixedWhere := strings.Join(fixedConditions, " AND ") - logger.Debug("Auto-fixed preload WHERE clause: '%s' -> '%s'", where, fixedWhere) - return fixedWhere, nil -} - func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, preloads []common.PreloadOption) common.SelectQuery { modelType := reflect.TypeOf(model) @@ -1197,7 +1134,7 @@ func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, pre // Validate and fix WHERE clause to ensure it contains the relation prefix if len(preload.Where) > 0 { - fixedWhere, err := h.validateAndFixPreloadWhere(preload.Where, relationFieldName) + fixedWhere, err := common.ValidateAndFixPreloadWhere(preload.Where, relationFieldName) if err != nil { logger.Error("Invalid preload WHERE clause for relation '%s': %v", relationFieldName, err) panic(fmt.Errorf("invalid preload WHERE clause for relation '%s': %w", relationFieldName, err)) diff --git a/pkg/restheadspec/handler.go b/pkg/restheadspec/handler.go index 29b201c..a1777ca 100644 --- a/pkg/restheadspec/handler.go +++ b/pkg/restheadspec/handler.go @@ -200,69 +200,6 @@ func (h *Handler) HandleGet(w common.ResponseWriter, r common.Request, params ma // parseOptionsFromHeaders is now implemented in headers.go -// validateAndFixPreloadWhere validates that the WHERE clause for a preload contains -// the relation prefix (alias). If not present, it attempts to add it to column references. -// Returns the fixed WHERE clause and an error if it cannot be safely fixed. -func (h *Handler) validateAndFixPreloadWhere(where string, relationName string) (string, error) { - if where == "" { - return where, nil - } - - // Check if the relation name is already present in the WHERE clause - lowerWhere := strings.ToLower(where) - lowerRelation := strings.ToLower(relationName) - - // Check for patterns like "relation.", "relation ", or just "relation" followed by a dot - if strings.Contains(lowerWhere, lowerRelation+".") || - strings.Contains(lowerWhere, "`"+lowerRelation+"`.") || - strings.Contains(lowerWhere, "\""+lowerRelation+"\".") { - // Relation prefix is already present - return where, nil - } - - // If the WHERE clause is complex (contains OR, parentheses, subqueries, etc.), - // we can't safely auto-fix it - require explicit prefix - if strings.Contains(lowerWhere, " or ") || - strings.Contains(where, "(") || - strings.Contains(where, ")") { - return "", fmt.Errorf("preload WHERE condition must reference the relation '%s' (e.g., '%s.column_name'). Complex WHERE clauses with OR/parentheses must explicitly use the relation prefix", relationName, relationName) - } - - // Try to add the relation prefix to simple column references - // This handles basic cases like "column = value" or "column = value AND other_column = value" - // Split by AND to handle multiple conditions (case-insensitive) - originalConditions := strings.Split(where, " AND ") - - // If uppercase split didn't work, try lowercase - if len(originalConditions) == 1 { - originalConditions = strings.Split(where, " and ") - } - - fixedConditions := make([]string, 0, len(originalConditions)) - - for _, cond := range originalConditions { - cond = strings.TrimSpace(cond) - if cond == "" { - continue - } - - // Check if this condition already has a table prefix (contains a dot) - if strings.Contains(cond, ".") { - fixedConditions = append(fixedConditions, cond) - continue - } - - // Add relation prefix to the column name - // This prefixes the entire condition with "relationName." - fixedCond := fmt.Sprintf("%s.%s", relationName, cond) - fixedConditions = append(fixedConditions, fixedCond) - } - - fixedWhere := strings.Join(fixedConditions, " AND ") - logger.Debug("Auto-fixed preload WHERE clause: '%s' -> '%s'", where, fixedWhere) - return fixedWhere, nil -} - func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id string, options ExtendedRequestOptions) { // Capture panics and return error response defer func() { @@ -410,7 +347,7 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st // Validate and fix WHERE clause to ensure it contains the relation prefix if len(preload.Where) > 0 { - fixedWhere, err := h.validateAndFixPreloadWhere(preload.Where, preload.Relation) + fixedWhere, err := common.ValidateAndFixPreloadWhere(preload.Where, preload.Relation) if err != nil { logger.Error("Invalid preload WHERE clause for relation '%s': %v", preload.Relation, err) h.sendError(w, http.StatusBadRequest, "invalid_preload_where",