Better handling of preload where conditions and a few panic changes

This commit is contained in:
Hein 2025-11-20 16:50:26 +02:00
parent 745564f2e7
commit 0d4909054c
5 changed files with 206 additions and 53 deletions

View File

@ -46,8 +46,8 @@ func (b *BunAdapter) NewDelete() common.DeleteQuery {
func (b *BunAdapter) Exec(ctx context.Context, query string, args ...interface{}) (res common.Result, err error) { func (b *BunAdapter) Exec(ctx context.Context, query string, args ...interface{}) (res common.Result, err error) {
defer func() { defer func() {
if panicErr := logger.RecoverPanic("BunAdapter.Exec"); panicErr != nil { if r := recover(); r != nil {
err = panicErr err = logger.HandlePanic("BunAdapter.Exec", r)
} }
}() }()
result, err := b.db.ExecContext(ctx, query, args...) result, err := b.db.ExecContext(ctx, query, args...)
@ -56,8 +56,8 @@ func (b *BunAdapter) Exec(ctx context.Context, query string, args ...interface{}
func (b *BunAdapter) Query(ctx context.Context, dest interface{}, query string, args ...interface{}) (err error) { func (b *BunAdapter) Query(ctx context.Context, dest interface{}, query string, args ...interface{}) (err error) {
defer func() { defer func() {
if panicErr := logger.RecoverPanic("BunAdapter.Query"); panicErr != nil { if r := recover(); r != nil {
err = panicErr err = logger.HandlePanic("BunAdapter.Query", r)
} }
}() }()
return b.db.NewRaw(query, args...).Scan(ctx, dest) return b.db.NewRaw(query, args...).Scan(ctx, dest)
@ -86,8 +86,8 @@ func (b *BunAdapter) RollbackTx(ctx context.Context) error {
func (b *BunAdapter) RunInTransaction(ctx context.Context, fn func(common.Database) error) (err error) { func (b *BunAdapter) RunInTransaction(ctx context.Context, fn func(common.Database) error) (err error) {
defer func() { defer func() {
if panicErr := logger.RecoverPanic("BunAdapter.RunInTransaction"); panicErr != nil { if r := recover(); r != nil {
err = panicErr err = logger.HandlePanic("BunAdapter.RunInTransaction", r)
} }
}() }()
return b.db.RunInTx(ctx, &sql.TxOptions{}, func(ctx context.Context, tx bun.Tx) error { return b.db.RunInTx(ctx, &sql.TxOptions{}, func(ctx context.Context, tx bun.Tx) error {
@ -235,6 +235,11 @@ func (b *BunSelectQuery) Preload(relation string, conditions ...interface{}) com
func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery { func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
b.query = b.query.Relation(relation, func(sq *bun.SelectQuery) *bun.SelectQuery { b.query = b.query.Relation(relation, func(sq *bun.SelectQuery) *bun.SelectQuery {
defer func() {
if r := recover(); r != nil {
logger.HandlePanic("BunSelectQuery.PreloadRelation", r)
}
}()
if len(apply) == 0 { if len(apply) == 0 {
return sq return sq
} }
@ -294,8 +299,8 @@ func (b *BunSelectQuery) Having(having string, args ...interface{}) common.Selec
func (b *BunSelectQuery) Scan(ctx context.Context, dest interface{}) (err error) { func (b *BunSelectQuery) Scan(ctx context.Context, dest interface{}) (err error) {
defer func() { defer func() {
if panicErr := logger.RecoverPanic("BunSelectQuery.Scan"); panicErr != nil { if r := recover(); r != nil {
err = panicErr err = logger.HandlePanic("BunSelectQuery.Scan", r)
} }
}() }()
if dest == nil { if dest == nil {
@ -306,8 +311,8 @@ func (b *BunSelectQuery) Scan(ctx context.Context, dest interface{}) (err error)
func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) { func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) {
defer func() { defer func() {
if panicErr := logger.RecoverPanic("BunSelectQuery.ScanModel"); panicErr != nil { if r := recover(); r != nil {
err = panicErr err = logger.HandlePanic("BunSelectQuery.ScanModel", r)
} }
}() }()
if b.query.GetModel() == nil { if b.query.GetModel() == nil {
@ -319,8 +324,8 @@ func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) {
func (b *BunSelectQuery) Count(ctx context.Context) (count int, err error) { func (b *BunSelectQuery) Count(ctx context.Context) (count int, err error) {
defer func() { defer func() {
if panicErr := logger.RecoverPanic("BunSelectQuery.Count"); panicErr != nil { if r := recover(); r != nil {
err = panicErr err = logger.HandlePanic("BunSelectQuery.Count", r)
count = 0 count = 0
} }
}() }()
@ -341,8 +346,8 @@ func (b *BunSelectQuery) Count(ctx context.Context) (count int, err error) {
func (b *BunSelectQuery) Exists(ctx context.Context) (exists bool, err error) { func (b *BunSelectQuery) Exists(ctx context.Context) (exists bool, err error) {
defer func() { defer func() {
if panicErr := logger.RecoverPanic("BunSelectQuery.Exists"); panicErr != nil { if r := recover(); r != nil {
err = panicErr err = logger.HandlePanic("BunSelectQuery.Exists", r)
exists = false exists = false
} }
}() }()
@ -392,8 +397,8 @@ func (b *BunInsertQuery) Returning(columns ...string) common.InsertQuery {
func (b *BunInsertQuery) Exec(ctx context.Context) (res common.Result, err error) { func (b *BunInsertQuery) Exec(ctx context.Context) (res common.Result, err error) {
defer func() { defer func() {
if panicErr := logger.RecoverPanic("BunInsertQuery.Exec"); panicErr != nil { if r := recover(); r != nil {
err = panicErr err = logger.HandlePanic("BunInsertQuery.Exec", r)
} }
}() }()
if b.values != nil && len(b.values) > 0 { if b.values != nil && len(b.values) > 0 {
@ -478,8 +483,8 @@ func (b *BunUpdateQuery) Returning(columns ...string) common.UpdateQuery {
func (b *BunUpdateQuery) Exec(ctx context.Context) (res common.Result, err error) { func (b *BunUpdateQuery) Exec(ctx context.Context) (res common.Result, err error) {
defer func() { defer func() {
if panicErr := logger.RecoverPanic("BunUpdateQuery.Exec"); panicErr != nil { if r := recover(); r != nil {
err = panicErr err = logger.HandlePanic("BunUpdateQuery.Exec", r)
} }
}() }()
result, err := b.query.Exec(ctx) result, err := b.query.Exec(ctx)
@ -508,8 +513,8 @@ func (b *BunDeleteQuery) Where(query string, args ...interface{}) common.DeleteQ
func (b *BunDeleteQuery) Exec(ctx context.Context) (res common.Result, err error) { func (b *BunDeleteQuery) Exec(ctx context.Context) (res common.Result, err error) {
defer func() { defer func() {
if panicErr := logger.RecoverPanic("BunDeleteQuery.Exec"); panicErr != nil { if r := recover(); r != nil {
err = panicErr err = logger.HandlePanic("BunDeleteQuery.Exec", r)
} }
}() }()
result, err := b.query.Exec(ctx) result, err := b.query.Exec(ctx)

View File

@ -41,8 +41,8 @@ func (g *GormAdapter) NewDelete() common.DeleteQuery {
func (g *GormAdapter) Exec(ctx context.Context, query string, args ...interface{}) (res common.Result, err error) { func (g *GormAdapter) Exec(ctx context.Context, query string, args ...interface{}) (res common.Result, err error) {
defer func() { defer func() {
if panicErr := logger.RecoverPanic("GormAdapter.Exec"); panicErr != nil { if r := recover(); r != nil {
err = panicErr err = logger.HandlePanic("GormAdapter.Exec", r)
} }
}() }()
result := g.db.WithContext(ctx).Exec(query, args...) result := g.db.WithContext(ctx).Exec(query, args...)
@ -51,8 +51,8 @@ func (g *GormAdapter) Exec(ctx context.Context, query string, args ...interface{
func (g *GormAdapter) Query(ctx context.Context, dest interface{}, query string, args ...interface{}) (err error) { func (g *GormAdapter) Query(ctx context.Context, dest interface{}, query string, args ...interface{}) (err error) {
defer func() { defer func() {
if panicErr := logger.RecoverPanic("GormAdapter.Query"); panicErr != nil { if r := recover(); r != nil {
err = panicErr err = logger.HandlePanic("GormAdapter.Query", r)
} }
}() }()
return g.db.WithContext(ctx).Raw(query, args...).Find(dest).Error return g.db.WithContext(ctx).Raw(query, args...).Find(dest).Error
@ -76,8 +76,8 @@ func (g *GormAdapter) RollbackTx(ctx context.Context) error {
func (g *GormAdapter) RunInTransaction(ctx context.Context, fn func(common.Database) error) (err error) { func (g *GormAdapter) RunInTransaction(ctx context.Context, fn func(common.Database) error) (err error) {
defer func() { defer func() {
if panicErr := logger.RecoverPanic("GormAdapter.RunInTransaction"); panicErr != nil { if r := recover(); r != nil {
err = panicErr err = logger.HandlePanic("GormAdapter.RunInTransaction", r)
} }
}() }()
return g.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { return g.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
@ -273,8 +273,8 @@ func (g *GormSelectQuery) Having(having string, args ...interface{}) common.Sele
func (g *GormSelectQuery) Scan(ctx context.Context, dest interface{}) (err error) { func (g *GormSelectQuery) Scan(ctx context.Context, dest interface{}) (err error) {
defer func() { defer func() {
if panicErr := logger.RecoverPanic("GormSelectQuery.Scan"); panicErr != nil { if r := recover(); r != nil {
err = panicErr err = logger.HandlePanic("GormSelectQuery.Scan", r)
} }
}() }()
return g.db.WithContext(ctx).Find(dest).Error return g.db.WithContext(ctx).Find(dest).Error
@ -282,8 +282,8 @@ func (g *GormSelectQuery) Scan(ctx context.Context, dest interface{}) (err error
func (g *GormSelectQuery) ScanModel(ctx context.Context) (err error) { func (g *GormSelectQuery) ScanModel(ctx context.Context) (err error) {
defer func() { defer func() {
if panicErr := logger.RecoverPanic("GormSelectQuery.ScanModel"); panicErr != nil { if r := recover(); r != nil {
err = panicErr err = logger.HandlePanic("GormSelectQuery.ScanModel", r)
} }
}() }()
if g.db.Statement.Model == nil { if g.db.Statement.Model == nil {
@ -294,8 +294,8 @@ func (g *GormSelectQuery) ScanModel(ctx context.Context) (err error) {
func (g *GormSelectQuery) Count(ctx context.Context) (count int, err error) { func (g *GormSelectQuery) Count(ctx context.Context) (count int, err error) {
defer func() { defer func() {
if panicErr := logger.RecoverPanic("GormSelectQuery.Count"); panicErr != nil { if r := recover(); r != nil {
err = panicErr err = logger.HandlePanic("GormSelectQuery.Count", r)
count = 0 count = 0
} }
}() }()
@ -306,8 +306,8 @@ func (g *GormSelectQuery) Count(ctx context.Context) (count int, err error) {
func (g *GormSelectQuery) Exists(ctx context.Context) (exists bool, err error) { func (g *GormSelectQuery) Exists(ctx context.Context) (exists bool, err error) {
defer func() { defer func() {
if panicErr := logger.RecoverPanic("GormSelectQuery.Exists"); panicErr != nil { if r := recover(); r != nil {
err = panicErr err = logger.HandlePanic("GormSelectQuery.Exists", r)
exists = false exists = false
} }
}() }()
@ -354,8 +354,8 @@ func (g *GormInsertQuery) Returning(columns ...string) common.InsertQuery {
func (g *GormInsertQuery) Exec(ctx context.Context) (res common.Result, err error) { func (g *GormInsertQuery) Exec(ctx context.Context) (res common.Result, err error) {
defer func() { defer func() {
if panicErr := logger.RecoverPanic("GormInsertQuery.Exec"); panicErr != nil { if r := recover(); r != nil {
err = panicErr err = logger.HandlePanic("GormInsertQuery.Exec", r)
} }
}() }()
var result *gorm.DB var result *gorm.DB
@ -446,8 +446,8 @@ func (g *GormUpdateQuery) Returning(columns ...string) common.UpdateQuery {
func (g *GormUpdateQuery) Exec(ctx context.Context) (res common.Result, err error) { func (g *GormUpdateQuery) Exec(ctx context.Context) (res common.Result, err error) {
defer func() { defer func() {
if panicErr := logger.RecoverPanic("GormUpdateQuery.Exec"); panicErr != nil { if r := recover(); r != nil {
err = panicErr err = logger.HandlePanic("GormUpdateQuery.Exec", r)
} }
}() }()
result := g.db.WithContext(ctx).Updates(g.updates) result := g.db.WithContext(ctx).Updates(g.updates)
@ -478,8 +478,8 @@ func (g *GormDeleteQuery) Where(query string, args ...interface{}) common.Delete
func (g *GormDeleteQuery) Exec(ctx context.Context) (res common.Result, err error) { func (g *GormDeleteQuery) Exec(ctx context.Context) (res common.Result, err error) {
defer func() { defer func() {
if panicErr := logger.RecoverPanic("GormDeleteQuery.Exec"); panicErr != nil { if r := recover(); r != nil {
err = panicErr err = logger.HandlePanic("GormDeleteQuery.Exec", r)
} }
}() }()
result := g.db.WithContext(ctx).Delete(g.model) result := g.db.WithContext(ctx).Delete(g.model)

View File

@ -104,13 +104,17 @@ func CatchPanic(location string) {
CatchPanicCallback(location, nil) CatchPanicCallback(location, nil)
} }
// RecoverPanic recovers from panics and returns an error // HandlePanic logs a panic and returns it as an error
// Use this in deferred functions to convert panics into errors // This should be called with the result of recover() from a deferred function
func RecoverPanic(methodName string) error { // Example usage:
if r := recover(); r != nil { //
// defer func() {
// if r := recover(); r != nil {
// err = logger.HandlePanic("MethodName", r)
// }
// }()
func HandlePanic(methodName string, r any) error {
stack := debug.Stack() stack := debug.Stack()
Error("Panic in %s: %v\nStack trace:\n%s", methodName, r, string(stack)) Error("Panic in %s: %v\nStack trace:\n%s", methodName, r, string(stack))
return fmt.Errorf("panic in %s: %v", methodName, r) return fmt.Errorf("panic in %s: %v", methodName, r)
} }
return nil
}

View File

@ -1105,6 +1105,69 @@ type relationshipInfo struct {
relatedModel interface{} relatedModel interface{}
} }
// validateAndFixPreloadWhere validates that the WHERE clause for a preload contains
// the relation prefix (alias). If not present, it attempts to add it to column references.
// Returns the fixed WHERE clause and an error if it cannot be safely fixed.
func (h *Handler) validateAndFixPreloadWhere(where string, relationName string) (string, error) {
if where == "" {
return where, nil
}
// Check if the relation name is already present in the WHERE clause
lowerWhere := strings.ToLower(where)
lowerRelation := strings.ToLower(relationName)
// Check for patterns like "relation.", "relation ", or just "relation" followed by a dot
if strings.Contains(lowerWhere, lowerRelation+".") ||
strings.Contains(lowerWhere, "`"+lowerRelation+"`.") ||
strings.Contains(lowerWhere, "\""+lowerRelation+"\".") {
// Relation prefix is already present
return where, nil
}
// If the WHERE clause is complex (contains OR, parentheses, subqueries, etc.),
// we can't safely auto-fix it - require explicit prefix
if strings.Contains(lowerWhere, " or ") ||
strings.Contains(where, "(") ||
strings.Contains(where, ")") {
return "", fmt.Errorf("preload WHERE condition must reference the relation '%s' (e.g., '%s.column_name'). Complex WHERE clauses with OR/parentheses must explicitly use the relation prefix", relationName, relationName)
}
// Try to add the relation prefix to simple column references
// This handles basic cases like "column = value" or "column = value AND other_column = value"
// Split by AND to handle multiple conditions (case-insensitive)
originalConditions := strings.Split(where, " AND ")
// If uppercase split didn't work, try lowercase
if len(originalConditions) == 1 {
originalConditions = strings.Split(where, " and ")
}
fixedConditions := make([]string, 0, len(originalConditions))
for _, cond := range originalConditions {
cond = strings.TrimSpace(cond)
if cond == "" {
continue
}
// Check if this condition already has a table prefix (contains a dot)
if strings.Contains(cond, ".") {
fixedConditions = append(fixedConditions, cond)
continue
}
// Add relation prefix to the column name
// This prefixes the entire condition with "relationName."
fixedCond := fmt.Sprintf("%s.%s", relationName, cond)
fixedConditions = append(fixedConditions, fixedCond)
}
fixedWhere := strings.Join(fixedConditions, " AND ")
logger.Debug("Auto-fixed preload WHERE clause: '%s' -> '%s'", where, fixedWhere)
return fixedWhere, nil
}
func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, preloads []common.PreloadOption) common.SelectQuery { func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, preloads []common.PreloadOption) common.SelectQuery {
modelType := reflect.TypeOf(model) modelType := reflect.TypeOf(model)
@ -1132,10 +1195,15 @@ func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, pre
// ORMs like GORM and Bun expect the struct field name, not the JSON name // ORMs like GORM and Bun expect the struct field name, not the JSON name
relationFieldName := relInfo.fieldName relationFieldName := relInfo.fieldName
// For now, we'll preload without conditions // Validate and fix WHERE clause to ensure it contains the relation prefix
// TODO: Implement column selection and filtering for preloads if len(preload.Where) > 0 {
// This requires a more sophisticated approach with callbacks or query builders fixedWhere, err := h.validateAndFixPreloadWhere(preload.Where, relationFieldName)
// Apply preloading if err != nil {
logger.Error("Invalid preload WHERE clause for relation '%s': %v", relationFieldName, err)
panic(fmt.Errorf("invalid preload WHERE clause for relation '%s': %w", relationFieldName, err))
}
preload.Where = fixedWhere
}
logger.Debug("Applying preload: %s", relationFieldName) logger.Debug("Applying preload: %s", relationFieldName)
query = query.PreloadRelation(relationFieldName, func(sq common.SelectQuery) common.SelectQuery { query = query.PreloadRelation(relationFieldName, func(sq common.SelectQuery) common.SelectQuery {

View File

@ -200,6 +200,69 @@ func (h *Handler) HandleGet(w common.ResponseWriter, r common.Request, params ma
// parseOptionsFromHeaders is now implemented in headers.go // parseOptionsFromHeaders is now implemented in headers.go
// validateAndFixPreloadWhere validates that the WHERE clause for a preload contains
// the relation prefix (alias). If not present, it attempts to add it to column references.
// Returns the fixed WHERE clause and an error if it cannot be safely fixed.
func (h *Handler) validateAndFixPreloadWhere(where string, relationName string) (string, error) {
if where == "" {
return where, nil
}
// Check if the relation name is already present in the WHERE clause
lowerWhere := strings.ToLower(where)
lowerRelation := strings.ToLower(relationName)
// Check for patterns like "relation.", "relation ", or just "relation" followed by a dot
if strings.Contains(lowerWhere, lowerRelation+".") ||
strings.Contains(lowerWhere, "`"+lowerRelation+"`.") ||
strings.Contains(lowerWhere, "\""+lowerRelation+"\".") {
// Relation prefix is already present
return where, nil
}
// If the WHERE clause is complex (contains OR, parentheses, subqueries, etc.),
// we can't safely auto-fix it - require explicit prefix
if strings.Contains(lowerWhere, " or ") ||
strings.Contains(where, "(") ||
strings.Contains(where, ")") {
return "", fmt.Errorf("preload WHERE condition must reference the relation '%s' (e.g., '%s.column_name'). Complex WHERE clauses with OR/parentheses must explicitly use the relation prefix", relationName, relationName)
}
// Try to add the relation prefix to simple column references
// This handles basic cases like "column = value" or "column = value AND other_column = value"
// Split by AND to handle multiple conditions (case-insensitive)
originalConditions := strings.Split(where, " AND ")
// If uppercase split didn't work, try lowercase
if len(originalConditions) == 1 {
originalConditions = strings.Split(where, " and ")
}
fixedConditions := make([]string, 0, len(originalConditions))
for _, cond := range originalConditions {
cond = strings.TrimSpace(cond)
if cond == "" {
continue
}
// Check if this condition already has a table prefix (contains a dot)
if strings.Contains(cond, ".") {
fixedConditions = append(fixedConditions, cond)
continue
}
// Add relation prefix to the column name
// This prefixes the entire condition with "relationName."
fixedCond := fmt.Sprintf("%s.%s", relationName, cond)
fixedConditions = append(fixedConditions, fixedCond)
}
fixedWhere := strings.Join(fixedConditions, " AND ")
logger.Debug("Auto-fixed preload WHERE clause: '%s' -> '%s'", where, fixedWhere)
return fixedWhere, nil
}
func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id string, options ExtendedRequestOptions) { func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id string, options ExtendedRequestOptions) {
// Capture panics and return error response // Capture panics and return error response
defer func() { defer func() {
@ -344,6 +407,19 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
for idx := range options.Preload { for idx := range options.Preload {
preload := options.Preload[idx] preload := options.Preload[idx]
logger.Debug("Applying preload: %s", preload.Relation) logger.Debug("Applying preload: %s", preload.Relation)
// Validate and fix WHERE clause to ensure it contains the relation prefix
if len(preload.Where) > 0 {
fixedWhere, err := h.validateAndFixPreloadWhere(preload.Where, preload.Relation)
if err != nil {
logger.Error("Invalid preload WHERE clause for relation '%s': %v", preload.Relation, err)
h.sendError(w, http.StatusBadRequest, "invalid_preload_where",
fmt.Sprintf("Invalid preload WHERE clause for relation '%s'", preload.Relation), err)
return
}
preload.Where = fixedWhere
}
query = query.PreloadRelation(preload.Relation, func(sq common.SelectQuery) common.SelectQuery { query = query.PreloadRelation(preload.Relation, func(sq common.SelectQuery) common.SelectQuery {
if len(preload.OmitColumns) > 0 { if len(preload.OmitColumns) > 0 {
allCols := reflection.GetModelColumns(model) allCols := reflection.GetModelColumns(model)