Compare commits

..

2 Commits

Author SHA1 Message Date
Hein
c12c045db1 feat(validation): Clear JoinAliases in FilterRequestOptions
Some checks failed
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Successful in -27m20s
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Successful in -26m49s
Build , Vet Test, and Lint / Build (push) Successful in -26m53s
Build , Vet Test, and Lint / Lint Code (push) Successful in -26m22s
Tests / Integration Tests (push) Failing after -27m37s
Tests / Unit Tests (push) Successful in -27m25s
* Implemented logic to clear JoinAliases after filtering.
* Added unit test to verify JoinAliases is nil post-filtering.
* Ensured other fields are correctly filtered.
2026-01-15 14:43:11 +02:00
Hein
24a7ef7284 feat(restheadspec): Add support for join aliases in filters and sorts
- Extract join aliases from custom SQL JOIN clauses.
- Validate join aliases for filtering and sorting operations.
- Update documentation to reflect new functionality.
- Enhance tests for alias extraction and usage.
2026-01-15 14:18:25 +02:00
7 changed files with 290 additions and 8 deletions

View File

@@ -166,6 +166,14 @@ func SanitizeWhereClause(where string, tableName string, options ...*RequestOpti
logger.Debug("Added preload relation '%s' as allowed table prefix", options[0].Preload[pi].Relation) logger.Debug("Added preload relation '%s' as allowed table prefix", options[0].Preload[pi].Relation)
} }
} }
// Add join aliases as allowed prefixes
for _, alias := range options[0].JoinAliases {
if alias != "" {
allowedPrefixes[alias] = true
logger.Debug("Added join alias '%s' as allowed table prefix", alias)
}
}
} }
// Split by AND to handle multiple conditions // Split by AND to handle multiple conditions

View File

@@ -23,6 +23,10 @@ type RequestOptions struct {
CursorForward string `json:"cursor_forward"` CursorForward string `json:"cursor_forward"`
CursorBackward string `json:"cursor_backward"` CursorBackward string `json:"cursor_backward"`
FetchRowNumber *string `json:"fetch_row_number"` FetchRowNumber *string `json:"fetch_row_number"`
// Join table aliases (used for validation of prefixed columns in filters/sorts)
// Not serialized to JSON as it's internal validation state
JoinAliases []string `json:"-"`
} }
type Parameter struct { type Parameter struct {

View File

@@ -237,15 +237,29 @@ func (v *ColumnValidator) FilterRequestOptions(options RequestOptions) RequestOp
for _, sort := range options.Sort { for _, sort := range options.Sort {
if v.IsValidColumn(sort.Column) { if v.IsValidColumn(sort.Column) {
validSorts = append(validSorts, sort) validSorts = append(validSorts, sort)
} else if strings.HasPrefix(sort.Column, "(") && strings.HasSuffix(sort.Column, ")") {
// Allow sort by expression/subquery, but validate for security
if IsSafeSortExpression(sort.Column) {
validSorts = append(validSorts, sort)
} else {
logger.Warn("Unsafe sort expression '%s' removed", sort.Column)
}
} else { } else {
logger.Warn("Invalid column in sort '%s' removed", sort.Column) foundJoin := false
for _, j := range options.JoinAliases {
if strings.Contains(sort.Column, j) {
foundJoin = true
break
}
}
if foundJoin {
validSorts = append(validSorts, sort)
continue
}
if strings.HasPrefix(sort.Column, "(") && strings.HasSuffix(sort.Column, ")") {
// Allow sort by expression/subquery, but validate for security
if IsSafeSortExpression(sort.Column) {
validSorts = append(validSorts, sort)
} else {
logger.Warn("Unsafe sort expression '%s' removed", sort.Column)
}
} else {
logger.Warn("Invalid column in sort '%s' removed", sort.Column)
}
} }
} }
filtered.Sort = validSorts filtered.Sort = validSorts
@@ -291,6 +305,9 @@ func (v *ColumnValidator) FilterRequestOptions(options RequestOptions) RequestOp
} }
filtered.Preload = validPreloads filtered.Preload = validPreloads
// Clear JoinAliases - this is an internal validation field and should not be persisted
filtered.JoinAliases = nil
return filtered return filtered
} }

View File

