fix(db): improve database connection handling and reconnection logic

* Added a database factory function to allow reconnection when the database is closed.
* Implemented mutex locks for safe concurrent access to the database connection.
* Updated all database query methods to handle reconnection attempts on closed connections.
* Enhanced error handling for database operations across multiple providers.
This commit is contained in:
Hein
2026-04-09 09:19:28 +02:00
parent a9bf08f58b
commit 79a3912f93
10 changed files with 449 additions and 91 deletions

View File

@@ -6,6 +6,7 @@ import (
"fmt"
"reflect"
"strings"
"sync"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/logger"
@@ -17,6 +18,8 @@ import (
// This provides a lightweight PostgreSQL adapter without ORM overhead
type PgSQLAdapter struct {
db *sql.DB
dbMu sync.RWMutex
dbFactory func() (*sql.DB, error)
driverName string
}
@@ -31,6 +34,36 @@ func NewPgSQLAdapter(db *sql.DB, driverName ...string) *PgSQLAdapter {
return &PgSQLAdapter{db: db, driverName: name}
}
// WithDBFactory configures a factory used to reopen the database connection if it is closed.
func (p *PgSQLAdapter) WithDBFactory(factory func() (*sql.DB, error)) *PgSQLAdapter {
p.dbFactory = factory
return p
}
func (p *PgSQLAdapter) getDB() *sql.DB {
p.dbMu.RLock()
defer p.dbMu.RUnlock()
return p.db
}
func (p *PgSQLAdapter) reconnectDB() error {
if p.dbFactory == nil {
return fmt.Errorf("no db factory configured for reconnect")
}
newDB, err := p.dbFactory()
if err != nil {
return err
}
p.dbMu.Lock()
p.db = newDB
p.dbMu.Unlock()
return nil
}
func isDBClosed(err error) bool {
return err != nil && strings.Contains(err.Error(), "sql: database is closed")
}
// EnableQueryDebug enables query debugging for development
func (p *PgSQLAdapter) EnableQueryDebug() {
logger.Info("PgSQL query debug mode - logging enabled via logger")
@@ -38,7 +71,7 @@ func (p *PgSQLAdapter) EnableQueryDebug() {
func (p *PgSQLAdapter) NewSelect() common.SelectQuery {
return &PgSQLSelectQuery{
db: p.db,
db: p.getDB(),
driverName: p.driverName,
columns: []string{"*"},
args: make([]interface{}, 0),
@@ -47,7 +80,7 @@ func (p *PgSQLAdapter) NewSelect() common.SelectQuery {
func (p *PgSQLAdapter) NewInsert() common.InsertQuery {
return &PgSQLInsertQuery{
db: p.db,
db: p.getDB(),
driverName: p.driverName,
values: make(map[string]interface{}),
}
@@ -55,7 +88,7 @@ func (p *PgSQLAdapter) NewInsert() common.InsertQuery {
func (p *PgSQLAdapter) NewUpdate() common.UpdateQuery {
return &PgSQLUpdateQuery{
db: p.db,
db: p.getDB(),
driverName: p.driverName,
sets: make(map[string]interface{}),
args: make([]interface{}, 0),
@@ -65,7 +98,7 @@ func (p *PgSQLAdapter) NewUpdate() common.UpdateQuery {
func (p *PgSQLAdapter) NewDelete() common.DeleteQuery {
return &PgSQLDeleteQuery{
db: p.db,
db: p.getDB(),
driverName: p.driverName,
args: make([]interface{}, 0),
whereClauses: make([]string, 0),
@@ -79,7 +112,14 @@ func (p *PgSQLAdapter) Exec(ctx context.Context, query string, args ...interface
}
}()
logger.Debug("PgSQL Exec: %s [args: %v]", query, args)
result, err := p.db.ExecContext(ctx, query, args...)
var result sql.Result
run := func() error { var e error; result, e = p.getDB().ExecContext(ctx, query, args...); return e }
err = run()
if isDBClosed(err) {
if reconnErr := p.reconnectDB(); reconnErr == nil {
err = run()
}
}
if err != nil {
logger.Error("PgSQL Exec failed: %v", err)
return nil, err
@@ -94,7 +134,14 @@ func (p *PgSQLAdapter) Query(ctx context.Context, dest interface{}, query string
}
}()
logger.Debug("PgSQL Query: %s [args: %v]", query, args)
rows, err := p.db.QueryContext(ctx, query, args...)
var rows *sql.Rows
run := func() error { var e error; rows, e = p.getDB().QueryContext(ctx, query, args...); return e }
err = run()
if isDBClosed(err) {
if reconnErr := p.reconnectDB(); reconnErr == nil {
err = run()
}
}
if err != nil {
logger.Error("PgSQL Query failed: %v", err)
return err
@@ -105,7 +152,7 @@ func (p *PgSQLAdapter) Query(ctx context.Context, dest interface{}, query string
}
func (p *PgSQLAdapter) BeginTx(ctx context.Context) (common.Database, error) {
tx, err := p.db.BeginTx(ctx, nil)
tx, err := p.getDB().BeginTx(ctx, nil)
if err != nil {
return nil, err
}
@@ -127,7 +174,7 @@ func (p *PgSQLAdapter) RunInTransaction(ctx context.Context, fn func(common.Data
}
}()
tx, err := p.db.BeginTx(ctx, nil)
tx, err := p.getDB().BeginTx(ctx, nil)
if err != nil {
return err
}