Fixed order by. Added OrderExpr to database interface
Some checks are pending
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Waiting to run
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Waiting to run
Build , Vet Test, and Lint / Lint Code (push) Waiting to run
Build , Vet Test, and Lint / Build (push) Waiting to run
Tests / Unit Tests (push) Waiting to run
Tests / Integration Tests (push) Waiting to run

This commit is contained in:
Hein 2025-12-17 16:50:33 +02:00
parent 932f12ab0a
commit 9351093e2a
8 changed files with 282 additions and 9 deletions

View File

@ -691,6 +691,11 @@ func (b *BunSelectQuery) Order(order string) common.SelectQuery {
return b return b
} }
func (b *BunSelectQuery) OrderExpr(order string, args ...interface{}) common.SelectQuery {
b.query = b.query.OrderExpr(order, args...)
return b
}
func (b *BunSelectQuery) Limit(n int) common.SelectQuery { func (b *BunSelectQuery) Limit(n int) common.SelectQuery {
b.query = b.query.Limit(n) b.query = b.query.Limit(n)
return b return b

View File

@ -386,6 +386,12 @@ func (g *GormSelectQuery) Order(order string) common.SelectQuery {
return g return g
} }
func (g *GormSelectQuery) OrderExpr(order string, args ...interface{}) common.SelectQuery {
// GORM's Order can handle expressions directly
g.db = g.db.Order(gorm.Expr(order, args...))
return g
}
func (g *GormSelectQuery) Limit(n int) common.SelectQuery { func (g *GormSelectQuery) Limit(n int) common.SelectQuery {
g.db = g.db.Limit(n) g.db = g.db.Limit(n)
return g return g

View File

@ -281,6 +281,13 @@ func (p *PgSQLSelectQuery) Order(order string) common.SelectQuery {
return p return p
} }
func (p *PgSQLSelectQuery) OrderExpr(order string, args ...interface{}) common.SelectQuery {
// For PgSQL, expressions are passed directly without quoting
// If there are args, we would need to format them, but for now just append the expression
p.orderBy = append(p.orderBy, order)
return p
}
func (p *PgSQLSelectQuery) Limit(n int) common.SelectQuery { func (p *PgSQLSelectQuery) Limit(n int) common.SelectQuery {
p.limit = n p.limit = n
return p return p

View File

@ -46,6 +46,7 @@ type SelectQuery interface {
PreloadRelation(relation string, apply ...func(SelectQuery) SelectQuery) SelectQuery PreloadRelation(relation string, apply ...func(SelectQuery) SelectQuery) SelectQuery
JoinRelation(relation string, apply ...func(SelectQuery) SelectQuery) SelectQuery JoinRelation(relation string, apply ...func(SelectQuery) SelectQuery) SelectQuery
Order(order string) SelectQuery Order(order string) SelectQuery
OrderExpr(order string, args ...interface{}) SelectQuery
Limit(n int) SelectQuery Limit(n int) SelectQuery
Offset(n int) SelectQuery Offset(n int) SelectQuery
Group(group string) SelectQuery Group(group string) SelectQuery

View File

@ -237,6 +237,13 @@ func (v *ColumnValidator) FilterRequestOptions(options RequestOptions) RequestOp
for _, sort := range options.Sort { for _, sort := range options.Sort {
if v.IsValidColumn(sort.Column) { if v.IsValidColumn(sort.Column) {
validSorts = append(validSorts, sort) validSorts = append(validSorts, sort)
} else if strings.HasPrefix(sort.Column, "(") && strings.HasSuffix(sort.Column, ")") {
// Allow sort by expression/subquery, but validate for security
if IsSafeSortExpression(sort.Column) {
validSorts = append(validSorts, sort)
} else {
logger.Warn("Unsafe sort expression '%s' removed", sort.Column)
}
} else { } else {
logger.Warn("Invalid column in sort '%s' removed", sort.Column) logger.Warn("Invalid column in sort '%s' removed", sort.Column)
} }
@ -262,6 +269,24 @@ func (v *ColumnValidator) FilterRequestOptions(options RequestOptions) RequestOp
} }
filteredPreload.Filters = validPreloadFilters filteredPreload.Filters = validPreloadFilters
// Filter preload sort columns
validPreloadSorts := make([]SortOption, 0, len(preload.Sort))
for _, sort := range preload.Sort {
if v.IsValidColumn(sort.Column) {
validPreloadSorts = append(validPreloadSorts, sort)
} else if strings.HasPrefix(sort.Column, "(") && strings.HasSuffix(sort.Column, ")") {
// Allow sort by expression/subquery, but validate for security
if IsSafeSortExpression(sort.Column) {
validPreloadSorts = append(validPreloadSorts, sort)
} else {
logger.Warn("Unsafe sort expression in preload '%s' removed: '%s'", preload.Relation, sort.Column)
}
} else {
logger.Warn("Invalid column in preload '%s' sort '%s' removed", preload.Relation, sort.Column)
}
}
filteredPreload.Sort = validPreloadSorts
validPreloads = append(validPreloads, filteredPreload) validPreloads = append(validPreloads, filteredPreload)
} }
filtered.Preload = validPreloads filtered.Preload = validPreloads
@ -269,6 +294,56 @@ func (v *ColumnValidator) FilterRequestOptions(options RequestOptions) RequestOp
return filtered return filtered
} }
// IsSafeSortExpression validates that a sort expression (enclosed in brackets) is safe
// and doesn't contain SQL injection attempts or dangerous commands
func IsSafeSortExpression(expr string) bool {
if expr == "" {
return false
}
// Expression must be enclosed in brackets
expr = strings.TrimSpace(expr)
if !strings.HasPrefix(expr, "(") || !strings.HasSuffix(expr, ")") {
return false
}
// Remove outer brackets for content validation
expr = expr[1 : len(expr)-1]
expr = strings.TrimSpace(expr)
// Convert to lowercase for checking dangerous keywords
exprLower := strings.ToLower(expr)
// Check for dangerous SQL commands that should never be in a sort expression
dangerousKeywords := []string{
"drop ", "delete ", "insert ", "update ", "alter ", "create ",
"truncate ", "exec ", "execute ", "grant ", "revoke ",
"into ", "values ", "set ", "shutdown", "xp_",
}
for _, keyword := range dangerousKeywords {
if strings.Contains(exprLower, keyword) {
logger.Warn("Dangerous SQL keyword '%s' detected in sort expression: %s", keyword, expr)
return false
}
}
// Check for SQL comment attempts
if strings.Contains(expr, "--") || strings.Contains(expr, "/*") || strings.Contains(expr, "*/") {
logger.Warn("SQL comment detected in sort expression: %s", expr)
return false
}
// Check for semicolon (command separator)
if strings.Contains(expr, ";") {
logger.Warn("Command separator (;) detected in sort expression: %s", expr)
return false
}
// Expression appears safe
return true
}
// GetValidColumns returns a list of all valid column names for debugging purposes // GetValidColumns returns a list of all valid column names for debugging purposes
func (v *ColumnValidator) GetValidColumns() []string { func (v *ColumnValidator) GetValidColumns() []string {
columns := make([]string, 0, len(v.validColumns)) columns := make([]string, 0, len(v.validColumns))

View File

@ -361,3 +361,83 @@ func TestFilterRequestOptions(t *testing.T) {
t.Errorf("Expected sort column 'id', got %s", filtered.Sort[0].Column) t.Errorf("Expected sort column 'id', got %s", filtered.Sort[0].Column)
} }
} }
func TestIsSafeSortExpression(t *testing.T) {
tests := []struct {
name string
expression string
shouldPass bool
}{
// Safe expressions
{"Valid subquery", "(SELECT MAX(price) FROM products)", true},
{"Valid CASE expression", "(CASE WHEN status = 'active' THEN 1 ELSE 0 END)", true},
{"Valid aggregate", "(COUNT(*) OVER (PARTITION BY category))", true},
{"Valid function", "(COALESCE(discount, 0))", true},
// Dangerous expressions - SQL injection attempts
{"DROP TABLE attempt", "(id); DROP TABLE users; --", false},
{"DELETE attempt", "(id WHERE 1=1); DELETE FROM users; --", false},
{"INSERT attempt", "(id); INSERT INTO admin VALUES ('hacker'); --", false},
{"UPDATE attempt", "(id); UPDATE users SET role='admin'; --", false},
{"EXEC attempt", "(id); EXEC sp_executesql 'DROP TABLE users'; --", false},
{"XP_ stored proc", "(id); xp_cmdshell 'dir'; --", false},
// Comment injection
{"SQL comment dash", "(id) -- malicious comment", false},
{"SQL comment block start", "(id) /* comment", false},
{"SQL comment block end", "(id) comment */", false},
// Semicolon attempts
{"Semicolon separator", "(id); SELECT * FROM passwords", false},
// Empty/invalid
{"Empty string", "", false},
{"Just brackets", "()", true}, // Empty but technically valid structure
{"No brackets", "id", false}, // Must have brackets for expressions
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := IsSafeSortExpression(tt.expression)
if result != tt.shouldPass {
t.Errorf("IsSafeSortExpression(%q) = %v, want %v", tt.expression, result, tt.shouldPass)
}
})
}
}
func TestFilterRequestOptions_WithSortExpressions(t *testing.T) {
model := TestModel{}
validator := NewColumnValidator(model)
options := RequestOptions{
Sort: []SortOption{
{Column: "id", Direction: "ASC"}, // Valid column
{Column: "(SELECT MAX(age) FROM users)", Direction: "DESC"}, // Safe expression
{Column: "name", Direction: "ASC"}, // Valid column
{Column: "(id); DROP TABLE users; --", Direction: "DESC"}, // Dangerous expression
{Column: "invalid_col", Direction: "ASC"}, // Invalid column
{Column: "(CASE WHEN age > 18 THEN 1 ELSE 0 END)", Direction: "ASC"}, // Safe expression
},
}
filtered := validator.FilterRequestOptions(options)
// Should keep: id, safe expression, name, another safe expression
// Should remove: dangerous expression, invalid column
expectedCount := 4
if len(filtered.Sort) != expectedCount {
t.Errorf("Expected %d sort options, got %d", expectedCount, len(filtered.Sort))
}
// Verify the kept options
if filtered.Sort[0].Column != "id" {
t.Errorf("Expected first sort to be 'id', got '%s'", filtered.Sort[0].Column)
}
if filtered.Sort[1].Column != "(SELECT MAX(age) FROM users)" {
t.Errorf("Expected second sort to be safe expression, got '%s'", filtered.Sort[1].Column)
}
if filtered.Sort[2].Column != "name" {
t.Errorf("Expected third sort to be 'name', got '%s'", filtered.Sort[2].Column)
}
}

