mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2025-12-13 17:10:36 +00:00
Fixed for relation preloading
This commit is contained in:
parent
5862016031
commit
e8111c01aa
@ -2,6 +2,7 @@ package common
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
@ -207,6 +208,20 @@ func SanitizeWhereClause(where string, tableName string, options ...*RequestOpti
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} else if tableName != "" && !hasTablePrefix(condToCheck) {
|
||||||
|
// If tableName is provided and the condition DOESN'T have a table prefix,
|
||||||
|
// qualify unambiguous column references to prevent "ambiguous column" errors
|
||||||
|
// when there are multiple joins on the same table (e.g., recursive preloads)
|
||||||
|
columnName := extractUnqualifiedColumnName(condToCheck)
|
||||||
|
if columnName != "" && (validColumns == nil || isValidColumn(columnName, validColumns)) {
|
||||||
|
// Qualify the column with the table name
|
||||||
|
// Be careful to only replace the column name, not other occurrences of the string
|
||||||
|
oldRef := columnName
|
||||||
|
newRef := tableName + "." + columnName
|
||||||
|
// Use word boundary matching to avoid replacing partial matches
|
||||||
|
cond = qualifyColumnInCondition(cond, oldRef, newRef)
|
||||||
|
logger.Debug("Qualified unqualified column in condition: '%s' added table prefix '%s'", oldRef, tableName)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
validConditions = append(validConditions, cond)
|
validConditions = append(validConditions, cond)
|
||||||
@ -483,6 +498,86 @@ func extractTableAndColumn(cond string) (table string, column string) {
|
|||||||
return "", ""
|
return "", ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// extractUnqualifiedColumnName extracts the column name from an unqualified condition
|
||||||
|
// For example: "rid_parentmastertaskitem is null" returns "rid_parentmastertaskitem"
|
||||||
|
// "status = 'active'" returns "status"
|
||||||
|
func extractUnqualifiedColumnName(cond string) string {
|
||||||
|
// Common SQL operators
|
||||||
|
operators := []string{" = ", " != ", " <> ", " > ", " >= ", " < ", " <= ", " LIKE ", " like ", " IN ", " in ", " IS ", " is ", " NOT ", " not "}
|
||||||
|
|
||||||
|
// Find the column reference (left side of the operator)
|
||||||
|
minIdx := -1
|
||||||
|
for _, op := range operators {
|
||||||
|
idx := strings.Index(cond, op)
|
||||||
|
if idx > 0 && (minIdx == -1 || idx < minIdx) {
|
||||||
|
minIdx = idx
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var columnRef string
|
||||||
|
if minIdx > 0 {
|
||||||
|
columnRef = strings.TrimSpace(cond[:minIdx])
|
||||||
|
} else {
|
||||||
|
// No operator found, might be a single column reference
|
||||||
|
parts := strings.Fields(cond)
|
||||||
|
if len(parts) > 0 {
|
||||||
|
columnRef = parts[0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if columnRef == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove any quotes
|
||||||
|
columnRef = strings.Trim(columnRef, "`\"'")
|
||||||
|
|
||||||
|
// Return empty if it contains a dot (already qualified) or function call
|
||||||
|
if strings.Contains(columnRef, ".") || strings.Contains(columnRef, "(") {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
return columnRef
|
||||||
|
}
|
||||||
|
|
||||||
|
// qualifyColumnInCondition replaces an unqualified column name with a qualified one in a condition
|
||||||
|
// Uses word boundaries to avoid partial matches
|
||||||
|
// For example: qualifyColumnInCondition("rid_item is null", "rid_item", "table.rid_item")
|
||||||
|
// returns "table.rid_item is null"
|
||||||
|
func qualifyColumnInCondition(cond, oldRef, newRef string) string {
|
||||||
|
// Use word boundary matching with Go's supported regex syntax
|
||||||
|
// \b matches word boundaries
|
||||||
|
escapedOld := regexp.QuoteMeta(oldRef)
|
||||||
|
pattern := `\b` + escapedOld + `\b`
|
||||||
|
|
||||||
|
re, err := regexp.Compile(pattern)
|
||||||
|
if err != nil {
|
||||||
|
// If regex fails, fall back to simple string replacement
|
||||||
|
logger.Debug("Failed to compile regex for column qualification, using simple replace: %v", err)
|
||||||
|
return strings.Replace(cond, oldRef, newRef, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only replace if the match is not preceded by a dot (to avoid replacing already qualified columns)
|
||||||
|
result := cond
|
||||||
|
matches := re.FindAllStringIndex(cond, -1)
|
||||||
|
|
||||||
|
// Process matches in reverse order to maintain correct indices
|
||||||
|
for i := len(matches) - 1; i >= 0; i-- {
|
||||||
|
match := matches[i]
|
||||||
|
start := match[0]
|
||||||
|
|
||||||
|
// Check if preceded by a dot (already qualified)
|
||||||
|
if start > 0 && cond[start-1] == '.' {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Replace this occurrence
|
||||||
|
result = result[:start] + newRef + result[match[1]:]
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
// findOperatorOutsideParentheses finds the first occurrence of an operator outside of parentheses
|
// findOperatorOutsideParentheses finds the first occurrence of an operator outside of parentheses
|
||||||
// Returns the index of the operator, or -1 if not found or only found inside parentheses
|
// Returns the index of the operator, or -1 if not found or only found inside parentheses
|
||||||
func findOperatorOutsideParentheses(s string, operator string) int {
|
func findOperatorOutsideParentheses(s string, operator string) int {
|
||||||
|
|||||||
@ -33,16 +33,16 @@ func TestSanitizeWhereClause(t *testing.T) {
|
|||||||
expected: "",
|
expected: "",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "valid condition with parentheses - no prefix added",
|
name: "valid condition with parentheses - prefix added to prevent ambiguity",
|
||||||
where: "(status = 'active')",
|
where: "(status = 'active')",
|
||||||
tableName: "users",
|
tableName: "users",
|
||||||
expected: "status = 'active'",
|
expected: "users.status = 'active'",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "mixed trivial and valid conditions - no prefix added",
|
name: "mixed trivial and valid conditions - prefix added",
|
||||||
where: "true AND status = 'active' AND 1=1",
|
where: "true AND status = 'active' AND 1=1",
|
||||||
tableName: "users",
|
tableName: "users",
|
||||||
expected: "status = 'active'",
|
expected: "users.status = 'active'",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "condition with correct table prefix - unchanged",
|
name: "condition with correct table prefix - unchanged",
|
||||||
@ -63,10 +63,10 @@ func TestSanitizeWhereClause(t *testing.T) {
|
|||||||
expected: "users.status = 'active' AND users.age > 18",
|
expected: "users.status = 'active' AND users.age > 18",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "multiple valid conditions without prefix - no prefix added",
|
name: "multiple valid conditions without prefix - prefixes added",
|
||||||
where: "status = 'active' AND age > 18",
|
where: "status = 'active' AND age > 18",
|
||||||
tableName: "users",
|
tableName: "users",
|
||||||
expected: "status = 'active' AND age > 18",
|
expected: "users.status = 'active' AND users.age > 18",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "no table name provided",
|
name: "no table name provided",
|
||||||
@ -90,13 +90,13 @@ func TestSanitizeWhereClause(t *testing.T) {
|
|||||||
name: "mixed case AND operators",
|
name: "mixed case AND operators",
|
||||||
where: "status = 'active' AND age > 18 and name = 'John'",
|
where: "status = 'active' AND age > 18 and name = 'John'",
|
||||||
tableName: "users",
|
tableName: "users",
|
||||||
expected: "status = 'active' AND age > 18 AND name = 'John'",
|
expected: "users.status = 'active' AND users.age > 18 AND users.name = 'John'",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "subquery with ORDER BY and LIMIT - allowed",
|
name: "subquery with ORDER BY and LIMIT - allowed",
|
||||||
where: "id IN (SELECT id FROM users WHERE status = 'active' ORDER BY created_at DESC LIMIT 10)",
|
where: "id IN (SELECT id FROM users WHERE status = 'active' ORDER BY created_at DESC LIMIT 10)",
|
||||||
tableName: "users",
|
tableName: "users",
|
||||||
expected: "id IN (SELECT id FROM users WHERE status = 'active' ORDER BY created_at DESC LIMIT 10)",
|
expected: "users.id IN (SELECT users.id FROM users WHERE status = 'active' ORDER BY created_at DESC LIMIT 10)",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "dangerous DELETE keyword - blocked",
|
name: "dangerous DELETE keyword - blocked",
|
||||||
|
|||||||
@ -746,9 +746,29 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co
|
|||||||
|
|
||||||
// Apply ComputedQL fields if any
|
// Apply ComputedQL fields if any
|
||||||
if len(preload.ComputedQL) > 0 {
|
if len(preload.ComputedQL) > 0 {
|
||||||
|
// Get the base table name from the related model
|
||||||
|
baseTableName := getTableNameFromModel(relatedModel)
|
||||||
|
// Convert the preload relation path to the Bun alias format
|
||||||
|
preloadAlias := relationPathToBunAlias(preload.Relation)
|
||||||
|
|
||||||
|
logger.Debug("Applying computed columns to preload %s (alias: %s, base table: %s)",
|
||||||
|
preload.Relation, preloadAlias, baseTableName)
|
||||||
|
|
||||||
for colName, colExpr := range preload.ComputedQL {
|
for colName, colExpr := range preload.ComputedQL {
|
||||||
|
// Replace table references in the expression with the preload alias
|
||||||
|
// This fixes the ambiguous column reference issue when there are multiple
|
||||||
|
// levels of recursive/nested preloads
|
||||||
|
adjustedExpr := colExpr
|
||||||
|
if baseTableName != "" && preloadAlias != "" {
|
||||||
|
adjustedExpr = replaceTableReferencesInSQL(colExpr, baseTableName, preloadAlias)
|
||||||
|
if adjustedExpr != colExpr {
|
||||||
|
logger.Debug("Adjusted computed column expression for %s: '%s' -> '%s'",
|
||||||
|
colName, colExpr, adjustedExpr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
logger.Debug("Applying computed column to preload %s: %s", preload.Relation, colName)
|
logger.Debug("Applying computed column to preload %s: %s", preload.Relation, colName)
|
||||||
sq = sq.ColumnExpr(fmt.Sprintf("(%s) AS %s", colExpr, colName))
|
sq = sq.ColumnExpr(fmt.Sprintf("(%s) AS %s", adjustedExpr, colName))
|
||||||
// Remove the computed column from selected columns to avoid duplication
|
// Remove the computed column from selected columns to avoid duplication
|
||||||
for colIndex := range preload.Columns {
|
for colIndex := range preload.Columns {
|
||||||
if preload.Columns[colIndex] == colName {
|
if preload.Columns[colIndex] == colName {
|
||||||
@ -841,6 +861,73 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co
|
|||||||
return query
|
return query
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// relationPathToBunAlias converts a relation path like "MAL.MAL.DEF" to the Bun alias format "mal__mal__def"
|
||||||
|
// Bun generates aliases for nested relations by lowercasing and replacing dots with double underscores
|
||||||
|
func relationPathToBunAlias(relationPath string) string {
|
||||||
|
if relationPath == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
// Convert to lowercase and replace dots with double underscores
|
||||||
|
alias := strings.ToLower(relationPath)
|
||||||
|
alias = strings.ReplaceAll(alias, ".", "__")
|
||||||
|
return alias
|
||||||
|
}
|
||||||
|
|
||||||
|
// replaceTableReferencesInSQL replaces references to a base table name in a SQL expression
|
||||||
|
// with the appropriate alias for the current preload level
|
||||||
|
// For example, if baseTableName is "mastertaskitem" and targetAlias is "mal__mal",
|
||||||
|
// it will replace "mastertaskitem.rid_mastertaskitem" with "mal__mal.rid_mastertaskitem"
|
||||||
|
func replaceTableReferencesInSQL(sqlExpr, baseTableName, targetAlias string) string {
|
||||||
|
if sqlExpr == "" || baseTableName == "" || targetAlias == "" {
|
||||||
|
return sqlExpr
|
||||||
|
}
|
||||||
|
|
||||||
|
// Replace both quoted and unquoted table references
|
||||||
|
// Handle patterns like: tablename.column, "tablename".column, tablename."column", "tablename"."column"
|
||||||
|
|
||||||
|
// Pattern 1: tablename.column (unquoted)
|
||||||
|
result := strings.ReplaceAll(sqlExpr, baseTableName+".", targetAlias+".")
|
||||||
|
|
||||||
|
// Pattern 2: "tablename".column or "tablename"."column" (quoted table name)
|
||||||
|
result = strings.ReplaceAll(result, "\""+baseTableName+"\".", "\""+targetAlias+"\".")
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// getTableNameFromModel extracts the table name from a model
|
||||||
|
// It checks the bun tag first, then falls back to converting the struct name to snake_case
|
||||||
|
func getTableNameFromModel(model interface{}) string {
|
||||||
|
if model == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
modelType := reflect.TypeOf(model)
|
||||||
|
|
||||||
|
// Unwrap pointers
|
||||||
|
for modelType != nil && modelType.Kind() == reflect.Ptr {
|
||||||
|
modelType = modelType.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// Look for bun tag on embedded BaseModel
|
||||||
|
for i := 0; i < modelType.NumField(); i++ {
|
||||||
|
field := modelType.Field(i)
|
||||||
|
if field.Anonymous {
|
||||||
|
bunTag := field.Tag.Get("bun")
|
||||||
|
if strings.HasPrefix(bunTag, "table:") {
|
||||||
|
return strings.TrimPrefix(bunTag, "table:")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback: convert struct name to lowercase (simple heuristic)
|
||||||
|
// This handles cases like "MasterTaskItem" -> "mastertaskitem"
|
||||||
|
return strings.ToLower(modelType.Name())
|
||||||
|
}
|
||||||
|
|
||||||
func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, data interface{}, options ExtendedRequestOptions) {
|
func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, data interface{}, options ExtendedRequestOptions) {
|
||||||
// Capture panics and return error response
|
// Capture panics and return error response
|
||||||
defer func() {
|
defer func() {
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user