mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2026-01-15 07:24:25 +00:00
feat(database): ✨ Enhance Preload and Join functionality
Some checks failed
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Successful in -22m33s
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Successful in -22m11s
Build , Vet Test, and Lint / Build (push) Successful in -26m39s
Build , Vet Test, and Lint / Lint Code (push) Successful in -25m53s
Tests / Integration Tests (push) Failing after -27m30s
Tests / Unit Tests (push) Successful in -27m5s
Some checks failed
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Successful in -22m33s
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Successful in -22m11s
Build , Vet Test, and Lint / Build (push) Successful in -26m39s
Build , Vet Test, and Lint / Lint Code (push) Successful in -25m53s
Tests / Integration Tests (push) Failing after -27m30s
Tests / Unit Tests (push) Successful in -27m5s
* Introduce skipAutoDetect flag to prevent circular calls in PreloadRelation. * Improve handling of long alias chains in PreloadRelation. * Ensure JoinRelation uses PreloadRelation without causing recursion. * Clear deferred preloads after execution to prevent re-execution. feat(recursive_crud): ✨ Filter valid fields in nested CUD processing * Add filterValidFields method to validate input data against model structure. * Use reflection to ensure only valid fields are processed. feat(reflection): ✨ Add utility to get valid JSON field names * Implement GetValidJSONFieldNames to retrieve valid JSON field names from model. * Enhance field validation during nested CUD operations. fix(handler): 🐛 Adjust recursive preload depth limit * Change recursive preload depth limit from 5 to 4 to prevent excessive recursion.
This commit is contained in:
@@ -211,6 +211,7 @@ type BunSelectQuery struct {
|
||||
deferredPreloads []deferredPreload // Preloads to execute as separate queries
|
||||
inJoinContext bool // Track if we're in a JOIN relation context
|
||||
joinTableAlias string // Alias to use for JOIN conditions
|
||||
skipAutoDetect bool // Skip auto-detection to prevent circular calls
|
||||
}
|
||||
|
||||
// deferredPreload represents a preload that will be executed as a separate query
|
||||
@@ -531,22 +532,25 @@ func (b *BunSelectQuery) Preload(relation string, conditions ...interface{}) com
|
||||
func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
|
||||
// Auto-detect relationship type and choose optimal loading strategy
|
||||
// Get the model from the query if available
|
||||
model := b.query.GetModel()
|
||||
if model != nil && model.Value() != nil {
|
||||
relType := reflection.GetRelationType(model.Value(), relation)
|
||||
// Skip auto-detection if flag is set (prevents circular calls from JoinRelation)
|
||||
if !b.skipAutoDetect {
|
||||
model := b.query.GetModel()
|
||||
if model != nil && model.Value() != nil {
|
||||
relType := reflection.GetRelationType(model.Value(), relation)
|
||||
|
||||
// Log the detected relationship type
|
||||
logger.Debug("PreloadRelation '%s' detected as: %s", relation, relType)
|
||||
// Log the detected relationship type
|
||||
logger.Debug("PreloadRelation '%s' detected as: %s", relation, relType)
|
||||
|
||||
// If this is a belongs-to or has-one relation, use JOIN for better performance
|
||||
if relType.ShouldUseJoin() {
|
||||
logger.Info("Using JOIN strategy for %s relation '%s'", relType, relation)
|
||||
return b.JoinRelation(relation, apply...)
|
||||
}
|
||||
// If this is a belongs-to or has-one relation, use JOIN for better performance
|
||||
if relType.ShouldUseJoin() {
|
||||
logger.Info("Using JOIN strategy for %s relation '%s'", relType, relation)
|
||||
return b.JoinRelation(relation, apply...)
|
||||
}
|
||||
|
||||
// For has-many, many-to-many, or unknown: use separate query (safer default)
|
||||
if relType == reflection.RelationHasMany || relType == reflection.RelationManyToMany {
|
||||
logger.Debug("Using separate query for %s relation '%s'", relType, relation)
|
||||
// For has-many, many-to-many, or unknown: use separate query (safer default)
|
||||
if relType == reflection.RelationHasMany || relType == reflection.RelationManyToMany {
|
||||
logger.Debug("Using separate query for %s relation '%s'", relType, relation)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -559,7 +563,7 @@ func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.S
|
||||
const safeAliasLimit = 35 // Leave room for column names
|
||||
|
||||
// If the alias chain is too long, defer this preload to be executed as a separate query
|
||||
if len(aliasChain) > safeAliasLimit {
|
||||
if len(relationParts) > 1 && len(aliasChain) > safeAliasLimit {
|
||||
logger.Info("Preload relation '%s' creates long alias chain '%s' (%d chars). "+
|
||||
"Using separate query to avoid PostgreSQL %d-char identifier limit.",
|
||||
relation, aliasChain, len(aliasChain), postgresIdentifierLimit)
|
||||
@@ -683,6 +687,10 @@ func (b *BunSelectQuery) JoinRelation(relation string, apply ...func(common.Sele
|
||||
|
||||
// Use PreloadRelation with the wrapped functions
|
||||
// Bun's Relation() will use JOIN for belongs-to and has-one relations
|
||||
// CRITICAL: Set skipAutoDetect flag to prevent circular call
|
||||
// (PreloadRelation would detect belongs-to and call JoinRelation again)
|
||||
b.skipAutoDetect = true
|
||||
defer func() { b.skipAutoDetect = false }()
|
||||
return b.PreloadRelation(relation, wrappedApply...)
|
||||
}
|
||||
|
||||
@@ -742,6 +750,8 @@ func (b *BunSelectQuery) Scan(ctx context.Context, dest interface{}) (err error)
|
||||
logger.Warn("Failed to execute deferred preloads: %v", err)
|
||||
// Don't fail the whole query, just log the warning
|
||||
}
|
||||
// Clear deferred preloads to prevent re-execution
|
||||
b.deferredPreloads = nil
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -810,6 +820,8 @@ func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) {
|
||||
logger.Warn("Failed to execute deferred preloads: %v", err)
|
||||
// Don't fail the whole query, just log the warning
|
||||
}
|
||||
// Clear deferred preloads to prevent re-execution
|
||||
b.deferredPreloads = nil
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -898,13 +910,30 @@ func (b *BunSelectQuery) loadChildRelationForRecord(ctx context.Context, record
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get the interface value to pass to Bun
|
||||
parentValue := parentField.Interface()
|
||||
// Get a pointer to the parent field so Bun can modify it
|
||||
// CRITICAL: We need to pass a pointer, not a value, so that when Bun
|
||||
// loads the child records and appends them to the slice, the changes
|
||||
// are reflected in the original struct field.
|
||||
var parentPtr interface{}
|
||||
if parentField.Kind() == reflect.Ptr {
|
||||
// Field is already a pointer (e.g., Parent *Parent), use as-is
|
||||
parentPtr = parentField.Interface()
|
||||
} else {
|
||||
// Field is a value (e.g., Comments []Comment), get its address
|
||||
if parentField.CanAddr() {
|
||||
parentPtr = parentField.Addr().Interface()
|
||||
} else {
|
||||
return fmt.Errorf("cannot get address of field '%s'", parentRelation)
|
||||
}
|
||||
}
|
||||
|
||||
// Load the child relation on the parent record
|
||||
// This uses a shorter alias since we're only loading "Child", not "Parent.Child"
|
||||
// CRITICAL: Use WherePK() to ensure we only load children for THIS specific parent
|
||||
// record, not the first parent in the database table.
|
||||
return b.db.NewSelect().
|
||||
Model(parentValue).
|
||||
Model(parentPtr).
|
||||
WherePK().
|
||||
Relation(childRelation, func(sq *bun.SelectQuery) *bun.SelectQuery {
|
||||
// Apply any custom query modifications
|
||||
if len(apply) > 0 {
|
||||
|
||||
@@ -98,6 +98,10 @@ func (p *NestedCUDProcessor) ProcessNestedCUD(
|
||||
}
|
||||
}
|
||||
|
||||
// Filter regularData to only include fields that exist in the model
|
||||
// Use MapToStruct to validate and filter fields
|
||||
regularData = p.filterValidFields(regularData, model)
|
||||
|
||||
// Inject parent IDs for foreign key resolution
|
||||
p.injectForeignKeys(regularData, modelType, parentIDs)
|
||||
|
||||
@@ -187,6 +191,115 @@ func (p *NestedCUDProcessor) extractCRUDRequest(data map[string]interface{}) str
|
||||
return ""
|
||||
}
|
||||
|
||||
// filterValidFields filters input data to only include fields that exist in the model
|
||||
// Uses reflection.MapToStruct to validate fields and extract only those that match the model
|
||||
func (p *NestedCUDProcessor) filterValidFields(data map[string]interface{}, model interface{}) map[string]interface{} {
|
||||
if len(data) == 0 {
|
||||
return data
|
||||
}
|
||||
|
||||
// Create a new instance of the model to use with MapToStruct
|
||||
modelType := reflect.TypeOf(model)
|
||||
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
|
||||
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||
return data
|
||||
}
|
||||
|
||||
// Create a new instance of the model
|
||||
tempModel := reflect.New(modelType).Interface()
|
||||
|
||||
// Use MapToStruct to map the data - this will only map valid fields
|
||||
err := reflection.MapToStruct(data, tempModel)
|
||||
if err != nil {
|
||||
logger.Debug("Error mapping data to model: %v", err)
|
||||
return data
|
||||
}
|
||||
|
||||
// Extract the mapped fields back into a map
|
||||
// This effectively filters out any fields that don't exist in the model
|
||||
filteredData := make(map[string]interface{})
|
||||
tempModelValue := reflect.ValueOf(tempModel).Elem()
|
||||
|
||||
for key, value := range data {
|
||||
// Check if the field was successfully mapped
|
||||
if fieldWasMapped(tempModelValue, modelType, key) {
|
||||
filteredData[key] = value
|
||||
} else {
|
||||
logger.Debug("Skipping invalid field '%s' - not found in model %v", key, modelType)
|
||||
}
|
||||
}
|
||||
|
||||
return filteredData
|
||||
}
|
||||
|
||||
// fieldWasMapped checks if a field with the given key was mapped to the model
|
||||
func fieldWasMapped(modelValue reflect.Value, modelType reflect.Type, key string) bool {
|
||||
// Look for the field by JSON tag or field name
|
||||
for i := 0; i < modelType.NumField(); i++ {
|
||||
field := modelType.Field(i)
|
||||
|
||||
// Skip unexported fields
|
||||
if !field.IsExported() {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check JSON tag
|
||||
jsonTag := field.Tag.Get("json")
|
||||
if jsonTag != "" && jsonTag != "-" {
|
||||
parts := strings.Split(jsonTag, ",")
|
||||
if len(parts) > 0 && parts[0] == key {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Check bun tag
|
||||
bunTag := field.Tag.Get("bun")
|
||||
if bunTag != "" && bunTag != "-" {
|
||||
if colName := reflection.ExtractColumnFromBunTag(bunTag); colName == key {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Check gorm tag
|
||||
gormTag := field.Tag.Get("gorm")
|
||||
if gormTag != "" && gormTag != "-" {
|
||||
if colName := reflection.ExtractColumnFromGormTag(gormTag); colName == key {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Check lowercase field name
|
||||
if strings.EqualFold(field.Name, key) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Handle embedded structs recursively
|
||||
if field.Anonymous {
|
||||
fieldType := field.Type
|
||||
if fieldType.Kind() == reflect.Ptr {
|
||||
fieldType = fieldType.Elem()
|
||||
}
|
||||
if fieldType.Kind() == reflect.Struct {
|
||||
embeddedValue := modelValue.Field(i)
|
||||
if embeddedValue.Kind() == reflect.Ptr {
|
||||
if embeddedValue.IsNil() {
|
||||
continue
|
||||
}
|
||||
embeddedValue = embeddedValue.Elem()
|
||||
}
|
||||
if fieldWasMapped(embeddedValue, fieldType, key) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// injectForeignKeys injects parent IDs into data for foreign key fields
|
||||
func (p *NestedCUDProcessor) injectForeignKeys(data map[string]interface{}, modelType reflect.Type, parentIDs map[string]interface{}) {
|
||||
if len(parentIDs) == 0 {
|
||||
|
||||
@@ -1370,6 +1370,63 @@ func convertToFloat64(value interface{}) (float64, bool) {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// GetValidJSONFieldNames returns a map of valid JSON field names for a model
|
||||
// This can be used to validate input data against a model's structure
|
||||
// The map keys are the JSON field names (from json tags) that exist in the model
|
||||
func GetValidJSONFieldNames(modelType reflect.Type) map[string]bool {
|
||||
validFields := make(map[string]bool)
|
||||
|
||||
// Unwrap pointers to get to the base struct type
|
||||
for modelType != nil && modelType.Kind() == reflect.Pointer {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
|
||||
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||
return validFields
|
||||
}
|
||||
|
||||
collectValidFieldNames(modelType, validFields)
|
||||
return validFields
|
||||
}
|
||||
|
||||
// collectValidFieldNames recursively collects valid JSON field names from a struct type
|
||||
func collectValidFieldNames(typ reflect.Type, validFields map[string]bool) {
|
||||
for i := 0; i < typ.NumField(); i++ {
|
||||
field := typ.Field(i)
|
||||
|
||||
// Skip unexported fields
|
||||
if !field.IsExported() {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check for embedded structs
|
||||
if field.Anonymous {
|
||||
fieldType := field.Type
|
||||
if fieldType.Kind() == reflect.Ptr {
|
||||
fieldType = fieldType.Elem()
|
||||
}
|
||||
if fieldType.Kind() == reflect.Struct {
|
||||
// Recursively add fields from embedded struct
|
||||
collectValidFieldNames(fieldType, validFields)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Get the JSON tag name for this field (same logic as MapToStruct)
|
||||
jsonTag := field.Tag.Get("json")
|
||||
if jsonTag != "" && jsonTag != "-" {
|
||||
// Extract the field name from the JSON tag (before any options like omitempty)
|
||||
parts := strings.Split(jsonTag, ",")
|
||||
if len(parts) > 0 && parts[0] != "" {
|
||||
validFields[parts[0]] = true
|
||||
}
|
||||
} else {
|
||||
// If no JSON tag, use the field name in lowercase as a fallback
|
||||
validFields[strings.ToLower(field.Name)] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// getRelationModelSingleLevel gets the model type for a single level field (non-recursive)
|
||||
// This is a helper function used by GetRelationModel to handle one level at a time
|
||||
func getRelationModelSingleLevel(model interface{}, fieldName string) interface{} {
|
||||
|
||||
@@ -883,7 +883,7 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co
|
||||
})
|
||||
|
||||
// Handle recursive preloading
|
||||
if preload.Recursive && depth < 5 {
|
||||
if preload.Recursive && depth < 4 {
|
||||
logger.Debug("Applying recursive preload for %s at depth %d", preload.Relation, depth+1)
|
||||
|
||||
// For recursive relationships, we need to get the last part of the relation path
|
||||
|
||||
Reference in New Issue
Block a user