diff --git a/pkg/common/sql_helpers.go b/pkg/common/sql_helpers.go index 51e6a5c..01d7228 100644 --- a/pkg/common/sql_helpers.go +++ b/pkg/common/sql_helpers.go @@ -5,6 +5,8 @@ import ( "strings" "github.com/bitechdev/ResolveSpec/pkg/logger" + "github.com/bitechdev/ResolveSpec/pkg/modelregistry" + "github.com/bitechdev/ResolveSpec/pkg/reflection" ) // ValidateAndFixPreloadWhere validates that the WHERE clause for a preload contains @@ -135,6 +137,15 @@ func SanitizeWhereClause(where string, tableName string) string { where = strings.TrimSpace(where) + // Strip outer parentheses and re-trim + where = stripOuterParentheses(where) + + // Get valid columns from the model if tableName is provided + var validColumns map[string]bool + if tableName != "" { + validColumns = getValidColumnsForTable(tableName) + } + // Split by AND to handle multiple conditions conditions := splitByAND(where) @@ -146,22 +157,32 @@ func SanitizeWhereClause(where string, tableName string) string { continue } + // Strip parentheses from the condition before checking + condToCheck := stripOuterParentheses(cond) + // Skip trivial conditions that always evaluate to true - if IsTrivialCondition(cond) { + if IsTrivialCondition(condToCheck) { logger.Debug("Removing trivial condition: '%s'", cond) continue } // If tableName is provided and the condition doesn't already have a table prefix, // attempt to add it - if tableName != "" && !hasTablePrefix(cond) { + if tableName != "" && !hasTablePrefix(condToCheck) { // Check if this is a SQL expression/literal that shouldn't be prefixed - if !IsSQLExpression(strings.ToLower(cond)) { + if !IsSQLExpression(strings.ToLower(condToCheck)) { // Extract the column name and prefix it - columnName := ExtractColumnName(cond) + columnName := ExtractColumnName(condToCheck) if columnName != "" { - cond = strings.Replace(cond, columnName, tableName+"."+columnName, 1) - logger.Debug("Prefixed column in condition: '%s'", cond) + // Only prefix if this is a valid column in the model + // If we don't have model info (validColumns is nil), prefix anyway for backward compatibility + if validColumns == nil || isValidColumn(columnName, validColumns) { + // Replace in the original condition (without stripped parens) + cond = strings.Replace(cond, columnName, tableName+"."+columnName, 1) + logger.Debug("Prefixed column in condition: '%s'", cond) + } else { + logger.Debug("Skipping prefix for '%s' - not a valid column in model", columnName) + } } } } @@ -182,6 +203,43 @@ func SanitizeWhereClause(where string, tableName string) string { return result } +// stripOuterParentheses removes matching outer parentheses from a string +// It handles nested parentheses correctly +func stripOuterParentheses(s string) string { + s = strings.TrimSpace(s) + + for { + if len(s) < 2 || s[0] != '(' || s[len(s)-1] != ')' { + return s + } + + // Check if these parentheses match (i.e., they're the outermost pair) + depth := 0 + matched := false + for i := 0; i < len(s); i++ { + switch s[i] { + case '(': + depth++ + case ')': + depth-- + if depth == 0 && i == len(s)-1 { + matched = true + } else if depth == 0 { + // Found a closing paren before the end, so outer parens don't match + return s + } + } + } + + if !matched { + return s + } + + // Strip the outer parentheses and continue + s = strings.TrimSpace(s[1 : len(s)-1]) + } +} + // splitByAND splits a WHERE clause by AND operators (case-insensitive) // This is a simple split that doesn't handle nested parentheses or complex expressions func splitByAND(where string) []string { @@ -245,3 +303,38 @@ func IsSQLKeyword(word string) bool { } return false } + +// getValidColumnsForTable retrieves the valid SQL columns for a table from the model registry +// Returns a map of column names for fast lookup, or nil if the model is not found +func getValidColumnsForTable(tableName string) map[string]bool { + // Try to get the model from the registry + model, err := modelregistry.GetModelByName(tableName) + if err != nil { + // Model not found, return nil to indicate we should use fallback behavior + return nil + } + + // Get SQL columns from the model + columns := reflection.GetSQLModelColumns(model) + if len(columns) == 0 { + // No columns found, return nil + return nil + } + + // Build a map for fast lookup + columnMap := make(map[string]bool, len(columns)) + for _, col := range columns { + columnMap[strings.ToLower(col)] = true + } + + return columnMap +} + +// isValidColumn checks if a column name exists in the valid columns map +// Handles case-insensitive comparison +func isValidColumn(columnName string, validColumns map[string]bool) bool { + if validColumns == nil { + return true // No model info, assume valid + } + return validColumns[strings.ToLower(columnName)] +} diff --git a/pkg/common/sql_helpers_test.go b/pkg/common/sql_helpers_test.go new file mode 100644 index 0000000..2b0d5a0 --- /dev/null +++ b/pkg/common/sql_helpers_test.go @@ -0,0 +1,224 @@ +package common + +import ( + "testing" + + "github.com/bitechdev/ResolveSpec/pkg/modelregistry" +) + +func TestSanitizeWhereClause(t *testing.T) { + tests := []struct { + name string + where string + tableName string + expected string + }{ + { + name: "trivial conditions in parentheses", + where: "(true AND true AND true)", + tableName: "mastertask", + expected: "", + }, + { + name: "trivial conditions without parentheses", + where: "true AND true AND true", + tableName: "mastertask", + expected: "", + }, + { + name: "single trivial condition", + where: "true", + tableName: "mastertask", + expected: "", + }, + { + name: "valid condition with parentheses", + where: "(status = 'active')", + tableName: "users", + expected: "users.status = 'active'", + }, + { + name: "mixed trivial and valid conditions", + where: "true AND status = 'active' AND 1=1", + tableName: "users", + expected: "users.status = 'active'", + }, + { + name: "condition already with table prefix", + where: "users.status = 'active'", + tableName: "users", + expected: "users.status = 'active'", + }, + { + name: "multiple valid conditions", + where: "status = 'active' AND age > 18", + tableName: "users", + expected: "users.status = 'active' AND users.age > 18", + }, + { + name: "no table name provided", + where: "status = 'active'", + tableName: "", + expected: "status = 'active'", + }, + { + name: "empty where clause", + where: "", + tableName: "users", + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := SanitizeWhereClause(tt.where, tt.tableName) + if result != tt.expected { + t.Errorf("SanitizeWhereClause(%q, %q) = %q; want %q", tt.where, tt.tableName, result, tt.expected) + } + }) + } +} + +func TestStripOuterParentheses(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + { + name: "single level parentheses", + input: "(true)", + expected: "true", + }, + { + name: "multiple levels", + input: "((true))", + expected: "true", + }, + { + name: "no parentheses", + input: "true", + expected: "true", + }, + { + name: "mismatched parentheses", + input: "(true", + expected: "(true", + }, + { + name: "complex expression", + input: "(a AND b)", + expected: "a AND b", + }, + { + name: "nested but not outer", + input: "(a AND (b OR c)) AND d", + expected: "(a AND (b OR c)) AND d", + }, + { + name: "with spaces", + input: " ( true ) ", + expected: "true", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := stripOuterParentheses(tt.input) + if result != tt.expected { + t.Errorf("stripOuterParentheses(%q) = %q; want %q", tt.input, result, tt.expected) + } + }) + } +} + +func TestIsTrivialCondition(t *testing.T) { + tests := []struct { + name string + input string + expected bool + }{ + {"true", "true", true}, + {"true with spaces", " true ", true}, + {"TRUE uppercase", "TRUE", true}, + {"1=1", "1=1", true}, + {"1 = 1", "1 = 1", true}, + {"true = true", "true = true", true}, + {"valid condition", "status = 'active'", false}, + {"false", "false", false}, + {"column name", "is_active", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := IsTrivialCondition(tt.input) + if result != tt.expected { + t.Errorf("IsTrivialCondition(%q) = %v; want %v", tt.input, result, tt.expected) + } + }) + } +} + +// Test model for model-aware sanitization tests +type MasterTask struct { + ID int `bun:"id,pk"` + Name string `bun:"name"` + Status string `bun:"status"` + UserID int `bun:"user_id"` +} + +func TestSanitizeWhereClauseWithModel(t *testing.T) { + // Register the test model + err := modelregistry.RegisterModel(MasterTask{}, "mastertask") + if err != nil { + // Model might already be registered, ignore error + t.Logf("Model registration returned: %v", err) + } + + tests := []struct { + name string + where string + tableName string + expected string + }{ + { + name: "valid column gets prefixed", + where: "status = 'active'", + tableName: "mastertask", + expected: "mastertask.status = 'active'", + }, + { + name: "multiple valid columns get prefixed", + where: "status = 'active' AND user_id = 123", + tableName: "mastertask", + expected: "mastertask.status = 'active' AND mastertask.user_id = 123", + }, + { + name: "invalid column does not get prefixed", + where: "invalid_column = 'value'", + tableName: "mastertask", + expected: "invalid_column = 'value'", + }, + { + name: "mix of valid and trivial conditions", + where: "true AND status = 'active' AND 1=1", + tableName: "mastertask", + expected: "mastertask.status = 'active'", + }, + { + name: "parentheses with valid column", + where: "(status = 'active')", + tableName: "mastertask", + expected: "mastertask.status = 'active'", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := SanitizeWhereClause(tt.where, tt.tableName) + if result != tt.expected { + t.Errorf("SanitizeWhereClause(%q, %q) = %q; want %q", tt.where, tt.tableName, result, tt.expected) + } + }) + } +} diff --git a/pkg/restheadspec/handler.go b/pkg/restheadspec/handler.go index 40a395b..a7bbc89 100644 --- a/pkg/restheadspec/handler.go +++ b/pkg/restheadspec/handler.go @@ -652,6 +652,7 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co sq = sq.Order(fmt.Sprintf("%s %s", sort.Column, sort.Direction)) } } + // Apply WHERE clause if len(preload.Where) > 0 { diff --git a/pkg/restheadspec/headers.go b/pkg/restheadspec/headers.go index a7541cf..7f1debb 100644 --- a/pkg/restheadspec/headers.go +++ b/pkg/restheadspec/headers.go @@ -267,6 +267,12 @@ func (h *Handler) parseOptionsFromHeaders(r common.Request, model interface{}) E h.resolveRelationNamesInOptions(&options, model) } + //Always sort according to the primary key if no sorting is specified + if len(options.Sort) == 0 { + pkName := reflection.GetPrimaryKeyName(model) + options.Sort = []common.SortOption{{Column: pkName, Direction: "ASC"}} + } + return options }