fix(db): improve database connection handling and reconnection logic

* Added a database factory function to allow reconnection when the database is closed.
* Implemented mutex locks for safe concurrent access to the database connection.
* Updated all database query methods to handle reconnection attempts on closed connections.
* Enhanced error handling for database operations across multiple providers.
This commit is contained in:
Hein
2026-04-09 09:19:28 +02:00
parent a9bf08f58b
commit 79a3912f93
10 changed files with 449 additions and 91 deletions

View File

@@ -6,6 +6,7 @@ import (
"fmt"
"reflect"
"strings"
"sync"
"time"
"github.com/uptrace/bun"
@@ -95,6 +96,8 @@ func debugScanIntoStruct(rows interface{}, dest interface{}) error {
// This demonstrates how the abstraction works with different ORMs
type BunAdapter struct {
db *bun.DB
dbMu sync.RWMutex
dbFactory func() (*bun.DB, error)
driverName string
}
@@ -106,10 +109,36 @@ func NewBunAdapter(db *bun.DB) *BunAdapter {
return adapter
}
// WithDBFactory configures a factory used to reopen the database connection if it is closed.
func (b *BunAdapter) WithDBFactory(factory func() (*bun.DB, error)) *BunAdapter {
b.dbFactory = factory
return b
}
func (b *BunAdapter) getDB() *bun.DB {
b.dbMu.RLock()
defer b.dbMu.RUnlock()
return b.db
}
func (b *BunAdapter) reconnectDB() error {
if b.dbFactory == nil {
return fmt.Errorf("no db factory configured for reconnect")
}
newDB, err := b.dbFactory()
if err != nil {
return err
}
b.dbMu.Lock()
b.db = newDB
b.dbMu.Unlock()
return nil
}
// EnableQueryDebug enables query debugging which logs all SQL queries including preloads
// This is useful for debugging preload queries that may be failing
func (b *BunAdapter) EnableQueryDebug() {
b.db.AddQueryHook(&QueryDebugHook{})
b.getDB().AddQueryHook(&QueryDebugHook{})
logger.Info("Bun query debug mode enabled - all SQL queries will be logged")
}
@@ -130,22 +159,22 @@ func (b *BunAdapter) DisableQueryDebug() {
func (b *BunAdapter) NewSelect() common.SelectQuery {
return &BunSelectQuery{
query: b.db.NewSelect(),
query: b.getDB().NewSelect(),
db: b.db,
driverName: b.driverName,
}
}
func (b *BunAdapter) NewInsert() common.InsertQuery {
return &BunInsertQuery{query: b.db.NewInsert()}
return &BunInsertQuery{query: b.getDB().NewInsert()}
}
func (b *BunAdapter) NewUpdate() common.UpdateQuery {
return &BunUpdateQuery{query: b.db.NewUpdate()}
return &BunUpdateQuery{query: b.getDB().NewUpdate()}
}
func (b *BunAdapter) NewDelete() common.DeleteQuery {
return &BunDeleteQuery{query: b.db.NewDelete()}
return &BunDeleteQuery{query: b.getDB().NewDelete()}
}
func (b *BunAdapter) Exec(ctx context.Context, query string, args ...interface{}) (res common.Result, err error) {
@@ -154,7 +183,14 @@ func (b *BunAdapter) Exec(ctx context.Context, query string, args ...interface{}
err = logger.HandlePanic("BunAdapter.Exec", r)
}
}()
result, err := b.db.ExecContext(ctx, query, args...)
var result sql.Result
run := func() error { var e error; result, e = b.getDB().ExecContext(ctx, query, args...); return e }
err = run()
if isDBClosed(err) {
if reconnErr := b.reconnectDB(); reconnErr == nil {
err = run()
}
}
return &BunResult{result: result}, err
}
@@ -164,11 +200,17 @@ func (b *BunAdapter) Query(ctx context.Context, dest interface{}, query string,
err = logger.HandlePanic("BunAdapter.Query", r)
}
}()
return b.db.NewRaw(query, args...).Scan(ctx, dest)
err = b.getDB().NewRaw(query, args...).Scan(ctx, dest)
if isDBClosed(err) {
if reconnErr := b.reconnectDB(); reconnErr == nil {
err = b.getDB().NewRaw(query, args...).Scan(ctx, dest)
}
}
return err
}
func (b *BunAdapter) BeginTx(ctx context.Context) (common.Database, error) {
tx, err := b.db.BeginTx(ctx, &sql.TxOptions{})
tx, err := b.getDB().BeginTx(ctx, &sql.TxOptions{})
if err != nil {
return nil, err
}
@@ -194,7 +236,7 @@ func (b *BunAdapter) RunInTransaction(ctx context.Context, fn func(common.Databa
err = logger.HandlePanic("BunAdapter.RunInTransaction", r)
}
}()
return b.db.RunInTx(ctx, &sql.TxOptions{}, func(ctx context.Context, tx bun.Tx) error {
return b.getDB().RunInTx(ctx, &sql.TxOptions{}, func(ctx context.Context, tx bun.Tx) error {
// Create adapter with transaction
adapter := &BunTxAdapter{tx: tx, driverName: b.driverName}
return fn(adapter)
@@ -202,7 +244,7 @@ func (b *BunAdapter) RunInTransaction(ctx context.Context, fn func(common.Databa
}
func (b *BunAdapter) GetUnderlyingDB() interface{} {
return b.db
return b.getDB()
}
func (b *BunAdapter) DriverName() string {