From 16a960d97306a1cf9eeb133a4412b79c0542ce9e Mon Sep 17 00:00:00 2001 From: Hein Date: Fri, 10 Apr 2026 11:18:39 +0200 Subject: [PATCH] feat(db): add reconnect logic for database adapters * Implement reconnect functionality in GormAdapter and other database adapters. * Introduce a DBFactory to handle reconnections. * Update health check logic to skip reconnects for transient failures. * Add tests for reconnect behavior in DatabaseAuthenticator. --- .gitignore | 1 + pkg/common/adapters/database/gorm.go | 222 +++++++++++++++++++++++---- pkg/dbmanager/connection.go | 46 +++++- pkg/dbmanager/factory_test.go | 159 +++++++++++++++++++ pkg/dbmanager/manager.go | 23 ++- pkg/dbmanager/manager_test.go | 64 ++++++++ pkg/security/providers.go | 64 +++++--- pkg/security/providers_test.go | 206 +++++++++++++++++++++++++ 8 files changed, 728 insertions(+), 57 deletions(-) diff --git a/.gitignore b/.gitignore index 8bd5cb4..c1a1ec1 100644 --- a/.gitignore +++ b/.gitignore @@ -29,3 +29,4 @@ test.db tests/data/ node_modules/ resolvespec-js/dist/ +.codex diff --git a/pkg/common/adapters/database/gorm.go b/pkg/common/adapters/database/gorm.go index e7e6724..0bbedd2 100644 --- a/pkg/common/adapters/database/gorm.go +++ b/pkg/common/adapters/database/gorm.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "strings" + "sync" "gorm.io/gorm" @@ -15,7 +16,9 @@ import ( // GormAdapter adapts GORM to work with our Database interface type GormAdapter struct { + dbMu sync.RWMutex db *gorm.DB + dbFactory func() (*gorm.DB, error) driverName string } @@ -27,10 +30,72 @@ func NewGormAdapter(db *gorm.DB) *GormAdapter { return adapter } +// WithDBFactory configures a factory used to reopen the database connection if it is closed. +func (g *GormAdapter) WithDBFactory(factory func() (*gorm.DB, error)) *GormAdapter { + g.dbFactory = factory + return g +} + +func (g *GormAdapter) getDB() *gorm.DB { + g.dbMu.RLock() + defer g.dbMu.RUnlock() + return g.db +} + +func (g *GormAdapter) reconnectDB(targets ...*gorm.DB) error { + if g.dbFactory == nil { + return fmt.Errorf("no db factory configured for reconnect") + } + + freshDB, err := g.dbFactory() + if err != nil { + return err + } + + g.dbMu.Lock() + previous := g.db + g.db = freshDB + g.driverName = normalizeGormDriverName(freshDB) + g.dbMu.Unlock() + + if previous != nil { + syncGormConnPool(previous, freshDB) + } + + for _, target := range targets { + if target != nil && target != previous { + syncGormConnPool(target, freshDB) + } + } + + return nil +} + +func syncGormConnPool(target, fresh *gorm.DB) { + if target == nil || fresh == nil { + return + } + + if target.Config != nil && fresh.Config != nil { + target.Config.ConnPool = fresh.Config.ConnPool + } + + if target.Statement != nil { + if fresh.Statement != nil && fresh.Statement.ConnPool != nil { + target.Statement.ConnPool = fresh.Statement.ConnPool + } else if fresh.Config != nil { + target.Statement.ConnPool = fresh.Config.ConnPool + } + target.Statement.DB = target + } +} + // EnableQueryDebug enables query debugging which logs all SQL queries including preloads // This is useful for debugging preload queries that may be failing func (g *GormAdapter) EnableQueryDebug() *GormAdapter { + g.dbMu.Lock() g.db = g.db.Debug() + g.dbMu.Unlock() logger.Info("GORM query debug mode enabled - all SQL queries will be logged") return g } @@ -44,19 +109,19 @@ func (g *GormAdapter) DisableQueryDebug() *GormAdapter { } func (g *GormAdapter) NewSelect() common.SelectQuery { - return &GormSelectQuery{db: g.db, driverName: g.driverName} + return &GormSelectQuery{db: g.getDB(), driverName: g.driverName, reconnect: g.reconnectDB} } func (g *GormAdapter) NewInsert() common.InsertQuery { - return &GormInsertQuery{db: g.db} + return &GormInsertQuery{db: g.getDB(), reconnect: g.reconnectDB} } func (g *GormAdapter) NewUpdate() common.UpdateQuery { - return &GormUpdateQuery{db: g.db} + return &GormUpdateQuery{db: g.getDB(), reconnect: g.reconnectDB} } func (g *GormAdapter) NewDelete() common.DeleteQuery { - return &GormDeleteQuery{db: g.db} + return &GormDeleteQuery{db: g.getDB(), reconnect: g.reconnectDB} } func (g *GormAdapter) Exec(ctx context.Context, query string, args ...interface{}) (res common.Result, err error) { @@ -65,7 +130,15 @@ func (g *GormAdapter) Exec(ctx context.Context, query string, args ...interface{ err = logger.HandlePanic("GormAdapter.Exec", r) } }() - result := g.db.WithContext(ctx).Exec(query, args...) + run := func() *gorm.DB { + return g.getDB().WithContext(ctx).Exec(query, args...) + } + result := run() + if isDBClosed(result.Error) { + if reconnErr := g.reconnectDB(); reconnErr == nil { + result = run() + } + } return &GormResult{result: result}, result.Error } @@ -75,15 +148,32 @@ func (g *GormAdapter) Query(ctx context.Context, dest interface{}, query string, err = logger.HandlePanic("GormAdapter.Query", r) } }() - return g.db.WithContext(ctx).Raw(query, args...).Find(dest).Error + run := func() error { + return g.getDB().WithContext(ctx).Raw(query, args...).Find(dest).Error + } + err = run() + if isDBClosed(err) { + if reconnErr := g.reconnectDB(); reconnErr == nil { + err = run() + } + } + return err } func (g *GormAdapter) BeginTx(ctx context.Context) (common.Database, error) { - tx := g.db.WithContext(ctx).Begin() + run := func() *gorm.DB { + return g.getDB().WithContext(ctx).Begin() + } + tx := run() + if isDBClosed(tx.Error) { + if reconnErr := g.reconnectDB(); reconnErr == nil { + tx = run() + } + } if tx.Error != nil { return nil, tx.Error } - return &GormAdapter{db: tx, driverName: g.driverName}, nil + return &GormAdapter{db: tx, dbFactory: g.dbFactory, driverName: g.driverName}, nil } func (g *GormAdapter) CommitTx(ctx context.Context) error { @@ -100,24 +190,37 @@ func (g *GormAdapter) RunInTransaction(ctx context.Context, fn func(common.Datab err = logger.HandlePanic("GormAdapter.RunInTransaction", r) } }() - return g.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { - adapter := &GormAdapter{db: tx, driverName: g.driverName} - return fn(adapter) - }) + run := func() error { + return g.getDB().WithContext(ctx).Transaction(func(tx *gorm.DB) error { + adapter := &GormAdapter{db: tx, dbFactory: g.dbFactory, driverName: g.driverName} + return fn(adapter) + }) + } + err = run() + if isDBClosed(err) { + if reconnErr := g.reconnectDB(); reconnErr == nil { + err = run() + } + } + return err } func (g *GormAdapter) GetUnderlyingDB() interface{} { - return g.db + return g.getDB() } func (g *GormAdapter) DriverName() string { - if g.db.Dialector == nil { + return normalizeGormDriverName(g.getDB()) +} + +func normalizeGormDriverName(db *gorm.DB) string { + if db == nil || db.Dialector == nil { return "" } // Normalize GORM's dialector name to match the project's canonical vocabulary. // GORM returns "sqlserver" for MSSQL; the rest of the project uses "mssql". // GORM returns "sqlite" or "sqlite3" for SQLite; we normalize to "sqlite". - switch name := g.db.Name(); name { + switch name := db.Name(); name { case "sqlserver": return "mssql" case "sqlite3": @@ -130,6 +233,7 @@ func (g *GormAdapter) DriverName() string { // GormSelectQuery implements SelectQuery for GORM type GormSelectQuery struct { db *gorm.DB + reconnect func(...*gorm.DB) error schema string // Separated schema name tableName string // Just the table name, without schema tableAlias string @@ -347,6 +451,7 @@ func (g *GormSelectQuery) PreloadRelation(relation string, apply ...func(common. wrapper := &GormSelectQuery{ db: db, + reconnect: g.reconnect, driverName: g.driverName, } @@ -385,6 +490,7 @@ func (g *GormSelectQuery) JoinRelation(relation string, apply ...func(common.Sel wrapper := &GormSelectQuery{ db: db, + reconnect: g.reconnect, driverName: g.driverName, inJoinContext: true, // Mark as JOIN context joinTableAlias: strings.ToLower(relation), // Use relation name as alias @@ -444,7 +550,15 @@ func (g *GormSelectQuery) Scan(ctx context.Context, dest interface{}) (err error err = logger.HandlePanic("GormSelectQuery.Scan", r) } }() - err = g.db.WithContext(ctx).Find(dest).Error + run := func() error { + return g.db.WithContext(ctx).Find(dest).Error + } + err = run() + if isDBClosed(err) && g.reconnect != nil { + if reconnErr := g.reconnect(g.db); reconnErr == nil { + err = run() + } + } if err != nil { // Log SQL string for debugging sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB { @@ -464,7 +578,15 @@ func (g *GormSelectQuery) ScanModel(ctx context.Context) (err error) { if g.db.Statement.Model == nil { return fmt.Errorf("ScanModel requires Model() to be set before scanning") } - err = g.db.WithContext(ctx).Find(g.db.Statement.Model).Error + run := func() error { + return g.db.WithContext(ctx).Find(g.db.Statement.Model).Error + } + err = run() + if isDBClosed(err) && g.reconnect != nil { + if reconnErr := g.reconnect(g.db); reconnErr == nil { + err = run() + } + } if err != nil { // Log SQL string for debugging sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB { @@ -483,7 +605,15 @@ func (g *GormSelectQuery) Count(ctx context.Context) (count int, err error) { } }() var count64 int64 - err = g.db.WithContext(ctx).Count(&count64).Error + run := func() error { + return g.db.WithContext(ctx).Count(&count64).Error + } + err = run() + if isDBClosed(err) && g.reconnect != nil { + if reconnErr := g.reconnect(g.db); reconnErr == nil { + err = run() + } + } if err != nil { // Log SQL string for debugging sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB { @@ -502,7 +632,15 @@ func (g *GormSelectQuery) Exists(ctx context.Context) (exists bool, err error) { } }() var count int64 - err = g.db.WithContext(ctx).Limit(1).Count(&count).Error + run := func() error { + return g.db.WithContext(ctx).Limit(1).Count(&count).Error + } + err = run() + if isDBClosed(err) && g.reconnect != nil { + if reconnErr := g.reconnect(g.db); reconnErr == nil { + err = run() + } + } if err != nil { // Log SQL string for debugging sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB { @@ -516,6 +654,7 @@ func (g *GormSelectQuery) Exists(ctx context.Context) (exists bool, err error) { // GormInsertQuery implements InsertQuery for GORM type GormInsertQuery struct { db *gorm.DB + reconnect func(...*gorm.DB) error model interface{} values map[string]interface{} } @@ -555,14 +694,21 @@ func (g *GormInsertQuery) Exec(ctx context.Context) (res common.Result, err erro err = logger.HandlePanic("GormInsertQuery.Exec", r) } }() - var result *gorm.DB - switch { - case g.model != nil: - result = g.db.WithContext(ctx).Create(g.model) - case g.values != nil: - result = g.db.WithContext(ctx).Create(g.values) - default: - result = g.db.WithContext(ctx).Create(map[string]interface{}{}) + run := func() *gorm.DB { + switch { + case g.model != nil: + return g.db.WithContext(ctx).Create(g.model) + case g.values != nil: + return g.db.WithContext(ctx).Create(g.values) + default: + return g.db.WithContext(ctx).Create(map[string]interface{}{}) + } + } + result := run() + if isDBClosed(result.Error) && g.reconnect != nil { + if reconnErr := g.reconnect(g.db); reconnErr == nil { + result = run() + } } return &GormResult{result: result}, result.Error } @@ -570,6 +716,7 @@ func (g *GormInsertQuery) Exec(ctx context.Context) (res common.Result, err erro // GormUpdateQuery implements UpdateQuery for GORM type GormUpdateQuery struct { db *gorm.DB + reconnect func(...*gorm.DB) error model interface{} updates interface{} } @@ -647,7 +794,15 @@ func (g *GormUpdateQuery) Exec(ctx context.Context) (res common.Result, err erro err = logger.HandlePanic("GormUpdateQuery.Exec", r) } }() - result := g.db.WithContext(ctx).Updates(g.updates) + run := func() *gorm.DB { + return g.db.WithContext(ctx).Updates(g.updates) + } + result := run() + if isDBClosed(result.Error) && g.reconnect != nil { + if reconnErr := g.reconnect(g.db); reconnErr == nil { + result = run() + } + } if result.Error != nil { // Log SQL string for debugging sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB { @@ -661,6 +816,7 @@ func (g *GormUpdateQuery) Exec(ctx context.Context) (res common.Result, err erro // GormDeleteQuery implements DeleteQuery for GORM type GormDeleteQuery struct { db *gorm.DB + reconnect func(...*gorm.DB) error model interface{} } @@ -686,7 +842,15 @@ func (g *GormDeleteQuery) Exec(ctx context.Context) (res common.Result, err erro err = logger.HandlePanic("GormDeleteQuery.Exec", r) } }() - result := g.db.WithContext(ctx).Delete(g.model) + run := func() *gorm.DB { + return g.db.WithContext(ctx).Delete(g.model) + } + result := run() + if isDBClosed(result.Error) && g.reconnect != nil { + if reconnErr := g.reconnect(g.db); reconnErr == nil { + result = run() + } + } if result.Error != nil { // Log SQL string for debugging sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB { diff --git a/pkg/dbmanager/connection.go b/pkg/dbmanager/connection.go index 2f73ed5..1690066 100644 --- a/pkg/dbmanager/connection.go +++ b/pkg/dbmanager/connection.go @@ -359,6 +359,42 @@ func (c *sqlConnection) Stats() *ConnectionStats { return stats } +func (c *sqlConnection) reconnectForAdapter() error { + timeout := c.config.ConnectTimeout + if timeout <= 0 { + timeout = 10 * time.Second + } + + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + return c.Reconnect(ctx) +} + +func (c *sqlConnection) reopenNativeForAdapter() (*sql.DB, error) { + if err := c.reconnectForAdapter(); err != nil { + return nil, err + } + + return c.Native() +} + +func (c *sqlConnection) reopenBunForAdapter() (*bun.DB, error) { + if err := c.reconnectForAdapter(); err != nil { + return nil, err + } + + return c.Bun() +} + +func (c *sqlConnection) reopenGORMForAdapter() (*gorm.DB, error) { + if err := c.reconnectForAdapter(); err != nil { + return nil, err + } + + return c.GORM() +} + // getBunAdapter returns or creates the Bun adapter func (c *sqlConnection) getBunAdapter() (common.Database, error) { if c == nil { @@ -391,7 +427,7 @@ func (c *sqlConnection) getBunAdapter() (common.Database, error) { c.bunDB = bun.NewDB(native, dialect) } - c.bunAdapter = database.NewBunAdapter(c.bunDB) + c.bunAdapter = database.NewBunAdapter(c.bunDB).WithDBFactory(c.reopenBunForAdapter) return c.bunAdapter, nil } @@ -432,7 +468,7 @@ func (c *sqlConnection) getGORMAdapter() (common.Database, error) { c.gormDB = db } - c.gormAdapter = database.NewGormAdapter(c.gormDB) + c.gormAdapter = database.NewGormAdapter(c.gormDB).WithDBFactory(c.reopenGORMForAdapter) return c.gormAdapter, nil } @@ -473,11 +509,11 @@ func (c *sqlConnection) getNativeAdapter() (common.Database, error) { // Create a native adapter based on database type switch c.dbType { case DatabaseTypePostgreSQL: - c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB, string(c.dbType)) + c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB, string(c.dbType)).WithDBFactory(c.reopenNativeForAdapter) case DatabaseTypeSQLite: - c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB, string(c.dbType)) + c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB, string(c.dbType)).WithDBFactory(c.reopenNativeForAdapter) case DatabaseTypeMSSQL: - c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB, string(c.dbType)) + c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB, string(c.dbType)).WithDBFactory(c.reopenNativeForAdapter) default: return nil, ErrUnsupportedDatabase } diff --git a/pkg/dbmanager/factory_test.go b/pkg/dbmanager/factory_test.go index 38c0312..1e71c75 100644 --- a/pkg/dbmanager/factory_test.go +++ b/pkg/dbmanager/factory_test.go @@ -4,8 +4,13 @@ import ( "context" "database/sql" "testing" + "time" _ "github.com/mattn/go-sqlite3" + "gorm.io/gorm" + + "github.com/bitechdev/ResolveSpec/pkg/common/adapters/database" + "github.com/bitechdev/ResolveSpec/pkg/dbmanager/providers" ) func TestNewConnectionFromDB(t *testing.T) { @@ -208,3 +213,157 @@ func TestNewConnectionFromDB_PostgreSQL(t *testing.T) { t.Errorf("Expected type DatabaseTypePostgreSQL, got '%s'", conn.Type()) } } + +func TestDatabaseNativeAdapterReconnectFactory(t *testing.T) { + conn := newSQLConnection("test-native", DatabaseTypeSQLite, ConnectionConfig{ + Name: "test-native", + Type: DatabaseTypeSQLite, + FilePath: ":memory:", + DefaultORM: string(ORMTypeNative), + ConnectTimeout: 2 * time.Second, + }, providers.NewSQLiteProvider()) + + ctx := context.Background() + if err := conn.Connect(ctx); err != nil { + t.Fatalf("Failed to connect: %v", err) + } + defer conn.Close() + + db, err := conn.Database() + if err != nil { + t.Fatalf("Failed to get database adapter: %v", err) + } + + adapter, ok := db.(*database.PgSQLAdapter) + if !ok { + t.Fatalf("Expected PgSQLAdapter, got %T", db) + } + + underlyingBefore, ok := adapter.GetUnderlyingDB().(*sql.DB) + if !ok { + t.Fatalf("Expected underlying *sql.DB, got %T", adapter.GetUnderlyingDB()) + } + + if err := underlyingBefore.Close(); err != nil { + t.Fatalf("Failed to close underlying database: %v", err) + } + + if _, err := db.Exec(ctx, "SELECT 1"); err != nil { + t.Fatalf("Expected native adapter to reconnect, got error: %v", err) + } + + underlyingAfter, ok := adapter.GetUnderlyingDB().(*sql.DB) + if !ok { + t.Fatalf("Expected reconnected *sql.DB, got %T", adapter.GetUnderlyingDB()) + } + + if underlyingAfter == underlyingBefore { + t.Fatal("Expected adapter to swap to a fresh *sql.DB after reconnect") + } +} + +func TestDatabaseBunAdapterReconnectFactory(t *testing.T) { + conn := newSQLConnection("test-bun", DatabaseTypeSQLite, ConnectionConfig{ + Name: "test-bun", + Type: DatabaseTypeSQLite, + FilePath: ":memory:", + DefaultORM: string(ORMTypeBun), + ConnectTimeout: 2 * time.Second, + }, providers.NewSQLiteProvider()) + + ctx := context.Background() + if err := conn.Connect(ctx); err != nil { + t.Fatalf("Failed to connect: %v", err) + } + defer conn.Close() + + db, err := conn.Database() + if err != nil { + t.Fatalf("Failed to get database adapter: %v", err) + } + + adapter, ok := db.(*database.BunAdapter) + if !ok { + t.Fatalf("Expected BunAdapter, got %T", db) + } + + underlyingBefore, ok := adapter.GetUnderlyingDB().(interface{ Close() error }) + if !ok { + t.Fatalf("Expected underlying Bun DB with Close method, got %T", adapter.GetUnderlyingDB()) + } + + if err := underlyingBefore.Close(); err != nil { + t.Fatalf("Failed to close underlying Bun database: %v", err) + } + + if _, err := db.Exec(ctx, "SELECT 1"); err != nil { + t.Fatalf("Expected Bun adapter to reconnect, got error: %v", err) + } + + underlyingAfter := adapter.GetUnderlyingDB() + if underlyingAfter == underlyingBefore { + t.Fatal("Expected adapter to swap to a fresh Bun DB after reconnect") + } +} + +func TestDatabaseGormAdapterReconnectFactory(t *testing.T) { + conn := newSQLConnection("test-gorm", DatabaseTypeSQLite, ConnectionConfig{ + Name: "test-gorm", + Type: DatabaseTypeSQLite, + FilePath: ":memory:", + DefaultORM: string(ORMTypeGORM), + ConnectTimeout: 2 * time.Second, + }, providers.NewSQLiteProvider()) + + ctx := context.Background() + if err := conn.Connect(ctx); err != nil { + t.Fatalf("Failed to connect: %v", err) + } + defer conn.Close() + + db, err := conn.Database() + if err != nil { + t.Fatalf("Failed to get database adapter: %v", err) + } + + adapter, ok := db.(*database.GormAdapter) + if !ok { + t.Fatalf("Expected GormAdapter, got %T", db) + } + + gormBefore, ok := adapter.GetUnderlyingDB().(*gorm.DB) + if !ok { + t.Fatalf("Expected underlying *gorm.DB, got %T", adapter.GetUnderlyingDB()) + } + + sqlBefore, err := gormBefore.DB() + if err != nil { + t.Fatalf("Failed to get underlying *sql.DB: %v", err) + } + + if err := sqlBefore.Close(); err != nil { + t.Fatalf("Failed to close underlying database: %v", err) + } + + count, err := db.NewSelect().Table("sqlite_master").Count(ctx) + if err != nil { + t.Fatalf("Expected GORM query builder to reconnect, got error: %v", err) + } + if count < 0 { + t.Fatalf("Expected non-negative count, got %d", count) + } + + gormAfter, ok := adapter.GetUnderlyingDB().(*gorm.DB) + if !ok { + t.Fatalf("Expected reconnected *gorm.DB, got %T", adapter.GetUnderlyingDB()) + } + + sqlAfter, err := gormAfter.DB() + if err != nil { + t.Fatalf("Failed to get reconnected *sql.DB: %v", err) + } + + if sqlAfter == sqlBefore { + t.Fatal("Expected GORM adapter to use a fresh *sql.DB after reconnect") + } +} diff --git a/pkg/dbmanager/manager.go b/pkg/dbmanager/manager.go index e5dfb5b..1ae4dfe 100644 --- a/pkg/dbmanager/manager.go +++ b/pkg/dbmanager/manager.go @@ -2,7 +2,9 @@ package dbmanager import ( "context" + "errors" "fmt" + "strings" "sync" "time" @@ -366,8 +368,11 @@ func (m *connectionManager) performHealthCheck() { "connection", item.name, "error", err) - // Attempt reconnection if enabled - if m.config.EnableAutoReconnect { + // Only reconnect when the client handle itself is closed/disconnected. + // For transient database restarts or network blips, *sql.DB can recover + // on its own; forcing Close()+Connect() here invalidates any cached ORM + // wrappers and callers that still hold the old handle. + if m.config.EnableAutoReconnect && shouldReconnectAfterHealthCheck(err) { logger.Info("Attempting reconnection: connection=%s", item.name) if err := item.conn.Reconnect(ctx); err != nil { logger.Error("Reconnection failed", @@ -376,7 +381,21 @@ func (m *connectionManager) performHealthCheck() { } else { logger.Info("Reconnection successful: connection=%s", item.name) } + } else if m.config.EnableAutoReconnect { + logger.Info("Skipping reconnect for transient health check failure: connection=%s", item.name) } } } } + +func shouldReconnectAfterHealthCheck(err error) bool { + if err == nil { + return false + } + + if errors.Is(err, ErrConnectionClosed) { + return true + } + + return strings.Contains(err.Error(), "sql: database is closed") +} diff --git a/pkg/dbmanager/manager_test.go b/pkg/dbmanager/manager_test.go index 3497690..2f3d27f 100644 --- a/pkg/dbmanager/manager_test.go +++ b/pkg/dbmanager/manager_test.go @@ -3,12 +3,38 @@ package dbmanager import ( "context" "database/sql" + "fmt" "testing" "time" + "github.com/uptrace/bun" + "go.mongodb.org/mongo-driver/mongo" + "gorm.io/gorm" + + "github.com/bitechdev/ResolveSpec/pkg/common" + _ "github.com/mattn/go-sqlite3" ) +type healthCheckStubConnection struct { + healthErr error + reconnectCalls int +} + +func (c *healthCheckStubConnection) Name() string { return "stub" } +func (c *healthCheckStubConnection) Type() DatabaseType { return DatabaseTypePostgreSQL } +func (c *healthCheckStubConnection) Bun() (*bun.DB, error) { return nil, fmt.Errorf("not implemented") } +func (c *healthCheckStubConnection) GORM() (*gorm.DB, error) { return nil, fmt.Errorf("not implemented") } +func (c *healthCheckStubConnection) Native() (*sql.DB, error) { return nil, fmt.Errorf("not implemented") } +func (c *healthCheckStubConnection) DB() (*sql.DB, error) { return nil, fmt.Errorf("not implemented") } +func (c *healthCheckStubConnection) Database() (common.Database, error) { return nil, fmt.Errorf("not implemented") } +func (c *healthCheckStubConnection) MongoDB() (*mongo.Client, error) { return nil, fmt.Errorf("not implemented") } +func (c *healthCheckStubConnection) Connect(ctx context.Context) error { return nil } +func (c *healthCheckStubConnection) Close() error { return nil } +func (c *healthCheckStubConnection) HealthCheck(ctx context.Context) error { return c.healthErr } +func (c *healthCheckStubConnection) Reconnect(ctx context.Context) error { c.reconnectCalls++; return nil } +func (c *healthCheckStubConnection) Stats() *ConnectionStats { return &ConnectionStats{} } + func TestBackgroundHealthChecker(t *testing.T) { // Create a SQLite in-memory database db, err := sql.Open("sqlite3", ":memory:") @@ -224,3 +250,41 @@ func TestManagerStatsAfterClose(t *testing.T) { t.Errorf("Expected 0 total connections after close, got %d", stats.TotalConnections) } } + +func TestPerformHealthCheckSkipsReconnectForTransientFailures(t *testing.T) { + conn := &healthCheckStubConnection{ + healthErr: fmt.Errorf("connection 'primary' health check: dial tcp 127.0.0.1:5432: connect: connection refused"), + } + + mgr := &connectionManager{ + connections: map[string]Connection{"primary": conn}, + config: ManagerConfig{ + EnableAutoReconnect: true, + }, + } + + mgr.performHealthCheck() + + if conn.reconnectCalls != 0 { + t.Fatalf("expected no reconnect attempts for transient health failure, got %d", conn.reconnectCalls) + } +} + +func TestPerformHealthCheckReconnectsClosedConnections(t *testing.T) { + conn := &healthCheckStubConnection{ + healthErr: NewConnectionError("primary", "health check", fmt.Errorf("sql: database is closed")), + } + + mgr := &connectionManager{ + connections: map[string]Connection{"primary": conn}, + config: ManagerConfig{ + EnableAutoReconnect: true, + }, + } + + mgr.performHealthCheck() + + if conn.reconnectCalls != 1 { + t.Fatalf("expected reconnect attempt for closed database handle, got %d", conn.reconnectCalls) + } +} diff --git a/pkg/security/providers.go b/pkg/security/providers.go index af57172..a10f45e 100644 --- a/pkg/security/providers.go +++ b/pkg/security/providers.go @@ -143,6 +143,22 @@ func (a *DatabaseAuthenticator) reconnectDB() error { return nil } +func (a *DatabaseAuthenticator) runDBOpWithReconnect(run func(*sql.DB) error) error { + db := a.getDB() + if db == nil { + return fmt.Errorf("database connection is nil") + } + + err := run(db) + if isDBClosed(err) { + if reconnErr := a.reconnectDB(); reconnErr == nil { + err = run(a.getDB()) + } + } + + return err +} + func (a *DatabaseAuthenticator) Login(ctx context.Context, req LoginRequest) (*LoginResponse, error) { // Convert LoginRequest to JSON reqJSON, err := json.Marshal(req) @@ -154,16 +170,10 @@ func (a *DatabaseAuthenticator) Login(ctx context.Context, req LoginRequest) (*L var errorMsg sql.NullString var dataJSON sql.NullString - runLoginQuery := func() error { + err = a.runDBOpWithReconnect(func(db *sql.DB) error { query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Login) - return a.getDB().QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON) - } - err = runLoginQuery() - if isDBClosed(err) { - if reconnErr := a.reconnectDB(); reconnErr == nil { - err = runLoginQuery() - } - } + return db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON) + }) if err != nil { return nil, fmt.Errorf("login query failed: %w", err) } @@ -196,8 +206,10 @@ func (a *DatabaseAuthenticator) Register(ctx context.Context, req RegisterReques var errorMsg sql.NullString var dataJSON sql.NullString - query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Register) - err = a.getDB().QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON) + err = a.runDBOpWithReconnect(func(db *sql.DB) error { + query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Register) + return db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON) + }) if err != nil { return nil, fmt.Errorf("register query failed: %w", err) } @@ -229,8 +241,10 @@ func (a *DatabaseAuthenticator) Logout(ctx context.Context, req LogoutRequest) e var errorMsg sql.NullString var dataJSON sql.NullString - query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Logout) - err = a.getDB().QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON) + err = a.runDBOpWithReconnect(func(db *sql.DB) error { + query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Logout) + return db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON) + }) if err != nil { return fmt.Errorf("logout query failed: %w", err) } @@ -303,8 +317,10 @@ func (a *DatabaseAuthenticator) Authenticate(r *http.Request) (*UserContext, err var errorMsg sql.NullString var userJSON sql.NullString - query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2)`, a.sqlNames.Session) - err := a.getDB().QueryRowContext(r.Context(), query, token, reference).Scan(&success, &errorMsg, &userJSON) + err := a.runDBOpWithReconnect(func(db *sql.DB) error { + query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2)`, a.sqlNames.Session) + return db.QueryRowContext(r.Context(), query, token, reference).Scan(&success, &errorMsg, &userJSON) + }) if err != nil { return nil, fmt.Errorf("session query failed: %w", err) } @@ -379,8 +395,10 @@ func (a *DatabaseAuthenticator) updateSessionActivity(ctx context.Context, sessi var errorMsg sql.NullString var updatedUserJSON sql.NullString - query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2::jsonb)`, a.sqlNames.SessionUpdate) - _ = a.getDB().QueryRowContext(ctx, query, sessionToken, string(userJSON)).Scan(&success, &errorMsg, &updatedUserJSON) + _ = a.runDBOpWithReconnect(func(db *sql.DB) error { + query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2::jsonb)`, a.sqlNames.SessionUpdate) + return db.QueryRowContext(ctx, query, sessionToken, string(userJSON)).Scan(&success, &errorMsg, &updatedUserJSON) + }) } // RefreshToken implements Refreshable interface @@ -390,8 +408,10 @@ func (a *DatabaseAuthenticator) RefreshToken(ctx context.Context, refreshToken s var errorMsg sql.NullString var userJSON sql.NullString // Get current session to pass to refresh - query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2)`, a.sqlNames.Session) - err := a.getDB().QueryRowContext(ctx, query, refreshToken, "refresh").Scan(&success, &errorMsg, &userJSON) + err := a.runDBOpWithReconnect(func(db *sql.DB) error { + query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2)`, a.sqlNames.Session) + return db.QueryRowContext(ctx, query, refreshToken, "refresh").Scan(&success, &errorMsg, &userJSON) + }) if err != nil { return nil, fmt.Errorf("refresh token query failed: %w", err) } @@ -407,8 +427,10 @@ func (a *DatabaseAuthenticator) RefreshToken(ctx context.Context, refreshToken s var newErrorMsg sql.NullString var newUserJSON sql.NullString - refreshQuery := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2::jsonb)`, a.sqlNames.RefreshToken) - err = a.getDB().QueryRowContext(ctx, refreshQuery, refreshToken, userJSON).Scan(&newSuccess, &newErrorMsg, &newUserJSON) + err = a.runDBOpWithReconnect(func(db *sql.DB) error { + refreshQuery := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2::jsonb)`, a.sqlNames.RefreshToken) + return db.QueryRowContext(ctx, refreshQuery, refreshToken, userJSON).Scan(&newSuccess, &newErrorMsg, &newUserJSON) + }) if err != nil { return nil, fmt.Errorf("refresh token generation failed: %w", err) } diff --git a/pkg/security/providers_test.go b/pkg/security/providers_test.go index b20f193..9b50dfa 100644 --- a/pkg/security/providers_test.go +++ b/pkg/security/providers_test.go @@ -3,6 +3,7 @@ package security import ( "context" "database/sql" + "fmt" "net/http" "net/http/httptest" "testing" @@ -790,6 +791,211 @@ func TestDatabaseAuthenticatorRefreshToken(t *testing.T) { }) } +func TestDatabaseAuthenticatorReconnectsClosedDBPaths(t *testing.T) { + newAuthWithReconnect := func(t *testing.T) (*DatabaseAuthenticator, sqlmock.Sqlmock, sqlmock.Sqlmock, func()) { + t.Helper() + + primaryDB, primaryMock, err := sqlmock.New() + if err != nil { + t.Fatalf("failed to create primary mock db: %v", err) + } + + reconnectDB, reconnectMock, err := sqlmock.New() + if err != nil { + primaryDB.Close() + t.Fatalf("failed to create reconnect mock db: %v", err) + } + + cacheProvider := cache.NewMemoryProvider(&cache.Options{ + DefaultTTL: 1 * time.Minute, + MaxSize: 1000, + }) + + auth := NewDatabaseAuthenticatorWithOptions(primaryDB, DatabaseAuthenticatorOptions{ + Cache: cache.NewCache(cacheProvider), + DBFactory: func() (*sql.DB, error) { + return reconnectDB, nil + }, + }) + + cleanup := func() { + _ = primaryDB.Close() + _ = reconnectDB.Close() + } + + return auth, primaryMock, reconnectMock, cleanup + } + + t.Run("Authenticate reconnects after closed database", func(t *testing.T) { + auth, primaryMock, reconnectMock, cleanup := newAuthWithReconnect(t) + defer cleanup() + + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("Authorization", "Bearer reconnect-auth-token") + + primaryMock.ExpectQuery(`SELECT p_success, p_error, p_user::text FROM resolvespec_session`). + WithArgs("reconnect-auth-token", "authenticate"). + WillReturnError(fmt.Errorf("sql: database is closed")) + + reconnectRows := sqlmock.NewRows([]string{"p_success", "p_error", "p_user"}). + AddRow(true, nil, `{"user_id":7,"user_name":"reconnect-user","session_id":"reconnect-auth-token"}`) + + reconnectMock.ExpectQuery(`SELECT p_success, p_error, p_user::text FROM resolvespec_session`). + WithArgs("reconnect-auth-token", "authenticate"). + WillReturnRows(reconnectRows) + + userCtx, err := auth.Authenticate(req) + if err != nil { + t.Fatalf("expected authenticate to reconnect, got %v", err) + } + if userCtx.UserID != 7 { + t.Fatalf("expected user ID 7, got %d", userCtx.UserID) + } + + if err := primaryMock.ExpectationsWereMet(); err != nil { + t.Fatalf("primary db expectations not met: %v", err) + } + if err := reconnectMock.ExpectationsWereMet(); err != nil { + t.Fatalf("reconnect db expectations not met: %v", err) + } + }) + + t.Run("Register reconnects after closed database", func(t *testing.T) { + auth, primaryMock, reconnectMock, cleanup := newAuthWithReconnect(t) + defer cleanup() + + req := RegisterRequest{ + Username: "reconnect-register", + Password: "password123", + Email: "reconnect@example.com", + UserLevel: 1, + Roles: []string{"user"}, + } + + primaryMock.ExpectQuery(`SELECT p_success, p_error, p_data::text FROM resolvespec_register`). + WithArgs(sqlmock.AnyArg()). + WillReturnError(fmt.Errorf("sql: database is closed")) + + reconnectRows := sqlmock.NewRows([]string{"p_success", "p_error", "p_data"}). + AddRow(true, nil, `{"token":"reconnected-register-token","user":{"user_id":8,"user_name":"reconnect-register"},"expires_in":86400}`) + + reconnectMock.ExpectQuery(`SELECT p_success, p_error, p_data::text FROM resolvespec_register`). + WithArgs(sqlmock.AnyArg()). + WillReturnRows(reconnectRows) + + resp, err := auth.Register(context.Background(), req) + if err != nil { + t.Fatalf("expected register to reconnect, got %v", err) + } + if resp.Token != "reconnected-register-token" { + t.Fatalf("expected refreshed token, got %s", resp.Token) + } + + if err := primaryMock.ExpectationsWereMet(); err != nil { + t.Fatalf("primary db expectations not met: %v", err) + } + if err := reconnectMock.ExpectationsWereMet(); err != nil { + t.Fatalf("reconnect db expectations not met: %v", err) + } + }) + + t.Run("Logout reconnects after closed database", func(t *testing.T) { + auth, primaryMock, reconnectMock, cleanup := newAuthWithReconnect(t) + defer cleanup() + + req := LogoutRequest{Token: "logout-reconnect-token", UserID: 9} + + primaryMock.ExpectQuery(`SELECT p_success, p_error, p_data::text FROM resolvespec_logout`). + WithArgs(sqlmock.AnyArg()). + WillReturnError(fmt.Errorf("sql: database is closed")) + + reconnectRows := sqlmock.NewRows([]string{"p_success", "p_error", "p_data"}). + AddRow(true, nil, nil) + + reconnectMock.ExpectQuery(`SELECT p_success, p_error, p_data::text FROM resolvespec_logout`). + WithArgs(sqlmock.AnyArg()). + WillReturnRows(reconnectRows) + + if err := auth.Logout(context.Background(), req); err != nil { + t.Fatalf("expected logout to reconnect, got %v", err) + } + + if err := primaryMock.ExpectationsWereMet(); err != nil { + t.Fatalf("primary db expectations not met: %v", err) + } + if err := reconnectMock.ExpectationsWereMet(); err != nil { + t.Fatalf("reconnect db expectations not met: %v", err) + } + }) + + t.Run("RefreshToken reconnects after closed database", func(t *testing.T) { + auth, primaryMock, reconnectMock, cleanup := newAuthWithReconnect(t) + defer cleanup() + + refreshToken := "refresh-reconnect-token" + + primaryMock.ExpectQuery(`SELECT p_success, p_error, p_user::text FROM resolvespec_session`). + WithArgs(refreshToken, "refresh"). + WillReturnError(fmt.Errorf("sql: database is closed")) + + sessionRows := sqlmock.NewRows([]string{"p_success", "p_error", "p_user"}). + AddRow(true, nil, `{"user_id":10,"user_name":"refresh-user"}`) + + reconnectMock.ExpectQuery(`SELECT p_success, p_error, p_user::text FROM resolvespec_session`). + WithArgs(refreshToken, "refresh"). + WillReturnRows(sessionRows) + + refreshRows := sqlmock.NewRows([]string{"p_success", "p_error", "p_user"}). + AddRow(true, nil, `{"user_id":10,"user_name":"refresh-user","session_id":"refreshed-token"}`) + + reconnectMock.ExpectQuery(`SELECT p_success, p_error, p_user::text FROM resolvespec_refresh_token`). + WithArgs(refreshToken, sqlmock.AnyArg()). + WillReturnRows(refreshRows) + + resp, err := auth.RefreshToken(context.Background(), refreshToken) + if err != nil { + t.Fatalf("expected refresh token to reconnect, got %v", err) + } + if resp.Token != "refreshed-token" { + t.Fatalf("expected refreshed-token, got %s", resp.Token) + } + + if err := primaryMock.ExpectationsWereMet(); err != nil { + t.Fatalf("primary db expectations not met: %v", err) + } + if err := reconnectMock.ExpectationsWereMet(); err != nil { + t.Fatalf("reconnect db expectations not met: %v", err) + } + }) + + t.Run("updateSessionActivity reconnects after closed database", func(t *testing.T) { + auth, primaryMock, reconnectMock, cleanup := newAuthWithReconnect(t) + defer cleanup() + + userCtx := &UserContext{UserID: 11, UserName: "activity-user", SessionID: "activity-token"} + + primaryMock.ExpectQuery(`SELECT p_success, p_error, p_user::text FROM resolvespec_session_update`). + WithArgs("activity-token", sqlmock.AnyArg()). + WillReturnError(fmt.Errorf("sql: database is closed")) + + reconnectRows := sqlmock.NewRows([]string{"p_success", "p_error", "p_user"}). + AddRow(true, nil, `{"user_id":11,"user_name":"activity-user","session_id":"activity-token"}`) + + reconnectMock.ExpectQuery(`SELECT p_success, p_error, p_user::text FROM resolvespec_session_update`). + WithArgs("activity-token", sqlmock.AnyArg()). + WillReturnRows(reconnectRows) + + auth.updateSessionActivity(context.Background(), "activity-token", userCtx) + + if err := primaryMock.ExpectationsWereMet(); err != nil { + t.Fatalf("primary db expectations not met: %v", err) + } + if err := reconnectMock.ExpectationsWereMet(); err != nil { + t.Fatalf("reconnect db expectations not met: %v", err) + } + }) +} + // Test JWTAuthenticator func TestJWTAuthenticator(t *testing.T) { db, mock, err := sqlmock.New()