diff --git a/pkg/common/sql_helpers.go b/pkg/common/sql_helpers.go index 8d62168..5b14ce5 100644 --- a/pkg/common/sql_helpers.go +++ b/pkg/common/sql_helpers.go @@ -2,6 +2,7 @@ package common import ( "fmt" + "regexp" "strings" "github.com/bitechdev/ResolveSpec/pkg/logger" @@ -207,6 +208,20 @@ func SanitizeWhereClause(where string, tableName string, options ...*RequestOpti } } } + } else if tableName != "" && !hasTablePrefix(condToCheck) { + // If tableName is provided and the condition DOESN'T have a table prefix, + // qualify unambiguous column references to prevent "ambiguous column" errors + // when there are multiple joins on the same table (e.g., recursive preloads) + columnName := extractUnqualifiedColumnName(condToCheck) + if columnName != "" && (validColumns == nil || isValidColumn(columnName, validColumns)) { + // Qualify the column with the table name + // Be careful to only replace the column name, not other occurrences of the string + oldRef := columnName + newRef := tableName + "." + columnName + // Use word boundary matching to avoid replacing partial matches + cond = qualifyColumnInCondition(cond, oldRef, newRef) + logger.Debug("Qualified unqualified column in condition: '%s' added table prefix '%s'", oldRef, tableName) + } } validConditions = append(validConditions, cond) @@ -483,6 +498,86 @@ func extractTableAndColumn(cond string) (table string, column string) { return "", "" } +// extractUnqualifiedColumnName extracts the column name from an unqualified condition +// For example: "rid_parentmastertaskitem is null" returns "rid_parentmastertaskitem" +// "status = 'active'" returns "status" +func extractUnqualifiedColumnName(cond string) string { + // Common SQL operators + operators := []string{" = ", " != ", " <> ", " > ", " >= ", " < ", " <= ", " LIKE ", " like ", " IN ", " in ", " IS ", " is ", " NOT ", " not "} + + // Find the column reference (left side of the operator) + minIdx := -1 + for _, op := range operators { + idx := strings.Index(cond, op) + if idx > 0 && (minIdx == -1 || idx < minIdx) { + minIdx = idx + } + } + + var columnRef string + if minIdx > 0 { + columnRef = strings.TrimSpace(cond[:minIdx]) + } else { + // No operator found, might be a single column reference + parts := strings.Fields(cond) + if len(parts) > 0 { + columnRef = parts[0] + } + } + + if columnRef == "" { + return "" + } + + // Remove any quotes + columnRef = strings.Trim(columnRef, "`\"'") + + // Return empty if it contains a dot (already qualified) or function call + if strings.Contains(columnRef, ".") || strings.Contains(columnRef, "(") { + return "" + } + + return columnRef +} + +// qualifyColumnInCondition replaces an unqualified column name with a qualified one in a condition +// Uses word boundaries to avoid partial matches +// For example: qualifyColumnInCondition("rid_item is null", "rid_item", "table.rid_item") +// returns "table.rid_item is null" +func qualifyColumnInCondition(cond, oldRef, newRef string) string { + // Use word boundary matching with Go's supported regex syntax + // \b matches word boundaries + escapedOld := regexp.QuoteMeta(oldRef) + pattern := `\b` + escapedOld + `\b` + + re, err := regexp.Compile(pattern) + if err != nil { + // If regex fails, fall back to simple string replacement + logger.Debug("Failed to compile regex for column qualification, using simple replace: %v", err) + return strings.Replace(cond, oldRef, newRef, 1) + } + + // Only replace if the match is not preceded by a dot (to avoid replacing already qualified columns) + result := cond + matches := re.FindAllStringIndex(cond, -1) + + // Process matches in reverse order to maintain correct indices + for i := len(matches) - 1; i >= 0; i-- { + match := matches[i] + start := match[0] + + // Check if preceded by a dot (already qualified) + if start > 0 && cond[start-1] == '.' { + continue + } + + // Replace this occurrence + result = result[:start] + newRef + result[match[1]:] + } + + return result +} + // findOperatorOutsideParentheses finds the first occurrence of an operator outside of parentheses // Returns the index of the operator, or -1 if not found or only found inside parentheses func findOperatorOutsideParentheses(s string, operator string) int { diff --git a/pkg/common/sql_helpers_test.go b/pkg/common/sql_helpers_test.go index 2874c19..e7cefd4 100644 --- a/pkg/common/sql_helpers_test.go +++ b/pkg/common/sql_helpers_test.go @@ -33,16 +33,16 @@ func TestSanitizeWhereClause(t *testing.T) { expected: "", }, { - name: "valid condition with parentheses - no prefix added", + name: "valid condition with parentheses - prefix added to prevent ambiguity", where: "(status = 'active')", tableName: "users", - expected: "status = 'active'", + expected: "users.status = 'active'", }, { - name: "mixed trivial and valid conditions - no prefix added", + name: "mixed trivial and valid conditions - prefix added", where: "true AND status = 'active' AND 1=1", tableName: "users", - expected: "status = 'active'", + expected: "users.status = 'active'", }, { name: "condition with correct table prefix - unchanged", @@ -63,10 +63,10 @@ func TestSanitizeWhereClause(t *testing.T) { expected: "users.status = 'active' AND users.age > 18", }, { - name: "multiple valid conditions without prefix - no prefix added", + name: "multiple valid conditions without prefix - prefixes added", where: "status = 'active' AND age > 18", tableName: "users", - expected: "status = 'active' AND age > 18", + expected: "users.status = 'active' AND users.age > 18", }, { name: "no table name provided", @@ -90,13 +90,13 @@ func TestSanitizeWhereClause(t *testing.T) { name: "mixed case AND operators", where: "status = 'active' AND age > 18 and name = 'John'", tableName: "users", - expected: "status = 'active' AND age > 18 AND name = 'John'", + expected: "users.status = 'active' AND users.age > 18 AND users.name = 'John'", }, { name: "subquery with ORDER BY and LIMIT - allowed", where: "id IN (SELECT id FROM users WHERE status = 'active' ORDER BY created_at DESC LIMIT 10)", tableName: "users", - expected: "id IN (SELECT id FROM users WHERE status = 'active' ORDER BY created_at DESC LIMIT 10)", + expected: "users.id IN (SELECT users.id FROM users WHERE status = 'active' ORDER BY created_at DESC LIMIT 10)", }, { name: "dangerous DELETE keyword - blocked", diff --git a/pkg/restheadspec/handler.go b/pkg/restheadspec/handler.go index 5b55730..55b3fcb 100644 --- a/pkg/restheadspec/handler.go +++ b/pkg/restheadspec/handler.go @@ -746,9 +746,29 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co // Apply ComputedQL fields if any if len(preload.ComputedQL) > 0 { + // Get the base table name from the related model + baseTableName := getTableNameFromModel(relatedModel) + // Convert the preload relation path to the Bun alias format + preloadAlias := relationPathToBunAlias(preload.Relation) + + logger.Debug("Applying computed columns to preload %s (alias: %s, base table: %s)", + preload.Relation, preloadAlias, baseTableName) + for colName, colExpr := range preload.ComputedQL { + // Replace table references in the expression with the preload alias + // This fixes the ambiguous column reference issue when there are multiple + // levels of recursive/nested preloads + adjustedExpr := colExpr + if baseTableName != "" && preloadAlias != "" { + adjustedExpr = replaceTableReferencesInSQL(colExpr, baseTableName, preloadAlias) + if adjustedExpr != colExpr { + logger.Debug("Adjusted computed column expression for %s: '%s' -> '%s'", + colName, colExpr, adjustedExpr) + } + } + logger.Debug("Applying computed column to preload %s: %s", preload.Relation, colName) - sq = sq.ColumnExpr(fmt.Sprintf("(%s) AS %s", colExpr, colName)) + sq = sq.ColumnExpr(fmt.Sprintf("(%s) AS %s", adjustedExpr, colName)) // Remove the computed column from selected columns to avoid duplication for colIndex := range preload.Columns { if preload.Columns[colIndex] == colName { @@ -841,6 +861,73 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co return query } +// relationPathToBunAlias converts a relation path like "MAL.MAL.DEF" to the Bun alias format "mal__mal__def" +// Bun generates aliases for nested relations by lowercasing and replacing dots with double underscores +func relationPathToBunAlias(relationPath string) string { + if relationPath == "" { + return "" + } + // Convert to lowercase and replace dots with double underscores + alias := strings.ToLower(relationPath) + alias = strings.ReplaceAll(alias, ".", "__") + return alias +} + +// replaceTableReferencesInSQL replaces references to a base table name in a SQL expression +// with the appropriate alias for the current preload level +// For example, if baseTableName is "mastertaskitem" and targetAlias is "mal__mal", +// it will replace "mastertaskitem.rid_mastertaskitem" with "mal__mal.rid_mastertaskitem" +func replaceTableReferencesInSQL(sqlExpr, baseTableName, targetAlias string) string { + if sqlExpr == "" || baseTableName == "" || targetAlias == "" { + return sqlExpr + } + + // Replace both quoted and unquoted table references + // Handle patterns like: tablename.column, "tablename".column, tablename."column", "tablename"."column" + + // Pattern 1: tablename.column (unquoted) + result := strings.ReplaceAll(sqlExpr, baseTableName+".", targetAlias+".") + + // Pattern 2: "tablename".column or "tablename"."column" (quoted table name) + result = strings.ReplaceAll(result, "\""+baseTableName+"\".", "\""+targetAlias+"\".") + + return result +} + +// getTableNameFromModel extracts the table name from a model +// It checks the bun tag first, then falls back to converting the struct name to snake_case +func getTableNameFromModel(model interface{}) string { + if model == nil { + return "" + } + + modelType := reflect.TypeOf(model) + + // Unwrap pointers + for modelType != nil && modelType.Kind() == reflect.Ptr { + modelType = modelType.Elem() + } + + if modelType == nil || modelType.Kind() != reflect.Struct { + return "" + } + + // Look for bun tag on embedded BaseModel + for i := 0; i < modelType.NumField(); i++ { + field := modelType.Field(i) + if field.Anonymous { + bunTag := field.Tag.Get("bun") + if strings.HasPrefix(bunTag, "table:") { + return strings.TrimPrefix(bunTag, "table:") + } + } + } + + // Fallback: convert struct name to lowercase (simple heuristic) + // This handles cases like "MasterTaskItem" -> "mastertaskitem" + return strings.ToLower(modelType.Name()) +} + func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, data interface{}, options ExtendedRequestOptions) { // Capture panics and return error response defer func() {