diff --git a/pkg/restheadspec/HEADERS.md b/pkg/restheadspec/HEADERS.md index c422404..0149a0b 100644 --- a/pkg/restheadspec/HEADERS.md +++ b/pkg/restheadspec/HEADERS.md @@ -214,14 +214,25 @@ x-expand: department:id,name,code **Note:** Currently, expand falls back to preload behavior. Full JOIN expansion is planned for future implementation. #### `x-custom-sql-join` -Raw SQL JOIN statement. +Custom SQL JOIN clauses for joining tables in queries. -**Format:** SQL JOIN clause +**Format:** SQL JOIN clause or multiple clauses separated by `|` + +**Single JOIN:** ``` x-custom-sql-join: LEFT JOIN departments d ON d.id = employees.department_id ``` -⚠️ **Note:** Not yet fully implemented. +**Multiple JOINs:** +``` +x-custom-sql-join: LEFT JOIN departments d ON d.id = e.dept_id | INNER JOIN roles r ON r.id = e.role_id +``` + +**Features:** +- Supports any type of JOIN (INNER, LEFT, RIGHT, FULL, CROSS) +- Multiple JOINs can be specified using the pipe `|` separator +- JOINs are sanitized for security +- Can be specified via headers or query parameters --- diff --git a/pkg/restheadspec/cache_helpers.go b/pkg/restheadspec/cache_helpers.go index 094e435..1e81187 100644 --- a/pkg/restheadspec/cache_helpers.go +++ b/pkg/restheadspec/cache_helpers.go @@ -26,6 +26,7 @@ type queryCacheKey struct { Sort []common.SortOption `json:"sort"` CustomSQLWhere string `json:"custom_sql_where,omitempty"` CustomSQLOr string `json:"custom_sql_or,omitempty"` + CustomSQLJoin []string `json:"custom_sql_join,omitempty"` Expand []expandOptionKey `json:"expand,omitempty"` Distinct bool `json:"distinct,omitempty"` CursorForward string `json:"cursor_forward,omitempty"` @@ -40,7 +41,7 @@ type cachedTotal struct { // buildExtendedQueryCacheKey builds a cache key for extended query options (restheadspec) // Includes expand, distinct, and cursor pagination options func buildExtendedQueryCacheKey(tableName string, filters []common.FilterOption, sort []common.SortOption, - customWhere, customOr string, expandOpts []interface{}, distinct bool, cursorFwd, cursorBwd string) string { + customWhere, customOr string, customJoin []string, expandOpts []interface{}, distinct bool, cursorFwd, cursorBwd string) string { key := queryCacheKey{ TableName: tableName, @@ -48,6 +49,7 @@ func buildExtendedQueryCacheKey(tableName string, filters []common.FilterOption, Sort: sort, CustomSQLWhere: customWhere, CustomSQLOr: customOr, + CustomSQLJoin: customJoin, Distinct: distinct, CursorForward: cursorFwd, CursorBackward: cursorBwd, @@ -75,8 +77,8 @@ func buildExtendedQueryCacheKey(tableName string, filters []common.FilterOption, jsonData, err := json.Marshal(key) if err != nil { // Fallback to simple string concatenation if JSON fails - return hashString(fmt.Sprintf("%s_%v_%v_%s_%s_%v_%v_%s_%s", - tableName, filters, sort, customWhere, customOr, expandOpts, distinct, cursorFwd, cursorBwd)) + return hashString(fmt.Sprintf("%s_%v_%v_%s_%s_%v_%v_%v_%s_%s", + tableName, filters, sort, customWhere, customOr, customJoin, expandOpts, distinct, cursorFwd, cursorBwd)) } return hashString(string(jsonData)) diff --git a/pkg/restheadspec/handler.go b/pkg/restheadspec/handler.go index f3dc8d8..c176bcb 100644 --- a/pkg/restheadspec/handler.go +++ b/pkg/restheadspec/handler.go @@ -502,6 +502,15 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st } } + // Apply custom SQL JOIN clauses + if len(options.CustomSQLJoin) > 0 { + for _, joinClause := range options.CustomSQLJoin { + logger.Debug("Applying custom SQL JOIN: %s", joinClause) + // Joins are already sanitized during parsing, so we can apply them directly + query = query.Join(joinClause) + } + } + // If ID is provided, filter by ID if id != "" { pkName := reflection.GetPrimaryKeyName(model) @@ -552,6 +561,7 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st options.Sort, options.CustomSQLWhere, options.CustomSQLOr, + options.CustomSQLJoin, expandOpts, options.Distinct, options.CursorForward, diff --git a/pkg/restheadspec/headers.go b/pkg/restheadspec/headers.go index bdae2bd..eb32fb8 100644 --- a/pkg/restheadspec/headers.go +++ b/pkg/restheadspec/headers.go @@ -26,7 +26,8 @@ type ExtendedRequestOptions struct { CustomSQLOr string // Joins - Expand []ExpandOption + Expand []ExpandOption + CustomSQLJoin []string // Custom SQL JOIN clauses // Advanced features AdvancedSQL map[string]string // Column -> SQL expression @@ -111,6 +112,7 @@ func (h *Handler) parseOptionsFromHeaders(r common.Request, model interface{}) E AdvancedSQL: make(map[string]string), ComputedQL: make(map[string]string), Expand: make([]ExpandOption, 0), + CustomSQLJoin: make([]string, 0), ResponseFormat: "simple", // Default response format SingleRecordAsObject: true, // Default: normalize single-element arrays to objects } @@ -185,8 +187,7 @@ func (h *Handler) parseOptionsFromHeaders(r common.Request, model interface{}) E case strings.HasPrefix(key, "x-expand"): h.parseExpand(&options, decodedValue) case strings.HasPrefix(key, "x-custom-sql-join"): - // TODO: Implement custom SQL join - logger.Debug("Custom SQL join not yet implemented: %s", decodedValue) + h.parseCustomSQLJoin(&options, decodedValue) // Sorting & Pagination case strings.HasPrefix(key, "x-sort"): @@ -495,6 +496,43 @@ func (h *Handler) parseExpand(options *ExtendedRequestOptions, value string) { } } +// parseCustomSQLJoin parses x-custom-sql-join header +// Format: Single JOIN clause or multiple JOIN clauses separated by | +// Example: "LEFT JOIN departments d ON d.id = employees.department_id" +// Example: "LEFT JOIN departments d ON d.id = e.dept_id | INNER JOIN roles r ON r.id = e.role_id" +func (h *Handler) parseCustomSQLJoin(options *ExtendedRequestOptions, value string) { + if value == "" { + return + } + + // Split by | for multiple joins + joins := strings.Split(value, "|") + for _, joinStr := range joins { + joinStr = strings.TrimSpace(joinStr) + if joinStr == "" { + continue + } + + // Basic validation: should contain "JOIN" keyword + upperJoin := strings.ToUpper(joinStr) + if !strings.Contains(upperJoin, "JOIN") { + logger.Warn("Invalid custom SQL join (missing JOIN keyword): %s", joinStr) + continue + } + + // Sanitize the join clause using common.SanitizeWhereClause + // Note: This is basic sanitization - in production you may want stricter validation + sanitizedJoin := common.SanitizeWhereClause(joinStr, "", nil) + if sanitizedJoin == "" { + logger.Warn("Custom SQL join failed sanitization: %s", joinStr) + continue + } + + logger.Debug("Adding custom SQL join: %s", sanitizedJoin) + options.CustomSQLJoin = append(options.CustomSQLJoin, sanitizedJoin) + } +} + // parseSorting parses x-sort header // Format: +field1,-field2,field3 (+ for ASC, - for DESC, default ASC) func (h *Handler) parseSorting(options *ExtendedRequestOptions, value string) { diff --git a/pkg/restheadspec/query_params_test.go b/pkg/restheadspec/query_params_test.go index ac1beeb..5ea19ec 100644 --- a/pkg/restheadspec/query_params_test.go +++ b/pkg/restheadspec/query_params_test.go @@ -301,6 +301,62 @@ func TestParseOptionsFromQueryParams(t *testing.T) { } }, }, + { + name: "Parse custom SQL JOIN from query params", + queryParams: map[string]string{ + "x-custom-sql-join": `LEFT JOIN departments d ON d.id = employees.department_id`, + }, + validate: func(t *testing.T, options ExtendedRequestOptions) { + if len(options.CustomSQLJoin) == 0 { + t.Error("Expected CustomSQLJoin to be set") + return + } + if len(options.CustomSQLJoin) != 1 { + t.Errorf("Expected 1 custom SQL join, got %d", len(options.CustomSQLJoin)) + return + } + expected := `LEFT JOIN departments d ON d.id = employees.department_id` + if options.CustomSQLJoin[0] != expected { + t.Errorf("Expected CustomSQLJoin[0]=%q, got %q", expected, options.CustomSQLJoin[0]) + } + }, + }, + { + name: "Parse multiple custom SQL JOINs from query params", + queryParams: map[string]string{ + "x-custom-sql-join": `LEFT JOIN departments d ON d.id = e.dept_id | INNER JOIN roles r ON r.id = e.role_id`, + }, + validate: func(t *testing.T, options ExtendedRequestOptions) { + if len(options.CustomSQLJoin) != 2 { + t.Errorf("Expected 2 custom SQL joins, got %d", len(options.CustomSQLJoin)) + return + } + expected1 := `LEFT JOIN departments d ON d.id = e.dept_id` + expected2 := `INNER JOIN roles r ON r.id = e.role_id` + if options.CustomSQLJoin[0] != expected1 { + t.Errorf("Expected CustomSQLJoin[0]=%q, got %q", expected1, options.CustomSQLJoin[0]) + } + if options.CustomSQLJoin[1] != expected2 { + t.Errorf("Expected CustomSQLJoin[1]=%q, got %q", expected2, options.CustomSQLJoin[1]) + } + }, + }, + { + name: "Parse custom SQL JOIN from headers", + headers: map[string]string{ + "X-Custom-SQL-Join": `LEFT JOIN users u ON u.id = posts.user_id`, + }, + validate: func(t *testing.T, options ExtendedRequestOptions) { + if len(options.CustomSQLJoin) == 0 { + t.Error("Expected CustomSQLJoin to be set from header") + return + } + expected := `LEFT JOIN users u ON u.id = posts.user_id` + if options.CustomSQLJoin[0] != expected { + t.Errorf("Expected CustomSQLJoin[0]=%q, got %q", expected, options.CustomSQLJoin[0]) + } + }, + }, } for _, tt := range tests { diff --git a/pkg/restheadspec/restheadspec.go b/pkg/restheadspec/restheadspec.go index cfe0378..5a743bb 100644 --- a/pkg/restheadspec/restheadspec.go +++ b/pkg/restheadspec/restheadspec.go @@ -32,6 +32,7 @@ // - X-Clean-JSON: Boolean to remove null/empty fields // - X-Custom-SQL-Where: Custom SQL WHERE clause (AND) // - X-Custom-SQL-Or: Custom SQL WHERE clause (OR) +// - X-Custom-SQL-Join: Custom SQL JOIN clauses (pipe-separated for multiple) // // # Usage Example //