diff --git a/pkg/common/sql_helpers.go b/pkg/common/sql_helpers.go index 6730db6..0af6616 100644 --- a/pkg/common/sql_helpers.go +++ b/pkg/common/sql_helpers.go @@ -234,35 +234,52 @@ func stripOuterParentheses(s string) string { s = strings.TrimSpace(s) for { - if len(s) < 2 || s[0] != '(' || s[len(s)-1] != ')' { + stripped, wasStripped := stripOneMatchingOuterParen(s) + if !wasStripped { return s } + s = stripped + } +} - // Check if these parentheses match (i.e., they're the outermost pair) - depth := 0 - matched := false - for i := 0; i < len(s); i++ { - switch s[i] { - case '(': - depth++ - case ')': - depth-- - if depth == 0 && i == len(s)-1 { - matched = true - } else if depth == 0 { - // Found a closing paren before the end, so outer parens don't match - return s - } +// stripOneOuterParentheses removes only one level of matching outer parentheses from a string +// Unlike stripOuterParentheses, this only strips once, preserving nested parentheses +func stripOneOuterParentheses(s string) string { + stripped, _ := stripOneMatchingOuterParen(strings.TrimSpace(s)) + return stripped +} + +// stripOneMatchingOuterParen is a helper that strips one matching pair of outer parentheses +// Returns the stripped string and a boolean indicating if stripping occurred +func stripOneMatchingOuterParen(s string) (string, bool) { + if len(s) < 2 || s[0] != '(' || s[len(s)-1] != ')' { + return s, false + } + + // Check if these parentheses match (i.e., they're the outermost pair) + depth := 0 + matched := false + for i := 0; i < len(s); i++ { + switch s[i] { + case '(': + depth++ + case ')': + depth-- + if depth == 0 && i == len(s)-1 { + matched = true + } else if depth == 0 { + // Found a closing paren before the end, so outer parens don't match + return s, false } } - - if !matched { - return s - } - - // Strip the outer parentheses and continue - s = strings.TrimSpace(s[1 : len(s)-1]) } + + if !matched { + return s, false + } + + // Strip the outer parentheses + return strings.TrimSpace(s[1 : len(s)-1]), true } // splitByAND splits a WHERE clause by AND operators (case-insensitive) @@ -683,8 +700,8 @@ func AddTablePrefixToColumns(where string, tableName string) string { // - No valid column reference is found // - The column doesn't exist in the table (when validColumns is provided) func addPrefixToSingleCondition(cond string, tableName string, validColumns map[string]bool) string { - // Strip outer grouping parentheses to get to the actual condition - strippedCond := stripOuterParentheses(cond) + // Strip one level of outer grouping parentheses to get to the actual condition + strippedCond := stripOneOuterParentheses(cond) // Skip SQL literals and trivial conditions (true, false, null, 1=1, etc.) if IsSQLExpression(strippedCond) || IsTrivialCondition(strippedCond) { @@ -692,6 +709,34 @@ func addPrefixToSingleCondition(cond string, tableName string, validColumns map[ return cond } + // After stripping outer parentheses, check if there are multiple AND-separated conditions + // at the top level. If so, split and process each separately to avoid incorrectly + // treating "true AND status" as a single column name. + subConditions := splitByAND(strippedCond) + if len(subConditions) > 1 { + // Multiple conditions found - process each separately + logger.Debug("Found %d sub-conditions after stripping parentheses, processing separately", len(subConditions)) + processedConditions := make([]string, 0, len(subConditions)) + for _, subCond := range subConditions { + // Recursively process each sub-condition + processed := addPrefixToSingleCondition(subCond, tableName, validColumns) + processedConditions = append(processedConditions, processed) + } + result := strings.Join(processedConditions, " AND ") + // Preserve original outer parentheses if they existed + if cond != strippedCond { + result = "(" + result + ")" + } + return result + } + + // If we stripped parentheses and still have more parentheses, recursively process + if cond != strippedCond && strings.HasPrefix(strippedCond, "(") && strings.HasSuffix(strippedCond, ")") { + // Recursively handle nested parentheses + processed := addPrefixToSingleCondition(strippedCond, tableName, validColumns) + return "(" + processed + ")" + } + // Extract the left side of the comparison (before the operator) columnRef := extractLeftSideOfComparison(strippedCond) if columnRef == "" { diff --git a/pkg/common/sql_helpers_test.go b/pkg/common/sql_helpers_test.go index d4a0706..6f2a4ca 100644 --- a/pkg/common/sql_helpers_test.go +++ b/pkg/common/sql_helpers_test.go @@ -658,3 +658,76 @@ func TestSanitizeWhereClauseWithModel(t *testing.T) { }) } } + +func TestAddTablePrefixToColumns_ComplexConditions(t *testing.T) { +tests := []struct { +name string +where string +tableName string +expected string +}{ +{ +name: "Parentheses with true AND condition - should not prefix true", +where: "(true AND status = 'active')", +tableName: "mastertask", +expected: "(true AND mastertask.status = 'active')", +}, +{ +name: "Parentheses with multiple conditions including true", +where: "(true AND status = 'active' AND id > 5)", +tableName: "mastertask", +expected: "(true AND mastertask.status = 'active' AND mastertask.id > 5)", +}, +{ +name: "Nested parentheses with true", +where: "((true AND status = 'active'))", +tableName: "mastertask", +expected: "((true AND mastertask.status = 'active'))", +}, +{ +name: "Mixed: false AND valid conditions", +where: "(false AND name = 'test')", +tableName: "mastertask", +expected: "(false AND mastertask.name = 'test')", +}, +{ +name: "Mixed: null AND valid conditions", +where: "(null AND status = 'active')", +tableName: "mastertask", +expected: "(null AND mastertask.status = 'active')", +}, +{ +name: "Multiple true conditions in parentheses", +where: "(true AND true AND status = 'active')", +tableName: "mastertask", +expected: "(true AND true AND mastertask.status = 'active')", +}, +{ +name: "Simple true without parens - should not prefix", +where: "true", +tableName: "mastertask", +expected: "true", +}, +{ +name: "Simple condition without parens - should prefix", +where: "status = 'active'", +tableName: "mastertask", +expected: "mastertask.status = 'active'", +}, +{ +name: "Unregistered table with true - should not prefix true", +where: "(true AND status = 'active')", +tableName: "unregistered_table", +expected: "(true AND unregistered_table.status = 'active')", +}, +} + +for _, tt := range tests { +t.Run(tt.name, func(t *testing.T) { +result := AddTablePrefixToColumns(tt.where, tt.tableName) +if result != tt.expected { +t.Errorf("AddTablePrefixToColumns(%q, %q) = %q; want %q", tt.where, tt.tableName, result, tt.expected) +} +}) +} +}