diff --git a/pkg/common/sql_helpers_test.go b/pkg/common/sql_helpers_test.go index e7cefd4..d4a0706 100644 --- a/pkg/common/sql_helpers_test.go +++ b/pkg/common/sql_helpers_test.go @@ -138,7 +138,10 @@ func TestSanitizeWhereClause(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result := SanitizeWhereClause(tt.where, tt.tableName) + // First add table prefixes to unqualified columns + prefixedWhere := AddTablePrefixToColumns(tt.where, tt.tableName) + // Then sanitize the where clause + result := SanitizeWhereClause(prefixedWhere, tt.tableName) if result != tt.expected { t.Errorf("SanitizeWhereClause(%q, %q) = %q; want %q", tt.where, tt.tableName, result, tt.expected) } @@ -348,6 +351,7 @@ func TestSanitizeWhereClauseWithPreloads(t *testing.T) { tableName string options *RequestOptions expected string + addPrefix bool }{ { name: "preload relation prefix is preserved", @@ -416,15 +420,30 @@ func TestSanitizeWhereClauseWithPreloads(t *testing.T) { options: &RequestOptions{Preload: []PreloadOption{}}, expected: "users.status = 'active'", }, + + { + name: "complex where clause with subquery and preload", + where: `("mastertaskitem"."rid_mastertask" IN (6, 173, 157, 172, 174, 171, 170, 169, 167, 168, 166, 145, 161, 164, 146, 160, 147, 159, 148, 150, 152, 175, 151, 8, 153, 149, 155, 154, 165)) AND (rid_parentmastertaskitem is null)`, + tableName: "mastertaskitem", + options: nil, + expected: `("mastertaskitem"."rid_mastertask" IN (6, 173, 157, 172, 174, 171, 170, 169, 167, 168, 166, 145, 161, 164, 146, 160, 147, 159, 148, 150, 152, 175, 151, 8, 153, 149, 155, 154, 165)) AND (mastertaskitem.rid_parentmastertaskitem is null)`, + addPrefix: true, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var result string + prefixedWhere := tt.where + if tt.addPrefix { + // First add table prefixes to unqualified columns + prefixedWhere = AddTablePrefixToColumns(tt.where, tt.tableName) + } + // Then sanitize the where clause if tt.options != nil { - result = SanitizeWhereClause(tt.where, tt.tableName, tt.options) + result = SanitizeWhereClause(prefixedWhere, tt.tableName, tt.options) } else { - result = SanitizeWhereClause(tt.where, tt.tableName) + result = SanitizeWhereClause(prefixedWhere, tt.tableName) } if result != tt.expected { t.Errorf("SanitizeWhereClause(%q, %q, options) = %q; want %q", tt.where, tt.tableName, result, tt.expected) diff --git a/pkg/restheadspec/handler.go b/pkg/restheadspec/handler.go index f109fd5..eda5f4c 100644 --- a/pkg/restheadspec/handler.go +++ b/pkg/restheadspec/handler.go @@ -861,7 +861,10 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co if len(preload.Where) > 0 { // Build RequestOptions with all preloads to allow references to sibling relations preloadOpts := &common.RequestOptions{Preload: allPreloads} - sanitizedWhere := common.SanitizeWhereClause(preload.Where, reflection.ExtractTableNameOnly(preload.Relation), preloadOpts) + // First add table prefixes to unqualified columns + prefixedWhere := common.AddTablePrefixToColumns(preload.Where, reflection.ExtractTableNameOnly(preload.Relation)) + // Then sanitize and allow preload table prefixes + sanitizedWhere := common.SanitizeWhereClause(prefixedWhere, reflection.ExtractTableNameOnly(preload.Relation), preloadOpts) if len(sanitizedWhere) > 0 { sq = sq.Where(sanitizedWhere) } diff --git a/pkg/restheadspec/headers.go b/pkg/restheadspec/headers.go index a64cdb4..7c5d209 100644 --- a/pkg/restheadspec/headers.go +++ b/pkg/restheadspec/headers.go @@ -935,7 +935,16 @@ func (h *Handler) addXFilesPreload(xfile *XFiles, options *ExtendedRequestOption // Add WHERE clause if SQL conditions specified whereConditions := make([]string, 0) if len(xfile.SqlAnd) > 0 { - whereConditions = append(whereConditions, xfile.SqlAnd...) + // Process each SQL condition: add table prefixes and sanitize + for _, sqlCond := range xfile.SqlAnd { + // First add table prefixes to unqualified columns + prefixedCond := common.AddTablePrefixToColumns(sqlCond, xfile.TableName) + // Then sanitize the condition + sanitizedCond := common.SanitizeWhereClause(prefixedCond, xfile.TableName) + if sanitizedCond != "" { + whereConditions = append(whereConditions, sanitizedCond) + } + } } if len(whereConditions) > 0 { preloadOpt.Where = strings.Join(whereConditions, " AND ")