feat(db): add reconnect logic for database adapters

* Implement reconnect functionality in GormAdapter and other database adapters.
* Introduce a DBFactory to handle reconnections.
* Update health check logic to skip reconnects for transient failures.
* Add tests for reconnect behavior in DatabaseAuthenticator.
This commit is contained in:
Hein
2026-04-10 11:18:39 +02:00
parent 2afee9d238
commit 16a960d973
8 changed files with 728 additions and 57 deletions

View File

@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"strings"
"sync"
"gorm.io/gorm"
@@ -15,7 +16,9 @@ import (
// GormAdapter adapts GORM to work with our Database interface
type GormAdapter struct {
dbMu sync.RWMutex
db *gorm.DB
dbFactory func() (*gorm.DB, error)
driverName string
}
@@ -27,10 +30,72 @@ func NewGormAdapter(db *gorm.DB) *GormAdapter {
return adapter
}
// WithDBFactory configures a factory used to reopen the database connection if it is closed.
func (g *GormAdapter) WithDBFactory(factory func() (*gorm.DB, error)) *GormAdapter {
g.dbFactory = factory
return g
}
func (g *GormAdapter) getDB() *gorm.DB {
g.dbMu.RLock()
defer g.dbMu.RUnlock()
return g.db
}
func (g *GormAdapter) reconnectDB(targets ...*gorm.DB) error {
if g.dbFactory == nil {
return fmt.Errorf("no db factory configured for reconnect")
}
freshDB, err := g.dbFactory()
if err != nil {
return err
}
g.dbMu.Lock()
previous := g.db
g.db = freshDB
g.driverName = normalizeGormDriverName(freshDB)
g.dbMu.Unlock()
if previous != nil {
syncGormConnPool(previous, freshDB)
}
for _, target := range targets {
if target != nil && target != previous {
syncGormConnPool(target, freshDB)
}
}
return nil
}
func syncGormConnPool(target, fresh *gorm.DB) {
if target == nil || fresh == nil {
return
}
if target.Config != nil && fresh.Config != nil {
target.Config.ConnPool = fresh.Config.ConnPool
}
if target.Statement != nil {
if fresh.Statement != nil && fresh.Statement.ConnPool != nil {
target.Statement.ConnPool = fresh.Statement.ConnPool
} else if fresh.Config != nil {
target.Statement.ConnPool = fresh.Config.ConnPool
}
target.Statement.DB = target
}
}
// EnableQueryDebug enables query debugging which logs all SQL queries including preloads
// This is useful for debugging preload queries that may be failing
func (g *GormAdapter) EnableQueryDebug() *GormAdapter {
g.dbMu.Lock()
g.db = g.db.Debug()
g.dbMu.Unlock()
logger.Info("GORM query debug mode enabled - all SQL queries will be logged")
return g
}
@@ -44,19 +109,19 @@ func (g *GormAdapter) DisableQueryDebug() *GormAdapter {
}
func (g *GormAdapter) NewSelect() common.SelectQuery {
return &GormSelectQuery{db: g.db, driverName: g.driverName}
return &GormSelectQuery{db: g.getDB(), driverName: g.driverName, reconnect: g.reconnectDB}
}
func (g *GormAdapter) NewInsert() common.InsertQuery {
return &GormInsertQuery{db: g.db}
return &GormInsertQuery{db: g.getDB(), reconnect: g.reconnectDB}
}
func (g *GormAdapter) NewUpdate() common.UpdateQuery {
return &GormUpdateQuery{db: g.db}
return &GormUpdateQuery{db: g.getDB(), reconnect: g.reconnectDB}
}
func (g *GormAdapter) NewDelete() common.DeleteQuery {
return &GormDeleteQuery{db: g.db}
return &GormDeleteQuery{db: g.getDB(), reconnect: g.reconnectDB}
}
func (g *GormAdapter) Exec(ctx context.Context, query string, args ...interface{}) (res common.Result, err error) {
@@ -65,7 +130,15 @@ func (g *GormAdapter) Exec(ctx context.Context, query string, args ...interface{
err = logger.HandlePanic("GormAdapter.Exec", r)
}
}()
result := g.db.WithContext(ctx).Exec(query, args...)
run := func() *gorm.DB {
return g.getDB().WithContext(ctx).Exec(query, args...)
}
result := run()
if isDBClosed(result.Error) {
if reconnErr := g.reconnectDB(); reconnErr == nil {
result = run()
}
}
return &GormResult{result: result}, result.Error
}
@@ -75,15 +148,32 @@ func (g *GormAdapter) Query(ctx context.Context, dest interface{}, query string,
err = logger.HandlePanic("GormAdapter.Query", r)
}
}()
return g.db.WithContext(ctx).Raw(query, args...).Find(dest).Error
run := func() error {
return g.getDB().WithContext(ctx).Raw(query, args...).Find(dest).Error
}
err = run()
if isDBClosed(err) {
if reconnErr := g.reconnectDB(); reconnErr == nil {
err = run()
}
}
return err
}
func (g *GormAdapter) BeginTx(ctx context.Context) (common.Database, error) {
tx := g.db.WithContext(ctx).Begin()
run := func() *gorm.DB {
return g.getDB().WithContext(ctx).Begin()
}
tx := run()
if isDBClosed(tx.Error) {
if reconnErr := g.reconnectDB(); reconnErr == nil {
tx = run()
}
}
if tx.Error != nil {
return nil, tx.Error
}
return &GormAdapter{db: tx, driverName: g.driverName}, nil
return &GormAdapter{db: tx, dbFactory: g.dbFactory, driverName: g.driverName}, nil
}
func (g *GormAdapter) CommitTx(ctx context.Context) error {
@@ -100,24 +190,37 @@ func (g *GormAdapter) RunInTransaction(ctx context.Context, fn func(common.Datab
err = logger.HandlePanic("GormAdapter.RunInTransaction", r)
}
}()
return g.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
adapter := &GormAdapter{db: tx, driverName: g.driverName}
return fn(adapter)
})
run := func() error {
return g.getDB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
adapter := &GormAdapter{db: tx, dbFactory: g.dbFactory, driverName: g.driverName}
return fn(adapter)
})
}
err = run()
if isDBClosed(err) {
if reconnErr := g.reconnectDB(); reconnErr == nil {
err = run()
}
}
return err
}
func (g *GormAdapter) GetUnderlyingDB() interface{} {
return g.db
return g.getDB()
}
func (g *GormAdapter) DriverName() string {
if g.db.Dialector == nil {
return normalizeGormDriverName(g.getDB())
}
func normalizeGormDriverName(db *gorm.DB) string {
if db == nil || db.Dialector == nil {
return ""
}
// Normalize GORM's dialector name to match the project's canonical vocabulary.
// GORM returns "sqlserver" for MSSQL; the rest of the project uses "mssql".
// GORM returns "sqlite" or "sqlite3" for SQLite; we normalize to "sqlite".
switch name := g.db.Name(); name {
switch name := db.Name(); name {
case "sqlserver":
return "mssql"
case "sqlite3":
@@ -130,6 +233,7 @@ func (g *GormAdapter) DriverName() string {
// GormSelectQuery implements SelectQuery for GORM
type GormSelectQuery struct {
db *gorm.DB
reconnect func(...*gorm.DB) error
schema string // Separated schema name
tableName string // Just the table name, without schema
tableAlias string
@@ -347,6 +451,7 @@ func (g *GormSelectQuery) PreloadRelation(relation string, apply ...func(common.
wrapper := &GormSelectQuery{
db: db,
reconnect: g.reconnect,
driverName: g.driverName,
}
@@ -385,6 +490,7 @@ func (g *GormSelectQuery) JoinRelation(relation string, apply ...func(common.Sel
wrapper := &GormSelectQuery{
db: db,
reconnect: g.reconnect,
driverName: g.driverName,
inJoinContext: true, // Mark as JOIN context
joinTableAlias: strings.ToLower(relation), // Use relation name as alias
@@ -444,7 +550,15 @@ func (g *GormSelectQuery) Scan(ctx context.Context, dest interface{}) (err error
err = logger.HandlePanic("GormSelectQuery.Scan", r)
}
}()
err = g.db.WithContext(ctx).Find(dest).Error
run := func() error {
return g.db.WithContext(ctx).Find(dest).Error
}
err = run()
if isDBClosed(err) && g.reconnect != nil {
if reconnErr := g.reconnect(g.db); reconnErr == nil {
err = run()
}
}
if err != nil {
// Log SQL string for debugging
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
@@ -464,7 +578,15 @@ func (g *GormSelectQuery) ScanModel(ctx context.Context) (err error) {
if g.db.Statement.Model == nil {
return fmt.Errorf("ScanModel requires Model() to be set before scanning")
}
err = g.db.WithContext(ctx).Find(g.db.Statement.Model).Error
run := func() error {
return g.db.WithContext(ctx).Find(g.db.Statement.Model).Error
}
err = run()
if isDBClosed(err) && g.reconnect != nil {
if reconnErr := g.reconnect(g.db); reconnErr == nil {
err = run()
}
}
if err != nil {
// Log SQL string for debugging
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
@@ -483,7 +605,15 @@ func (g *GormSelectQuery) Count(ctx context.Context) (count int, err error) {
}
}()
var count64 int64
err = g.db.WithContext(ctx).Count(&count64).Error
run := func() error {
return g.db.WithContext(ctx).Count(&count64).Error
}
err = run()
if isDBClosed(err) && g.reconnect != nil {
if reconnErr := g.reconnect(g.db); reconnErr == nil {
err = run()
}
}
if err != nil {
// Log SQL string for debugging
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
@@ -502,7 +632,15 @@ func (g *GormSelectQuery) Exists(ctx context.Context) (exists bool, err error) {
}
}()
var count int64
err = g.db.WithContext(ctx).Limit(1).Count(&count).Error
run := func() error {
return g.db.WithContext(ctx).Limit(1).Count(&count).Error
}
err = run()
if isDBClosed(err) && g.reconnect != nil {
if reconnErr := g.reconnect(g.db); reconnErr == nil {
err = run()
}
}
if err != nil {
// Log SQL string for debugging
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
@@ -516,6 +654,7 @@ func (g *GormSelectQuery) Exists(ctx context.Context) (exists bool, err error) {
// GormInsertQuery implements InsertQuery for GORM
type GormInsertQuery struct {
db *gorm.DB
reconnect func(...*gorm.DB) error
model interface{}
values map[string]interface{}
}
@@ -555,14 +694,21 @@ func (g *GormInsertQuery) Exec(ctx context.Context) (res common.Result, err erro
err = logger.HandlePanic("GormInsertQuery.Exec", r)
}
}()
var result *gorm.DB
switch {
case g.model != nil:
result = g.db.WithContext(ctx).Create(g.model)
case g.values != nil:
result = g.db.WithContext(ctx).Create(g.values)
default:
result = g.db.WithContext(ctx).Create(map[string]interface{}{})
run := func() *gorm.DB {
switch {
case g.model != nil:
return g.db.WithContext(ctx).Create(g.model)
case g.values != nil:
return g.db.WithContext(ctx).Create(g.values)
default:
return g.db.WithContext(ctx).Create(map[string]interface{}{})
}
}
result := run()
if isDBClosed(result.Error) && g.reconnect != nil {
if reconnErr := g.reconnect(g.db); reconnErr == nil {
result = run()
}
}
return &GormResult{result: result}, result.Error
}
@@ -570,6 +716,7 @@ func (g *GormInsertQuery) Exec(ctx context.Context) (res common.Result, err erro
// GormUpdateQuery implements UpdateQuery for GORM
type GormUpdateQuery struct {
db *gorm.DB
reconnect func(...*gorm.DB) error
model interface{}
updates interface{}
}
@@ -647,7 +794,15 @@ func (g *GormUpdateQuery) Exec(ctx context.Context) (res common.Result, err erro
err = logger.HandlePanic("GormUpdateQuery.Exec", r)
}
}()
result := g.db.WithContext(ctx).Updates(g.updates)
run := func() *gorm.DB {
return g.db.WithContext(ctx).Updates(g.updates)
}
result := run()
if isDBClosed(result.Error) && g.reconnect != nil {
if reconnErr := g.reconnect(g.db); reconnErr == nil {
result = run()
}
}
if result.Error != nil {
// Log SQL string for debugging
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
@@ -661,6 +816,7 @@ func (g *GormUpdateQuery) Exec(ctx context.Context) (res common.Result, err erro
// GormDeleteQuery implements DeleteQuery for GORM
type GormDeleteQuery struct {
db *gorm.DB
reconnect func(...*gorm.DB) error
model interface{}
}
@@ -686,7 +842,15 @@ func (g *GormDeleteQuery) Exec(ctx context.Context) (res common.Result, err erro
err = logger.HandlePanic("GormDeleteQuery.Exec", r)
}
}()
result := g.db.WithContext(ctx).Delete(g.model)
run := func() *gorm.DB {
return g.db.WithContext(ctx).Delete(g.model)
}
result := run()
if isDBClosed(result.Error) && g.reconnect != nil {
if reconnErr := g.reconnect(g.db); reconnErr == nil {
result = run()
}
}
if result.Error != nil {
// Log SQL string for debugging
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {