diff --git a/pkg/common/adapters/database/bun.go b/pkg/common/adapters/database/bun.go index a83b945..52ba33d 100644 --- a/pkg/common/adapters/database/bun.go +++ b/pkg/common/adapters/database/bun.go @@ -691,6 +691,11 @@ func (b *BunSelectQuery) Order(order string) common.SelectQuery { return b } +func (b *BunSelectQuery) OrderExpr(order string, args ...interface{}) common.SelectQuery { + b.query = b.query.OrderExpr(order, args...) + return b +} + func (b *BunSelectQuery) Limit(n int) common.SelectQuery { b.query = b.query.Limit(n) return b diff --git a/pkg/common/adapters/database/gorm.go b/pkg/common/adapters/database/gorm.go index 4bf1cf6..9d3b3d9 100644 --- a/pkg/common/adapters/database/gorm.go +++ b/pkg/common/adapters/database/gorm.go @@ -386,6 +386,12 @@ func (g *GormSelectQuery) Order(order string) common.SelectQuery { return g } +func (g *GormSelectQuery) OrderExpr(order string, args ...interface{}) common.SelectQuery { + // GORM's Order can handle expressions directly + g.db = g.db.Order(gorm.Expr(order, args...)) + return g +} + func (g *GormSelectQuery) Limit(n int) common.SelectQuery { g.db = g.db.Limit(n) return g diff --git a/pkg/common/adapters/database/pgsql.go b/pkg/common/adapters/database/pgsql.go index 4b81204..f2486f3 100644 --- a/pkg/common/adapters/database/pgsql.go +++ b/pkg/common/adapters/database/pgsql.go @@ -281,6 +281,13 @@ func (p *PgSQLSelectQuery) Order(order string) common.SelectQuery { return p } +func (p *PgSQLSelectQuery) OrderExpr(order string, args ...interface{}) common.SelectQuery { + // For PgSQL, expressions are passed directly without quoting + // If there are args, we would need to format them, but for now just append the expression + p.orderBy = append(p.orderBy, order) + return p +} + func (p *PgSQLSelectQuery) Limit(n int) common.SelectQuery { p.limit = n return p diff --git a/pkg/common/interfaces.go b/pkg/common/interfaces.go index 57dd78f..03a72a0 100644 --- a/pkg/common/interfaces.go +++ b/pkg/common/interfaces.go @@ -46,6 +46,7 @@ type SelectQuery interface { PreloadRelation(relation string, apply ...func(SelectQuery) SelectQuery) SelectQuery JoinRelation(relation string, apply ...func(SelectQuery) SelectQuery) SelectQuery Order(order string) SelectQuery + OrderExpr(order string, args ...interface{}) SelectQuery Limit(n int) SelectQuery Offset(n int) SelectQuery Group(group string) SelectQuery diff --git a/pkg/common/validation.go b/pkg/common/validation.go index c177471..a1ac064 100644 --- a/pkg/common/validation.go +++ b/pkg/common/validation.go @@ -237,6 +237,13 @@ func (v *ColumnValidator) FilterRequestOptions(options RequestOptions) RequestOp for _, sort := range options.Sort { if v.IsValidColumn(sort.Column) { validSorts = append(validSorts, sort) + } else if strings.HasPrefix(sort.Column, "(") && strings.HasSuffix(sort.Column, ")") { + // Allow sort by expression/subquery, but validate for security + if IsSafeSortExpression(sort.Column) { + validSorts = append(validSorts, sort) + } else { + logger.Warn("Unsafe sort expression '%s' removed", sort.Column) + } } else { logger.Warn("Invalid column in sort '%s' removed", sort.Column) } @@ -262,6 +269,24 @@ func (v *ColumnValidator) FilterRequestOptions(options RequestOptions) RequestOp } filteredPreload.Filters = validPreloadFilters + // Filter preload sort columns + validPreloadSorts := make([]SortOption, 0, len(preload.Sort)) + for _, sort := range preload.Sort { + if v.IsValidColumn(sort.Column) { + validPreloadSorts = append(validPreloadSorts, sort) + } else if strings.HasPrefix(sort.Column, "(") && strings.HasSuffix(sort.Column, ")") { + // Allow sort by expression/subquery, but validate for security + if IsSafeSortExpression(sort.Column) { + validPreloadSorts = append(validPreloadSorts, sort) + } else { + logger.Warn("Unsafe sort expression in preload '%s' removed: '%s'", preload.Relation, sort.Column) + } + } else { + logger.Warn("Invalid column in preload '%s' sort '%s' removed", preload.Relation, sort.Column) + } + } + filteredPreload.Sort = validPreloadSorts + validPreloads = append(validPreloads, filteredPreload) } filtered.Preload = validPreloads @@ -269,6 +294,56 @@ func (v *ColumnValidator) FilterRequestOptions(options RequestOptions) RequestOp return filtered } +// IsSafeSortExpression validates that a sort expression (enclosed in brackets) is safe +// and doesn't contain SQL injection attempts or dangerous commands +func IsSafeSortExpression(expr string) bool { + if expr == "" { + return false + } + + // Expression must be enclosed in brackets + expr = strings.TrimSpace(expr) + if !strings.HasPrefix(expr, "(") || !strings.HasSuffix(expr, ")") { + return false + } + + // Remove outer brackets for content validation + expr = expr[1 : len(expr)-1] + expr = strings.TrimSpace(expr) + + // Convert to lowercase for checking dangerous keywords + exprLower := strings.ToLower(expr) + + // Check for dangerous SQL commands that should never be in a sort expression + dangerousKeywords := []string{ + "drop ", "delete ", "insert ", "update ", "alter ", "create ", + "truncate ", "exec ", "execute ", "grant ", "revoke ", + "into ", "values ", "set ", "shutdown", "xp_", + } + + for _, keyword := range dangerousKeywords { + if strings.Contains(exprLower, keyword) { + logger.Warn("Dangerous SQL keyword '%s' detected in sort expression: %s", keyword, expr) + return false + } + } + + // Check for SQL comment attempts + if strings.Contains(expr, "--") || strings.Contains(expr, "/*") || strings.Contains(expr, "*/") { + logger.Warn("SQL comment detected in sort expression: %s", expr) + return false + } + + // Check for semicolon (command separator) + if strings.Contains(expr, ";") { + logger.Warn("Command separator (;) detected in sort expression: %s", expr) + return false + } + + // Expression appears safe + return true +} + // GetValidColumns returns a list of all valid column names for debugging purposes func (v *ColumnValidator) GetValidColumns() []string { columns := make([]string, 0, len(v.validColumns)) diff --git a/pkg/common/validation_test.go b/pkg/common/validation_test.go index a68be98..1e56070 100644 --- a/pkg/common/validation_test.go +++ b/pkg/common/validation_test.go @@ -361,3 +361,83 @@ func TestFilterRequestOptions(t *testing.T) { t.Errorf("Expected sort column 'id', got %s", filtered.Sort[0].Column) } } + +func TestIsSafeSortExpression(t *testing.T) { + tests := []struct { + name string + expression string + shouldPass bool + }{ + // Safe expressions + {"Valid subquery", "(SELECT MAX(price) FROM products)", true}, + {"Valid CASE expression", "(CASE WHEN status = 'active' THEN 1 ELSE 0 END)", true}, + {"Valid aggregate", "(COUNT(*) OVER (PARTITION BY category))", true}, + {"Valid function", "(COALESCE(discount, 0))", true}, + + // Dangerous expressions - SQL injection attempts + {"DROP TABLE attempt", "(id); DROP TABLE users; --", false}, + {"DELETE attempt", "(id WHERE 1=1); DELETE FROM users; --", false}, + {"INSERT attempt", "(id); INSERT INTO admin VALUES ('hacker'); --", false}, + {"UPDATE attempt", "(id); UPDATE users SET role='admin'; --", false}, + {"EXEC attempt", "(id); EXEC sp_executesql 'DROP TABLE users'; --", false}, + {"XP_ stored proc", "(id); xp_cmdshell 'dir'; --", false}, + + // Comment injection + {"SQL comment dash", "(id) -- malicious comment", false}, + {"SQL comment block start", "(id) /* comment", false}, + {"SQL comment block end", "(id) comment */", false}, + + // Semicolon attempts + {"Semicolon separator", "(id); SELECT * FROM passwords", false}, + + // Empty/invalid + {"Empty string", "", false}, + {"Just brackets", "()", true}, // Empty but technically valid structure + {"No brackets", "id", false}, // Must have brackets for expressions + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := IsSafeSortExpression(tt.expression) + if result != tt.shouldPass { + t.Errorf("IsSafeSortExpression(%q) = %v, want %v", tt.expression, result, tt.shouldPass) + } + }) + } +} + +func TestFilterRequestOptions_WithSortExpressions(t *testing.T) { + model := TestModel{} + validator := NewColumnValidator(model) + + options := RequestOptions{ + Sort: []SortOption{ + {Column: "id", Direction: "ASC"}, // Valid column + {Column: "(SELECT MAX(age) FROM users)", Direction: "DESC"}, // Safe expression + {Column: "name", Direction: "ASC"}, // Valid column + {Column: "(id); DROP TABLE users; --", Direction: "DESC"}, // Dangerous expression + {Column: "invalid_col", Direction: "ASC"}, // Invalid column + {Column: "(CASE WHEN age > 18 THEN 1 ELSE 0 END)", Direction: "ASC"}, // Safe expression + }, + } + + filtered := validator.FilterRequestOptions(options) + + // Should keep: id, safe expression, name, another safe expression + // Should remove: dangerous expression, invalid column + expectedCount := 4 + if len(filtered.Sort) != expectedCount { + t.Errorf("Expected %d sort options, got %d", expectedCount, len(filtered.Sort)) + } + + // Verify the kept options + if filtered.Sort[0].Column != "id" { + t.Errorf("Expected first sort to be 'id', got '%s'", filtered.Sort[0].Column) + } + if filtered.Sort[1].Column != "(SELECT MAX(age) FROM users)" { + t.Errorf("Expected second sort to be safe expression, got '%s'", filtered.Sort[1].Column) + } + if filtered.Sort[2].Column != "name" { + t.Errorf("Expected third sort to be 'name', got '%s'", filtered.Sort[2].Column) + } +} diff --git a/pkg/restheadspec/handler.go b/pkg/restheadspec/handler.go index ce6bb64..6c4ffb8 100644 --- a/pkg/restheadspec/handler.go +++ b/pkg/restheadspec/handler.go @@ -513,7 +513,15 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st direction = "DESC" } logger.Debug("Applying sort: %s %s", sort.Column, direction) - query = query.Order(fmt.Sprintf("%s %s", sort.Column, direction)) + + // Check if it's an expression (enclosed in brackets) - use directly without quoting + if strings.HasPrefix(sort.Column, "(") && strings.HasSuffix(sort.Column, ")") { + // For expressions, pass as raw SQL to prevent auto-quoting + query = query.OrderExpr(fmt.Sprintf("%s %s", sort.Column, direction)) + } else { + // Regular column - let Bun handle quoting + query = query.Order(fmt.Sprintf("%s %s", sort.Column, direction)) + } } // Get total count before pagination (unless skip count is requested) @@ -827,7 +835,14 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co // Apply sorting if len(preload.Sort) > 0 { for _, sort := range preload.Sort { - sq = sq.Order(fmt.Sprintf("%s %s", sort.Column, sort.Direction)) + // Check if it's an expression (enclosed in brackets) - use directly without quoting + if strings.HasPrefix(sort.Column, "(") && strings.HasSuffix(sort.Column, ")") { + // For expressions, pass as raw SQL to prevent auto-quoting + sq = sq.OrderExpr(fmt.Sprintf("%s %s", sort.Column, sort.Direction)) + } else { + // Regular column - let ORM handle quoting + sq = sq.Order(fmt.Sprintf("%s %s", sort.Column, sort.Direction)) + } } } @@ -2188,7 +2203,14 @@ func (h *Handler) FetchRowNumber(ctx context.Context, tableName string, pkName s if strings.EqualFold(sort.Direction, "desc") { direction = "DESC" } - sortParts = append(sortParts, fmt.Sprintf("%s.%s %s", tableName, sort.Column, direction)) + + // Check if it's an expression (enclosed in brackets) - use directly without table prefix + if strings.HasPrefix(sort.Column, "(") && strings.HasSuffix(sort.Column, ")") { + sortParts = append(sortParts, fmt.Sprintf("%s %s", sort.Column, direction)) + } else { + // Regular column - add table prefix + sortParts = append(sortParts, fmt.Sprintf("%s.%s %s", tableName, sort.Column, direction)) + } } sortSQL = strings.Join(sortParts, ", ") } else { @@ -2397,6 +2419,55 @@ func (h *Handler) filterExtendedOptions(validator *common.ColumnValidator, optio expandValidator := common.NewColumnValidator(relInfo.relatedModel) // Filter columns using the related model's validator filteredExpand.Columns = expandValidator.FilterValidColumns(expand.Columns) + + // Filter sort columns in the expand Sort string + if expand.Sort != "" { + sortFields := strings.Split(expand.Sort, ",") + validSortFields := make([]string, 0, len(sortFields)) + for _, sortField := range sortFields { + sortField = strings.TrimSpace(sortField) + if sortField == "" { + continue + } + + // Extract column name (remove direction prefixes/suffixes) + colName := sortField + direction := "" + + if strings.HasPrefix(sortField, "-") { + direction = "-" + colName = strings.TrimPrefix(sortField, "-") + } else if strings.HasPrefix(sortField, "+") { + direction = "+" + colName = strings.TrimPrefix(sortField, "+") + } + + if strings.HasSuffix(strings.ToLower(colName), " desc") { + direction = " desc" + colName = strings.TrimSuffix(strings.ToLower(colName), " desc") + } else if strings.HasSuffix(strings.ToLower(colName), " asc") { + direction = " asc" + colName = strings.TrimSuffix(strings.ToLower(colName), " asc") + } + + colName = strings.TrimSpace(colName) + + // Validate the column name + if expandValidator.IsValidColumn(colName) { + validSortFields = append(validSortFields, direction+colName) + } else if strings.HasPrefix(colName, "(") && strings.HasSuffix(colName, ")") { + // Allow sort by expression/subquery, but validate for security + if common.IsSafeSortExpression(colName) { + validSortFields = append(validSortFields, direction+colName) + } else { + logger.Warn("Unsafe sort expression in expand '%s' removed: '%s'", expand.Relation, colName) + } + } else { + logger.Warn("Invalid column in expand '%s' sort '%s' removed", expand.Relation, colName) + } + } + filteredExpand.Sort = strings.Join(validSortFields, ",") + } } else { // If we can't find the relationship, log a warning and skip column filtering logger.Warn("Cannot validate columns for unknown relation: %s", expand.Relation) diff --git a/pkg/restheadspec/headers.go b/pkg/restheadspec/headers.go index 903e50f..a64cdb4 100644 --- a/pkg/restheadspec/headers.go +++ b/pkg/restheadspec/headers.go @@ -529,19 +529,47 @@ func (h *Handler) parseSorting(options *ExtendedRequestOptions, value string) { } // parseCommaSeparated parses comma-separated values and trims whitespace +// It respects bracket nesting and only splits on commas outside of parentheses func (h *Handler) parseCommaSeparated(value string) []string { if value == "" { return nil } - parts := strings.Split(value, ",") - result := make([]string, 0, len(parts)) - for _, part := range parts { - part = strings.TrimSpace(part) - if part != "" { - result = append(result, part) + result := make([]string, 0) + var current strings.Builder + nestingLevel := 0 + + for _, char := range value { + switch char { + case '(': + nestingLevel++ + current.WriteRune(char) + case ')': + nestingLevel-- + current.WriteRune(char) + case ',': + if nestingLevel == 0 { + // We're outside all brackets, so split here + part := strings.TrimSpace(current.String()) + if part != "" { + result = append(result, part) + } + current.Reset() + } else { + // Inside brackets, keep the comma + current.WriteRune(char) + } + default: + current.WriteRune(char) } } + + // Add the last part + part := strings.TrimSpace(current.String()) + if part != "" { + result = append(result, part) + } + return result }