@@ -362,6 +362,29 @@ func TestFilterRequestOptions(t *testing.T) {
} }
} }
func TestFilterRequestOptions_ClearsJoinAliases(t *testing.T) {
model := TestModel{}
validator := NewColumnValidator(model)
options := RequestOptions{
Columns: []string{"id", "name"},
// Set JoinAliases - this should be cleared by FilterRequestOptions
JoinAliases: []string{"d", "u", "r"},
}
filtered := validator.FilterRequestOptions(options)
// Verify that JoinAliases was cleared (internal field should not persist)
if filtered.JoinAliases != nil {
t.Errorf("Expected JoinAliases to be nil after filtering, got %v", filtered.JoinAliases)
}
// Verify that other fields are still properly filtered
if len(filtered.Columns) != 2 {
t.Errorf("Expected 2 columns, got %d", len(filtered.Columns))
}
}
func TestIsSafeSortExpression(t *testing.T) { func TestIsSafeSortExpression(t *testing.T) {
tests := []struct { tests := []struct {
name string name string

View File

@@ -233,6 +233,27 @@ x-custom-sql-join: LEFT JOIN departments d ON d.id = e.dept_id | INNER JOIN role
- Multiple JOINs can be specified using the pipe `|` separator - Multiple JOINs can be specified using the pipe `|` separator
- JOINs are sanitized for security - JOINs are sanitized for security
- Can be specified via headers or query parameters - Can be specified via headers or query parameters
- **Table aliases are automatically extracted and allowed for filtering and sorting**
**Using Join Aliases in Filters and Sorts:**
When you specify a custom SQL join with an alias, you can use that alias in your filter and sort parameters:
```
# Join with alias
x-custom-sql-join: LEFT JOIN departments d ON d.id = employees.department_id
# Sort by joined table column
x-sort: d.name,employees.id
# Filter by joined table column
x-searchop-eq-d.name: Engineering
```
The system automatically:
1. Extracts the alias from the JOIN clause (e.g., `d` from `departments d`)
2. Validates that prefixed columns (like `d.name`) refer to valid join aliases
3. Allows these prefixed columns in filters and sorts
--- ---

View File

@@ -28,6 +28,7 @@ type ExtendedRequestOptions struct {
// Joins // Joins
Expand []ExpandOption Expand []ExpandOption
CustomSQLJoin []string // Custom SQL JOIN clauses CustomSQLJoin []string // Custom SQL JOIN clauses
JoinAliases []string // Extracted table aliases from CustomSQLJoin for validation
// Advanced features // Advanced features
AdvancedSQL map[string]string // Column -> SQL expression AdvancedSQL map[string]string // Column -> SQL expression
@@ -528,11 +529,69 @@ func (h *Handler) parseCustomSQLJoin(options *ExtendedRequestOptions, value stri
continue continue
} }
// Extract table alias from the JOIN clause
alias := extractJoinAlias(sanitizedJoin)
if alias != "" {
options.JoinAliases = append(options.JoinAliases, alias)
// Also add to the embedded RequestOptions for validation
options.RequestOptions.JoinAliases = append(options.RequestOptions.JoinAliases, alias)
logger.Debug("Extracted join alias: %s", alias)
}
logger.Debug("Adding custom SQL join: %s", sanitizedJoin) logger.Debug("Adding custom SQL join: %s", sanitizedJoin)
options.CustomSQLJoin = append(options.CustomSQLJoin, sanitizedJoin) options.CustomSQLJoin = append(options.CustomSQLJoin, sanitizedJoin)
} }
} }
// extractJoinAlias extracts the table alias from a JOIN clause
// Examples:
// - "LEFT JOIN departments d ON ..." -> "d"
// - "INNER JOIN users AS u ON ..." -> "u"
// - "JOIN roles r ON ..." -> "r"
func extractJoinAlias(joinClause string) string {
// Pattern: JOIN table_name [AS] alias ON ...
// We need to extract the alias (word before ON)
upperJoin := strings.ToUpper(joinClause)
// Find the "JOIN" keyword position
joinIdx := strings.Index(upperJoin, "JOIN")
if joinIdx == -1 {
return ""
}
// Find the "ON" keyword position
onIdx := strings.Index(upperJoin, " ON ")
if onIdx == -1 {
return ""
}
// Extract the part between JOIN and ON
betweenJoinAndOn := strings.TrimSpace(joinClause[joinIdx+4 : onIdx])
// Split by spaces to get words
words := strings.Fields(betweenJoinAndOn)
if len(words) == 0 {
return ""
}
// If there's an AS keyword, the alias is after it
for i, word := range words {
if strings.EqualFold(word, "AS") && i+1 < len(words) {
return words[i+1]
}
}
// Otherwise, the alias is the last word (if there are 2+ words)
// Format: "table_name alias" or just "table_name"
if len(words) >= 2 {
return words[len(words)-1]
}
// Only one word means it's just the table name, no alias
return ""
}
// parseSorting parses x-sort header // parseSorting parses x-sort header
// Format: +field1,-field2,field3 (+ for ASC, - for DESC, default ASC) // Format: +field1,-field2,field3 (+ for ASC, - for DESC, default ASC)
func (h *Handler) parseSorting(options *ExtendedRequestOptions, value string) { func (h *Handler) parseSorting(options *ExtendedRequestOptions, value string) {

View File

@@ -357,6 +357,107 @@ func TestParseOptionsFromQueryParams(t *testing.T) {
} }
}, },
}, },
{
name: "Extract aliases from custom SQL JOIN",
queryParams: map[string]string{
"x-custom-sql-join": `LEFT JOIN departments d ON d.id = employees.department_id`,
},
validate: func(t *testing.T, options ExtendedRequestOptions) {
if len(options.JoinAliases) == 0 {
t.Error("Expected JoinAliases to be extracted")
return
}
if len(options.JoinAliases) != 1 {
t.Errorf("Expected 1 join alias, got %d", len(options.JoinAliases))
return
}
if options.JoinAliases[0] != "d" {
t.Errorf("Expected join alias 'd', got %q", options.JoinAliases[0])
}
// Also check that it's in the embedded RequestOptions
if len(options.RequestOptions.JoinAliases) != 1 || options.RequestOptions.JoinAliases[0] != "d" {
t.Error("Expected join alias to also be in RequestOptions.JoinAliases")
}
},
},
{
name: "Extract multiple aliases from multiple custom SQL JOINs",
queryParams: map[string]string{
"x-custom-sql-join": `LEFT JOIN departments d ON d.id = e.dept_id | INNER JOIN roles AS r ON r.id = e.role_id`,
},
validate: func(t *testing.T, options ExtendedRequestOptions) {
if len(options.JoinAliases) != 2 {
t.Errorf("Expected 2 join aliases, got %d", len(options.JoinAliases))
return
}
expectedAliases := []string{"d", "r"}
for i, expected := range expectedAliases {
if options.JoinAliases[i] != expected {
t.Errorf("Expected join alias[%d]=%q, got %q", i, expected, options.JoinAliases[i])
}
}
},
},
{
name: "Custom JOIN with sort on joined table",
queryParams: map[string]string{
"x-custom-sql-join": `LEFT JOIN departments d ON d.id = employees.department_id`,
"x-sort": "d.name,employees.id",
},
validate: func(t *testing.T, options ExtendedRequestOptions) {
// Verify join was added
if len(options.CustomSQLJoin) != 1 {
t.Errorf("Expected 1 custom SQL join, got %d", len(options.CustomSQLJoin))
return
}
// Verify alias was extracted
if len(options.JoinAliases) != 1 || options.JoinAliases[0] != "d" {
t.Error("Expected join alias 'd' to be extracted")
return
}
// Verify sort was parsed
if len(options.Sort) != 2 {
t.Errorf("Expected 2 sort options, got %d", len(options.Sort))
return
}
if options.Sort[0].Column != "d.name" {
t.Errorf("Expected first sort column 'd.name', got %q", options.Sort[0].Column)
}
if options.Sort[1].Column != "employees.id" {
t.Errorf("Expected second sort column 'employees.id', got %q", options.Sort[1].Column)
}
},
},
{
name: "Custom JOIN with filter on joined table",
queryParams: map[string]string{
"x-custom-sql-join": `LEFT JOIN departments d ON d.id = employees.department_id`,
"x-searchop-eq-d.name": "Engineering",
},
validate: func(t *testing.T, options ExtendedRequestOptions) {
// Verify join was added
if len(options.CustomSQLJoin) != 1 {
t.Error("Expected 1 custom SQL join")
return
}
// Verify alias was extracted
if len(options.JoinAliases) != 1 || options.JoinAliases[0] != "d" {
t.Error("Expected join alias 'd' to be extracted")
return
}
// Verify filter was parsed
if len(options.Filters) != 1 {
t.Errorf("Expected 1 filter, got %d", len(options.Filters))
return
}
if options.Filters[0].Column != "d.name" {
t.Errorf("Expected filter column 'd.name', got %q", options.Filters[0].Column)
}
if options.Filters[0].Operator != "eq" {
t.Errorf("Expected filter operator 'eq', got %q", options.Filters[0].Operator)
}
},
},
} }
for _, tt := range tests { for _, tt := range tests {
@@ -451,6 +552,55 @@ func TestHeadersAndQueryParamsCombined(t *testing.T) {
} }
} }
// TestCustomJoinAliasExtraction tests the extractJoinAlias helper function
func TestCustomJoinAliasExtraction(t *testing.T) {
tests := []struct {
name string
join string
expected string
}{
{
name: "LEFT JOIN with alias",
join: "LEFT JOIN departments d ON d.id = employees.department_id",
expected: "d",
},
{
name: "INNER JOIN with AS keyword",
join: "INNER JOIN users AS u ON u.id = posts.user_id",
expected: "u",
},
{
name: "Simple JOIN with alias",
join: "JOIN roles r ON r.id = user_roles.role_id",
expected: "r",
},
{
name: "JOIN without alias (just table name)",
join: "JOIN departments ON departments.id = employees.dept_id",
expected: "",
},
{
name: "RIGHT JOIN with alias",
join: "RIGHT JOIN orders o ON o.customer_id = customers.id",
expected: "o",
},
{
name: "FULL OUTER JOIN with AS",
join: "FULL OUTER JOIN products AS p ON p.id = order_items.product_id",
expected: "p",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := extractJoinAlias(tt.join)
if result != tt.expected {
t.Errorf("extractJoinAlias(%q) = %q, want %q", tt.join, result, tt.expected)
}
})
}
}
// Helper function to check if a string contains a substring // Helper function to check if a string contains a substring
func contains(s, substr string) bool { func contains(s, substr string) bool {
return len(s) >= len(substr) && (s == substr || len(s) > len(substr) && containsHelper(s, substr)) return len(s) >= len(substr) && (s == substr || len(s) > len(substr) && containsHelper(s, substr))