mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2025-12-31 17:28:58 +00:00
Enhanced the SanitizeWhereClause function
This commit is contained in:
@@ -78,17 +78,22 @@ func IsTrivialCondition(cond string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// SanitizeWhereClause removes trivial conditions and optionally prefixes table/relation names to columns
|
||||
// SanitizeWhereClause removes trivial conditions and fixes incorrect table prefixes
|
||||
// This function should be used everywhere a WHERE statement is sent to ensure clean, efficient SQL
|
||||
//
|
||||
// Parameters:
|
||||
// - where: The WHERE clause string to sanitize
|
||||
// - tableName: Optional table/relation name to prefix to column references (empty string to skip prefixing)
|
||||
// - tableName: The correct table/relation name to use when fixing incorrect prefixes
|
||||
// - options: Optional RequestOptions containing preload relations that should be allowed as valid prefixes
|
||||
//
|
||||
// Returns:
|
||||
// - The sanitized WHERE clause with trivial conditions removed and columns optionally prefixed
|
||||
// - The sanitized WHERE clause with trivial conditions removed and incorrect prefixes fixed
|
||||
// - An empty string if all conditions were trivial or the input was empty
|
||||
func SanitizeWhereClause(where string, tableName string) string {
|
||||
//
|
||||
// Note: This function will NOT add prefixes to unprefixed columns. It will only fix
|
||||
// incorrect prefixes (e.g., wrong_table.column -> correct_table.column), unless the
|
||||
// prefix matches a preloaded relation name, in which case it's left unchanged.
|
||||
func SanitizeWhereClause(where string, tableName string, options ...*RequestOptions) string {
|
||||
if where == "" {
|
||||
return ""
|
||||
}
|
||||
@@ -104,6 +109,22 @@ func SanitizeWhereClause(where string, tableName string) string {
|
||||
validColumns = getValidColumnsForTable(tableName)
|
||||
}
|
||||
|
||||
// Build a set of allowed table prefixes (main table + preloaded relations)
|
||||
allowedPrefixes := make(map[string]bool)
|
||||
if tableName != "" {
|
||||
allowedPrefixes[tableName] = true
|
||||
}
|
||||
|
||||
// Add preload relation names as allowed prefixes
|
||||
if len(options) > 0 && options[0] != nil {
|
||||
for pi := range options[0].Preload {
|
||||
if options[0].Preload[pi].Relation != "" {
|
||||
allowedPrefixes[options[0].Preload[pi].Relation] = true
|
||||
logger.Debug("Added preload relation '%s' as allowed table prefix", options[0].Preload[pi].Relation)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Split by AND to handle multiple conditions
|
||||
conditions := splitByAND(where)
|
||||
|
||||
@@ -124,22 +145,23 @@ func SanitizeWhereClause(where string, tableName string) string {
|
||||
continue
|
||||
}
|
||||
|
||||
// If tableName is provided and the condition doesn't already have a table prefix,
|
||||
// attempt to add it
|
||||
if tableName != "" && !hasTablePrefix(condToCheck) {
|
||||
// Check if this is a SQL expression/literal that shouldn't be prefixed
|
||||
if !IsSQLExpression(strings.ToLower(condToCheck)) {
|
||||
// Extract the column name and prefix it
|
||||
columnName := ExtractColumnName(condToCheck)
|
||||
if columnName != "" {
|
||||
// Only prefix if this is a valid column in the model
|
||||
// If we don't have model info (validColumns is nil), prefix anyway for backward compatibility
|
||||
// If tableName is provided and the condition HAS a table prefix, check if it's correct
|
||||
if tableName != "" && hasTablePrefix(condToCheck) {
|
||||
// Extract the current prefix and column name
|
||||
currentPrefix, columnName := extractTableAndColumn(condToCheck)
|
||||
|
||||
if currentPrefix != "" && columnName != "" {
|
||||
// Check if the prefix is allowed (main table or preload relation)
|
||||
if !allowedPrefixes[currentPrefix] {
|
||||
// Prefix is not in the allowed list - only fix if it's a valid column in the main table
|
||||
if validColumns == nil || isValidColumn(columnName, validColumns) {
|
||||
// Replace in the original condition (without stripped parens)
|
||||
cond = strings.Replace(cond, columnName, tableName+"."+columnName, 1)
|
||||
logger.Debug("Prefixed column in condition: '%s'", cond)
|
||||
// Replace the incorrect prefix with the correct main table name
|
||||
oldRef := currentPrefix + "." + columnName
|
||||
newRef := tableName + "." + columnName
|
||||
cond = strings.Replace(cond, oldRef, newRef, 1)
|
||||
logger.Debug("Fixed incorrect table prefix in condition: '%s' -> '%s'", oldRef, newRef)
|
||||
} else {
|
||||
logger.Debug("Skipping prefix for '%s' - not a valid column in model", columnName)
|
||||
logger.Debug("Skipping prefix fix for '%s.%s' - not a valid column in main table (might be preload relation)", currentPrefix, columnName)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -288,6 +310,53 @@ func getValidColumnsForTable(tableName string) map[string]bool {
|
||||
return columnMap
|
||||
}
|
||||
|
||||
// extractTableAndColumn extracts the table prefix and column name from a qualified reference
|
||||
// For example: "users.status = 'active'" returns ("users", "status")
|
||||
// Returns empty strings if no table prefix is found
|
||||
func extractTableAndColumn(cond string) (table string, column string) {
|
||||
// Common SQL operators to find the column reference
|
||||
operators := []string{" = ", " != ", " <> ", " > ", " >= ", " < ", " <= ", " LIKE ", " like ", " IN ", " in ", " IS ", " is "}
|
||||
|
||||
var columnRef string
|
||||
|
||||
// Find the column reference (left side of the operator)
|
||||
for _, op := range operators {
|
||||
if idx := strings.Index(cond, op); idx > 0 {
|
||||
columnRef = strings.TrimSpace(cond[:idx])
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// If no operator found, the whole condition might be the column reference
|
||||
if columnRef == "" {
|
||||
parts := strings.Fields(cond)
|
||||
if len(parts) > 0 {
|
||||
columnRef = parts[0]
|
||||
}
|
||||
}
|
||||
|
||||
if columnRef == "" {
|
||||
return "", ""
|
||||
}
|
||||
|
||||
// Remove any quotes
|
||||
columnRef = strings.Trim(columnRef, "`\"'")
|
||||
|
||||
// Check if it contains a dot (qualified reference)
|
||||
if dotIdx := strings.LastIndex(columnRef, "."); dotIdx > 0 {
|
||||
table = columnRef[:dotIdx]
|
||||
column = columnRef[dotIdx+1:]
|
||||
|
||||
// Remove quotes from table and column if present
|
||||
table = strings.Trim(table, "`\"'")
|
||||
column = strings.Trim(column, "`\"'")
|
||||
|
||||
return table, column
|
||||
}
|
||||
|
||||
return "", ""
|
||||
}
|
||||
|
||||
// isValidColumn checks if a column name exists in the valid columns map
|
||||
// Handles case-insensitive comparison
|
||||
func isValidColumn(columnName string, validColumns map[string]bool) bool {
|
||||
|
||||
Reference in New Issue
Block a user