mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2025-12-13 17:10:36 +00:00
Better sql where validation
This commit is contained in:
parent
1f7a57f8e3
commit
99001c749d
@ -1,6 +1,7 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
@ -78,6 +79,41 @@ func IsTrivialCondition(cond string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// validateWhereClauseSecurity checks for dangerous SQL statements in WHERE clauses
|
||||
// Returns an error if any dangerous keywords are found
|
||||
func validateWhereClauseSecurity(where string) error {
|
||||
if where == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
lowerWhere := strings.ToLower(where)
|
||||
|
||||
// List of dangerous SQL keywords that should never appear in WHERE clauses
|
||||
dangerousKeywords := []string{
|
||||
"delete ", "delete\t", "delete\n", "delete;",
|
||||
"update ", "update\t", "update\n", "update;",
|
||||
"truncate ", "truncate\t", "truncate\n", "truncate;",
|
||||
"drop ", "drop\t", "drop\n", "drop;",
|
||||
"alter ", "alter\t", "alter\n", "alter;",
|
||||
"create ", "create\t", "create\n", "create;",
|
||||
"insert ", "insert\t", "insert\n", "insert;",
|
||||
"grant ", "grant\t", "grant\n", "grant;",
|
||||
"revoke ", "revoke\t", "revoke\n", "revoke;",
|
||||
"exec ", "exec\t", "exec\n", "exec;",
|
||||
"execute ", "execute\t", "execute\n", "execute;",
|
||||
";delete", ";update", ";truncate", ";drop", ";alter", ";create", ";insert",
|
||||
}
|
||||
|
||||
for _, keyword := range dangerousKeywords {
|
||||
if strings.Contains(lowerWhere, keyword) {
|
||||
logger.Error("Dangerous SQL keyword detected in WHERE clause: %s", strings.TrimSpace(keyword))
|
||||
return fmt.Errorf("dangerous SQL keyword detected in WHERE clause: %s", strings.TrimSpace(keyword))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SanitizeWhereClause removes trivial conditions and fixes incorrect table prefixes
|
||||
// This function should be used everywhere a WHERE statement is sent to ensure clean, efficient SQL
|
||||
//
|
||||
@ -100,6 +136,12 @@ func SanitizeWhereClause(where string, tableName string, options ...*RequestOpti
|
||||
|
||||
where = strings.TrimSpace(where)
|
||||
|
||||
// Validate that the WHERE clause doesn't contain dangerous SQL statements
|
||||
if err := validateWhereClauseSecurity(where); err != nil {
|
||||
logger.Debug("Security validation failed for WHERE clause: %v", err)
|
||||
return ""
|
||||
}
|
||||
|
||||
// Strip outer parentheses and re-trim
|
||||
where = stripOuterParentheses(where)
|
||||
|
||||
@ -221,19 +263,57 @@ func stripOuterParentheses(s string) string {
|
||||
}
|
||||
|
||||
// splitByAND splits a WHERE clause by AND operators (case-insensitive)
|
||||
// This is a simple split that doesn't handle nested parentheses or complex expressions
|
||||
// This is parenthesis-aware and won't split on AND operators inside subqueries
|
||||
func splitByAND(where string) []string {
|
||||
// First try uppercase AND
|
||||
conditions := strings.Split(where, " AND ")
|
||||
conditions := []string{}
|
||||
currentCondition := strings.Builder{}
|
||||
depth := 0 // Track parenthesis depth
|
||||
i := 0
|
||||
|
||||
// If we didn't split on uppercase, try lowercase
|
||||
if len(conditions) == 1 {
|
||||
conditions = strings.Split(where, " and ")
|
||||
for i < len(where) {
|
||||
ch := where[i]
|
||||
|
||||
// Track parenthesis depth
|
||||
if ch == '(' {
|
||||
depth++
|
||||
currentCondition.WriteByte(ch)
|
||||
i++
|
||||
continue
|
||||
} else if ch == ')' {
|
||||
depth--
|
||||
currentCondition.WriteByte(ch)
|
||||
i++
|
||||
continue
|
||||
}
|
||||
|
||||
// If we still didn't split, try mixed case
|
||||
if len(conditions) == 1 {
|
||||
conditions = strings.Split(where, " And ")
|
||||
// Only look for AND operators at depth 0 (not inside parentheses)
|
||||
if depth == 0 {
|
||||
// Check if we're at an AND operator (case-insensitive)
|
||||
// We need at least " AND " (5 chars) or " and " (5 chars)
|
||||
if i+5 <= len(where) {
|
||||
substring := where[i : i+5]
|
||||
lowerSubstring := strings.ToLower(substring)
|
||||
|
||||
if lowerSubstring == " and " {
|
||||
// Found an AND operator at the top level
|
||||
// Add the current condition to the list
|
||||
conditions = append(conditions, currentCondition.String())
|
||||
currentCondition.Reset()
|
||||
// Skip past the AND operator
|
||||
i += 5
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Not an AND operator or we're inside parentheses, just add the character
|
||||
currentCondition.WriteByte(ch)
|
||||
i++
|
||||
}
|
||||
|
||||
// Add the last condition
|
||||
if currentCondition.Len() > 0 {
|
||||
conditions = append(conditions, currentCondition.String())
|
||||
}
|
||||
|
||||
return conditions
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
|
||||
@ -85,6 +86,42 @@ func TestSanitizeWhereClause(t *testing.T) {
|
||||
tableName: "users",
|
||||
expected: "users.status = 'active' AND users.age > 18",
|
||||
},
|
||||
{
|
||||
name: "mixed case AND operators",
|
||||
where: "status = 'active' AND age > 18 and name = 'John'",
|
||||
tableName: "users",
|
||||
expected: "status = 'active' AND age > 18 AND name = 'John'",
|
||||
},
|
||||
{
|
||||
name: "subquery with ORDER BY and LIMIT - allowed",
|
||||
where: "id IN (SELECT id FROM users WHERE status = 'active' ORDER BY created_at DESC LIMIT 10)",
|
||||
tableName: "users",
|
||||
expected: "id IN (SELECT id FROM users WHERE status = 'active' ORDER BY created_at DESC LIMIT 10)",
|
||||
},
|
||||
{
|
||||
name: "dangerous DELETE keyword - blocked",
|
||||
where: "status = 'active'; DELETE FROM users",
|
||||
tableName: "users",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "dangerous UPDATE keyword - blocked",
|
||||
where: "1=1; UPDATE users SET admin = true",
|
||||
tableName: "users",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "dangerous TRUNCATE keyword - blocked",
|
||||
where: "status = 'active' OR TRUNCATE TABLE users",
|
||||
tableName: "users",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "dangerous DROP keyword - blocked",
|
||||
where: "status = 'active'; DROP TABLE users",
|
||||
tableName: "users",
|
||||
expected: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@ -138,6 +175,11 @@ func TestStripOuterParentheses(t *testing.T) {
|
||||
input: " ( true ) ",
|
||||
expected: "true",
|
||||
},
|
||||
{
|
||||
name: "complex sub query",
|
||||
input: "(a = 1 AND b = 2 or c = 3 and (select s from generate_series(1,10) s where s < 10 and s > 0 offset 2 limit 1) = 3)",
|
||||
expected: "a = 1 AND b = 2 or c = 3 and (select s from generate_series(1,10) s where s < 10 and s > 0 offset 2 limit 1) = 3",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@ -337,6 +379,131 @@ type MasterTask struct {
|
||||
UserID int `bun:"user_id"`
|
||||
}
|
||||
|
||||
func TestSplitByAND(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected []string
|
||||
}{
|
||||
{
|
||||
name: "uppercase AND",
|
||||
input: "status = 'active' AND age > 18",
|
||||
expected: []string{"status = 'active'", "age > 18"},
|
||||
},
|
||||
{
|
||||
name: "lowercase and",
|
||||
input: "status = 'active' and age > 18",
|
||||
expected: []string{"status = 'active'", "age > 18"},
|
||||
},
|
||||
{
|
||||
name: "mixed case AND",
|
||||
input: "status = 'active' AND age > 18 and name = 'John'",
|
||||
expected: []string{"status = 'active'", "age > 18", "name = 'John'"},
|
||||
},
|
||||
{
|
||||
name: "single condition",
|
||||
input: "status = 'active'",
|
||||
expected: []string{"status = 'active'"},
|
||||
},
|
||||
{
|
||||
name: "multiple uppercase AND",
|
||||
input: "a = 1 AND b = 2 AND c = 3",
|
||||
expected: []string{"a = 1", "b = 2", "c = 3"},
|
||||
},
|
||||
{
|
||||
name: "multiple case subquery",
|
||||
input: "a = 1 AND b = 2 AND c = 3 and (select s from generate_series(1,10) s where s < 10 and s > 0 offset 2 limit 1) = 3",
|
||||
expected: []string{"a = 1", "b = 2", "c = 3", "(select s from generate_series(1,10) s where s < 10 and s > 0 offset 2 limit 1) = 3"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := splitByAND(tt.input)
|
||||
if len(result) != len(tt.expected) {
|
||||
t.Errorf("splitByAND(%q) returned %d conditions; want %d", tt.input, len(result), len(tt.expected))
|
||||
return
|
||||
}
|
||||
for i := range result {
|
||||
if strings.TrimSpace(result[i]) != strings.TrimSpace(tt.expected[i]) {
|
||||
t.Errorf("splitByAND(%q)[%d] = %q; want %q", tt.input, i, result[i], tt.expected[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateWhereClauseSecurity(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "safe WHERE clause",
|
||||
input: "status = 'active' AND age > 18",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "safe subquery",
|
||||
input: "id IN (SELECT id FROM users WHERE status = 'active' ORDER BY created_at DESC LIMIT 10)",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "DELETE keyword",
|
||||
input: "status = 'active'; DELETE FROM users",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "UPDATE keyword",
|
||||
input: "1=1; UPDATE users SET admin = true",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "TRUNCATE keyword",
|
||||
input: "status = 'active' OR TRUNCATE TABLE users",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "DROP keyword",
|
||||
input: "status = 'active'; DROP TABLE users",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "INSERT keyword",
|
||||
input: "status = 'active'; INSERT INTO users (name) VALUES ('hacker')",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "ALTER keyword",
|
||||
input: "1=1; ALTER TABLE users ADD COLUMN is_admin BOOLEAN",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "CREATE keyword",
|
||||
input: "1=1; CREATE TABLE malicious (id INT)",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "empty clause",
|
||||
input: "",
|
||||
expectError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := validateWhereClauseSecurity(tt.input)
|
||||
if tt.expectError && err == nil {
|
||||
t.Errorf("validateWhereClauseSecurity(%q) expected error but got none", tt.input)
|
||||
}
|
||||
if !tt.expectError && err != nil {
|
||||
t.Errorf("validateWhereClauseSecurity(%q) unexpected error: %v", tt.input, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeWhereClauseWithModel(t *testing.T) {
|
||||
// Register the test model
|
||||
err := modelregistry.RegisterModel(MasterTask{}, "mastertask")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user