From 99001c749de51e4b8462df642a415fd0f1b81787 Mon Sep 17 00:00:00 2001 From: Hein Date: Wed, 10 Dec 2025 09:52:13 +0200 Subject: [PATCH] Better sql where validation --- pkg/common/sql_helpers.go | 98 +++++++++++++++++-- pkg/common/sql_helpers_test.go | 167 +++++++++++++++++++++++++++++++++ 2 files changed, 256 insertions(+), 9 deletions(-) diff --git a/pkg/common/sql_helpers.go b/pkg/common/sql_helpers.go index 1d2ac6c..4c99001 100644 --- a/pkg/common/sql_helpers.go +++ b/pkg/common/sql_helpers.go @@ -1,6 +1,7 @@ package common import ( + "fmt" "strings" "github.com/bitechdev/ResolveSpec/pkg/logger" @@ -78,6 +79,41 @@ func IsTrivialCondition(cond string) bool { return false } +// validateWhereClauseSecurity checks for dangerous SQL statements in WHERE clauses +// Returns an error if any dangerous keywords are found +func validateWhereClauseSecurity(where string) error { + if where == "" { + return nil + } + + lowerWhere := strings.ToLower(where) + + // List of dangerous SQL keywords that should never appear in WHERE clauses + dangerousKeywords := []string{ + "delete ", "delete\t", "delete\n", "delete;", + "update ", "update\t", "update\n", "update;", + "truncate ", "truncate\t", "truncate\n", "truncate;", + "drop ", "drop\t", "drop\n", "drop;", + "alter ", "alter\t", "alter\n", "alter;", + "create ", "create\t", "create\n", "create;", + "insert ", "insert\t", "insert\n", "insert;", + "grant ", "grant\t", "grant\n", "grant;", + "revoke ", "revoke\t", "revoke\n", "revoke;", + "exec ", "exec\t", "exec\n", "exec;", + "execute ", "execute\t", "execute\n", "execute;", + ";delete", ";update", ";truncate", ";drop", ";alter", ";create", ";insert", + } + + for _, keyword := range dangerousKeywords { + if strings.Contains(lowerWhere, keyword) { + logger.Error("Dangerous SQL keyword detected in WHERE clause: %s", strings.TrimSpace(keyword)) + return fmt.Errorf("dangerous SQL keyword detected in WHERE clause: %s", strings.TrimSpace(keyword)) + } + } + + return nil +} + // SanitizeWhereClause removes trivial conditions and fixes incorrect table prefixes // This function should be used everywhere a WHERE statement is sent to ensure clean, efficient SQL // @@ -100,6 +136,12 @@ func SanitizeWhereClause(where string, tableName string, options ...*RequestOpti where = strings.TrimSpace(where) + // Validate that the WHERE clause doesn't contain dangerous SQL statements + if err := validateWhereClauseSecurity(where); err != nil { + logger.Debug("Security validation failed for WHERE clause: %v", err) + return "" + } + // Strip outer parentheses and re-trim where = stripOuterParentheses(where) @@ -221,19 +263,57 @@ func stripOuterParentheses(s string) string { } // splitByAND splits a WHERE clause by AND operators (case-insensitive) -// This is a simple split that doesn't handle nested parentheses or complex expressions +// This is parenthesis-aware and won't split on AND operators inside subqueries func splitByAND(where string) []string { - // First try uppercase AND - conditions := strings.Split(where, " AND ") + conditions := []string{} + currentCondition := strings.Builder{} + depth := 0 // Track parenthesis depth + i := 0 - // If we didn't split on uppercase, try lowercase - if len(conditions) == 1 { - conditions = strings.Split(where, " and ") + for i < len(where) { + ch := where[i] + + // Track parenthesis depth + if ch == '(' { + depth++ + currentCondition.WriteByte(ch) + i++ + continue + } else if ch == ')' { + depth-- + currentCondition.WriteByte(ch) + i++ + continue + } + + // Only look for AND operators at depth 0 (not inside parentheses) + if depth == 0 { + // Check if we're at an AND operator (case-insensitive) + // We need at least " AND " (5 chars) or " and " (5 chars) + if i+5 <= len(where) { + substring := where[i : i+5] + lowerSubstring := strings.ToLower(substring) + + if lowerSubstring == " and " { + // Found an AND operator at the top level + // Add the current condition to the list + conditions = append(conditions, currentCondition.String()) + currentCondition.Reset() + // Skip past the AND operator + i += 5 + continue + } + } + } + + // Not an AND operator or we're inside parentheses, just add the character + currentCondition.WriteByte(ch) + i++ } - // If we still didn't split, try mixed case - if len(conditions) == 1 { - conditions = strings.Split(where, " And ") + // Add the last condition + if currentCondition.Len() > 0 { + conditions = append(conditions, currentCondition.String()) } return conditions diff --git a/pkg/common/sql_helpers_test.go b/pkg/common/sql_helpers_test.go index 5f328a5..92c6d12 100644 --- a/pkg/common/sql_helpers_test.go +++ b/pkg/common/sql_helpers_test.go @@ -1,6 +1,7 @@ package common import ( + "strings" "testing" "github.com/bitechdev/ResolveSpec/pkg/modelregistry" @@ -85,6 +86,42 @@ func TestSanitizeWhereClause(t *testing.T) { tableName: "users", expected: "users.status = 'active' AND users.age > 18", }, + { + name: "mixed case AND operators", + where: "status = 'active' AND age > 18 and name = 'John'", + tableName: "users", + expected: "status = 'active' AND age > 18 AND name = 'John'", + }, + { + name: "subquery with ORDER BY and LIMIT - allowed", + where: "id IN (SELECT id FROM users WHERE status = 'active' ORDER BY created_at DESC LIMIT 10)", + tableName: "users", + expected: "id IN (SELECT id FROM users WHERE status = 'active' ORDER BY created_at DESC LIMIT 10)", + }, + { + name: "dangerous DELETE keyword - blocked", + where: "status = 'active'; DELETE FROM users", + tableName: "users", + expected: "", + }, + { + name: "dangerous UPDATE keyword - blocked", + where: "1=1; UPDATE users SET admin = true", + tableName: "users", + expected: "", + }, + { + name: "dangerous TRUNCATE keyword - blocked", + where: "status = 'active' OR TRUNCATE TABLE users", + tableName: "users", + expected: "", + }, + { + name: "dangerous DROP keyword - blocked", + where: "status = 'active'; DROP TABLE users", + tableName: "users", + expected: "", + }, } for _, tt := range tests { @@ -138,6 +175,11 @@ func TestStripOuterParentheses(t *testing.T) { input: " ( true ) ", expected: "true", }, + { + name: "complex sub query", + input: "(a = 1 AND b = 2 or c = 3 and (select s from generate_series(1,10) s where s < 10 and s > 0 offset 2 limit 1) = 3)", + expected: "a = 1 AND b = 2 or c = 3 and (select s from generate_series(1,10) s where s < 10 and s > 0 offset 2 limit 1) = 3", + }, } for _, tt := range tests { @@ -337,6 +379,131 @@ type MasterTask struct { UserID int `bun:"user_id"` } +func TestSplitByAND(t *testing.T) { + tests := []struct { + name string + input string + expected []string + }{ + { + name: "uppercase AND", + input: "status = 'active' AND age > 18", + expected: []string{"status = 'active'", "age > 18"}, + }, + { + name: "lowercase and", + input: "status = 'active' and age > 18", + expected: []string{"status = 'active'", "age > 18"}, + }, + { + name: "mixed case AND", + input: "status = 'active' AND age > 18 and name = 'John'", + expected: []string{"status = 'active'", "age > 18", "name = 'John'"}, + }, + { + name: "single condition", + input: "status = 'active'", + expected: []string{"status = 'active'"}, + }, + { + name: "multiple uppercase AND", + input: "a = 1 AND b = 2 AND c = 3", + expected: []string{"a = 1", "b = 2", "c = 3"}, + }, + { + name: "multiple case subquery", + input: "a = 1 AND b = 2 AND c = 3 and (select s from generate_series(1,10) s where s < 10 and s > 0 offset 2 limit 1) = 3", + expected: []string{"a = 1", "b = 2", "c = 3", "(select s from generate_series(1,10) s where s < 10 and s > 0 offset 2 limit 1) = 3"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := splitByAND(tt.input) + if len(result) != len(tt.expected) { + t.Errorf("splitByAND(%q) returned %d conditions; want %d", tt.input, len(result), len(tt.expected)) + return + } + for i := range result { + if strings.TrimSpace(result[i]) != strings.TrimSpace(tt.expected[i]) { + t.Errorf("splitByAND(%q)[%d] = %q; want %q", tt.input, i, result[i], tt.expected[i]) + } + } + }) + } +} + +func TestValidateWhereClauseSecurity(t *testing.T) { + tests := []struct { + name string + input string + expectError bool + }{ + { + name: "safe WHERE clause", + input: "status = 'active' AND age > 18", + expectError: false, + }, + { + name: "safe subquery", + input: "id IN (SELECT id FROM users WHERE status = 'active' ORDER BY created_at DESC LIMIT 10)", + expectError: false, + }, + { + name: "DELETE keyword", + input: "status = 'active'; DELETE FROM users", + expectError: true, + }, + { + name: "UPDATE keyword", + input: "1=1; UPDATE users SET admin = true", + expectError: true, + }, + { + name: "TRUNCATE keyword", + input: "status = 'active' OR TRUNCATE TABLE users", + expectError: true, + }, + { + name: "DROP keyword", + input: "status = 'active'; DROP TABLE users", + expectError: true, + }, + { + name: "INSERT keyword", + input: "status = 'active'; INSERT INTO users (name) VALUES ('hacker')", + expectError: true, + }, + { + name: "ALTER keyword", + input: "1=1; ALTER TABLE users ADD COLUMN is_admin BOOLEAN", + expectError: true, + }, + { + name: "CREATE keyword", + input: "1=1; CREATE TABLE malicious (id INT)", + expectError: true, + }, + { + name: "empty clause", + input: "", + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateWhereClauseSecurity(tt.input) + if tt.expectError && err == nil { + t.Errorf("validateWhereClauseSecurity(%q) expected error but got none", tt.input) + } + if !tt.expectError && err != nil { + t.Errorf("validateWhereClauseSecurity(%q) unexpected error: %v", tt.input, err) + } + }) + } +} + func TestSanitizeWhereClauseWithModel(t *testing.T) { // Register the test model err := modelregistry.RegisterModel(MasterTask{}, "mastertask")