mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2025-12-29 07:44:25 +00:00
Compare commits
9 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
db2b7e878e | ||
|
|
9572bfc7b8 | ||
|
|
f0962ea1ec | ||
|
|
8fcb065b42 | ||
|
|
dc3b621380 | ||
|
|
a4dd2a7086 | ||
|
|
3ec2e5f15a | ||
|
|
c52afe2825 | ||
|
|
76e98d02c3 |
218
pkg/common/adapters/database/RELATION_LOADING.md
Normal file
218
pkg/common/adapters/database/RELATION_LOADING.md
Normal file
@@ -0,0 +1,218 @@
|
||||
# Automatic Relation Loading Strategies
|
||||
|
||||
## Overview
|
||||
|
||||
**NEW:** The database adapters now **automatically** choose the optimal loading strategy by inspecting your model's relationship tags!
|
||||
|
||||
Simply use `PreloadRelation()` and the system automatically:
|
||||
- Detects relationship type from Bun/GORM tags
|
||||
- Uses **JOIN** for many-to-one and one-to-one (efficient, no duplication)
|
||||
- Uses **separate query** for one-to-many and many-to-many (avoids duplication)
|
||||
|
||||
## How It Works
|
||||
|
||||
```go
|
||||
// Just write this - the system handles the rest!
|
||||
db.NewSelect().
|
||||
Model(&links).
|
||||
PreloadRelation("Provider"). // ✓ Auto-detects belongs-to → uses JOIN
|
||||
PreloadRelation("Tags"). // ✓ Auto-detects has-many → uses separate query
|
||||
Scan(ctx, &links)
|
||||
```
|
||||
|
||||
### Detection Logic
|
||||
|
||||
The system inspects your model's struct tags:
|
||||
|
||||
**Bun models:**
|
||||
```go
|
||||
type Link struct {
|
||||
Provider *Provider `bun:"rel:belongs-to"` // → Detected: belongs-to → JOIN
|
||||
Tags []Tag `bun:"rel:has-many"` // → Detected: has-many → Separate query
|
||||
}
|
||||
```
|
||||
|
||||
**GORM models:**
|
||||
```go
|
||||
type Link struct {
|
||||
ProviderID int
|
||||
Provider *Provider `gorm:"foreignKey:ProviderID"` // → Detected: belongs-to → JOIN
|
||||
Tags []Tag `gorm:"many2many:link_tags"` // → Detected: many-to-many → Separate query
|
||||
}
|
||||
```
|
||||
|
||||
**Type inference (fallback):**
|
||||
- `[]Type` (slice) → has-many → Separate query
|
||||
- `*Type` (pointer) → belongs-to → JOIN
|
||||
- `Type` (struct) → belongs-to → JOIN
|
||||
|
||||
### What Gets Logged
|
||||
|
||||
Enable debug logging to see strategy selection:
|
||||
|
||||
```go
|
||||
bunAdapter.EnableQueryDebug()
|
||||
```
|
||||
|
||||
**Output:**
|
||||
```
|
||||
DEBUG: PreloadRelation 'Provider' detected as: belongs-to
|
||||
INFO: Using JOIN strategy for belongs-to relation 'Provider'
|
||||
DEBUG: PreloadRelation 'Links' detected as: has-many
|
||||
DEBUG: Using separate query for has-many relation 'Links'
|
||||
```
|
||||
|
||||
## Relationship Types
|
||||
|
||||
| Bun Tag | GORM Pattern | Field Type | Strategy | Why |
|
||||
|---------|--------------|------------|----------|-----|
|
||||
| `rel:has-many` | Slice field | `[]Type` | Separate Query | Avoids duplicating parent data |
|
||||
| `rel:belongs-to` | `foreignKey:` | `*Type` | JOIN | Single parent, no duplication |
|
||||
| `rel:has-one` | Single pointer | `*Type` | JOIN | One-to-one, no duplication |
|
||||
| `rel:many-to-many` | `many2many:` | `[]Type` | Separate Query | Complex join, avoid cartesian |
|
||||
|
||||
## Manual Override
|
||||
|
||||
If you need to force a specific strategy, use `JoinRelation()`:
|
||||
|
||||
```go
|
||||
// Force JOIN even for has-many (not recommended)
|
||||
db.NewSelect().
|
||||
Model(&providers).
|
||||
JoinRelation("Links"). // Explicitly use JOIN
|
||||
Scan(ctx, &providers)
|
||||
```
|
||||
|
||||
## Examples
|
||||
|
||||
### Automatic Strategy Selection (Recommended)
|
||||
|
||||
```go
|
||||
// Example 1: Loading parent provider for each link
|
||||
// System detects belongs-to → uses JOIN automatically
|
||||
db.NewSelect().
|
||||
Model(&links).
|
||||
PreloadRelation("Provider", func(q common.SelectQuery) common.SelectQuery {
|
||||
return q.Where("active = ?", true)
|
||||
}).
|
||||
Scan(ctx, &links)
|
||||
|
||||
// Generated SQL: Single query with JOIN
|
||||
// SELECT links.*, providers.*
|
||||
// FROM links
|
||||
// LEFT JOIN providers ON links.provider_id = providers.id
|
||||
// WHERE providers.active = true
|
||||
|
||||
// Example 2: Loading child links for each provider
|
||||
// System detects has-many → uses separate query automatically
|
||||
db.NewSelect().
|
||||
Model(&providers).
|
||||
PreloadRelation("Links", func(q common.SelectQuery) common.SelectQuery {
|
||||
return q.Where("active = ?", true)
|
||||
}).
|
||||
Scan(ctx, &providers)
|
||||
|
||||
// Generated SQL: Two queries
|
||||
// Query 1: SELECT * FROM providers
|
||||
// Query 2: SELECT * FROM links
|
||||
// WHERE provider_id IN (1, 2, 3, ...)
|
||||
// AND active = true
|
||||
```
|
||||
|
||||
### Mixed Relationships
|
||||
|
||||
```go
|
||||
type Order struct {
|
||||
ID int
|
||||
CustomerID int
|
||||
Customer *Customer `bun:"rel:belongs-to"` // JOIN
|
||||
Items []Item `bun:"rel:has-many"` // Separate
|
||||
Invoice *Invoice `bun:"rel:has-one"` // JOIN
|
||||
}
|
||||
|
||||
// All three handled optimally!
|
||||
db.NewSelect().
|
||||
Model(&orders).
|
||||
PreloadRelation("Customer"). // → JOIN (many-to-one)
|
||||
PreloadRelation("Items"). // → Separate (one-to-many)
|
||||
PreloadRelation("Invoice"). // → JOIN (one-to-one)
|
||||
Scan(ctx, &orders)
|
||||
```
|
||||
|
||||
## Performance Benefits
|
||||
|
||||
### Before (Manual Strategy Selection)
|
||||
|
||||
```go
|
||||
// You had to remember which to use:
|
||||
.PreloadRelation("Provider") // Should I use PreloadRelation or JoinRelation?
|
||||
.PreloadRelation("Links") // Which is more efficient here?
|
||||
```
|
||||
|
||||
### After (Automatic Selection)
|
||||
|
||||
```go
|
||||
// Just use PreloadRelation everywhere:
|
||||
.PreloadRelation("Provider") // ✓ System uses JOIN automatically
|
||||
.PreloadRelation("Links") // ✓ System uses separate query automatically
|
||||
```
|
||||
|
||||
## Migration Guide
|
||||
|
||||
**No changes needed!** If you're already using `PreloadRelation()`, it now automatically optimizes:
|
||||
|
||||
```go
|
||||
// Before: Always used separate query
|
||||
.PreloadRelation("Provider") // Inefficient: extra round trip
|
||||
|
||||
// After: Automatic optimization
|
||||
.PreloadRelation("Provider") // ✓ Now uses JOIN automatically!
|
||||
```
|
||||
|
||||
## Implementation Details
|
||||
|
||||
### Supported Bun Tags
|
||||
- `rel:has-many` → Separate query
|
||||
- `rel:belongs-to` → JOIN
|
||||
- `rel:has-one` → JOIN
|
||||
- `rel:many-to-many` or `rel:m2m` → Separate query
|
||||
|
||||
### Supported GORM Patterns
|
||||
- `many2many:` tag → Separate query
|
||||
- `foreignKey:` tag → JOIN (belongs-to)
|
||||
- `[]Type` slice without many2many → Separate query (has-many)
|
||||
- `*Type` pointer with foreignKey → JOIN (belongs-to)
|
||||
- `*Type` pointer without foreignKey → JOIN (has-one)
|
||||
|
||||
### Fallback Behavior
|
||||
- `[]Type` (slice) → Separate query (safe default for collections)
|
||||
- `*Type` or `Type` (single) → JOIN (safe default for single relations)
|
||||
- Unknown → Separate query (safest default)
|
||||
|
||||
## Debugging
|
||||
|
||||
To see strategy selection in action:
|
||||
|
||||
```go
|
||||
// Enable debug logging
|
||||
bunAdapter.EnableQueryDebug() // or gormAdapter.EnableQueryDebug()
|
||||
|
||||
// Run your query
|
||||
db.NewSelect().
|
||||
Model(&records).
|
||||
PreloadRelation("RelationName").
|
||||
Scan(ctx, &records)
|
||||
|
||||
// Check logs for:
|
||||
// - "PreloadRelation 'X' detected as: belongs-to"
|
||||
// - "Using JOIN strategy for belongs-to relation 'X'"
|
||||
// - Actual SQL queries executed
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Use PreloadRelation() for everything** - Let the system optimize
|
||||
2. **Define proper relationship tags** - Ensures correct detection
|
||||
3. **Only use JoinRelation() for overrides** - When you know better than auto-detection
|
||||
4. **Enable debug logging during development** - Verify optimal strategies are chosen
|
||||
5. **Trust the system** - It's designed to choose correctly based on relationship type
|
||||
67
pkg/common/adapters/database/alias_test.go
Normal file
67
pkg/common/adapters/database/alias_test.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNormalizeTableAlias(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
query string
|
||||
expectedAlias string
|
||||
tableName string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "strips incorrect alias from simple condition",
|
||||
query: "APIL.rid_hub = 2576",
|
||||
expectedAlias: "apiproviderlink",
|
||||
tableName: "apiproviderlink",
|
||||
want: "rid_hub = 2576",
|
||||
},
|
||||
{
|
||||
name: "keeps correct alias",
|
||||
query: "apiproviderlink.rid_hub = 2576",
|
||||
expectedAlias: "apiproviderlink",
|
||||
tableName: "apiproviderlink",
|
||||
want: "apiproviderlink.rid_hub = 2576",
|
||||
},
|
||||
{
|
||||
name: "strips incorrect alias with multiple conditions",
|
||||
query: "APIL.rid_hub = ? AND APIL.active = ?",
|
||||
expectedAlias: "apiproviderlink",
|
||||
tableName: "apiproviderlink",
|
||||
want: "rid_hub = ? AND active = ?",
|
||||
},
|
||||
{
|
||||
name: "handles mixed correct and incorrect aliases",
|
||||
query: "APIL.rid_hub = ? AND apiproviderlink.active = ?",
|
||||
expectedAlias: "apiproviderlink",
|
||||
tableName: "apiproviderlink",
|
||||
want: "rid_hub = ? AND apiproviderlink.active = ?",
|
||||
},
|
||||
{
|
||||
name: "handles parentheses",
|
||||
query: "(APIL.rid_hub = ?)",
|
||||
expectedAlias: "apiproviderlink",
|
||||
tableName: "apiproviderlink",
|
||||
want: "(rid_hub = ?)",
|
||||
},
|
||||
{
|
||||
name: "no alias in query",
|
||||
query: "rid_hub = ?",
|
||||
expectedAlias: "apiproviderlink",
|
||||
tableName: "apiproviderlink",
|
||||
want: "rid_hub = ?",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := normalizeTableAlias(tt.query, tt.expectedAlias, tt.tableName)
|
||||
if got != tt.want {
|
||||
t.Errorf("normalizeTableAlias() = %q, want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/uptrace/bun"
|
||||
|
||||
@@ -15,6 +16,24 @@ import (
|
||||
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
||||
)
|
||||
|
||||
// QueryDebugHook is a Bun query hook that logs all SQL queries including preloads
|
||||
type QueryDebugHook struct{}
|
||||
|
||||
func (h *QueryDebugHook) BeforeQuery(ctx context.Context, event *bun.QueryEvent) context.Context {
|
||||
return ctx
|
||||
}
|
||||
|
||||
func (h *QueryDebugHook) AfterQuery(ctx context.Context, event *bun.QueryEvent) {
|
||||
query := event.Query
|
||||
duration := time.Since(event.StartTime)
|
||||
|
||||
if event.Err != nil {
|
||||
logger.Error("SQL Query Failed [%s]: %s. Error: %v", duration, query, event.Err)
|
||||
} else {
|
||||
logger.Debug("SQL Query Success [%s]: %s", duration, query)
|
||||
}
|
||||
}
|
||||
|
||||
// BunAdapter adapts Bun to work with our Database interface
|
||||
// This demonstrates how the abstraction works with different ORMs
|
||||
type BunAdapter struct {
|
||||
@@ -26,6 +45,20 @@ func NewBunAdapter(db *bun.DB) *BunAdapter {
|
||||
return &BunAdapter{db: db}
|
||||
}
|
||||
|
||||
// EnableQueryDebug enables query debugging which logs all SQL queries including preloads
|
||||
// This is useful for debugging preload queries that may be failing
|
||||
func (b *BunAdapter) EnableQueryDebug() {
|
||||
b.db.AddQueryHook(&QueryDebugHook{})
|
||||
logger.Info("Bun query debug mode enabled - all SQL queries will be logged")
|
||||
}
|
||||
|
||||
// DisableQueryDebug removes all query hooks
|
||||
func (b *BunAdapter) DisableQueryDebug() {
|
||||
// Create a new DB without hooks
|
||||
// Note: Bun doesn't have a RemoveQueryHook, so we'd need to track hooks manually
|
||||
logger.Info("To disable query debug, recreate the BunAdapter without adding the hook")
|
||||
}
|
||||
|
||||
func (b *BunAdapter) NewSelect() common.SelectQuery {
|
||||
return &BunSelectQuery{
|
||||
query: b.db.NewSelect(),
|
||||
@@ -107,6 +140,8 @@ type BunSelectQuery struct {
|
||||
tableName string // Just the table name, without schema
|
||||
tableAlias string
|
||||
deferredPreloads []deferredPreload // Preloads to execute as separate queries
|
||||
inJoinContext bool // Track if we're in a JOIN relation context
|
||||
joinTableAlias string // Alias to use for JOIN conditions
|
||||
}
|
||||
|
||||
// deferredPreload represents a preload that will be executed as a separate query
|
||||
@@ -156,10 +191,106 @@ func (b *BunSelectQuery) ColumnExpr(query string, args ...interface{}) common.Se
|
||||
}
|
||||
|
||||
func (b *BunSelectQuery) Where(query string, args ...interface{}) common.SelectQuery {
|
||||
// If we're in a JOIN context, add table prefix to unqualified columns
|
||||
if b.inJoinContext && b.joinTableAlias != "" {
|
||||
query = addTablePrefix(query, b.joinTableAlias)
|
||||
} else if b.tableAlias != "" && b.tableName != "" {
|
||||
// If we have a table alias defined, check if the query references a different alias
|
||||
// This can happen in preloads where the user expects a certain alias but Bun generates another
|
||||
query = normalizeTableAlias(query, b.tableAlias, b.tableName)
|
||||
}
|
||||
b.query = b.query.Where(query, args...)
|
||||
return b
|
||||
}
|
||||
|
||||
// addTablePrefix adds a table prefix to unqualified column references
|
||||
// This is used in JOIN contexts where conditions must reference the joined table
|
||||
func addTablePrefix(query, tableAlias string) string {
|
||||
if tableAlias == "" || query == "" {
|
||||
return query
|
||||
}
|
||||
|
||||
// Split on spaces and parentheses to find column references
|
||||
parts := strings.FieldsFunc(query, func(r rune) bool {
|
||||
return r == ' ' || r == '(' || r == ')' || r == ','
|
||||
})
|
||||
|
||||
modified := query
|
||||
for _, part := range parts {
|
||||
// Check if this looks like an unqualified column reference
|
||||
// (no dot, and likely a column name before an operator)
|
||||
if !strings.Contains(part, ".") {
|
||||
// Extract potential column name (before = or other operators)
|
||||
for _, op := range []string{"=", "!=", "<>", ">", ">=", "<", "<=", " LIKE ", " IN ", " IS "} {
|
||||
if strings.Contains(part, op) {
|
||||
colName := strings.Split(part, op)[0]
|
||||
colName = strings.TrimSpace(colName)
|
||||
if colName != "" && !isOperatorOrKeyword(colName) {
|
||||
// Add table prefix
|
||||
prefixed := tableAlias + "." + colName + strings.TrimPrefix(part, colName)
|
||||
modified = strings.ReplaceAll(modified, part, prefixed)
|
||||
logger.Debug("Adding table prefix '%s' to column '%s' in JOIN condition", tableAlias, colName)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return modified
|
||||
}
|
||||
|
||||
// isOperatorOrKeyword checks if a string is likely an operator or SQL keyword
|
||||
func isOperatorOrKeyword(s string) bool {
|
||||
s = strings.ToUpper(strings.TrimSpace(s))
|
||||
keywords := []string{"AND", "OR", "NOT", "IN", "IS", "NULL", "TRUE", "FALSE", "LIKE", "BETWEEN"}
|
||||
for _, kw := range keywords {
|
||||
if s == kw {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// normalizeTableAlias replaces table alias prefixes in SQL conditions
|
||||
// This handles cases where a user references a table alias that doesn't match
|
||||
// what Bun generates (common in preload contexts)
|
||||
func normalizeTableAlias(query, expectedAlias, tableName string) string {
|
||||
// Pattern: <word>.<column> where <word> might be an incorrect alias
|
||||
// We'll look for patterns like "APIL.column" and either:
|
||||
// 1. Remove the alias prefix entirely (safest)
|
||||
// 2. Replace with the expected alias
|
||||
|
||||
// For now, we'll use a simple approach: if the query contains a dot (qualified reference)
|
||||
// and that prefix is not the expected alias or table name, strip it
|
||||
|
||||
// Split on spaces and parentheses to find qualified references
|
||||
parts := strings.FieldsFunc(query, func(r rune) bool {
|
||||
return r == ' ' || r == '(' || r == ')' || r == ','
|
||||
})
|
||||
|
||||
modified := query
|
||||
for _, part := range parts {
|
||||
// Check if this looks like a qualified column reference
|
||||
if dotIndex := strings.Index(part, "."); dotIndex > 0 {
|
||||
prefix := part[:dotIndex]
|
||||
column := part[dotIndex+1:]
|
||||
|
||||
// Check if the prefix matches our expected alias or table name (case-insensitive)
|
||||
if !strings.EqualFold(prefix, expectedAlias) &&
|
||||
!strings.EqualFold(prefix, tableName) &&
|
||||
!strings.EqualFold(prefix, strings.ToLower(tableName)) {
|
||||
// This is a different alias - remove the prefix
|
||||
logger.Debug("Stripping incorrect alias '%s' from WHERE condition, keeping just '%s'", prefix, column)
|
||||
// Replace the qualified reference with just the column name
|
||||
modified = strings.ReplaceAll(modified, part, column)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return modified
|
||||
}
|
||||
|
||||
func (b *BunSelectQuery) WhereOr(query string, args ...interface{}) common.SelectQuery {
|
||||
b.query = b.query.WhereOr(query, args...)
|
||||
return b
|
||||
@@ -288,6 +419,27 @@ func (b *BunSelectQuery) Preload(relation string, conditions ...interface{}) com
|
||||
// }
|
||||
|
||||
func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
|
||||
// Auto-detect relationship type and choose optimal loading strategy
|
||||
// Get the model from the query if available
|
||||
model := b.query.GetModel()
|
||||
if model != nil && model.Value() != nil {
|
||||
relType := reflection.GetRelationType(model.Value(), relation)
|
||||
|
||||
// Log the detected relationship type
|
||||
logger.Debug("PreloadRelation '%s' detected as: %s", relation, relType)
|
||||
|
||||
// If this is a belongs-to or has-one relation, use JOIN for better performance
|
||||
if relType.ShouldUseJoin() {
|
||||
logger.Info("Using JOIN strategy for %s relation '%s'", relType, relation)
|
||||
return b.JoinRelation(relation, apply...)
|
||||
}
|
||||
|
||||
// For has-many, many-to-many, or unknown: use separate query (safer default)
|
||||
if relType == reflection.RelationHasMany || relType == reflection.RelationManyToMany {
|
||||
logger.Debug("Using separate query for %s relation '%s'", relType, relation)
|
||||
}
|
||||
}
|
||||
|
||||
// Check if this relation chain would create problematic long aliases
|
||||
relationParts := strings.Split(relation, ".")
|
||||
aliasChain := strings.ToLower(strings.Join(relationParts, "__"))
|
||||
@@ -350,6 +502,28 @@ func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.S
|
||||
db: b.db,
|
||||
}
|
||||
|
||||
// Try to extract table name and alias from the preload model
|
||||
if model := sq.GetModel(); model != nil && model.Value() != nil {
|
||||
modelValue := model.Value()
|
||||
|
||||
// Extract table name if model implements TableNameProvider
|
||||
if provider, ok := modelValue.(common.TableNameProvider); ok {
|
||||
fullTableName := provider.TableName()
|
||||
wrapper.schema, wrapper.tableName = parseTableName(fullTableName)
|
||||
}
|
||||
|
||||
// Extract table alias if model implements TableAliasProvider
|
||||
if provider, ok := modelValue.(common.TableAliasProvider); ok {
|
||||
wrapper.tableAlias = provider.TableAlias()
|
||||
// Apply the alias to the Bun query so conditions can reference it
|
||||
if wrapper.tableAlias != "" {
|
||||
// Note: Bun's Relation() already sets up the table, but we can add
|
||||
// the alias explicitly if needed
|
||||
logger.Debug("Preload relation '%s' using table alias: %s", relation, wrapper.tableAlias)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Start with the interface value (not pointer)
|
||||
current := common.SelectQuery(wrapper)
|
||||
|
||||
@@ -372,6 +546,36 @@ func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.S
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *BunSelectQuery) JoinRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
|
||||
// JoinRelation uses a LEFT JOIN instead of a separate query
|
||||
// This is more efficient for many-to-one or one-to-one relationships
|
||||
|
||||
logger.Debug("JoinRelation '%s' - Using JOIN strategy with automatic WHERE prefix addition", relation)
|
||||
|
||||
// Wrap the apply functions to automatically add table prefix to WHERE conditions
|
||||
wrappedApply := make([]func(common.SelectQuery) common.SelectQuery, 0, len(apply))
|
||||
for _, fn := range apply {
|
||||
if fn != nil {
|
||||
wrappedFn := func(originalFn func(common.SelectQuery) common.SelectQuery) func(common.SelectQuery) common.SelectQuery {
|
||||
return func(q common.SelectQuery) common.SelectQuery {
|
||||
// Create a special wrapper that adds prefixes to WHERE conditions
|
||||
if bunQuery, ok := q.(*BunSelectQuery); ok {
|
||||
// Mark this query as being in JOIN context
|
||||
bunQuery.inJoinContext = true
|
||||
bunQuery.joinTableAlias = strings.ToLower(relation)
|
||||
}
|
||||
return originalFn(q)
|
||||
}
|
||||
}(fn)
|
||||
wrappedApply = append(wrappedApply, wrappedFn)
|
||||
}
|
||||
}
|
||||
|
||||
// Use PreloadRelation with the wrapped functions
|
||||
// Bun's Relation() will use JOIN for belongs-to and has-one relations
|
||||
return b.PreloadRelation(relation, wrappedApply...)
|
||||
}
|
||||
|
||||
func (b *BunSelectQuery) Order(order string) common.SelectQuery {
|
||||
b.query = b.query.Order(order)
|
||||
return b
|
||||
@@ -410,6 +614,9 @@ func (b *BunSelectQuery) Scan(ctx context.Context, dest interface{}) (err error)
|
||||
// Execute the main query first
|
||||
err = b.query.Scan(ctx, dest)
|
||||
if err != nil {
|
||||
// Log SQL string for debugging
|
||||
sqlStr := b.query.String()
|
||||
logger.Error("BunSelectQuery.Scan failed. SQL: %s. Error: %v", sqlStr, err)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -438,6 +645,9 @@ func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) {
|
||||
// Execute the main query first
|
||||
err = b.query.Scan(ctx)
|
||||
if err != nil {
|
||||
// Log SQL string for debugging
|
||||
sqlStr := b.query.String()
|
||||
logger.Error("BunSelectQuery.ScanModel failed. SQL: %s. Error: %v", sqlStr, err)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -573,15 +783,25 @@ func (b *BunSelectQuery) Count(ctx context.Context) (count int, err error) {
|
||||
// If Model() was set, use bun's native Count() which works properly
|
||||
if b.hasModel {
|
||||
count, err := b.query.Count(ctx)
|
||||
if err != nil {
|
||||
// Log SQL string for debugging
|
||||
sqlStr := b.query.String()
|
||||
logger.Error("BunSelectQuery.Count failed. SQL: %s. Error: %v", sqlStr, err)
|
||||
}
|
||||
return count, err
|
||||
}
|
||||
|
||||
// Otherwise, wrap as subquery to avoid "Model(nil)" error
|
||||
// This is needed when only Table() is set without a model
|
||||
err = b.db.NewSelect().
|
||||
countQuery := b.db.NewSelect().
|
||||
TableExpr("(?) AS subquery", b.query).
|
||||
ColumnExpr("COUNT(*)").
|
||||
Scan(ctx, &count)
|
||||
ColumnExpr("COUNT(*)")
|
||||
err = countQuery.Scan(ctx, &count)
|
||||
if err != nil {
|
||||
// Log SQL string for debugging
|
||||
sqlStr := countQuery.String()
|
||||
logger.Error("BunSelectQuery.Count (subquery) failed. SQL: %s. Error: %v", sqlStr, err)
|
||||
}
|
||||
return count, err
|
||||
}
|
||||
|
||||
@@ -592,7 +812,13 @@ func (b *BunSelectQuery) Exists(ctx context.Context) (exists bool, err error) {
|
||||
exists = false
|
||||
}
|
||||
}()
|
||||
return b.query.Exists(ctx)
|
||||
exists, err = b.query.Exists(ctx)
|
||||
if err != nil {
|
||||
// Log SQL string for debugging
|
||||
sqlStr := b.query.String()
|
||||
logger.Error("BunSelectQuery.Exists failed. SQL: %s. Error: %v", sqlStr, err)
|
||||
}
|
||||
return exists, err
|
||||
}
|
||||
|
||||
// BunInsertQuery implements InsertQuery for Bun
|
||||
@@ -729,6 +955,11 @@ func (b *BunUpdateQuery) Exec(ctx context.Context) (res common.Result, err error
|
||||
}
|
||||
}()
|
||||
result, err := b.query.Exec(ctx)
|
||||
if err != nil {
|
||||
// Log SQL string for debugging
|
||||
sqlStr := b.query.String()
|
||||
logger.Error("BunUpdateQuery.Exec failed. SQL: %s. Error: %v", sqlStr, err)
|
||||
}
|
||||
return &BunResult{result: result}, err
|
||||
}
|
||||
|
||||
@@ -759,6 +990,11 @@ func (b *BunDeleteQuery) Exec(ctx context.Context) (res common.Result, err error
|
||||
}
|
||||
}()
|
||||
result, err := b.query.Exec(ctx)
|
||||
if err != nil {
|
||||
// Log SQL string for debugging
|
||||
sqlStr := b.query.String()
|
||||
logger.Error("BunDeleteQuery.Exec failed. SQL: %s. Error: %v", sqlStr, err)
|
||||
}
|
||||
return &BunResult{result: result}, err
|
||||
}
|
||||
|
||||
|
||||
@@ -23,6 +23,22 @@ func NewGormAdapter(db *gorm.DB) *GormAdapter {
|
||||
return &GormAdapter{db: db}
|
||||
}
|
||||
|
||||
// EnableQueryDebug enables query debugging which logs all SQL queries including preloads
|
||||
// This is useful for debugging preload queries that may be failing
|
||||
func (g *GormAdapter) EnableQueryDebug() *GormAdapter {
|
||||
g.db = g.db.Debug()
|
||||
logger.Info("GORM query debug mode enabled - all SQL queries will be logged")
|
||||
return g
|
||||
}
|
||||
|
||||
// DisableQueryDebug disables query debugging
|
||||
func (g *GormAdapter) DisableQueryDebug() *GormAdapter {
|
||||
// GORM's Debug() creates a new session, so we need to get the base DB
|
||||
// This is a simplified implementation
|
||||
logger.Info("GORM debug mode - create a new adapter without Debug() to disable")
|
||||
return g
|
||||
}
|
||||
|
||||
func (g *GormAdapter) NewSelect() common.SelectQuery {
|
||||
return &GormSelectQuery{db: g.db}
|
||||
}
|
||||
@@ -88,10 +104,12 @@ func (g *GormAdapter) RunInTransaction(ctx context.Context, fn func(common.Datab
|
||||
|
||||
// GormSelectQuery implements SelectQuery for GORM
|
||||
type GormSelectQuery struct {
|
||||
db *gorm.DB
|
||||
schema string // Separated schema name
|
||||
tableName string // Just the table name, without schema
|
||||
tableAlias string
|
||||
db *gorm.DB
|
||||
schema string // Separated schema name
|
||||
tableName string // Just the table name, without schema
|
||||
tableAlias string
|
||||
inJoinContext bool // Track if we're in a JOIN relation context
|
||||
joinTableAlias string // Alias to use for JOIN conditions
|
||||
}
|
||||
|
||||
func (g *GormSelectQuery) Model(model interface{}) common.SelectQuery {
|
||||
@@ -135,10 +153,61 @@ func (g *GormSelectQuery) ColumnExpr(query string, args ...interface{}) common.S
|
||||
}
|
||||
|
||||
func (g *GormSelectQuery) Where(query string, args ...interface{}) common.SelectQuery {
|
||||
// If we're in a JOIN context, add table prefix to unqualified columns
|
||||
if g.inJoinContext && g.joinTableAlias != "" {
|
||||
query = addTablePrefixGorm(query, g.joinTableAlias)
|
||||
}
|
||||
g.db = g.db.Where(query, args...)
|
||||
return g
|
||||
}
|
||||
|
||||
// addTablePrefixGorm adds a table prefix to unqualified column references (GORM version)
|
||||
func addTablePrefixGorm(query, tableAlias string) string {
|
||||
if tableAlias == "" || query == "" {
|
||||
return query
|
||||
}
|
||||
|
||||
// Split on spaces and parentheses to find column references
|
||||
parts := strings.FieldsFunc(query, func(r rune) bool {
|
||||
return r == ' ' || r == '(' || r == ')' || r == ','
|
||||
})
|
||||
|
||||
modified := query
|
||||
for _, part := range parts {
|
||||
// Check if this looks like an unqualified column reference
|
||||
if !strings.Contains(part, ".") {
|
||||
// Extract potential column name (before = or other operators)
|
||||
for _, op := range []string{"=", "!=", "<>", ">", ">=", "<", "<=", " LIKE ", " IN ", " IS "} {
|
||||
if strings.Contains(part, op) {
|
||||
colName := strings.Split(part, op)[0]
|
||||
colName = strings.TrimSpace(colName)
|
||||
if colName != "" && !isOperatorOrKeywordGorm(colName) {
|
||||
// Add table prefix
|
||||
prefixed := tableAlias + "." + colName + strings.TrimPrefix(part, colName)
|
||||
modified = strings.ReplaceAll(modified, part, prefixed)
|
||||
logger.Debug("Adding table prefix '%s' to column '%s' in JOIN condition", tableAlias, colName)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return modified
|
||||
}
|
||||
|
||||
// isOperatorOrKeywordGorm checks if a string is likely an operator or SQL keyword (GORM version)
|
||||
func isOperatorOrKeywordGorm(s string) bool {
|
||||
s = strings.ToUpper(strings.TrimSpace(s))
|
||||
keywords := []string{"AND", "OR", "NOT", "IN", "IS", "NULL", "TRUE", "FALSE", "LIKE", "BETWEEN"}
|
||||
for _, kw := range keywords {
|
||||
if s == kw {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (g *GormSelectQuery) WhereOr(query string, args ...interface{}) common.SelectQuery {
|
||||
g.db = g.db.Or(query, args...)
|
||||
return g
|
||||
@@ -222,6 +291,27 @@ func (g *GormSelectQuery) Preload(relation string, conditions ...interface{}) co
|
||||
}
|
||||
|
||||
func (g *GormSelectQuery) PreloadRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
|
||||
// Auto-detect relationship type and choose optimal loading strategy
|
||||
// Get the model from GORM's statement if available
|
||||
if g.db.Statement != nil && g.db.Statement.Model != nil {
|
||||
relType := reflection.GetRelationType(g.db.Statement.Model, relation)
|
||||
|
||||
// Log the detected relationship type
|
||||
logger.Debug("PreloadRelation '%s' detected as: %s", relation, relType)
|
||||
|
||||
// If this is a belongs-to or has-one relation, use JOIN for better performance
|
||||
if relType.ShouldUseJoin() {
|
||||
logger.Info("Using JOIN strategy for %s relation '%s'", relType, relation)
|
||||
return g.JoinRelation(relation, apply...)
|
||||
}
|
||||
|
||||
// For has-many, many-to-many, or unknown: use separate query (safer default)
|
||||
if relType == reflection.RelationHasMany || relType == reflection.RelationManyToMany {
|
||||
logger.Debug("Using separate query for %s relation '%s'", relType, relation)
|
||||
}
|
||||
}
|
||||
|
||||
// Use GORM's Preload (separate query strategy)
|
||||
g.db = g.db.Preload(relation, func(db *gorm.DB) *gorm.DB {
|
||||
if len(apply) == 0 {
|
||||
return db
|
||||
@@ -251,6 +341,42 @@ func (g *GormSelectQuery) PreloadRelation(relation string, apply ...func(common.
|
||||
return g
|
||||
}
|
||||
|
||||
func (g *GormSelectQuery) JoinRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
|
||||
// JoinRelation uses a JOIN instead of a separate preload query
|
||||
// This is more efficient for many-to-one or one-to-one relationships
|
||||
// as it avoids additional round trips to the database
|
||||
|
||||
// GORM's Joins() method forces a JOIN for the preload
|
||||
logger.Debug("JoinRelation '%s' - Using GORM Joins() with automatic WHERE prefix addition", relation)
|
||||
|
||||
g.db = g.db.Joins(relation, func(db *gorm.DB) *gorm.DB {
|
||||
if len(apply) == 0 {
|
||||
return db
|
||||
}
|
||||
|
||||
wrapper := &GormSelectQuery{
|
||||
db: db,
|
||||
inJoinContext: true, // Mark as JOIN context
|
||||
joinTableAlias: strings.ToLower(relation), // Use relation name as alias
|
||||
}
|
||||
current := common.SelectQuery(wrapper)
|
||||
|
||||
for _, fn := range apply {
|
||||
if fn != nil {
|
||||
current = fn(current)
|
||||
}
|
||||
}
|
||||
|
||||
if finalGorm, ok := current.(*GormSelectQuery); ok {
|
||||
return finalGorm.db
|
||||
}
|
||||
|
||||
return db
|
||||
})
|
||||
|
||||
return g
|
||||
}
|
||||
|
||||
func (g *GormSelectQuery) Order(order string) common.SelectQuery {
|
||||
g.db = g.db.Order(order)
|
||||
return g
|
||||
@@ -282,7 +408,15 @@ func (g *GormSelectQuery) Scan(ctx context.Context, dest interface{}) (err error
|
||||
err = logger.HandlePanic("GormSelectQuery.Scan", r)
|
||||
}
|
||||
}()
|
||||
return g.db.WithContext(ctx).Find(dest).Error
|
||||
err = g.db.WithContext(ctx).Find(dest).Error
|
||||
if err != nil {
|
||||
// Log SQL string for debugging
|
||||
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
||||
return tx.Find(dest)
|
||||
})
|
||||
logger.Error("GormSelectQuery.Scan failed. SQL: %s. Error: %v", sqlStr, err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (g *GormSelectQuery) ScanModel(ctx context.Context) (err error) {
|
||||
@@ -294,7 +428,15 @@ func (g *GormSelectQuery) ScanModel(ctx context.Context) (err error) {
|
||||
if g.db.Statement.Model == nil {
|
||||
return fmt.Errorf("ScanModel requires Model() to be set before scanning")
|
||||
}
|
||||
return g.db.WithContext(ctx).Find(g.db.Statement.Model).Error
|
||||
err = g.db.WithContext(ctx).Find(g.db.Statement.Model).Error
|
||||
if err != nil {
|
||||
// Log SQL string for debugging
|
||||
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
||||
return tx.Find(g.db.Statement.Model)
|
||||
})
|
||||
logger.Error("GormSelectQuery.ScanModel failed. SQL: %s. Error: %v", sqlStr, err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (g *GormSelectQuery) Count(ctx context.Context) (count int, err error) {
|
||||
@@ -306,6 +448,13 @@ func (g *GormSelectQuery) Count(ctx context.Context) (count int, err error) {
|
||||
}()
|
||||
var count64 int64
|
||||
err = g.db.WithContext(ctx).Count(&count64).Error
|
||||
if err != nil {
|
||||
// Log SQL string for debugging
|
||||
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
||||
return tx.Count(&count64)
|
||||
})
|
||||
logger.Error("GormSelectQuery.Count failed. SQL: %s. Error: %v", sqlStr, err)
|
||||
}
|
||||
return int(count64), err
|
||||
}
|
||||
|
||||
@@ -318,6 +467,13 @@ func (g *GormSelectQuery) Exists(ctx context.Context) (exists bool, err error) {
|
||||
}()
|
||||
var count int64
|
||||
err = g.db.WithContext(ctx).Limit(1).Count(&count).Error
|
||||
if err != nil {
|
||||
// Log SQL string for debugging
|
||||
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
||||
return tx.Limit(1).Count(&count)
|
||||
})
|
||||
logger.Error("GormSelectQuery.Exists failed. SQL: %s. Error: %v", sqlStr, err)
|
||||
}
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
@@ -456,6 +612,13 @@ func (g *GormUpdateQuery) Exec(ctx context.Context) (res common.Result, err erro
|
||||
}
|
||||
}()
|
||||
result := g.db.WithContext(ctx).Updates(g.updates)
|
||||
if result.Error != nil {
|
||||
// Log SQL string for debugging
|
||||
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
||||
return tx.Updates(g.updates)
|
||||
})
|
||||
logger.Error("GormUpdateQuery.Exec failed. SQL: %s. Error: %v", sqlStr, result.Error)
|
||||
}
|
||||
return &GormResult{result: result}, result.Error
|
||||
}
|
||||
|
||||
@@ -488,6 +651,13 @@ func (g *GormDeleteQuery) Exec(ctx context.Context) (res common.Result, err erro
|
||||
}
|
||||
}()
|
||||
result := g.db.WithContext(ctx).Delete(g.model)
|
||||
if result.Error != nil {
|
||||
// Log SQL string for debugging
|
||||
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
||||
return tx.Delete(g.model)
|
||||
})
|
||||
logger.Error("GormDeleteQuery.Exec failed. SQL: %s. Error: %v", sqlStr, result.Error)
|
||||
}
|
||||
return &GormResult{result: result}, result.Error
|
||||
}
|
||||
|
||||
|
||||
@@ -38,6 +38,7 @@ type SelectQuery interface {
|
||||
LeftJoin(query string, args ...interface{}) SelectQuery
|
||||
Preload(relation string, conditions ...interface{}) SelectQuery
|
||||
PreloadRelation(relation string, apply ...func(SelectQuery) SelectQuery) SelectQuery
|
||||
JoinRelation(relation string, apply ...func(SelectQuery) SelectQuery) SelectQuery
|
||||
Order(order string) SelectQuery
|
||||
Limit(n int) SelectQuery
|
||||
Offset(n int) SelectQuery
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
@@ -9,81 +8,40 @@ import (
|
||||
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
||||
)
|
||||
|
||||
// ValidateAndFixPreloadWhere validates that the WHERE clause for a preload contains
|
||||
// the relation prefix (alias). If not present, it attempts to add it to column references.
|
||||
// Returns the fixed WHERE clause and an error if it cannot be safely fixed.
|
||||
// ValidateAndFixPreloadWhere validates and normalizes WHERE clauses for preloads
|
||||
//
|
||||
// NOTE: For preload queries, table aliases from the parent query are not valid since
|
||||
// the preload executes as a separate query with its own table alias. This function
|
||||
// now simply validates basic syntax without requiring or adding prefixes.
|
||||
// The actual alias normalization happens in the database adapter layer.
|
||||
//
|
||||
// Returns the WHERE clause and an error if it contains obviously invalid syntax.
|
||||
func ValidateAndFixPreloadWhere(where string, relationName string) (string, error) {
|
||||
if where == "" {
|
||||
return where, nil
|
||||
}
|
||||
|
||||
// Check if the relation name is already present in the WHERE clause
|
||||
lowerWhere := strings.ToLower(where)
|
||||
lowerRelation := strings.ToLower(relationName)
|
||||
where = strings.TrimSpace(where)
|
||||
|
||||
// Check for patterns like "relation.", "relation ", or just "relation" followed by a dot
|
||||
if strings.Contains(lowerWhere, lowerRelation+".") ||
|
||||
strings.Contains(lowerWhere, "`"+lowerRelation+"`.") ||
|
||||
strings.Contains(lowerWhere, "\""+lowerRelation+"\".") {
|
||||
// Relation prefix is already present
|
||||
// Just do basic validation - don't require or add prefixes
|
||||
// The database adapter will handle alias normalization
|
||||
|
||||
// Check if the WHERE clause contains any qualified column references
|
||||
// If it does, log a debug message but don't fail - let the adapter handle it
|
||||
if strings.Contains(where, ".") {
|
||||
logger.Debug("Preload WHERE clause for '%s' contains qualified column references: '%s'. "+
|
||||
"Note: In preload context, table aliases from parent query are not available. "+
|
||||
"The database adapter will normalize aliases automatically.", relationName, where)
|
||||
}
|
||||
|
||||
// Validate that it's not empty or just whitespace
|
||||
if where == "" {
|
||||
return where, nil
|
||||
}
|
||||
|
||||
// If the WHERE clause is complex (contains OR, parentheses, subqueries, etc.),
|
||||
// we can't safely auto-fix it - require explicit prefix
|
||||
if strings.Contains(lowerWhere, " or ") ||
|
||||
strings.Contains(where, "(") ||
|
||||
strings.Contains(where, ")") {
|
||||
return "", fmt.Errorf("preload WHERE condition must reference the relation '%s' (e.g., '%s.column_name'). Complex WHERE clauses with OR/parentheses must explicitly use the relation prefix", relationName, relationName)
|
||||
}
|
||||
|
||||
// Try to add the relation prefix to simple column references
|
||||
// This handles basic cases like "column = value" or "column = value AND other_column = value"
|
||||
// Split by AND to handle multiple conditions (case-insensitive)
|
||||
originalConditions := strings.Split(where, " AND ")
|
||||
|
||||
// If uppercase split didn't work, try lowercase
|
||||
if len(originalConditions) == 1 {
|
||||
originalConditions = strings.Split(where, " and ")
|
||||
}
|
||||
|
||||
fixedConditions := make([]string, 0, len(originalConditions))
|
||||
|
||||
for _, cond := range originalConditions {
|
||||
cond = strings.TrimSpace(cond)
|
||||
if cond == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if this condition already has a table prefix (contains a dot)
|
||||
if strings.Contains(cond, ".") {
|
||||
fixedConditions = append(fixedConditions, cond)
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if this is a SQL expression/literal that shouldn't be prefixed
|
||||
lowerCond := strings.ToLower(strings.TrimSpace(cond))
|
||||
if IsSQLExpression(lowerCond) {
|
||||
// Don't prefix SQL expressions like "true", "false", "1=1", etc.
|
||||
fixedConditions = append(fixedConditions, cond)
|
||||
continue
|
||||
}
|
||||
|
||||
// Extract the column name (first identifier before operator)
|
||||
columnName := ExtractColumnName(cond)
|
||||
if columnName == "" {
|
||||
// Can't identify column name, require explicit prefix
|
||||
return "", fmt.Errorf("preload WHERE condition must reference the relation '%s' (e.g., '%s.column_name'). Cannot auto-fix condition: %s", relationName, relationName, cond)
|
||||
}
|
||||
|
||||
// Add relation prefix to the column name only
|
||||
fixedCond := strings.Replace(cond, columnName, relationName+"."+columnName, 1)
|
||||
fixedConditions = append(fixedConditions, fixedCond)
|
||||
}
|
||||
|
||||
fixedWhere := strings.Join(fixedConditions, " AND ")
|
||||
logger.Debug("Auto-fixed preload WHERE clause: '%s' -> '%s'", where, fixedWhere)
|
||||
return fixedWhere, nil
|
||||
// Return the WHERE clause as-is
|
||||
// The BunSelectQuery.Where() method will handle alias normalization via normalizeTableAlias()
|
||||
return where, nil
|
||||
}
|
||||
|
||||
// IsSQLExpression checks if a condition is a SQL expression that shouldn't be prefixed
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -9,18 +9,18 @@ import (
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// TestSqlInt16 tests SqlInt16 type
|
||||
func TestSqlInt16(t *testing.T) {
|
||||
// TestNewSqlInt16 tests NewSqlInt16 type
|
||||
func TestNewSqlInt16(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
expected SqlInt16
|
||||
}{
|
||||
{"int", 42, SqlInt16(42)},
|
||||
{"int32", int32(100), SqlInt16(100)},
|
||||
{"int64", int64(200), SqlInt16(200)},
|
||||
{"string", "123", SqlInt16(123)},
|
||||
{"nil", nil, SqlInt16(0)},
|
||||
{"int", 42, Null(int16(42), true)},
|
||||
{"int32", int32(100), NewSqlInt16(100)},
|
||||
{"int64", int64(200), NewSqlInt16(200)},
|
||||
{"string", "123", NewSqlInt16(123)},
|
||||
{"nil", nil, Null(int16(0), false)},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -36,15 +36,15 @@ func TestSqlInt16(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlInt16_Value(t *testing.T) {
|
||||
func TestNewSqlInt16_Value(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input SqlInt16
|
||||
expected driver.Value
|
||||
}{
|
||||
{"zero", SqlInt16(0), nil},
|
||||
{"positive", SqlInt16(42), int64(42)},
|
||||
{"negative", SqlInt16(-10), int64(-10)},
|
||||
{"zero", Null(int16(0), false), nil},
|
||||
{"positive", NewSqlInt16(42), int16(42)},
|
||||
{"negative", NewSqlInt16(-10), int16(-10)},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -60,8 +60,8 @@ func TestSqlInt16_Value(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlInt16_JSON(t *testing.T) {
|
||||
n := SqlInt16(42)
|
||||
func TestNewSqlInt16_JSON(t *testing.T) {
|
||||
n := NewSqlInt16(42)
|
||||
|
||||
// Marshal
|
||||
data, err := json.Marshal(n)
|
||||
@@ -78,24 +78,24 @@ func TestSqlInt16_JSON(t *testing.T) {
|
||||
if err := json.Unmarshal([]byte("123"), &n2); err != nil {
|
||||
t.Fatalf("Unmarshal failed: %v", err)
|
||||
}
|
||||
if n2 != 123 {
|
||||
t.Errorf("expected 123, got %d", n2)
|
||||
if n2.Int64() != 123 {
|
||||
t.Errorf("expected 123, got %d", n2.Int64())
|
||||
}
|
||||
}
|
||||
|
||||
// TestSqlInt64 tests SqlInt64 type
|
||||
func TestSqlInt64(t *testing.T) {
|
||||
// TestNewSqlInt64 tests NewSqlInt64 type
|
||||
func TestNewSqlInt64(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
expected SqlInt64
|
||||
}{
|
||||
{"int", 42, SqlInt64(42)},
|
||||
{"int32", int32(100), SqlInt64(100)},
|
||||
{"int64", int64(9223372036854775807), SqlInt64(9223372036854775807)},
|
||||
{"uint32", uint32(100), SqlInt64(100)},
|
||||
{"uint64", uint64(200), SqlInt64(200)},
|
||||
{"nil", nil, SqlInt64(0)},
|
||||
{"int", 42, NewSqlInt64(42)},
|
||||
{"int32", int32(100), NewSqlInt64(100)},
|
||||
{"int64", int64(9223372036854775807), NewSqlInt64(9223372036854775807)},
|
||||
{"uint32", uint32(100), NewSqlInt64(100)},
|
||||
{"uint64", uint64(200), NewSqlInt64(200)},
|
||||
{"nil", nil, SqlInt64{}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -135,8 +135,8 @@ func TestSqlFloat64(t *testing.T) {
|
||||
if n.Valid != tt.valid {
|
||||
t.Errorf("expected valid=%v, got valid=%v", tt.valid, n.Valid)
|
||||
}
|
||||
if tt.valid && n.Float64 != tt.expected {
|
||||
t.Errorf("expected %v, got %v", tt.expected, n.Float64)
|
||||
if tt.valid && n.Float64() != tt.expected {
|
||||
t.Errorf("expected %v, got %v", tt.expected, n.Float64())
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -162,7 +162,7 @@ func TestSqlTimeStamp(t *testing.T) {
|
||||
if err := ts.Scan(tt.input); err != nil {
|
||||
t.Fatalf("Scan failed: %v", err)
|
||||
}
|
||||
if ts.GetTime().IsZero() {
|
||||
if ts.Time().IsZero() {
|
||||
t.Error("expected non-zero time")
|
||||
}
|
||||
})
|
||||
@@ -171,7 +171,7 @@ func TestSqlTimeStamp(t *testing.T) {
|
||||
|
||||
func TestSqlTimeStamp_JSON(t *testing.T) {
|
||||
now := time.Date(2024, 1, 15, 10, 30, 45, 0, time.UTC)
|
||||
ts := SqlTimeStamp(now)
|
||||
ts := NewSqlTimeStamp(now)
|
||||
|
||||
// Marshal
|
||||
data, err := json.Marshal(ts)
|
||||
@@ -188,8 +188,8 @@ func TestSqlTimeStamp_JSON(t *testing.T) {
|
||||
if err := json.Unmarshal([]byte(`"2024-01-15T10:30:45"`), &ts2); err != nil {
|
||||
t.Fatalf("Unmarshal failed: %v", err)
|
||||
}
|
||||
if ts2.GetTime().Year() != 2024 {
|
||||
t.Errorf("expected year 2024, got %d", ts2.GetTime().Year())
|
||||
if ts2.Time().Year() != 2024 {
|
||||
t.Errorf("expected year 2024, got %d", ts2.Time().Year())
|
||||
}
|
||||
|
||||
// Test null
|
||||
@@ -226,7 +226,7 @@ func TestSqlDate(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSqlDate_JSON(t *testing.T) {
|
||||
date := SqlDate(time.Date(2024, 1, 15, 0, 0, 0, 0, time.UTC))
|
||||
date := NewSqlDate(time.Date(2024, 1, 15, 0, 0, 0, 0, time.UTC))
|
||||
|
||||
// Marshal
|
||||
data, err := json.Marshal(date)
|
||||
@@ -471,8 +471,8 @@ func TestSqlUUID_Scan(t *testing.T) {
|
||||
if u.Valid != tt.valid {
|
||||
t.Errorf("expected valid=%v, got valid=%v", tt.valid, u.Valid)
|
||||
}
|
||||
if tt.valid && u.String != tt.expected {
|
||||
t.Errorf("expected %s, got %s", tt.expected, u.String)
|
||||
if tt.valid && u.String() != tt.expected {
|
||||
t.Errorf("expected %s, got %s", tt.expected, u.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -480,13 +480,13 @@ func TestSqlUUID_Scan(t *testing.T) {
|
||||
|
||||
func TestSqlUUID_Value(t *testing.T) {
|
||||
testUUID := uuid.New()
|
||||
u := SqlUUID{String: testUUID.String(), Valid: true}
|
||||
u := NewSqlUUID(testUUID)
|
||||
|
||||
val, err := u.Value()
|
||||
if err != nil {
|
||||
t.Fatalf("Value failed: %v", err)
|
||||
}
|
||||
if val != testUUID.String() {
|
||||
if val != testUUID {
|
||||
t.Errorf("expected %s, got %s", testUUID.String(), val)
|
||||
}
|
||||
|
||||
@@ -503,7 +503,7 @@ func TestSqlUUID_Value(t *testing.T) {
|
||||
|
||||
func TestSqlUUID_JSON(t *testing.T) {
|
||||
testUUID := uuid.New()
|
||||
u := SqlUUID{String: testUUID.String(), Valid: true}
|
||||
u := NewSqlUUID(testUUID)
|
||||
|
||||
// Marshal
|
||||
data, err := json.Marshal(u)
|
||||
@@ -520,8 +520,8 @@ func TestSqlUUID_JSON(t *testing.T) {
|
||||
if err := json.Unmarshal([]byte(`"`+testUUID.String()+`"`), &u2); err != nil {
|
||||
t.Fatalf("Unmarshal failed: %v", err)
|
||||
}
|
||||
if u2.String != testUUID.String() {
|
||||
t.Errorf("expected %s, got %s", testUUID.String(), u2.String)
|
||||
if u2.String() != testUUID.String() {
|
||||
t.Errorf("expected %s, got %s", testUUID.String(), u2.String())
|
||||
}
|
||||
|
||||
// Test null
|
||||
|
||||
@@ -16,8 +16,8 @@ import (
|
||||
|
||||
// MockDatabase implements common.Database interface for testing
|
||||
type MockDatabase struct {
|
||||
QueryFunc func(ctx context.Context, dest interface{}, query string, args ...interface{}) error
|
||||
ExecFunc func(ctx context.Context, query string, args ...interface{}) (common.Result, error)
|
||||
QueryFunc func(ctx context.Context, dest interface{}, query string, args ...interface{}) error
|
||||
ExecFunc func(ctx context.Context, query string, args ...interface{}) (common.Result, error)
|
||||
RunInTransactionFunc func(ctx context.Context, fn func(common.Database) error) error
|
||||
}
|
||||
|
||||
@@ -161,9 +161,9 @@ func TestExtractInputVariables(t *testing.T) {
|
||||
handler := NewHandler(&MockDatabase{})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
sqlQuery string
|
||||
expectedVars []string
|
||||
name string
|
||||
sqlQuery string
|
||||
expectedVars []string
|
||||
}{
|
||||
{
|
||||
name: "No variables",
|
||||
@@ -340,9 +340,9 @@ func TestSqlQryWhere(t *testing.T) {
|
||||
// TestGetIPAddress tests IP address extraction
|
||||
func TestGetIPAddress(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupReq func() *http.Request
|
||||
expected string
|
||||
name string
|
||||
setupReq func() *http.Request
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "X-Forwarded-For header",
|
||||
@@ -782,9 +782,10 @@ func TestReplaceMetaVariables(t *testing.T) {
|
||||
handler := NewHandler(&MockDatabase{})
|
||||
|
||||
userCtx := &security.UserContext{
|
||||
UserID: 123,
|
||||
UserName: "testuser",
|
||||
SessionID: "456",
|
||||
UserID: 123,
|
||||
UserName: "testuser",
|
||||
SessionID: "ABC456",
|
||||
SessionRID: 456,
|
||||
}
|
||||
|
||||
metainfo := map[string]interface{}{
|
||||
@@ -821,6 +822,12 @@ func TestReplaceMetaVariables(t *testing.T) {
|
||||
expectedCheck: func(result string) bool {
|
||||
return strings.Contains(result, "456")
|
||||
},
|
||||
}, {
|
||||
name: "Replace [id_session]",
|
||||
sqlQuery: "SELECT * FROM sessions WHERE session_id = [id_session]",
|
||||
expectedCheck: func(result string) bool {
|
||||
return strings.Contains(result, "ABC456")
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -28,6 +28,10 @@ func NewModelRegistry() *DefaultModelRegistry {
|
||||
}
|
||||
}
|
||||
|
||||
func GetDefaultRegistry() *DefaultModelRegistry {
|
||||
return defaultRegistry
|
||||
}
|
||||
|
||||
func SetDefaultRegistry(registry *DefaultModelRegistry) {
|
||||
registriesMutex.Lock()
|
||||
defer registriesMutex.Unlock()
|
||||
|
||||
@@ -750,6 +750,118 @@ func ConvertToNumericType(value string, kind reflect.Kind) (interface{}, error)
|
||||
return nil, fmt.Errorf("unsupported numeric type: %v", kind)
|
||||
}
|
||||
|
||||
// RelationType represents the type of database relationship
|
||||
type RelationType string
|
||||
|
||||
const (
|
||||
RelationHasMany RelationType = "has-many" // 1:N - use separate query
|
||||
RelationBelongsTo RelationType = "belongs-to" // N:1 - use JOIN
|
||||
RelationHasOne RelationType = "has-one" // 1:1 - use JOIN
|
||||
RelationManyToMany RelationType = "many-to-many" // M:N - use separate query
|
||||
RelationUnknown RelationType = "unknown"
|
||||
)
|
||||
|
||||
// ShouldUseJoin returns true if the relation type should use a JOIN instead of separate query
|
||||
func (rt RelationType) ShouldUseJoin() bool {
|
||||
return rt == RelationBelongsTo || rt == RelationHasOne
|
||||
}
|
||||
|
||||
// GetRelationType inspects the model's struct tags to determine the relationship type
|
||||
// It checks both Bun and GORM tags to identify the relationship cardinality
|
||||
func GetRelationType(model interface{}, fieldName string) RelationType {
|
||||
if model == nil || fieldName == "" {
|
||||
return RelationUnknown
|
||||
}
|
||||
|
||||
modelType := reflect.TypeOf(model)
|
||||
if modelType == nil {
|
||||
return RelationUnknown
|
||||
}
|
||||
|
||||
if modelType.Kind() == reflect.Ptr {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
|
||||
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||
return RelationUnknown
|
||||
}
|
||||
|
||||
// Find the field
|
||||
for i := 0; i < modelType.NumField(); i++ {
|
||||
field := modelType.Field(i)
|
||||
|
||||
// Check if field name matches (case-insensitive)
|
||||
if !strings.EqualFold(field.Name, fieldName) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check Bun tags first
|
||||
bunTag := field.Tag.Get("bun")
|
||||
if bunTag != "" && strings.Contains(bunTag, "rel:") {
|
||||
// Parse bun relation tag: rel:has-many, rel:belongs-to, rel:has-one, rel:many-to-many
|
||||
parts := strings.Split(bunTag, ",")
|
||||
for _, part := range parts {
|
||||
part = strings.TrimSpace(part)
|
||||
if strings.HasPrefix(part, "rel:") {
|
||||
relType := strings.TrimPrefix(part, "rel:")
|
||||
switch relType {
|
||||
case "has-many":
|
||||
return RelationHasMany
|
||||
case "belongs-to":
|
||||
return RelationBelongsTo
|
||||
case "has-one":
|
||||
return RelationHasOne
|
||||
case "many-to-many", "m2m":
|
||||
return RelationManyToMany
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check GORM tags
|
||||
gormTag := field.Tag.Get("gorm")
|
||||
if gormTag != "" {
|
||||
// GORM uses different patterns:
|
||||
// - foreignKey: usually indicates belongs-to or has-one
|
||||
// - many2many: indicates many-to-many
|
||||
// - Field type (slice vs pointer) helps determine cardinality
|
||||
|
||||
if strings.Contains(gormTag, "many2many:") {
|
||||
return RelationManyToMany
|
||||
}
|
||||
|
||||
// Check field type for cardinality hints
|
||||
fieldType := field.Type
|
||||
if fieldType.Kind() == reflect.Slice {
|
||||
// Slice indicates has-many or many-to-many
|
||||
return RelationHasMany
|
||||
}
|
||||
if fieldType.Kind() == reflect.Ptr {
|
||||
// Pointer to single struct usually indicates belongs-to or has-one
|
||||
// Check if it has foreignKey (belongs-to) or references (has-one)
|
||||
if strings.Contains(gormTag, "foreignKey:") {
|
||||
return RelationBelongsTo
|
||||
}
|
||||
return RelationHasOne
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to field type inference
|
||||
fieldType := field.Type
|
||||
if fieldType.Kind() == reflect.Slice {
|
||||
// Slice of structs → has-many
|
||||
return RelationHasMany
|
||||
}
|
||||
if fieldType.Kind() == reflect.Ptr || fieldType.Kind() == reflect.Struct {
|
||||
// Single struct → belongs-to (default assumption for safety)
|
||||
// Using belongs-to as default ensures we use JOIN, which is safer
|
||||
return RelationBelongsTo
|
||||
}
|
||||
}
|
||||
|
||||
return RelationUnknown
|
||||
}
|
||||
|
||||
// GetRelationModel gets the model type for a relation field
|
||||
// It searches for the field by name in the following order (case-insensitive):
|
||||
// 1. Actual field name
|
||||
|
||||
Reference in New Issue
Block a user