mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2025-12-30 16:24:26 +00:00
Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e1abd5ebc1 | ||
|
|
ca4e53969b | ||
|
|
db2b7e878e | ||
|
|
9572bfc7b8 |
218
pkg/common/adapters/database/RELATION_LOADING.md
Normal file
218
pkg/common/adapters/database/RELATION_LOADING.md
Normal file
@@ -0,0 +1,218 @@
|
|||||||
|
# Automatic Relation Loading Strategies
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
**NEW:** The database adapters now **automatically** choose the optimal loading strategy by inspecting your model's relationship tags!
|
||||||
|
|
||||||
|
Simply use `PreloadRelation()` and the system automatically:
|
||||||
|
- Detects relationship type from Bun/GORM tags
|
||||||
|
- Uses **JOIN** for many-to-one and one-to-one (efficient, no duplication)
|
||||||
|
- Uses **separate query** for one-to-many and many-to-many (avoids duplication)
|
||||||
|
|
||||||
|
## How It Works
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Just write this - the system handles the rest!
|
||||||
|
db.NewSelect().
|
||||||
|
Model(&links).
|
||||||
|
PreloadRelation("Provider"). // ✓ Auto-detects belongs-to → uses JOIN
|
||||||
|
PreloadRelation("Tags"). // ✓ Auto-detects has-many → uses separate query
|
||||||
|
Scan(ctx, &links)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Detection Logic
|
||||||
|
|
||||||
|
The system inspects your model's struct tags:
|
||||||
|
|
||||||
|
**Bun models:**
|
||||||
|
```go
|
||||||
|
type Link struct {
|
||||||
|
Provider *Provider `bun:"rel:belongs-to"` // → Detected: belongs-to → JOIN
|
||||||
|
Tags []Tag `bun:"rel:has-many"` // → Detected: has-many → Separate query
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**GORM models:**
|
||||||
|
```go
|
||||||
|
type Link struct {
|
||||||
|
ProviderID int
|
||||||
|
Provider *Provider `gorm:"foreignKey:ProviderID"` // → Detected: belongs-to → JOIN
|
||||||
|
Tags []Tag `gorm:"many2many:link_tags"` // → Detected: many-to-many → Separate query
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Type inference (fallback):**
|
||||||
|
- `[]Type` (slice) → has-many → Separate query
|
||||||
|
- `*Type` (pointer) → belongs-to → JOIN
|
||||||
|
- `Type` (struct) → belongs-to → JOIN
|
||||||
|
|
||||||
|
### What Gets Logged
|
||||||
|
|
||||||
|
Enable debug logging to see strategy selection:
|
||||||
|
|
||||||
|
```go
|
||||||
|
bunAdapter.EnableQueryDebug()
|
||||||
|
```
|
||||||
|
|
||||||
|
**Output:**
|
||||||
|
```
|
||||||
|
DEBUG: PreloadRelation 'Provider' detected as: belongs-to
|
||||||
|
INFO: Using JOIN strategy for belongs-to relation 'Provider'
|
||||||
|
DEBUG: PreloadRelation 'Links' detected as: has-many
|
||||||
|
DEBUG: Using separate query for has-many relation 'Links'
|
||||||
|
```
|
||||||
|
|
||||||
|
## Relationship Types
|
||||||
|
|
||||||
|
| Bun Tag | GORM Pattern | Field Type | Strategy | Why |
|
||||||
|
|---------|--------------|------------|----------|-----|
|
||||||
|
| `rel:has-many` | Slice field | `[]Type` | Separate Query | Avoids duplicating parent data |
|
||||||
|
| `rel:belongs-to` | `foreignKey:` | `*Type` | JOIN | Single parent, no duplication |
|
||||||
|
| `rel:has-one` | Single pointer | `*Type` | JOIN | One-to-one, no duplication |
|
||||||
|
| `rel:many-to-many` | `many2many:` | `[]Type` | Separate Query | Complex join, avoid cartesian |
|
||||||
|
|
||||||
|
## Manual Override
|
||||||
|
|
||||||
|
If you need to force a specific strategy, use `JoinRelation()`:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Force JOIN even for has-many (not recommended)
|
||||||
|
db.NewSelect().
|
||||||
|
Model(&providers).
|
||||||
|
JoinRelation("Links"). // Explicitly use JOIN
|
||||||
|
Scan(ctx, &providers)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Examples
|
||||||
|
|
||||||
|
### Automatic Strategy Selection (Recommended)
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Example 1: Loading parent provider for each link
|
||||||
|
// System detects belongs-to → uses JOIN automatically
|
||||||
|
db.NewSelect().
|
||||||
|
Model(&links).
|
||||||
|
PreloadRelation("Provider", func(q common.SelectQuery) common.SelectQuery {
|
||||||
|
return q.Where("active = ?", true)
|
||||||
|
}).
|
||||||
|
Scan(ctx, &links)
|
||||||
|
|
||||||
|
// Generated SQL: Single query with JOIN
|
||||||
|
// SELECT links.*, providers.*
|
||||||
|
// FROM links
|
||||||
|
// LEFT JOIN providers ON links.provider_id = providers.id
|
||||||
|
// WHERE providers.active = true
|
||||||
|
|
||||||
|
// Example 2: Loading child links for each provider
|
||||||
|
// System detects has-many → uses separate query automatically
|
||||||
|
db.NewSelect().
|
||||||
|
Model(&providers).
|
||||||
|
PreloadRelation("Links", func(q common.SelectQuery) common.SelectQuery {
|
||||||
|
return q.Where("active = ?", true)
|
||||||
|
}).
|
||||||
|
Scan(ctx, &providers)
|
||||||
|
|
||||||
|
// Generated SQL: Two queries
|
||||||
|
// Query 1: SELECT * FROM providers
|
||||||
|
// Query 2: SELECT * FROM links
|
||||||
|
// WHERE provider_id IN (1, 2, 3, ...)
|
||||||
|
// AND active = true
|
||||||
|
```
|
||||||
|
|
||||||
|
### Mixed Relationships
|
||||||
|
|
||||||
|
```go
|
||||||
|
type Order struct {
|
||||||
|
ID int
|
||||||
|
CustomerID int
|
||||||
|
Customer *Customer `bun:"rel:belongs-to"` // JOIN
|
||||||
|
Items []Item `bun:"rel:has-many"` // Separate
|
||||||
|
Invoice *Invoice `bun:"rel:has-one"` // JOIN
|
||||||
|
}
|
||||||
|
|
||||||
|
// All three handled optimally!
|
||||||
|
db.NewSelect().
|
||||||
|
Model(&orders).
|
||||||
|
PreloadRelation("Customer"). // → JOIN (many-to-one)
|
||||||
|
PreloadRelation("Items"). // → Separate (one-to-many)
|
||||||
|
PreloadRelation("Invoice"). // → JOIN (one-to-one)
|
||||||
|
Scan(ctx, &orders)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Performance Benefits
|
||||||
|
|
||||||
|
### Before (Manual Strategy Selection)
|
||||||
|
|
||||||
|
```go
|
||||||
|
// You had to remember which to use:
|
||||||
|
.PreloadRelation("Provider") // Should I use PreloadRelation or JoinRelation?
|
||||||
|
.PreloadRelation("Links") // Which is more efficient here?
|
||||||
|
```
|
||||||
|
|
||||||
|
### After (Automatic Selection)
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Just use PreloadRelation everywhere:
|
||||||
|
.PreloadRelation("Provider") // ✓ System uses JOIN automatically
|
||||||
|
.PreloadRelation("Links") // ✓ System uses separate query automatically
|
||||||
|
```
|
||||||
|
|
||||||
|
## Migration Guide
|
||||||
|
|
||||||
|
**No changes needed!** If you're already using `PreloadRelation()`, it now automatically optimizes:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Before: Always used separate query
|
||||||
|
.PreloadRelation("Provider") // Inefficient: extra round trip
|
||||||
|
|
||||||
|
// After: Automatic optimization
|
||||||
|
.PreloadRelation("Provider") // ✓ Now uses JOIN automatically!
|
||||||
|
```
|
||||||
|
|
||||||
|
## Implementation Details
|
||||||
|
|
||||||
|
### Supported Bun Tags
|
||||||
|
- `rel:has-many` → Separate query
|
||||||
|
- `rel:belongs-to` → JOIN
|
||||||
|
- `rel:has-one` → JOIN
|
||||||
|
- `rel:many-to-many` or `rel:m2m` → Separate query
|
||||||
|
|
||||||
|
### Supported GORM Patterns
|
||||||
|
- `many2many:` tag → Separate query
|
||||||
|
- `foreignKey:` tag → JOIN (belongs-to)
|
||||||
|
- `[]Type` slice without many2many → Separate query (has-many)
|
||||||
|
- `*Type` pointer with foreignKey → JOIN (belongs-to)
|
||||||
|
- `*Type` pointer without foreignKey → JOIN (has-one)
|
||||||
|
|
||||||
|
### Fallback Behavior
|
||||||
|
- `[]Type` (slice) → Separate query (safe default for collections)
|
||||||
|
- `*Type` or `Type` (single) → JOIN (safe default for single relations)
|
||||||
|
- Unknown → Separate query (safest default)
|
||||||
|
|
||||||
|
## Debugging
|
||||||
|
|
||||||
|
To see strategy selection in action:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Enable debug logging
|
||||||
|
bunAdapter.EnableQueryDebug() // or gormAdapter.EnableQueryDebug()
|
||||||
|
|
||||||
|
// Run your query
|
||||||
|
db.NewSelect().
|
||||||
|
Model(&records).
|
||||||
|
PreloadRelation("RelationName").
|
||||||
|
Scan(ctx, &records)
|
||||||
|
|
||||||
|
// Check logs for:
|
||||||
|
// - "PreloadRelation 'X' detected as: belongs-to"
|
||||||
|
// - "Using JOIN strategy for belongs-to relation 'X'"
|
||||||
|
// - Actual SQL queries executed
|
||||||
|
```
|
||||||
|
|
||||||
|
## Best Practices
|
||||||
|
|
||||||
|
1. **Use PreloadRelation() for everything** - Let the system optimize
|
||||||
|
2. **Define proper relationship tags** - Ensures correct detection
|
||||||
|
3. **Only use JoinRelation() for overrides** - When you know better than auto-detection
|
||||||
|
4. **Enable debug logging during development** - Verify optimal strategies are chosen
|
||||||
|
5. **Trust the system** - It's designed to choose correctly based on relationship type
|
||||||
81
pkg/common/adapters/database/alias_test.go
Normal file
81
pkg/common/adapters/database/alias_test.go
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
package database
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNormalizeTableAlias(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
query string
|
||||||
|
expectedAlias string
|
||||||
|
tableName string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "strips plausible alias from simple condition",
|
||||||
|
query: "APIL.rid_hub = 2576",
|
||||||
|
expectedAlias: "apiproviderlink",
|
||||||
|
tableName: "apiproviderlink",
|
||||||
|
want: "rid_hub = 2576",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "keeps correct alias",
|
||||||
|
query: "apiproviderlink.rid_hub = 2576",
|
||||||
|
expectedAlias: "apiproviderlink",
|
||||||
|
tableName: "apiproviderlink",
|
||||||
|
want: "apiproviderlink.rid_hub = 2576",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "strips plausible alias with multiple conditions",
|
||||||
|
query: "APIL.rid_hub = ? AND APIL.active = ?",
|
||||||
|
expectedAlias: "apiproviderlink",
|
||||||
|
tableName: "apiproviderlink",
|
||||||
|
want: "rid_hub = ? AND active = ?",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "handles mixed correct and plausible aliases",
|
||||||
|
query: "APIL.rid_hub = ? AND apiproviderlink.active = ?",
|
||||||
|
expectedAlias: "apiproviderlink",
|
||||||
|
tableName: "apiproviderlink",
|
||||||
|
want: "rid_hub = ? AND apiproviderlink.active = ?",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "handles parentheses",
|
||||||
|
query: "(APIL.rid_hub = ?)",
|
||||||
|
expectedAlias: "apiproviderlink",
|
||||||
|
tableName: "apiproviderlink",
|
||||||
|
want: "(rid_hub = ?)",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no alias in query",
|
||||||
|
query: "rid_hub = ?",
|
||||||
|
expectedAlias: "apiproviderlink",
|
||||||
|
tableName: "apiproviderlink",
|
||||||
|
want: "rid_hub = ?",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "keeps reference to different table (not in current table name)",
|
||||||
|
query: "APIL.rid_hub = ?",
|
||||||
|
expectedAlias: "apiprovider",
|
||||||
|
tableName: "apiprovider",
|
||||||
|
want: "APIL.rid_hub = ?",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "keeps reference with short prefix that might be ambiguous",
|
||||||
|
query: "AP.rid = ?",
|
||||||
|
expectedAlias: "apiprovider",
|
||||||
|
tableName: "apiprovider",
|
||||||
|
want: "AP.rid = ?",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := normalizeTableAlias(tt.query, tt.expectedAlias, tt.tableName)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("normalizeTableAlias() = %q, want %q", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -140,6 +140,8 @@ type BunSelectQuery struct {
|
|||||||
tableName string // Just the table name, without schema
|
tableName string // Just the table name, without schema
|
||||||
tableAlias string
|
tableAlias string
|
||||||
deferredPreloads []deferredPreload // Preloads to execute as separate queries
|
deferredPreloads []deferredPreload // Preloads to execute as separate queries
|
||||||
|
inJoinContext bool // Track if we're in a JOIN relation context
|
||||||
|
joinTableAlias string // Alias to use for JOIN conditions
|
||||||
}
|
}
|
||||||
|
|
||||||
// deferredPreload represents a preload that will be executed as a separate query
|
// deferredPreload represents a preload that will be executed as a separate query
|
||||||
@@ -189,10 +191,147 @@ func (b *BunSelectQuery) ColumnExpr(query string, args ...interface{}) common.Se
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (b *BunSelectQuery) Where(query string, args ...interface{}) common.SelectQuery {
|
func (b *BunSelectQuery) Where(query string, args ...interface{}) common.SelectQuery {
|
||||||
|
// If we're in a JOIN context, add table prefix to unqualified columns
|
||||||
|
if b.inJoinContext && b.joinTableAlias != "" {
|
||||||
|
query = addTablePrefix(query, b.joinTableAlias)
|
||||||
|
} else if b.tableAlias != "" && b.tableName != "" {
|
||||||
|
// If we have a table alias defined, check if the query references a different alias
|
||||||
|
// This can happen in preloads where the user expects a certain alias but Bun generates another
|
||||||
|
query = normalizeTableAlias(query, b.tableAlias, b.tableName)
|
||||||
|
}
|
||||||
b.query = b.query.Where(query, args...)
|
b.query = b.query.Where(query, args...)
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// addTablePrefix adds a table prefix to unqualified column references
|
||||||
|
// This is used in JOIN contexts where conditions must reference the joined table
|
||||||
|
func addTablePrefix(query, tableAlias string) string {
|
||||||
|
if tableAlias == "" || query == "" {
|
||||||
|
return query
|
||||||
|
}
|
||||||
|
|
||||||
|
// Split on spaces and parentheses to find column references
|
||||||
|
parts := strings.FieldsFunc(query, func(r rune) bool {
|
||||||
|
return r == ' ' || r == '(' || r == ')' || r == ','
|
||||||
|
})
|
||||||
|
|
||||||
|
modified := query
|
||||||
|
for _, part := range parts {
|
||||||
|
// Check if this looks like an unqualified column reference
|
||||||
|
// (no dot, and likely a column name before an operator)
|
||||||
|
if !strings.Contains(part, ".") {
|
||||||
|
// Extract potential column name (before = or other operators)
|
||||||
|
for _, op := range []string{"=", "!=", "<>", ">", ">=", "<", "<=", " LIKE ", " IN ", " IS "} {
|
||||||
|
if strings.Contains(part, op) {
|
||||||
|
colName := strings.Split(part, op)[0]
|
||||||
|
colName = strings.TrimSpace(colName)
|
||||||
|
if colName != "" && !isOperatorOrKeyword(colName) {
|
||||||
|
// Add table prefix
|
||||||
|
prefixed := tableAlias + "." + colName + strings.TrimPrefix(part, colName)
|
||||||
|
modified = strings.ReplaceAll(modified, part, prefixed)
|
||||||
|
logger.Debug("Adding table prefix '%s' to column '%s' in JOIN condition", tableAlias, colName)
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return modified
|
||||||
|
}
|
||||||
|
|
||||||
|
// isOperatorOrKeyword checks if a string is likely an operator or SQL keyword
|
||||||
|
func isOperatorOrKeyword(s string) bool {
|
||||||
|
s = strings.ToUpper(strings.TrimSpace(s))
|
||||||
|
keywords := []string{"AND", "OR", "NOT", "IN", "IS", "NULL", "TRUE", "FALSE", "LIKE", "BETWEEN"}
|
||||||
|
for _, kw := range keywords {
|
||||||
|
if s == kw {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// isAcronymMatch checks if prefix is an acronym of tableName
|
||||||
|
// For example, "apil" matches "apiproviderlink" because each letter appears in sequence
|
||||||
|
func isAcronymMatch(prefix, tableName string) bool {
|
||||||
|
if len(prefix) == 0 || len(tableName) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
prefixIdx := 0
|
||||||
|
for i := 0; i < len(tableName) && prefixIdx < len(prefix); i++ {
|
||||||
|
if tableName[i] == prefix[prefixIdx] {
|
||||||
|
prefixIdx++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// All characters of prefix were found in sequence in tableName
|
||||||
|
return prefixIdx == len(prefix)
|
||||||
|
}
|
||||||
|
|
||||||
|
// normalizeTableAlias replaces table alias prefixes in SQL conditions
|
||||||
|
// This handles cases where a user references a table alias that doesn't match
|
||||||
|
// what Bun generates (common in preload contexts)
|
||||||
|
func normalizeTableAlias(query, expectedAlias, tableName string) string {
|
||||||
|
// Pattern: <word>.<column> where <word> might be an incorrect alias
|
||||||
|
// We'll look for patterns like "APIL.column" and either:
|
||||||
|
// 1. Remove the alias prefix if it's clearly meant for this table
|
||||||
|
// 2. Leave it alone if it might be referring to another table (JOIN/preload)
|
||||||
|
|
||||||
|
// Split on spaces and parentheses to find qualified references
|
||||||
|
parts := strings.FieldsFunc(query, func(r rune) bool {
|
||||||
|
return r == ' ' || r == '(' || r == ')' || r == ','
|
||||||
|
})
|
||||||
|
|
||||||
|
modified := query
|
||||||
|
for _, part := range parts {
|
||||||
|
// Check if this looks like a qualified column reference
|
||||||
|
if dotIndex := strings.Index(part, "."); dotIndex > 0 {
|
||||||
|
prefix := part[:dotIndex]
|
||||||
|
column := part[dotIndex+1:]
|
||||||
|
|
||||||
|
// Check if the prefix matches our expected alias or table name (case-insensitive)
|
||||||
|
if strings.EqualFold(prefix, expectedAlias) ||
|
||||||
|
strings.EqualFold(prefix, tableName) ||
|
||||||
|
strings.EqualFold(prefix, strings.ToLower(tableName)) {
|
||||||
|
// Prefix matches current table, it's safe but redundant - leave it
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the prefix could plausibly be an alias/acronym for this table
|
||||||
|
// Only strip if we're confident it's meant for this table
|
||||||
|
// For example: "APIL" could be an acronym for "apiproviderlink"
|
||||||
|
prefixLower := strings.ToLower(prefix)
|
||||||
|
tableNameLower := strings.ToLower(tableName)
|
||||||
|
|
||||||
|
// Check if prefix is a substring of table name
|
||||||
|
isSubstring := strings.Contains(tableNameLower, prefixLower) && len(prefixLower) > 2
|
||||||
|
|
||||||
|
// Check if prefix is an acronym of table name
|
||||||
|
// e.g., "APIL" matches "ApiProviderLink" (A-p-I-providerL-ink)
|
||||||
|
isAcronym := false
|
||||||
|
if !isSubstring && len(prefixLower) > 2 {
|
||||||
|
isAcronym = isAcronymMatch(prefixLower, tableNameLower)
|
||||||
|
}
|
||||||
|
|
||||||
|
if isSubstring || isAcronym {
|
||||||
|
// This looks like it could be an alias for this table - strip it
|
||||||
|
logger.Debug("Stripping plausible alias '%s' from WHERE condition, keeping just '%s'", prefix, column)
|
||||||
|
// Replace the qualified reference with just the column name
|
||||||
|
modified = strings.ReplaceAll(modified, part, column)
|
||||||
|
} else {
|
||||||
|
// Prefix doesn't match the current table at all
|
||||||
|
// It's likely referring to a different table (JOIN/preload)
|
||||||
|
// DON'T strip it - leave the qualified reference as-is
|
||||||
|
logger.Debug("Keeping qualified reference '%s' - prefix '%s' doesn't match current table '%s'", part, prefix, tableName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return modified
|
||||||
|
}
|
||||||
|
|
||||||
func (b *BunSelectQuery) WhereOr(query string, args ...interface{}) common.SelectQuery {
|
func (b *BunSelectQuery) WhereOr(query string, args ...interface{}) common.SelectQuery {
|
||||||
b.query = b.query.WhereOr(query, args...)
|
b.query = b.query.WhereOr(query, args...)
|
||||||
return b
|
return b
|
||||||
@@ -321,6 +460,27 @@ func (b *BunSelectQuery) Preload(relation string, conditions ...interface{}) com
|
|||||||
// }
|
// }
|
||||||
|
|
||||||
func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
|
func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
|
||||||
|
// Auto-detect relationship type and choose optimal loading strategy
|
||||||
|
// Get the model from the query if available
|
||||||
|
model := b.query.GetModel()
|
||||||
|
if model != nil && model.Value() != nil {
|
||||||
|
relType := reflection.GetRelationType(model.Value(), relation)
|
||||||
|
|
||||||
|
// Log the detected relationship type
|
||||||
|
logger.Debug("PreloadRelation '%s' detected as: %s", relation, relType)
|
||||||
|
|
||||||
|
// If this is a belongs-to or has-one relation, use JOIN for better performance
|
||||||
|
if relType.ShouldUseJoin() {
|
||||||
|
logger.Info("Using JOIN strategy for %s relation '%s'", relType, relation)
|
||||||
|
return b.JoinRelation(relation, apply...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// For has-many, many-to-many, or unknown: use separate query (safer default)
|
||||||
|
if relType == reflection.RelationHasMany || relType == reflection.RelationManyToMany {
|
||||||
|
logger.Debug("Using separate query for %s relation '%s'", relType, relation)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Check if this relation chain would create problematic long aliases
|
// Check if this relation chain would create problematic long aliases
|
||||||
relationParts := strings.Split(relation, ".")
|
relationParts := strings.Split(relation, ".")
|
||||||
aliasChain := strings.ToLower(strings.Join(relationParts, "__"))
|
aliasChain := strings.ToLower(strings.Join(relationParts, "__"))
|
||||||
@@ -383,6 +543,28 @@ func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.S
|
|||||||
db: b.db,
|
db: b.db,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Try to extract table name and alias from the preload model
|
||||||
|
if model := sq.GetModel(); model != nil && model.Value() != nil {
|
||||||
|
modelValue := model.Value()
|
||||||
|
|
||||||
|
// Extract table name if model implements TableNameProvider
|
||||||
|
if provider, ok := modelValue.(common.TableNameProvider); ok {
|
||||||
|
fullTableName := provider.TableName()
|
||||||
|
wrapper.schema, wrapper.tableName = parseTableName(fullTableName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract table alias if model implements TableAliasProvider
|
||||||
|
if provider, ok := modelValue.(common.TableAliasProvider); ok {
|
||||||
|
wrapper.tableAlias = provider.TableAlias()
|
||||||
|
// Apply the alias to the Bun query so conditions can reference it
|
||||||
|
if wrapper.tableAlias != "" {
|
||||||
|
// Note: Bun's Relation() already sets up the table, but we can add
|
||||||
|
// the alias explicitly if needed
|
||||||
|
logger.Debug("Preload relation '%s' using table alias: %s", relation, wrapper.tableAlias)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Start with the interface value (not pointer)
|
// Start with the interface value (not pointer)
|
||||||
current := common.SelectQuery(wrapper)
|
current := common.SelectQuery(wrapper)
|
||||||
|
|
||||||
@@ -405,6 +587,36 @@ func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.S
|
|||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (b *BunSelectQuery) JoinRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
|
||||||
|
// JoinRelation uses a LEFT JOIN instead of a separate query
|
||||||
|
// This is more efficient for many-to-one or one-to-one relationships
|
||||||
|
|
||||||
|
logger.Debug("JoinRelation '%s' - Using JOIN strategy with automatic WHERE prefix addition", relation)
|
||||||
|
|
||||||
|
// Wrap the apply functions to automatically add table prefix to WHERE conditions
|
||||||
|
wrappedApply := make([]func(common.SelectQuery) common.SelectQuery, 0, len(apply))
|
||||||
|
for _, fn := range apply {
|
||||||
|
if fn != nil {
|
||||||
|
wrappedFn := func(originalFn func(common.SelectQuery) common.SelectQuery) func(common.SelectQuery) common.SelectQuery {
|
||||||
|
return func(q common.SelectQuery) common.SelectQuery {
|
||||||
|
// Create a special wrapper that adds prefixes to WHERE conditions
|
||||||
|
if bunQuery, ok := q.(*BunSelectQuery); ok {
|
||||||
|
// Mark this query as being in JOIN context
|
||||||
|
bunQuery.inJoinContext = true
|
||||||
|
bunQuery.joinTableAlias = strings.ToLower(relation)
|
||||||
|
}
|
||||||
|
return originalFn(q)
|
||||||
|
}
|
||||||
|
}(fn)
|
||||||
|
wrappedApply = append(wrappedApply, wrappedFn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use PreloadRelation with the wrapped functions
|
||||||
|
// Bun's Relation() will use JOIN for belongs-to and has-one relations
|
||||||
|
return b.PreloadRelation(relation, wrappedApply...)
|
||||||
|
}
|
||||||
|
|
||||||
func (b *BunSelectQuery) Order(order string) common.SelectQuery {
|
func (b *BunSelectQuery) Order(order string) common.SelectQuery {
|
||||||
b.query = b.query.Order(order)
|
b.query = b.query.Order(order)
|
||||||
return b
|
return b
|
||||||
|
|||||||
@@ -104,10 +104,12 @@ func (g *GormAdapter) RunInTransaction(ctx context.Context, fn func(common.Datab
|
|||||||
|
|
||||||
// GormSelectQuery implements SelectQuery for GORM
|
// GormSelectQuery implements SelectQuery for GORM
|
||||||
type GormSelectQuery struct {
|
type GormSelectQuery struct {
|
||||||
db *gorm.DB
|
db *gorm.DB
|
||||||
schema string // Separated schema name
|
schema string // Separated schema name
|
||||||
tableName string // Just the table name, without schema
|
tableName string // Just the table name, without schema
|
||||||
tableAlias string
|
tableAlias string
|
||||||
|
inJoinContext bool // Track if we're in a JOIN relation context
|
||||||
|
joinTableAlias string // Alias to use for JOIN conditions
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GormSelectQuery) Model(model interface{}) common.SelectQuery {
|
func (g *GormSelectQuery) Model(model interface{}) common.SelectQuery {
|
||||||
@@ -151,10 +153,61 @@ func (g *GormSelectQuery) ColumnExpr(query string, args ...interface{}) common.S
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (g *GormSelectQuery) Where(query string, args ...interface{}) common.SelectQuery {
|
func (g *GormSelectQuery) Where(query string, args ...interface{}) common.SelectQuery {
|
||||||
|
// If we're in a JOIN context, add table prefix to unqualified columns
|
||||||
|
if g.inJoinContext && g.joinTableAlias != "" {
|
||||||
|
query = addTablePrefixGorm(query, g.joinTableAlias)
|
||||||
|
}
|
||||||
g.db = g.db.Where(query, args...)
|
g.db = g.db.Where(query, args...)
|
||||||
return g
|
return g
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// addTablePrefixGorm adds a table prefix to unqualified column references (GORM version)
|
||||||
|
func addTablePrefixGorm(query, tableAlias string) string {
|
||||||
|
if tableAlias == "" || query == "" {
|
||||||
|
return query
|
||||||
|
}
|
||||||
|
|
||||||
|
// Split on spaces and parentheses to find column references
|
||||||
|
parts := strings.FieldsFunc(query, func(r rune) bool {
|
||||||
|
return r == ' ' || r == '(' || r == ')' || r == ','
|
||||||
|
})
|
||||||
|
|
||||||
|
modified := query
|
||||||
|
for _, part := range parts {
|
||||||
|
// Check if this looks like an unqualified column reference
|
||||||
|
if !strings.Contains(part, ".") {
|
||||||
|
// Extract potential column name (before = or other operators)
|
||||||
|
for _, op := range []string{"=", "!=", "<>", ">", ">=", "<", "<=", " LIKE ", " IN ", " IS "} {
|
||||||
|
if strings.Contains(part, op) {
|
||||||
|
colName := strings.Split(part, op)[0]
|
||||||
|
colName = strings.TrimSpace(colName)
|
||||||
|
if colName != "" && !isOperatorOrKeywordGorm(colName) {
|
||||||
|
// Add table prefix
|
||||||
|
prefixed := tableAlias + "." + colName + strings.TrimPrefix(part, colName)
|
||||||
|
modified = strings.ReplaceAll(modified, part, prefixed)
|
||||||
|
logger.Debug("Adding table prefix '%s' to column '%s' in JOIN condition", tableAlias, colName)
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return modified
|
||||||
|
}
|
||||||
|
|
||||||
|
// isOperatorOrKeywordGorm checks if a string is likely an operator or SQL keyword (GORM version)
|
||||||
|
func isOperatorOrKeywordGorm(s string) bool {
|
||||||
|
s = strings.ToUpper(strings.TrimSpace(s))
|
||||||
|
keywords := []string{"AND", "OR", "NOT", "IN", "IS", "NULL", "TRUE", "FALSE", "LIKE", "BETWEEN"}
|
||||||
|
for _, kw := range keywords {
|
||||||
|
if s == kw {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
func (g *GormSelectQuery) WhereOr(query string, args ...interface{}) common.SelectQuery {
|
func (g *GormSelectQuery) WhereOr(query string, args ...interface{}) common.SelectQuery {
|
||||||
g.db = g.db.Or(query, args...)
|
g.db = g.db.Or(query, args...)
|
||||||
return g
|
return g
|
||||||
@@ -238,6 +291,27 @@ func (g *GormSelectQuery) Preload(relation string, conditions ...interface{}) co
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (g *GormSelectQuery) PreloadRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
|
func (g *GormSelectQuery) PreloadRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
|
||||||
|
// Auto-detect relationship type and choose optimal loading strategy
|
||||||
|
// Get the model from GORM's statement if available
|
||||||
|
if g.db.Statement != nil && g.db.Statement.Model != nil {
|
||||||
|
relType := reflection.GetRelationType(g.db.Statement.Model, relation)
|
||||||
|
|
||||||
|
// Log the detected relationship type
|
||||||
|
logger.Debug("PreloadRelation '%s' detected as: %s", relation, relType)
|
||||||
|
|
||||||
|
// If this is a belongs-to or has-one relation, use JOIN for better performance
|
||||||
|
if relType.ShouldUseJoin() {
|
||||||
|
logger.Info("Using JOIN strategy for %s relation '%s'", relType, relation)
|
||||||
|
return g.JoinRelation(relation, apply...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// For has-many, many-to-many, or unknown: use separate query (safer default)
|
||||||
|
if relType == reflection.RelationHasMany || relType == reflection.RelationManyToMany {
|
||||||
|
logger.Debug("Using separate query for %s relation '%s'", relType, relation)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use GORM's Preload (separate query strategy)
|
||||||
g.db = g.db.Preload(relation, func(db *gorm.DB) *gorm.DB {
|
g.db = g.db.Preload(relation, func(db *gorm.DB) *gorm.DB {
|
||||||
if len(apply) == 0 {
|
if len(apply) == 0 {
|
||||||
return db
|
return db
|
||||||
@@ -267,6 +341,42 @@ func (g *GormSelectQuery) PreloadRelation(relation string, apply ...func(common.
|
|||||||
return g
|
return g
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (g *GormSelectQuery) JoinRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
|
||||||
|
// JoinRelation uses a JOIN instead of a separate preload query
|
||||||
|
// This is more efficient for many-to-one or one-to-one relationships
|
||||||
|
// as it avoids additional round trips to the database
|
||||||
|
|
||||||
|
// GORM's Joins() method forces a JOIN for the preload
|
||||||
|
logger.Debug("JoinRelation '%s' - Using GORM Joins() with automatic WHERE prefix addition", relation)
|
||||||
|
|
||||||
|
g.db = g.db.Joins(relation, func(db *gorm.DB) *gorm.DB {
|
||||||
|
if len(apply) == 0 {
|
||||||
|
return db
|
||||||
|
}
|
||||||
|
|
||||||
|
wrapper := &GormSelectQuery{
|
||||||
|
db: db,
|
||||||
|
inJoinContext: true, // Mark as JOIN context
|
||||||
|
joinTableAlias: strings.ToLower(relation), // Use relation name as alias
|
||||||
|
}
|
||||||
|
current := common.SelectQuery(wrapper)
|
||||||
|
|
||||||
|
for _, fn := range apply {
|
||||||
|
if fn != nil {
|
||||||
|
current = fn(current)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if finalGorm, ok := current.(*GormSelectQuery); ok {
|
||||||
|
return finalGorm.db
|
||||||
|
}
|
||||||
|
|
||||||
|
return db
|
||||||
|
})
|
||||||
|
|
||||||
|
return g
|
||||||
|
}
|
||||||
|
|
||||||
func (g *GormSelectQuery) Order(order string) common.SelectQuery {
|
func (g *GormSelectQuery) Order(order string) common.SelectQuery {
|
||||||
g.db = g.db.Order(order)
|
g.db = g.db.Order(order)
|
||||||
return g
|
return g
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ type SelectQuery interface {
|
|||||||
LeftJoin(query string, args ...interface{}) SelectQuery
|
LeftJoin(query string, args ...interface{}) SelectQuery
|
||||||
Preload(relation string, conditions ...interface{}) SelectQuery
|
Preload(relation string, conditions ...interface{}) SelectQuery
|
||||||
PreloadRelation(relation string, apply ...func(SelectQuery) SelectQuery) SelectQuery
|
PreloadRelation(relation string, apply ...func(SelectQuery) SelectQuery) SelectQuery
|
||||||
|
JoinRelation(relation string, apply ...func(SelectQuery) SelectQuery) SelectQuery
|
||||||
Order(order string) SelectQuery
|
Order(order string) SelectQuery
|
||||||
Limit(n int) SelectQuery
|
Limit(n int) SelectQuery
|
||||||
Offset(n int) SelectQuery
|
Offset(n int) SelectQuery
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package common
|
package common
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
@@ -9,81 +8,40 @@ import (
|
|||||||
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ValidateAndFixPreloadWhere validates that the WHERE clause for a preload contains
|
// ValidateAndFixPreloadWhere validates and normalizes WHERE clauses for preloads
|
||||||
// the relation prefix (alias). If not present, it attempts to add it to column references.
|
//
|
||||||
// Returns the fixed WHERE clause and an error if it cannot be safely fixed.
|
// NOTE: For preload queries, table aliases from the parent query are not valid since
|
||||||
|
// the preload executes as a separate query with its own table alias. This function
|
||||||
|
// now simply validates basic syntax without requiring or adding prefixes.
|
||||||
|
// The actual alias normalization happens in the database adapter layer.
|
||||||
|
//
|
||||||
|
// Returns the WHERE clause and an error if it contains obviously invalid syntax.
|
||||||
func ValidateAndFixPreloadWhere(where string, relationName string) (string, error) {
|
func ValidateAndFixPreloadWhere(where string, relationName string) (string, error) {
|
||||||
if where == "" {
|
if where == "" {
|
||||||
return where, nil
|
return where, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if the relation name is already present in the WHERE clause
|
where = strings.TrimSpace(where)
|
||||||
lowerWhere := strings.ToLower(where)
|
|
||||||
lowerRelation := strings.ToLower(relationName)
|
|
||||||
|
|
||||||
// Check for patterns like "relation.", "relation ", or just "relation" followed by a dot
|
// Just do basic validation - don't require or add prefixes
|
||||||
if strings.Contains(lowerWhere, lowerRelation+".") ||
|
// The database adapter will handle alias normalization
|
||||||
strings.Contains(lowerWhere, "`"+lowerRelation+"`.") ||
|
|
||||||
strings.Contains(lowerWhere, "\""+lowerRelation+"\".") {
|
// Check if the WHERE clause contains any qualified column references
|
||||||
// Relation prefix is already present
|
// If it does, log a debug message but don't fail - let the adapter handle it
|
||||||
|
if strings.Contains(where, ".") {
|
||||||
|
logger.Debug("Preload WHERE clause for '%s' contains qualified column references: '%s'. "+
|
||||||
|
"Note: In preload context, table aliases from parent query are not available. "+
|
||||||
|
"The database adapter will normalize aliases automatically.", relationName, where)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate that it's not empty or just whitespace
|
||||||
|
if where == "" {
|
||||||
return where, nil
|
return where, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// If the WHERE clause is complex (contains OR, parentheses, subqueries, etc.),
|
// Return the WHERE clause as-is
|
||||||
// we can't safely auto-fix it - require explicit prefix
|
// The BunSelectQuery.Where() method will handle alias normalization via normalizeTableAlias()
|
||||||
if strings.Contains(lowerWhere, " or ") ||
|
return where, nil
|
||||||
strings.Contains(where, "(") ||
|
|
||||||
strings.Contains(where, ")") {
|
|
||||||
return "", fmt.Errorf("preload WHERE condition must reference the relation '%s' (e.g., '%s.column_name'). Complex WHERE clauses with OR/parentheses must explicitly use the relation prefix", relationName, relationName)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Try to add the relation prefix to simple column references
|
|
||||||
// This handles basic cases like "column = value" or "column = value AND other_column = value"
|
|
||||||
// Split by AND to handle multiple conditions (case-insensitive)
|
|
||||||
originalConditions := strings.Split(where, " AND ")
|
|
||||||
|
|
||||||
// If uppercase split didn't work, try lowercase
|
|
||||||
if len(originalConditions) == 1 {
|
|
||||||
originalConditions = strings.Split(where, " and ")
|
|
||||||
}
|
|
||||||
|
|
||||||
fixedConditions := make([]string, 0, len(originalConditions))
|
|
||||||
|
|
||||||
for _, cond := range originalConditions {
|
|
||||||
cond = strings.TrimSpace(cond)
|
|
||||||
if cond == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if this condition already has a table prefix (contains a dot)
|
|
||||||
if strings.Contains(cond, ".") {
|
|
||||||
fixedConditions = append(fixedConditions, cond)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if this is a SQL expression/literal that shouldn't be prefixed
|
|
||||||
lowerCond := strings.ToLower(strings.TrimSpace(cond))
|
|
||||||
if IsSQLExpression(lowerCond) {
|
|
||||||
// Don't prefix SQL expressions like "true", "false", "1=1", etc.
|
|
||||||
fixedConditions = append(fixedConditions, cond)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Extract the column name (first identifier before operator)
|
|
||||||
columnName := ExtractColumnName(cond)
|
|
||||||
if columnName == "" {
|
|
||||||
// Can't identify column name, require explicit prefix
|
|
||||||
return "", fmt.Errorf("preload WHERE condition must reference the relation '%s' (e.g., '%s.column_name'). Cannot auto-fix condition: %s", relationName, relationName, cond)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add relation prefix to the column name only
|
|
||||||
fixedCond := strings.Replace(cond, columnName, relationName+"."+columnName, 1)
|
|
||||||
fixedConditions = append(fixedConditions, fixedCond)
|
|
||||||
}
|
|
||||||
|
|
||||||
fixedWhere := strings.Join(fixedConditions, " AND ")
|
|
||||||
logger.Debug("Auto-fixed preload WHERE clause: '%s' -> '%s'", where, fixedWhere)
|
|
||||||
return fixedWhere, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsSQLExpression checks if a condition is a SQL expression that shouldn't be prefixed
|
// IsSQLExpression checks if a condition is a SQL expression that shouldn't be prefixed
|
||||||
@@ -120,17 +78,22 @@ func IsTrivialCondition(cond string) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// SanitizeWhereClause removes trivial conditions and optionally prefixes table/relation names to columns
|
// SanitizeWhereClause removes trivial conditions and fixes incorrect table prefixes
|
||||||
// This function should be used everywhere a WHERE statement is sent to ensure clean, efficient SQL
|
// This function should be used everywhere a WHERE statement is sent to ensure clean, efficient SQL
|
||||||
//
|
//
|
||||||
// Parameters:
|
// Parameters:
|
||||||
// - where: The WHERE clause string to sanitize
|
// - where: The WHERE clause string to sanitize
|
||||||
// - tableName: Optional table/relation name to prefix to column references (empty string to skip prefixing)
|
// - tableName: The correct table/relation name to use when fixing incorrect prefixes
|
||||||
|
// - options: Optional RequestOptions containing preload relations that should be allowed as valid prefixes
|
||||||
//
|
//
|
||||||
// Returns:
|
// Returns:
|
||||||
// - The sanitized WHERE clause with trivial conditions removed and columns optionally prefixed
|
// - The sanitized WHERE clause with trivial conditions removed and incorrect prefixes fixed
|
||||||
// - An empty string if all conditions were trivial or the input was empty
|
// - An empty string if all conditions were trivial or the input was empty
|
||||||
func SanitizeWhereClause(where string, tableName string) string {
|
//
|
||||||
|
// Note: This function will NOT add prefixes to unprefixed columns. It will only fix
|
||||||
|
// incorrect prefixes (e.g., wrong_table.column -> correct_table.column), unless the
|
||||||
|
// prefix matches a preloaded relation name, in which case it's left unchanged.
|
||||||
|
func SanitizeWhereClause(where string, tableName string, options ...*RequestOptions) string {
|
||||||
if where == "" {
|
if where == "" {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
@@ -146,6 +109,22 @@ func SanitizeWhereClause(where string, tableName string) string {
|
|||||||
validColumns = getValidColumnsForTable(tableName)
|
validColumns = getValidColumnsForTable(tableName)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Build a set of allowed table prefixes (main table + preloaded relations)
|
||||||
|
allowedPrefixes := make(map[string]bool)
|
||||||
|
if tableName != "" {
|
||||||
|
allowedPrefixes[tableName] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add preload relation names as allowed prefixes
|
||||||
|
if len(options) > 0 && options[0] != nil {
|
||||||
|
for pi := range options[0].Preload {
|
||||||
|
if options[0].Preload[pi].Relation != "" {
|
||||||
|
allowedPrefixes[options[0].Preload[pi].Relation] = true
|
||||||
|
logger.Debug("Added preload relation '%s' as allowed table prefix", options[0].Preload[pi].Relation)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Split by AND to handle multiple conditions
|
// Split by AND to handle multiple conditions
|
||||||
conditions := splitByAND(where)
|
conditions := splitByAND(where)
|
||||||
|
|
||||||
@@ -166,22 +145,23 @@ func SanitizeWhereClause(where string, tableName string) string {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// If tableName is provided and the condition doesn't already have a table prefix,
|
// If tableName is provided and the condition HAS a table prefix, check if it's correct
|
||||||
// attempt to add it
|
if tableName != "" && hasTablePrefix(condToCheck) {
|
||||||
if tableName != "" && !hasTablePrefix(condToCheck) {
|
// Extract the current prefix and column name
|
||||||
// Check if this is a SQL expression/literal that shouldn't be prefixed
|
currentPrefix, columnName := extractTableAndColumn(condToCheck)
|
||||||
if !IsSQLExpression(strings.ToLower(condToCheck)) {
|
|
||||||
// Extract the column name and prefix it
|
if currentPrefix != "" && columnName != "" {
|
||||||
columnName := ExtractColumnName(condToCheck)
|
// Check if the prefix is allowed (main table or preload relation)
|
||||||
if columnName != "" {
|
if !allowedPrefixes[currentPrefix] {
|
||||||
// Only prefix if this is a valid column in the model
|
// Prefix is not in the allowed list - only fix if it's a valid column in the main table
|
||||||
// If we don't have model info (validColumns is nil), prefix anyway for backward compatibility
|
|
||||||
if validColumns == nil || isValidColumn(columnName, validColumns) {
|
if validColumns == nil || isValidColumn(columnName, validColumns) {
|
||||||
// Replace in the original condition (without stripped parens)
|
// Replace the incorrect prefix with the correct main table name
|
||||||
cond = strings.Replace(cond, columnName, tableName+"."+columnName, 1)
|
oldRef := currentPrefix + "." + columnName
|
||||||
logger.Debug("Prefixed column in condition: '%s'", cond)
|
newRef := tableName + "." + columnName
|
||||||
|
cond = strings.Replace(cond, oldRef, newRef, 1)
|
||||||
|
logger.Debug("Fixed incorrect table prefix in condition: '%s' -> '%s'", oldRef, newRef)
|
||||||
} else {
|
} else {
|
||||||
logger.Debug("Skipping prefix for '%s' - not a valid column in model", columnName)
|
logger.Debug("Skipping prefix fix for '%s.%s' - not a valid column in main table (might be preload relation)", currentPrefix, columnName)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -330,6 +310,53 @@ func getValidColumnsForTable(tableName string) map[string]bool {
|
|||||||
return columnMap
|
return columnMap
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// extractTableAndColumn extracts the table prefix and column name from a qualified reference
|
||||||
|
// For example: "users.status = 'active'" returns ("users", "status")
|
||||||
|
// Returns empty strings if no table prefix is found
|
||||||
|
func extractTableAndColumn(cond string) (table string, column string) {
|
||||||
|
// Common SQL operators to find the column reference
|
||||||
|
operators := []string{" = ", " != ", " <> ", " > ", " >= ", " < ", " <= ", " LIKE ", " like ", " IN ", " in ", " IS ", " is "}
|
||||||
|
|
||||||
|
var columnRef string
|
||||||
|
|
||||||
|
// Find the column reference (left side of the operator)
|
||||||
|
for _, op := range operators {
|
||||||
|
if idx := strings.Index(cond, op); idx > 0 {
|
||||||
|
columnRef = strings.TrimSpace(cond[:idx])
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If no operator found, the whole condition might be the column reference
|
||||||
|
if columnRef == "" {
|
||||||
|
parts := strings.Fields(cond)
|
||||||
|
if len(parts) > 0 {
|
||||||
|
columnRef = parts[0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if columnRef == "" {
|
||||||
|
return "", ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove any quotes
|
||||||
|
columnRef = strings.Trim(columnRef, "`\"'")
|
||||||
|
|
||||||
|
// Check if it contains a dot (qualified reference)
|
||||||
|
if dotIdx := strings.LastIndex(columnRef, "."); dotIdx > 0 {
|
||||||
|
table = columnRef[:dotIdx]
|
||||||
|
column = columnRef[dotIdx+1:]
|
||||||
|
|
||||||
|
// Remove quotes from table and column if present
|
||||||
|
table = strings.Trim(table, "`\"'")
|
||||||
|
column = strings.Trim(column, "`\"'")
|
||||||
|
|
||||||
|
return table, column
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", ""
|
||||||
|
}
|
||||||
|
|
||||||
// isValidColumn checks if a column name exists in the valid columns map
|
// isValidColumn checks if a column name exists in the valid columns map
|
||||||
// Handles case-insensitive comparison
|
// Handles case-insensitive comparison
|
||||||
func isValidColumn(columnName string, validColumns map[string]bool) bool {
|
func isValidColumn(columnName string, validColumns map[string]bool) bool {
|
||||||
|
|||||||
@@ -32,29 +32,41 @@ func TestSanitizeWhereClause(t *testing.T) {
|
|||||||
expected: "",
|
expected: "",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "valid condition with parentheses",
|
name: "valid condition with parentheses - no prefix added",
|
||||||
where: "(status = 'active')",
|
where: "(status = 'active')",
|
||||||
tableName: "users",
|
tableName: "users",
|
||||||
expected: "users.status = 'active'",
|
expected: "status = 'active'",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "mixed trivial and valid conditions",
|
name: "mixed trivial and valid conditions - no prefix added",
|
||||||
where: "true AND status = 'active' AND 1=1",
|
where: "true AND status = 'active' AND 1=1",
|
||||||
tableName: "users",
|
tableName: "users",
|
||||||
expected: "users.status = 'active'",
|
expected: "status = 'active'",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "condition already with table prefix",
|
name: "condition with correct table prefix - unchanged",
|
||||||
where: "users.status = 'active'",
|
where: "users.status = 'active'",
|
||||||
tableName: "users",
|
tableName: "users",
|
||||||
expected: "users.status = 'active'",
|
expected: "users.status = 'active'",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "multiple valid conditions",
|
name: "condition with incorrect table prefix - fixed",
|
||||||
where: "status = 'active' AND age > 18",
|
where: "wrong_table.status = 'active'",
|
||||||
|
tableName: "users",
|
||||||
|
expected: "users.status = 'active'",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple conditions with incorrect prefix - fixed",
|
||||||
|
where: "wrong_table.status = 'active' AND wrong_table.age > 18",
|
||||||
tableName: "users",
|
tableName: "users",
|
||||||
expected: "users.status = 'active' AND users.age > 18",
|
expected: "users.status = 'active' AND users.age > 18",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "multiple valid conditions without prefix - no prefix added",
|
||||||
|
where: "status = 'active' AND age > 18",
|
||||||
|
tableName: "users",
|
||||||
|
expected: "status = 'active' AND age > 18",
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "no table name provided",
|
name: "no table name provided",
|
||||||
where: "status = 'active'",
|
where: "status = 'active'",
|
||||||
@@ -67,6 +79,12 @@ func TestSanitizeWhereClause(t *testing.T) {
|
|||||||
tableName: "users",
|
tableName: "users",
|
||||||
expected: "",
|
expected: "",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "mixed correct and incorrect prefixes",
|
||||||
|
where: "users.status = 'active' AND wrong_table.age > 18",
|
||||||
|
tableName: "users",
|
||||||
|
expected: "users.status = 'active' AND users.age > 18",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
@@ -159,6 +177,158 @@ func TestIsTrivialCondition(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestExtractTableAndColumn(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expectedTable string
|
||||||
|
expectedCol string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "qualified column with equals",
|
||||||
|
input: "users.status = 'active'",
|
||||||
|
expectedTable: "users",
|
||||||
|
expectedCol: "status",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "qualified column with greater than",
|
||||||
|
input: "users.age > 18",
|
||||||
|
expectedTable: "users",
|
||||||
|
expectedCol: "age",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "qualified column with LIKE",
|
||||||
|
input: "users.name LIKE '%john%'",
|
||||||
|
expectedTable: "users",
|
||||||
|
expectedCol: "name",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "qualified column with IN",
|
||||||
|
input: "users.status IN ('active', 'pending')",
|
||||||
|
expectedTable: "users",
|
||||||
|
expectedCol: "status",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unqualified column",
|
||||||
|
input: "status = 'active'",
|
||||||
|
expectedTable: "",
|
||||||
|
expectedCol: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "qualified with backticks",
|
||||||
|
input: "`users`.`status` = 'active'",
|
||||||
|
expectedTable: "users",
|
||||||
|
expectedCol: "status",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "schema.table.column reference",
|
||||||
|
input: "public.users.status = 'active'",
|
||||||
|
expectedTable: "public.users",
|
||||||
|
expectedCol: "status",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty string",
|
||||||
|
input: "",
|
||||||
|
expectedTable: "",
|
||||||
|
expectedCol: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
table, col := extractTableAndColumn(tt.input)
|
||||||
|
if table != tt.expectedTable || col != tt.expectedCol {
|
||||||
|
t.Errorf("extractTableAndColumn(%q) = (%q, %q); want (%q, %q)",
|
||||||
|
tt.input, table, col, tt.expectedTable, tt.expectedCol)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSanitizeWhereClauseWithPreloads(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
where string
|
||||||
|
tableName string
|
||||||
|
options *RequestOptions
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "preload relation prefix is preserved",
|
||||||
|
where: "Department.name = 'Engineering'",
|
||||||
|
tableName: "users",
|
||||||
|
options: &RequestOptions{
|
||||||
|
Preload: []PreloadOption{
|
||||||
|
{Relation: "Department"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: "Department.name = 'Engineering'",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple preload relations - all preserved",
|
||||||
|
where: "Department.name = 'Engineering' AND Manager.status = 'active'",
|
||||||
|
tableName: "users",
|
||||||
|
options: &RequestOptions{
|
||||||
|
Preload: []PreloadOption{
|
||||||
|
{Relation: "Department"},
|
||||||
|
{Relation: "Manager"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: "Department.name = 'Engineering' AND Manager.status = 'active'",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "mix of main table and preload relation",
|
||||||
|
where: "users.status = 'active' AND Department.name = 'Engineering'",
|
||||||
|
tableName: "users",
|
||||||
|
options: &RequestOptions{
|
||||||
|
Preload: []PreloadOption{
|
||||||
|
{Relation: "Department"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: "users.status = 'active' AND Department.name = 'Engineering'",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "incorrect prefix fixed when not a preload relation",
|
||||||
|
where: "wrong_table.status = 'active' AND Department.name = 'Engineering'",
|
||||||
|
tableName: "users",
|
||||||
|
options: &RequestOptions{
|
||||||
|
Preload: []PreloadOption{
|
||||||
|
{Relation: "Department"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: "users.status = 'active' AND Department.name = 'Engineering'",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no options provided - works as before",
|
||||||
|
where: "wrong_table.status = 'active'",
|
||||||
|
tableName: "users",
|
||||||
|
options: nil,
|
||||||
|
expected: "users.status = 'active'",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty preload list - works as before",
|
||||||
|
where: "wrong_table.status = 'active'",
|
||||||
|
tableName: "users",
|
||||||
|
options: &RequestOptions{Preload: []PreloadOption{}},
|
||||||
|
expected: "users.status = 'active'",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
var result string
|
||||||
|
if tt.options != nil {
|
||||||
|
result = SanitizeWhereClause(tt.where, tt.tableName, tt.options)
|
||||||
|
} else {
|
||||||
|
result = SanitizeWhereClause(tt.where, tt.tableName)
|
||||||
|
}
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("SanitizeWhereClause(%q, %q, options) = %q; want %q", tt.where, tt.tableName, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Test model for model-aware sanitization tests
|
// Test model for model-aware sanitization tests
|
||||||
type MasterTask struct {
|
type MasterTask struct {
|
||||||
ID int `bun:"id,pk"`
|
ID int `bun:"id,pk"`
|
||||||
@@ -182,34 +352,52 @@ func TestSanitizeWhereClauseWithModel(t *testing.T) {
|
|||||||
expected string
|
expected string
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "valid column gets prefixed",
|
name: "valid column without prefix - no prefix added",
|
||||||
where: "status = 'active'",
|
where: "status = 'active'",
|
||||||
tableName: "mastertask",
|
tableName: "mastertask",
|
||||||
|
expected: "status = 'active'",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple valid columns without prefix - no prefix added",
|
||||||
|
where: "status = 'active' AND user_id = 123",
|
||||||
|
tableName: "mastertask",
|
||||||
|
expected: "status = 'active' AND user_id = 123",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "incorrect table prefix on valid column - fixed",
|
||||||
|
where: "wrong_table.status = 'active'",
|
||||||
|
tableName: "mastertask",
|
||||||
expected: "mastertask.status = 'active'",
|
expected: "mastertask.status = 'active'",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "multiple valid columns get prefixed",
|
name: "incorrect prefix on invalid column - not fixed",
|
||||||
where: "status = 'active' AND user_id = 123",
|
where: "wrong_table.invalid_column = 'value'",
|
||||||
tableName: "mastertask",
|
tableName: "mastertask",
|
||||||
expected: "mastertask.status = 'active' AND mastertask.user_id = 123",
|
expected: "wrong_table.invalid_column = 'value'",
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "invalid column does not get prefixed",
|
|
||||||
where: "invalid_column = 'value'",
|
|
||||||
tableName: "mastertask",
|
|
||||||
expected: "invalid_column = 'value'",
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "mix of valid and trivial conditions",
|
name: "mix of valid and trivial conditions",
|
||||||
where: "true AND status = 'active' AND 1=1",
|
where: "true AND status = 'active' AND 1=1",
|
||||||
tableName: "mastertask",
|
tableName: "mastertask",
|
||||||
|
expected: "status = 'active'",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "parentheses with valid column - no prefix added",
|
||||||
|
where: "(status = 'active')",
|
||||||
|
tableName: "mastertask",
|
||||||
|
expected: "status = 'active'",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "correct prefix - unchanged",
|
||||||
|
where: "mastertask.status = 'active'",
|
||||||
|
tableName: "mastertask",
|
||||||
expected: "mastertask.status = 'active'",
|
expected: "mastertask.status = 'active'",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "parentheses with valid column",
|
name: "multiple conditions with mixed prefixes",
|
||||||
where: "(status = 'active')",
|
where: "mastertask.status = 'active' AND wrong_table.user_id = 123",
|
||||||
tableName: "mastertask",
|
tableName: "mastertask",
|
||||||
expected: "mastertask.status = 'active'",
|
expected: "mastertask.status = 'active' AND mastertask.user_id = 123",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
331
pkg/reflection/generic_model_test.go
Normal file
331
pkg/reflection/generic_model_test.go
Normal file
@@ -0,0 +1,331 @@
|
|||||||
|
package reflection
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Test models for GetModelColumnDetail
|
||||||
|
type TestModelForColumnDetail struct {
|
||||||
|
ID int `gorm:"column:rid_test;primaryKey;type:bigserial;not null" json:"id"`
|
||||||
|
Name string `gorm:"column:name;type:varchar(255);not null" json:"name"`
|
||||||
|
Email string `gorm:"column:email;type:varchar(255);unique;nullable" json:"email"`
|
||||||
|
Description string `gorm:"column:description;type:text;null" json:"description"`
|
||||||
|
ForeignKey int `gorm:"foreignKey:parent_id" json:"foreign_key"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type EmbeddedBase struct {
|
||||||
|
ID int `gorm:"column:rid_base;primaryKey;identity" json:"id"`
|
||||||
|
CreatedAt string `gorm:"column:created_at;type:timestamp" json:"created_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ModelWithEmbeddedForDetail struct {
|
||||||
|
EmbeddedBase
|
||||||
|
Title string `gorm:"column:title;type:varchar(100);not null" json:"title"`
|
||||||
|
Content string `gorm:"column:content;type:text" json:"content"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Model with nil embedded pointer
|
||||||
|
type ModelWithNilEmbedded struct {
|
||||||
|
ID int `gorm:"column:id;primaryKey" json:"id"`
|
||||||
|
*EmbeddedBase
|
||||||
|
Name string `gorm:"column:name" json:"name"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetModelColumnDetail(t *testing.T) {
|
||||||
|
t.Run("simple struct", func(t *testing.T) {
|
||||||
|
model := TestModelForColumnDetail{
|
||||||
|
ID: 1,
|
||||||
|
Name: "Test",
|
||||||
|
Email: "test@example.com",
|
||||||
|
Description: "Test description",
|
||||||
|
ForeignKey: 100,
|
||||||
|
}
|
||||||
|
|
||||||
|
details := GetModelColumnDetail(reflect.ValueOf(model))
|
||||||
|
|
||||||
|
if len(details) != 5 {
|
||||||
|
t.Errorf("Expected 5 fields, got %d", len(details))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check ID field
|
||||||
|
found := false
|
||||||
|
for _, detail := range details {
|
||||||
|
if detail.Name == "ID" {
|
||||||
|
found = true
|
||||||
|
if detail.SQLName != "rid_test" {
|
||||||
|
t.Errorf("Expected SQLName 'rid_test', got '%s'", detail.SQLName)
|
||||||
|
}
|
||||||
|
// Note: primaryKey (without underscore) is not detected as primary_key
|
||||||
|
// The function looks for "identity" or "primary_key" (with underscore)
|
||||||
|
if detail.SQLDataType != "bigserial" {
|
||||||
|
t.Errorf("Expected SQLDataType 'bigserial', got '%s'", detail.SQLDataType)
|
||||||
|
}
|
||||||
|
if detail.Nullable {
|
||||||
|
t.Errorf("Expected Nullable false, got true")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
t.Errorf("ID field not found in details")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("struct with embedded fields", func(t *testing.T) {
|
||||||
|
model := ModelWithEmbeddedForDetail{
|
||||||
|
EmbeddedBase: EmbeddedBase{
|
||||||
|
ID: 1,
|
||||||
|
CreatedAt: "2024-01-01",
|
||||||
|
},
|
||||||
|
Title: "Test Title",
|
||||||
|
Content: "Test Content",
|
||||||
|
}
|
||||||
|
|
||||||
|
details := GetModelColumnDetail(reflect.ValueOf(model))
|
||||||
|
|
||||||
|
// Should have 4 fields: ID, CreatedAt from embedded, Title, Content from main
|
||||||
|
if len(details) != 4 {
|
||||||
|
t.Errorf("Expected 4 fields, got %d", len(details))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that embedded field is included
|
||||||
|
foundID := false
|
||||||
|
foundCreatedAt := false
|
||||||
|
for _, detail := range details {
|
||||||
|
if detail.Name == "ID" {
|
||||||
|
foundID = true
|
||||||
|
if detail.SQLKey != "primary_key" {
|
||||||
|
t.Errorf("Expected SQLKey 'primary_key' for embedded ID, got '%s'", detail.SQLKey)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if detail.Name == "CreatedAt" {
|
||||||
|
foundCreatedAt = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !foundID {
|
||||||
|
t.Errorf("Embedded ID field not found")
|
||||||
|
}
|
||||||
|
if !foundCreatedAt {
|
||||||
|
t.Errorf("Embedded CreatedAt field not found")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("nil embedded pointer is skipped", func(t *testing.T) {
|
||||||
|
model := ModelWithNilEmbedded{
|
||||||
|
ID: 1,
|
||||||
|
Name: "Test",
|
||||||
|
EmbeddedBase: nil, // nil embedded pointer
|
||||||
|
}
|
||||||
|
|
||||||
|
details := GetModelColumnDetail(reflect.ValueOf(model))
|
||||||
|
|
||||||
|
// Should have 2 fields: ID and Name (embedded is nil, so skipped)
|
||||||
|
if len(details) != 2 {
|
||||||
|
t.Errorf("Expected 2 fields (nil embedded skipped), got %d", len(details))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("pointer to struct", func(t *testing.T) {
|
||||||
|
model := &TestModelForColumnDetail{
|
||||||
|
ID: 1,
|
||||||
|
Name: "Test",
|
||||||
|
}
|
||||||
|
|
||||||
|
details := GetModelColumnDetail(reflect.ValueOf(model))
|
||||||
|
|
||||||
|
if len(details) != 5 {
|
||||||
|
t.Errorf("Expected 5 fields, got %d", len(details))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("invalid value", func(t *testing.T) {
|
||||||
|
var invalid reflect.Value
|
||||||
|
details := GetModelColumnDetail(invalid)
|
||||||
|
|
||||||
|
if len(details) != 0 {
|
||||||
|
t.Errorf("Expected 0 fields for invalid value, got %d", len(details))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("non-struct type", func(t *testing.T) {
|
||||||
|
details := GetModelColumnDetail(reflect.ValueOf(123))
|
||||||
|
|
||||||
|
if len(details) != 0 {
|
||||||
|
t.Errorf("Expected 0 fields for non-struct, got %d", len(details))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("nullable and not null detection", func(t *testing.T) {
|
||||||
|
model := TestModelForColumnDetail{}
|
||||||
|
details := GetModelColumnDetail(reflect.ValueOf(model))
|
||||||
|
|
||||||
|
for _, detail := range details {
|
||||||
|
switch detail.Name {
|
||||||
|
case "ID":
|
||||||
|
if detail.Nullable {
|
||||||
|
t.Errorf("ID should not be nullable (has 'not null')")
|
||||||
|
}
|
||||||
|
case "Name":
|
||||||
|
if detail.Nullable {
|
||||||
|
t.Errorf("Name should not be nullable (has 'not null')")
|
||||||
|
}
|
||||||
|
case "Email":
|
||||||
|
if !detail.Nullable {
|
||||||
|
t.Errorf("Email should be nullable (has 'nullable')")
|
||||||
|
}
|
||||||
|
case "Description":
|
||||||
|
if !detail.Nullable {
|
||||||
|
t.Errorf("Description should be nullable (has 'null')")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("unique and uniqueindex detection", func(t *testing.T) {
|
||||||
|
type UniqueTestModel struct {
|
||||||
|
ID int `gorm:"column:id;primary_key"`
|
||||||
|
Username string `gorm:"column:username;unique"`
|
||||||
|
Email string `gorm:"column:email;uniqueindex"`
|
||||||
|
}
|
||||||
|
|
||||||
|
model := UniqueTestModel{}
|
||||||
|
details := GetModelColumnDetail(reflect.ValueOf(model))
|
||||||
|
|
||||||
|
for _, detail := range details {
|
||||||
|
switch detail.Name {
|
||||||
|
case "ID":
|
||||||
|
if detail.SQLKey != "primary_key" {
|
||||||
|
t.Errorf("ID should have SQLKey 'primary_key', got '%s'", detail.SQLKey)
|
||||||
|
}
|
||||||
|
case "Username":
|
||||||
|
if detail.SQLKey != "unique" {
|
||||||
|
t.Errorf("Username should have SQLKey 'unique', got '%s'", detail.SQLKey)
|
||||||
|
}
|
||||||
|
case "Email":
|
||||||
|
// The function checks for "unique" first, so uniqueindex is also detected as "unique"
|
||||||
|
// This is expected behavior based on the code logic
|
||||||
|
if detail.SQLKey != "unique" {
|
||||||
|
t.Errorf("Email should have SQLKey 'unique' (uniqueindex contains 'unique'), got '%s'", detail.SQLKey)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("foreign key detection", func(t *testing.T) {
|
||||||
|
// Note: The foreignkey extraction in generic_model.go has a bug where
|
||||||
|
// it requires ik > 0, so foreignkey at the start won't extract the value
|
||||||
|
type FKTestModel struct {
|
||||||
|
ParentID int `gorm:"column:parent_id;foreignkey:rid_parent;association_foreignkey:id_atevent"`
|
||||||
|
}
|
||||||
|
|
||||||
|
model := FKTestModel{}
|
||||||
|
details := GetModelColumnDetail(reflect.ValueOf(model))
|
||||||
|
|
||||||
|
if len(details) == 0 {
|
||||||
|
t.Fatal("Expected at least 1 field")
|
||||||
|
}
|
||||||
|
|
||||||
|
detail := details[0]
|
||||||
|
if detail.SQLKey != "foreign_key" {
|
||||||
|
t.Errorf("Expected SQLKey 'foreign_key', got '%s'", detail.SQLKey)
|
||||||
|
}
|
||||||
|
// Due to the bug in the code (requires ik > 0), the SQLName will be extracted
|
||||||
|
// when foreignkey is not at the beginning of the string
|
||||||
|
if detail.SQLName != "rid_parent" {
|
||||||
|
t.Errorf("Expected SQLName 'rid_parent', got '%s'", detail.SQLName)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFnFindKeyVal(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
src string
|
||||||
|
key string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "find column",
|
||||||
|
src: "column:user_id;primaryKey;type:bigint",
|
||||||
|
key: "column:",
|
||||||
|
expected: "user_id",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "find type",
|
||||||
|
src: "column:name;type:varchar(255);not null",
|
||||||
|
key: "type:",
|
||||||
|
expected: "varchar(255)",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "key not found",
|
||||||
|
src: "primaryKey;autoIncrement",
|
||||||
|
key: "column:",
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "key at end without semicolon",
|
||||||
|
src: "primaryKey;column:id",
|
||||||
|
key: "column:",
|
||||||
|
expected: "id",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "case insensitive search",
|
||||||
|
src: "Column:user_id;primaryKey",
|
||||||
|
key: "column:",
|
||||||
|
expected: "user_id",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty src",
|
||||||
|
src: "",
|
||||||
|
key: "column:",
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple occurrences (returns first)",
|
||||||
|
src: "column:first;column:second",
|
||||||
|
key: "column:",
|
||||||
|
expected: "first",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := fnFindKeyVal(tt.src, tt.key)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("fnFindKeyVal(%q, %q) = %q, want %q", tt.src, tt.key, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetModelColumnDetail_FieldValue(t *testing.T) {
|
||||||
|
model := TestModelForColumnDetail{
|
||||||
|
ID: 123,
|
||||||
|
Name: "TestName",
|
||||||
|
Email: "test@example.com",
|
||||||
|
}
|
||||||
|
|
||||||
|
details := GetModelColumnDetail(reflect.ValueOf(model))
|
||||||
|
|
||||||
|
for _, detail := range details {
|
||||||
|
if !detail.FieldValue.IsValid() {
|
||||||
|
t.Errorf("Field %s has invalid FieldValue", detail.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that FieldValue matches the actual value
|
||||||
|
switch detail.Name {
|
||||||
|
case "ID":
|
||||||
|
if detail.FieldValue.Int() != 123 {
|
||||||
|
t.Errorf("Expected ID FieldValue 123, got %v", detail.FieldValue.Int())
|
||||||
|
}
|
||||||
|
case "Name":
|
||||||
|
if detail.FieldValue.String() != "TestName" {
|
||||||
|
t.Errorf("Expected Name FieldValue 'TestName', got %v", detail.FieldValue.String())
|
||||||
|
}
|
||||||
|
case "Email":
|
||||||
|
if detail.FieldValue.String() != "test@example.com" {
|
||||||
|
t.Errorf("Expected Email FieldValue 'test@example.com', got %v", detail.FieldValue.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -750,6 +750,118 @@ func ConvertToNumericType(value string, kind reflect.Kind) (interface{}, error)
|
|||||||
return nil, fmt.Errorf("unsupported numeric type: %v", kind)
|
return nil, fmt.Errorf("unsupported numeric type: %v", kind)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RelationType represents the type of database relationship
|
||||||
|
type RelationType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
RelationHasMany RelationType = "has-many" // 1:N - use separate query
|
||||||
|
RelationBelongsTo RelationType = "belongs-to" // N:1 - use JOIN
|
||||||
|
RelationHasOne RelationType = "has-one" // 1:1 - use JOIN
|
||||||
|
RelationManyToMany RelationType = "many-to-many" // M:N - use separate query
|
||||||
|
RelationUnknown RelationType = "unknown"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ShouldUseJoin returns true if the relation type should use a JOIN instead of separate query
|
||||||
|
func (rt RelationType) ShouldUseJoin() bool {
|
||||||
|
return rt == RelationBelongsTo || rt == RelationHasOne
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRelationType inspects the model's struct tags to determine the relationship type
|
||||||
|
// It checks both Bun and GORM tags to identify the relationship cardinality
|
||||||
|
func GetRelationType(model interface{}, fieldName string) RelationType {
|
||||||
|
if model == nil || fieldName == "" {
|
||||||
|
return RelationUnknown
|
||||||
|
}
|
||||||
|
|
||||||
|
modelType := reflect.TypeOf(model)
|
||||||
|
if modelType == nil {
|
||||||
|
return RelationUnknown
|
||||||
|
}
|
||||||
|
|
||||||
|
if modelType.Kind() == reflect.Ptr {
|
||||||
|
modelType = modelType.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||||
|
return RelationUnknown
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find the field
|
||||||
|
for i := 0; i < modelType.NumField(); i++ {
|
||||||
|
field := modelType.Field(i)
|
||||||
|
|
||||||
|
// Check if field name matches (case-insensitive)
|
||||||
|
if !strings.EqualFold(field.Name, fieldName) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check Bun tags first
|
||||||
|
bunTag := field.Tag.Get("bun")
|
||||||
|
if bunTag != "" && strings.Contains(bunTag, "rel:") {
|
||||||
|
// Parse bun relation tag: rel:has-many, rel:belongs-to, rel:has-one, rel:many-to-many
|
||||||
|
parts := strings.Split(bunTag, ",")
|
||||||
|
for _, part := range parts {
|
||||||
|
part = strings.TrimSpace(part)
|
||||||
|
if strings.HasPrefix(part, "rel:") {
|
||||||
|
relType := strings.TrimPrefix(part, "rel:")
|
||||||
|
switch relType {
|
||||||
|
case "has-many":
|
||||||
|
return RelationHasMany
|
||||||
|
case "belongs-to":
|
||||||
|
return RelationBelongsTo
|
||||||
|
case "has-one":
|
||||||
|
return RelationHasOne
|
||||||
|
case "many-to-many", "m2m":
|
||||||
|
return RelationManyToMany
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check GORM tags
|
||||||
|
gormTag := field.Tag.Get("gorm")
|
||||||
|
if gormTag != "" {
|
||||||
|
// GORM uses different patterns:
|
||||||
|
// - foreignKey: usually indicates belongs-to or has-one
|
||||||
|
// - many2many: indicates many-to-many
|
||||||
|
// - Field type (slice vs pointer) helps determine cardinality
|
||||||
|
|
||||||
|
if strings.Contains(gormTag, "many2many:") {
|
||||||
|
return RelationManyToMany
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check field type for cardinality hints
|
||||||
|
fieldType := field.Type
|
||||||
|
if fieldType.Kind() == reflect.Slice {
|
||||||
|
// Slice indicates has-many or many-to-many
|
||||||
|
return RelationHasMany
|
||||||
|
}
|
||||||
|
if fieldType.Kind() == reflect.Ptr {
|
||||||
|
// Pointer to single struct usually indicates belongs-to or has-one
|
||||||
|
// Check if it has foreignKey (belongs-to) or references (has-one)
|
||||||
|
if strings.Contains(gormTag, "foreignKey:") {
|
||||||
|
return RelationBelongsTo
|
||||||
|
}
|
||||||
|
return RelationHasOne
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fall back to field type inference
|
||||||
|
fieldType := field.Type
|
||||||
|
if fieldType.Kind() == reflect.Slice {
|
||||||
|
// Slice of structs → has-many
|
||||||
|
return RelationHasMany
|
||||||
|
}
|
||||||
|
if fieldType.Kind() == reflect.Ptr || fieldType.Kind() == reflect.Struct {
|
||||||
|
// Single struct → belongs-to (default assumption for safety)
|
||||||
|
// Using belongs-to as default ensures we use JOIN, which is safer
|
||||||
|
return RelationBelongsTo
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return RelationUnknown
|
||||||
|
}
|
||||||
|
|
||||||
// GetRelationModel gets the model type for a relation field
|
// GetRelationModel gets the model type for a relation field
|
||||||
// It searches for the field by name in the following order (case-insensitive):
|
// It searches for the field by name in the following order (case-insensitive):
|
||||||
// 1. Actual field name
|
// 1. Actual field name
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -316,7 +316,7 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
// Apply cursor filter to query
|
// Apply cursor filter to query
|
||||||
if cursorFilter != "" {
|
if cursorFilter != "" {
|
||||||
logger.Debug("Applying cursor filter: %s", cursorFilter)
|
logger.Debug("Applying cursor filter: %s", cursorFilter)
|
||||||
sanitizedCursor := common.SanitizeWhereClause(cursorFilter, reflection.ExtractTableNameOnly(tableName))
|
sanitizedCursor := common.SanitizeWhereClause(cursorFilter, reflection.ExtractTableNameOnly(tableName), &options)
|
||||||
if sanitizedCursor != "" {
|
if sanitizedCursor != "" {
|
||||||
query = query.Where(sanitizedCursor)
|
query = query.Where(sanitizedCursor)
|
||||||
}
|
}
|
||||||
@@ -1351,7 +1351,9 @@ func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, pre
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(preload.Where) > 0 {
|
if len(preload.Where) > 0 {
|
||||||
sanitizedWhere := common.SanitizeWhereClause(preload.Where, reflection.ExtractTableNameOnly(preload.Relation))
|
// Build RequestOptions with all preloads to allow references to sibling relations
|
||||||
|
preloadOpts := &common.RequestOptions{Preload: preloads}
|
||||||
|
sanitizedWhere := common.SanitizeWhereClause(preload.Where, reflection.ExtractTableNameOnly(preload.Relation), preloadOpts)
|
||||||
if len(sanitizedWhere) > 0 {
|
if len(sanitizedWhere) > 0 {
|
||||||
sq = sq.Where(sanitizedWhere)
|
sq = sq.Where(sanitizedWhere)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -450,7 +450,7 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Apply the preload with recursive support
|
// Apply the preload with recursive support
|
||||||
query = h.applyPreloadWithRecursion(query, preload, model, 0)
|
query = h.applyPreloadWithRecursion(query, preload, options.Preload, model, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Apply DISTINCT if requested
|
// Apply DISTINCT if requested
|
||||||
@@ -480,8 +480,8 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
// Apply custom SQL WHERE clause (AND condition)
|
// Apply custom SQL WHERE clause (AND condition)
|
||||||
if options.CustomSQLWhere != "" {
|
if options.CustomSQLWhere != "" {
|
||||||
logger.Debug("Applying custom SQL WHERE: %s", options.CustomSQLWhere)
|
logger.Debug("Applying custom SQL WHERE: %s", options.CustomSQLWhere)
|
||||||
// Sanitize without auto-prefixing since custom SQL may reference multiple tables
|
// Sanitize and allow preload table prefixes since custom SQL may reference multiple tables
|
||||||
sanitizedWhere := common.SanitizeWhereClause(options.CustomSQLWhere, reflection.ExtractTableNameOnly(tableName))
|
sanitizedWhere := common.SanitizeWhereClause(options.CustomSQLWhere, reflection.ExtractTableNameOnly(tableName), &options.RequestOptions)
|
||||||
if sanitizedWhere != "" {
|
if sanitizedWhere != "" {
|
||||||
query = query.Where(sanitizedWhere)
|
query = query.Where(sanitizedWhere)
|
||||||
}
|
}
|
||||||
@@ -490,8 +490,8 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
// Apply custom SQL WHERE clause (OR condition)
|
// Apply custom SQL WHERE clause (OR condition)
|
||||||
if options.CustomSQLOr != "" {
|
if options.CustomSQLOr != "" {
|
||||||
logger.Debug("Applying custom SQL OR: %s", options.CustomSQLOr)
|
logger.Debug("Applying custom SQL OR: %s", options.CustomSQLOr)
|
||||||
// Sanitize without auto-prefixing since custom SQL may reference multiple tables
|
// Sanitize and allow preload table prefixes since custom SQL may reference multiple tables
|
||||||
sanitizedOr := common.SanitizeWhereClause(options.CustomSQLOr, reflection.ExtractTableNameOnly(tableName))
|
sanitizedOr := common.SanitizeWhereClause(options.CustomSQLOr, reflection.ExtractTableNameOnly(tableName), &options.RequestOptions)
|
||||||
if sanitizedOr != "" {
|
if sanitizedOr != "" {
|
||||||
query = query.WhereOr(sanitizedOr)
|
query = query.WhereOr(sanitizedOr)
|
||||||
}
|
}
|
||||||
@@ -625,7 +625,7 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
// Apply cursor filter to query
|
// Apply cursor filter to query
|
||||||
if cursorFilter != "" {
|
if cursorFilter != "" {
|
||||||
logger.Debug("Applying cursor filter: %s", cursorFilter)
|
logger.Debug("Applying cursor filter: %s", cursorFilter)
|
||||||
sanitizedCursor := common.SanitizeWhereClause(cursorFilter, reflection.ExtractTableNameOnly(tableName))
|
sanitizedCursor := common.SanitizeWhereClause(cursorFilter, reflection.ExtractTableNameOnly(tableName), &options.RequestOptions)
|
||||||
if sanitizedCursor != "" {
|
if sanitizedCursor != "" {
|
||||||
query = query.Where(sanitizedCursor)
|
query = query.Where(sanitizedCursor)
|
||||||
}
|
}
|
||||||
@@ -703,7 +703,7 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
}
|
}
|
||||||
|
|
||||||
// applyPreloadWithRecursion applies a preload with support for ComputedQL and recursive preloading
|
// applyPreloadWithRecursion applies a preload with support for ComputedQL and recursive preloading
|
||||||
func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload common.PreloadOption, model interface{}, depth int) common.SelectQuery {
|
func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload common.PreloadOption, allPreloads []common.PreloadOption, model interface{}, depth int) common.SelectQuery {
|
||||||
// Log relationship keys if they're specified (from XFiles)
|
// Log relationship keys if they're specified (from XFiles)
|
||||||
if preload.RelatedKey != "" || preload.ForeignKey != "" || preload.PrimaryKey != "" {
|
if preload.RelatedKey != "" || preload.ForeignKey != "" || preload.PrimaryKey != "" {
|
||||||
logger.Debug("Preload %s has relationship keys - PK: %s, RelatedKey: %s, ForeignKey: %s",
|
logger.Debug("Preload %s has relationship keys - PK: %s, RelatedKey: %s, ForeignKey: %s",
|
||||||
@@ -799,7 +799,9 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co
|
|||||||
|
|
||||||
// Apply WHERE clause
|
// Apply WHERE clause
|
||||||
if len(preload.Where) > 0 {
|
if len(preload.Where) > 0 {
|
||||||
sanitizedWhere := common.SanitizeWhereClause(preload.Where, reflection.ExtractTableNameOnly(preload.Relation))
|
// Build RequestOptions with all preloads to allow references to sibling relations
|
||||||
|
preloadOpts := &common.RequestOptions{Preload: allPreloads}
|
||||||
|
sanitizedWhere := common.SanitizeWhereClause(preload.Where, reflection.ExtractTableNameOnly(preload.Relation), preloadOpts)
|
||||||
if len(sanitizedWhere) > 0 {
|
if len(sanitizedWhere) > 0 {
|
||||||
sq = sq.Where(sanitizedWhere)
|
sq = sq.Where(sanitizedWhere)
|
||||||
}
|
}
|
||||||
@@ -832,7 +834,7 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co
|
|||||||
recursivePreload.Relation = preload.Relation + "." + lastRelationName
|
recursivePreload.Relation = preload.Relation + "." + lastRelationName
|
||||||
|
|
||||||
// Recursively apply preload until we reach depth 5
|
// Recursively apply preload until we reach depth 5
|
||||||
query = h.applyPreloadWithRecursion(query, recursivePreload, model, depth+1)
|
query = h.applyPreloadWithRecursion(query, recursivePreload, allPreloads, model, depth+1)
|
||||||
}
|
}
|
||||||
|
|
||||||
return query
|
return query
|
||||||
|
|||||||
Reference in New Issue
Block a user