diff --git a/pkg/common/adapters/database/bun.go b/pkg/common/adapters/database/bun.go index b16c2ba..ec44420 100644 --- a/pkg/common/adapters/database/bun.go +++ b/pkg/common/adapters/database/bun.go @@ -9,6 +9,7 @@ import ( "github.com/uptrace/bun" "github.com/bitechdev/ResolveSpec/pkg/common" + "github.com/bitechdev/ResolveSpec/pkg/logger" "github.com/bitechdev/ResolveSpec/pkg/modelregistry" "github.com/bitechdev/ResolveSpec/pkg/reflection" ) @@ -43,12 +44,22 @@ func (b *BunAdapter) NewDelete() common.DeleteQuery { return &BunDeleteQuery{query: b.db.NewDelete()} } -func (b *BunAdapter) Exec(ctx context.Context, query string, args ...interface{}) (common.Result, error) { +func (b *BunAdapter) Exec(ctx context.Context, query string, args ...interface{}) (res common.Result, err error) { + defer func() { + if panicErr := logger.RecoverPanic("BunAdapter.Exec"); panicErr != nil { + err = panicErr + } + }() result, err := b.db.ExecContext(ctx, query, args...) return &BunResult{result: result}, err } -func (b *BunAdapter) Query(ctx context.Context, dest interface{}, query string, args ...interface{}) error { +func (b *BunAdapter) Query(ctx context.Context, dest interface{}, query string, args ...interface{}) (err error) { + defer func() { + if panicErr := logger.RecoverPanic("BunAdapter.Query"); panicErr != nil { + err = panicErr + } + }() return b.db.NewRaw(query, args...).Scan(ctx, dest) } @@ -73,7 +84,12 @@ func (b *BunAdapter) RollbackTx(ctx context.Context) error { return nil } -func (b *BunAdapter) RunInTransaction(ctx context.Context, fn func(common.Database) error) error { +func (b *BunAdapter) RunInTransaction(ctx context.Context, fn func(common.Database) error) (err error) { + defer func() { + if panicErr := logger.RecoverPanic("BunAdapter.RunInTransaction"); panicErr != nil { + err = panicErr + } + }() return b.db.RunInTx(ctx, &sql.TxOptions{}, func(ctx context.Context, tx bun.Tx) error { // Create adapter with transaction adapter := &BunTxAdapter{tx: tx} @@ -276,15 +292,38 @@ func (b *BunSelectQuery) Having(having string, args ...interface{}) common.Selec return b } -func (b *BunSelectQuery) Scan(ctx context.Context, dest interface{}) error { +func (b *BunSelectQuery) Scan(ctx context.Context, dest interface{}) (err error) { + defer func() { + if panicErr := logger.RecoverPanic("BunSelectQuery.Scan"); panicErr != nil { + err = panicErr + } + }() + if dest == nil { + return fmt.Errorf("destination cannot be nil") + } return b.query.Scan(ctx, dest) } -func (b *BunSelectQuery) ScanModel(ctx context.Context) error { +func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) { + defer func() { + if panicErr := logger.RecoverPanic("BunSelectQuery.ScanModel"); panicErr != nil { + err = panicErr + } + }() + if b.query.GetModel() == nil { + return fmt.Errorf("model is nil") + } + return b.query.Scan(ctx) } -func (b *BunSelectQuery) Count(ctx context.Context) (int, error) { +func (b *BunSelectQuery) Count(ctx context.Context) (count int, err error) { + defer func() { + if panicErr := logger.RecoverPanic("BunSelectQuery.Count"); panicErr != nil { + err = panicErr + count = 0 + } + }() // If Model() was set, use bun's native Count() which works properly if b.hasModel { count, err := b.query.Count(ctx) @@ -293,15 +332,20 @@ func (b *BunSelectQuery) Count(ctx context.Context) (int, error) { // Otherwise, wrap as subquery to avoid "Model(nil)" error // This is needed when only Table() is set without a model - var count int - err := b.db.NewSelect(). + err = b.db.NewSelect(). TableExpr("(?) AS subquery", b.query). ColumnExpr("COUNT(*)"). Scan(ctx, &count) return count, err } -func (b *BunSelectQuery) Exists(ctx context.Context) (bool, error) { +func (b *BunSelectQuery) Exists(ctx context.Context) (exists bool, err error) { + defer func() { + if panicErr := logger.RecoverPanic("BunSelectQuery.Exists"); panicErr != nil { + err = panicErr + exists = false + } + }() return b.query.Exists(ctx) } @@ -320,7 +364,6 @@ func (b *BunInsertQuery) Model(model interface{}) common.InsertQuery { func (b *BunInsertQuery) Table(table string) common.InsertQuery { if b.hasModel { - // If model is set, do not override table name return b } b.query = b.query.Table(table) @@ -347,7 +390,12 @@ func (b *BunInsertQuery) Returning(columns ...string) common.InsertQuery { return b } -func (b *BunInsertQuery) Exec(ctx context.Context) (common.Result, error) { +func (b *BunInsertQuery) Exec(ctx context.Context) (res common.Result, err error) { + defer func() { + if panicErr := logger.RecoverPanic("BunInsertQuery.Exec"); panicErr != nil { + err = panicErr + } + }() if b.values != nil && len(b.values) > 0 { if !b.hasModel { // If no model was set, use the values map as the model @@ -428,7 +476,12 @@ func (b *BunUpdateQuery) Returning(columns ...string) common.UpdateQuery { return b } -func (b *BunUpdateQuery) Exec(ctx context.Context) (common.Result, error) { +func (b *BunUpdateQuery) Exec(ctx context.Context) (res common.Result, err error) { + defer func() { + if panicErr := logger.RecoverPanic("BunUpdateQuery.Exec"); panicErr != nil { + err = panicErr + } + }() result, err := b.query.Exec(ctx) return &BunResult{result: result}, err } @@ -453,7 +506,12 @@ func (b *BunDeleteQuery) Where(query string, args ...interface{}) common.DeleteQ return b } -func (b *BunDeleteQuery) Exec(ctx context.Context) (common.Result, error) { +func (b *BunDeleteQuery) Exec(ctx context.Context) (res common.Result, err error) { + defer func() { + if panicErr := logger.RecoverPanic("BunDeleteQuery.Exec"); panicErr != nil { + err = panicErr + } + }() result, err := b.query.Exec(ctx) return &BunResult{result: result}, err } diff --git a/pkg/common/adapters/database/gorm.go b/pkg/common/adapters/database/gorm.go index 666d9cb..26589b9 100644 --- a/pkg/common/adapters/database/gorm.go +++ b/pkg/common/adapters/database/gorm.go @@ -8,6 +8,7 @@ import ( "gorm.io/gorm" "github.com/bitechdev/ResolveSpec/pkg/common" + "github.com/bitechdev/ResolveSpec/pkg/logger" "github.com/bitechdev/ResolveSpec/pkg/modelregistry" "github.com/bitechdev/ResolveSpec/pkg/reflection" ) @@ -38,12 +39,22 @@ func (g *GormAdapter) NewDelete() common.DeleteQuery { return &GormDeleteQuery{db: g.db} } -func (g *GormAdapter) Exec(ctx context.Context, query string, args ...interface{}) (common.Result, error) { +func (g *GormAdapter) Exec(ctx context.Context, query string, args ...interface{}) (res common.Result, err error) { + defer func() { + if panicErr := logger.RecoverPanic("GormAdapter.Exec"); panicErr != nil { + err = panicErr + } + }() result := g.db.WithContext(ctx).Exec(query, args...) return &GormResult{result: result}, result.Error } -func (g *GormAdapter) Query(ctx context.Context, dest interface{}, query string, args ...interface{}) error { +func (g *GormAdapter) Query(ctx context.Context, dest interface{}, query string, args ...interface{}) (err error) { + defer func() { + if panicErr := logger.RecoverPanic("GormAdapter.Query"); panicErr != nil { + err = panicErr + } + }() return g.db.WithContext(ctx).Raw(query, args...).Find(dest).Error } @@ -63,7 +74,12 @@ func (g *GormAdapter) RollbackTx(ctx context.Context) error { return g.db.WithContext(ctx).Rollback().Error } -func (g *GormAdapter) RunInTransaction(ctx context.Context, fn func(common.Database) error) error { +func (g *GormAdapter) RunInTransaction(ctx context.Context, fn func(common.Database) error) (err error) { + defer func() { + if panicErr := logger.RecoverPanic("GormAdapter.RunInTransaction"); panicErr != nil { + err = panicErr + } + }() return g.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { adapter := &GormAdapter{db: tx} return fn(adapter) @@ -255,26 +271,48 @@ func (g *GormSelectQuery) Having(having string, args ...interface{}) common.Sele return g } -func (g *GormSelectQuery) Scan(ctx context.Context, dest interface{}) error { +func (g *GormSelectQuery) Scan(ctx context.Context, dest interface{}) (err error) { + defer func() { + if panicErr := logger.RecoverPanic("GormSelectQuery.Scan"); panicErr != nil { + err = panicErr + } + }() return g.db.WithContext(ctx).Find(dest).Error } -func (g *GormSelectQuery) ScanModel(ctx context.Context) error { +func (g *GormSelectQuery) ScanModel(ctx context.Context) (err error) { + defer func() { + if panicErr := logger.RecoverPanic("GormSelectQuery.ScanModel"); panicErr != nil { + err = panicErr + } + }() if g.db.Statement.Model == nil { return fmt.Errorf("ScanModel requires Model() to be set before scanning") } return g.db.WithContext(ctx).Find(g.db.Statement.Model).Error } -func (g *GormSelectQuery) Count(ctx context.Context) (int, error) { - var count int64 - err := g.db.WithContext(ctx).Count(&count).Error - return int(count), err +func (g *GormSelectQuery) Count(ctx context.Context) (count int, err error) { + defer func() { + if panicErr := logger.RecoverPanic("GormSelectQuery.Count"); panicErr != nil { + err = panicErr + count = 0 + } + }() + var count64 int64 + err = g.db.WithContext(ctx).Count(&count64).Error + return int(count64), err } -func (g *GormSelectQuery) Exists(ctx context.Context) (bool, error) { +func (g *GormSelectQuery) Exists(ctx context.Context) (exists bool, err error) { + defer func() { + if panicErr := logger.RecoverPanic("GormSelectQuery.Exists"); panicErr != nil { + err = panicErr + exists = false + } + }() var count int64 - err := g.db.WithContext(ctx).Limit(1).Count(&count).Error + err = g.db.WithContext(ctx).Limit(1).Count(&count).Error return count > 0, err } @@ -314,7 +352,12 @@ func (g *GormInsertQuery) Returning(columns ...string) common.InsertQuery { return g } -func (g *GormInsertQuery) Exec(ctx context.Context) (common.Result, error) { +func (g *GormInsertQuery) Exec(ctx context.Context) (res common.Result, err error) { + defer func() { + if panicErr := logger.RecoverPanic("GormInsertQuery.Exec"); panicErr != nil { + err = panicErr + } + }() var result *gorm.DB switch { case g.model != nil: @@ -401,7 +444,12 @@ func (g *GormUpdateQuery) Returning(columns ...string) common.UpdateQuery { return g } -func (g *GormUpdateQuery) Exec(ctx context.Context) (common.Result, error) { +func (g *GormUpdateQuery) Exec(ctx context.Context) (res common.Result, err error) { + defer func() { + if panicErr := logger.RecoverPanic("GormUpdateQuery.Exec"); panicErr != nil { + err = panicErr + } + }() result := g.db.WithContext(ctx).Updates(g.updates) return &GormResult{result: result}, result.Error } @@ -428,7 +476,12 @@ func (g *GormDeleteQuery) Where(query string, args ...interface{}) common.Delete return g } -func (g *GormDeleteQuery) Exec(ctx context.Context) (common.Result, error) { +func (g *GormDeleteQuery) Exec(ctx context.Context) (res common.Result, err error) { + defer func() { + if panicErr := logger.RecoverPanic("GormDeleteQuery.Exec"); panicErr != nil { + err = panicErr + } + }() result := g.db.WithContext(ctx).Delete(g.model) return &GormResult{result: result}, result.Error } diff --git a/pkg/logger/logger.go b/pkg/logger/logger.go index c4767ea..03238b6 100644 --- a/pkg/logger/logger.go +++ b/pkg/logger/logger.go @@ -103,3 +103,14 @@ func CatchPanicCallback(location string, cb func(err any)) { func CatchPanic(location string) { CatchPanicCallback(location, nil) } + +// RecoverPanic recovers from panics and returns an error +// Use this in deferred functions to convert panics into errors +func RecoverPanic(methodName string) error { + if r := recover(); r != nil { + stack := debug.Stack() + Error("Panic in %s: %v\nStack trace:\n%s", methodName, r, string(stack)) + return fmt.Errorf("panic in %s: %v", methodName, r) + } + return nil +} diff --git a/pkg/restheadspec/headers.go b/pkg/restheadspec/headers.go index c395f98..7438673 100644 --- a/pkg/restheadspec/headers.go +++ b/pkg/restheadspec/headers.go @@ -715,11 +715,15 @@ func (h *Handler) getRelationModel(model interface{}, fieldName string) interfac } modelType := reflect.TypeOf(model) + if modelType == nil { + return nil + } + if modelType.Kind() == reflect.Ptr { modelType = modelType.Elem() } - if modelType.Kind() != reflect.Struct { + if modelType == nil || modelType.Kind() != reflect.Struct { return nil } @@ -731,11 +735,21 @@ func (h *Handler) getRelationModel(model interface{}, fieldName string) interfac // Get the target type targetType := field.Type + if targetType == nil { + return nil + } + if targetType.Kind() == reflect.Slice { targetType = targetType.Elem() + if targetType == nil { + return nil + } } if targetType.Kind() == reflect.Ptr { targetType = targetType.Elem() + if targetType == nil { + return nil + } } if targetType.Kind() != reflect.Struct { @@ -755,11 +769,20 @@ func (h *Handler) resolveRelationName(model interface{}, nameOrTable string) str } modelType := reflect.TypeOf(model) + if modelType == nil { + return nameOrTable + } + // Dereference pointer if needed if modelType.Kind() == reflect.Ptr { modelType = modelType.Elem() } + // Check again after dereferencing + if modelType == nil { + return nameOrTable + } + // Ensure it's a struct if modelType.Kind() != reflect.Struct { return nameOrTable