feat(database): add custom preload handling for relations

* Introduced custom preloads to manage relations that may exceed PostgreSQL's identifier limit.
* Implemented checks for alias length to prevent truncation warnings.
* Enhanced the loading mechanism for nested relations using separate queries.
This commit is contained in:
Hein
2026-02-02 18:39:48 +02:00
parent 7600a6d1fb
commit 646620ed83
2 changed files with 577 additions and 1 deletions

View File

@@ -211,6 +211,7 @@ type BunSelectQuery struct {
inJoinContext bool // Track if we're in a JOIN relation context inJoinContext bool // Track if we're in a JOIN relation context
joinTableAlias string // Alias to use for JOIN conditions joinTableAlias string // Alias to use for JOIN conditions
skipAutoDetect bool // Skip auto-detection to prevent circular calls skipAutoDetect bool // Skip auto-detection to prevent circular calls
customPreloads map[string][]func(common.SelectQuery) common.SelectQuery // Relations to load with custom implementation
} }
func (b *BunSelectQuery) Model(model interface{}) common.SelectQuery { func (b *BunSelectQuery) Model(model interface{}) common.SelectQuery {
@@ -480,6 +481,25 @@ func (b *BunSelectQuery) Preload(relation string, conditions ...interface{}) com
} }
func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery { func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
// Check if this relation will likely cause alias truncation FIRST
// PostgreSQL has a 63-character limit on identifiers
willTruncate := checkAliasLength(relation)
if willTruncate {
logger.Warn("Preload relation '%s' would generate aliases exceeding PostgreSQL's 63-char limit", relation)
logger.Info("Using custom preload implementation with separate queries for relation '%s'", relation)
// Store this relation for custom post-processing after the main query
// We'll load it manually with separate queries to avoid JOIN aliases
if b.customPreloads == nil {
b.customPreloads = make(map[string][]func(common.SelectQuery) common.SelectQuery)
}
b.customPreloads[relation] = apply
// Return without calling Bun's Relation() - we'll handle it ourselves
return b
}
// Auto-detect relationship type and choose optimal loading strategy // Auto-detect relationship type and choose optimal loading strategy
// Skip auto-detection if flag is set (prevents circular calls from JoinRelation) // Skip auto-detection if flag is set (prevents circular calls from JoinRelation)
if !b.skipAutoDetect { if !b.skipAutoDetect {
@@ -490,8 +510,8 @@ func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.S
// Log the detected relationship type // Log the detected relationship type
logger.Debug("PreloadRelation '%s' detected as: %s", relation, relType) logger.Debug("PreloadRelation '%s' detected as: %s", relation, relType)
// If this is a belongs-to or has-one relation, use JOIN for better performance
if relType.ShouldUseJoin() { if relType.ShouldUseJoin() {
// If this is a belongs-to or has-one relation that won't exceed limits, use JOIN for better performance
logger.Info("Using JOIN strategy for %s relation '%s'", relType, relation) logger.Info("Using JOIN strategy for %s relation '%s'", relType, relation)
return b.JoinRelation(relation, apply...) return b.JoinRelation(relation, apply...)
} }
@@ -504,6 +524,8 @@ func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.S
} }
// Use Bun's native Relation() for preloading // Use Bun's native Relation() for preloading
// Note: For relations that would cause truncation, skipAutoDetect is set to true
// to prevent our auto-detection from adding JOIN optimization
b.query = b.query.Relation(relation, func(sq *bun.SelectQuery) *bun.SelectQuery { b.query = b.query.Relation(relation, func(sq *bun.SelectQuery) *bun.SelectQuery {
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
@@ -561,6 +583,507 @@ func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.S
return b return b
} }
// checkIfRelationAlreadyLoaded checks if a relation is already populated on parent records
// Returns the collection of related records if already loaded
func checkIfRelationAlreadyLoaded(parents reflect.Value, relationName string) (reflect.Value, bool) {
if parents.Len() == 0 {
return reflect.Value{}, false
}
// Get the first parent to check the relation field
firstParent := parents.Index(0)
if firstParent.Kind() == reflect.Ptr {
firstParent = firstParent.Elem()
}
// Find the relation field
relationField := firstParent.FieldByName(relationName)
if !relationField.IsValid() {
return reflect.Value{}, false
}
// Check if it's a slice (has-many)
if relationField.Kind() == reflect.Slice {
// Check if any parent has a non-empty slice
for i := 0; i < parents.Len(); i++ {
parent := parents.Index(i)
if parent.Kind() == reflect.Ptr {
parent = parent.Elem()
}
field := parent.FieldByName(relationName)
if field.IsValid() && !field.IsNil() && field.Len() > 0 {
// Already loaded! Collect all related records from all parents
allRelated := reflect.MakeSlice(field.Type(), 0, field.Len()*parents.Len())
for j := 0; j < parents.Len(); j++ {
p := parents.Index(j)
if p.Kind() == reflect.Ptr {
p = p.Elem()
}
f := p.FieldByName(relationName)
if f.IsValid() && !f.IsNil() {
for k := 0; k < f.Len(); k++ {
allRelated = reflect.Append(allRelated, f.Index(k))
}
}
}
return allRelated, true
}
}
} else if relationField.Kind() == reflect.Ptr {
// Check if it's a pointer (has-one/belongs-to)
if !relationField.IsNil() {
// Already loaded! Collect all related records from all parents
var relatedType reflect.Type
if relationField.Elem().IsValid() {
relatedType = relationField.Type()
} else {
relatedType = relationField.Type()
}
allRelated := reflect.MakeSlice(reflect.SliceOf(relatedType), 0, parents.Len())
for j := 0; j < parents.Len(); j++ {
p := parents.Index(j)
if p.Kind() == reflect.Ptr {
p = p.Elem()
}
f := p.FieldByName(relationName)
if f.IsValid() && !f.IsNil() {
allRelated = reflect.Append(allRelated, f)
}
}
return allRelated, true
}
}
return reflect.Value{}, false
}
// loadCustomPreloads loads relations that would cause alias truncation using separate queries
func (b *BunSelectQuery) loadCustomPreloads(ctx context.Context) error {
model := b.query.GetModel()
if model == nil || model.Value() == nil {
return fmt.Errorf("no model to load preloads for")
}
// Get the actual data from the model
modelValue := reflect.ValueOf(model.Value())
if modelValue.Kind() == reflect.Ptr {
modelValue = modelValue.Elem()
}
// We only handle slices of records for now
if modelValue.Kind() != reflect.Slice {
logger.Warn("Custom preloads only support slice models currently, got: %v", modelValue.Kind())
return nil
}
if modelValue.Len() == 0 {
logger.Debug("No records to load preloads for")
return nil
}
// For each custom preload relation
for relation, applyFuncs := range b.customPreloads {
logger.Info("Loading custom preload for relation: %s", relation)
// Parse the relation path (e.g., "MTL.MAL.DEF" -> ["MTL", "MAL", "DEF"])
relationParts := strings.Split(relation, ".")
// Start with the parent records
currentRecords := modelValue
// Load each level of the relation
for i, relationPart := range relationParts {
isLastPart := i == len(relationParts)-1
logger.Debug("Loading relation part [%d/%d]: %s", i+1, len(relationParts), relationPart)
// Check if this level is already loaded by Bun (avoid duplicates)
existingRecords, alreadyLoaded := checkIfRelationAlreadyLoaded(currentRecords, relationPart)
if alreadyLoaded && existingRecords.IsValid() && existingRecords.Len() > 0 {
logger.Info("Relation '%s' already loaded by Bun, using existing %d records", relationPart, existingRecords.Len())
currentRecords = existingRecords
continue
}
// Load this level and get the loaded records for the next level
loadedRecords, err := b.loadRelationLevel(ctx, currentRecords, relationPart, isLastPart, applyFuncs)
if err != nil {
return fmt.Errorf("failed to load relation %s (part %s): %w", relation, relationPart, err)
}
// For nested relations, use the loaded records as parents for the next level
if !isLastPart && loadedRecords.IsValid() && loadedRecords.Len() > 0 {
logger.Debug("Collected %d records for next level", loadedRecords.Len())
currentRecords = loadedRecords
} else if !isLastPart {
logger.Debug("No records loaded at level %s, stopping nested preload", relationPart)
break
}
}
}
return nil
}
// loadRelationLevel loads a single level of a relation for a set of parent records
// Returns the loaded records (for use as parents in nested preloads) and any error
func (b *BunSelectQuery) loadRelationLevel(ctx context.Context, parentRecords reflect.Value, relationName string, isLast bool, applyFuncs []func(common.SelectQuery) common.SelectQuery) (reflect.Value, error) {
if parentRecords.Len() == 0 {
return reflect.Value{}, nil
}
// Get the first record to inspect the struct type
firstRecord := parentRecords.Index(0)
if firstRecord.Kind() == reflect.Ptr {
firstRecord = firstRecord.Elem()
}
if firstRecord.Kind() != reflect.Struct {
return reflect.Value{}, fmt.Errorf("expected struct, got %v", firstRecord.Kind())
}
parentType := firstRecord.Type()
// Find the relation field in the struct
structField, found := parentType.FieldByName(relationName)
if !found {
return reflect.Value{}, fmt.Errorf("relation field %s not found in struct %s", relationName, parentType.Name())
}
// Parse the bun tag to get relation info
bunTag := structField.Tag.Get("bun")
logger.Debug("Relation %s bun tag: %s", relationName, bunTag)
relInfo, err := parseRelationTag(bunTag)
if err != nil {
return reflect.Value{}, fmt.Errorf("failed to parse relation tag for %s: %w", relationName, err)
}
logger.Debug("Parsed relation: type=%s, join=%s", relInfo.relType, relInfo.joinCondition)
// Extract foreign key values from parent records
fkValues, err := extractForeignKeyValues(parentRecords, relInfo.localKey)
if err != nil {
return reflect.Value{}, fmt.Errorf("failed to extract FK values: %w", err)
}
if len(fkValues) == 0 {
logger.Debug("No foreign key values to load for relation %s", relationName)
return reflect.Value{}, nil
}
logger.Debug("Loading %d related records for %s (FK values: %v)", len(fkValues), relationName, fkValues)
// Get the related model type
relatedType := structField.Type
isSlice := relatedType.Kind() == reflect.Slice
if isSlice {
relatedType = relatedType.Elem()
}
if relatedType.Kind() == reflect.Ptr {
relatedType = relatedType.Elem()
}
// Create a slice to hold the results
resultsSlice := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(relatedType)), 0, len(fkValues))
resultsPtr := reflect.New(resultsSlice.Type())
resultsPtr.Elem().Set(resultsSlice)
// Build and execute the query
query := b.db.NewSelect().Model(resultsPtr.Interface())
// Apply WHERE clause: foreign_key IN (values...)
query = query.Where(fmt.Sprintf("%s IN (?)", relInfo.foreignKey), bun.In(fkValues))
// Apply user's functions (if any)
if isLast && len(applyFuncs) > 0 {
wrapper := &BunSelectQuery{query: query, db: b.db}
for _, fn := range applyFuncs {
if fn != nil {
wrapper = fn(wrapper).(*BunSelectQuery)
query = wrapper.query
}
}
}
// Execute the query
err = query.Scan(ctx)
if err != nil {
return reflect.Value{}, fmt.Errorf("failed to load related records: %w", err)
}
loadedRecords := resultsPtr.Elem()
logger.Info("Loaded %d related records for relation %s", loadedRecords.Len(), relationName)
// Associate loaded records back to parent records
err = associateRelatedRecords(parentRecords, loadedRecords, relationName, relInfo, isSlice)
if err != nil {
return reflect.Value{}, err
}
// Return the loaded records for use in nested preloads
return loadedRecords, nil
}
// relationInfo holds parsed relation metadata
type relationInfo struct {
relType string // has-one, has-many, belongs-to
localKey string // Key in parent table
foreignKey string // Key in related table
joinCondition string // Full join condition
}
// parseRelationTag parses the bun:"rel:..." tag
func parseRelationTag(tag string) (*relationInfo, error) {
info := &relationInfo{}
// Parse tag like: rel:has-one,join:rid_mastertaskitem=rid_mastertaskitem
parts := strings.Split(tag, ",")
for _, part := range parts {
part = strings.TrimSpace(part)
if strings.HasPrefix(part, "rel:") {
info.relType = strings.TrimPrefix(part, "rel:")
} else if strings.HasPrefix(part, "join:") {
info.joinCondition = strings.TrimPrefix(part, "join:")
// Parse join: local_key=foreign_key
joinParts := strings.Split(info.joinCondition, "=")
if len(joinParts) == 2 {
info.localKey = strings.TrimSpace(joinParts[0])
info.foreignKey = strings.TrimSpace(joinParts[1])
}
}
}
if info.relType == "" || info.localKey == "" || info.foreignKey == "" {
return nil, fmt.Errorf("incomplete relation tag: %s", tag)
}
return info, nil
}
// extractForeignKeyValues collects FK values from parent records
func extractForeignKeyValues(records reflect.Value, fkFieldName string) ([]interface{}, error) {
values := make([]interface{}, 0, records.Len())
seenValues := make(map[interface{}]bool)
for i := 0; i < records.Len(); i++ {
record := records.Index(i)
if record.Kind() == reflect.Ptr {
record = record.Elem()
}
// Find the FK field - try both exact name and capitalized version
fkField := record.FieldByName(fkFieldName)
if !fkField.IsValid() {
// Try capitalized version
fkField = record.FieldByName(strings.Title(fkFieldName))
}
if !fkField.IsValid() {
// Try finding by json tag
for j := 0; j < record.NumField(); j++ {
field := record.Type().Field(j)
jsonTag := field.Tag.Get("json")
bunTag := field.Tag.Get("bun")
if strings.HasPrefix(jsonTag, fkFieldName) || strings.Contains(bunTag, fkFieldName) {
fkField = record.Field(j)
break
}
}
}
if !fkField.IsValid() {
continue // Skip records without FK
}
// Extract the value
var value interface{}
if fkField.CanInterface() {
value = fkField.Interface()
// Handle SqlNull types
if nullType, ok := value.(interface{ IsNull() bool }); ok {
if nullType.IsNull() {
continue
}
}
// Handle types with Int64() method
if int64er, ok := value.(interface{ Int64() int64 }); ok {
value = int64er.Int64()
}
// Deduplicate
if !seenValues[value] {
values = append(values, value)
seenValues[value] = true
}
}
}
return values, nil
}
// associateRelatedRecords associates loaded records back to parents
func associateRelatedRecords(parents, related reflect.Value, fieldName string, relInfo *relationInfo, isSlice bool) error {
logger.Debug("Associating %d related records to %d parents for field '%s'", related.Len(), parents.Len(), fieldName)
// Build a map: foreignKey -> related record(s)
relatedMap := make(map[interface{}][]reflect.Value)
for i := 0; i < related.Len(); i++ {
relRecord := related.Index(i)
relRecordElem := relRecord
if relRecordElem.Kind() == reflect.Ptr {
relRecordElem = relRecordElem.Elem()
}
// Get the foreign key value from the related record - try multiple variations
fkField := findFieldByName(relRecordElem, relInfo.foreignKey)
if !fkField.IsValid() {
logger.Warn("Could not find FK field '%s' in related record type %s", relInfo.foreignKey, relRecordElem.Type().Name())
continue
}
fkValue := extractFieldValue(fkField)
if fkValue == nil {
continue
}
relatedMap[fkValue] = append(relatedMap[fkValue], related.Index(i))
}
logger.Debug("Built related map with %d unique FK values", len(relatedMap))
// Associate with parents
associatedCount := 0
for i := 0; i < parents.Len(); i++ {
parentPtr := parents.Index(i)
parent := parentPtr
if parent.Kind() == reflect.Ptr {
parent = parent.Elem()
}
// Get the local key value from parent
localField := findFieldByName(parent, relInfo.localKey)
if !localField.IsValid() {
logger.Warn("Could not find local key field '%s' in parent type %s", relInfo.localKey, parent.Type().Name())
continue
}
localValue := extractFieldValue(localField)
if localValue == nil {
continue
}
// Find matching related records
matches := relatedMap[localValue]
if len(matches) == 0 {
continue
}
// Set the relation field - IMPORTANT: use the pointer, not the elem
relationField := parent.FieldByName(fieldName)
if !relationField.IsValid() {
logger.Warn("Relation field '%s' not found in parent type %s", fieldName, parent.Type().Name())
continue
}
if !relationField.CanSet() {
logger.Warn("Relation field '%s' cannot be set (unexported?)", fieldName)
continue
}
if isSlice {
// For has-many: replace entire slice (don't append to avoid duplicates)
newSlice := reflect.MakeSlice(relationField.Type(), 0, len(matches))
for _, match := range matches {
newSlice = reflect.Append(newSlice, match)
}
relationField.Set(newSlice)
associatedCount += len(matches)
logger.Debug("Set has-many field '%s' with %d records for parent %d", fieldName, len(matches), i)
} else {
// For has-one/belongs-to: only set if not already set (avoid duplicates)
if relationField.IsNil() {
relationField.Set(matches[0])
associatedCount++
logger.Debug("Set has-one field '%s' for parent %d", fieldName, i)
} else {
logger.Debug("Skipping has-one field '%s' for parent %d (already set)", fieldName, i)
}
}
}
logger.Info("Associated %d related records to %d parents for field '%s'", associatedCount, parents.Len(), fieldName)
return nil
}
// findFieldByName finds a struct field by name, trying multiple variations
func findFieldByName(v reflect.Value, name string) reflect.Value {
// Try exact name
field := v.FieldByName(name)
if field.IsValid() {
return field
}
// Try with capital first letter
if len(name) > 0 {
capital := strings.ToUpper(name[0:1]) + name[1:]
field = v.FieldByName(capital)
if field.IsValid() {
return field
}
}
// Try searching by json or bun tag
t := v.Type()
for i := 0; i < t.NumField(); i++ {
f := t.Field(i)
jsonTag := f.Tag.Get("json")
bunTag := f.Tag.Get("bun")
// Check json tag
if strings.HasPrefix(jsonTag, name+",") || jsonTag == name {
return v.Field(i)
}
// Check bun tag for column name
if strings.Contains(bunTag, name+",") || strings.Contains(bunTag, name+":") {
return v.Field(i)
}
}
return reflect.Value{}
}
// extractFieldValue extracts the value from a field, handling SqlNull types
func extractFieldValue(field reflect.Value) interface{} {
if !field.CanInterface() {
return nil
}
value := field.Interface()
// Handle SqlNull types
if nullType, ok := value.(interface{ IsNull() bool }); ok {
if nullType.IsNull() {
return nil
}
}
// Handle types with Int64() method
if int64er, ok := value.(interface{ Int64() int64 }); ok {
return int64er.Int64()
}
// Handle types with String() method for comparison
if stringer, ok := value.(interface{ String() string }); ok {
return stringer.String()
}
return value
}
func (b *BunSelectQuery) JoinRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery { func (b *BunSelectQuery) JoinRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
// JoinRelation uses a LEFT JOIN instead of a separate query // JoinRelation uses a LEFT JOIN instead of a separate query
// This is more efficient for many-to-one or one-to-one relationships // This is more efficient for many-to-one or one-to-one relationships
@@ -700,6 +1223,15 @@ func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) {
return err return err
} }
// After main query, load custom preloads using separate queries
if len(b.customPreloads) > 0 {
logger.Info("Loading %d custom preload(s) with separate queries", len(b.customPreloads))
if err := b.loadCustomPreloads(ctx); err != nil {
logger.Error("Failed to load custom preloads: %v", err)
return err
}
}
return nil return nil
} }

