Compare commits

...

2 Commits

Author SHA1 Message Date
Hein
ca4e53969b Better tests
Some checks are pending
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Waiting to run
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Waiting to run
Build , Vet Test, and Lint / Lint Code (push) Waiting to run
Build , Vet Test, and Lint / Build (push) Waiting to run
Tests / Unit Tests (push) Waiting to run
Tests / Integration Tests (push) Waiting to run
2025-12-09 15:32:16 +02:00
Hein
db2b7e878e Better handling of preloads 2025-12-09 15:12:17 +02:00
9 changed files with 2050 additions and 89 deletions

View File

@@ -0,0 +1,218 @@
# Automatic Relation Loading Strategies
## Overview
**NEW:** The database adapters now **automatically** choose the optimal loading strategy by inspecting your model's relationship tags!
Simply use `PreloadRelation()` and the system automatically:
- Detects relationship type from Bun/GORM tags
- Uses **JOIN** for many-to-one and one-to-one (efficient, no duplication)
- Uses **separate query** for one-to-many and many-to-many (avoids duplication)
## How It Works
```go
// Just write this - the system handles the rest!
db.NewSelect().
Model(&links).
PreloadRelation("Provider"). // ✓ Auto-detects belongs-to → uses JOIN
PreloadRelation("Tags"). // ✓ Auto-detects has-many → uses separate query
Scan(ctx, &links)
```
### Detection Logic
The system inspects your model's struct tags:
**Bun models:**
```go
type Link struct {
Provider *Provider `bun:"rel:belongs-to"` // → Detected: belongs-to → JOIN
Tags []Tag `bun:"rel:has-many"` // → Detected: has-many → Separate query
}
```
**GORM models:**
```go
type Link struct {
ProviderID int
Provider *Provider `gorm:"foreignKey:ProviderID"` // → Detected: belongs-to → JOIN
Tags []Tag `gorm:"many2many:link_tags"` // → Detected: many-to-many → Separate query
}
```
**Type inference (fallback):**
- `[]Type` (slice) → has-many → Separate query
- `*Type` (pointer) → belongs-to → JOIN
- `Type` (struct) → belongs-to → JOIN
### What Gets Logged
Enable debug logging to see strategy selection:
```go
bunAdapter.EnableQueryDebug()
```
**Output:**
```
DEBUG: PreloadRelation 'Provider' detected as: belongs-to
INFO: Using JOIN strategy for belongs-to relation 'Provider'
DEBUG: PreloadRelation 'Links' detected as: has-many
DEBUG: Using separate query for has-many relation 'Links'
```
## Relationship Types
| Bun Tag | GORM Pattern | Field Type | Strategy | Why |
|---------|--------------|------------|----------|-----|
| `rel:has-many` | Slice field | `[]Type` | Separate Query | Avoids duplicating parent data |
| `rel:belongs-to` | `foreignKey:` | `*Type` | JOIN | Single parent, no duplication |
| `rel:has-one` | Single pointer | `*Type` | JOIN | One-to-one, no duplication |
| `rel:many-to-many` | `many2many:` | `[]Type` | Separate Query | Complex join, avoid cartesian |
## Manual Override
If you need to force a specific strategy, use `JoinRelation()`:
```go
// Force JOIN even for has-many (not recommended)
db.NewSelect().
Model(&providers).
JoinRelation("Links"). // Explicitly use JOIN
Scan(ctx, &providers)
```
## Examples
### Automatic Strategy Selection (Recommended)
```go
// Example 1: Loading parent provider for each link
// System detects belongs-to → uses JOIN automatically
db.NewSelect().
Model(&links).
PreloadRelation("Provider", func(q common.SelectQuery) common.SelectQuery {
return q.Where("active = ?", true)
}).
Scan(ctx, &links)
// Generated SQL: Single query with JOIN
// SELECT links.*, providers.*
// FROM links
// LEFT JOIN providers ON links.provider_id = providers.id
// WHERE providers.active = true
// Example 2: Loading child links for each provider
// System detects has-many → uses separate query automatically
db.NewSelect().
Model(&providers).
PreloadRelation("Links", func(q common.SelectQuery) common.SelectQuery {
return q.Where("active = ?", true)
}).
Scan(ctx, &providers)
// Generated SQL: Two queries
// Query 1: SELECT * FROM providers
// Query 2: SELECT * FROM links
// WHERE provider_id IN (1, 2, 3, ...)
// AND active = true
```
### Mixed Relationships
```go
type Order struct {
ID int
CustomerID int
Customer *Customer `bun:"rel:belongs-to"` // JOIN
Items []Item `bun:"rel:has-many"` // Separate
Invoice *Invoice `bun:"rel:has-one"` // JOIN
}
// All three handled optimally!
db.NewSelect().
Model(&orders).
PreloadRelation("Customer"). // → JOIN (many-to-one)
PreloadRelation("Items"). // → Separate (one-to-many)
PreloadRelation("Invoice"). // → JOIN (one-to-one)
Scan(ctx, &orders)
```
## Performance Benefits
### Before (Manual Strategy Selection)
```go
// You had to remember which to use:
.PreloadRelation("Provider") // Should I use PreloadRelation or JoinRelation?
.PreloadRelation("Links") // Which is more efficient here?
```
### After (Automatic Selection)
```go
// Just use PreloadRelation everywhere:
.PreloadRelation("Provider") // ✓ System uses JOIN automatically
.PreloadRelation("Links") // ✓ System uses separate query automatically
```
## Migration Guide
**No changes needed!** If you're already using `PreloadRelation()`, it now automatically optimizes:
```go
// Before: Always used separate query
.PreloadRelation("Provider") // Inefficient: extra round trip
// After: Automatic optimization
.PreloadRelation("Provider") // ✓ Now uses JOIN automatically!
```
## Implementation Details
### Supported Bun Tags
- `rel:has-many` → Separate query
- `rel:belongs-to` → JOIN
- `rel:has-one` → JOIN
- `rel:many-to-many` or `rel:m2m` → Separate query
### Supported GORM Patterns
- `many2many:` tag → Separate query
- `foreignKey:` tag → JOIN (belongs-to)
- `[]Type` slice without many2many → Separate query (has-many)
- `*Type` pointer with foreignKey → JOIN (belongs-to)
- `*Type` pointer without foreignKey → JOIN (has-one)
### Fallback Behavior
- `[]Type` (slice) → Separate query (safe default for collections)
- `*Type` or `Type` (single) → JOIN (safe default for single relations)
- Unknown → Separate query (safest default)
## Debugging
To see strategy selection in action:
```go
// Enable debug logging
bunAdapter.EnableQueryDebug() // or gormAdapter.EnableQueryDebug()
// Run your query
db.NewSelect().
Model(&records).
PreloadRelation("RelationName").
Scan(ctx, &records)
// Check logs for:
// - "PreloadRelation 'X' detected as: belongs-to"
// - "Using JOIN strategy for belongs-to relation 'X'"
// - Actual SQL queries executed
```
## Best Practices
1. **Use PreloadRelation() for everything** - Let the system optimize
2. **Define proper relationship tags** - Ensures correct detection
3. **Only use JoinRelation() for overrides** - When you know better than auto-detection
4. **Enable debug logging during development** - Verify optimal strategies are chosen
5. **Trust the system** - It's designed to choose correctly based on relationship type

