mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2025-12-30 08:14:25 +00:00
Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1cd04b7083 | ||
|
|
0d4909054c | ||
|
|
745564f2e7 | ||
|
|
311e50bfdd |
@@ -9,6 +9,7 @@ import (
|
|||||||
"github.com/uptrace/bun"
|
"github.com/uptrace/bun"
|
||||||
|
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
|
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
||||||
)
|
)
|
||||||
@@ -43,12 +44,22 @@ func (b *BunAdapter) NewDelete() common.DeleteQuery {
|
|||||||
return &BunDeleteQuery{query: b.db.NewDelete()}
|
return &BunDeleteQuery{query: b.db.NewDelete()}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *BunAdapter) Exec(ctx context.Context, query string, args ...interface{}) (common.Result, error) {
|
func (b *BunAdapter) Exec(ctx context.Context, query string, args ...interface{}) (res common.Result, err error) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
err = logger.HandlePanic("BunAdapter.Exec", r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
result, err := b.db.ExecContext(ctx, query, args...)
|
result, err := b.db.ExecContext(ctx, query, args...)
|
||||||
return &BunResult{result: result}, err
|
return &BunResult{result: result}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *BunAdapter) Query(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
|
func (b *BunAdapter) Query(ctx context.Context, dest interface{}, query string, args ...interface{}) (err error) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
err = logger.HandlePanic("BunAdapter.Query", r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
return b.db.NewRaw(query, args...).Scan(ctx, dest)
|
return b.db.NewRaw(query, args...).Scan(ctx, dest)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -73,7 +84,12 @@ func (b *BunAdapter) RollbackTx(ctx context.Context) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *BunAdapter) RunInTransaction(ctx context.Context, fn func(common.Database) error) error {
|
func (b *BunAdapter) RunInTransaction(ctx context.Context, fn func(common.Database) error) (err error) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
err = logger.HandlePanic("BunAdapter.RunInTransaction", r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
return b.db.RunInTx(ctx, &sql.TxOptions{}, func(ctx context.Context, tx bun.Tx) error {
|
return b.db.RunInTx(ctx, &sql.TxOptions{}, func(ctx context.Context, tx bun.Tx) error {
|
||||||
// Create adapter with transaction
|
// Create adapter with transaction
|
||||||
adapter := &BunTxAdapter{tx: tx}
|
adapter := &BunTxAdapter{tx: tx}
|
||||||
@@ -219,6 +235,11 @@ func (b *BunSelectQuery) Preload(relation string, conditions ...interface{}) com
|
|||||||
|
|
||||||
func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
|
func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
|
||||||
b.query = b.query.Relation(relation, func(sq *bun.SelectQuery) *bun.SelectQuery {
|
b.query = b.query.Relation(relation, func(sq *bun.SelectQuery) *bun.SelectQuery {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
logger.HandlePanic("BunSelectQuery.PreloadRelation", r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
if len(apply) == 0 {
|
if len(apply) == 0 {
|
||||||
return sq
|
return sq
|
||||||
}
|
}
|
||||||
@@ -276,15 +297,38 @@ func (b *BunSelectQuery) Having(having string, args ...interface{}) common.Selec
|
|||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *BunSelectQuery) Scan(ctx context.Context, dest interface{}) error {
|
func (b *BunSelectQuery) Scan(ctx context.Context, dest interface{}) (err error) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
err = logger.HandlePanic("BunSelectQuery.Scan", r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
if dest == nil {
|
||||||
|
return fmt.Errorf("destination cannot be nil")
|
||||||
|
}
|
||||||
return b.query.Scan(ctx, dest)
|
return b.query.Scan(ctx, dest)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *BunSelectQuery) ScanModel(ctx context.Context) error {
|
func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
err = logger.HandlePanic("BunSelectQuery.ScanModel", r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
if b.query.GetModel() == nil {
|
||||||
|
return fmt.Errorf("model is nil")
|
||||||
|
}
|
||||||
|
|
||||||
return b.query.Scan(ctx)
|
return b.query.Scan(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *BunSelectQuery) Count(ctx context.Context) (int, error) {
|
func (b *BunSelectQuery) Count(ctx context.Context) (count int, err error) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
err = logger.HandlePanic("BunSelectQuery.Count", r)
|
||||||
|
count = 0
|
||||||
|
}
|
||||||
|
}()
|
||||||
// If Model() was set, use bun's native Count() which works properly
|
// If Model() was set, use bun's native Count() which works properly
|
||||||
if b.hasModel {
|
if b.hasModel {
|
||||||
count, err := b.query.Count(ctx)
|
count, err := b.query.Count(ctx)
|
||||||
@@ -293,15 +337,20 @@ func (b *BunSelectQuery) Count(ctx context.Context) (int, error) {
|
|||||||
|
|
||||||
// Otherwise, wrap as subquery to avoid "Model(nil)" error
|
// Otherwise, wrap as subquery to avoid "Model(nil)" error
|
||||||
// This is needed when only Table() is set without a model
|
// This is needed when only Table() is set without a model
|
||||||
var count int
|
err = b.db.NewSelect().
|
||||||
err := b.db.NewSelect().
|
|
||||||
TableExpr("(?) AS subquery", b.query).
|
TableExpr("(?) AS subquery", b.query).
|
||||||
ColumnExpr("COUNT(*)").
|
ColumnExpr("COUNT(*)").
|
||||||
Scan(ctx, &count)
|
Scan(ctx, &count)
|
||||||
return count, err
|
return count, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *BunSelectQuery) Exists(ctx context.Context) (bool, error) {
|
func (b *BunSelectQuery) Exists(ctx context.Context) (exists bool, err error) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
err = logger.HandlePanic("BunSelectQuery.Exists", r)
|
||||||
|
exists = false
|
||||||
|
}
|
||||||
|
}()
|
||||||
return b.query.Exists(ctx)
|
return b.query.Exists(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -320,7 +369,6 @@ func (b *BunInsertQuery) Model(model interface{}) common.InsertQuery {
|
|||||||
|
|
||||||
func (b *BunInsertQuery) Table(table string) common.InsertQuery {
|
func (b *BunInsertQuery) Table(table string) common.InsertQuery {
|
||||||
if b.hasModel {
|
if b.hasModel {
|
||||||
// If model is set, do not override table name
|
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
b.query = b.query.Table(table)
|
b.query = b.query.Table(table)
|
||||||
@@ -347,7 +395,12 @@ func (b *BunInsertQuery) Returning(columns ...string) common.InsertQuery {
|
|||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *BunInsertQuery) Exec(ctx context.Context) (common.Result, error) {
|
func (b *BunInsertQuery) Exec(ctx context.Context) (res common.Result, err error) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
err = logger.HandlePanic("BunInsertQuery.Exec", r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
if b.values != nil && len(b.values) > 0 {
|
if b.values != nil && len(b.values) > 0 {
|
||||||
if !b.hasModel {
|
if !b.hasModel {
|
||||||
// If no model was set, use the values map as the model
|
// If no model was set, use the values map as the model
|
||||||
@@ -428,7 +481,12 @@ func (b *BunUpdateQuery) Returning(columns ...string) common.UpdateQuery {
|
|||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *BunUpdateQuery) Exec(ctx context.Context) (common.Result, error) {
|
func (b *BunUpdateQuery) Exec(ctx context.Context) (res common.Result, err error) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
err = logger.HandlePanic("BunUpdateQuery.Exec", r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
result, err := b.query.Exec(ctx)
|
result, err := b.query.Exec(ctx)
|
||||||
return &BunResult{result: result}, err
|
return &BunResult{result: result}, err
|
||||||
}
|
}
|
||||||
@@ -453,7 +511,12 @@ func (b *BunDeleteQuery) Where(query string, args ...interface{}) common.DeleteQ
|
|||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *BunDeleteQuery) Exec(ctx context.Context) (common.Result, error) {
|
func (b *BunDeleteQuery) Exec(ctx context.Context) (res common.Result, err error) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
err = logger.HandlePanic("BunDeleteQuery.Exec", r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
result, err := b.query.Exec(ctx)
|
result, err := b.query.Exec(ctx)
|
||||||
return &BunResult{result: result}, err
|
return &BunResult{result: result}, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
|
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
|
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
||||||
)
|
)
|
||||||
@@ -38,12 +39,22 @@ func (g *GormAdapter) NewDelete() common.DeleteQuery {
|
|||||||
return &GormDeleteQuery{db: g.db}
|
return &GormDeleteQuery{db: g.db}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GormAdapter) Exec(ctx context.Context, query string, args ...interface{}) (common.Result, error) {
|
func (g *GormAdapter) Exec(ctx context.Context, query string, args ...interface{}) (res common.Result, err error) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
err = logger.HandlePanic("GormAdapter.Exec", r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
result := g.db.WithContext(ctx).Exec(query, args...)
|
result := g.db.WithContext(ctx).Exec(query, args...)
|
||||||
return &GormResult{result: result}, result.Error
|
return &GormResult{result: result}, result.Error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GormAdapter) Query(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
|
func (g *GormAdapter) Query(ctx context.Context, dest interface{}, query string, args ...interface{}) (err error) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
err = logger.HandlePanic("GormAdapter.Query", r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
return g.db.WithContext(ctx).Raw(query, args...).Find(dest).Error
|
return g.db.WithContext(ctx).Raw(query, args...).Find(dest).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -63,7 +74,12 @@ func (g *GormAdapter) RollbackTx(ctx context.Context) error {
|
|||||||
return g.db.WithContext(ctx).Rollback().Error
|
return g.db.WithContext(ctx).Rollback().Error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GormAdapter) RunInTransaction(ctx context.Context, fn func(common.Database) error) error {
|
func (g *GormAdapter) RunInTransaction(ctx context.Context, fn func(common.Database) error) (err error) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
err = logger.HandlePanic("GormAdapter.RunInTransaction", r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
return g.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
return g.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||||
adapter := &GormAdapter{db: tx}
|
adapter := &GormAdapter{db: tx}
|
||||||
return fn(adapter)
|
return fn(adapter)
|
||||||
@@ -255,26 +271,48 @@ func (g *GormSelectQuery) Having(having string, args ...interface{}) common.Sele
|
|||||||
return g
|
return g
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GormSelectQuery) Scan(ctx context.Context, dest interface{}) error {
|
func (g *GormSelectQuery) Scan(ctx context.Context, dest interface{}) (err error) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
err = logger.HandlePanic("GormSelectQuery.Scan", r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
return g.db.WithContext(ctx).Find(dest).Error
|
return g.db.WithContext(ctx).Find(dest).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GormSelectQuery) ScanModel(ctx context.Context) error {
|
func (g *GormSelectQuery) ScanModel(ctx context.Context) (err error) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
err = logger.HandlePanic("GormSelectQuery.ScanModel", r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
if g.db.Statement.Model == nil {
|
if g.db.Statement.Model == nil {
|
||||||
return fmt.Errorf("ScanModel requires Model() to be set before scanning")
|
return fmt.Errorf("ScanModel requires Model() to be set before scanning")
|
||||||
}
|
}
|
||||||
return g.db.WithContext(ctx).Find(g.db.Statement.Model).Error
|
return g.db.WithContext(ctx).Find(g.db.Statement.Model).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GormSelectQuery) Count(ctx context.Context) (int, error) {
|
func (g *GormSelectQuery) Count(ctx context.Context) (count int, err error) {
|
||||||
var count int64
|
defer func() {
|
||||||
err := g.db.WithContext(ctx).Count(&count).Error
|
if r := recover(); r != nil {
|
||||||
return int(count), err
|
err = logger.HandlePanic("GormSelectQuery.Count", r)
|
||||||
|
count = 0
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
var count64 int64
|
||||||
|
err = g.db.WithContext(ctx).Count(&count64).Error
|
||||||
|
return int(count64), err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GormSelectQuery) Exists(ctx context.Context) (bool, error) {
|
func (g *GormSelectQuery) Exists(ctx context.Context) (exists bool, err error) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
err = logger.HandlePanic("GormSelectQuery.Exists", r)
|
||||||
|
exists = false
|
||||||
|
}
|
||||||
|
}()
|
||||||
var count int64
|
var count int64
|
||||||
err := g.db.WithContext(ctx).Limit(1).Count(&count).Error
|
err = g.db.WithContext(ctx).Limit(1).Count(&count).Error
|
||||||
return count > 0, err
|
return count > 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -314,7 +352,12 @@ func (g *GormInsertQuery) Returning(columns ...string) common.InsertQuery {
|
|||||||
return g
|
return g
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GormInsertQuery) Exec(ctx context.Context) (common.Result, error) {
|
func (g *GormInsertQuery) Exec(ctx context.Context) (res common.Result, err error) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
err = logger.HandlePanic("GormInsertQuery.Exec", r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
var result *gorm.DB
|
var result *gorm.DB
|
||||||
switch {
|
switch {
|
||||||
case g.model != nil:
|
case g.model != nil:
|
||||||
@@ -401,7 +444,12 @@ func (g *GormUpdateQuery) Returning(columns ...string) common.UpdateQuery {
|
|||||||
return g
|
return g
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GormUpdateQuery) Exec(ctx context.Context) (common.Result, error) {
|
func (g *GormUpdateQuery) Exec(ctx context.Context) (res common.Result, err error) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
err = logger.HandlePanic("GormUpdateQuery.Exec", r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
result := g.db.WithContext(ctx).Updates(g.updates)
|
result := g.db.WithContext(ctx).Updates(g.updates)
|
||||||
return &GormResult{result: result}, result.Error
|
return &GormResult{result: result}, result.Error
|
||||||
}
|
}
|
||||||
@@ -428,7 +476,12 @@ func (g *GormDeleteQuery) Where(query string, args ...interface{}) common.Delete
|
|||||||
return g
|
return g
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GormDeleteQuery) Exec(ctx context.Context) (common.Result, error) {
|
func (g *GormDeleteQuery) Exec(ctx context.Context) (res common.Result, err error) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
err = logger.HandlePanic("GormDeleteQuery.Exec", r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
result := g.db.WithContext(ctx).Delete(g.model)
|
result := g.db.WithContext(ctx).Delete(g.model)
|
||||||
return &GormResult{result: result}, result.Error
|
return &GormResult{result: result}, result.Error
|
||||||
}
|
}
|
||||||
|
|||||||
136
pkg/common/sql_helpers.go
Normal file
136
pkg/common/sql_helpers.go
Normal file
@@ -0,0 +1,136 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ValidateAndFixPreloadWhere validates that the WHERE clause for a preload contains
|
||||||
|
// the relation prefix (alias). If not present, it attempts to add it to column references.
|
||||||
|
// Returns the fixed WHERE clause and an error if it cannot be safely fixed.
|
||||||
|
func ValidateAndFixPreloadWhere(where string, relationName string) (string, error) {
|
||||||
|
if where == "" {
|
||||||
|
return where, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the relation name is already present in the WHERE clause
|
||||||
|
lowerWhere := strings.ToLower(where)
|
||||||
|
lowerRelation := strings.ToLower(relationName)
|
||||||
|
|
||||||
|
// Check for patterns like "relation.", "relation ", or just "relation" followed by a dot
|
||||||
|
if strings.Contains(lowerWhere, lowerRelation+".") ||
|
||||||
|
strings.Contains(lowerWhere, "`"+lowerRelation+"`.") ||
|
||||||
|
strings.Contains(lowerWhere, "\""+lowerRelation+"\".") {
|
||||||
|
// Relation prefix is already present
|
||||||
|
return where, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the WHERE clause is complex (contains OR, parentheses, subqueries, etc.),
|
||||||
|
// we can't safely auto-fix it - require explicit prefix
|
||||||
|
if strings.Contains(lowerWhere, " or ") ||
|
||||||
|
strings.Contains(where, "(") ||
|
||||||
|
strings.Contains(where, ")") {
|
||||||
|
return "", fmt.Errorf("preload WHERE condition must reference the relation '%s' (e.g., '%s.column_name'). Complex WHERE clauses with OR/parentheses must explicitly use the relation prefix", relationName, relationName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to add the relation prefix to simple column references
|
||||||
|
// This handles basic cases like "column = value" or "column = value AND other_column = value"
|
||||||
|
// Split by AND to handle multiple conditions (case-insensitive)
|
||||||
|
originalConditions := strings.Split(where, " AND ")
|
||||||
|
|
||||||
|
// If uppercase split didn't work, try lowercase
|
||||||
|
if len(originalConditions) == 1 {
|
||||||
|
originalConditions = strings.Split(where, " and ")
|
||||||
|
}
|
||||||
|
|
||||||
|
fixedConditions := make([]string, 0, len(originalConditions))
|
||||||
|
|
||||||
|
for _, cond := range originalConditions {
|
||||||
|
cond = strings.TrimSpace(cond)
|
||||||
|
if cond == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if this condition already has a table prefix (contains a dot)
|
||||||
|
if strings.Contains(cond, ".") {
|
||||||
|
fixedConditions = append(fixedConditions, cond)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if this is a SQL expression/literal that shouldn't be prefixed
|
||||||
|
lowerCond := strings.ToLower(strings.TrimSpace(cond))
|
||||||
|
if IsSQLExpression(lowerCond) {
|
||||||
|
// Don't prefix SQL expressions like "true", "false", "1=1", etc.
|
||||||
|
fixedConditions = append(fixedConditions, cond)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract the column name (first identifier before operator)
|
||||||
|
columnName := ExtractColumnName(cond)
|
||||||
|
if columnName == "" {
|
||||||
|
// Can't identify column name, require explicit prefix
|
||||||
|
return "", fmt.Errorf("preload WHERE condition must reference the relation '%s' (e.g., '%s.column_name'). Cannot auto-fix condition: %s", relationName, relationName, cond)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add relation prefix to the column name only
|
||||||
|
fixedCond := strings.Replace(cond, columnName, relationName+"."+columnName, 1)
|
||||||
|
fixedConditions = append(fixedConditions, fixedCond)
|
||||||
|
}
|
||||||
|
|
||||||
|
fixedWhere := strings.Join(fixedConditions, " AND ")
|
||||||
|
logger.Debug("Auto-fixed preload WHERE clause: '%s' -> '%s'", where, fixedWhere)
|
||||||
|
return fixedWhere, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsSQLExpression checks if a condition is a SQL expression that shouldn't be prefixed
|
||||||
|
func IsSQLExpression(cond string) bool {
|
||||||
|
// Common SQL literals and expressions
|
||||||
|
sqlLiterals := []string{"true", "false", "null", "1=1", "1 = 1", "0=0", "0 = 0"}
|
||||||
|
for _, literal := range sqlLiterals {
|
||||||
|
if cond == literal {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExtractColumnName extracts the column name from a WHERE condition
|
||||||
|
// For example: "status = 'active'" returns "status"
|
||||||
|
func ExtractColumnName(cond string) string {
|
||||||
|
// Common SQL operators
|
||||||
|
operators := []string{" = ", " != ", " <> ", " > ", " >= ", " < ", " <= ", " LIKE ", " like ", " IN ", " in ", " IS ", " is "}
|
||||||
|
|
||||||
|
for _, op := range operators {
|
||||||
|
if idx := strings.Index(cond, op); idx > 0 {
|
||||||
|
columnName := strings.TrimSpace(cond[:idx])
|
||||||
|
// Remove quotes if present
|
||||||
|
columnName = strings.Trim(columnName, "`\"'")
|
||||||
|
return columnName
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If no operator found, check if it's a simple identifier (for boolean columns)
|
||||||
|
parts := strings.Fields(cond)
|
||||||
|
if len(parts) > 0 {
|
||||||
|
columnName := strings.Trim(parts[0], "`\"'")
|
||||||
|
// Check if it's a valid identifier (not a SQL keyword)
|
||||||
|
if !IsSQLKeyword(strings.ToLower(columnName)) {
|
||||||
|
return columnName
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsSQLKeyword checks if a string is a SQL keyword that shouldn't be treated as a column name
|
||||||
|
func IsSQLKeyword(word string) bool {
|
||||||
|
keywords := []string{"select", "from", "where", "and", "or", "not", "in", "is", "null", "true", "false", "like", "between", "exists"}
|
||||||
|
for _, kw := range keywords {
|
||||||
|
if word == kw {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
@@ -103,3 +103,18 @@ func CatchPanicCallback(location string, cb func(err any)) {
|
|||||||
func CatchPanic(location string) {
|
func CatchPanic(location string) {
|
||||||
CatchPanicCallback(location, nil)
|
CatchPanicCallback(location, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// HandlePanic logs a panic and returns it as an error
|
||||||
|
// This should be called with the result of recover() from a deferred function
|
||||||
|
// Example usage:
|
||||||
|
//
|
||||||
|
// defer func() {
|
||||||
|
// if r := recover(); r != nil {
|
||||||
|
// err = logger.HandlePanic("MethodName", r)
|
||||||
|
// }
|
||||||
|
// }()
|
||||||
|
func HandlePanic(methodName string, r any) error {
|
||||||
|
stack := debug.Stack()
|
||||||
|
Error("Panic in %s: %v\nStack trace:\n%s", methodName, r, string(stack))
|
||||||
|
return fmt.Errorf("panic in %s: %v", methodName, r)
|
||||||
|
}
|
||||||
|
|||||||
@@ -1132,10 +1132,15 @@ func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, pre
|
|||||||
// ORMs like GORM and Bun expect the struct field name, not the JSON name
|
// ORMs like GORM and Bun expect the struct field name, not the JSON name
|
||||||
relationFieldName := relInfo.fieldName
|
relationFieldName := relInfo.fieldName
|
||||||
|
|
||||||
// For now, we'll preload without conditions
|
// Validate and fix WHERE clause to ensure it contains the relation prefix
|
||||||
// TODO: Implement column selection and filtering for preloads
|
if len(preload.Where) > 0 {
|
||||||
// This requires a more sophisticated approach with callbacks or query builders
|
fixedWhere, err := common.ValidateAndFixPreloadWhere(preload.Where, relationFieldName)
|
||||||
// Apply preloading
|
if err != nil {
|
||||||
|
logger.Error("Invalid preload WHERE clause for relation '%s': %v", relationFieldName, err)
|
||||||
|
panic(fmt.Errorf("invalid preload WHERE clause for relation '%s': %w", relationFieldName, err))
|
||||||
|
}
|
||||||
|
preload.Where = fixedWhere
|
||||||
|
}
|
||||||
|
|
||||||
logger.Debug("Applying preload: %s", relationFieldName)
|
logger.Debug("Applying preload: %s", relationFieldName)
|
||||||
query = query.PreloadRelation(relationFieldName, func(sq common.SelectQuery) common.SelectQuery {
|
query = query.PreloadRelation(relationFieldName, func(sq common.SelectQuery) common.SelectQuery {
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ const (
|
|||||||
contextKeyTableName contextKey = "tableName"
|
contextKeyTableName contextKey = "tableName"
|
||||||
contextKeyModel contextKey = "model"
|
contextKeyModel contextKey = "model"
|
||||||
contextKeyModelPtr contextKey = "modelPtr"
|
contextKeyModelPtr contextKey = "modelPtr"
|
||||||
|
contextKeyOptions contextKey = "options"
|
||||||
)
|
)
|
||||||
|
|
||||||
// WithSchema adds schema to context
|
// WithSchema adds schema to context
|
||||||
@@ -74,12 +75,28 @@ func GetModelPtr(ctx context.Context) interface{} {
|
|||||||
return ctx.Value(contextKeyModelPtr)
|
return ctx.Value(contextKeyModelPtr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithOptions adds request options to context
|
||||||
|
func WithOptions(ctx context.Context, options ExtendedRequestOptions) context.Context {
|
||||||
|
return context.WithValue(ctx, contextKeyOptions, options)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOptions retrieves request options from context
|
||||||
|
func GetOptions(ctx context.Context) *ExtendedRequestOptions {
|
||||||
|
if v := ctx.Value(contextKeyOptions); v != nil {
|
||||||
|
if opts, ok := v.(ExtendedRequestOptions); ok {
|
||||||
|
return &opts
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// WithRequestData adds all request-scoped data to context at once
|
// WithRequestData adds all request-scoped data to context at once
|
||||||
func WithRequestData(ctx context.Context, schema, entity, tableName string, model, modelPtr interface{}) context.Context {
|
func WithRequestData(ctx context.Context, schema, entity, tableName string, model, modelPtr interface{}, options ExtendedRequestOptions) context.Context {
|
||||||
ctx = WithSchema(ctx, schema)
|
ctx = WithSchema(ctx, schema)
|
||||||
ctx = WithEntity(ctx, entity)
|
ctx = WithEntity(ctx, entity)
|
||||||
ctx = WithTableName(ctx, tableName)
|
ctx = WithTableName(ctx, tableName)
|
||||||
ctx = WithModel(ctx, model)
|
ctx = WithModel(ctx, model)
|
||||||
ctx = WithModelPtr(ctx, modelPtr)
|
ctx = WithModelPtr(ctx, modelPtr)
|
||||||
|
ctx = WithOptions(ctx, options)
|
||||||
return ctx
|
return ctx
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -65,9 +65,6 @@ func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[s
|
|||||||
entity := params["entity"]
|
entity := params["entity"]
|
||||||
id := params["id"]
|
id := params["id"]
|
||||||
|
|
||||||
// Parse options from headers (now returns ExtendedRequestOptions)
|
|
||||||
options := h.parseOptionsFromHeaders(r)
|
|
||||||
|
|
||||||
// Determine operation based on HTTP method
|
// Determine operation based on HTTP method
|
||||||
method := r.Method()
|
method := r.Method()
|
||||||
|
|
||||||
@@ -104,13 +101,16 @@ func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[s
|
|||||||
modelPtr := reflect.New(reflect.TypeOf(model)).Interface()
|
modelPtr := reflect.New(reflect.TypeOf(model)).Interface()
|
||||||
tableName := h.getTableName(schema, entity, model)
|
tableName := h.getTableName(schema, entity, model)
|
||||||
|
|
||||||
// Add request-scoped data to context
|
// Parse options from headers - this now includes relation name resolution
|
||||||
ctx = WithRequestData(ctx, schema, entity, tableName, model, modelPtr)
|
options := h.parseOptionsFromHeaders(r, model)
|
||||||
|
|
||||||
// Validate and filter columns in options (log warnings for invalid columns)
|
// Validate and filter columns in options (log warnings for invalid columns)
|
||||||
validator := common.NewColumnValidator(model)
|
validator := common.NewColumnValidator(model)
|
||||||
options = filterExtendedOptions(validator, options)
|
options = filterExtendedOptions(validator, options)
|
||||||
|
|
||||||
|
// Add request-scoped data to context (including options)
|
||||||
|
ctx = WithRequestData(ctx, schema, entity, tableName, model, modelPtr, options)
|
||||||
|
|
||||||
switch method {
|
switch method {
|
||||||
case "GET":
|
case "GET":
|
||||||
if id != "" {
|
if id != "" {
|
||||||
@@ -344,6 +344,19 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
for idx := range options.Preload {
|
for idx := range options.Preload {
|
||||||
preload := options.Preload[idx]
|
preload := options.Preload[idx]
|
||||||
logger.Debug("Applying preload: %s", preload.Relation)
|
logger.Debug("Applying preload: %s", preload.Relation)
|
||||||
|
|
||||||
|
// Validate and fix WHERE clause to ensure it contains the relation prefix
|
||||||
|
if len(preload.Where) > 0 {
|
||||||
|
fixedWhere, err := common.ValidateAndFixPreloadWhere(preload.Where, preload.Relation)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Invalid preload WHERE clause for relation '%s': %v", preload.Relation, err)
|
||||||
|
h.sendError(w, http.StatusBadRequest, "invalid_preload_where",
|
||||||
|
fmt.Sprintf("Invalid preload WHERE clause for relation '%s'", preload.Relation), err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
preload.Where = fixedWhere
|
||||||
|
}
|
||||||
|
|
||||||
query = query.PreloadRelation(preload.Relation, func(sq common.SelectQuery) common.SelectQuery {
|
query = query.PreloadRelation(preload.Relation, func(sq common.SelectQuery) common.SelectQuery {
|
||||||
if len(preload.OmitColumns) > 0 {
|
if len(preload.OmitColumns) > 0 {
|
||||||
allCols := reflection.GetModelColumns(model)
|
allCols := reflection.GetModelColumns(model)
|
||||||
|
|||||||
@@ -99,7 +99,8 @@ func DecodeParam(pStr string) (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// parseOptionsFromHeaders parses all request options from HTTP headers
|
// parseOptionsFromHeaders parses all request options from HTTP headers
|
||||||
func (h *Handler) parseOptionsFromHeaders(r common.Request) ExtendedRequestOptions {
|
// If model is provided, it will resolve table names to field names in preload/expand options
|
||||||
|
func (h *Handler) parseOptionsFromHeaders(r common.Request, model interface{}) ExtendedRequestOptions {
|
||||||
options := ExtendedRequestOptions{
|
options := ExtendedRequestOptions{
|
||||||
RequestOptions: common.RequestOptions{
|
RequestOptions: common.RequestOptions{
|
||||||
Filters: make([]common.FilterOption, 0),
|
Filters: make([]common.FilterOption, 0),
|
||||||
@@ -225,6 +226,11 @@ func (h *Handler) parseOptionsFromHeaders(r common.Request) ExtendedRequestOptio
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Resolve relation names (convert table names to field names) if model is provided
|
||||||
|
if model != nil {
|
||||||
|
h.resolveRelationNamesInOptions(&options, model)
|
||||||
|
}
|
||||||
|
|
||||||
return options
|
return options
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -655,6 +661,192 @@ func (h *Handler) processXFilesRelations(xfiles *XFiles, options *ExtendedReques
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// resolveRelationNamesInOptions resolves all table names to field names in preload options
|
||||||
|
// This is called internally by parseOptionsFromHeaders when a model is provided
|
||||||
|
func (h *Handler) resolveRelationNamesInOptions(options *ExtendedRequestOptions, model interface{}) {
|
||||||
|
if options == nil || model == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Resolve relation names in all preload options
|
||||||
|
for i := range options.Preload {
|
||||||
|
preload := &options.Preload[i]
|
||||||
|
|
||||||
|
// Split the relation path (e.g., "parent.child.grandchild")
|
||||||
|
parts := strings.Split(preload.Relation, ".")
|
||||||
|
resolvedParts := make([]string, 0, len(parts))
|
||||||
|
|
||||||
|
// Resolve each part of the path
|
||||||
|
currentModel := model
|
||||||
|
for _, part := range parts {
|
||||||
|
resolvedPart := h.resolveRelationName(currentModel, part)
|
||||||
|
resolvedParts = append(resolvedParts, resolvedPart)
|
||||||
|
|
||||||
|
// Try to get the model type for the next level
|
||||||
|
// This allows nested resolution
|
||||||
|
if nextModel := h.getRelationModel(currentModel, resolvedPart); nextModel != nil {
|
||||||
|
currentModel = nextModel
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update the relation path with resolved names
|
||||||
|
resolvedPath := strings.Join(resolvedParts, ".")
|
||||||
|
if resolvedPath != preload.Relation {
|
||||||
|
logger.Debug("Resolved relation path '%s' -> '%s'", preload.Relation, resolvedPath)
|
||||||
|
preload.Relation = resolvedPath
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Resolve relation names in expand options
|
||||||
|
for i := range options.Expand {
|
||||||
|
expand := &options.Expand[i]
|
||||||
|
resolved := h.resolveRelationName(model, expand.Relation)
|
||||||
|
if resolved != expand.Relation {
|
||||||
|
logger.Debug("Resolved expand relation '%s' -> '%s'", expand.Relation, resolved)
|
||||||
|
expand.Relation = resolved
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// getRelationModel gets the model type for a relation field
|
||||||
|
func (h *Handler) getRelationModel(model interface{}, fieldName string) interface{} {
|
||||||
|
if model == nil || fieldName == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
modelType := reflect.TypeOf(model)
|
||||||
|
if modelType == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if modelType.Kind() == reflect.Ptr {
|
||||||
|
modelType = modelType.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find the field
|
||||||
|
field, found := modelType.FieldByName(fieldName)
|
||||||
|
if !found {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the target type
|
||||||
|
targetType := field.Type
|
||||||
|
if targetType == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if targetType.Kind() == reflect.Slice {
|
||||||
|
targetType = targetType.Elem()
|
||||||
|
if targetType == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if targetType.Kind() == reflect.Ptr {
|
||||||
|
targetType = targetType.Elem()
|
||||||
|
if targetType == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if targetType.Kind() != reflect.Struct {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a zero value of the target type
|
||||||
|
return reflect.New(targetType).Elem().Interface()
|
||||||
|
}
|
||||||
|
|
||||||
|
// resolveRelationName resolves a relation name or table name to the actual field name in the model
|
||||||
|
// If the input is already a field name, it returns it as-is
|
||||||
|
// If the input is a table name, it looks up the corresponding relation field
|
||||||
|
func (h *Handler) resolveRelationName(model interface{}, nameOrTable string) string {
|
||||||
|
if model == nil || nameOrTable == "" {
|
||||||
|
return nameOrTable
|
||||||
|
}
|
||||||
|
|
||||||
|
modelType := reflect.TypeOf(model)
|
||||||
|
if modelType == nil {
|
||||||
|
return nameOrTable
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dereference pointer if needed
|
||||||
|
if modelType.Kind() == reflect.Ptr {
|
||||||
|
modelType = modelType.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check again after dereferencing
|
||||||
|
if modelType == nil {
|
||||||
|
return nameOrTable
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure it's a struct
|
||||||
|
if modelType.Kind() != reflect.Struct {
|
||||||
|
return nameOrTable
|
||||||
|
}
|
||||||
|
|
||||||
|
// First, check if the input matches a field name directly
|
||||||
|
for i := 0; i < modelType.NumField(); i++ {
|
||||||
|
field := modelType.Field(i)
|
||||||
|
if field.Name == nameOrTable {
|
||||||
|
// It's already a field name
|
||||||
|
logger.Debug("Input '%s' is a field name", nameOrTable)
|
||||||
|
return nameOrTable
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If not found as a field name, try to look it up as a table name
|
||||||
|
normalizedInput := strings.ToLower(strings.ReplaceAll(nameOrTable, "_", ""))
|
||||||
|
|
||||||
|
for i := 0; i < modelType.NumField(); i++ {
|
||||||
|
field := modelType.Field(i)
|
||||||
|
fieldType := field.Type
|
||||||
|
|
||||||
|
// Check if it's a slice or pointer to a struct
|
||||||
|
var targetType reflect.Type
|
||||||
|
if fieldType.Kind() == reflect.Slice {
|
||||||
|
targetType = fieldType.Elem()
|
||||||
|
} else if fieldType.Kind() == reflect.Ptr {
|
||||||
|
targetType = fieldType.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
if targetType != nil {
|
||||||
|
// Dereference pointer if the slice contains pointers
|
||||||
|
if targetType.Kind() == reflect.Ptr {
|
||||||
|
targetType = targetType.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if it's a struct type
|
||||||
|
if targetType.Kind() == reflect.Struct {
|
||||||
|
// Get the type name and normalize it
|
||||||
|
typeName := targetType.Name()
|
||||||
|
|
||||||
|
// Extract the table name from type name
|
||||||
|
// Patterns: ModelCoreMastertaskitem -> mastertaskitem
|
||||||
|
// ModelMastertaskitem -> mastertaskitem
|
||||||
|
normalizedTypeName := strings.ToLower(typeName)
|
||||||
|
|
||||||
|
// Remove common prefixes like "model", "modelcore", etc.
|
||||||
|
normalizedTypeName = strings.TrimPrefix(normalizedTypeName, "modelcore")
|
||||||
|
normalizedTypeName = strings.TrimPrefix(normalizedTypeName, "model")
|
||||||
|
|
||||||
|
// Compare normalized names
|
||||||
|
if normalizedTypeName == normalizedInput {
|
||||||
|
logger.Debug("Resolved table name '%s' to field '%s' (type: %s)", nameOrTable, field.Name, typeName)
|
||||||
|
return field.Name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If no match found, return the original input
|
||||||
|
logger.Debug("No field found for '%s', using as-is", nameOrTable)
|
||||||
|
return nameOrTable
|
||||||
|
}
|
||||||
|
|
||||||
// addXFilesPreload converts an XFiles relation into a PreloadOption
|
// addXFilesPreload converts an XFiles relation into a PreloadOption
|
||||||
// and recursively processes its children
|
// and recursively processes its children
|
||||||
func (h *Handler) addXFilesPreload(xfile *XFiles, options *ExtendedRequestOptions, basePath string) {
|
func (h *Handler) addXFilesPreload(xfile *XFiles, options *ExtendedRequestOptions, basePath string) {
|
||||||
@@ -662,7 +854,8 @@ func (h *Handler) addXFilesPreload(xfile *XFiles, options *ExtendedRequestOption
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Determine the relation path
|
// Store the table name as-is for now - it will be resolved to field name later
|
||||||
|
// when we have the model instance available
|
||||||
relationPath := xfile.TableName
|
relationPath := xfile.TableName
|
||||||
if basePath != "" {
|
if basePath != "" {
|
||||||
relationPath = basePath + "." + xfile.TableName
|
relationPath = basePath + "." + xfile.TableName
|
||||||
|
|||||||
Reference in New Issue
Block a user