View File

@@ -4,6 +4,7 @@ import (
"database/sql" "database/sql"
"strings" "strings"
"github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/uptrace/bun/dialect/mssqldialect" "github.com/uptrace/bun/dialect/mssqldialect"
"github.com/uptrace/bun/dialect/pgdialect" "github.com/uptrace/bun/dialect/pgdialect"
"github.com/uptrace/bun/dialect/sqlitedialect" "github.com/uptrace/bun/dialect/sqlitedialect"
@@ -13,6 +14,49 @@ import (
"gorm.io/gorm" "gorm.io/gorm"
) )
// PostgreSQL identifier length limit (63 bytes + null terminator = 64 bytes total)
const postgresIdentifierLimit = 63
// checkAliasLength checks if a preload relation path will generate aliases that exceed PostgreSQL's limit
// Returns true if the alias is likely to be truncated
func checkAliasLength(relation string) bool {
// Bun generates aliases like: parentalias__childalias__columnname
// For nested preloads, it uses the pattern: relation1__relation2__relation3__columnname
parts := strings.Split(relation, ".")
if len(parts) <= 1 {
return false // Single level relations are fine
}
// Calculate the actual alias prefix length that Bun will generate
// Bun uses double underscores (__) between each relation level
// and converts the relation names to lowercase with underscores
aliasPrefix := strings.ToLower(strings.Join(parts, "__"))
aliasPrefixLen := len(aliasPrefix)
// We need to add 2 more underscores for the column name separator plus column name length
// Column names in the error were things like "rid_mastertype_hubtype" (23 chars)
// To be safe, assume the longest column name could be around 35 chars
maxColumnNameLen := 35
estimatedMaxLen := aliasPrefixLen + 2 + maxColumnNameLen
// Check if this would exceed PostgreSQL's identifier limit
if estimatedMaxLen > postgresIdentifierLimit {
logger.Warn("Preload relation '%s' will generate aliases up to %d chars (prefix: %d + column: %d), exceeding PostgreSQL's %d char limit",
relation, estimatedMaxLen, aliasPrefixLen, maxColumnNameLen, postgresIdentifierLimit)
return true
}
// Also check if just the prefix is getting close (within 15 chars of limit)
// This gives room for column names
if aliasPrefixLen > (postgresIdentifierLimit - 15) {
logger.Warn("Preload relation '%s' has alias prefix of %d chars, which may cause truncation with longer column names (limit: %d)",
relation, aliasPrefixLen, postgresIdentifierLimit)
return true
}
return false
}
// parseTableName splits a table name that may contain schema into separate schema and table // parseTableName splits a table name that may contain schema into separate schema and table
// For example: "public.users" -> ("public", "users") // For example: "public.users" -> ("public", "users")
// //