mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2026-04-12 02:43:53 +00:00
feat(db): add reconnect logic for database adapters
* Implement reconnect functionality in GormAdapter and other database adapters. * Introduce a DBFactory to handle reconnections. * Update health check logic to skip reconnects for transient failures. * Add tests for reconnect behavior in DatabaseAuthenticator.
This commit is contained in:
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
@@ -15,7 +16,9 @@ import (
|
||||
|
||||
// GormAdapter adapts GORM to work with our Database interface
|
||||
type GormAdapter struct {
|
||||
dbMu sync.RWMutex
|
||||
db *gorm.DB
|
||||
dbFactory func() (*gorm.DB, error)
|
||||
driverName string
|
||||
}
|
||||
|
||||
@@ -27,10 +30,72 @@ func NewGormAdapter(db *gorm.DB) *GormAdapter {
|
||||
return adapter
|
||||
}
|
||||
|
||||
// WithDBFactory configures a factory used to reopen the database connection if it is closed.
|
||||
func (g *GormAdapter) WithDBFactory(factory func() (*gorm.DB, error)) *GormAdapter {
|
||||
g.dbFactory = factory
|
||||
return g
|
||||
}
|
||||
|
||||
func (g *GormAdapter) getDB() *gorm.DB {
|
||||
g.dbMu.RLock()
|
||||
defer g.dbMu.RUnlock()
|
||||
return g.db
|
||||
}
|
||||
|
||||
func (g *GormAdapter) reconnectDB(targets ...*gorm.DB) error {
|
||||
if g.dbFactory == nil {
|
||||
return fmt.Errorf("no db factory configured for reconnect")
|
||||
}
|
||||
|
||||
freshDB, err := g.dbFactory()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
g.dbMu.Lock()
|
||||
previous := g.db
|
||||
g.db = freshDB
|
||||
g.driverName = normalizeGormDriverName(freshDB)
|
||||
g.dbMu.Unlock()
|
||||
|
||||
if previous != nil {
|
||||
syncGormConnPool(previous, freshDB)
|
||||
}
|
||||
|
||||
for _, target := range targets {
|
||||
if target != nil && target != previous {
|
||||
syncGormConnPool(target, freshDB)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func syncGormConnPool(target, fresh *gorm.DB) {
|
||||
if target == nil || fresh == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if target.Config != nil && fresh.Config != nil {
|
||||
target.Config.ConnPool = fresh.Config.ConnPool
|
||||
}
|
||||
|
||||
if target.Statement != nil {
|
||||
if fresh.Statement != nil && fresh.Statement.ConnPool != nil {
|
||||
target.Statement.ConnPool = fresh.Statement.ConnPool
|
||||
} else if fresh.Config != nil {
|
||||
target.Statement.ConnPool = fresh.Config.ConnPool
|
||||
}
|
||||
target.Statement.DB = target
|
||||
}
|
||||
}
|
||||
|
||||
// EnableQueryDebug enables query debugging which logs all SQL queries including preloads
|
||||
// This is useful for debugging preload queries that may be failing
|
||||
func (g *GormAdapter) EnableQueryDebug() *GormAdapter {
|
||||
g.dbMu.Lock()
|
||||
g.db = g.db.Debug()
|
||||
g.dbMu.Unlock()
|
||||
logger.Info("GORM query debug mode enabled - all SQL queries will be logged")
|
||||
return g
|
||||
}
|
||||
@@ -44,19 +109,19 @@ func (g *GormAdapter) DisableQueryDebug() *GormAdapter {
|
||||
}
|
||||
|
||||
func (g *GormAdapter) NewSelect() common.SelectQuery {
|
||||
return &GormSelectQuery{db: g.db, driverName: g.driverName}
|
||||
return &GormSelectQuery{db: g.getDB(), driverName: g.driverName, reconnect: g.reconnectDB}
|
||||
}
|
||||
|
||||
func (g *GormAdapter) NewInsert() common.InsertQuery {
|
||||
return &GormInsertQuery{db: g.db}
|
||||
return &GormInsertQuery{db: g.getDB(), reconnect: g.reconnectDB}
|
||||
}
|
||||
|
||||
func (g *GormAdapter) NewUpdate() common.UpdateQuery {
|
||||
return &GormUpdateQuery{db: g.db}
|
||||
return &GormUpdateQuery{db: g.getDB(), reconnect: g.reconnectDB}
|
||||
}
|
||||
|
||||
func (g *GormAdapter) NewDelete() common.DeleteQuery {
|
||||
return &GormDeleteQuery{db: g.db}
|
||||
return &GormDeleteQuery{db: g.getDB(), reconnect: g.reconnectDB}
|
||||
}
|
||||
|
||||
func (g *GormAdapter) Exec(ctx context.Context, query string, args ...interface{}) (res common.Result, err error) {
|
||||
@@ -65,7 +130,15 @@ func (g *GormAdapter) Exec(ctx context.Context, query string, args ...interface{
|
||||
err = logger.HandlePanic("GormAdapter.Exec", r)
|
||||
}
|
||||
}()
|
||||
result := g.db.WithContext(ctx).Exec(query, args...)
|
||||
run := func() *gorm.DB {
|
||||
return g.getDB().WithContext(ctx).Exec(query, args...)
|
||||
}
|
||||
result := run()
|
||||
if isDBClosed(result.Error) {
|
||||
if reconnErr := g.reconnectDB(); reconnErr == nil {
|
||||
result = run()
|
||||
}
|
||||
}
|
||||
return &GormResult{result: result}, result.Error
|
||||
}
|
||||
|
||||
@@ -75,15 +148,32 @@ func (g *GormAdapter) Query(ctx context.Context, dest interface{}, query string,
|
||||
err = logger.HandlePanic("GormAdapter.Query", r)
|
||||
}
|
||||
}()
|
||||
return g.db.WithContext(ctx).Raw(query, args...).Find(dest).Error
|
||||
run := func() error {
|
||||
return g.getDB().WithContext(ctx).Raw(query, args...).Find(dest).Error
|
||||
}
|
||||
err = run()
|
||||
if isDBClosed(err) {
|
||||
if reconnErr := g.reconnectDB(); reconnErr == nil {
|
||||
err = run()
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (g *GormAdapter) BeginTx(ctx context.Context) (common.Database, error) {
|
||||
tx := g.db.WithContext(ctx).Begin()
|
||||
run := func() *gorm.DB {
|
||||
return g.getDB().WithContext(ctx).Begin()
|
||||
}
|
||||
tx := run()
|
||||
if isDBClosed(tx.Error) {
|
||||
if reconnErr := g.reconnectDB(); reconnErr == nil {
|
||||
tx = run()
|
||||
}
|
||||
}
|
||||
if tx.Error != nil {
|
||||
return nil, tx.Error
|
||||
}
|
||||
return &GormAdapter{db: tx, driverName: g.driverName}, nil
|
||||
return &GormAdapter{db: tx, dbFactory: g.dbFactory, driverName: g.driverName}, nil
|
||||
}
|
||||
|
||||
func (g *GormAdapter) CommitTx(ctx context.Context) error {
|
||||
@@ -100,24 +190,37 @@ func (g *GormAdapter) RunInTransaction(ctx context.Context, fn func(common.Datab
|
||||
err = logger.HandlePanic("GormAdapter.RunInTransaction", r)
|
||||
}
|
||||
}()
|
||||
return g.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
adapter := &GormAdapter{db: tx, driverName: g.driverName}
|
||||
return fn(adapter)
|
||||
})
|
||||
run := func() error {
|
||||
return g.getDB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
adapter := &GormAdapter{db: tx, dbFactory: g.dbFactory, driverName: g.driverName}
|
||||
return fn(adapter)
|
||||
})
|
||||
}
|
||||
err = run()
|
||||
if isDBClosed(err) {
|
||||
if reconnErr := g.reconnectDB(); reconnErr == nil {
|
||||
err = run()
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (g *GormAdapter) GetUnderlyingDB() interface{} {
|
||||
return g.db
|
||||
return g.getDB()
|
||||
}
|
||||
|
||||
func (g *GormAdapter) DriverName() string {
|
||||
if g.db.Dialector == nil {
|
||||
return normalizeGormDriverName(g.getDB())
|
||||
}
|
||||
|
||||
func normalizeGormDriverName(db *gorm.DB) string {
|
||||
if db == nil || db.Dialector == nil {
|
||||
return ""
|
||||
}
|
||||
// Normalize GORM's dialector name to match the project's canonical vocabulary.
|
||||
// GORM returns "sqlserver" for MSSQL; the rest of the project uses "mssql".
|
||||
// GORM returns "sqlite" or "sqlite3" for SQLite; we normalize to "sqlite".
|
||||
switch name := g.db.Name(); name {
|
||||
switch name := db.Name(); name {
|
||||
case "sqlserver":
|
||||
return "mssql"
|
||||
case "sqlite3":
|
||||
@@ -130,6 +233,7 @@ func (g *GormAdapter) DriverName() string {
|
||||
// GormSelectQuery implements SelectQuery for GORM
|
||||
type GormSelectQuery struct {
|
||||
db *gorm.DB
|
||||
reconnect func(...*gorm.DB) error
|
||||
schema string // Separated schema name
|
||||
tableName string // Just the table name, without schema
|
||||
tableAlias string
|
||||
@@ -347,6 +451,7 @@ func (g *GormSelectQuery) PreloadRelation(relation string, apply ...func(common.
|
||||
|
||||
wrapper := &GormSelectQuery{
|
||||
db: db,
|
||||
reconnect: g.reconnect,
|
||||
driverName: g.driverName,
|
||||
}
|
||||
|
||||
@@ -385,6 +490,7 @@ func (g *GormSelectQuery) JoinRelation(relation string, apply ...func(common.Sel
|
||||
|
||||
wrapper := &GormSelectQuery{
|
||||
db: db,
|
||||
reconnect: g.reconnect,
|
||||
driverName: g.driverName,
|
||||
inJoinContext: true, // Mark as JOIN context
|
||||
joinTableAlias: strings.ToLower(relation), // Use relation name as alias
|
||||
@@ -444,7 +550,15 @@ func (g *GormSelectQuery) Scan(ctx context.Context, dest interface{}) (err error
|
||||
err = logger.HandlePanic("GormSelectQuery.Scan", r)
|
||||
}
|
||||
}()
|
||||
err = g.db.WithContext(ctx).Find(dest).Error
|
||||
run := func() error {
|
||||
return g.db.WithContext(ctx).Find(dest).Error
|
||||
}
|
||||
err = run()
|
||||
if isDBClosed(err) && g.reconnect != nil {
|
||||
if reconnErr := g.reconnect(g.db); reconnErr == nil {
|
||||
err = run()
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
// Log SQL string for debugging
|
||||
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
||||
@@ -464,7 +578,15 @@ func (g *GormSelectQuery) ScanModel(ctx context.Context) (err error) {
|
||||
if g.db.Statement.Model == nil {
|
||||
return fmt.Errorf("ScanModel requires Model() to be set before scanning")
|
||||
}
|
||||
err = g.db.WithContext(ctx).Find(g.db.Statement.Model).Error
|
||||
run := func() error {
|
||||
return g.db.WithContext(ctx).Find(g.db.Statement.Model).Error
|
||||
}
|
||||
err = run()
|
||||
if isDBClosed(err) && g.reconnect != nil {
|
||||
if reconnErr := g.reconnect(g.db); reconnErr == nil {
|
||||
err = run()
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
// Log SQL string for debugging
|
||||
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
||||
@@ -483,7 +605,15 @@ func (g *GormSelectQuery) Count(ctx context.Context) (count int, err error) {
|
||||
}
|
||||
}()
|
||||
var count64 int64
|
||||
err = g.db.WithContext(ctx).Count(&count64).Error
|
||||
run := func() error {
|
||||
return g.db.WithContext(ctx).Count(&count64).Error
|
||||
}
|
||||
err = run()
|
||||
if isDBClosed(err) && g.reconnect != nil {
|
||||
if reconnErr := g.reconnect(g.db); reconnErr == nil {
|
||||
err = run()
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
// Log SQL string for debugging
|
||||
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
||||
@@ -502,7 +632,15 @@ func (g *GormSelectQuery) Exists(ctx context.Context) (exists bool, err error) {
|
||||
}
|
||||
}()
|
||||
var count int64
|
||||
err = g.db.WithContext(ctx).Limit(1).Count(&count).Error
|
||||
run := func() error {
|
||||
return g.db.WithContext(ctx).Limit(1).Count(&count).Error
|
||||
}
|
||||
err = run()
|
||||
if isDBClosed(err) && g.reconnect != nil {
|
||||
if reconnErr := g.reconnect(g.db); reconnErr == nil {
|
||||
err = run()
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
// Log SQL string for debugging
|
||||
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
||||
@@ -516,6 +654,7 @@ func (g *GormSelectQuery) Exists(ctx context.Context) (exists bool, err error) {
|
||||
// GormInsertQuery implements InsertQuery for GORM
|
||||
type GormInsertQuery struct {
|
||||
db *gorm.DB
|
||||
reconnect func(...*gorm.DB) error
|
||||
model interface{}
|
||||
values map[string]interface{}
|
||||
}
|
||||
@@ -555,14 +694,21 @@ func (g *GormInsertQuery) Exec(ctx context.Context) (res common.Result, err erro
|
||||
err = logger.HandlePanic("GormInsertQuery.Exec", r)
|
||||
}
|
||||
}()
|
||||
var result *gorm.DB
|
||||
switch {
|
||||
case g.model != nil:
|
||||
result = g.db.WithContext(ctx).Create(g.model)
|
||||
case g.values != nil:
|
||||
result = g.db.WithContext(ctx).Create(g.values)
|
||||
default:
|
||||
result = g.db.WithContext(ctx).Create(map[string]interface{}{})
|
||||
run := func() *gorm.DB {
|
||||
switch {
|
||||
case g.model != nil:
|
||||
return g.db.WithContext(ctx).Create(g.model)
|
||||
case g.values != nil:
|
||||
return g.db.WithContext(ctx).Create(g.values)
|
||||
default:
|
||||
return g.db.WithContext(ctx).Create(map[string]interface{}{})
|
||||
}
|
||||
}
|
||||
result := run()
|
||||
if isDBClosed(result.Error) && g.reconnect != nil {
|
||||
if reconnErr := g.reconnect(g.db); reconnErr == nil {
|
||||
result = run()
|
||||
}
|
||||
}
|
||||
return &GormResult{result: result}, result.Error
|
||||
}
|
||||
@@ -570,6 +716,7 @@ func (g *GormInsertQuery) Exec(ctx context.Context) (res common.Result, err erro
|
||||
// GormUpdateQuery implements UpdateQuery for GORM
|
||||
type GormUpdateQuery struct {
|
||||
db *gorm.DB
|
||||
reconnect func(...*gorm.DB) error
|
||||
model interface{}
|
||||
updates interface{}
|
||||
}
|
||||
@@ -647,7 +794,15 @@ func (g *GormUpdateQuery) Exec(ctx context.Context) (res common.Result, err erro
|
||||
err = logger.HandlePanic("GormUpdateQuery.Exec", r)
|
||||
}
|
||||
}()
|
||||
result := g.db.WithContext(ctx).Updates(g.updates)
|
||||
run := func() *gorm.DB {
|
||||
return g.db.WithContext(ctx).Updates(g.updates)
|
||||
}
|
||||
result := run()
|
||||
if isDBClosed(result.Error) && g.reconnect != nil {
|
||||
if reconnErr := g.reconnect(g.db); reconnErr == nil {
|
||||
result = run()
|
||||
}
|
||||
}
|
||||
if result.Error != nil {
|
||||
// Log SQL string for debugging
|
||||
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
||||
@@ -661,6 +816,7 @@ func (g *GormUpdateQuery) Exec(ctx context.Context) (res common.Result, err erro
|
||||
// GormDeleteQuery implements DeleteQuery for GORM
|
||||
type GormDeleteQuery struct {
|
||||
db *gorm.DB
|
||||
reconnect func(...*gorm.DB) error
|
||||
model interface{}
|
||||
}
|
||||
|
||||
@@ -686,7 +842,15 @@ func (g *GormDeleteQuery) Exec(ctx context.Context) (res common.Result, err erro
|
||||
err = logger.HandlePanic("GormDeleteQuery.Exec", r)
|
||||
}
|
||||
}()
|
||||
result := g.db.WithContext(ctx).Delete(g.model)
|
||||
run := func() *gorm.DB {
|
||||
return g.db.WithContext(ctx).Delete(g.model)
|
||||
}
|
||||
result := run()
|
||||
if isDBClosed(result.Error) && g.reconnect != nil {
|
||||
if reconnErr := g.reconnect(g.db); reconnErr == nil {
|
||||
result = run()
|
||||
}
|
||||
}
|
||||
if result.Error != nil {
|
||||
// Log SQL string for debugging
|
||||
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
||||
|
||||
Reference in New Issue
Block a user