diff --git a/pkg/common/adapters/database/gorm.go b/pkg/common/adapters/database/gorm.go index 5241ead..4479976 100644 --- a/pkg/common/adapters/database/gorm.go +++ b/pkg/common/adapters/database/gorm.go @@ -204,7 +204,7 @@ func (g *GormSelectQuery) PreloadRelation(relation string, apply ...func(common. } wrapper := &GormSelectQuery{ - db: g.db, + db: db, } current := common.SelectQuery(wrapper) diff --git a/pkg/resolvespec/handler.go b/pkg/resolvespec/handler.go index 23dc977..7466116 100644 --- a/pkg/resolvespec/handler.go +++ b/pkg/resolvespec/handler.go @@ -1137,8 +1137,8 @@ func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, pre // This requires a more sophisticated approach with callbacks or query builders // Apply preloading - logger.Debug("Applying preload: %s", preload.Relation) - query = query.PreloadRelation(preload.Relation, func(sq common.SelectQuery) common.SelectQuery { + logger.Debug("Applying preload: %s", relationFieldName) + query = query.PreloadRelation(relationFieldName, func(sq common.SelectQuery) common.SelectQuery { if len(preload.OmitColumns) > 0 { allCols := reflection.GetModelColumns(model) // Remove omitted columns @@ -1158,7 +1158,28 @@ func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, pre } if len(preload.Columns) > 0 { - sq = sq.Column(preload.Columns...) + // Ensure foreign key is included in column selection for GORM to establish the relationship + columns := make([]string, len(preload.Columns)) + copy(columns, preload.Columns) + + // Add foreign key if not already present + if relInfo.foreignKey != "" { + // Convert struct field name (e.g., DepartmentID) to snake_case (e.g., department_id) + foreignKeyColumn := toSnakeCase(relInfo.foreignKey) + + hasForeignKey := false + for _, col := range columns { + if col == foreignKeyColumn || col == relInfo.foreignKey { + hasForeignKey = true + break + } + } + if !hasForeignKey { + columns = append(columns, foreignKeyColumn) + } + } + + sq = sq.Column(columns...) } if len(preload.Filters) > 0 { @@ -1240,3 +1261,28 @@ func (h *Handler) extractTagValue(tag, key string) string { } return "" } + +// toSnakeCase converts a PascalCase or camelCase string to snake_case +func toSnakeCase(s string) string { + var result strings.Builder + runes := []rune(s) + + for i := 0; i < len(runes); i++ { + r := runes[i] + + if i > 0 && r >= 'A' && r <= 'Z' { + // Check if previous character is lowercase or if next character is lowercase + prevIsLower := runes[i-1] >= 'a' && runes[i-1] <= 'z' + nextIsLower := i+1 < len(runes) && runes[i+1] >= 'a' && runes[i+1] <= 'z' + + // Add underscore if this is the start of a new word + // (previous was lowercase OR this is followed by lowercase) + if prevIsLower || nextIsLower { + result.WriteByte('_') + } + } + + result.WriteRune(r) + } + return strings.ToLower(result.String()) +}