mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2026-04-13 11:23:52 +00:00
feat(db): add reconnect logic for database adapters
* Implement reconnect functionality in GormAdapter and other database adapters. * Introduce a DBFactory to handle reconnections. * Update health check logic to skip reconnects for transient failures. * Add tests for reconnect behavior in DatabaseAuthenticator.
This commit is contained in:
@@ -359,6 +359,42 @@ func (c *sqlConnection) Stats() *ConnectionStats {
|
||||
return stats
|
||||
}
|
||||
|
||||
func (c *sqlConnection) reconnectForAdapter() error {
|
||||
timeout := c.config.ConnectTimeout
|
||||
if timeout <= 0 {
|
||||
timeout = 10 * time.Second
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
|
||||
return c.Reconnect(ctx)
|
||||
}
|
||||
|
||||
func (c *sqlConnection) reopenNativeForAdapter() (*sql.DB, error) {
|
||||
if err := c.reconnectForAdapter(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return c.Native()
|
||||
}
|
||||
|
||||
func (c *sqlConnection) reopenBunForAdapter() (*bun.DB, error) {
|
||||
if err := c.reconnectForAdapter(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return c.Bun()
|
||||
}
|
||||
|
||||
func (c *sqlConnection) reopenGORMForAdapter() (*gorm.DB, error) {
|
||||
if err := c.reconnectForAdapter(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return c.GORM()
|
||||
}
|
||||
|
||||
// getBunAdapter returns or creates the Bun adapter
|
||||
func (c *sqlConnection) getBunAdapter() (common.Database, error) {
|
||||
if c == nil {
|
||||
@@ -391,7 +427,7 @@ func (c *sqlConnection) getBunAdapter() (common.Database, error) {
|
||||
c.bunDB = bun.NewDB(native, dialect)
|
||||
}
|
||||
|
||||
c.bunAdapter = database.NewBunAdapter(c.bunDB)
|
||||
c.bunAdapter = database.NewBunAdapter(c.bunDB).WithDBFactory(c.reopenBunForAdapter)
|
||||
return c.bunAdapter, nil
|
||||
}
|
||||
|
||||
@@ -432,7 +468,7 @@ func (c *sqlConnection) getGORMAdapter() (common.Database, error) {
|
||||
c.gormDB = db
|
||||
}
|
||||
|
||||
c.gormAdapter = database.NewGormAdapter(c.gormDB)
|
||||
c.gormAdapter = database.NewGormAdapter(c.gormDB).WithDBFactory(c.reopenGORMForAdapter)
|
||||
return c.gormAdapter, nil
|
||||
}
|
||||
|
||||
@@ -473,11 +509,11 @@ func (c *sqlConnection) getNativeAdapter() (common.Database, error) {
|
||||
// Create a native adapter based on database type
|
||||
switch c.dbType {
|
||||
case DatabaseTypePostgreSQL:
|
||||
c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB, string(c.dbType))
|
||||
c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB, string(c.dbType)).WithDBFactory(c.reopenNativeForAdapter)
|
||||
case DatabaseTypeSQLite:
|
||||
c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB, string(c.dbType))
|
||||
c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB, string(c.dbType)).WithDBFactory(c.reopenNativeForAdapter)
|
||||
case DatabaseTypeMSSQL:
|
||||
c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB, string(c.dbType))
|
||||
c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB, string(c.dbType)).WithDBFactory(c.reopenNativeForAdapter)
|
||||
default:
|
||||
return nil, ErrUnsupportedDatabase
|
||||
}
|
||||
|
||||
@@ -4,8 +4,13 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common/adapters/database"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/dbmanager/providers"
|
||||
)
|
||||
|
||||
func TestNewConnectionFromDB(t *testing.T) {
|
||||
@@ -208,3 +213,157 @@ func TestNewConnectionFromDB_PostgreSQL(t *testing.T) {
|
||||
t.Errorf("Expected type DatabaseTypePostgreSQL, got '%s'", conn.Type())
|
||||
}
|
||||
}
|
||||
|
||||
func TestDatabaseNativeAdapterReconnectFactory(t *testing.T) {
|
||||
conn := newSQLConnection("test-native", DatabaseTypeSQLite, ConnectionConfig{
|
||||
Name: "test-native",
|
||||
Type: DatabaseTypeSQLite,
|
||||
FilePath: ":memory:",
|
||||
DefaultORM: string(ORMTypeNative),
|
||||
ConnectTimeout: 2 * time.Second,
|
||||
}, providers.NewSQLiteProvider())
|
||||
|
||||
ctx := context.Background()
|
||||
if err := conn.Connect(ctx); err != nil {
|
||||
t.Fatalf("Failed to connect: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
db, err := conn.Database()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get database adapter: %v", err)
|
||||
}
|
||||
|
||||
adapter, ok := db.(*database.PgSQLAdapter)
|
||||
if !ok {
|
||||
t.Fatalf("Expected PgSQLAdapter, got %T", db)
|
||||
}
|
||||
|
||||
underlyingBefore, ok := adapter.GetUnderlyingDB().(*sql.DB)
|
||||
if !ok {
|
||||
t.Fatalf("Expected underlying *sql.DB, got %T", adapter.GetUnderlyingDB())
|
||||
}
|
||||
|
||||
if err := underlyingBefore.Close(); err != nil {
|
||||
t.Fatalf("Failed to close underlying database: %v", err)
|
||||
}
|
||||
|
||||
if _, err := db.Exec(ctx, "SELECT 1"); err != nil {
|
||||
t.Fatalf("Expected native adapter to reconnect, got error: %v", err)
|
||||
}
|
||||
|
||||
underlyingAfter, ok := adapter.GetUnderlyingDB().(*sql.DB)
|
||||
if !ok {
|
||||
t.Fatalf("Expected reconnected *sql.DB, got %T", adapter.GetUnderlyingDB())
|
||||
}
|
||||
|
||||
if underlyingAfter == underlyingBefore {
|
||||
t.Fatal("Expected adapter to swap to a fresh *sql.DB after reconnect")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDatabaseBunAdapterReconnectFactory(t *testing.T) {
|
||||
conn := newSQLConnection("test-bun", DatabaseTypeSQLite, ConnectionConfig{
|
||||
Name: "test-bun",
|
||||
Type: DatabaseTypeSQLite,
|
||||
FilePath: ":memory:",
|
||||
DefaultORM: string(ORMTypeBun),
|
||||
ConnectTimeout: 2 * time.Second,
|
||||
}, providers.NewSQLiteProvider())
|
||||
|
||||
ctx := context.Background()
|
||||
if err := conn.Connect(ctx); err != nil {
|
||||
t.Fatalf("Failed to connect: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
db, err := conn.Database()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get database adapter: %v", err)
|
||||
}
|
||||
|
||||
adapter, ok := db.(*database.BunAdapter)
|
||||
if !ok {
|
||||
t.Fatalf("Expected BunAdapter, got %T", db)
|
||||
}
|
||||
|
||||
underlyingBefore, ok := adapter.GetUnderlyingDB().(interface{ Close() error })
|
||||
if !ok {
|
||||
t.Fatalf("Expected underlying Bun DB with Close method, got %T", adapter.GetUnderlyingDB())
|
||||
}
|
||||
|
||||
if err := underlyingBefore.Close(); err != nil {
|
||||
t.Fatalf("Failed to close underlying Bun database: %v", err)
|
||||
}
|
||||
|
||||
if _, err := db.Exec(ctx, "SELECT 1"); err != nil {
|
||||
t.Fatalf("Expected Bun adapter to reconnect, got error: %v", err)
|
||||
}
|
||||
|
||||
underlyingAfter := adapter.GetUnderlyingDB()
|
||||
if underlyingAfter == underlyingBefore {
|
||||
t.Fatal("Expected adapter to swap to a fresh Bun DB after reconnect")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDatabaseGormAdapterReconnectFactory(t *testing.T) {
|
||||
conn := newSQLConnection("test-gorm", DatabaseTypeSQLite, ConnectionConfig{
|
||||
Name: "test-gorm",
|
||||
Type: DatabaseTypeSQLite,
|
||||
FilePath: ":memory:",
|
||||
DefaultORM: string(ORMTypeGORM),
|
||||
ConnectTimeout: 2 * time.Second,
|
||||
}, providers.NewSQLiteProvider())
|
||||
|
||||
ctx := context.Background()
|
||||
if err := conn.Connect(ctx); err != nil {
|
||||
t.Fatalf("Failed to connect: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
db, err := conn.Database()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get database adapter: %v", err)
|
||||
}
|
||||
|
||||
adapter, ok := db.(*database.GormAdapter)
|
||||
if !ok {
|
||||
t.Fatalf("Expected GormAdapter, got %T", db)
|
||||
}
|
||||
|
||||
gormBefore, ok := adapter.GetUnderlyingDB().(*gorm.DB)
|
||||
if !ok {
|
||||
t.Fatalf("Expected underlying *gorm.DB, got %T", adapter.GetUnderlyingDB())
|
||||
}
|
||||
|
||||
sqlBefore, err := gormBefore.DB()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get underlying *sql.DB: %v", err)
|
||||
}
|
||||
|
||||
if err := sqlBefore.Close(); err != nil {
|
||||
t.Fatalf("Failed to close underlying database: %v", err)
|
||||
}
|
||||
|
||||
count, err := db.NewSelect().Table("sqlite_master").Count(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected GORM query builder to reconnect, got error: %v", err)
|
||||
}
|
||||
if count < 0 {
|
||||
t.Fatalf("Expected non-negative count, got %d", count)
|
||||
}
|
||||
|
||||
gormAfter, ok := adapter.GetUnderlyingDB().(*gorm.DB)
|
||||
if !ok {
|
||||
t.Fatalf("Expected reconnected *gorm.DB, got %T", adapter.GetUnderlyingDB())
|
||||
}
|
||||
|
||||
sqlAfter, err := gormAfter.DB()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get reconnected *sql.DB: %v", err)
|
||||
}
|
||||
|
||||
if sqlAfter == sqlBefore {
|
||||
t.Fatal("Expected GORM adapter to use a fresh *sql.DB after reconnect")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,7 +2,9 @@ package dbmanager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -366,8 +368,11 @@ func (m *connectionManager) performHealthCheck() {
|
||||
"connection", item.name,
|
||||
"error", err)
|
||||
|
||||
// Attempt reconnection if enabled
|
||||
if m.config.EnableAutoReconnect {
|
||||
// Only reconnect when the client handle itself is closed/disconnected.
|
||||
// For transient database restarts or network blips, *sql.DB can recover
|
||||
// on its own; forcing Close()+Connect() here invalidates any cached ORM
|
||||
// wrappers and callers that still hold the old handle.
|
||||
if m.config.EnableAutoReconnect && shouldReconnectAfterHealthCheck(err) {
|
||||
logger.Info("Attempting reconnection: connection=%s", item.name)
|
||||
if err := item.conn.Reconnect(ctx); err != nil {
|
||||
logger.Error("Reconnection failed",
|
||||
@@ -376,7 +381,21 @@ func (m *connectionManager) performHealthCheck() {
|
||||
} else {
|
||||
logger.Info("Reconnection successful: connection=%s", item.name)
|
||||
}
|
||||
} else if m.config.EnableAutoReconnect {
|
||||
logger.Info("Skipping reconnect for transient health check failure: connection=%s", item.name)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func shouldReconnectAfterHealthCheck(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if errors.Is(err, ErrConnectionClosed) {
|
||||
return true
|
||||
}
|
||||
|
||||
return strings.Contains(err.Error(), "sql: database is closed")
|
||||
}
|
||||
|
||||
@@ -3,12 +3,38 @@ package dbmanager
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/uptrace/bun"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
type healthCheckStubConnection struct {
|
||||
healthErr error
|
||||
reconnectCalls int
|
||||
}
|
||||
|
||||
func (c *healthCheckStubConnection) Name() string { return "stub" }
|
||||
func (c *healthCheckStubConnection) Type() DatabaseType { return DatabaseTypePostgreSQL }
|
||||
func (c *healthCheckStubConnection) Bun() (*bun.DB, error) { return nil, fmt.Errorf("not implemented") }
|
||||
func (c *healthCheckStubConnection) GORM() (*gorm.DB, error) { return nil, fmt.Errorf("not implemented") }
|
||||
func (c *healthCheckStubConnection) Native() (*sql.DB, error) { return nil, fmt.Errorf("not implemented") }
|
||||
func (c *healthCheckStubConnection) DB() (*sql.DB, error) { return nil, fmt.Errorf("not implemented") }
|
||||
func (c *healthCheckStubConnection) Database() (common.Database, error) { return nil, fmt.Errorf("not implemented") }
|
||||
func (c *healthCheckStubConnection) MongoDB() (*mongo.Client, error) { return nil, fmt.Errorf("not implemented") }
|
||||
func (c *healthCheckStubConnection) Connect(ctx context.Context) error { return nil }
|
||||
func (c *healthCheckStubConnection) Close() error { return nil }
|
||||
func (c *healthCheckStubConnection) HealthCheck(ctx context.Context) error { return c.healthErr }
|
||||
func (c *healthCheckStubConnection) Reconnect(ctx context.Context) error { c.reconnectCalls++; return nil }
|
||||
func (c *healthCheckStubConnection) Stats() *ConnectionStats { return &ConnectionStats{} }
|
||||
|
||||
func TestBackgroundHealthChecker(t *testing.T) {
|
||||
// Create a SQLite in-memory database
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
@@ -224,3 +250,41 @@ func TestManagerStatsAfterClose(t *testing.T) {
|
||||
t.Errorf("Expected 0 total connections after close, got %d", stats.TotalConnections)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPerformHealthCheckSkipsReconnectForTransientFailures(t *testing.T) {
|
||||
conn := &healthCheckStubConnection{
|
||||
healthErr: fmt.Errorf("connection 'primary' health check: dial tcp 127.0.0.1:5432: connect: connection refused"),
|
||||
}
|
||||
|
||||
mgr := &connectionManager{
|
||||
connections: map[string]Connection{"primary": conn},
|
||||
config: ManagerConfig{
|
||||
EnableAutoReconnect: true,
|
||||
},
|
||||
}
|
||||
|
||||
mgr.performHealthCheck()
|
||||
|
||||
if conn.reconnectCalls != 0 {
|
||||
t.Fatalf("expected no reconnect attempts for transient health failure, got %d", conn.reconnectCalls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPerformHealthCheckReconnectsClosedConnections(t *testing.T) {
|
||||
conn := &healthCheckStubConnection{
|
||||
healthErr: NewConnectionError("primary", "health check", fmt.Errorf("sql: database is closed")),
|
||||
}
|
||||
|
||||
mgr := &connectionManager{
|
||||
connections: map[string]Connection{"primary": conn},
|
||||
config: ManagerConfig{
|
||||
EnableAutoReconnect: true,
|
||||
},
|
||||
}
|
||||
|
||||
mgr.performHealthCheck()
|
||||
|
||||
if conn.reconnectCalls != 1 {
|
||||
t.Fatalf("expected reconnect attempt for closed database handle, got %d", conn.reconnectCalls)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user