mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2025-12-18 19:00:36 +00:00
Fixed order by. Added OrderExpr to database interface
Some checks are pending
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Waiting to run
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Waiting to run
Build , Vet Test, and Lint / Lint Code (push) Waiting to run
Build , Vet Test, and Lint / Build (push) Waiting to run
Tests / Unit Tests (push) Waiting to run
Tests / Integration Tests (push) Waiting to run
Some checks are pending
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Waiting to run
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Waiting to run
Build , Vet Test, and Lint / Lint Code (push) Waiting to run
Build , Vet Test, and Lint / Build (push) Waiting to run
Tests / Unit Tests (push) Waiting to run
Tests / Integration Tests (push) Waiting to run
This commit is contained in:
parent
932f12ab0a
commit
9351093e2a
@ -691,6 +691,11 @@ func (b *BunSelectQuery) Order(order string) common.SelectQuery {
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *BunSelectQuery) OrderExpr(order string, args ...interface{}) common.SelectQuery {
|
||||
b.query = b.query.OrderExpr(order, args...)
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *BunSelectQuery) Limit(n int) common.SelectQuery {
|
||||
b.query = b.query.Limit(n)
|
||||
return b
|
||||
|
||||
@ -386,6 +386,12 @@ func (g *GormSelectQuery) Order(order string) common.SelectQuery {
|
||||
return g
|
||||
}
|
||||
|
||||
func (g *GormSelectQuery) OrderExpr(order string, args ...interface{}) common.SelectQuery {
|
||||
// GORM's Order can handle expressions directly
|
||||
g.db = g.db.Order(gorm.Expr(order, args...))
|
||||
return g
|
||||
}
|
||||
|
||||
func (g *GormSelectQuery) Limit(n int) common.SelectQuery {
|
||||
g.db = g.db.Limit(n)
|
||||
return g
|
||||
|
||||
@ -281,6 +281,13 @@ func (p *PgSQLSelectQuery) Order(order string) common.SelectQuery {
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *PgSQLSelectQuery) OrderExpr(order string, args ...interface{}) common.SelectQuery {
|
||||
// For PgSQL, expressions are passed directly without quoting
|
||||
// If there are args, we would need to format them, but for now just append the expression
|
||||
p.orderBy = append(p.orderBy, order)
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *PgSQLSelectQuery) Limit(n int) common.SelectQuery {
|
||||
p.limit = n
|
||||
return p
|
||||
|
||||
@ -46,6 +46,7 @@ type SelectQuery interface {
|
||||
PreloadRelation(relation string, apply ...func(SelectQuery) SelectQuery) SelectQuery
|
||||
JoinRelation(relation string, apply ...func(SelectQuery) SelectQuery) SelectQuery
|
||||
Order(order string) SelectQuery
|
||||
OrderExpr(order string, args ...interface{}) SelectQuery
|
||||
Limit(n int) SelectQuery
|
||||
Offset(n int) SelectQuery
|
||||
Group(group string) SelectQuery
|
||||
|
||||
@ -237,6 +237,13 @@ func (v *ColumnValidator) FilterRequestOptions(options RequestOptions) RequestOp
|
||||
for _, sort := range options.Sort {
|
||||
if v.IsValidColumn(sort.Column) {
|
||||
validSorts = append(validSorts, sort)
|
||||
} else if strings.HasPrefix(sort.Column, "(") && strings.HasSuffix(sort.Column, ")") {
|
||||
// Allow sort by expression/subquery, but validate for security
|
||||
if IsSafeSortExpression(sort.Column) {
|
||||
validSorts = append(validSorts, sort)
|
||||
} else {
|
||||
logger.Warn("Unsafe sort expression '%s' removed", sort.Column)
|
||||
}
|
||||
} else {
|
||||
logger.Warn("Invalid column in sort '%s' removed", sort.Column)
|
||||
}
|
||||
@ -262,6 +269,24 @@ func (v *ColumnValidator) FilterRequestOptions(options RequestOptions) RequestOp
|
||||
}
|
||||
filteredPreload.Filters = validPreloadFilters
|
||||
|
||||
// Filter preload sort columns
|
||||
validPreloadSorts := make([]SortOption, 0, len(preload.Sort))
|
||||
for _, sort := range preload.Sort {
|
||||
if v.IsValidColumn(sort.Column) {
|
||||
validPreloadSorts = append(validPreloadSorts, sort)
|
||||
} else if strings.HasPrefix(sort.Column, "(") && strings.HasSuffix(sort.Column, ")") {
|
||||
// Allow sort by expression/subquery, but validate for security
|
||||
if IsSafeSortExpression(sort.Column) {
|
||||
validPreloadSorts = append(validPreloadSorts, sort)
|
||||
} else {
|
||||
logger.Warn("Unsafe sort expression in preload '%s' removed: '%s'", preload.Relation, sort.Column)
|
||||
}
|
||||
} else {
|
||||
logger.Warn("Invalid column in preload '%s' sort '%s' removed", preload.Relation, sort.Column)
|
||||
}
|
||||
}
|
||||
filteredPreload.Sort = validPreloadSorts
|
||||
|
||||
validPreloads = append(validPreloads, filteredPreload)
|
||||
}
|
||||
filtered.Preload = validPreloads
|
||||
@ -269,6 +294,56 @@ func (v *ColumnValidator) FilterRequestOptions(options RequestOptions) RequestOp
|
||||
return filtered
|
||||
}
|
||||
|
||||
// IsSafeSortExpression validates that a sort expression (enclosed in brackets) is safe
|
||||
// and doesn't contain SQL injection attempts or dangerous commands
|
||||
func IsSafeSortExpression(expr string) bool {
|
||||
if expr == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
// Expression must be enclosed in brackets
|
||||
expr = strings.TrimSpace(expr)
|
||||
if !strings.HasPrefix(expr, "(") || !strings.HasSuffix(expr, ")") {
|
||||
return false
|
||||
}
|
||||
|
||||
// Remove outer brackets for content validation
|
||||
expr = expr[1 : len(expr)-1]
|
||||
expr = strings.TrimSpace(expr)
|
||||
|
||||
// Convert to lowercase for checking dangerous keywords
|
||||
exprLower := strings.ToLower(expr)
|
||||
|
||||
// Check for dangerous SQL commands that should never be in a sort expression
|
||||
dangerousKeywords := []string{
|
||||
"drop ", "delete ", "insert ", "update ", "alter ", "create ",
|
||||
"truncate ", "exec ", "execute ", "grant ", "revoke ",
|
||||
"into ", "values ", "set ", "shutdown", "xp_",
|
||||
}
|
||||
|
||||
for _, keyword := range dangerousKeywords {
|
||||
if strings.Contains(exprLower, keyword) {
|
||||
logger.Warn("Dangerous SQL keyword '%s' detected in sort expression: %s", keyword, expr)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Check for SQL comment attempts
|
||||
if strings.Contains(expr, "--") || strings.Contains(expr, "/*") || strings.Contains(expr, "*/") {
|
||||
logger.Warn("SQL comment detected in sort expression: %s", expr)
|
||||
return false
|
||||
}
|
||||
|
||||
// Check for semicolon (command separator)
|
||||
if strings.Contains(expr, ";") {
|
||||
logger.Warn("Command separator (;) detected in sort expression: %s", expr)
|
||||
return false
|
||||
}
|
||||
|
||||
// Expression appears safe
|
||||
return true
|
||||
}
|
||||
|
||||
// GetValidColumns returns a list of all valid column names for debugging purposes
|
||||
func (v *ColumnValidator) GetValidColumns() []string {
|
||||
columns := make([]string, 0, len(v.validColumns))
|
||||
|
||||
@ -361,3 +361,83 @@ func TestFilterRequestOptions(t *testing.T) {
|
||||
t.Errorf("Expected sort column 'id', got %s", filtered.Sort[0].Column)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsSafeSortExpression(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
expression string
|
||||
shouldPass bool
|
||||
}{
|
||||
// Safe expressions
|
||||
{"Valid subquery", "(SELECT MAX(price) FROM products)", true},
|
||||
{"Valid CASE expression", "(CASE WHEN status = 'active' THEN 1 ELSE 0 END)", true},
|
||||
{"Valid aggregate", "(COUNT(*) OVER (PARTITION BY category))", true},
|
||||
{"Valid function", "(COALESCE(discount, 0))", true},
|
||||
|
||||
// Dangerous expressions - SQL injection attempts
|
||||
{"DROP TABLE attempt", "(id); DROP TABLE users; --", false},
|
||||
{"DELETE attempt", "(id WHERE 1=1); DELETE FROM users; --", false},
|
||||
{"INSERT attempt", "(id); INSERT INTO admin VALUES ('hacker'); --", false},
|
||||
{"UPDATE attempt", "(id); UPDATE users SET role='admin'; --", false},
|
||||
{"EXEC attempt", "(id); EXEC sp_executesql 'DROP TABLE users'; --", false},
|
||||
{"XP_ stored proc", "(id); xp_cmdshell 'dir'; --", false},
|
||||
|
||||
// Comment injection
|
||||
{"SQL comment dash", "(id) -- malicious comment", false},
|
||||
{"SQL comment block start", "(id) /* comment", false},
|
||||
{"SQL comment block end", "(id) comment */", false},
|
||||
|
||||
// Semicolon attempts
|
||||
{"Semicolon separator", "(id); SELECT * FROM passwords", false},
|
||||
|
||||
// Empty/invalid
|
||||
{"Empty string", "", false},
|
||||
{"Just brackets", "()", true}, // Empty but technically valid structure
|
||||
{"No brackets", "id", false}, // Must have brackets for expressions
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := IsSafeSortExpression(tt.expression)
|
||||
if result != tt.shouldPass {
|
||||
t.Errorf("IsSafeSortExpression(%q) = %v, want %v", tt.expression, result, tt.shouldPass)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilterRequestOptions_WithSortExpressions(t *testing.T) {
|
||||
model := TestModel{}
|
||||
validator := NewColumnValidator(model)
|
||||
|
||||
options := RequestOptions{
|
||||
Sort: []SortOption{
|
||||
{Column: "id", Direction: "ASC"}, // Valid column
|
||||
{Column: "(SELECT MAX(age) FROM users)", Direction: "DESC"}, // Safe expression
|
||||
{Column: "name", Direction: "ASC"}, // Valid column
|
||||
{Column: "(id); DROP TABLE users; --", Direction: "DESC"}, // Dangerous expression
|
||||
{Column: "invalid_col", Direction: "ASC"}, // Invalid column
|
||||
{Column: "(CASE WHEN age > 18 THEN 1 ELSE 0 END)", Direction: "ASC"}, // Safe expression
|
||||
},
|
||||
}
|
||||
|
||||
filtered := validator.FilterRequestOptions(options)
|
||||
|
||||
// Should keep: id, safe expression, name, another safe expression
|
||||
// Should remove: dangerous expression, invalid column
|
||||
expectedCount := 4
|
||||
if len(filtered.Sort) != expectedCount {
|
||||
t.Errorf("Expected %d sort options, got %d", expectedCount, len(filtered.Sort))
|
||||
}
|
||||
|
||||
// Verify the kept options
|
||||
if filtered.Sort[0].Column != "id" {
|
||||
t.Errorf("Expected first sort to be 'id', got '%s'", filtered.Sort[0].Column)
|
||||
}
|
||||
if filtered.Sort[1].Column != "(SELECT MAX(age) FROM users)" {
|
||||
t.Errorf("Expected second sort to be safe expression, got '%s'", filtered.Sort[1].Column)
|
||||
}
|
||||
if filtered.Sort[2].Column != "name" {
|
||||
t.Errorf("Expected third sort to be 'name', got '%s'", filtered.Sort[2].Column)
|
||||
}
|
||||
}
|
||||
|
||||
@ -513,7 +513,15 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
||||
direction = "DESC"
|
||||
}
|
||||
logger.Debug("Applying sort: %s %s", sort.Column, direction)
|
||||
query = query.Order(fmt.Sprintf("%s %s", sort.Column, direction))
|
||||
|
||||
// Check if it's an expression (enclosed in brackets) - use directly without quoting
|
||||
if strings.HasPrefix(sort.Column, "(") && strings.HasSuffix(sort.Column, ")") {
|
||||
// For expressions, pass as raw SQL to prevent auto-quoting
|
||||
query = query.OrderExpr(fmt.Sprintf("%s %s", sort.Column, direction))
|
||||
} else {
|
||||
// Regular column - let Bun handle quoting
|
||||
query = query.Order(fmt.Sprintf("%s %s", sort.Column, direction))
|
||||
}
|
||||
}
|
||||
|
||||
// Get total count before pagination (unless skip count is requested)
|
||||
@ -827,7 +835,14 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co
|
||||
// Apply sorting
|
||||
if len(preload.Sort) > 0 {
|
||||
for _, sort := range preload.Sort {
|
||||
sq = sq.Order(fmt.Sprintf("%s %s", sort.Column, sort.Direction))
|
||||
// Check if it's an expression (enclosed in brackets) - use directly without quoting
|
||||
if strings.HasPrefix(sort.Column, "(") && strings.HasSuffix(sort.Column, ")") {
|
||||
// For expressions, pass as raw SQL to prevent auto-quoting
|
||||
sq = sq.OrderExpr(fmt.Sprintf("%s %s", sort.Column, sort.Direction))
|
||||
} else {
|
||||
// Regular column - let ORM handle quoting
|
||||
sq = sq.Order(fmt.Sprintf("%s %s", sort.Column, sort.Direction))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -2188,7 +2203,14 @@ func (h *Handler) FetchRowNumber(ctx context.Context, tableName string, pkName s
|
||||
if strings.EqualFold(sort.Direction, "desc") {
|
||||
direction = "DESC"
|
||||
}
|
||||
sortParts = append(sortParts, fmt.Sprintf("%s.%s %s", tableName, sort.Column, direction))
|
||||
|
||||
// Check if it's an expression (enclosed in brackets) - use directly without table prefix
|
||||
if strings.HasPrefix(sort.Column, "(") && strings.HasSuffix(sort.Column, ")") {
|
||||
sortParts = append(sortParts, fmt.Sprintf("%s %s", sort.Column, direction))
|
||||
} else {
|
||||
// Regular column - add table prefix
|
||||
sortParts = append(sortParts, fmt.Sprintf("%s.%s %s", tableName, sort.Column, direction))
|
||||
}
|
||||
}
|
||||
sortSQL = strings.Join(sortParts, ", ")
|
||||
} else {
|
||||
@ -2397,6 +2419,55 @@ func (h *Handler) filterExtendedOptions(validator *common.ColumnValidator, optio
|
||||
expandValidator := common.NewColumnValidator(relInfo.relatedModel)
|
||||
// Filter columns using the related model's validator
|
||||
filteredExpand.Columns = expandValidator.FilterValidColumns(expand.Columns)
|
||||
|
||||
// Filter sort columns in the expand Sort string
|
||||
if expand.Sort != "" {
|
||||
sortFields := strings.Split(expand.Sort, ",")
|
||||
validSortFields := make([]string, 0, len(sortFields))
|
||||
for _, sortField := range sortFields {
|
||||
sortField = strings.TrimSpace(sortField)
|
||||
if sortField == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Extract column name (remove direction prefixes/suffixes)
|
||||
colName := sortField
|
||||
direction := ""
|
||||
|
||||
if strings.HasPrefix(sortField, "-") {
|
||||
direction = "-"
|
||||
colName = strings.TrimPrefix(sortField, "-")
|
||||
} else if strings.HasPrefix(sortField, "+") {
|
||||
direction = "+"
|
||||
colName = strings.TrimPrefix(sortField, "+")
|
||||
}
|
||||
|
||||
if strings.HasSuffix(strings.ToLower(colName), " desc") {
|
||||
direction = " desc"
|
||||
colName = strings.TrimSuffix(strings.ToLower(colName), " desc")
|
||||
} else if strings.HasSuffix(strings.ToLower(colName), " asc") {
|
||||
direction = " asc"
|
||||
colName = strings.TrimSuffix(strings.ToLower(colName), " asc")
|
||||
}
|
||||
|
||||
colName = strings.TrimSpace(colName)
|
||||
|
||||
// Validate the column name
|
||||
if expandValidator.IsValidColumn(colName) {
|
||||
validSortFields = append(validSortFields, direction+colName)
|
||||
} else if strings.HasPrefix(colName, "(") && strings.HasSuffix(colName, ")") {
|
||||
// Allow sort by expression/subquery, but validate for security
|
||||
if common.IsSafeSortExpression(colName) {
|
||||
validSortFields = append(validSortFields, direction+colName)
|
||||
} else {
|
||||
logger.Warn("Unsafe sort expression in expand '%s' removed: '%s'", expand.Relation, colName)
|
||||
}
|
||||
} else {
|
||||
logger.Warn("Invalid column in expand '%s' sort '%s' removed", expand.Relation, colName)
|
||||
}
|
||||
}
|
||||
filteredExpand.Sort = strings.Join(validSortFields, ",")
|
||||
}
|
||||
} else {
|
||||
// If we can't find the relationship, log a warning and skip column filtering
|
||||
logger.Warn("Cannot validate columns for unknown relation: %s", expand.Relation)
|
||||
|
||||
@ -529,19 +529,47 @@ func (h *Handler) parseSorting(options *ExtendedRequestOptions, value string) {
|
||||
}
|
||||
|
||||
// parseCommaSeparated parses comma-separated values and trims whitespace
|
||||
// It respects bracket nesting and only splits on commas outside of parentheses
|
||||
func (h *Handler) parseCommaSeparated(value string) []string {
|
||||
if value == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
parts := strings.Split(value, ",")
|
||||
result := make([]string, 0, len(parts))
|
||||
for _, part := range parts {
|
||||
part = strings.TrimSpace(part)
|
||||
if part != "" {
|
||||
result = append(result, part)
|
||||
result := make([]string, 0)
|
||||
var current strings.Builder
|
||||
nestingLevel := 0
|
||||
|
||||
for _, char := range value {
|
||||
switch char {
|
||||
case '(':
|
||||
nestingLevel++
|
||||
current.WriteRune(char)
|
||||
case ')':
|
||||
nestingLevel--
|
||||
current.WriteRune(char)
|
||||
case ',':
|
||||
if nestingLevel == 0 {
|
||||
// We're outside all brackets, so split here
|
||||
part := strings.TrimSpace(current.String())
|
||||
if part != "" {
|
||||
result = append(result, part)
|
||||
}
|
||||
current.Reset()
|
||||
} else {
|
||||
// Inside brackets, keep the comma
|
||||
current.WriteRune(char)
|
||||
}
|
||||
default:
|
||||
current.WriteRune(char)
|
||||
}
|
||||
}
|
||||
|
||||
// Add the last part
|
||||
part := strings.TrimSpace(current.String())
|
||||
if part != "" {
|
||||
result = append(result, part)
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user