View File

@@ -13,7 +13,7 @@ func TestNormalizeTableAlias(t *testing.T) {
want string want string
}{ }{
{ {
name: "strips incorrect alias from simple condition", name: "strips plausible alias from simple condition",
query: "APIL.rid_hub = 2576", query: "APIL.rid_hub = 2576",
expectedAlias: "apiproviderlink", expectedAlias: "apiproviderlink",
tableName: "apiproviderlink", tableName: "apiproviderlink",
@@ -27,14 +27,14 @@ func TestNormalizeTableAlias(t *testing.T) {
want: "apiproviderlink.rid_hub = 2576", want: "apiproviderlink.rid_hub = 2576",
}, },
{ {
name: "strips incorrect alias with multiple conditions", name: "strips plausible alias with multiple conditions",
query: "APIL.rid_hub = ? AND APIL.active = ?", query: "APIL.rid_hub = ? AND APIL.active = ?",
expectedAlias: "apiproviderlink", expectedAlias: "apiproviderlink",
tableName: "apiproviderlink", tableName: "apiproviderlink",
want: "rid_hub = ? AND active = ?", want: "rid_hub = ? AND active = ?",
}, },
{ {
name: "handles mixed correct and incorrect aliases", name: "handles mixed correct and plausible aliases",
query: "APIL.rid_hub = ? AND apiproviderlink.active = ?", query: "APIL.rid_hub = ? AND apiproviderlink.active = ?",
expectedAlias: "apiproviderlink", expectedAlias: "apiproviderlink",
tableName: "apiproviderlink", tableName: "apiproviderlink",
@@ -54,6 +54,20 @@ func TestNormalizeTableAlias(t *testing.T) {
tableName: "apiproviderlink", tableName: "apiproviderlink",
want: "rid_hub = ?", want: "rid_hub = ?",
}, },
{
name: "keeps reference to different table (not in current table name)",
query: "APIL.rid_hub = ?",
expectedAlias: "apiprovider",
tableName: "apiprovider",
want: "APIL.rid_hub = ?",
},
{
name: "keeps reference with short prefix that might be ambiguous",
query: "AP.rid = ?",
expectedAlias: "apiprovider",
tableName: "apiprovider",
want: "AP.rid = ?",
},
} }
for _, tt := range tests { for _, tt := range tests {

View File

@@ -140,6 +140,8 @@ type BunSelectQuery struct {
tableName string // Just the table name, without schema tableName string // Just the table name, without schema
tableAlias string tableAlias string
deferredPreloads []deferredPreload // Preloads to execute as separate queries deferredPreloads []deferredPreload // Preloads to execute as separate queries
inJoinContext bool // Track if we're in a JOIN relation context
joinTableAlias string // Alias to use for JOIN conditions
} }
// deferredPreload represents a preload that will be executed as a separate query // deferredPreload represents a preload that will be executed as a separate query
@@ -189,28 +191,93 @@ func (b *BunSelectQuery) ColumnExpr(query string, args ...interface{}) common.Se
} }
func (b *BunSelectQuery) Where(query string, args ...interface{}) common.SelectQuery { func (b *BunSelectQuery) Where(query string, args ...interface{}) common.SelectQuery {
// If we're in a JOIN context, add table prefix to unqualified columns
if b.inJoinContext && b.joinTableAlias != "" {
query = addTablePrefix(query, b.joinTableAlias)
} else if b.tableAlias != "" && b.tableName != "" {
// If we have a table alias defined, check if the query references a different alias // If we have a table alias defined, check if the query references a different alias
// This can happen in preloads where the user expects a certain alias but Bun generates another // This can happen in preloads where the user expects a certain alias but Bun generates another
if b.tableAlias != "" && b.tableName != "" {
// Detect if query contains a qualified column reference (e.g., "APIL.column")
// and replace it with the unqualified version or the correct alias
query = normalizeTableAlias(query, b.tableAlias, b.tableName) query = normalizeTableAlias(query, b.tableAlias, b.tableName)
} }
b.query = b.query.Where(query, args...) b.query = b.query.Where(query, args...)
return b return b
} }
// addTablePrefix adds a table prefix to unqualified column references
// This is used in JOIN contexts where conditions must reference the joined table
func addTablePrefix(query, tableAlias string) string {
if tableAlias == "" || query == "" {
return query
}
// Split on spaces and parentheses to find column references
parts := strings.FieldsFunc(query, func(r rune) bool {
return r == ' ' || r == '(' || r == ')' || r == ','
})
modified := query
for _, part := range parts {
// Check if this looks like an unqualified column reference
// (no dot, and likely a column name before an operator)
if !strings.Contains(part, ".") {
// Extract potential column name (before = or other operators)
for _, op := range []string{"=", "!=", "<>", ">", ">=", "<", "<=", " LIKE ", " IN ", " IS "} {
if strings.Contains(part, op) {
colName := strings.Split(part, op)[0]
colName = strings.TrimSpace(colName)
if colName != "" && !isOperatorOrKeyword(colName) {
// Add table prefix
prefixed := tableAlias + "." + colName + strings.TrimPrefix(part, colName)
modified = strings.ReplaceAll(modified, part, prefixed)
logger.Debug("Adding table prefix '%s' to column '%s' in JOIN condition", tableAlias, colName)
}
break
}
}
}
}
return modified
}
// isOperatorOrKeyword checks if a string is likely an operator or SQL keyword
func isOperatorOrKeyword(s string) bool {
s = strings.ToUpper(strings.TrimSpace(s))
keywords := []string{"AND", "OR", "NOT", "IN", "IS", "NULL", "TRUE", "FALSE", "LIKE", "BETWEEN"}
for _, kw := range keywords {
if s == kw {
return true
}
}
return false
}
// isAcronymMatch checks if prefix is an acronym of tableName
// For example, "apil" matches "apiproviderlink" because each letter appears in sequence
func isAcronymMatch(prefix, tableName string) bool {
if len(prefix) == 0 || len(tableName) == 0 {
return false
}
prefixIdx := 0
for i := 0; i < len(tableName) && prefixIdx < len(prefix); i++ {
if tableName[i] == prefix[prefixIdx] {
prefixIdx++
}
}
// All characters of prefix were found in sequence in tableName
return prefixIdx == len(prefix)
}
// normalizeTableAlias replaces table alias prefixes in SQL conditions // normalizeTableAlias replaces table alias prefixes in SQL conditions
// This handles cases where a user references a table alias that doesn't match // This handles cases where a user references a table alias that doesn't match
// what Bun generates (common in preload contexts) // what Bun generates (common in preload contexts)
func normalizeTableAlias(query, expectedAlias, tableName string) string { func normalizeTableAlias(query, expectedAlias, tableName string) string {
// Pattern: <word>.<column> where <word> might be an incorrect alias // Pattern: <word>.<column> where <word> might be an incorrect alias
// We'll look for patterns like "APIL.column" and either: // We'll look for patterns like "APIL.column" and either:
// 1. Remove the alias prefix entirely (safest) // 1. Remove the alias prefix if it's clearly meant for this table
// 2. Replace with the expected alias // 2. Leave it alone if it might be referring to another table (JOIN/preload)
// For now, we'll use a simple approach: if the query contains a dot (qualified reference)
// and that prefix is not the expected alias or table name, strip it
// Split on spaces and parentheses to find qualified references // Split on spaces and parentheses to find qualified references
parts := strings.FieldsFunc(query, func(r rune) bool { parts := strings.FieldsFunc(query, func(r rune) bool {
@@ -225,13 +292,39 @@ func normalizeTableAlias(query, expectedAlias, tableName string) string {
column := part[dotIndex+1:] column := part[dotIndex+1:]
// Check if the prefix matches our expected alias or table name (case-insensitive) // Check if the prefix matches our expected alias or table name (case-insensitive)
if !strings.EqualFold(prefix, expectedAlias) && if strings.EqualFold(prefix, expectedAlias) ||
!strings.EqualFold(prefix, tableName) && strings.EqualFold(prefix, tableName) ||
!strings.EqualFold(prefix, strings.ToLower(tableName)) { strings.EqualFold(prefix, strings.ToLower(tableName)) {
// This is a different alias - remove the prefix // Prefix matches current table, it's safe but redundant - leave it
logger.Debug("Stripping incorrect alias '%s' from WHERE condition, keeping just '%s'", prefix, column) continue
}
// Check if the prefix could plausibly be an alias/acronym for this table
// Only strip if we're confident it's meant for this table
// For example: "APIL" could be an acronym for "apiproviderlink"
prefixLower := strings.ToLower(prefix)
tableNameLower := strings.ToLower(tableName)
// Check if prefix is a substring of table name
isSubstring := strings.Contains(tableNameLower, prefixLower) && len(prefixLower) > 2
// Check if prefix is an acronym of table name
// e.g., "APIL" matches "ApiProviderLink" (A-p-I-providerL-ink)
isAcronym := false
if !isSubstring && len(prefixLower) > 2 {
isAcronym = isAcronymMatch(prefixLower, tableNameLower)
}
if isSubstring || isAcronym {
// This looks like it could be an alias for this table - strip it
logger.Debug("Stripping plausible alias '%s' from WHERE condition, keeping just '%s'", prefix, column)
// Replace the qualified reference with just the column name // Replace the qualified reference with just the column name
modified = strings.ReplaceAll(modified, part, column) modified = strings.ReplaceAll(modified, part, column)
} else {
// Prefix doesn't match the current table at all
// It's likely referring to a different table (JOIN/preload)
// DON'T strip it - leave the qualified reference as-is
logger.Debug("Keeping qualified reference '%s' - prefix '%s' doesn't match current table '%s'", part, prefix, tableName)
} }
} }
} }
@@ -367,6 +460,27 @@ func (b *BunSelectQuery) Preload(relation string, conditions ...interface{}) com
// } // }
func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery { func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
// Auto-detect relationship type and choose optimal loading strategy
// Get the model from the query if available
model := b.query.GetModel()
if model != nil && model.Value() != nil {
relType := reflection.GetRelationType(model.Value(), relation)
// Log the detected relationship type
logger.Debug("PreloadRelation '%s' detected as: %s", relation, relType)
// If this is a belongs-to or has-one relation, use JOIN for better performance
if relType.ShouldUseJoin() {
logger.Info("Using JOIN strategy for %s relation '%s'", relType, relation)
return b.JoinRelation(relation, apply...)
}
// For has-many, many-to-many, or unknown: use separate query (safer default)
if relType == reflection.RelationHasMany || relType == reflection.RelationManyToMany {
logger.Debug("Using separate query for %s relation '%s'", relType, relation)
}
}
// Check if this relation chain would create problematic long aliases // Check if this relation chain would create problematic long aliases
relationParts := strings.Split(relation, ".") relationParts := strings.Split(relation, ".")
aliasChain := strings.ToLower(strings.Join(relationParts, "__")) aliasChain := strings.ToLower(strings.Join(relationParts, "__"))
@@ -473,6 +587,36 @@ func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.S
return b return b
} }
func (b *BunSelectQuery) JoinRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
// JoinRelation uses a LEFT JOIN instead of a separate query
// This is more efficient for many-to-one or one-to-one relationships
logger.Debug("JoinRelation '%s' - Using JOIN strategy with automatic WHERE prefix addition", relation)
// Wrap the apply functions to automatically add table prefix to WHERE conditions
wrappedApply := make([]func(common.SelectQuery) common.SelectQuery, 0, len(apply))
for _, fn := range apply {
if fn != nil {
wrappedFn := func(originalFn func(common.SelectQuery) common.SelectQuery) func(common.SelectQuery) common.SelectQuery {
return func(q common.SelectQuery) common.SelectQuery {
// Create a special wrapper that adds prefixes to WHERE conditions
if bunQuery, ok := q.(*BunSelectQuery); ok {
// Mark this query as being in JOIN context
bunQuery.inJoinContext = true
bunQuery.joinTableAlias = strings.ToLower(relation)
}
return originalFn(q)
}
}(fn)
wrappedApply = append(wrappedApply, wrappedFn)
}
}
// Use PreloadRelation with the wrapped functions
// Bun's Relation() will use JOIN for belongs-to and has-one relations
return b.PreloadRelation(relation, wrappedApply...)
}
func (b *BunSelectQuery) Order(order string) common.SelectQuery { func (b *BunSelectQuery) Order(order string) common.SelectQuery {
b.query = b.query.Order(order) b.query = b.query.Order(order)
return b return b

View File

@@ -108,6 +108,8 @@ type GormSelectQuery struct {
schema string // Separated schema name schema string // Separated schema name
tableName string // Just the table name, without schema tableName string // Just the table name, without schema
tableAlias string tableAlias string
inJoinContext bool // Track if we're in a JOIN relation context
joinTableAlias string // Alias to use for JOIN conditions
} }
func (g *GormSelectQuery) Model(model interface{}) common.SelectQuery { func (g *GormSelectQuery) Model(model interface{}) common.SelectQuery {
@@ -151,10 +153,61 @@ func (g *GormSelectQuery) ColumnExpr(query string, args ...interface{}) common.S
} }
func (g *GormSelectQuery) Where(query string, args ...interface{}) common.SelectQuery { func (g *GormSelectQuery) Where(query string, args ...interface{}) common.SelectQuery {
// If we're in a JOIN context, add table prefix to unqualified columns
if g.inJoinContext && g.joinTableAlias != "" {
query = addTablePrefixGorm(query, g.joinTableAlias)
}
g.db = g.db.Where(query, args...) g.db = g.db.Where(query, args...)
return g return g
} }
// addTablePrefixGorm adds a table prefix to unqualified column references (GORM version)
func addTablePrefixGorm(query, tableAlias string) string {
if tableAlias == "" || query == "" {
return query
}
// Split on spaces and parentheses to find column references
parts := strings.FieldsFunc(query, func(r rune) bool {
return r == ' ' || r == '(' || r == ')' || r == ','
})
modified := query
for _, part := range parts {
// Check if this looks like an unqualified column reference
if !strings.Contains(part, ".") {
// Extract potential column name (before = or other operators)
for _, op := range []string{"=", "!=", "<>", ">", ">=", "<", "<=", " LIKE ", " IN ", " IS "} {
if strings.Contains(part, op) {
colName := strings.Split(part, op)[0]
colName = strings.TrimSpace(colName)
if colName != "" && !isOperatorOrKeywordGorm(colName) {
// Add table prefix
prefixed := tableAlias + "." + colName + strings.TrimPrefix(part, colName)
modified = strings.ReplaceAll(modified, part, prefixed)
logger.Debug("Adding table prefix '%s' to column '%s' in JOIN condition", tableAlias, colName)
}
break
}
}
}
}
return modified
}
// isOperatorOrKeywordGorm checks if a string is likely an operator or SQL keyword (GORM version)
func isOperatorOrKeywordGorm(s string) bool {
s = strings.ToUpper(strings.TrimSpace(s))
keywords := []string{"AND", "OR", "NOT", "IN", "IS", "NULL", "TRUE", "FALSE", "LIKE", "BETWEEN"}
for _, kw := range keywords {
if s == kw {
return true
}
}
return false
}
func (g *GormSelectQuery) WhereOr(query string, args ...interface{}) common.SelectQuery { func (g *GormSelectQuery) WhereOr(query string, args ...interface{}) common.SelectQuery {
g.db = g.db.Or(query, args...) g.db = g.db.Or(query, args...)
return g return g
@@ -238,6 +291,27 @@ func (g *GormSelectQuery) Preload(relation string, conditions ...interface{}) co
} }
func (g *GormSelectQuery) PreloadRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery { func (g *GormSelectQuery) PreloadRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
// Auto-detect relationship type and choose optimal loading strategy
// Get the model from GORM's statement if available
if g.db.Statement != nil && g.db.Statement.Model != nil {
relType := reflection.GetRelationType(g.db.Statement.Model, relation)
// Log the detected relationship type
logger.Debug("PreloadRelation '%s' detected as: %s", relation, relType)
// If this is a belongs-to or has-one relation, use JOIN for better performance
if relType.ShouldUseJoin() {
logger.Info("Using JOIN strategy for %s relation '%s'", relType, relation)
return g.JoinRelation(relation, apply...)
}
// For has-many, many-to-many, or unknown: use separate query (safer default)
if relType == reflection.RelationHasMany || relType == reflection.RelationManyToMany {
logger.Debug("Using separate query for %s relation '%s'", relType, relation)
}
}
// Use GORM's Preload (separate query strategy)
g.db = g.db.Preload(relation, func(db *gorm.DB) *gorm.DB { g.db = g.db.Preload(relation, func(db *gorm.DB) *gorm.DB {
if len(apply) == 0 { if len(apply) == 0 {
return db return db
@@ -267,6 +341,42 @@ func (g *GormSelectQuery) PreloadRelation(relation string, apply ...func(common.
return g return g
} }
func (g *GormSelectQuery) JoinRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
// JoinRelation uses a JOIN instead of a separate preload query
// This is more efficient for many-to-one or one-to-one relationships
// as it avoids additional round trips to the database
// GORM's Joins() method forces a JOIN for the preload
logger.Debug("JoinRelation '%s' - Using GORM Joins() with automatic WHERE prefix addition", relation)
g.db = g.db.Joins(relation, func(db *gorm.DB) *gorm.DB {
if len(apply) == 0 {
return db
}
wrapper := &GormSelectQuery{
db: db,
inJoinContext: true, // Mark as JOIN context
joinTableAlias: strings.ToLower(relation), // Use relation name as alias
}
current := common.SelectQuery(wrapper)
for _, fn := range apply {
if fn != nil {
current = fn(current)
}
}
if finalGorm, ok := current.(*GormSelectQuery); ok {
return finalGorm.db
}
return db
})
return g
}
func (g *GormSelectQuery) Order(order string) common.SelectQuery { func (g *GormSelectQuery) Order(order string) common.SelectQuery {
g.db = g.db.Order(order) g.db = g.db.Order(order)
return g return g

View File

@@ -38,6 +38,7 @@ type SelectQuery interface {
LeftJoin(query string, args ...interface{}) SelectQuery LeftJoin(query string, args ...interface{}) SelectQuery
Preload(relation string, conditions ...interface{}) SelectQuery Preload(relation string, conditions ...interface{}) SelectQuery
PreloadRelation(relation string, apply ...func(SelectQuery) SelectQuery) SelectQuery PreloadRelation(relation string, apply ...func(SelectQuery) SelectQuery) SelectQuery
JoinRelation(relation string, apply ...func(SelectQuery) SelectQuery) SelectQuery
Order(order string) SelectQuery Order(order string) SelectQuery
Limit(n int) SelectQuery Limit(n int) SelectQuery
Offset(n int) SelectQuery Offset(n int) SelectQuery

View File

@@ -1,7 +1,6 @@
package common package common
import ( import (
"fmt"
"strings" "strings"
"github.com/bitechdev/ResolveSpec/pkg/logger" "github.com/bitechdev/ResolveSpec/pkg/logger"
@@ -9,81 +8,40 @@ import (
"github.com/bitechdev/ResolveSpec/pkg/reflection" "github.com/bitechdev/ResolveSpec/pkg/reflection"
) )
// ValidateAndFixPreloadWhere validates that the WHERE clause for a preload contains // ValidateAndFixPreloadWhere validates and normalizes WHERE clauses for preloads
// the relation prefix (alias). If not present, it attempts to add it to column references. //
// Returns the fixed WHERE clause and an error if it cannot be safely fixed. // NOTE: For preload queries, table aliases from the parent query are not valid since
// the preload executes as a separate query with its own table alias. This function
// now simply validates basic syntax without requiring or adding prefixes.
// The actual alias normalization happens in the database adapter layer.
//
// Returns the WHERE clause and an error if it contains obviously invalid syntax.
func ValidateAndFixPreloadWhere(where string, relationName string) (string, error) { func ValidateAndFixPreloadWhere(where string, relationName string) (string, error) {
if where == "" { if where == "" {
return where, nil return where, nil
} }
// Check if the relation name is already present in the WHERE clause where = strings.TrimSpace(where)
lowerWhere := strings.ToLower(where)
lowerRelation := strings.ToLower(relationName)
// Check for patterns like "relation.", "relation ", or just "relation" followed by a dot // Just do basic validation - don't require or add prefixes
if strings.Contains(lowerWhere, lowerRelation+".") || // The database adapter will handle alias normalization
strings.Contains(lowerWhere, "`"+lowerRelation+"`.") ||
strings.Contains(lowerWhere, "\""+lowerRelation+"\".") { // Check if the WHERE clause contains any qualified column references
// Relation prefix is already present // If it does, log a debug message but don't fail - let the adapter handle it
if strings.Contains(where, ".") {
logger.Debug("Preload WHERE clause for '%s' contains qualified column references: '%s'. "+
"Note: In preload context, table aliases from parent query are not available. "+
"The database adapter will normalize aliases automatically.", relationName, where)
}
// Validate that it's not empty or just whitespace
if where == "" {
return where, nil return where, nil
} }
// If the WHERE clause is complex (contains OR, parentheses, subqueries, etc.), // Return the WHERE clause as-is
// we can't safely auto-fix it - require explicit prefix // The BunSelectQuery.Where() method will handle alias normalization via normalizeTableAlias()
if strings.Contains(lowerWhere, " or ") || return where, nil
strings.Contains(where, "(") ||
strings.Contains(where, ")") {
return "", fmt.Errorf("preload WHERE condition must reference the relation '%s' (e.g., '%s.column_name'). Complex WHERE clauses with OR/parentheses must explicitly use the relation prefix", relationName, relationName)
}
// Try to add the relation prefix to simple column references
// This handles basic cases like "column = value" or "column = value AND other_column = value"
// Split by AND to handle multiple conditions (case-insensitive)
originalConditions := strings.Split(where, " AND ")
// If uppercase split didn't work, try lowercase
if len(originalConditions) == 1 {
originalConditions = strings.Split(where, " and ")
}
fixedConditions := make([]string, 0, len(originalConditions))
for _, cond := range originalConditions {
cond = strings.TrimSpace(cond)
if cond == "" {
continue
}
// Check if this condition already has a table prefix (contains a dot)
if strings.Contains(cond, ".") {
fixedConditions = append(fixedConditions, cond)
continue
}
// Check if this is a SQL expression/literal that shouldn't be prefixed
lowerCond := strings.ToLower(strings.TrimSpace(cond))
if IsSQLExpression(lowerCond) {
// Don't prefix SQL expressions like "true", "false", "1=1", etc.
fixedConditions = append(fixedConditions, cond)
continue
}
// Extract the column name (first identifier before operator)
columnName := ExtractColumnName(cond)
if columnName == "" {
// Can't identify column name, require explicit prefix
return "", fmt.Errorf("preload WHERE condition must reference the relation '%s' (e.g., '%s.column_name'). Cannot auto-fix condition: %s", relationName, relationName, cond)
}
// Add relation prefix to the column name only
fixedCond := strings.Replace(cond, columnName, relationName+"."+columnName, 1)
fixedConditions = append(fixedConditions, fixedCond)
}
fixedWhere := strings.Join(fixedConditions, " AND ")
logger.Debug("Auto-fixed preload WHERE clause: '%s' -> '%s'", where, fixedWhere)
return fixedWhere, nil
} }
// IsSQLExpression checks if a condition is a SQL expression that shouldn't be prefixed // IsSQLExpression checks if a condition is a SQL expression that shouldn't be prefixed

View File

@@ -0,0 +1,331 @@
package reflection
import (
"reflect"
"testing"
)
// Test models for GetModelColumnDetail
type TestModelForColumnDetail struct {
ID int `gorm:"column:rid_test;primaryKey;type:bigserial;not null" json:"id"`
Name string `gorm:"column:name;type:varchar(255);not null" json:"name"`
Email string `gorm:"column:email;type:varchar(255);unique;nullable" json:"email"`
Description string `gorm:"column:description;type:text;null" json:"description"`
ForeignKey int `gorm:"foreignKey:parent_id" json:"foreign_key"`
}
type EmbeddedBase struct {
ID int `gorm:"column:rid_base;primaryKey;identity" json:"id"`
CreatedAt string `gorm:"column:created_at;type:timestamp" json:"created_at"`
}
type ModelWithEmbeddedForDetail struct {
EmbeddedBase
Title string `gorm:"column:title;type:varchar(100);not null" json:"title"`
Content string `gorm:"column:content;type:text" json:"content"`
}
// Model with nil embedded pointer
type ModelWithNilEmbedded struct {
ID int `gorm:"column:id;primaryKey" json:"id"`
*EmbeddedBase
Name string `gorm:"column:name" json:"name"`
}
func TestGetModelColumnDetail(t *testing.T) {
t.Run("simple struct", func(t *testing.T) {
model := TestModelForColumnDetail{
ID: 1,
Name: "Test",
Email: "test@example.com",
Description: "Test description",
ForeignKey: 100,
}
details := GetModelColumnDetail(reflect.ValueOf(model))
if len(details) != 5 {
t.Errorf("Expected 5 fields, got %d", len(details))
}
// Check ID field
found := false
for _, detail := range details {
if detail.Name == "ID" {
found = true
if detail.SQLName != "rid_test" {
t.Errorf("Expected SQLName 'rid_test', got '%s'", detail.SQLName)
}
// Note: primaryKey (without underscore) is not detected as primary_key
// The function looks for "identity" or "primary_key" (with underscore)
if detail.SQLDataType != "bigserial" {
t.Errorf("Expected SQLDataType 'bigserial', got '%s'", detail.SQLDataType)
}
if detail.Nullable {
t.Errorf("Expected Nullable false, got true")
}
}
}
if !found {
t.Errorf("ID field not found in details")
}
})
t.Run("struct with embedded fields", func(t *testing.T) {
model := ModelWithEmbeddedForDetail{
EmbeddedBase: EmbeddedBase{
ID: 1,
CreatedAt: "2024-01-01",
},
Title: "Test Title",
Content: "Test Content",
}
details := GetModelColumnDetail(reflect.ValueOf(model))
// Should have 4 fields: ID, CreatedAt from embedded, Title, Content from main
if len(details) != 4 {
t.Errorf("Expected 4 fields, got %d", len(details))
}
// Check that embedded field is included
foundID := false
foundCreatedAt := false
for _, detail := range details {
if detail.Name == "ID" {
foundID = true
if detail.SQLKey != "primary_key" {
t.Errorf("Expected SQLKey 'primary_key' for embedded ID, got '%s'", detail.SQLKey)
}
}
if detail.Name == "CreatedAt" {
foundCreatedAt = true
}
}
if !foundID {
t.Errorf("Embedded ID field not found")
}
if !foundCreatedAt {
t.Errorf("Embedded CreatedAt field not found")
}
})
t.Run("nil embedded pointer is skipped", func(t *testing.T) {
model := ModelWithNilEmbedded{
ID: 1,
Name: "Test",
EmbeddedBase: nil, // nil embedded pointer
}
details := GetModelColumnDetail(reflect.ValueOf(model))
// Should have 2 fields: ID and Name (embedded is nil, so skipped)
if len(details) != 2 {
t.Errorf("Expected 2 fields (nil embedded skipped), got %d", len(details))
}
})
t.Run("pointer to struct", func(t *testing.T) {
model := &TestModelForColumnDetail{
ID: 1,
Name: "Test",
}
details := GetModelColumnDetail(reflect.ValueOf(model))
if len(details) != 5 {
t.Errorf("Expected 5 fields, got %d", len(details))
}
})
t.Run("invalid value", func(t *testing.T) {
var invalid reflect.Value
details := GetModelColumnDetail(invalid)
if len(details) != 0 {
t.Errorf("Expected 0 fields for invalid value, got %d", len(details))
}
})
t.Run("non-struct type", func(t *testing.T) {
details := GetModelColumnDetail(reflect.ValueOf(123))
if len(details) != 0 {
t.Errorf("Expected 0 fields for non-struct, got %d", len(details))
}
})
t.Run("nullable and not null detection", func(t *testing.T) {
model := TestModelForColumnDetail{}
details := GetModelColumnDetail(reflect.ValueOf(model))
for _, detail := range details {
switch detail.Name {
case "ID":
if detail.Nullable {
t.Errorf("ID should not be nullable (has 'not null')")
}
case "Name":
if detail.Nullable {
t.Errorf("Name should not be nullable (has 'not null')")
}
case "Email":
if !detail.Nullable {
t.Errorf("Email should be nullable (has 'nullable')")
}
case "Description":
if !detail.Nullable {
t.Errorf("Description should be nullable (has 'null')")
}
}
}
})
t.Run("unique and uniqueindex detection", func(t *testing.T) {
type UniqueTestModel struct {
ID int `gorm:"column:id;primary_key"`
Username string `gorm:"column:username;unique"`
Email string `gorm:"column:email;uniqueindex"`
}
model := UniqueTestModel{}
details := GetModelColumnDetail(reflect.ValueOf(model))
for _, detail := range details {
switch detail.Name {
case "ID":
if detail.SQLKey != "primary_key" {
t.Errorf("ID should have SQLKey 'primary_key', got '%s'", detail.SQLKey)
}
case "Username":
if detail.SQLKey != "unique" {
t.Errorf("Username should have SQLKey 'unique', got '%s'", detail.SQLKey)
}
case "Email":
// The function checks for "unique" first, so uniqueindex is also detected as "unique"
// This is expected behavior based on the code logic
if detail.SQLKey != "unique" {
t.Errorf("Email should have SQLKey 'unique' (uniqueindex contains 'unique'), got '%s'", detail.SQLKey)
}
}
}
})
t.Run("foreign key detection", func(t *testing.T) {
// Note: The foreignkey extraction in generic_model.go has a bug where
// it requires ik > 0, so foreignkey at the start won't extract the value
type FKTestModel struct {
ParentID int `gorm:"column:parent_id;foreignkey:rid_parent;association_foreignkey:id_atevent"`
}
model := FKTestModel{}
details := GetModelColumnDetail(reflect.ValueOf(model))
if len(details) == 0 {
t.Fatal("Expected at least 1 field")
}
detail := details[0]
if detail.SQLKey != "foreign_key" {
t.Errorf("Expected SQLKey 'foreign_key', got '%s'", detail.SQLKey)
}
// Due to the bug in the code (requires ik > 0), the SQLName will be extracted
// when foreignkey is not at the beginning of the string
if detail.SQLName != "rid_parent" {
t.Errorf("Expected SQLName 'rid_parent', got '%s'", detail.SQLName)
}
})
}
func TestFnFindKeyVal(t *testing.T) {
tests := []struct {
name string
src string
key string
expected string
}{
{
name: "find column",
src: "column:user_id;primaryKey;type:bigint",
key: "column:",
expected: "user_id",
},
{
name: "find type",
src: "column:name;type:varchar(255);not null",
key: "type:",
expected: "varchar(255)",
},
{
name: "key not found",
src: "primaryKey;autoIncrement",
key: "column:",
expected: "",
},
{
name: "key at end without semicolon",
src: "primaryKey;column:id",
key: "column:",
expected: "id",
},
{
name: "case insensitive search",
src: "Column:user_id;primaryKey",
key: "column:",
expected: "user_id",
},
{
name: "empty src",
src: "",
key: "column:",
expected: "",
},
{
name: "multiple occurrences (returns first)",
src: "column:first;column:second",
key: "column:",
expected: "first",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := fnFindKeyVal(tt.src, tt.key)
if result != tt.expected {
t.Errorf("fnFindKeyVal(%q, %q) = %q, want %q", tt.src, tt.key, result, tt.expected)
}
})
}
}
func TestGetModelColumnDetail_FieldValue(t *testing.T) {
model := TestModelForColumnDetail{
ID: 123,
Name: "TestName",
Email: "test@example.com",
}
details := GetModelColumnDetail(reflect.ValueOf(model))
for _, detail := range details {
if !detail.FieldValue.IsValid() {
t.Errorf("Field %s has invalid FieldValue", detail.Name)
}
// Check that FieldValue matches the actual value
switch detail.Name {
case "ID":
if detail.FieldValue.Int() != 123 {
t.Errorf("Expected ID FieldValue 123, got %v", detail.FieldValue.Int())
}
case "Name":
if detail.FieldValue.String() != "TestName" {
t.Errorf("Expected Name FieldValue 'TestName', got %v", detail.FieldValue.String())
}
case "Email":
if detail.FieldValue.String() != "test@example.com" {
t.Errorf("Expected Email FieldValue 'test@example.com', got %v", detail.FieldValue.String())
}
}
}
}

View File

@@ -750,6 +750,118 @@ func ConvertToNumericType(value string, kind reflect.Kind) (interface{}, error)
return nil, fmt.Errorf("unsupported numeric type: %v", kind) return nil, fmt.Errorf("unsupported numeric type: %v", kind)
} }
// RelationType represents the type of database relationship
type RelationType string
const (
RelationHasMany RelationType = "has-many" // 1:N - use separate query
RelationBelongsTo RelationType = "belongs-to" // N:1 - use JOIN
RelationHasOne RelationType = "has-one" // 1:1 - use JOIN
RelationManyToMany RelationType = "many-to-many" // M:N - use separate query
RelationUnknown RelationType = "unknown"
)
// ShouldUseJoin returns true if the relation type should use a JOIN instead of separate query
func (rt RelationType) ShouldUseJoin() bool {
return rt == RelationBelongsTo || rt == RelationHasOne
}
// GetRelationType inspects the model's struct tags to determine the relationship type
// It checks both Bun and GORM tags to identify the relationship cardinality
func GetRelationType(model interface{}, fieldName string) RelationType {
if model == nil || fieldName == "" {
return RelationUnknown
}
modelType := reflect.TypeOf(model)
if modelType == nil {
return RelationUnknown
}
if modelType.Kind() == reflect.Ptr {
modelType = modelType.Elem()
}
if modelType == nil || modelType.Kind() != reflect.Struct {
return RelationUnknown
}
// Find the field
for i := 0; i < modelType.NumField(); i++ {
field := modelType.Field(i)
// Check if field name matches (case-insensitive)
if !strings.EqualFold(field.Name, fieldName) {
continue
}
// Check Bun tags first
bunTag := field.Tag.Get("bun")
if bunTag != "" && strings.Contains(bunTag, "rel:") {
// Parse bun relation tag: rel:has-many, rel:belongs-to, rel:has-one, rel:many-to-many
parts := strings.Split(bunTag, ",")
for _, part := range parts {
part = strings.TrimSpace(part)
if strings.HasPrefix(part, "rel:") {
relType := strings.TrimPrefix(part, "rel:")
switch relType {
case "has-many":
return RelationHasMany
case "belongs-to":
return RelationBelongsTo
case "has-one":
return RelationHasOne
case "many-to-many", "m2m":
return RelationManyToMany
}
}
}
}
// Check GORM tags
gormTag := field.Tag.Get("gorm")
if gormTag != "" {
// GORM uses different patterns:
// - foreignKey: usually indicates belongs-to or has-one
// - many2many: indicates many-to-many
// - Field type (slice vs pointer) helps determine cardinality
if strings.Contains(gormTag, "many2many:") {
return RelationManyToMany
}
// Check field type for cardinality hints
fieldType := field.Type
if fieldType.Kind() == reflect.Slice {
// Slice indicates has-many or many-to-many
return RelationHasMany
}
if fieldType.Kind() == reflect.Ptr {
// Pointer to single struct usually indicates belongs-to or has-one
// Check if it has foreignKey (belongs-to) or references (has-one)
if strings.Contains(gormTag, "foreignKey:") {
return RelationBelongsTo
}
return RelationHasOne
}
}
// Fall back to field type inference
fieldType := field.Type
if fieldType.Kind() == reflect.Slice {
// Slice of structs → has-many
return RelationHasMany
}
if fieldType.Kind() == reflect.Ptr || fieldType.Kind() == reflect.Struct {
// Single struct → belongs-to (default assumption for safety)
// Using belongs-to as default ensures we use JOIN, which is safer
return RelationBelongsTo
}
}
return RelationUnknown
}
// GetRelationModel gets the model type for a relation field // GetRelationModel gets the model type for a relation field
// It searches for the field by name in the following order (case-insensitive): // It searches for the field by name in the following order (case-insensitive):
// 1. Actual field name // 1. Actual field name

File diff suppressed because it is too large Load Diff