Fixed order by. Added OrderExpr to database interface
Some checks are pending
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Waiting to run
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Waiting to run
Build , Vet Test, and Lint / Lint Code (push) Waiting to run
Build , Vet Test, and Lint / Build (push) Waiting to run
Tests / Unit Tests (push) Waiting to run
Tests / Integration Tests (push) Waiting to run

This commit is contained in:
Hein
2025-12-17 16:50:33 +02:00
parent 932f12ab0a
commit 9351093e2a
8 changed files with 282 additions and 9 deletions

View File

@@ -513,7 +513,15 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
direction = "DESC"
}
logger.Debug("Applying sort: %s %s", sort.Column, direction)
query = query.Order(fmt.Sprintf("%s %s", sort.Column, direction))
// Check if it's an expression (enclosed in brackets) - use directly without quoting
if strings.HasPrefix(sort.Column, "(") && strings.HasSuffix(sort.Column, ")") {
// For expressions, pass as raw SQL to prevent auto-quoting
query = query.OrderExpr(fmt.Sprintf("%s %s", sort.Column, direction))
} else {
// Regular column - let Bun handle quoting
query = query.Order(fmt.Sprintf("%s %s", sort.Column, direction))
}
}
// Get total count before pagination (unless skip count is requested)
@@ -827,7 +835,14 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co
// Apply sorting
if len(preload.Sort) > 0 {
for _, sort := range preload.Sort {
sq = sq.Order(fmt.Sprintf("%s %s", sort.Column, sort.Direction))
// Check if it's an expression (enclosed in brackets) - use directly without quoting
if strings.HasPrefix(sort.Column, "(") && strings.HasSuffix(sort.Column, ")") {
// For expressions, pass as raw SQL to prevent auto-quoting
sq = sq.OrderExpr(fmt.Sprintf("%s %s", sort.Column, sort.Direction))
} else {
// Regular column - let ORM handle quoting
sq = sq.Order(fmt.Sprintf("%s %s", sort.Column, sort.Direction))
}
}
}
@@ -2188,7 +2203,14 @@ func (h *Handler) FetchRowNumber(ctx context.Context, tableName string, pkName s
if strings.EqualFold(sort.Direction, "desc") {
direction = "DESC"
}
sortParts = append(sortParts, fmt.Sprintf("%s.%s %s", tableName, sort.Column, direction))
// Check if it's an expression (enclosed in brackets) - use directly without table prefix
if strings.HasPrefix(sort.Column, "(") && strings.HasSuffix(sort.Column, ")") {
sortParts = append(sortParts, fmt.Sprintf("%s %s", sort.Column, direction))
} else {
// Regular column - add table prefix
sortParts = append(sortParts, fmt.Sprintf("%s.%s %s", tableName, sort.Column, direction))
}
}
sortSQL = strings.Join(sortParts, ", ")
} else {
@@ -2397,6 +2419,55 @@ func (h *Handler) filterExtendedOptions(validator *common.ColumnValidator, optio
expandValidator := common.NewColumnValidator(relInfo.relatedModel)
// Filter columns using the related model's validator
filteredExpand.Columns = expandValidator.FilterValidColumns(expand.Columns)
// Filter sort columns in the expand Sort string
if expand.Sort != "" {
sortFields := strings.Split(expand.Sort, ",")
validSortFields := make([]string, 0, len(sortFields))
for _, sortField := range sortFields {
sortField = strings.TrimSpace(sortField)
if sortField == "" {
continue
}
// Extract column name (remove direction prefixes/suffixes)
colName := sortField
direction := ""
if strings.HasPrefix(sortField, "-") {
direction = "-"
colName = strings.TrimPrefix(sortField, "-")
} else if strings.HasPrefix(sortField, "+") {
direction = "+"
colName = strings.TrimPrefix(sortField, "+")
}
if strings.HasSuffix(strings.ToLower(colName), " desc") {
direction = " desc"
colName = strings.TrimSuffix(strings.ToLower(colName), " desc")
} else if strings.HasSuffix(strings.ToLower(colName), " asc") {
direction = " asc"
colName = strings.TrimSuffix(strings.ToLower(colName), " asc")
}
colName = strings.TrimSpace(colName)
// Validate the column name
if expandValidator.IsValidColumn(colName) {
validSortFields = append(validSortFields, direction+colName)
} else if strings.HasPrefix(colName, "(") && strings.HasSuffix(colName, ")") {
// Allow sort by expression/subquery, but validate for security
if common.IsSafeSortExpression(colName) {
validSortFields = append(validSortFields, direction+colName)
} else {
logger.Warn("Unsafe sort expression in expand '%s' removed: '%s'", expand.Relation, colName)
}
} else {
logger.Warn("Invalid column in expand '%s' sort '%s' removed", expand.Relation, colName)
}
}
filteredExpand.Sort = strings.Join(validSortFields, ",")
}
} else {
// If we can't find the relationship, log a warning and skip column filtering
logger.Warn("Cannot validate columns for unknown relation: %s", expand.Relation)

View File

@@ -529,19 +529,47 @@ func (h *Handler) parseSorting(options *ExtendedRequestOptions, value string) {
}
// parseCommaSeparated parses comma-separated values and trims whitespace
// It respects bracket nesting and only splits on commas outside of parentheses
func (h *Handler) parseCommaSeparated(value string) []string {
if value == "" {
return nil
}
parts := strings.Split(value, ",")
result := make([]string, 0, len(parts))
for _, part := range parts {
part = strings.TrimSpace(part)
if part != "" {
result = append(result, part)
result := make([]string, 0)
var current strings.Builder
nestingLevel := 0
for _, char := range value {
switch char {
case '(':
nestingLevel++
current.WriteRune(char)
case ')':
nestingLevel--
current.WriteRune(char)
case ',':
if nestingLevel == 0 {
// We're outside all brackets, so split here
part := strings.TrimSpace(current.String())
if part != "" {
result = append(result, part)
}
current.Reset()
} else {
// Inside brackets, keep the comma
current.WriteRune(char)
}
default:
current.WriteRune(char)
}
}
// Add the last part
part := strings.TrimSpace(current.String())
if part != "" {
result = append(result, part)
}
return result
}