diff --git a/pkg/common/adapters/database/bun.go b/pkg/common/adapters/database/bun.go index ec44420..c0db6ab 100644 --- a/pkg/common/adapters/database/bun.go +++ b/pkg/common/adapters/database/bun.go @@ -46,8 +46,8 @@ func (b *BunAdapter) NewDelete() common.DeleteQuery { func (b *BunAdapter) Exec(ctx context.Context, query string, args ...interface{}) (res common.Result, err error) { defer func() { - if panicErr := logger.RecoverPanic("BunAdapter.Exec"); panicErr != nil { - err = panicErr + if r := recover(); r != nil { + err = logger.HandlePanic("BunAdapter.Exec", r) } }() result, err := b.db.ExecContext(ctx, query, args...) @@ -56,8 +56,8 @@ func (b *BunAdapter) Exec(ctx context.Context, query string, args ...interface{} func (b *BunAdapter) Query(ctx context.Context, dest interface{}, query string, args ...interface{}) (err error) { defer func() { - if panicErr := logger.RecoverPanic("BunAdapter.Query"); panicErr != nil { - err = panicErr + if r := recover(); r != nil { + err = logger.HandlePanic("BunAdapter.Query", r) } }() return b.db.NewRaw(query, args...).Scan(ctx, dest) @@ -86,8 +86,8 @@ func (b *BunAdapter) RollbackTx(ctx context.Context) error { func (b *BunAdapter) RunInTransaction(ctx context.Context, fn func(common.Database) error) (err error) { defer func() { - if panicErr := logger.RecoverPanic("BunAdapter.RunInTransaction"); panicErr != nil { - err = panicErr + if r := recover(); r != nil { + err = logger.HandlePanic("BunAdapter.RunInTransaction", r) } }() return b.db.RunInTx(ctx, &sql.TxOptions{}, func(ctx context.Context, tx bun.Tx) error { @@ -235,6 +235,11 @@ func (b *BunSelectQuery) Preload(relation string, conditions ...interface{}) com func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery { b.query = b.query.Relation(relation, func(sq *bun.SelectQuery) *bun.SelectQuery { + defer func() { + if r := recover(); r != nil { + logger.HandlePanic("BunSelectQuery.PreloadRelation", r) + } + }() if len(apply) == 0 { return sq } @@ -294,8 +299,8 @@ func (b *BunSelectQuery) Having(having string, args ...interface{}) common.Selec func (b *BunSelectQuery) Scan(ctx context.Context, dest interface{}) (err error) { defer func() { - if panicErr := logger.RecoverPanic("BunSelectQuery.Scan"); panicErr != nil { - err = panicErr + if r := recover(); r != nil { + err = logger.HandlePanic("BunSelectQuery.Scan", r) } }() if dest == nil { @@ -306,8 +311,8 @@ func (b *BunSelectQuery) Scan(ctx context.Context, dest interface{}) (err error) func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) { defer func() { - if panicErr := logger.RecoverPanic("BunSelectQuery.ScanModel"); panicErr != nil { - err = panicErr + if r := recover(); r != nil { + err = logger.HandlePanic("BunSelectQuery.ScanModel", r) } }() if b.query.GetModel() == nil { @@ -319,8 +324,8 @@ func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) { func (b *BunSelectQuery) Count(ctx context.Context) (count int, err error) { defer func() { - if panicErr := logger.RecoverPanic("BunSelectQuery.Count"); panicErr != nil { - err = panicErr + if r := recover(); r != nil { + err = logger.HandlePanic("BunSelectQuery.Count", r) count = 0 } }() @@ -341,8 +346,8 @@ func (b *BunSelectQuery) Count(ctx context.Context) (count int, err error) { func (b *BunSelectQuery) Exists(ctx context.Context) (exists bool, err error) { defer func() { - if panicErr := logger.RecoverPanic("BunSelectQuery.Exists"); panicErr != nil { - err = panicErr + if r := recover(); r != nil { + err = logger.HandlePanic("BunSelectQuery.Exists", r) exists = false } }() @@ -392,8 +397,8 @@ func (b *BunInsertQuery) Returning(columns ...string) common.InsertQuery { func (b *BunInsertQuery) Exec(ctx context.Context) (res common.Result, err error) { defer func() { - if panicErr := logger.RecoverPanic("BunInsertQuery.Exec"); panicErr != nil { - err = panicErr + if r := recover(); r != nil { + err = logger.HandlePanic("BunInsertQuery.Exec", r) } }() if b.values != nil && len(b.values) > 0 { @@ -478,8 +483,8 @@ func (b *BunUpdateQuery) Returning(columns ...string) common.UpdateQuery { func (b *BunUpdateQuery) Exec(ctx context.Context) (res common.Result, err error) { defer func() { - if panicErr := logger.RecoverPanic("BunUpdateQuery.Exec"); panicErr != nil { - err = panicErr + if r := recover(); r != nil { + err = logger.HandlePanic("BunUpdateQuery.Exec", r) } }() result, err := b.query.Exec(ctx) @@ -508,8 +513,8 @@ func (b *BunDeleteQuery) Where(query string, args ...interface{}) common.DeleteQ func (b *BunDeleteQuery) Exec(ctx context.Context) (res common.Result, err error) { defer func() { - if panicErr := logger.RecoverPanic("BunDeleteQuery.Exec"); panicErr != nil { - err = panicErr + if r := recover(); r != nil { + err = logger.HandlePanic("BunDeleteQuery.Exec", r) } }() result, err := b.query.Exec(ctx) diff --git a/pkg/common/adapters/database/gorm.go b/pkg/common/adapters/database/gorm.go index 26589b9..5311938 100644 --- a/pkg/common/adapters/database/gorm.go +++ b/pkg/common/adapters/database/gorm.go @@ -41,8 +41,8 @@ func (g *GormAdapter) NewDelete() common.DeleteQuery { func (g *GormAdapter) Exec(ctx context.Context, query string, args ...interface{}) (res common.Result, err error) { defer func() { - if panicErr := logger.RecoverPanic("GormAdapter.Exec"); panicErr != nil { - err = panicErr + if r := recover(); r != nil { + err = logger.HandlePanic("GormAdapter.Exec", r) } }() result := g.db.WithContext(ctx).Exec(query, args...) @@ -51,8 +51,8 @@ func (g *GormAdapter) Exec(ctx context.Context, query string, args ...interface{ func (g *GormAdapter) Query(ctx context.Context, dest interface{}, query string, args ...interface{}) (err error) { defer func() { - if panicErr := logger.RecoverPanic("GormAdapter.Query"); panicErr != nil { - err = panicErr + if r := recover(); r != nil { + err = logger.HandlePanic("GormAdapter.Query", r) } }() return g.db.WithContext(ctx).Raw(query, args...).Find(dest).Error @@ -76,8 +76,8 @@ func (g *GormAdapter) RollbackTx(ctx context.Context) error { func (g *GormAdapter) RunInTransaction(ctx context.Context, fn func(common.Database) error) (err error) { defer func() { - if panicErr := logger.RecoverPanic("GormAdapter.RunInTransaction"); panicErr != nil { - err = panicErr + if r := recover(); r != nil { + err = logger.HandlePanic("GormAdapter.RunInTransaction", r) } }() return g.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { @@ -273,8 +273,8 @@ func (g *GormSelectQuery) Having(having string, args ...interface{}) common.Sele func (g *GormSelectQuery) Scan(ctx context.Context, dest interface{}) (err error) { defer func() { - if panicErr := logger.RecoverPanic("GormSelectQuery.Scan"); panicErr != nil { - err = panicErr + if r := recover(); r != nil { + err = logger.HandlePanic("GormSelectQuery.Scan", r) } }() return g.db.WithContext(ctx).Find(dest).Error @@ -282,8 +282,8 @@ func (g *GormSelectQuery) Scan(ctx context.Context, dest interface{}) (err error func (g *GormSelectQuery) ScanModel(ctx context.Context) (err error) { defer func() { - if panicErr := logger.RecoverPanic("GormSelectQuery.ScanModel"); panicErr != nil { - err = panicErr + if r := recover(); r != nil { + err = logger.HandlePanic("GormSelectQuery.ScanModel", r) } }() if g.db.Statement.Model == nil { @@ -294,8 +294,8 @@ func (g *GormSelectQuery) ScanModel(ctx context.Context) (err error) { func (g *GormSelectQuery) Count(ctx context.Context) (count int, err error) { defer func() { - if panicErr := logger.RecoverPanic("GormSelectQuery.Count"); panicErr != nil { - err = panicErr + if r := recover(); r != nil { + err = logger.HandlePanic("GormSelectQuery.Count", r) count = 0 } }() @@ -306,8 +306,8 @@ func (g *GormSelectQuery) Count(ctx context.Context) (count int, err error) { func (g *GormSelectQuery) Exists(ctx context.Context) (exists bool, err error) { defer func() { - if panicErr := logger.RecoverPanic("GormSelectQuery.Exists"); panicErr != nil { - err = panicErr + if r := recover(); r != nil { + err = logger.HandlePanic("GormSelectQuery.Exists", r) exists = false } }() @@ -354,8 +354,8 @@ func (g *GormInsertQuery) Returning(columns ...string) common.InsertQuery { func (g *GormInsertQuery) Exec(ctx context.Context) (res common.Result, err error) { defer func() { - if panicErr := logger.RecoverPanic("GormInsertQuery.Exec"); panicErr != nil { - err = panicErr + if r := recover(); r != nil { + err = logger.HandlePanic("GormInsertQuery.Exec", r) } }() var result *gorm.DB @@ -446,8 +446,8 @@ func (g *GormUpdateQuery) Returning(columns ...string) common.UpdateQuery { func (g *GormUpdateQuery) Exec(ctx context.Context) (res common.Result, err error) { defer func() { - if panicErr := logger.RecoverPanic("GormUpdateQuery.Exec"); panicErr != nil { - err = panicErr + if r := recover(); r != nil { + err = logger.HandlePanic("GormUpdateQuery.Exec", r) } }() result := g.db.WithContext(ctx).Updates(g.updates) @@ -478,8 +478,8 @@ func (g *GormDeleteQuery) Where(query string, args ...interface{}) common.Delete func (g *GormDeleteQuery) Exec(ctx context.Context) (res common.Result, err error) { defer func() { - if panicErr := logger.RecoverPanic("GormDeleteQuery.Exec"); panicErr != nil { - err = panicErr + if r := recover(); r != nil { + err = logger.HandlePanic("GormDeleteQuery.Exec", r) } }() result := g.db.WithContext(ctx).Delete(g.model) diff --git a/pkg/logger/logger.go b/pkg/logger/logger.go index 03238b6..6514660 100644 --- a/pkg/logger/logger.go +++ b/pkg/logger/logger.go @@ -104,13 +104,17 @@ func CatchPanic(location string) { CatchPanicCallback(location, nil) } -// RecoverPanic recovers from panics and returns an error -// Use this in deferred functions to convert panics into errors -func RecoverPanic(methodName string) error { - if r := recover(); r != nil { - stack := debug.Stack() - Error("Panic in %s: %v\nStack trace:\n%s", methodName, r, string(stack)) - return fmt.Errorf("panic in %s: %v", methodName, r) - } - return nil +// HandlePanic logs a panic and returns it as an error +// This should be called with the result of recover() from a deferred function +// Example usage: +// +// defer func() { +// if r := recover(); r != nil { +// err = logger.HandlePanic("MethodName", r) +// } +// }() +func HandlePanic(methodName string, r any) error { + stack := debug.Stack() + Error("Panic in %s: %v\nStack trace:\n%s", methodName, r, string(stack)) + return fmt.Errorf("panic in %s: %v", methodName, r) } diff --git a/pkg/resolvespec/handler.go b/pkg/resolvespec/handler.go index cac81df..0e532cb 100644 --- a/pkg/resolvespec/handler.go +++ b/pkg/resolvespec/handler.go @@ -1105,6 +1105,69 @@ type relationshipInfo struct { relatedModel interface{} } +// validateAndFixPreloadWhere validates that the WHERE clause for a preload contains +// the relation prefix (alias). If not present, it attempts to add it to column references. +// Returns the fixed WHERE clause and an error if it cannot be safely fixed. +func (h *Handler) validateAndFixPreloadWhere(where string, relationName string) (string, error) { + if where == "" { + return where, nil + } + + // Check if the relation name is already present in the WHERE clause + lowerWhere := strings.ToLower(where) + lowerRelation := strings.ToLower(relationName) + + // Check for patterns like "relation.", "relation ", or just "relation" followed by a dot + if strings.Contains(lowerWhere, lowerRelation+".") || + strings.Contains(lowerWhere, "`"+lowerRelation+"`.") || + strings.Contains(lowerWhere, "\""+lowerRelation+"\".") { + // Relation prefix is already present + return where, nil + } + + // If the WHERE clause is complex (contains OR, parentheses, subqueries, etc.), + // we can't safely auto-fix it - require explicit prefix + if strings.Contains(lowerWhere, " or ") || + strings.Contains(where, "(") || + strings.Contains(where, ")") { + return "", fmt.Errorf("preload WHERE condition must reference the relation '%s' (e.g., '%s.column_name'). Complex WHERE clauses with OR/parentheses must explicitly use the relation prefix", relationName, relationName) + } + + // Try to add the relation prefix to simple column references + // This handles basic cases like "column = value" or "column = value AND other_column = value" + // Split by AND to handle multiple conditions (case-insensitive) + originalConditions := strings.Split(where, " AND ") + + // If uppercase split didn't work, try lowercase + if len(originalConditions) == 1 { + originalConditions = strings.Split(where, " and ") + } + + fixedConditions := make([]string, 0, len(originalConditions)) + + for _, cond := range originalConditions { + cond = strings.TrimSpace(cond) + if cond == "" { + continue + } + + // Check if this condition already has a table prefix (contains a dot) + if strings.Contains(cond, ".") { + fixedConditions = append(fixedConditions, cond) + continue + } + + // Add relation prefix to the column name + // This prefixes the entire condition with "relationName." + fixedCond := fmt.Sprintf("%s.%s", relationName, cond) + fixedConditions = append(fixedConditions, fixedCond) + } + + fixedWhere := strings.Join(fixedConditions, " AND ") + logger.Debug("Auto-fixed preload WHERE clause: '%s' -> '%s'", where, fixedWhere) + return fixedWhere, nil +} + func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, preloads []common.PreloadOption) common.SelectQuery { modelType := reflect.TypeOf(model) @@ -1132,10 +1195,15 @@ func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, pre // ORMs like GORM and Bun expect the struct field name, not the JSON name relationFieldName := relInfo.fieldName - // For now, we'll preload without conditions - // TODO: Implement column selection and filtering for preloads - // This requires a more sophisticated approach with callbacks or query builders - // Apply preloading + // Validate and fix WHERE clause to ensure it contains the relation prefix + if len(preload.Where) > 0 { + fixedWhere, err := h.validateAndFixPreloadWhere(preload.Where, relationFieldName) + if err != nil { + logger.Error("Invalid preload WHERE clause for relation '%s': %v", relationFieldName, err) + panic(fmt.Errorf("invalid preload WHERE clause for relation '%s': %w", relationFieldName, err)) + } + preload.Where = fixedWhere + } logger.Debug("Applying preload: %s", relationFieldName) query = query.PreloadRelation(relationFieldName, func(sq common.SelectQuery) common.SelectQuery { diff --git a/pkg/restheadspec/handler.go b/pkg/restheadspec/handler.go index 28912e6..29b201c 100644 --- a/pkg/restheadspec/handler.go +++ b/pkg/restheadspec/handler.go @@ -200,6 +200,69 @@ func (h *Handler) HandleGet(w common.ResponseWriter, r common.Request, params ma // parseOptionsFromHeaders is now implemented in headers.go +// validateAndFixPreloadWhere validates that the WHERE clause for a preload contains +// the relation prefix (alias). If not present, it attempts to add it to column references. +// Returns the fixed WHERE clause and an error if it cannot be safely fixed. +func (h *Handler) validateAndFixPreloadWhere(where string, relationName string) (string, error) { + if where == "" { + return where, nil + } + + // Check if the relation name is already present in the WHERE clause + lowerWhere := strings.ToLower(where) + lowerRelation := strings.ToLower(relationName) + + // Check for patterns like "relation.", "relation ", or just "relation" followed by a dot + if strings.Contains(lowerWhere, lowerRelation+".") || + strings.Contains(lowerWhere, "`"+lowerRelation+"`.") || + strings.Contains(lowerWhere, "\""+lowerRelation+"\".") { + // Relation prefix is already present + return where, nil + } + + // If the WHERE clause is complex (contains OR, parentheses, subqueries, etc.), + // we can't safely auto-fix it - require explicit prefix + if strings.Contains(lowerWhere, " or ") || + strings.Contains(where, "(") || + strings.Contains(where, ")") { + return "", fmt.Errorf("preload WHERE condition must reference the relation '%s' (e.g., '%s.column_name'). Complex WHERE clauses with OR/parentheses must explicitly use the relation prefix", relationName, relationName) + } + + // Try to add the relation prefix to simple column references + // This handles basic cases like "column = value" or "column = value AND other_column = value" + // Split by AND to handle multiple conditions (case-insensitive) + originalConditions := strings.Split(where, " AND ") + + // If uppercase split didn't work, try lowercase + if len(originalConditions) == 1 { + originalConditions = strings.Split(where, " and ") + } + + fixedConditions := make([]string, 0, len(originalConditions)) + + for _, cond := range originalConditions { + cond = strings.TrimSpace(cond) + if cond == "" { + continue + } + + // Check if this condition already has a table prefix (contains a dot) + if strings.Contains(cond, ".") { + fixedConditions = append(fixedConditions, cond) + continue + } + + // Add relation prefix to the column name + // This prefixes the entire condition with "relationName." + fixedCond := fmt.Sprintf("%s.%s", relationName, cond) + fixedConditions = append(fixedConditions, fixedCond) + } + + fixedWhere := strings.Join(fixedConditions, " AND ") + logger.Debug("Auto-fixed preload WHERE clause: '%s' -> '%s'", where, fixedWhere) + return fixedWhere, nil +} + func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id string, options ExtendedRequestOptions) { // Capture panics and return error response defer func() { @@ -344,6 +407,19 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st for idx := range options.Preload { preload := options.Preload[idx] logger.Debug("Applying preload: %s", preload.Relation) + + // Validate and fix WHERE clause to ensure it contains the relation prefix + if len(preload.Where) > 0 { + fixedWhere, err := h.validateAndFixPreloadWhere(preload.Where, preload.Relation) + if err != nil { + logger.Error("Invalid preload WHERE clause for relation '%s': %v", preload.Relation, err) + h.sendError(w, http.StatusBadRequest, "invalid_preload_where", + fmt.Sprintf("Invalid preload WHERE clause for relation '%s'", preload.Relation), err) + return + } + preload.Where = fixedWhere + } + query = query.PreloadRelation(preload.Relation, func(sq common.SelectQuery) common.SelectQuery { if len(preload.OmitColumns) > 0 { allCols := reflection.GetModelColumns(model)