diff --git a/pkg/common/types.go b/pkg/common/types.go index 3e81ab9..a68daf6 100644 --- a/pkg/common/types.go +++ b/pkg/common/types.go @@ -52,6 +52,10 @@ type PreloadOption struct { PrimaryKey string `json:"primary_key"` // Primary key of the related table RelatedKey string `json:"related_key"` // For child tables: column in child that references parent ForeignKey string `json:"foreign_key"` // For parent tables: column in current table that references parent + + // Custom SQL JOINs from XFiles - used when preload needs additional joins + SqlJoins []string `json:"sql_joins"` // Custom SQL JOIN clauses + JoinAliases []string `json:"join_aliases"` // Extracted table aliases from SqlJoins for validation } type FilterOption struct { diff --git a/pkg/common/validation.go b/pkg/common/validation.go index 653a869..1a7ae7d 100644 --- a/pkg/common/validation.go +++ b/pkg/common/validation.go @@ -272,13 +272,29 @@ func (v *ColumnValidator) FilterRequestOptions(options RequestOptions) RequestOp filteredPreload.Columns = v.FilterValidColumns(preload.Columns) filteredPreload.OmitColumns = v.FilterValidColumns(preload.OmitColumns) + // Preserve SqlJoins and JoinAliases for preloads with custom joins + filteredPreload.SqlJoins = preload.SqlJoins + filteredPreload.JoinAliases = preload.JoinAliases + // Filter preload filters validPreloadFilters := make([]FilterOption, 0, len(preload.Filters)) for _, filter := range preload.Filters { if v.IsValidColumn(filter.Column) { validPreloadFilters = append(validPreloadFilters, filter) } else { - logger.Warn("Invalid column in preload '%s' filter '%s' removed", preload.Relation, filter.Column) + // Check if the filter column references a joined table alias + foundJoin := false + for _, alias := range preload.JoinAliases { + if strings.Contains(filter.Column, alias) { + foundJoin = true + break + } + } + if foundJoin { + validPreloadFilters = append(validPreloadFilters, filter) + } else { + logger.Warn("Invalid column in preload '%s' filter '%s' removed", preload.Relation, filter.Column) + } } } filteredPreload.Filters = validPreloadFilters diff --git a/pkg/restheadspec/handler.go b/pkg/restheadspec/handler.go index d5f0aa1..4bc6c08 100644 --- a/pkg/restheadspec/handler.go +++ b/pkg/restheadspec/handler.go @@ -882,6 +882,15 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co } } + // Apply custom SQL joins from XFiles + if len(preload.SqlJoins) > 0 { + logger.Debug("Applying %d SQL joins to preload %s", len(preload.SqlJoins), preload.Relation) + for _, joinClause := range preload.SqlJoins { + sq = sq.Join(joinClause) + logger.Debug("Applied SQL join to preload %s: %s", preload.Relation, joinClause) + } + } + // Apply filters if len(preload.Filters) > 0 { for _, filter := range preload.Filters { diff --git a/pkg/restheadspec/headers.go b/pkg/restheadspec/headers.go index ef51cbc..cce3a1e 100644 --- a/pkg/restheadspec/headers.go +++ b/pkg/restheadspec/headers.go @@ -1088,6 +1088,32 @@ func (h *Handler) addXFilesPreload(xfile *XFiles, options *ExtendedRequestOption logger.Debug("X-Files: Set foreign key for %s: %s", relationPath, xfile.ForeignKey) } + // Transfer SqlJoins from XFiles to PreloadOption + if len(xfile.SqlJoins) > 0 { + preloadOpt.SqlJoins = make([]string, 0, len(xfile.SqlJoins)) + preloadOpt.JoinAliases = make([]string, 0, len(xfile.SqlJoins)) + + for _, joinClause := range xfile.SqlJoins { + // Sanitize the join clause + sanitizedJoin := common.SanitizeWhereClause(joinClause, "", nil) + if sanitizedJoin == "" { + logger.Warn("X-Files: SqlJoin failed sanitization for %s: %s", relationPath, joinClause) + continue + } + + preloadOpt.SqlJoins = append(preloadOpt.SqlJoins, sanitizedJoin) + + // Extract join alias for validation + alias := extractJoinAlias(sanitizedJoin) + if alias != "" { + preloadOpt.JoinAliases = append(preloadOpt.JoinAliases, alias) + logger.Debug("X-Files: Extracted join alias for %s: %s", relationPath, alias) + } + } + + logger.Debug("X-Files: Added %d SQL joins to preload %s", len(preloadOpt.SqlJoins), relationPath) + } + // Add the preload option options.Preload = append(options.Preload, preloadOpt) diff --git a/pkg/restheadspec/headers_test.go b/pkg/restheadspec/headers_test.go index 8117483..d83d09f 100644 --- a/pkg/restheadspec/headers_test.go +++ b/pkg/restheadspec/headers_test.go @@ -2,6 +2,8 @@ package restheadspec import ( "testing" + + "github.com/bitechdev/ResolveSpec/pkg/common" ) func TestDecodeHeaderValue(t *testing.T) { @@ -37,6 +39,121 @@ func TestDecodeHeaderValue(t *testing.T) { } } +func TestAddXFilesPreload_WithSqlJoins(t *testing.T) { + handler := &Handler{} + options := &ExtendedRequestOptions{ + RequestOptions: common.RequestOptions{ + Preload: make([]common.PreloadOption, 0), + }, + } + + // Create an XFiles with SqlJoins + xfile := &XFiles{ + TableName: "users", + SqlJoins: []string{ + "LEFT JOIN departments d ON d.id = users.department_id", + "INNER JOIN roles r ON r.id = users.role_id", + }, + FilterFields: []struct { + Field string `json:"field"` + Value string `json:"value"` + Operator string `json:"operator"` + }{ + {Field: "d.active", Value: "true", Operator: "eq"}, + {Field: "r.name", Value: "admin", Operator: "eq"}, + }, + } + + // Add the XFiles preload + handler.addXFilesPreload(xfile, options, "") + + // Verify that a preload was added + if len(options.Preload) != 1 { + t.Fatalf("Expected 1 preload, got %d", len(options.Preload)) + } + + preload := options.Preload[0] + + // Verify relation name + if preload.Relation != "users" { + t.Errorf("Expected relation 'users', got '%s'", preload.Relation) + } + + // Verify SqlJoins were transferred + if len(preload.SqlJoins) != 2 { + t.Fatalf("Expected 2 SQL joins, got %d", len(preload.SqlJoins)) + } + + // Verify JoinAliases were extracted + if len(preload.JoinAliases) != 2 { + t.Fatalf("Expected 2 join aliases, got %d", len(preload.JoinAliases)) + } + + // Verify the aliases are correct + expectedAliases := []string{"d", "r"} + for i, expected := range expectedAliases { + if preload.JoinAliases[i] != expected { + t.Errorf("Expected alias '%s', got '%s'", expected, preload.JoinAliases[i]) + } + } + + // Verify filters were added + if len(preload.Filters) != 2 { + t.Fatalf("Expected 2 filters, got %d", len(preload.Filters)) + } + + // Verify filter columns reference joined tables + if preload.Filters[0].Column != "d.active" { + t.Errorf("Expected filter column 'd.active', got '%s'", preload.Filters[0].Column) + } + if preload.Filters[1].Column != "r.name" { + t.Errorf("Expected filter column 'r.name', got '%s'", preload.Filters[1].Column) + } +} + +func TestExtractJoinAlias(t *testing.T) { + tests := []struct { + name string + joinClause string + expected string + }{ + { + name: "LEFT JOIN with alias", + joinClause: "LEFT JOIN departments d ON d.id = users.department_id", + expected: "d", + }, + { + name: "INNER JOIN with AS keyword", + joinClause: "INNER JOIN users AS u ON u.id = orders.user_id", + expected: "u", + }, + { + name: "JOIN without alias", + joinClause: "JOIN roles ON roles.id = users.role_id", + expected: "", + }, + { + name: "Complex join with multiple conditions", + joinClause: "LEFT OUTER JOIN products p ON p.id = items.product_id AND p.active = true", + expected: "p", + }, + { + name: "Invalid join (no ON clause)", + joinClause: "LEFT JOIN departments", + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := extractJoinAlias(tt.joinClause) + if result != tt.expected { + t.Errorf("Expected alias '%s', got '%s'", tt.expected, result) + } + }) + } +} + // Note: The following functions are unexported (lowercase) and cannot be tested directly: // - parseSelectFields // - parseFieldFilter