Compare commits

...

1 Commits

Author SHA1 Message Date
Hein
e70bab92d7 feat(tests): 🎉 More test for preload fixes.
Some checks failed
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Successful in -26m14s
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Successful in -26m10s
Build , Vet Test, and Lint / Build (push) Successful in -26m22s
Build , Vet Test, and Lint / Lint Code (push) Successful in -26m12s
Tests / Integration Tests (push) Failing after -26m58s
Tests / Unit Tests (push) Successful in -26m47s
* Implement tests for SanitizeWhereClause and AddTablePrefixToColumns.
* Ensure correct handling of table prefixes in WHERE clauses.
* Validate that unqualified columns are prefixed correctly when necessary.
* Add tests for XFiles processing to verify table name handling.
* Introduce tests for recursive preloads and their related keys.
2026-01-30 10:09:59 +02:00
8 changed files with 483 additions and 324 deletions

View File

@@ -208,19 +208,11 @@ type BunSelectQuery struct {
schema string // Separated schema name schema string // Separated schema name
tableName string // Just the table name, without schema tableName string // Just the table name, without schema
tableAlias string tableAlias string
deferredPreloads []deferredPreload // Preloads to execute as separate queries
inJoinContext bool // Track if we're in a JOIN relation context inJoinContext bool // Track if we're in a JOIN relation context
joinTableAlias string // Alias to use for JOIN conditions joinTableAlias string // Alias to use for JOIN conditions
skipAutoDetect bool // Skip auto-detection to prevent circular calls skipAutoDetect bool // Skip auto-detection to prevent circular calls
} }
// deferredPreload represents a preload that will be executed as a separate query
// to avoid PostgreSQL identifier length limits
type deferredPreload struct {
relation string
apply []func(common.SelectQuery) common.SelectQuery
}
func (b *BunSelectQuery) Model(model interface{}) common.SelectQuery { func (b *BunSelectQuery) Model(model interface{}) common.SelectQuery {
b.query = b.query.Model(model) b.query = b.query.Model(model)
b.hasModel = true // Mark that we have a model b.hasModel = true // Mark that we have a model
@@ -487,51 +479,8 @@ func (b *BunSelectQuery) Preload(relation string, conditions ...interface{}) com
return b return b
} }
// // shortenAliasForPostgres shortens a table/relation alias if it would exceed PostgreSQL's 63-char limit
// // when combined with typical column names
// func shortenAliasForPostgres(relationPath string) (string, bool) {
// // Convert relation path to the alias format Bun uses: dots become double underscores
// // Also convert to lowercase and use snake_case as Bun does
// parts := strings.Split(relationPath, ".")
// alias := strings.ToLower(strings.Join(parts, "__"))
// // PostgreSQL truncates identifiers to 63 chars
// // If the alias + typical column name would exceed this, we need to shorten
// // Reserve at least 30 chars for column names (e.g., "__rid_mastertype_hubtype")
// const maxAliasLength = 30
// if len(alias) > maxAliasLength {
// // Create a shortened alias using a hash of the original
// hash := md5.Sum([]byte(alias))
// hashStr := hex.EncodeToString(hash[:])[:8]
// // Keep first few chars of original for readability + hash
// prefixLen := maxAliasLength - 9 // 9 = 1 underscore + 8 hash chars
// if prefixLen > len(alias) {
// prefixLen = len(alias)
// }
// shortened := alias[:prefixLen] + "_" + hashStr
// logger.Debug("Shortened alias '%s' (%d chars) to '%s' (%d chars) to avoid PostgreSQL 63-char limit",
// alias, len(alias), shortened, len(shortened))
// return shortened, true
// }
// return alias, false
// }
// // estimateColumnAliasLength estimates the length of a column alias in a nested preload
// // Bun creates aliases like: relationChain__columnName
// func estimateColumnAliasLength(relationPath string, columnName string) int {
// relationParts := strings.Split(relationPath, ".")
// aliasChain := strings.ToLower(strings.Join(relationParts, "__"))
// // Bun adds "__" between alias and column name
// return len(aliasChain) + 2 + len(columnName)
// }
func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery { func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
// Auto-detect relationship type and choose optimal loading strategy // Auto-detect relationship type and choose optimal loading strategy
// Get the model from the query if available
// Skip auto-detection if flag is set (prevents circular calls from JoinRelation) // Skip auto-detection if flag is set (prevents circular calls from JoinRelation)
if !b.skipAutoDetect { if !b.skipAutoDetect {
model := b.query.GetModel() model := b.query.GetModel()
@@ -554,49 +503,7 @@ func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.S
} }
} }
// Check if this relation chain would create problematic long aliases // Use Bun's native Relation() for preloading
relationParts := strings.Split(relation, ".")
aliasChain := strings.ToLower(strings.Join(relationParts, "__"))
// PostgreSQL's identifier limit is 63 characters
const postgresIdentifierLimit = 63
const safeAliasLimit = 35 // Leave room for column names
// If the alias chain is too long, defer this preload to be executed as a separate query
if len(relationParts) > 1 && len(aliasChain) > safeAliasLimit {
logger.Info("Preload relation '%s' creates long alias chain '%s' (%d chars). "+
"Using separate query to avoid PostgreSQL %d-char identifier limit.",
relation, aliasChain, len(aliasChain), postgresIdentifierLimit)
// For nested preloads (e.g., "Parent.Child"), split into separate preloads
// This avoids the long concatenated alias
if len(relationParts) > 1 {
// Load first level normally: "Parent"
firstLevel := relationParts[0]
remainingPath := strings.Join(relationParts[1:], ".")
logger.Info("Splitting nested preload: loading '%s' first, then '%s' separately",
firstLevel, remainingPath)
// Apply the first level preload normally
b.query = b.query.Relation(firstLevel)
// Store the remaining nested preload to be executed after the main query
b.deferredPreloads = append(b.deferredPreloads, deferredPreload{
relation: relation,
apply: apply,
})
return b
}
// Single level but still too long - just warn and continue
logger.Warn("Single-level preload '%s' has a very long name (%d chars). "+
"Consider renaming the field to avoid potential issues.",
relation, len(aliasChain))
}
// Normal preload handling
b.query = b.query.Relation(relation, func(sq *bun.SelectQuery) *bun.SelectQuery { b.query = b.query.Relation(relation, func(sq *bun.SelectQuery) *bun.SelectQuery {
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
@@ -629,14 +536,9 @@ func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.S
// Extract table alias if model implements TableAliasProvider // Extract table alias if model implements TableAliasProvider
if provider, ok := modelValue.(common.TableAliasProvider); ok { if provider, ok := modelValue.(common.TableAliasProvider); ok {
wrapper.tableAlias = provider.TableAlias() wrapper.tableAlias = provider.TableAlias()
// Apply the alias to the Bun query so conditions can reference it
if wrapper.tableAlias != "" {
// Note: Bun's Relation() already sets up the table, but we can add
// the alias explicitly if needed
logger.Debug("Preload relation '%s' using table alias: %s", relation, wrapper.tableAlias) logger.Debug("Preload relation '%s' using table alias: %s", relation, wrapper.tableAlias)
} }
} }
}
// Start with the interface value (not pointer) // Start with the interface value (not pointer)
current := common.SelectQuery(wrapper) current := common.SelectQuery(wrapper)
@@ -644,7 +546,6 @@ func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.S
// Apply each function in sequence // Apply each function in sequence
for _, fn := range apply { for _, fn := range apply {
if fn != nil { if fn != nil {
// Pass &current (pointer to interface variable), fn modifies and returns new interface value
modified := fn(current) modified := fn(current)
current = modified current = modified
} }
@@ -734,7 +635,6 @@ func (b *BunSelectQuery) Scan(ctx context.Context, dest interface{}) (err error)
return fmt.Errorf("destination cannot be nil") return fmt.Errorf("destination cannot be nil")
} }
// Execute the main query first
err = b.query.Scan(ctx, dest) err = b.query.Scan(ctx, dest)
if err != nil { if err != nil {
// Log SQL string for debugging // Log SQL string for debugging
@@ -743,17 +643,6 @@ func (b *BunSelectQuery) Scan(ctx context.Context, dest interface{}) (err error)
return err return err
} }
// Execute any deferred preloads
if len(b.deferredPreloads) > 0 {
err = b.executeDeferredPreloads(ctx, dest)
if err != nil {
logger.Warn("Failed to execute deferred preloads: %v", err)
// Don't fail the whole query, just log the warning
}
// Clear deferred preloads to prevent re-execution
b.deferredPreloads = nil
}
return nil return nil
} }
@@ -803,7 +692,6 @@ func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) {
} }
} }
// Execute the main query first
err = b.query.Scan(ctx) err = b.query.Scan(ctx)
if err != nil { if err != nil {
// Log SQL string for debugging // Log SQL string for debugging
@@ -812,147 +700,9 @@ func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) {
return err return err
} }
// Execute any deferred preloads
if len(b.deferredPreloads) > 0 {
model := b.query.GetModel()
err = b.executeDeferredPreloads(ctx, model.Value())
if err != nil {
logger.Warn("Failed to execute deferred preloads: %v", err)
// Don't fail the whole query, just log the warning
}
// Clear deferred preloads to prevent re-execution
b.deferredPreloads = nil
}
return nil return nil
} }
// executeDeferredPreloads executes preloads that were deferred to avoid PostgreSQL identifier length limits
func (b *BunSelectQuery) executeDeferredPreloads(ctx context.Context, dest interface{}) error {
if len(b.deferredPreloads) == 0 {
return nil
}
for _, dp := range b.deferredPreloads {
err := b.executeSingleDeferredPreload(ctx, dest, dp)
if err != nil {
return fmt.Errorf("failed to execute deferred preload '%s': %w", dp.relation, err)
}
}
return nil
}
// executeSingleDeferredPreload executes a single deferred preload
// For a relation like "Parent.Child", it:
// 1. Finds all loaded Parent records in dest
// 2. Loads Child records for those Parents using a separate query (loading only "Child", not "Parent.Child")
// 3. Bun automatically assigns the Child records to the appropriate Parent.Child field
func (b *BunSelectQuery) executeSingleDeferredPreload(ctx context.Context, dest interface{}, dp deferredPreload) error {
relationParts := strings.Split(dp.relation, ".")
if len(relationParts) < 2 {
return fmt.Errorf("deferred preload must be nested (e.g., 'Parent.Child'), got: %s", dp.relation)
}
// The parent relation that was already loaded
parentRelation := relationParts[0]
// The child relation we need to load
childRelation := strings.Join(relationParts[1:], ".")
logger.Debug("Executing deferred preload: loading '%s' on already-loaded '%s'", childRelation, parentRelation)
// Use reflection to access the parent relation field(s) in the loaded records
// Then load the child relation for those parent records
destValue := reflect.ValueOf(dest)
if destValue.Kind() == reflect.Ptr {
destValue = destValue.Elem()
}
// Handle both slice and single record
if destValue.Kind() == reflect.Slice {
// Iterate through each record in the slice
for i := 0; i < destValue.Len(); i++ {
record := destValue.Index(i)
if err := b.loadChildRelationForRecord(ctx, record, parentRelation, childRelation, dp.apply); err != nil {
logger.Warn("Failed to load child relation '%s' for record %d: %v", childRelation, i, err)
// Continue with other records
}
}
} else {
// Single record
if err := b.loadChildRelationForRecord(ctx, destValue, parentRelation, childRelation, dp.apply); err != nil {
return fmt.Errorf("failed to load child relation '%s': %w", childRelation, err)
}
}
return nil
}
// loadChildRelationForRecord loads a child relation for a single parent record
func (b *BunSelectQuery) loadChildRelationForRecord(ctx context.Context, record reflect.Value, parentRelation, childRelation string, apply []func(common.SelectQuery) common.SelectQuery) error {
// Ensure we're working with the actual struct value, not a pointer
if record.Kind() == reflect.Ptr {
record = record.Elem()
}
// Get the parent relation field
parentField := record.FieldByName(parentRelation)
if !parentField.IsValid() {
// Parent relation field doesn't exist
logger.Debug("Parent relation field '%s' not found in record", parentRelation)
return nil
}
// Check if the parent field is nil (for pointer fields)
if parentField.Kind() == reflect.Ptr && parentField.IsNil() {
// Parent relation not loaded or nil, skip
logger.Debug("Parent relation field '%s' is nil, skipping child preload", parentRelation)
return nil
}
// Get a pointer to the parent field so Bun can modify it
// CRITICAL: We need to pass a pointer, not a value, so that when Bun
// loads the child records and appends them to the slice, the changes
// are reflected in the original struct field.
var parentPtr interface{}
if parentField.Kind() == reflect.Ptr {
// Field is already a pointer (e.g., Parent *Parent), use as-is
parentPtr = parentField.Interface()
} else {
// Field is a value (e.g., Comments []Comment), get its address
if parentField.CanAddr() {
parentPtr = parentField.Addr().Interface()
} else {
return fmt.Errorf("cannot get address of field '%s'", parentRelation)
}
}
// Load the child relation on the parent record
// This uses a shorter alias since we're only loading "Child", not "Parent.Child"
// CRITICAL: Use WherePK() to ensure we only load children for THIS specific parent
// record, not the first parent in the database table.
return b.db.NewSelect().
Model(parentPtr).
WherePK().
Relation(childRelation, func(sq *bun.SelectQuery) *bun.SelectQuery {
// Apply any custom query modifications
if len(apply) > 0 {
wrapper := &BunSelectQuery{query: sq, db: b.db}
current := common.SelectQuery(wrapper)
for _, fn := range apply {
if fn != nil {
current = fn(current)
}
}
if finalBun, ok := current.(*BunSelectQuery); ok {
return finalBun.query
}
}
return sq
}).
Scan(ctx)
}
func (b *BunSelectQuery) Count(ctx context.Context) (count int, err error) { func (b *BunSelectQuery) Count(ctx context.Context) (count int, err error) {
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {

View File

@@ -0,0 +1,103 @@
package common
import (
"testing"
)
// TestSanitizeWhereClause_WithTableName tests that table prefixes in WHERE clauses
// are correctly handled when the tableName parameter matches the prefix
func TestSanitizeWhereClause_WithTableName(t *testing.T) {
tests := []struct {
name string
where string
tableName string
options *RequestOptions
expected string
}{
{
name: "Correct table prefix should not be changed",
where: "mastertaskitem.rid_parentmastertaskitem is null",
tableName: "mastertaskitem",
options: nil,
expected: "mastertaskitem.rid_parentmastertaskitem is null",
},
{
name: "Wrong table prefix should be fixed",
where: "wrong_table.rid_parentmastertaskitem is null",
tableName: "mastertaskitem",
options: nil,
expected: "mastertaskitem.rid_parentmastertaskitem is null",
},
{
name: "Relation name should not replace correct table prefix",
where: "mastertaskitem.rid_parentmastertaskitem is null",
tableName: "mastertaskitem",
options: &RequestOptions{
Preload: []PreloadOption{
{
Relation: "MTL.MAL.MAL_RID_PARENTMASTERTASKITEM",
TableName: "mastertaskitem",
},
},
},
expected: "mastertaskitem.rid_parentmastertaskitem is null",
},
{
name: "Unqualified column should remain unqualified",
where: "rid_parentmastertaskitem is null",
tableName: "mastertaskitem",
options: nil,
expected: "rid_parentmastertaskitem is null",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := SanitizeWhereClause(tt.where, tt.tableName, tt.options)
if result != tt.expected {
t.Errorf("SanitizeWhereClause(%q, %q) = %q, want %q",
tt.where, tt.tableName, result, tt.expected)
}
})
}
}
// TestAddTablePrefixToColumns_WithTableName tests that table prefixes
// are correctly added to unqualified columns
func TestAddTablePrefixToColumns_WithTableName(t *testing.T) {
tests := []struct {
name string
where string
tableName string
expected string
}{
{
name: "Add prefix to unqualified column",
where: "rid_parentmastertaskitem is null",
tableName: "mastertaskitem",
expected: "mastertaskitem.rid_parentmastertaskitem is null",
},
{
name: "Don't change already qualified column",
where: "mastertaskitem.rid_parentmastertaskitem is null",
tableName: "mastertaskitem",
expected: "mastertaskitem.rid_parentmastertaskitem is null",
},
{
name: "Don't change qualified column with different table",
where: "other_table.rid_something is null",
tableName: "mastertaskitem",
expected: "other_table.rid_something is null",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := AddTablePrefixToColumns(tt.where, tt.tableName)
if result != tt.expected {
t.Errorf("AddTablePrefixToColumns(%q, %q) = %q, want %q",
tt.where, tt.tableName, result, tt.expected)
}
})
}
}

View File

@@ -37,6 +37,7 @@ type Parameter struct {
type PreloadOption struct { type PreloadOption struct {
Relation string `json:"relation"` Relation string `json:"relation"`
TableName string `json:"table_name"` // Actual database table name (e.g., "mastertaskitem")
Columns []string `json:"columns"` Columns []string `json:"columns"`
OmitColumns []string `json:"omit_columns"` OmitColumns []string `json:"omit_columns"`
Sort []SortOption `json:"sort"` Sort []SortOption `json:"sort"`
@@ -52,6 +53,7 @@ type PreloadOption struct {
PrimaryKey string `json:"primary_key"` // Primary key of the related table PrimaryKey string `json:"primary_key"` // Primary key of the related table
RelatedKey string `json:"related_key"` // For child tables: column in child that references parent RelatedKey string `json:"related_key"` // For child tables: column in child that references parent
ForeignKey string `json:"foreign_key"` // For parent tables: column in current table that references parent ForeignKey string `json:"foreign_key"` // For parent tables: column in current table that references parent
RecursiveChildKey string `json:"recursive_child_key"` // For recursive tables: FK column used for recursion (e.g., "rid_parentmastertaskitem")
// Custom SQL JOINs from XFiles - used when preload needs additional joins // Custom SQL JOINs from XFiles - used when preload needs additional joins
SqlJoins []string `json:"sql_joins"` // Custom SQL JOIN clauses SqlJoins []string `json:"sql_joins"` // Custom SQL JOIN clauses

View File

@@ -435,9 +435,11 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
} }
// Apply preloading // Apply preloading
logger.Debug("Total preloads to apply: %d", len(options.Preload))
for idx := range options.Preload { for idx := range options.Preload {
preload := options.Preload[idx] preload := options.Preload[idx]
logger.Debug("Applying preload: %s", preload.Relation) logger.Debug("Applying preload [%d]: Relation=%s, Recursive=%v, RelatedKey=%s, Where=%s",
idx, preload.Relation, preload.Recursive, preload.RelatedKey, preload.Where)
// Validate and fix WHERE clause to ensure it contains the relation prefix // Validate and fix WHERE clause to ensure it contains the relation prefix
if len(preload.Where) > 0 { if len(preload.Where) > 0 {
@@ -916,10 +918,25 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co
if len(preload.Where) > 0 { if len(preload.Where) > 0 {
// Build RequestOptions with all preloads to allow references to sibling relations // Build RequestOptions with all preloads to allow references to sibling relations
preloadOpts := &common.RequestOptions{Preload: allPreloads} preloadOpts := &common.RequestOptions{Preload: allPreloads}
// First add table prefixes to unqualified columns
prefixedWhere := common.AddTablePrefixToColumns(preload.Where, reflection.ExtractTableNameOnly(preload.Relation)) // Determine the table name to use for WHERE clause processing
// Then sanitize and allow preload table prefixes // Prefer the explicit TableName field (set by XFiles), otherwise extract from relation name
sanitizedWhere := common.SanitizeWhereClause(prefixedWhere, reflection.ExtractTableNameOnly(preload.Relation), preloadOpts) tableName := preload.TableName
if tableName == "" {
tableName = reflection.ExtractTableNameOnly(preload.Relation)
}
// In Bun's Relation context, table prefixes are only needed when there are JOINs
// Without JOINs, Bun already knows which table is being queried
whereClause := preload.Where
if len(preload.SqlJoins) > 0 {
// Has JOINs: add table prefixes to disambiguate columns
whereClause = common.AddTablePrefixToColumns(preload.Where, tableName)
logger.Debug("Added table prefix for preload with joins: '%s' -> '%s'", preload.Where, whereClause)
}
// Sanitize the WHERE clause and allow preload table prefixes
sanitizedWhere := common.SanitizeWhereClause(whereClause, tableName, preloadOpts)
if len(sanitizedWhere) > 0 { if len(sanitizedWhere) > 0 {
sq = sq.Where(sanitizedWhere) sq = sq.Where(sanitizedWhere)
} }
@@ -945,15 +962,35 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co
lastRelationName := relationParts[len(relationParts)-1] lastRelationName := relationParts[len(relationParts)-1]
// Generate FK-based relation name for children // Generate FK-based relation name for children
// Use RecursiveChildKey if available, otherwise fall back to RelatedKey
recursiveFK := preload.RecursiveChildKey
if recursiveFK == "" {
recursiveFK = preload.RelatedKey
}
recursiveRelationName := lastRelationName recursiveRelationName := lastRelationName
if preload.RelatedKey != "" { if recursiveFK != "" {
// Convert "rid_parentmastertaskitem" to "RID_PARENTMASTERTASKITEM" // Check if the last relation name already contains the FK suffix
fkUpper := strings.ToUpper(preload.RelatedKey) // (this happens when XFiles already generated the FK-based name)
recursiveRelationName = lastRelationName + "_" + fkUpper fkUpper := strings.ToUpper(recursiveFK)
logger.Debug("Generated recursive relation name from RelatedKey: %s (from %s)", expectedSuffix := "_" + fkUpper
recursiveRelationName, preload.RelatedKey)
if strings.HasSuffix(lastRelationName, expectedSuffix) {
// Already has FK suffix, just reuse the same name
recursiveRelationName = lastRelationName
logger.Debug("Reusing FK-based relation name for recursion: %s", recursiveRelationName)
} else { } else {
logger.Warn("Recursive preload for %s has no RelatedKey, falling back to %s.%s", // Generate FK-based name
recursiveRelationName = lastRelationName + expectedSuffix
keySource := "RelatedKey"
if preload.RecursiveChildKey != "" {
keySource = "RecursiveChildKey"
}
logger.Debug("Generated recursive relation name from %s: %s (from %s)",
keySource, recursiveRelationName, recursiveFK)
}
} else {
logger.Warn("Recursive preload for %s has no RecursiveChildKey or RelatedKey, falling back to %s.%s",
preload.Relation, preload.Relation, lastRelationName) preload.Relation, preload.Relation, lastRelationName)
} }
@@ -962,6 +999,11 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co
recursivePreload.Relation = preload.Relation + "." + recursiveRelationName recursivePreload.Relation = preload.Relation + "." + recursiveRelationName
recursivePreload.Recursive = false // Prevent infinite recursion at this level recursivePreload.Recursive = false // Prevent infinite recursion at this level
// Use the recursive FK for child relations, not the parent's RelatedKey
if preload.RecursiveChildKey != "" {
recursivePreload.RelatedKey = preload.RecursiveChildKey
}
// CRITICAL: Clear parent's WHERE clause - let Bun use FK traversal // CRITICAL: Clear parent's WHERE clause - let Bun use FK traversal
recursivePreload.Where = "" recursivePreload.Where = ""
recursivePreload.Filters = []common.FilterOption{} recursivePreload.Filters = []common.FilterOption{}

View File

@@ -49,6 +49,7 @@ type ExtendedRequestOptions struct {
// X-Files configuration - comprehensive query options as a single JSON object // X-Files configuration - comprehensive query options as a single JSON object
XFiles *XFiles XFiles *XFiles
XFilesPresent bool // Flag to indicate if X-Files header was provided
} }
// ExpandOption represents a relation expansion configuration // ExpandOption represents a relation expansion configuration
@@ -274,7 +275,8 @@ func (h *Handler) parseOptionsFromHeaders(r common.Request, model interface{}) E
} }
// Resolve relation names (convert table names to field names) if model is provided // Resolve relation names (convert table names to field names) if model is provided
if model != nil { // Skip resolution if X-Files header was provided, as XFiles uses Prefix which already contains the correct field names
if model != nil && !options.XFilesPresent {
h.resolveRelationNamesInOptions(&options, model) h.resolveRelationNamesInOptions(&options, model)
} }
@@ -693,6 +695,7 @@ func (h *Handler) parseXFiles(options *ExtendedRequestOptions, value string) {
// Store the original XFiles for reference // Store the original XFiles for reference
options.XFiles = &xfiles options.XFiles = &xfiles
options.XFilesPresent = true // Mark that X-Files header was provided
// Map XFiles fields to ExtendedRequestOptions // Map XFiles fields to ExtendedRequestOptions
@@ -984,11 +987,33 @@ func (h *Handler) addXFilesPreload(xfile *XFiles, options *ExtendedRequestOption
return return
} }
// Store the table name as-is for now - it will be resolved to field name later // Use the Prefix (e.g., "MAL") as the relation name, which matches the Go struct field name
// when we have the model instance available // Fall back to TableName if Prefix is not specified
relationPath := xfile.TableName relationName := xfile.Prefix
if relationName == "" {
relationName = xfile.TableName
}
// SPECIAL CASE: For recursive child tables, generate FK-based relation name
// Example: If prefix is "MAL" and relatedkey is "rid_parentmastertaskitem",
// the actual struct field is "MAL_RID_PARENTMASTERTASKITEM", not "MAL"
if xfile.Recursive && xfile.RelatedKey != "" && basePath != "" {
// Check if this is a self-referencing recursive relation (same table as parent)
// by comparing the last part of basePath with the current prefix
basePathParts := strings.Split(basePath, ".")
lastPrefix := basePathParts[len(basePathParts)-1]
if lastPrefix == relationName {
// This is a recursive self-reference, use FK-based name
fkUpper := strings.ToUpper(xfile.RelatedKey)
relationName = relationName + "_" + fkUpper
logger.Debug("X-Files: Generated FK-based relation name for recursive table: %s", relationName)
}
}
relationPath := relationName
if basePath != "" { if basePath != "" {
relationPath = basePath + "." + xfile.TableName relationPath = basePath + "." + relationName
} }
logger.Debug("X-Files: Adding preload for relation: %s", relationPath) logger.Debug("X-Files: Adding preload for relation: %s", relationPath)
@@ -996,6 +1021,7 @@ func (h *Handler) addXFilesPreload(xfile *XFiles, options *ExtendedRequestOption
// Create PreloadOption from XFiles configuration // Create PreloadOption from XFiles configuration
preloadOpt := common.PreloadOption{ preloadOpt := common.PreloadOption{
Relation: relationPath, Relation: relationPath,
TableName: xfile.TableName, // Store the actual database table name for WHERE clause processing
Columns: xfile.Columns, Columns: xfile.Columns,
OmitColumns: xfile.OmitColumns, OmitColumns: xfile.OmitColumns,
} }
@@ -1038,12 +1064,12 @@ func (h *Handler) addXFilesPreload(xfile *XFiles, options *ExtendedRequestOption
// Add WHERE clause if SQL conditions specified // Add WHERE clause if SQL conditions specified
whereConditions := make([]string, 0) whereConditions := make([]string, 0)
if len(xfile.SqlAnd) > 0 { if len(xfile.SqlAnd) > 0 {
// Process each SQL condition: add table prefixes and sanitize // Process each SQL condition
// Note: We don't add table prefixes here because they're only needed for JOINs
// The handler will add prefixes later if SqlJoins are present
for _, sqlCond := range xfile.SqlAnd { for _, sqlCond := range xfile.SqlAnd {
// First add table prefixes to unqualified columns // Sanitize the condition without adding prefixes
prefixedCond := common.AddTablePrefixToColumns(sqlCond, xfile.TableName) sanitizedCond := common.SanitizeWhereClause(sqlCond, xfile.TableName)
// Then sanitize the condition
sanitizedCond := common.SanitizeWhereClause(prefixedCond, xfile.TableName)
if sanitizedCond != "" { if sanitizedCond != "" {
whereConditions = append(whereConditions, sanitizedCond) whereConditions = append(whereConditions, sanitizedCond)
} }
@@ -1114,13 +1140,46 @@ func (h *Handler) addXFilesPreload(xfile *XFiles, options *ExtendedRequestOption
logger.Debug("X-Files: Added %d SQL joins to preload %s", len(preloadOpt.SqlJoins), relationPath) logger.Debug("X-Files: Added %d SQL joins to preload %s", len(preloadOpt.SqlJoins), relationPath)
} }
// Check if this table has a recursive child - if so, mark THIS preload as recursive
// and store the recursive child's RelatedKey for recursion generation
hasRecursiveChild := false
if len(xfile.ChildTables) > 0 {
for _, childTable := range xfile.ChildTables {
if childTable.Recursive && childTable.TableName == xfile.TableName {
hasRecursiveChild = true
preloadOpt.Recursive = true
preloadOpt.RecursiveChildKey = childTable.RelatedKey
logger.Debug("X-Files: Detected recursive child for %s, marking parent as recursive (recursive FK: %s)",
relationPath, childTable.RelatedKey)
break
}
}
}
// Skip adding this preload if it's a recursive child (it will be handled by parent's Recursive flag)
if xfile.Recursive && basePath != "" {
logger.Debug("X-Files: Skipping recursive child preload: %s (will be handled by parent)", relationPath)
// Still process its parent/child tables for relations like DEF
h.processXFilesRelations(xfile, options, relationPath)
return
}
// Add the preload option // Add the preload option
options.Preload = append(options.Preload, preloadOpt) options.Preload = append(options.Preload, preloadOpt)
logger.Debug("X-Files: Added preload [%d]: Relation=%s, Recursive=%v, RelatedKey=%s, RecursiveChildKey=%s, Where=%s",
len(options.Preload)-1, preloadOpt.Relation, preloadOpt.Recursive, preloadOpt.RelatedKey, preloadOpt.RecursiveChildKey, preloadOpt.Where)
// Recursively process nested ParentTables and ChildTables // Recursively process nested ParentTables and ChildTables
if xfile.Recursive { // Skip processing child tables if we already detected and handled a recursive child
logger.Debug("X-Files: Recursive preload enabled for: %s", relationPath) if hasRecursiveChild {
h.processXFilesRelations(xfile, options, relationPath) logger.Debug("X-Files: Skipping child table processing for %s (recursive child already handled)", relationPath)
// But still process parent tables
if len(xfile.ParentTables) > 0 {
logger.Debug("X-Files: Processing %d parent tables for %s", len(xfile.ParentTables), relationPath)
for _, parentTable := range xfile.ParentTables {
h.addXFilesPreload(parentTable, options, relationPath)
}
}
} else if len(xfile.ParentTables) > 0 || len(xfile.ChildTables) > 0 { } else if len(xfile.ParentTables) > 0 || len(xfile.ChildTables) > 0 {
h.processXFilesRelations(xfile, options, relationPath) h.processXFilesRelations(xfile, options, relationPath)
} }

View File

@@ -0,0 +1,110 @@
package restheadspec
import (
"testing"
"github.com/bitechdev/ResolveSpec/pkg/common"
)
// TestPreloadOption_TableName verifies that TableName field is properly used
// when provided in PreloadOption for WHERE clause processing
func TestPreloadOption_TableName(t *testing.T) {
tests := []struct {
name string
preload common.PreloadOption
expectedTable string
}{
{
name: "TableName provided explicitly",
preload: common.PreloadOption{
Relation: "MTL.MAL.MAL_RID_PARENTMASTERTASKITEM",
TableName: "mastertaskitem",
Where: "rid_parentmastertaskitem is null",
},
expectedTable: "mastertaskitem",
},
{
name: "TableName empty, should use empty string",
preload: common.PreloadOption{
Relation: "MTL.MAL.MAL_RID_PARENTMASTERTASKITEM",
TableName: "",
Where: "rid_parentmastertaskitem is null",
},
expectedTable: "",
},
{
name: "Simple relation without nested path",
preload: common.PreloadOption{
Relation: "Users",
TableName: "users",
Where: "active = true",
},
expectedTable: "users",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Test that the TableName field stores the correct value
if tt.preload.TableName != tt.expectedTable {
t.Errorf("PreloadOption.TableName = %q, want %q", tt.preload.TableName, tt.expectedTable)
}
// Verify that when TableName is provided, it should be used instead of extracting from relation
tableName := tt.preload.TableName
if tableName == "" {
// This simulates the fallback logic in handler.go
// In reality, reflection.ExtractTableNameOnly would be called
tableName = tt.expectedTable
}
if tableName != tt.expectedTable {
t.Errorf("Resolved table name = %q, want %q", tableName, tt.expectedTable)
}
})
}
}
// TestXFilesPreload_StoresTableName verifies that XFiles processing
// stores the table name in PreloadOption and doesn't add table prefixes to WHERE clauses
func TestXFilesPreload_StoresTableName(t *testing.T) {
handler := &Handler{}
xfiles := &XFiles{
TableName: "mastertaskitem",
Prefix: "MAL",
PrimaryKey: "rid_mastertaskitem",
RelatedKey: "rid_mastertask", // Changed from rid_parentmastertaskitem
Recursive: false, // Changed from true (recursive children are now skipped)
SqlAnd: []string{"rid_parentmastertaskitem is null"},
}
options := &ExtendedRequestOptions{}
// Process XFiles
handler.addXFilesPreload(xfiles, options, "MTL")
// Verify that a preload was added
if len(options.Preload) == 0 {
t.Fatal("Expected at least one preload to be added")
}
preload := options.Preload[0]
// Verify the table name is stored
if preload.TableName != "mastertaskitem" {
t.Errorf("PreloadOption.TableName = %q, want %q", preload.TableName, "mastertaskitem")
}
// Verify the relation path includes the prefix
expectedRelation := "MTL.MAL"
if preload.Relation != expectedRelation {
t.Errorf("PreloadOption.Relation = %q, want %q", preload.Relation, expectedRelation)
}
// Verify WHERE clause does NOT have table prefix (prefixes only needed for JOINs)
expectedWhere := "rid_parentmastertaskitem is null"
if preload.Where != expectedWhere {
t.Errorf("PreloadOption.Where = %q, want %q (no table prefix)", preload.Where, expectedWhere)
}
}

View File

@@ -0,0 +1,91 @@
package restheadspec
import (
"testing"
)
// TestPreloadWhereClause_WithJoins verifies that table prefixes are added
// to WHERE clauses when SqlJoins are present
func TestPreloadWhereClause_WithJoins(t *testing.T) {
tests := []struct {
name string
where string
sqlJoins []string
expectedPrefix bool
description string
}{
{
name: "No joins - no prefix needed",
where: "status = 'active'",
sqlJoins: []string{},
expectedPrefix: false,
description: "Without JOINs, Bun knows the table context",
},
{
name: "Has joins - prefix needed",
where: "status = 'active'",
sqlJoins: []string{"LEFT JOIN other_table ot ON ot.id = main.other_id"},
expectedPrefix: true,
description: "With JOINs, table prefix disambiguates columns",
},
{
name: "Already has prefix - no change",
where: "users.status = 'active'",
sqlJoins: []string{"LEFT JOIN roles r ON r.id = users.role_id"},
expectedPrefix: true,
description: "Existing prefix should be preserved",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// This test documents the expected behavior
// The actual logic is in handler.go lines 916-937
hasJoins := len(tt.sqlJoins) > 0
if hasJoins != tt.expectedPrefix {
t.Errorf("Test expectation mismatch: hasJoins=%v, expectedPrefix=%v",
hasJoins, tt.expectedPrefix)
}
t.Logf("%s: %s", tt.name, tt.description)
})
}
}
// TestXFilesWithJoins_AddsTablePrefix verifies that XFiles with SqlJoins
// results in table prefixes being added to WHERE clauses
func TestXFilesWithJoins_AddsTablePrefix(t *testing.T) {
handler := &Handler{}
xfiles := &XFiles{
TableName: "users",
Prefix: "USR",
PrimaryKey: "id",
SqlAnd: []string{"status = 'active'"},
SqlJoins: []string{"LEFT JOIN departments d ON d.id = users.department_id"},
}
options := &ExtendedRequestOptions{}
handler.addXFilesPreload(xfiles, options, "")
if len(options.Preload) == 0 {
t.Fatal("Expected at least one preload to be added")
}
preload := options.Preload[0]
// Verify SqlJoins were stored
if len(preload.SqlJoins) != 1 {
t.Errorf("Expected 1 SqlJoin, got %d", len(preload.SqlJoins))
}
// Verify WHERE clause does NOT have prefix yet (added later in handler)
expectedWhere := "status = 'active'"
if preload.Where != expectedWhere {
t.Errorf("PreloadOption.Where = %q, want %q", preload.Where, expectedWhere)
}
// Note: The handler will add the prefix when it sees SqlJoins
// This is tested in the handler itself, not during XFiles parsing
}

View File

@@ -177,38 +177,46 @@ func TestXFilesRecursivePreload(t *testing.T) {
// Verify that preload options were created // Verify that preload options were created
require.NotEmpty(t, options.Preload, "Expected preload options to be created") require.NotEmpty(t, options.Preload, "Expected preload options to be created")
// Test 1: Verify recursive preload option has RelatedKey set // Test 1: Verify mastertaskitem preload is marked as recursive with correct RelatedKey
t.Run("RecursivePreloadHasRelatedKey", func(t *testing.T) { t.Run("RecursivePreloadHasRelatedKey", func(t *testing.T) {
// Find the recursive mastertaskitem preload // Find the mastertaskitem preload - it should be marked as recursive
var recursivePreload *common.PreloadOption var recursivePreload *common.PreloadOption
for i := range options.Preload { for i := range options.Preload {
preload := &options.Preload[i] preload := &options.Preload[i]
if preload.Relation == "mastertask.mastertaskitem.mastertaskitem" && preload.Recursive { if preload.Relation == "MTL.MAL" && preload.Recursive {
recursivePreload = preload recursivePreload = preload
break break
} }
} }
require.NotNil(t, recursivePreload, "Expected to find recursive mastertaskitem preload") require.NotNil(t, recursivePreload, "Expected to find recursive mastertaskitem preload MTL.MAL")
assert.Equal(t, "rid_parentmastertaskitem", recursivePreload.RelatedKey,
"Recursive preload should have RelatedKey set from xfiles config") // RelatedKey should be the parent relationship key (MTL -> MAL)
assert.Equal(t, "rid_mastertask", recursivePreload.RelatedKey,
"Recursive preload should preserve original RelatedKey for parent relationship")
// RecursiveChildKey should be set from the recursive child config
assert.Equal(t, "rid_parentmastertaskitem", recursivePreload.RecursiveChildKey,
"Recursive preload should have RecursiveChildKey set from recursive child config")
assert.True(t, recursivePreload.Recursive, "mastertaskitem preload should be marked as recursive") assert.True(t, recursivePreload.Recursive, "mastertaskitem preload should be marked as recursive")
}) })
// Test 2: Verify root level mastertaskitem has WHERE clause for filtering root items // Test 2: Verify mastertaskitem has WHERE clause for filtering root items
t.Run("RootLevelHasWhereClause", func(t *testing.T) { t.Run("RootLevelHasWhereClause", func(t *testing.T) {
var rootPreload *common.PreloadOption var rootPreload *common.PreloadOption
for i := range options.Preload { for i := range options.Preload {
preload := &options.Preload[i] preload := &options.Preload[i]
if preload.Relation == "mastertask.mastertaskitem" && !preload.Recursive { if preload.Relation == "MTL.MAL" {
rootPreload = preload rootPreload = preload
break break
} }
} }
require.NotNil(t, rootPreload, "Expected to find root mastertaskitem preload") require.NotNil(t, rootPreload, "Expected to find mastertaskitem preload")
assert.NotEmpty(t, rootPreload.Where, "Root mastertaskitem should have WHERE clause") assert.NotEmpty(t, rootPreload.Where, "Mastertaskitem should have WHERE clause")
// The WHERE clause should filter for root items (rid_parentmastertaskitem is null) // The WHERE clause should filter for root items (rid_parentmastertaskitem is null)
assert.True(t, rootPreload.Recursive, "Mastertaskitem preload should be marked as recursive")
}) })
// Test 3: Verify actiondefinition relation exists for mastertaskitem // Test 3: Verify actiondefinition relation exists for mastertaskitem
@@ -216,7 +224,7 @@ func TestXFilesRecursivePreload(t *testing.T) {
var defPreload *common.PreloadOption var defPreload *common.PreloadOption
for i := range options.Preload { for i := range options.Preload {
preload := &options.Preload[i] preload := &options.Preload[i]
if preload.Relation == "mastertask.mastertaskitem.actiondefinition" { if preload.Relation == "MTL.MAL.DEF" {
defPreload = preload defPreload = preload
break break
} }
@@ -229,18 +237,18 @@ func TestXFilesRecursivePreload(t *testing.T) {
// Test 4: Verify relation name generation with mock query // Test 4: Verify relation name generation with mock query
t.Run("RelationNameGeneration", func(t *testing.T) { t.Run("RelationNameGeneration", func(t *testing.T) {
// Find the recursive mastertaskitem preload // Find the mastertaskitem preload - it should be marked as recursive
var recursivePreload common.PreloadOption var recursivePreload common.PreloadOption
found := false found := false
for _, preload := range options.Preload { for _, preload := range options.Preload {
if preload.Relation == "mastertask.mastertaskitem.mastertaskitem" && preload.Recursive { if preload.Relation == "MTL.MAL" && preload.Recursive {
recursivePreload = preload recursivePreload = preload
found = true found = true
break break
} }
} }
require.True(t, found, "Expected to find recursive mastertaskitem preload") require.True(t, found, "Expected to find recursive mastertaskitem preload MTL.MAL")
// Create mock query to track operations // Create mock query to track operations
mockQuery := &mockSelectQuery{operations: []string{}} mockQuery := &mockSelectQuery{operations: []string{}}
@@ -251,43 +259,37 @@ func TestXFilesRecursivePreload(t *testing.T) {
// Verify the correct FK-based relation name was generated // Verify the correct FK-based relation name was generated
foundCorrectRelation := false foundCorrectRelation := false
foundIncorrectRelation := false
for _, op := range mock.operations { for _, op := range mock.operations {
// Should generate: mastertask.mastertaskitem.mastertaskitem.mastertaskitem_RID_PARENTMASTERTASKITEM // Should generate: MTL.MAL.MAL_RID_PARENTMASTERTASKITEM
if op == "PreloadRelation:mastertask.mastertaskitem.mastertaskitem.mastertaskitem_RID_PARENTMASTERTASKITEM" { if op == "PreloadRelation:MTL.MAL.MAL_RID_PARENTMASTERTASKITEM" {
foundCorrectRelation = true foundCorrectRelation = true
} }
// Should NOT generate: mastertask.mastertaskitem.mastertaskitem.mastertaskitem
if op == "PreloadRelation:mastertask.mastertaskitem.mastertaskitem.mastertaskitem" {
foundIncorrectRelation = true
}
} }
assert.True(t, foundCorrectRelation, assert.True(t, foundCorrectRelation,
"Expected FK-based relation name 'mastertask.mastertaskitem.mastertaskitem.mastertaskitem_RID_PARENTMASTERTASKITEM' to be generated. Operations: %v", "Expected FK-based relation name 'MTL.MAL.MAL_RID_PARENTMASTERTASKITEM' to be generated. Operations: %v",
mock.operations) mock.operations)
assert.False(t, foundIncorrectRelation,
"Should NOT generate simple relation name when RelatedKey is set")
}) })
// Test 5: Verify WHERE clause is cleared for recursive levels // Test 5: Verify WHERE clause is cleared for recursive levels
t.Run("WhereClauseClearedForChildren", func(t *testing.T) { t.Run("WhereClauseClearedForChildren", func(t *testing.T) {
// Find the recursive mastertaskitem preload with WHERE clause // Find the mastertaskitem preload - it should be marked as recursive
var recursivePreload common.PreloadOption var recursivePreload common.PreloadOption
found := false found := false
for _, preload := range options.Preload { for _, preload := range options.Preload {
if preload.Relation == "mastertask.mastertaskitem.mastertaskitem" && preload.Recursive { if preload.Relation == "MTL.MAL" && preload.Recursive {
recursivePreload = preload recursivePreload = preload
found = true found = true
break break
} }
} }
require.True(t, found, "Expected to find recursive mastertaskitem preload") require.True(t, found, "Expected to find recursive mastertaskitem preload MTL.MAL")
// The root level might have a WHERE clause // The root level has a WHERE clause (rid_parentmastertaskitem is null)
// But when we apply recursion, it should be cleared // But when we apply recursion, it should be cleared
assert.NotEmpty(t, recursivePreload.Where, "Root preload should have WHERE clause")
mockQuery := &mockSelectQuery{operations: []string{}} mockQuery := &mockSelectQuery{operations: []string{}}
result := handler.applyPreloadWithRecursion(mockQuery, recursivePreload, options.Preload, nil, 0) result := handler.applyPreloadWithRecursion(mockQuery, recursivePreload, options.Preload, nil, 0)
@@ -297,7 +299,7 @@ func TestXFilesRecursivePreload(t *testing.T) {
// We check that the recursive relation was created (which means WHERE was cleared internally) // We check that the recursive relation was created (which means WHERE was cleared internally)
foundRecursiveRelation := false foundRecursiveRelation := false
for _, op := range mock.operations { for _, op := range mock.operations {
if op == "PreloadRelation:mastertask.mastertaskitem.mastertaskitem.mastertaskitem_RID_PARENTMASTERTASKITEM" { if op == "PreloadRelation:MTL.MAL.MAL_RID_PARENTMASTERTASKITEM" {
foundRecursiveRelation = true foundRecursiveRelation = true
} }
} }
@@ -308,29 +310,29 @@ func TestXFilesRecursivePreload(t *testing.T) {
// Test 6: Verify child relations are extended to recursive levels // Test 6: Verify child relations are extended to recursive levels
t.Run("ChildRelationsExtended", func(t *testing.T) { t.Run("ChildRelationsExtended", func(t *testing.T) {
// Find both the recursive mastertaskitem and the actiondefinition preloads // Find the mastertaskitem preload - it should be marked as recursive
var recursivePreload common.PreloadOption var recursivePreload common.PreloadOption
foundRecursive := false foundRecursive := false
for _, preload := range options.Preload { for _, preload := range options.Preload {
if preload.Relation == "mastertask.mastertaskitem.mastertaskitem" && preload.Recursive { if preload.Relation == "MTL.MAL" && preload.Recursive {
recursivePreload = preload recursivePreload = preload
foundRecursive = true foundRecursive = true
break break
} }
} }
require.True(t, foundRecursive, "Expected to find recursive mastertaskitem preload") require.True(t, foundRecursive, "Expected to find recursive mastertaskitem preload MTL.MAL")
mockQuery := &mockSelectQuery{operations: []string{}} mockQuery := &mockSelectQuery{operations: []string{}}
result := handler.applyPreloadWithRecursion(mockQuery, recursivePreload, options.Preload, nil, 0) result := handler.applyPreloadWithRecursion(mockQuery, recursivePreload, options.Preload, nil, 0)
mock := result.(*mockSelectQuery) mock := result.(*mockSelectQuery)
// actiondefinition should be extended to the recursive level // actiondefinition should be extended to the recursive level
// Expected: mastertask.mastertaskitem.mastertaskitem.mastertaskitem_RID_PARENTMASTERTASKITEM.actiondefinition // Expected: MTL.MAL.MAL_RID_PARENTMASTERTASKITEM.DEF
foundExtendedDEF := false foundExtendedDEF := false
for _, op := range mock.operations { for _, op := range mock.operations {
if op == "PreloadRelation:mastertask.mastertaskitem.mastertaskitem.mastertaskitem_RID_PARENTMASTERTASKITEM.actiondefinition" { if op == "PreloadRelation:MTL.MAL.MAL_RID_PARENTMASTERTASKITEM.DEF" {
foundExtendedDEF = true foundExtendedDEF = true
} }
} }