View File

@ -513,7 +513,15 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
direction = "DESC" direction = "DESC"
} }
logger.Debug("Applying sort: %s %s", sort.Column, direction) logger.Debug("Applying sort: %s %s", sort.Column, direction)
query = query.Order(fmt.Sprintf("%s %s", sort.Column, direction))
// Check if it's an expression (enclosed in brackets) - use directly without quoting
if strings.HasPrefix(sort.Column, "(") && strings.HasSuffix(sort.Column, ")") {
// For expressions, pass as raw SQL to prevent auto-quoting
query = query.OrderExpr(fmt.Sprintf("%s %s", sort.Column, direction))
} else {
// Regular column - let Bun handle quoting
query = query.Order(fmt.Sprintf("%s %s", sort.Column, direction))
}
} }
// Get total count before pagination (unless skip count is requested) // Get total count before pagination (unless skip count is requested)
@ -827,7 +835,14 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co
// Apply sorting // Apply sorting
if len(preload.Sort) > 0 { if len(preload.Sort) > 0 {
for _, sort := range preload.Sort { for _, sort := range preload.Sort {
sq = sq.Order(fmt.Sprintf("%s %s", sort.Column, sort.Direction)) // Check if it's an expression (enclosed in brackets) - use directly without quoting
if strings.HasPrefix(sort.Column, "(") && strings.HasSuffix(sort.Column, ")") {
// For expressions, pass as raw SQL to prevent auto-quoting
sq = sq.OrderExpr(fmt.Sprintf("%s %s", sort.Column, sort.Direction))
} else {
// Regular column - let ORM handle quoting
sq = sq.Order(fmt.Sprintf("%s %s", sort.Column, sort.Direction))
}
} }
} }
@ -2188,7 +2203,14 @@ func (h *Handler) FetchRowNumber(ctx context.Context, tableName string, pkName s
if strings.EqualFold(sort.Direction, "desc") { if strings.EqualFold(sort.Direction, "desc") {
direction = "DESC" direction = "DESC"
} }
sortParts = append(sortParts, fmt.Sprintf("%s.%s %s", tableName, sort.Column, direction))
// Check if it's an expression (enclosed in brackets) - use directly without table prefix
if strings.HasPrefix(sort.Column, "(") && strings.HasSuffix(sort.Column, ")") {
sortParts = append(sortParts, fmt.Sprintf("%s %s", sort.Column, direction))
} else {
// Regular column - add table prefix
sortParts = append(sortParts, fmt.Sprintf("%s.%s %s", tableName, sort.Column, direction))
}
} }
sortSQL = strings.Join(sortParts, ", ") sortSQL = strings.Join(sortParts, ", ")
} else { } else {
@ -2397,6 +2419,55 @@ func (h *Handler) filterExtendedOptions(validator *common.ColumnValidator, optio
expandValidator := common.NewColumnValidator(relInfo.relatedModel) expandValidator := common.NewColumnValidator(relInfo.relatedModel)
// Filter columns using the related model's validator // Filter columns using the related model's validator
filteredExpand.Columns = expandValidator.FilterValidColumns(expand.Columns) filteredExpand.Columns = expandValidator.FilterValidColumns(expand.Columns)
// Filter sort columns in the expand Sort string
if expand.Sort != "" {
sortFields := strings.Split(expand.Sort, ",")
validSortFields := make([]string, 0, len(sortFields))
for _, sortField := range sortFields {
sortField = strings.TrimSpace(sortField)
if sortField == "" {
continue
}
// Extract column name (remove direction prefixes/suffixes)
colName := sortField
direction := ""
if strings.HasPrefix(sortField, "-") {
direction = "-"
colName = strings.TrimPrefix(sortField, "-")
} else if strings.HasPrefix(sortField, "+") {
direction = "+"
colName = strings.TrimPrefix(sortField, "+")
}
if strings.HasSuffix(strings.ToLower(colName), " desc") {
direction = " desc"
colName = strings.TrimSuffix(strings.ToLower(colName), " desc")
} else if strings.HasSuffix(strings.ToLower(colName), " asc") {
direction = " asc"
colName = strings.TrimSuffix(strings.ToLower(colName), " asc")
}
colName = strings.TrimSpace(colName)
// Validate the column name
if expandValidator.IsValidColumn(colName) {
validSortFields = append(validSortFields, direction+colName)
} else if strings.HasPrefix(colName, "(") && strings.HasSuffix(colName, ")") {
// Allow sort by expression/subquery, but validate for security
if common.IsSafeSortExpression(colName) {
validSortFields = append(validSortFields, direction+colName)
} else {
logger.Warn("Unsafe sort expression in expand '%s' removed: '%s'", expand.Relation, colName)
}
} else {
logger.Warn("Invalid column in expand '%s' sort '%s' removed", expand.Relation, colName)
}
}
filteredExpand.Sort = strings.Join(validSortFields, ",")
}
} else { } else {
// If we can't find the relationship, log a warning and skip column filtering // If we can't find the relationship, log a warning and skip column filtering
logger.Warn("Cannot validate columns for unknown relation: %s", expand.Relation) logger.Warn("Cannot validate columns for unknown relation: %s", expand.Relation)

View File

@ -529,19 +529,47 @@ func (h *Handler) parseSorting(options *ExtendedRequestOptions, value string) {
} }
// parseCommaSeparated parses comma-separated values and trims whitespace // parseCommaSeparated parses comma-separated values and trims whitespace
// It respects bracket nesting and only splits on commas outside of parentheses
func (h *Handler) parseCommaSeparated(value string) []string { func (h *Handler) parseCommaSeparated(value string) []string {
if value == "" { if value == "" {
return nil return nil
} }
parts := strings.Split(value, ",") result := make([]string, 0)
result := make([]string, 0, len(parts)) var current strings.Builder
for _, part := range parts { nestingLevel := 0
part = strings.TrimSpace(part)
if part != "" { for _, char := range value {
result = append(result, part) switch char {
case '(':
nestingLevel++
current.WriteRune(char)
case ')':
nestingLevel--
current.WriteRune(char)
case ',':
if nestingLevel == 0 {
// We're outside all brackets, so split here
part := strings.TrimSpace(current.String())
if part != "" {
result = append(result, part)
}
current.Reset()
} else {
// Inside brackets, keep the comma
current.WriteRune(char)
}
default:
current.WriteRune(char)
} }
} }
// Add the last part
part := strings.TrimSpace(current.String())
if part != "" {
result = append(result, part)
}
return result return result
} }