mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2025-12-06 14:26:22 +00:00
Better handling of preload where conditions and a few panic changes
This commit is contained in:
parent
745564f2e7
commit
0d4909054c
@ -46,8 +46,8 @@ func (b *BunAdapter) NewDelete() common.DeleteQuery {
|
|||||||
|
|
||||||
func (b *BunAdapter) Exec(ctx context.Context, query string, args ...interface{}) (res common.Result, err error) {
|
func (b *BunAdapter) Exec(ctx context.Context, query string, args ...interface{}) (res common.Result, err error) {
|
||||||
defer func() {
|
defer func() {
|
||||||
if panicErr := logger.RecoverPanic("BunAdapter.Exec"); panicErr != nil {
|
if r := recover(); r != nil {
|
||||||
err = panicErr
|
err = logger.HandlePanic("BunAdapter.Exec", r)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
result, err := b.db.ExecContext(ctx, query, args...)
|
result, err := b.db.ExecContext(ctx, query, args...)
|
||||||
@ -56,8 +56,8 @@ func (b *BunAdapter) Exec(ctx context.Context, query string, args ...interface{}
|
|||||||
|
|
||||||
func (b *BunAdapter) Query(ctx context.Context, dest interface{}, query string, args ...interface{}) (err error) {
|
func (b *BunAdapter) Query(ctx context.Context, dest interface{}, query string, args ...interface{}) (err error) {
|
||||||
defer func() {
|
defer func() {
|
||||||
if panicErr := logger.RecoverPanic("BunAdapter.Query"); panicErr != nil {
|
if r := recover(); r != nil {
|
||||||
err = panicErr
|
err = logger.HandlePanic("BunAdapter.Query", r)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
return b.db.NewRaw(query, args...).Scan(ctx, dest)
|
return b.db.NewRaw(query, args...).Scan(ctx, dest)
|
||||||
@ -86,8 +86,8 @@ func (b *BunAdapter) RollbackTx(ctx context.Context) error {
|
|||||||
|
|
||||||
func (b *BunAdapter) RunInTransaction(ctx context.Context, fn func(common.Database) error) (err error) {
|
func (b *BunAdapter) RunInTransaction(ctx context.Context, fn func(common.Database) error) (err error) {
|
||||||
defer func() {
|
defer func() {
|
||||||
if panicErr := logger.RecoverPanic("BunAdapter.RunInTransaction"); panicErr != nil {
|
if r := recover(); r != nil {
|
||||||
err = panicErr
|
err = logger.HandlePanic("BunAdapter.RunInTransaction", r)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
return b.db.RunInTx(ctx, &sql.TxOptions{}, func(ctx context.Context, tx bun.Tx) error {
|
return b.db.RunInTx(ctx, &sql.TxOptions{}, func(ctx context.Context, tx bun.Tx) error {
|
||||||
@ -235,6 +235,11 @@ func (b *BunSelectQuery) Preload(relation string, conditions ...interface{}) com
|
|||||||
|
|
||||||
func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
|
func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
|
||||||
b.query = b.query.Relation(relation, func(sq *bun.SelectQuery) *bun.SelectQuery {
|
b.query = b.query.Relation(relation, func(sq *bun.SelectQuery) *bun.SelectQuery {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
logger.HandlePanic("BunSelectQuery.PreloadRelation", r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
if len(apply) == 0 {
|
if len(apply) == 0 {
|
||||||
return sq
|
return sq
|
||||||
}
|
}
|
||||||
@ -294,8 +299,8 @@ func (b *BunSelectQuery) Having(having string, args ...interface{}) common.Selec
|
|||||||
|
|
||||||
func (b *BunSelectQuery) Scan(ctx context.Context, dest interface{}) (err error) {
|
func (b *BunSelectQuery) Scan(ctx context.Context, dest interface{}) (err error) {
|
||||||
defer func() {
|
defer func() {
|
||||||
if panicErr := logger.RecoverPanic("BunSelectQuery.Scan"); panicErr != nil {
|
if r := recover(); r != nil {
|
||||||
err = panicErr
|
err = logger.HandlePanic("BunSelectQuery.Scan", r)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
if dest == nil {
|
if dest == nil {
|
||||||
@ -306,8 +311,8 @@ func (b *BunSelectQuery) Scan(ctx context.Context, dest interface{}) (err error)
|
|||||||
|
|
||||||
func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) {
|
func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) {
|
||||||
defer func() {
|
defer func() {
|
||||||
if panicErr := logger.RecoverPanic("BunSelectQuery.ScanModel"); panicErr != nil {
|
if r := recover(); r != nil {
|
||||||
err = panicErr
|
err = logger.HandlePanic("BunSelectQuery.ScanModel", r)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
if b.query.GetModel() == nil {
|
if b.query.GetModel() == nil {
|
||||||
@ -319,8 +324,8 @@ func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) {
|
|||||||
|
|
||||||
func (b *BunSelectQuery) Count(ctx context.Context) (count int, err error) {
|
func (b *BunSelectQuery) Count(ctx context.Context) (count int, err error) {
|
||||||
defer func() {
|
defer func() {
|
||||||
if panicErr := logger.RecoverPanic("BunSelectQuery.Count"); panicErr != nil {
|
if r := recover(); r != nil {
|
||||||
err = panicErr
|
err = logger.HandlePanic("BunSelectQuery.Count", r)
|
||||||
count = 0
|
count = 0
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
@ -341,8 +346,8 @@ func (b *BunSelectQuery) Count(ctx context.Context) (count int, err error) {
|
|||||||
|
|
||||||
func (b *BunSelectQuery) Exists(ctx context.Context) (exists bool, err error) {
|
func (b *BunSelectQuery) Exists(ctx context.Context) (exists bool, err error) {
|
||||||
defer func() {
|
defer func() {
|
||||||
if panicErr := logger.RecoverPanic("BunSelectQuery.Exists"); panicErr != nil {
|
if r := recover(); r != nil {
|
||||||
err = panicErr
|
err = logger.HandlePanic("BunSelectQuery.Exists", r)
|
||||||
exists = false
|
exists = false
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
@ -392,8 +397,8 @@ func (b *BunInsertQuery) Returning(columns ...string) common.InsertQuery {
|
|||||||
|
|
||||||
func (b *BunInsertQuery) Exec(ctx context.Context) (res common.Result, err error) {
|
func (b *BunInsertQuery) Exec(ctx context.Context) (res common.Result, err error) {
|
||||||
defer func() {
|
defer func() {
|
||||||
if panicErr := logger.RecoverPanic("BunInsertQuery.Exec"); panicErr != nil {
|
if r := recover(); r != nil {
|
||||||
err = panicErr
|
err = logger.HandlePanic("BunInsertQuery.Exec", r)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
if b.values != nil && len(b.values) > 0 {
|
if b.values != nil && len(b.values) > 0 {
|
||||||
@ -478,8 +483,8 @@ func (b *BunUpdateQuery) Returning(columns ...string) common.UpdateQuery {
|
|||||||
|
|
||||||
func (b *BunUpdateQuery) Exec(ctx context.Context) (res common.Result, err error) {
|
func (b *BunUpdateQuery) Exec(ctx context.Context) (res common.Result, err error) {
|
||||||
defer func() {
|
defer func() {
|
||||||
if panicErr := logger.RecoverPanic("BunUpdateQuery.Exec"); panicErr != nil {
|
if r := recover(); r != nil {
|
||||||
err = panicErr
|
err = logger.HandlePanic("BunUpdateQuery.Exec", r)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
result, err := b.query.Exec(ctx)
|
result, err := b.query.Exec(ctx)
|
||||||
@ -508,8 +513,8 @@ func (b *BunDeleteQuery) Where(query string, args ...interface{}) common.DeleteQ
|
|||||||
|
|
||||||
func (b *BunDeleteQuery) Exec(ctx context.Context) (res common.Result, err error) {
|
func (b *BunDeleteQuery) Exec(ctx context.Context) (res common.Result, err error) {
|
||||||
defer func() {
|
defer func() {
|
||||||
if panicErr := logger.RecoverPanic("BunDeleteQuery.Exec"); panicErr != nil {
|
if r := recover(); r != nil {
|
||||||
err = panicErr
|
err = logger.HandlePanic("BunDeleteQuery.Exec", r)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
result, err := b.query.Exec(ctx)
|
result, err := b.query.Exec(ctx)
|
||||||
|
|||||||
@ -41,8 +41,8 @@ func (g *GormAdapter) NewDelete() common.DeleteQuery {
|
|||||||
|
|
||||||
func (g *GormAdapter) Exec(ctx context.Context, query string, args ...interface{}) (res common.Result, err error) {
|
func (g *GormAdapter) Exec(ctx context.Context, query string, args ...interface{}) (res common.Result, err error) {
|
||||||
defer func() {
|
defer func() {
|
||||||
if panicErr := logger.RecoverPanic("GormAdapter.Exec"); panicErr != nil {
|
if r := recover(); r != nil {
|
||||||
err = panicErr
|
err = logger.HandlePanic("GormAdapter.Exec", r)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
result := g.db.WithContext(ctx).Exec(query, args...)
|
result := g.db.WithContext(ctx).Exec(query, args...)
|
||||||
@ -51,8 +51,8 @@ func (g *GormAdapter) Exec(ctx context.Context, query string, args ...interface{
|
|||||||
|
|
||||||
func (g *GormAdapter) Query(ctx context.Context, dest interface{}, query string, args ...interface{}) (err error) {
|
func (g *GormAdapter) Query(ctx context.Context, dest interface{}, query string, args ...interface{}) (err error) {
|
||||||
defer func() {
|
defer func() {
|
||||||
if panicErr := logger.RecoverPanic("GormAdapter.Query"); panicErr != nil {
|
if r := recover(); r != nil {
|
||||||
err = panicErr
|
err = logger.HandlePanic("GormAdapter.Query", r)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
return g.db.WithContext(ctx).Raw(query, args...).Find(dest).Error
|
return g.db.WithContext(ctx).Raw(query, args...).Find(dest).Error
|
||||||
@ -76,8 +76,8 @@ func (g *GormAdapter) RollbackTx(ctx context.Context) error {
|
|||||||
|
|
||||||
func (g *GormAdapter) RunInTransaction(ctx context.Context, fn func(common.Database) error) (err error) {
|
func (g *GormAdapter) RunInTransaction(ctx context.Context, fn func(common.Database) error) (err error) {
|
||||||
defer func() {
|
defer func() {
|
||||||
if panicErr := logger.RecoverPanic("GormAdapter.RunInTransaction"); panicErr != nil {
|
if r := recover(); r != nil {
|
||||||
err = panicErr
|
err = logger.HandlePanic("GormAdapter.RunInTransaction", r)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
return g.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
return g.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||||
@ -273,8 +273,8 @@ func (g *GormSelectQuery) Having(having string, args ...interface{}) common.Sele
|
|||||||
|
|
||||||
func (g *GormSelectQuery) Scan(ctx context.Context, dest interface{}) (err error) {
|
func (g *GormSelectQuery) Scan(ctx context.Context, dest interface{}) (err error) {
|
||||||
defer func() {
|
defer func() {
|
||||||
if panicErr := logger.RecoverPanic("GormSelectQuery.Scan"); panicErr != nil {
|
if r := recover(); r != nil {
|
||||||
err = panicErr
|
err = logger.HandlePanic("GormSelectQuery.Scan", r)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
return g.db.WithContext(ctx).Find(dest).Error
|
return g.db.WithContext(ctx).Find(dest).Error
|
||||||
@ -282,8 +282,8 @@ func (g *GormSelectQuery) Scan(ctx context.Context, dest interface{}) (err error
|
|||||||
|
|
||||||
func (g *GormSelectQuery) ScanModel(ctx context.Context) (err error) {
|
func (g *GormSelectQuery) ScanModel(ctx context.Context) (err error) {
|
||||||
defer func() {
|
defer func() {
|
||||||
if panicErr := logger.RecoverPanic("GormSelectQuery.ScanModel"); panicErr != nil {
|
if r := recover(); r != nil {
|
||||||
err = panicErr
|
err = logger.HandlePanic("GormSelectQuery.ScanModel", r)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
if g.db.Statement.Model == nil {
|
if g.db.Statement.Model == nil {
|
||||||
@ -294,8 +294,8 @@ func (g *GormSelectQuery) ScanModel(ctx context.Context) (err error) {
|
|||||||
|
|
||||||
func (g *GormSelectQuery) Count(ctx context.Context) (count int, err error) {
|
func (g *GormSelectQuery) Count(ctx context.Context) (count int, err error) {
|
||||||
defer func() {
|
defer func() {
|
||||||
if panicErr := logger.RecoverPanic("GormSelectQuery.Count"); panicErr != nil {
|
if r := recover(); r != nil {
|
||||||
err = panicErr
|
err = logger.HandlePanic("GormSelectQuery.Count", r)
|
||||||
count = 0
|
count = 0
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
@ -306,8 +306,8 @@ func (g *GormSelectQuery) Count(ctx context.Context) (count int, err error) {
|
|||||||
|
|
||||||
func (g *GormSelectQuery) Exists(ctx context.Context) (exists bool, err error) {
|
func (g *GormSelectQuery) Exists(ctx context.Context) (exists bool, err error) {
|
||||||
defer func() {
|
defer func() {
|
||||||
if panicErr := logger.RecoverPanic("GormSelectQuery.Exists"); panicErr != nil {
|
if r := recover(); r != nil {
|
||||||
err = panicErr
|
err = logger.HandlePanic("GormSelectQuery.Exists", r)
|
||||||
exists = false
|
exists = false
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
@ -354,8 +354,8 @@ func (g *GormInsertQuery) Returning(columns ...string) common.InsertQuery {
|
|||||||
|
|
||||||
func (g *GormInsertQuery) Exec(ctx context.Context) (res common.Result, err error) {
|
func (g *GormInsertQuery) Exec(ctx context.Context) (res common.Result, err error) {
|
||||||
defer func() {
|
defer func() {
|
||||||
if panicErr := logger.RecoverPanic("GormInsertQuery.Exec"); panicErr != nil {
|
if r := recover(); r != nil {
|
||||||
err = panicErr
|
err = logger.HandlePanic("GormInsertQuery.Exec", r)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
var result *gorm.DB
|
var result *gorm.DB
|
||||||
@ -446,8 +446,8 @@ func (g *GormUpdateQuery) Returning(columns ...string) common.UpdateQuery {
|
|||||||
|
|
||||||
func (g *GormUpdateQuery) Exec(ctx context.Context) (res common.Result, err error) {
|
func (g *GormUpdateQuery) Exec(ctx context.Context) (res common.Result, err error) {
|
||||||
defer func() {
|
defer func() {
|
||||||
if panicErr := logger.RecoverPanic("GormUpdateQuery.Exec"); panicErr != nil {
|
if r := recover(); r != nil {
|
||||||
err = panicErr
|
err = logger.HandlePanic("GormUpdateQuery.Exec", r)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
result := g.db.WithContext(ctx).Updates(g.updates)
|
result := g.db.WithContext(ctx).Updates(g.updates)
|
||||||
@ -478,8 +478,8 @@ func (g *GormDeleteQuery) Where(query string, args ...interface{}) common.Delete
|
|||||||
|
|
||||||
func (g *GormDeleteQuery) Exec(ctx context.Context) (res common.Result, err error) {
|
func (g *GormDeleteQuery) Exec(ctx context.Context) (res common.Result, err error) {
|
||||||
defer func() {
|
defer func() {
|
||||||
if panicErr := logger.RecoverPanic("GormDeleteQuery.Exec"); panicErr != nil {
|
if r := recover(); r != nil {
|
||||||
err = panicErr
|
err = logger.HandlePanic("GormDeleteQuery.Exec", r)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
result := g.db.WithContext(ctx).Delete(g.model)
|
result := g.db.WithContext(ctx).Delete(g.model)
|
||||||
|
|||||||
@ -104,13 +104,17 @@ func CatchPanic(location string) {
|
|||||||
CatchPanicCallback(location, nil)
|
CatchPanicCallback(location, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RecoverPanic recovers from panics and returns an error
|
// HandlePanic logs a panic and returns it as an error
|
||||||
// Use this in deferred functions to convert panics into errors
|
// This should be called with the result of recover() from a deferred function
|
||||||
func RecoverPanic(methodName string) error {
|
// Example usage:
|
||||||
if r := recover(); r != nil {
|
//
|
||||||
|
// defer func() {
|
||||||
|
// if r := recover(); r != nil {
|
||||||
|
// err = logger.HandlePanic("MethodName", r)
|
||||||
|
// }
|
||||||
|
// }()
|
||||||
|
func HandlePanic(methodName string, r any) error {
|
||||||
stack := debug.Stack()
|
stack := debug.Stack()
|
||||||
Error("Panic in %s: %v\nStack trace:\n%s", methodName, r, string(stack))
|
Error("Panic in %s: %v\nStack trace:\n%s", methodName, r, string(stack))
|
||||||
return fmt.Errorf("panic in %s: %v", methodName, r)
|
return fmt.Errorf("panic in %s: %v", methodName, r)
|
||||||
}
|
}
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|||||||
@ -1105,6 +1105,69 @@ type relationshipInfo struct {
|
|||||||
relatedModel interface{}
|
relatedModel interface{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// validateAndFixPreloadWhere validates that the WHERE clause for a preload contains
|
||||||
|
// the relation prefix (alias). If not present, it attempts to add it to column references.
|
||||||
|
// Returns the fixed WHERE clause and an error if it cannot be safely fixed.
|
||||||
|
func (h *Handler) validateAndFixPreloadWhere(where string, relationName string) (string, error) {
|
||||||
|
if where == "" {
|
||||||
|
return where, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the relation name is already present in the WHERE clause
|
||||||
|
lowerWhere := strings.ToLower(where)
|
||||||
|
lowerRelation := strings.ToLower(relationName)
|
||||||
|
|
||||||
|
// Check for patterns like "relation.", "relation ", or just "relation" followed by a dot
|
||||||
|
if strings.Contains(lowerWhere, lowerRelation+".") ||
|
||||||
|
strings.Contains(lowerWhere, "`"+lowerRelation+"`.") ||
|
||||||
|
strings.Contains(lowerWhere, "\""+lowerRelation+"\".") {
|
||||||
|
// Relation prefix is already present
|
||||||
|
return where, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the WHERE clause is complex (contains OR, parentheses, subqueries, etc.),
|
||||||
|
// we can't safely auto-fix it - require explicit prefix
|
||||||
|
if strings.Contains(lowerWhere, " or ") ||
|
||||||
|
strings.Contains(where, "(") ||
|
||||||
|
strings.Contains(where, ")") {
|
||||||
|
return "", fmt.Errorf("preload WHERE condition must reference the relation '%s' (e.g., '%s.column_name'). Complex WHERE clauses with OR/parentheses must explicitly use the relation prefix", relationName, relationName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to add the relation prefix to simple column references
|
||||||
|
// This handles basic cases like "column = value" or "column = value AND other_column = value"
|
||||||
|
// Split by AND to handle multiple conditions (case-insensitive)
|
||||||
|
originalConditions := strings.Split(where, " AND ")
|
||||||
|
|
||||||
|
// If uppercase split didn't work, try lowercase
|
||||||
|
if len(originalConditions) == 1 {
|
||||||
|
originalConditions = strings.Split(where, " and ")
|
||||||
|
}
|
||||||
|
|
||||||
|
fixedConditions := make([]string, 0, len(originalConditions))
|
||||||
|
|
||||||
|
for _, cond := range originalConditions {
|
||||||
|
cond = strings.TrimSpace(cond)
|
||||||
|
if cond == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if this condition already has a table prefix (contains a dot)
|
||||||
|
if strings.Contains(cond, ".") {
|
||||||
|
fixedConditions = append(fixedConditions, cond)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add relation prefix to the column name
|
||||||
|
// This prefixes the entire condition with "relationName."
|
||||||
|
fixedCond := fmt.Sprintf("%s.%s", relationName, cond)
|
||||||
|
fixedConditions = append(fixedConditions, fixedCond)
|
||||||
|
}
|
||||||
|
|
||||||
|
fixedWhere := strings.Join(fixedConditions, " AND ")
|
||||||
|
logger.Debug("Auto-fixed preload WHERE clause: '%s' -> '%s'", where, fixedWhere)
|
||||||
|
return fixedWhere, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, preloads []common.PreloadOption) common.SelectQuery {
|
func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, preloads []common.PreloadOption) common.SelectQuery {
|
||||||
modelType := reflect.TypeOf(model)
|
modelType := reflect.TypeOf(model)
|
||||||
|
|
||||||
@ -1132,10 +1195,15 @@ func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, pre
|
|||||||
// ORMs like GORM and Bun expect the struct field name, not the JSON name
|
// ORMs like GORM and Bun expect the struct field name, not the JSON name
|
||||||
relationFieldName := relInfo.fieldName
|
relationFieldName := relInfo.fieldName
|
||||||
|
|
||||||
// For now, we'll preload without conditions
|
// Validate and fix WHERE clause to ensure it contains the relation prefix
|
||||||
// TODO: Implement column selection and filtering for preloads
|
if len(preload.Where) > 0 {
|
||||||
// This requires a more sophisticated approach with callbacks or query builders
|
fixedWhere, err := h.validateAndFixPreloadWhere(preload.Where, relationFieldName)
|
||||||
// Apply preloading
|
if err != nil {
|
||||||
|
logger.Error("Invalid preload WHERE clause for relation '%s': %v", relationFieldName, err)
|
||||||
|
panic(fmt.Errorf("invalid preload WHERE clause for relation '%s': %w", relationFieldName, err))
|
||||||
|
}
|
||||||
|
preload.Where = fixedWhere
|
||||||
|
}
|
||||||
|
|
||||||
logger.Debug("Applying preload: %s", relationFieldName)
|
logger.Debug("Applying preload: %s", relationFieldName)
|
||||||
query = query.PreloadRelation(relationFieldName, func(sq common.SelectQuery) common.SelectQuery {
|
query = query.PreloadRelation(relationFieldName, func(sq common.SelectQuery) common.SelectQuery {
|
||||||
|
|||||||
@ -200,6 +200,69 @@ func (h *Handler) HandleGet(w common.ResponseWriter, r common.Request, params ma
|
|||||||
|
|
||||||
// parseOptionsFromHeaders is now implemented in headers.go
|
// parseOptionsFromHeaders is now implemented in headers.go
|
||||||
|
|
||||||
|
// validateAndFixPreloadWhere validates that the WHERE clause for a preload contains
|
||||||
|
// the relation prefix (alias). If not present, it attempts to add it to column references.
|
||||||
|
// Returns the fixed WHERE clause and an error if it cannot be safely fixed.
|
||||||
|
func (h *Handler) validateAndFixPreloadWhere(where string, relationName string) (string, error) {
|
||||||
|
if where == "" {
|
||||||
|
return where, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the relation name is already present in the WHERE clause
|
||||||
|
lowerWhere := strings.ToLower(where)
|
||||||
|
lowerRelation := strings.ToLower(relationName)
|
||||||
|
|
||||||
|
// Check for patterns like "relation.", "relation ", or just "relation" followed by a dot
|
||||||
|
if strings.Contains(lowerWhere, lowerRelation+".") ||
|
||||||
|
strings.Contains(lowerWhere, "`"+lowerRelation+"`.") ||
|
||||||
|
strings.Contains(lowerWhere, "\""+lowerRelation+"\".") {
|
||||||
|
// Relation prefix is already present
|
||||||
|
return where, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the WHERE clause is complex (contains OR, parentheses, subqueries, etc.),
|
||||||
|
// we can't safely auto-fix it - require explicit prefix
|
||||||
|
if strings.Contains(lowerWhere, " or ") ||
|
||||||
|
strings.Contains(where, "(") ||
|
||||||
|
strings.Contains(where, ")") {
|
||||||
|
return "", fmt.Errorf("preload WHERE condition must reference the relation '%s' (e.g., '%s.column_name'). Complex WHERE clauses with OR/parentheses must explicitly use the relation prefix", relationName, relationName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to add the relation prefix to simple column references
|
||||||
|
// This handles basic cases like "column = value" or "column = value AND other_column = value"
|
||||||
|
// Split by AND to handle multiple conditions (case-insensitive)
|
||||||
|
originalConditions := strings.Split(where, " AND ")
|
||||||
|
|
||||||
|
// If uppercase split didn't work, try lowercase
|
||||||
|
if len(originalConditions) == 1 {
|
||||||
|
originalConditions = strings.Split(where, " and ")
|
||||||
|
}
|
||||||
|
|
||||||
|
fixedConditions := make([]string, 0, len(originalConditions))
|
||||||
|
|
||||||
|
for _, cond := range originalConditions {
|
||||||
|
cond = strings.TrimSpace(cond)
|
||||||
|
if cond == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if this condition already has a table prefix (contains a dot)
|
||||||
|
if strings.Contains(cond, ".") {
|
||||||
|
fixedConditions = append(fixedConditions, cond)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add relation prefix to the column name
|
||||||
|
// This prefixes the entire condition with "relationName."
|
||||||
|
fixedCond := fmt.Sprintf("%s.%s", relationName, cond)
|
||||||
|
fixedConditions = append(fixedConditions, fixedCond)
|
||||||
|
}
|
||||||
|
|
||||||
|
fixedWhere := strings.Join(fixedConditions, " AND ")
|
||||||
|
logger.Debug("Auto-fixed preload WHERE clause: '%s' -> '%s'", where, fixedWhere)
|
||||||
|
return fixedWhere, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id string, options ExtendedRequestOptions) {
|
func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id string, options ExtendedRequestOptions) {
|
||||||
// Capture panics and return error response
|
// Capture panics and return error response
|
||||||
defer func() {
|
defer func() {
|
||||||
@ -344,6 +407,19 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
for idx := range options.Preload {
|
for idx := range options.Preload {
|
||||||
preload := options.Preload[idx]
|
preload := options.Preload[idx]
|
||||||
logger.Debug("Applying preload: %s", preload.Relation)
|
logger.Debug("Applying preload: %s", preload.Relation)
|
||||||
|
|
||||||
|
// Validate and fix WHERE clause to ensure it contains the relation prefix
|
||||||
|
if len(preload.Where) > 0 {
|
||||||
|
fixedWhere, err := h.validateAndFixPreloadWhere(preload.Where, preload.Relation)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Invalid preload WHERE clause for relation '%s': %v", preload.Relation, err)
|
||||||
|
h.sendError(w, http.StatusBadRequest, "invalid_preload_where",
|
||||||
|
fmt.Sprintf("Invalid preload WHERE clause for relation '%s'", preload.Relation), err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
preload.Where = fixedWhere
|
||||||
|
}
|
||||||
|
|
||||||
query = query.PreloadRelation(preload.Relation, func(sq common.SelectQuery) common.SelectQuery {
|
query = query.PreloadRelation(preload.Relation, func(sq common.SelectQuery) common.SelectQuery {
|
||||||
if len(preload.OmitColumns) > 0 {
|
if len(preload.OmitColumns) > 0 {
|
||||||
allCols := reflection.GetModelColumns(model)
|
allCols := reflection.GetModelColumns(model)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user