mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2025-12-06 14:26:22 +00:00
Better SanitizeWhereClause
This commit is contained in:
parent
9a3564f05f
commit
02c9b96b0c
@ -5,6 +5,8 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ValidateAndFixPreloadWhere validates that the WHERE clause for a preload contains
|
// ValidateAndFixPreloadWhere validates that the WHERE clause for a preload contains
|
||||||
@ -135,6 +137,15 @@ func SanitizeWhereClause(where string, tableName string) string {
|
|||||||
|
|
||||||
where = strings.TrimSpace(where)
|
where = strings.TrimSpace(where)
|
||||||
|
|
||||||
|
// Strip outer parentheses and re-trim
|
||||||
|
where = stripOuterParentheses(where)
|
||||||
|
|
||||||
|
// Get valid columns from the model if tableName is provided
|
||||||
|
var validColumns map[string]bool
|
||||||
|
if tableName != "" {
|
||||||
|
validColumns = getValidColumnsForTable(tableName)
|
||||||
|
}
|
||||||
|
|
||||||
// Split by AND to handle multiple conditions
|
// Split by AND to handle multiple conditions
|
||||||
conditions := splitByAND(where)
|
conditions := splitByAND(where)
|
||||||
|
|
||||||
@ -146,22 +157,32 @@ func SanitizeWhereClause(where string, tableName string) string {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Strip parentheses from the condition before checking
|
||||||
|
condToCheck := stripOuterParentheses(cond)
|
||||||
|
|
||||||
// Skip trivial conditions that always evaluate to true
|
// Skip trivial conditions that always evaluate to true
|
||||||
if IsTrivialCondition(cond) {
|
if IsTrivialCondition(condToCheck) {
|
||||||
logger.Debug("Removing trivial condition: '%s'", cond)
|
logger.Debug("Removing trivial condition: '%s'", cond)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// If tableName is provided and the condition doesn't already have a table prefix,
|
// If tableName is provided and the condition doesn't already have a table prefix,
|
||||||
// attempt to add it
|
// attempt to add it
|
||||||
if tableName != "" && !hasTablePrefix(cond) {
|
if tableName != "" && !hasTablePrefix(condToCheck) {
|
||||||
// Check if this is a SQL expression/literal that shouldn't be prefixed
|
// Check if this is a SQL expression/literal that shouldn't be prefixed
|
||||||
if !IsSQLExpression(strings.ToLower(cond)) {
|
if !IsSQLExpression(strings.ToLower(condToCheck)) {
|
||||||
// Extract the column name and prefix it
|
// Extract the column name and prefix it
|
||||||
columnName := ExtractColumnName(cond)
|
columnName := ExtractColumnName(condToCheck)
|
||||||
if columnName != "" {
|
if columnName != "" {
|
||||||
|
// Only prefix if this is a valid column in the model
|
||||||
|
// If we don't have model info (validColumns is nil), prefix anyway for backward compatibility
|
||||||
|
if validColumns == nil || isValidColumn(columnName, validColumns) {
|
||||||
|
// Replace in the original condition (without stripped parens)
|
||||||
cond = strings.Replace(cond, columnName, tableName+"."+columnName, 1)
|
cond = strings.Replace(cond, columnName, tableName+"."+columnName, 1)
|
||||||
logger.Debug("Prefixed column in condition: '%s'", cond)
|
logger.Debug("Prefixed column in condition: '%s'", cond)
|
||||||
|
} else {
|
||||||
|
logger.Debug("Skipping prefix for '%s' - not a valid column in model", columnName)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -182,6 +203,43 @@ func SanitizeWhereClause(where string, tableName string) string {
|
|||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// stripOuterParentheses removes matching outer parentheses from a string
|
||||||
|
// It handles nested parentheses correctly
|
||||||
|
func stripOuterParentheses(s string) string {
|
||||||
|
s = strings.TrimSpace(s)
|
||||||
|
|
||||||
|
for {
|
||||||
|
if len(s) < 2 || s[0] != '(' || s[len(s)-1] != ')' {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if these parentheses match (i.e., they're the outermost pair)
|
||||||
|
depth := 0
|
||||||
|
matched := false
|
||||||
|
for i := 0; i < len(s); i++ {
|
||||||
|
switch s[i] {
|
||||||
|
case '(':
|
||||||
|
depth++
|
||||||
|
case ')':
|
||||||
|
depth--
|
||||||
|
if depth == 0 && i == len(s)-1 {
|
||||||
|
matched = true
|
||||||
|
} else if depth == 0 {
|
||||||
|
// Found a closing paren before the end, so outer parens don't match
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !matched {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// Strip the outer parentheses and continue
|
||||||
|
s = strings.TrimSpace(s[1 : len(s)-1])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// splitByAND splits a WHERE clause by AND operators (case-insensitive)
|
// splitByAND splits a WHERE clause by AND operators (case-insensitive)
|
||||||
// This is a simple split that doesn't handle nested parentheses or complex expressions
|
// This is a simple split that doesn't handle nested parentheses or complex expressions
|
||||||
func splitByAND(where string) []string {
|
func splitByAND(where string) []string {
|
||||||
@ -245,3 +303,38 @@ func IsSQLKeyword(word string) bool {
|
|||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getValidColumnsForTable retrieves the valid SQL columns for a table from the model registry
|
||||||
|
// Returns a map of column names for fast lookup, or nil if the model is not found
|
||||||
|
func getValidColumnsForTable(tableName string) map[string]bool {
|
||||||
|
// Try to get the model from the registry
|
||||||
|
model, err := modelregistry.GetModelByName(tableName)
|
||||||
|
if err != nil {
|
||||||
|
// Model not found, return nil to indicate we should use fallback behavior
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get SQL columns from the model
|
||||||
|
columns := reflection.GetSQLModelColumns(model)
|
||||||
|
if len(columns) == 0 {
|
||||||
|
// No columns found, return nil
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build a map for fast lookup
|
||||||
|
columnMap := make(map[string]bool, len(columns))
|
||||||
|
for _, col := range columns {
|
||||||
|
columnMap[strings.ToLower(col)] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
return columnMap
|
||||||
|
}
|
||||||
|
|
||||||
|
// isValidColumn checks if a column name exists in the valid columns map
|
||||||
|
// Handles case-insensitive comparison
|
||||||
|
func isValidColumn(columnName string, validColumns map[string]bool) bool {
|
||||||
|
if validColumns == nil {
|
||||||
|
return true // No model info, assume valid
|
||||||
|
}
|
||||||
|
return validColumns[strings.ToLower(columnName)]
|
||||||
|
}
|
||||||
|
|||||||
224
pkg/common/sql_helpers_test.go
Normal file
224
pkg/common/sql_helpers_test.go
Normal file
@ -0,0 +1,224 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSanitizeWhereClause(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
where string
|
||||||
|
tableName string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "trivial conditions in parentheses",
|
||||||
|
where: "(true AND true AND true)",
|
||||||
|
tableName: "mastertask",
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "trivial conditions without parentheses",
|
||||||
|
where: "true AND true AND true",
|
||||||
|
tableName: "mastertask",
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "single trivial condition",
|
||||||
|
where: "true",
|
||||||
|
tableName: "mastertask",
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid condition with parentheses",
|
||||||
|
where: "(status = 'active')",
|
||||||
|
tableName: "users",
|
||||||
|
expected: "users.status = 'active'",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "mixed trivial and valid conditions",
|
||||||
|
where: "true AND status = 'active' AND 1=1",
|
||||||
|
tableName: "users",
|
||||||
|
expected: "users.status = 'active'",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "condition already with table prefix",
|
||||||
|
where: "users.status = 'active'",
|
||||||
|
tableName: "users",
|
||||||
|
expected: "users.status = 'active'",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple valid conditions",
|
||||||
|
where: "status = 'active' AND age > 18",
|
||||||
|
tableName: "users",
|
||||||
|
expected: "users.status = 'active' AND users.age > 18",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no table name provided",
|
||||||
|
where: "status = 'active'",
|
||||||
|
tableName: "",
|
||||||
|
expected: "status = 'active'",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty where clause",
|
||||||
|
where: "",
|
||||||
|
tableName: "users",
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := SanitizeWhereClause(tt.where, tt.tableName)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("SanitizeWhereClause(%q, %q) = %q; want %q", tt.where, tt.tableName, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStripOuterParentheses(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "single level parentheses",
|
||||||
|
input: "(true)",
|
||||||
|
expected: "true",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple levels",
|
||||||
|
input: "((true))",
|
||||||
|
expected: "true",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no parentheses",
|
||||||
|
input: "true",
|
||||||
|
expected: "true",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "mismatched parentheses",
|
||||||
|
input: "(true",
|
||||||
|
expected: "(true",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "complex expression",
|
||||||
|
input: "(a AND b)",
|
||||||
|
expected: "a AND b",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nested but not outer",
|
||||||
|
input: "(a AND (b OR c)) AND d",
|
||||||
|
expected: "(a AND (b OR c)) AND d",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "with spaces",
|
||||||
|
input: " ( true ) ",
|
||||||
|
expected: "true",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := stripOuterParentheses(tt.input)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("stripOuterParentheses(%q) = %q; want %q", tt.input, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsTrivialCondition(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{"true", "true", true},
|
||||||
|
{"true with spaces", " true ", true},
|
||||||
|
{"TRUE uppercase", "TRUE", true},
|
||||||
|
{"1=1", "1=1", true},
|
||||||
|
{"1 = 1", "1 = 1", true},
|
||||||
|
{"true = true", "true = true", true},
|
||||||
|
{"valid condition", "status = 'active'", false},
|
||||||
|
{"false", "false", false},
|
||||||
|
{"column name", "is_active", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := IsTrivialCondition(tt.input)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("IsTrivialCondition(%q) = %v; want %v", tt.input, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test model for model-aware sanitization tests
|
||||||
|
type MasterTask struct {
|
||||||
|
ID int `bun:"id,pk"`
|
||||||
|
Name string `bun:"name"`
|
||||||
|
Status string `bun:"status"`
|
||||||
|
UserID int `bun:"user_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSanitizeWhereClauseWithModel(t *testing.T) {
|
||||||
|
// Register the test model
|
||||||
|
err := modelregistry.RegisterModel(MasterTask{}, "mastertask")
|
||||||
|
if err != nil {
|
||||||
|
// Model might already be registered, ignore error
|
||||||
|
t.Logf("Model registration returned: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
where string
|
||||||
|
tableName string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid column gets prefixed",
|
||||||
|
where: "status = 'active'",
|
||||||
|
tableName: "mastertask",
|
||||||
|
expected: "mastertask.status = 'active'",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple valid columns get prefixed",
|
||||||
|
where: "status = 'active' AND user_id = 123",
|
||||||
|
tableName: "mastertask",
|
||||||
|
expected: "mastertask.status = 'active' AND mastertask.user_id = 123",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid column does not get prefixed",
|
||||||
|
where: "invalid_column = 'value'",
|
||||||
|
tableName: "mastertask",
|
||||||
|
expected: "invalid_column = 'value'",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "mix of valid and trivial conditions",
|
||||||
|
where: "true AND status = 'active' AND 1=1",
|
||||||
|
tableName: "mastertask",
|
||||||
|
expected: "mastertask.status = 'active'",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "parentheses with valid column",
|
||||||
|
where: "(status = 'active')",
|
||||||
|
tableName: "mastertask",
|
||||||
|
expected: "mastertask.status = 'active'",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := SanitizeWhereClause(tt.where, tt.tableName)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("SanitizeWhereClause(%q, %q) = %q; want %q", tt.where, tt.tableName, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -653,6 +653,7 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// Apply WHERE clause
|
// Apply WHERE clause
|
||||||
if len(preload.Where) > 0 {
|
if len(preload.Where) > 0 {
|
||||||
sanitizedWhere := common.SanitizeWhereClause(preload.Where, reflection.ExtractTableNameOnly(preload.Relation))
|
sanitizedWhere := common.SanitizeWhereClause(preload.Where, reflection.ExtractTableNameOnly(preload.Relation))
|
||||||
|
|||||||
@ -267,6 +267,12 @@ func (h *Handler) parseOptionsFromHeaders(r common.Request, model interface{}) E
|
|||||||
h.resolveRelationNamesInOptions(&options, model)
|
h.resolveRelationNamesInOptions(&options, model)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//Always sort according to the primary key if no sorting is specified
|
||||||
|
if len(options.Sort) == 0 {
|
||||||
|
pkName := reflection.GetPrimaryKeyName(model)
|
||||||
|
options.Sort = []common.SortOption{{Column: pkName, Direction: "ASC"}}
|
||||||
|
}
|
||||||
|
|
||||||
return options
|
return options
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user