From 24a7ef7284eab3eea7d902f33eb8f24d3ca2bed1 Mon Sep 17 00:00:00 2001 From: Hein Date: Thu, 15 Jan 2026 14:18:25 +0200 Subject: [PATCH] =?UTF-8?q?feat(restheadspec):=20=E2=9C=A8=20Add=20support?= =?UTF-8?q?=20for=20join=20aliases=20in=20filters=20and=20sorts?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Extract join aliases from custom SQL JOIN clauses. - Validate join aliases for filtering and sorting operations. - Update documentation to reflect new functionality. - Enhance tests for alias extraction and usage. --- pkg/common/sql_helpers.go | 8 ++ pkg/common/types.go | 4 + pkg/restheadspec/HEADERS.md | 21 ++++ pkg/restheadspec/headers.go | 59 ++++++++++ pkg/restheadspec/query_params_test.go | 150 ++++++++++++++++++++++++++ 5 files changed, 242 insertions(+) diff --git a/pkg/common/sql_helpers.go b/pkg/common/sql_helpers.go index 0af6616..26d8053 100644 --- a/pkg/common/sql_helpers.go +++ b/pkg/common/sql_helpers.go @@ -166,6 +166,14 @@ func SanitizeWhereClause(where string, tableName string, options ...*RequestOpti logger.Debug("Added preload relation '%s' as allowed table prefix", options[0].Preload[pi].Relation) } } + + // Add join aliases as allowed prefixes + for _, alias := range options[0].JoinAliases { + if alias != "" { + allowedPrefixes[alias] = true + logger.Debug("Added join alias '%s' as allowed table prefix", alias) + } + } } // Split by AND to handle multiple conditions diff --git a/pkg/common/types.go b/pkg/common/types.go index b09b3db..3e81ab9 100644 --- a/pkg/common/types.go +++ b/pkg/common/types.go @@ -23,6 +23,10 @@ type RequestOptions struct { CursorForward string `json:"cursor_forward"` CursorBackward string `json:"cursor_backward"` FetchRowNumber *string `json:"fetch_row_number"` + + // Join table aliases (used for validation of prefixed columns in filters/sorts) + // Not serialized to JSON as it's internal validation state + JoinAliases []string `json:"-"` } type Parameter struct { diff --git a/pkg/restheadspec/HEADERS.md b/pkg/restheadspec/HEADERS.md index 0149a0b..147f6ce 100644 --- a/pkg/restheadspec/HEADERS.md +++ b/pkg/restheadspec/HEADERS.md @@ -233,6 +233,27 @@ x-custom-sql-join: LEFT JOIN departments d ON d.id = e.dept_id | INNER JOIN role - Multiple JOINs can be specified using the pipe `|` separator - JOINs are sanitized for security - Can be specified via headers or query parameters +- **Table aliases are automatically extracted and allowed for filtering and sorting** + +**Using Join Aliases in Filters and Sorts:** + +When you specify a custom SQL join with an alias, you can use that alias in your filter and sort parameters: + +``` +# Join with alias +x-custom-sql-join: LEFT JOIN departments d ON d.id = employees.department_id + +# Sort by joined table column +x-sort: d.name,employees.id + +# Filter by joined table column +x-searchop-eq-d.name: Engineering +``` + +The system automatically: +1. Extracts the alias from the JOIN clause (e.g., `d` from `departments d`) +2. Validates that prefixed columns (like `d.name`) refer to valid join aliases +3. Allows these prefixed columns in filters and sorts --- diff --git a/pkg/restheadspec/headers.go b/pkg/restheadspec/headers.go index eb32fb8..ef51cbc 100644 --- a/pkg/restheadspec/headers.go +++ b/pkg/restheadspec/headers.go @@ -28,6 +28,7 @@ type ExtendedRequestOptions struct { // Joins Expand []ExpandOption CustomSQLJoin []string // Custom SQL JOIN clauses + JoinAliases []string // Extracted table aliases from CustomSQLJoin for validation // Advanced features AdvancedSQL map[string]string // Column -> SQL expression @@ -528,11 +529,69 @@ func (h *Handler) parseCustomSQLJoin(options *ExtendedRequestOptions, value stri continue } + // Extract table alias from the JOIN clause + alias := extractJoinAlias(sanitizedJoin) + if alias != "" { + options.JoinAliases = append(options.JoinAliases, alias) + // Also add to the embedded RequestOptions for validation + options.RequestOptions.JoinAliases = append(options.RequestOptions.JoinAliases, alias) + logger.Debug("Extracted join alias: %s", alias) + } + logger.Debug("Adding custom SQL join: %s", sanitizedJoin) options.CustomSQLJoin = append(options.CustomSQLJoin, sanitizedJoin) } } +// extractJoinAlias extracts the table alias from a JOIN clause +// Examples: +// - "LEFT JOIN departments d ON ..." -> "d" +// - "INNER JOIN users AS u ON ..." -> "u" +// - "JOIN roles r ON ..." -> "r" +func extractJoinAlias(joinClause string) string { + // Pattern: JOIN table_name [AS] alias ON ... + // We need to extract the alias (word before ON) + + upperJoin := strings.ToUpper(joinClause) + + // Find the "JOIN" keyword position + joinIdx := strings.Index(upperJoin, "JOIN") + if joinIdx == -1 { + return "" + } + + // Find the "ON" keyword position + onIdx := strings.Index(upperJoin, " ON ") + if onIdx == -1 { + return "" + } + + // Extract the part between JOIN and ON + betweenJoinAndOn := strings.TrimSpace(joinClause[joinIdx+4 : onIdx]) + + // Split by spaces to get words + words := strings.Fields(betweenJoinAndOn) + if len(words) == 0 { + return "" + } + + // If there's an AS keyword, the alias is after it + for i, word := range words { + if strings.EqualFold(word, "AS") && i+1 < len(words) { + return words[i+1] + } + } + + // Otherwise, the alias is the last word (if there are 2+ words) + // Format: "table_name alias" or just "table_name" + if len(words) >= 2 { + return words[len(words)-1] + } + + // Only one word means it's just the table name, no alias + return "" +} + // parseSorting parses x-sort header // Format: +field1,-field2,field3 (+ for ASC, - for DESC, default ASC) func (h *Handler) parseSorting(options *ExtendedRequestOptions, value string) { diff --git a/pkg/restheadspec/query_params_test.go b/pkg/restheadspec/query_params_test.go index 5ea19ec..9b37768 100644 --- a/pkg/restheadspec/query_params_test.go +++ b/pkg/restheadspec/query_params_test.go @@ -357,6 +357,107 @@ func TestParseOptionsFromQueryParams(t *testing.T) { } }, }, + { + name: "Extract aliases from custom SQL JOIN", + queryParams: map[string]string{ + "x-custom-sql-join": `LEFT JOIN departments d ON d.id = employees.department_id`, + }, + validate: func(t *testing.T, options ExtendedRequestOptions) { + if len(options.JoinAliases) == 0 { + t.Error("Expected JoinAliases to be extracted") + return + } + if len(options.JoinAliases) != 1 { + t.Errorf("Expected 1 join alias, got %d", len(options.JoinAliases)) + return + } + if options.JoinAliases[0] != "d" { + t.Errorf("Expected join alias 'd', got %q", options.JoinAliases[0]) + } + // Also check that it's in the embedded RequestOptions + if len(options.RequestOptions.JoinAliases) != 1 || options.RequestOptions.JoinAliases[0] != "d" { + t.Error("Expected join alias to also be in RequestOptions.JoinAliases") + } + }, + }, + { + name: "Extract multiple aliases from multiple custom SQL JOINs", + queryParams: map[string]string{ + "x-custom-sql-join": `LEFT JOIN departments d ON d.id = e.dept_id | INNER JOIN roles AS r ON r.id = e.role_id`, + }, + validate: func(t *testing.T, options ExtendedRequestOptions) { + if len(options.JoinAliases) != 2 { + t.Errorf("Expected 2 join aliases, got %d", len(options.JoinAliases)) + return + } + expectedAliases := []string{"d", "r"} + for i, expected := range expectedAliases { + if options.JoinAliases[i] != expected { + t.Errorf("Expected join alias[%d]=%q, got %q", i, expected, options.JoinAliases[i]) + } + } + }, + }, + { + name: "Custom JOIN with sort on joined table", + queryParams: map[string]string{ + "x-custom-sql-join": `LEFT JOIN departments d ON d.id = employees.department_id`, + "x-sort": "d.name,employees.id", + }, + validate: func(t *testing.T, options ExtendedRequestOptions) { + // Verify join was added + if len(options.CustomSQLJoin) != 1 { + t.Errorf("Expected 1 custom SQL join, got %d", len(options.CustomSQLJoin)) + return + } + // Verify alias was extracted + if len(options.JoinAliases) != 1 || options.JoinAliases[0] != "d" { + t.Error("Expected join alias 'd' to be extracted") + return + } + // Verify sort was parsed + if len(options.Sort) != 2 { + t.Errorf("Expected 2 sort options, got %d", len(options.Sort)) + return + } + if options.Sort[0].Column != "d.name" { + t.Errorf("Expected first sort column 'd.name', got %q", options.Sort[0].Column) + } + if options.Sort[1].Column != "employees.id" { + t.Errorf("Expected second sort column 'employees.id', got %q", options.Sort[1].Column) + } + }, + }, + { + name: "Custom JOIN with filter on joined table", + queryParams: map[string]string{ + "x-custom-sql-join": `LEFT JOIN departments d ON d.id = employees.department_id`, + "x-searchop-eq-d.name": "Engineering", + }, + validate: func(t *testing.T, options ExtendedRequestOptions) { + // Verify join was added + if len(options.CustomSQLJoin) != 1 { + t.Error("Expected 1 custom SQL join") + return + } + // Verify alias was extracted + if len(options.JoinAliases) != 1 || options.JoinAliases[0] != "d" { + t.Error("Expected join alias 'd' to be extracted") + return + } + // Verify filter was parsed + if len(options.Filters) != 1 { + t.Errorf("Expected 1 filter, got %d", len(options.Filters)) + return + } + if options.Filters[0].Column != "d.name" { + t.Errorf("Expected filter column 'd.name', got %q", options.Filters[0].Column) + } + if options.Filters[0].Operator != "eq" { + t.Errorf("Expected filter operator 'eq', got %q", options.Filters[0].Operator) + } + }, + }, } for _, tt := range tests { @@ -451,6 +552,55 @@ func TestHeadersAndQueryParamsCombined(t *testing.T) { } } +// TestCustomJoinAliasExtraction tests the extractJoinAlias helper function +func TestCustomJoinAliasExtraction(t *testing.T) { + tests := []struct { + name string + join string + expected string + }{ + { + name: "LEFT JOIN with alias", + join: "LEFT JOIN departments d ON d.id = employees.department_id", + expected: "d", + }, + { + name: "INNER JOIN with AS keyword", + join: "INNER JOIN users AS u ON u.id = posts.user_id", + expected: "u", + }, + { + name: "Simple JOIN with alias", + join: "JOIN roles r ON r.id = user_roles.role_id", + expected: "r", + }, + { + name: "JOIN without alias (just table name)", + join: "JOIN departments ON departments.id = employees.dept_id", + expected: "", + }, + { + name: "RIGHT JOIN with alias", + join: "RIGHT JOIN orders o ON o.customer_id = customers.id", + expected: "o", + }, + { + name: "FULL OUTER JOIN with AS", + join: "FULL OUTER JOIN products AS p ON p.id = order_items.product_id", + expected: "p", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := extractJoinAlias(tt.join) + if result != tt.expected { + t.Errorf("extractJoinAlias(%q) = %q, want %q", tt.join, result, tt.expected) + } + }) + } +} + // Helper function to check if a string contains a substring func contains(s, substr string) bool { return len(s) >= len(substr) && (s == substr || len(s) > len(substr) && containsHelper(s, substr))