mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2025-12-13 17:10:36 +00:00
368 lines
11 KiB
Go
368 lines
11 KiB
Go
package common
|
|
|
|
import (
|
|
"strings"
|
|
|
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
|
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
|
|
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
|
)
|
|
|
|
// ValidateAndFixPreloadWhere validates and normalizes WHERE clauses for preloads
|
|
//
|
|
// NOTE: For preload queries, table aliases from the parent query are not valid since
|
|
// the preload executes as a separate query with its own table alias. This function
|
|
// now simply validates basic syntax without requiring or adding prefixes.
|
|
// The actual alias normalization happens in the database adapter layer.
|
|
//
|
|
// Returns the WHERE clause and an error if it contains obviously invalid syntax.
|
|
func ValidateAndFixPreloadWhere(where string, relationName string) (string, error) {
|
|
if where == "" {
|
|
return where, nil
|
|
}
|
|
|
|
where = strings.TrimSpace(where)
|
|
|
|
// Just do basic validation - don't require or add prefixes
|
|
// The database adapter will handle alias normalization
|
|
|
|
// Check if the WHERE clause contains any qualified column references
|
|
// If it does, log a debug message but don't fail - let the adapter handle it
|
|
if strings.Contains(where, ".") {
|
|
logger.Debug("Preload WHERE clause for '%s' contains qualified column references: '%s'. "+
|
|
"Note: In preload context, table aliases from parent query are not available. "+
|
|
"The database adapter will normalize aliases automatically.", relationName, where)
|
|
}
|
|
|
|
// Validate that it's not empty or just whitespace
|
|
if where == "" {
|
|
return where, nil
|
|
}
|
|
|
|
// Return the WHERE clause as-is
|
|
// The BunSelectQuery.Where() method will handle alias normalization via normalizeTableAlias()
|
|
return where, nil
|
|
}
|
|
|
|
// IsSQLExpression checks if a condition is a SQL expression that shouldn't be prefixed
|
|
func IsSQLExpression(cond string) bool {
|
|
// Common SQL literals and expressions
|
|
sqlLiterals := []string{"true", "false", "null", "1=1", "1 = 1", "0=0", "0 = 0"}
|
|
for _, literal := range sqlLiterals {
|
|
if cond == literal {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// IsTrivialCondition checks if a condition is trivial and always evaluates to true
|
|
// These conditions should be removed from WHERE clauses as they have no filtering effect
|
|
func IsTrivialCondition(cond string) bool {
|
|
cond = strings.TrimSpace(cond)
|
|
lowerCond := strings.ToLower(cond)
|
|
|
|
// Conditions that always evaluate to true
|
|
trivialConditions := []string{
|
|
"1=1", "1 = 1", "1= 1", "1 =1",
|
|
"true", "true = true", "true=true", "true= true", "true =true",
|
|
"0=0", "0 = 0", "0= 0", "0 =0",
|
|
}
|
|
|
|
for _, trivial := range trivialConditions {
|
|
if lowerCond == trivial {
|
|
return true
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
// SanitizeWhereClause removes trivial conditions and fixes incorrect table prefixes
|
|
// This function should be used everywhere a WHERE statement is sent to ensure clean, efficient SQL
|
|
//
|
|
// Parameters:
|
|
// - where: The WHERE clause string to sanitize
|
|
// - tableName: The correct table/relation name to use when fixing incorrect prefixes
|
|
// - options: Optional RequestOptions containing preload relations that should be allowed as valid prefixes
|
|
//
|
|
// Returns:
|
|
// - The sanitized WHERE clause with trivial conditions removed and incorrect prefixes fixed
|
|
// - An empty string if all conditions were trivial or the input was empty
|
|
//
|
|
// Note: This function will NOT add prefixes to unprefixed columns. It will only fix
|
|
// incorrect prefixes (e.g., wrong_table.column -> correct_table.column), unless the
|
|
// prefix matches a preloaded relation name, in which case it's left unchanged.
|
|
func SanitizeWhereClause(where string, tableName string, options ...*RequestOptions) string {
|
|
if where == "" {
|
|
return ""
|
|
}
|
|
|
|
where = strings.TrimSpace(where)
|
|
|
|
// Strip outer parentheses and re-trim
|
|
where = stripOuterParentheses(where)
|
|
|
|
// Get valid columns from the model if tableName is provided
|
|
var validColumns map[string]bool
|
|
if tableName != "" {
|
|
validColumns = getValidColumnsForTable(tableName)
|
|
}
|
|
|
|
// Build a set of allowed table prefixes (main table + preloaded relations)
|
|
allowedPrefixes := make(map[string]bool)
|
|
if tableName != "" {
|
|
allowedPrefixes[tableName] = true
|
|
}
|
|
|
|
// Add preload relation names as allowed prefixes
|
|
if len(options) > 0 && options[0] != nil {
|
|
for pi := range options[0].Preload {
|
|
if options[0].Preload[pi].Relation != "" {
|
|
allowedPrefixes[options[0].Preload[pi].Relation] = true
|
|
logger.Debug("Added preload relation '%s' as allowed table prefix", options[0].Preload[pi].Relation)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Split by AND to handle multiple conditions
|
|
conditions := splitByAND(where)
|
|
|
|
validConditions := make([]string, 0, len(conditions))
|
|
|
|
for _, cond := range conditions {
|
|
cond = strings.TrimSpace(cond)
|
|
if cond == "" {
|
|
continue
|
|
}
|
|
|
|
// Strip parentheses from the condition before checking
|
|
condToCheck := stripOuterParentheses(cond)
|
|
|
|
// Skip trivial conditions that always evaluate to true
|
|
if IsTrivialCondition(condToCheck) {
|
|
logger.Debug("Removing trivial condition: '%s'", cond)
|
|
continue
|
|
}
|
|
|
|
// If tableName is provided and the condition HAS a table prefix, check if it's correct
|
|
if tableName != "" && hasTablePrefix(condToCheck) {
|
|
// Extract the current prefix and column name
|
|
currentPrefix, columnName := extractTableAndColumn(condToCheck)
|
|
|
|
if currentPrefix != "" && columnName != "" {
|
|
// Check if the prefix is allowed (main table or preload relation)
|
|
if !allowedPrefixes[currentPrefix] {
|
|
// Prefix is not in the allowed list - only fix if it's a valid column in the main table
|
|
if validColumns == nil || isValidColumn(columnName, validColumns) {
|
|
// Replace the incorrect prefix with the correct main table name
|
|
oldRef := currentPrefix + "." + columnName
|
|
newRef := tableName + "." + columnName
|
|
cond = strings.Replace(cond, oldRef, newRef, 1)
|
|
logger.Debug("Fixed incorrect table prefix in condition: '%s' -> '%s'", oldRef, newRef)
|
|
} else {
|
|
logger.Debug("Skipping prefix fix for '%s.%s' - not a valid column in main table (might be preload relation)", currentPrefix, columnName)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
validConditions = append(validConditions, cond)
|
|
}
|
|
|
|
if len(validConditions) == 0 {
|
|
return ""
|
|
}
|
|
|
|
result := strings.Join(validConditions, " AND ")
|
|
|
|
if result != where {
|
|
logger.Debug("Sanitized WHERE clause: '%s' -> '%s'", where, result)
|
|
}
|
|
|
|
return result
|
|
}
|
|
|
|
// stripOuterParentheses removes matching outer parentheses from a string
|
|
// It handles nested parentheses correctly
|
|
func stripOuterParentheses(s string) string {
|
|
s = strings.TrimSpace(s)
|
|
|
|
for {
|
|
if len(s) < 2 || s[0] != '(' || s[len(s)-1] != ')' {
|
|
return s
|
|
}
|
|
|
|
// Check if these parentheses match (i.e., they're the outermost pair)
|
|
depth := 0
|
|
matched := false
|
|
for i := 0; i < len(s); i++ {
|
|
switch s[i] {
|
|
case '(':
|
|
depth++
|
|
case ')':
|
|
depth--
|
|
if depth == 0 && i == len(s)-1 {
|
|
matched = true
|
|
} else if depth == 0 {
|
|
// Found a closing paren before the end, so outer parens don't match
|
|
return s
|
|
}
|
|
}
|
|
}
|
|
|
|
if !matched {
|
|
return s
|
|
}
|
|
|
|
// Strip the outer parentheses and continue
|
|
s = strings.TrimSpace(s[1 : len(s)-1])
|
|
}
|
|
}
|
|
|
|
// splitByAND splits a WHERE clause by AND operators (case-insensitive)
|
|
// This is a simple split that doesn't handle nested parentheses or complex expressions
|
|
func splitByAND(where string) []string {
|
|
// First try uppercase AND
|
|
conditions := strings.Split(where, " AND ")
|
|
|
|
// If we didn't split on uppercase, try lowercase
|
|
if len(conditions) == 1 {
|
|
conditions = strings.Split(where, " and ")
|
|
}
|
|
|
|
// If we still didn't split, try mixed case
|
|
if len(conditions) == 1 {
|
|
conditions = strings.Split(where, " And ")
|
|
}
|
|
|
|
return conditions
|
|
}
|
|
|
|
// hasTablePrefix checks if a condition already has a table/relation prefix (contains a dot)
|
|
func hasTablePrefix(cond string) bool {
|
|
// Look for patterns like "table.column" or "`table`.`column`" or "\"table\".\"column\""
|
|
return strings.Contains(cond, ".")
|
|
}
|
|
|
|
// ExtractColumnName extracts the column name from a WHERE condition
|
|
// For example: "status = 'active'" returns "status"
|
|
func ExtractColumnName(cond string) string {
|
|
// Common SQL operators
|
|
operators := []string{" = ", " != ", " <> ", " > ", " >= ", " < ", " <= ", " LIKE ", " like ", " IN ", " in ", " IS ", " is "}
|
|
|
|
for _, op := range operators {
|
|
if idx := strings.Index(cond, op); idx > 0 {
|
|
columnName := strings.TrimSpace(cond[:idx])
|
|
// Remove quotes if present
|
|
columnName = strings.Trim(columnName, "`\"'")
|
|
return columnName
|
|
}
|
|
}
|
|
|
|
// If no operator found, check if it's a simple identifier (for boolean columns)
|
|
parts := strings.Fields(cond)
|
|
if len(parts) > 0 {
|
|
columnName := strings.Trim(parts[0], "`\"'")
|
|
// Check if it's a valid identifier (not a SQL keyword)
|
|
if !IsSQLKeyword(strings.ToLower(columnName)) {
|
|
return columnName
|
|
}
|
|
}
|
|
|
|
return ""
|
|
}
|
|
|
|
// IsSQLKeyword checks if a string is a SQL keyword that shouldn't be treated as a column name
|
|
func IsSQLKeyword(word string) bool {
|
|
keywords := []string{"select", "from", "where", "and", "or", "not", "in", "is", "null", "true", "false", "like", "between", "exists"}
|
|
for _, kw := range keywords {
|
|
if word == kw {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// getValidColumnsForTable retrieves the valid SQL columns for a table from the model registry
|
|
// Returns a map of column names for fast lookup, or nil if the model is not found
|
|
func getValidColumnsForTable(tableName string) map[string]bool {
|
|
// Try to get the model from the registry
|
|
model, err := modelregistry.GetModelByName(tableName)
|
|
if err != nil {
|
|
// Model not found, return nil to indicate we should use fallback behavior
|
|
return nil
|
|
}
|
|
|
|
// Get SQL columns from the model
|
|
columns := reflection.GetSQLModelColumns(model)
|
|
if len(columns) == 0 {
|
|
// No columns found, return nil
|
|
return nil
|
|
}
|
|
|
|
// Build a map for fast lookup
|
|
columnMap := make(map[string]bool, len(columns))
|
|
for _, col := range columns {
|
|
columnMap[strings.ToLower(col)] = true
|
|
}
|
|
|
|
return columnMap
|
|
}
|
|
|
|
// extractTableAndColumn extracts the table prefix and column name from a qualified reference
|
|
// For example: "users.status = 'active'" returns ("users", "status")
|
|
// Returns empty strings if no table prefix is found
|
|
func extractTableAndColumn(cond string) (table string, column string) {
|
|
// Common SQL operators to find the column reference
|
|
operators := []string{" = ", " != ", " <> ", " > ", " >= ", " < ", " <= ", " LIKE ", " like ", " IN ", " in ", " IS ", " is "}
|
|
|
|
var columnRef string
|
|
|
|
// Find the column reference (left side of the operator)
|
|
for _, op := range operators {
|
|
if idx := strings.Index(cond, op); idx > 0 {
|
|
columnRef = strings.TrimSpace(cond[:idx])
|
|
break
|
|
}
|
|
}
|
|
|
|
// If no operator found, the whole condition might be the column reference
|
|
if columnRef == "" {
|
|
parts := strings.Fields(cond)
|
|
if len(parts) > 0 {
|
|
columnRef = parts[0]
|
|
}
|
|
}
|
|
|
|
if columnRef == "" {
|
|
return "", ""
|
|
}
|
|
|
|
// Remove any quotes
|
|
columnRef = strings.Trim(columnRef, "`\"'")
|
|
|
|
// Check if it contains a dot (qualified reference)
|
|
if dotIdx := strings.LastIndex(columnRef, "."); dotIdx > 0 {
|
|
table = columnRef[:dotIdx]
|
|
column = columnRef[dotIdx+1:]
|
|
|
|
// Remove quotes from table and column if present
|
|
table = strings.Trim(table, "`\"'")
|
|
column = strings.Trim(column, "`\"'")
|
|
|
|
return table, column
|
|
}
|
|
|
|
return "", ""
|
|
}
|
|
|
|
// isValidColumn checks if a column name exists in the valid columns map
|
|
// Handles case-insensitive comparison
|
|
func isValidColumn(columnName string, validColumns map[string]bool) bool {
|
|
if validColumns == nil {
|
|
return true // No model info, assume valid
|
|
}
|
|
return validColumns[strings.ToLower(columnName)]
|
|
}
|