Compare commits

..

2 Commits

Author SHA1 Message Date
Hein
4fc25c60ae fix(db): correct connection pool assignment in GORM adapter
Some checks failed
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Successful in -29m43s
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Successful in -29m6s
Build , Vet Test, and Lint / Lint Code (push) Successful in -29m11s
Build , Vet Test, and Lint / Build (push) Successful in -29m33s
Tests / Unit Tests (push) Successful in -30m4s
Tests / Integration Tests (push) Failing after -30m13s
2026-04-10 11:20:44 +02:00
Hein
16a960d973 feat(db): add reconnect logic for database adapters
* Implement reconnect functionality in GormAdapter and other database adapters.
* Introduce a DBFactory to handle reconnections.
* Update health check logic to skip reconnects for transient failures.
* Add tests for reconnect behavior in DatabaseAuthenticator.
2026-04-10 11:18:39 +02:00
8 changed files with 736 additions and 65 deletions

1
.gitignore vendored
View File

@@ -29,3 +29,4 @@ test.db
tests/data/
node_modules/
resolvespec-js/dist/
.codex

View File

@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"strings"
"sync"
"gorm.io/gorm"
@@ -15,7 +16,9 @@ import (
// GormAdapter adapts GORM to work with our Database interface
type GormAdapter struct {
dbMu sync.RWMutex
db *gorm.DB
dbFactory func() (*gorm.DB, error)
driverName string
}
@@ -27,10 +30,72 @@ func NewGormAdapter(db *gorm.DB) *GormAdapter {
return adapter
}
// WithDBFactory configures a factory used to reopen the database connection if it is closed.
func (g *GormAdapter) WithDBFactory(factory func() (*gorm.DB, error)) *GormAdapter {
g.dbFactory = factory
return g
}
func (g *GormAdapter) getDB() *gorm.DB {
g.dbMu.RLock()
defer g.dbMu.RUnlock()
return g.db
}
func (g *GormAdapter) reconnectDB(targets ...*gorm.DB) error {
if g.dbFactory == nil {
return fmt.Errorf("no db factory configured for reconnect")
}
freshDB, err := g.dbFactory()
if err != nil {
return err
}
g.dbMu.Lock()
previous := g.db
g.db = freshDB
g.driverName = normalizeGormDriverName(freshDB)
g.dbMu.Unlock()
if previous != nil {
syncGormConnPool(previous, freshDB)
}
for _, target := range targets {
if target != nil && target != previous {
syncGormConnPool(target, freshDB)
}
}
return nil
}
func syncGormConnPool(target, fresh *gorm.DB) {
if target == nil || fresh == nil {
return
}
if target.Config != nil && fresh.Config != nil {
target.ConnPool = fresh.ConnPool
}
if target.Statement != nil {
if fresh.Statement != nil && fresh.Statement.ConnPool != nil {
target.Statement.ConnPool = fresh.Statement.ConnPool
} else if fresh.Config != nil {
target.Statement.ConnPool = fresh.ConnPool
}
target.Statement.DB = target
}
}
// EnableQueryDebug enables query debugging which logs all SQL queries including preloads
// This is useful for debugging preload queries that may be failing
func (g *GormAdapter) EnableQueryDebug() *GormAdapter {
g.dbMu.Lock()
g.db = g.db.Debug()
g.dbMu.Unlock()
logger.Info("GORM query debug mode enabled - all SQL queries will be logged")
return g
}
@@ -44,19 +109,19 @@ func (g *GormAdapter) DisableQueryDebug() *GormAdapter {
}
func (g *GormAdapter) NewSelect() common.SelectQuery {
return &GormSelectQuery{db: g.db, driverName: g.driverName}
return &GormSelectQuery{db: g.getDB(), driverName: g.driverName, reconnect: g.reconnectDB}
}
func (g *GormAdapter) NewInsert() common.InsertQuery {
return &GormInsertQuery{db: g.db}
return &GormInsertQuery{db: g.getDB(), reconnect: g.reconnectDB}
}
func (g *GormAdapter) NewUpdate() common.UpdateQuery {
return &GormUpdateQuery{db: g.db}
return &GormUpdateQuery{db: g.getDB(), reconnect: g.reconnectDB}
}
func (g *GormAdapter) NewDelete() common.DeleteQuery {
return &GormDeleteQuery{db: g.db}
return &GormDeleteQuery{db: g.getDB(), reconnect: g.reconnectDB}
}
func (g *GormAdapter) Exec(ctx context.Context, query string, args ...interface{}) (res common.Result, err error) {
@@ -65,7 +130,15 @@ func (g *GormAdapter) Exec(ctx context.Context, query string, args ...interface{
err = logger.HandlePanic("GormAdapter.Exec", r)
}
}()
result := g.db.WithContext(ctx).Exec(query, args...)
run := func() *gorm.DB {
return g.getDB().WithContext(ctx).Exec(query, args...)
}
result := run()
if isDBClosed(result.Error) {
if reconnErr := g.reconnectDB(); reconnErr == nil {
result = run()
}
}
return &GormResult{result: result}, result.Error
}
@@ -75,15 +148,32 @@ func (g *GormAdapter) Query(ctx context.Context, dest interface{}, query string,
err = logger.HandlePanic("GormAdapter.Query", r)
}
}()
return g.db.WithContext(ctx).Raw(query, args...).Find(dest).Error
run := func() error {
return g.getDB().WithContext(ctx).Raw(query, args...).Find(dest).Error
}
err = run()
if isDBClosed(err) {
if reconnErr := g.reconnectDB(); reconnErr == nil {
err = run()
}
}
return err
}
func (g *GormAdapter) BeginTx(ctx context.Context) (common.Database, error) {
tx := g.db.WithContext(ctx).Begin()
run := func() *gorm.DB {
return g.getDB().WithContext(ctx).Begin()
}
tx := run()
if isDBClosed(tx.Error) {
if reconnErr := g.reconnectDB(); reconnErr == nil {
tx = run()
}
}
if tx.Error != nil {
return nil, tx.Error
}
return &GormAdapter{db: tx, driverName: g.driverName}, nil
return &GormAdapter{db: tx, dbFactory: g.dbFactory, driverName: g.driverName}, nil
}
func (g *GormAdapter) CommitTx(ctx context.Context) error {
@@ -100,24 +190,37 @@ func (g *GormAdapter) RunInTransaction(ctx context.Context, fn func(common.Datab
err = logger.HandlePanic("GormAdapter.RunInTransaction", r)
}
}()
return g.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
adapter := &GormAdapter{db: tx, driverName: g.driverName}
return fn(adapter)
})
run := func() error {
return g.getDB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
adapter := &GormAdapter{db: tx, dbFactory: g.dbFactory, driverName: g.driverName}
return fn(adapter)
})
}
err = run()
if isDBClosed(err) {
if reconnErr := g.reconnectDB(); reconnErr == nil {
err = run()
}
}
return err
}
func (g *GormAdapter) GetUnderlyingDB() interface{} {
return g.db
return g.getDB()
}
func (g *GormAdapter) DriverName() string {
if g.db.Dialector == nil {
return normalizeGormDriverName(g.getDB())
}
func normalizeGormDriverName(db *gorm.DB) string {
if db == nil || db.Dialector == nil {
return ""
}
// Normalize GORM's dialector name to match the project's canonical vocabulary.
// GORM returns "sqlserver" for MSSQL; the rest of the project uses "mssql".
// GORM returns "sqlite" or "sqlite3" for SQLite; we normalize to "sqlite".
switch name := g.db.Name(); name {
switch name := db.Name(); name {
case "sqlserver":
return "mssql"
case "sqlite3":
@@ -130,6 +233,7 @@ func (g *GormAdapter) DriverName() string {
// GormSelectQuery implements SelectQuery for GORM
type GormSelectQuery struct {
db *gorm.DB
reconnect func(...*gorm.DB) error
schema string // Separated schema name
tableName string // Just the table name, without schema
tableAlias string
@@ -347,6 +451,7 @@ func (g *GormSelectQuery) PreloadRelation(relation string, apply ...func(common.
wrapper := &GormSelectQuery{
db: db,
reconnect: g.reconnect,
driverName: g.driverName,
}
@@ -385,6 +490,7 @@ func (g *GormSelectQuery) JoinRelation(relation string, apply ...func(common.Sel
wrapper := &GormSelectQuery{
db: db,
reconnect: g.reconnect,
driverName: g.driverName,
inJoinContext: true, // Mark as JOIN context
joinTableAlias: strings.ToLower(relation), // Use relation name as alias
@@ -444,7 +550,15 @@ func (g *GormSelectQuery) Scan(ctx context.Context, dest interface{}) (err error
err = logger.HandlePanic("GormSelectQuery.Scan", r)
}
}()
err = g.db.WithContext(ctx).Find(dest).Error
run := func() error {
return g.db.WithContext(ctx).Find(dest).Error
}
err = run()
if isDBClosed(err) && g.reconnect != nil {
if reconnErr := g.reconnect(g.db); reconnErr == nil {
err = run()
}
}
if err != nil {
// Log SQL string for debugging
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
@@ -464,7 +578,15 @@ func (g *GormSelectQuery) ScanModel(ctx context.Context) (err error) {
if g.db.Statement.Model == nil {
return fmt.Errorf("ScanModel requires Model() to be set before scanning")
}
err = g.db.WithContext(ctx).Find(g.db.Statement.Model).Error
run := func() error {
return g.db.WithContext(ctx).Find(g.db.Statement.Model).Error
}
err = run()
if isDBClosed(err) && g.reconnect != nil {
if reconnErr := g.reconnect(g.db); reconnErr == nil {
err = run()
}
}
if err != nil {
// Log SQL string for debugging
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
@@ -483,7 +605,15 @@ func (g *GormSelectQuery) Count(ctx context.Context) (count int, err error) {
}
}()
var count64 int64
err = g.db.WithContext(ctx).Count(&count64).Error
run := func() error {
return g.db.WithContext(ctx).Count(&count64).Error
}
err = run()
if isDBClosed(err) && g.reconnect != nil {
if reconnErr := g.reconnect(g.db); reconnErr == nil {
err = run()
}
}
if err != nil {
// Log SQL string for debugging
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
@@ -502,7 +632,15 @@ func (g *GormSelectQuery) Exists(ctx context.Context) (exists bool, err error) {
}
}()
var count int64
err = g.db.WithContext(ctx).Limit(1).Count(&count).Error
run := func() error {
return g.db.WithContext(ctx).Limit(1).Count(&count).Error
}
err = run()
if isDBClosed(err) && g.reconnect != nil {
if reconnErr := g.reconnect(g.db); reconnErr == nil {
err = run()
}
}
if err != nil {
// Log SQL string for debugging
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
@@ -515,9 +653,10 @@ func (g *GormSelectQuery) Exists(ctx context.Context) (exists bool, err error) {
// GormInsertQuery implements InsertQuery for GORM
type GormInsertQuery struct {
db *gorm.DB
model interface{}
values map[string]interface{}
db *gorm.DB
reconnect func(...*gorm.DB) error
model interface{}
values map[string]interface{}
}
func (g *GormInsertQuery) Model(model interface{}) common.InsertQuery {
@@ -555,23 +694,31 @@ func (g *GormInsertQuery) Exec(ctx context.Context) (res common.Result, err erro
err = logger.HandlePanic("GormInsertQuery.Exec", r)
}
}()
var result *gorm.DB
switch {
case g.model != nil:
result = g.db.WithContext(ctx).Create(g.model)
case g.values != nil:
result = g.db.WithContext(ctx).Create(g.values)
default:
result = g.db.WithContext(ctx).Create(map[string]interface{}{})
run := func() *gorm.DB {
switch {
case g.model != nil:
return g.db.WithContext(ctx).Create(g.model)
case g.values != nil:
return g.db.WithContext(ctx).Create(g.values)
default:
return g.db.WithContext(ctx).Create(map[string]interface{}{})
}
}
result := run()
if isDBClosed(result.Error) && g.reconnect != nil {
if reconnErr := g.reconnect(g.db); reconnErr == nil {
result = run()
}
}
return &GormResult{result: result}, result.Error
}
// GormUpdateQuery implements UpdateQuery for GORM
type GormUpdateQuery struct {
db *gorm.DB
model interface{}
updates interface{}
db *gorm.DB
reconnect func(...*gorm.DB) error
model interface{}
updates interface{}
}
func (g *GormUpdateQuery) Model(model interface{}) common.UpdateQuery {
@@ -647,7 +794,15 @@ func (g *GormUpdateQuery) Exec(ctx context.Context) (res common.Result, err erro
err = logger.HandlePanic("GormUpdateQuery.Exec", r)
}
}()
result := g.db.WithContext(ctx).Updates(g.updates)
run := func() *gorm.DB {
return g.db.WithContext(ctx).Updates(g.updates)
}
result := run()
if isDBClosed(result.Error) && g.reconnect != nil {
if reconnErr := g.reconnect(g.db); reconnErr == nil {
result = run()
}
}
if result.Error != nil {
// Log SQL string for debugging
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
@@ -660,8 +815,9 @@ func (g *GormUpdateQuery) Exec(ctx context.Context) (res common.Result, err erro
// GormDeleteQuery implements DeleteQuery for GORM
type GormDeleteQuery struct {
db *gorm.DB
model interface{}
db *gorm.DB
reconnect func(...*gorm.DB) error
model interface{}
}
func (g *GormDeleteQuery) Model(model interface{}) common.DeleteQuery {
@@ -686,7 +842,15 @@ func (g *GormDeleteQuery) Exec(ctx context.Context) (res common.Result, err erro
err = logger.HandlePanic("GormDeleteQuery.Exec", r)
}
}()
result := g.db.WithContext(ctx).Delete(g.model)
run := func() *gorm.DB {
return g.db.WithContext(ctx).Delete(g.model)
}
result := run()
if isDBClosed(result.Error) && g.reconnect != nil {
if reconnErr := g.reconnect(g.db); reconnErr == nil {
result = run()
}
}
if result.Error != nil {
// Log SQL string for debugging
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {

View File

@@ -359,6 +359,42 @@ func (c *sqlConnection) Stats() *ConnectionStats {
return stats
}
func (c *sqlConnection) reconnectForAdapter() error {
timeout := c.config.ConnectTimeout
if timeout <= 0 {
timeout = 10 * time.Second
}
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
return c.Reconnect(ctx)
}
func (c *sqlConnection) reopenNativeForAdapter() (*sql.DB, error) {
if err := c.reconnectForAdapter(); err != nil {
return nil, err
}
return c.Native()
}
func (c *sqlConnection) reopenBunForAdapter() (*bun.DB, error) {
if err := c.reconnectForAdapter(); err != nil {
return nil, err
}
return c.Bun()
}
func (c *sqlConnection) reopenGORMForAdapter() (*gorm.DB, error) {
if err := c.reconnectForAdapter(); err != nil {
return nil, err
}
return c.GORM()
}
// getBunAdapter returns or creates the Bun adapter
func (c *sqlConnection) getBunAdapter() (common.Database, error) {
if c == nil {
@@ -391,7 +427,7 @@ func (c *sqlConnection) getBunAdapter() (common.Database, error) {
c.bunDB = bun.NewDB(native, dialect)
}
c.bunAdapter = database.NewBunAdapter(c.bunDB)
c.bunAdapter = database.NewBunAdapter(c.bunDB).WithDBFactory(c.reopenBunForAdapter)
return c.bunAdapter, nil
}
@@ -432,7 +468,7 @@ func (c *sqlConnection) getGORMAdapter() (common.Database, error) {
c.gormDB = db
}
c.gormAdapter = database.NewGormAdapter(c.gormDB)
c.gormAdapter = database.NewGormAdapter(c.gormDB).WithDBFactory(c.reopenGORMForAdapter)
return c.gormAdapter, nil
}
@@ -473,11 +509,11 @@ func (c *sqlConnection) getNativeAdapter() (common.Database, error) {
// Create a native adapter based on database type
switch c.dbType {
case DatabaseTypePostgreSQL:
c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB, string(c.dbType))
c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB, string(c.dbType)).WithDBFactory(c.reopenNativeForAdapter)
case DatabaseTypeSQLite:
c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB, string(c.dbType))
c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB, string(c.dbType)).WithDBFactory(c.reopenNativeForAdapter)
case DatabaseTypeMSSQL:
c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB, string(c.dbType))
c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB, string(c.dbType)).WithDBFactory(c.reopenNativeForAdapter)
default:
return nil, ErrUnsupportedDatabase
}

View File

@@ -4,8 +4,13 @@ import (
"context"
"database/sql"
"testing"
"time"
_ "github.com/mattn/go-sqlite3"
"gorm.io/gorm"
"github.com/bitechdev/ResolveSpec/pkg/common/adapters/database"
"github.com/bitechdev/ResolveSpec/pkg/dbmanager/providers"
)
func TestNewConnectionFromDB(t *testing.T) {
@@ -208,3 +213,157 @@ func TestNewConnectionFromDB_PostgreSQL(t *testing.T) {
t.Errorf("Expected type DatabaseTypePostgreSQL, got '%s'", conn.Type())
}
}
func TestDatabaseNativeAdapterReconnectFactory(t *testing.T) {
conn := newSQLConnection("test-native", DatabaseTypeSQLite, ConnectionConfig{
Name: "test-native",
Type: DatabaseTypeSQLite,
FilePath: ":memory:",
DefaultORM: string(ORMTypeNative),
ConnectTimeout: 2 * time.Second,
}, providers.NewSQLiteProvider())
ctx := context.Background()
if err := conn.Connect(ctx); err != nil {
t.Fatalf("Failed to connect: %v", err)
}
defer conn.Close()
db, err := conn.Database()
if err != nil {
t.Fatalf("Failed to get database adapter: %v", err)
}
adapter, ok := db.(*database.PgSQLAdapter)
if !ok {
t.Fatalf("Expected PgSQLAdapter, got %T", db)
}
underlyingBefore, ok := adapter.GetUnderlyingDB().(*sql.DB)
if !ok {
t.Fatalf("Expected underlying *sql.DB, got %T", adapter.GetUnderlyingDB())
}
if err := underlyingBefore.Close(); err != nil {
t.Fatalf("Failed to close underlying database: %v", err)
}
if _, err := db.Exec(ctx, "SELECT 1"); err != nil {
t.Fatalf("Expected native adapter to reconnect, got error: %v", err)
}
underlyingAfter, ok := adapter.GetUnderlyingDB().(*sql.DB)
if !ok {
t.Fatalf("Expected reconnected *sql.DB, got %T", adapter.GetUnderlyingDB())
}
if underlyingAfter == underlyingBefore {
t.Fatal("Expected adapter to swap to a fresh *sql.DB after reconnect")
}
}
func TestDatabaseBunAdapterReconnectFactory(t *testing.T) {
conn := newSQLConnection("test-bun", DatabaseTypeSQLite, ConnectionConfig{
Name: "test-bun",
Type: DatabaseTypeSQLite,
FilePath: ":memory:",
DefaultORM: string(ORMTypeBun),
ConnectTimeout: 2 * time.Second,
}, providers.NewSQLiteProvider())
ctx := context.Background()
if err := conn.Connect(ctx); err != nil {
t.Fatalf("Failed to connect: %v", err)
}
defer conn.Close()
db, err := conn.Database()
if err != nil {
t.Fatalf("Failed to get database adapter: %v", err)
}
adapter, ok := db.(*database.BunAdapter)
if !ok {
t.Fatalf("Expected BunAdapter, got %T", db)
}
underlyingBefore, ok := adapter.GetUnderlyingDB().(interface{ Close() error })
if !ok {
t.Fatalf("Expected underlying Bun DB with Close method, got %T", adapter.GetUnderlyingDB())
}
if err := underlyingBefore.Close(); err != nil {
t.Fatalf("Failed to close underlying Bun database: %v", err)
}
if _, err := db.Exec(ctx, "SELECT 1"); err != nil {
t.Fatalf("Expected Bun adapter to reconnect, got error: %v", err)
}
underlyingAfter := adapter.GetUnderlyingDB()
if underlyingAfter == underlyingBefore {
t.Fatal("Expected adapter to swap to a fresh Bun DB after reconnect")
}
}
func TestDatabaseGormAdapterReconnectFactory(t *testing.T) {
conn := newSQLConnection("test-gorm", DatabaseTypeSQLite, ConnectionConfig{
Name: "test-gorm",
Type: DatabaseTypeSQLite,
FilePath: ":memory:",
DefaultORM: string(ORMTypeGORM),
ConnectTimeout: 2 * time.Second,
}, providers.NewSQLiteProvider())
ctx := context.Background()
if err := conn.Connect(ctx); err != nil {
t.Fatalf("Failed to connect: %v", err)
}
defer conn.Close()
db, err := conn.Database()
if err != nil {
t.Fatalf("Failed to get database adapter: %v", err)
}
adapter, ok := db.(*database.GormAdapter)
if !ok {
t.Fatalf("Expected GormAdapter, got %T", db)
}
gormBefore, ok := adapter.GetUnderlyingDB().(*gorm.DB)
if !ok {
t.Fatalf("Expected underlying *gorm.DB, got %T", adapter.GetUnderlyingDB())
}
sqlBefore, err := gormBefore.DB()
if err != nil {
t.Fatalf("Failed to get underlying *sql.DB: %v", err)
}
if err := sqlBefore.Close(); err != nil {
t.Fatalf("Failed to close underlying database: %v", err)
}
count, err := db.NewSelect().Table("sqlite_master").Count(ctx)
if err != nil {
t.Fatalf("Expected GORM query builder to reconnect, got error: %v", err)
}
if count < 0 {
t.Fatalf("Expected non-negative count, got %d", count)
}
gormAfter, ok := adapter.GetUnderlyingDB().(*gorm.DB)
if !ok {
t.Fatalf("Expected reconnected *gorm.DB, got %T", adapter.GetUnderlyingDB())
}
sqlAfter, err := gormAfter.DB()
if err != nil {
t.Fatalf("Failed to get reconnected *sql.DB: %v", err)
}
if sqlAfter == sqlBefore {
t.Fatal("Expected GORM adapter to use a fresh *sql.DB after reconnect")
}
}

View File

@@ -2,7 +2,9 @@ package dbmanager
import (
"context"
"errors"
"fmt"
"strings"
"sync"
"time"
@@ -366,8 +368,11 @@ func (m *connectionManager) performHealthCheck() {
"connection", item.name,
"error", err)
// Attempt reconnection if enabled
if m.config.EnableAutoReconnect {
// Only reconnect when the client handle itself is closed/disconnected.
// For transient database restarts or network blips, *sql.DB can recover
// on its own; forcing Close()+Connect() here invalidates any cached ORM
// wrappers and callers that still hold the old handle.
if m.config.EnableAutoReconnect && shouldReconnectAfterHealthCheck(err) {
logger.Info("Attempting reconnection: connection=%s", item.name)
if err := item.conn.Reconnect(ctx); err != nil {
logger.Error("Reconnection failed",
@@ -376,7 +381,21 @@ func (m *connectionManager) performHealthCheck() {
} else {
logger.Info("Reconnection successful: connection=%s", item.name)
}
} else if m.config.EnableAutoReconnect {
logger.Info("Skipping reconnect for transient health check failure: connection=%s", item.name)
}
}
}
}
func shouldReconnectAfterHealthCheck(err error) bool {
if err == nil {
return false
}
if errors.Is(err, ErrConnectionClosed) {
return true
}
return strings.Contains(err.Error(), "sql: database is closed")
}

View File

@@ -3,12 +3,38 @@ package dbmanager
import (
"context"
"database/sql"
"fmt"
"testing"
"time"
"github.com/uptrace/bun"
"go.mongodb.org/mongo-driver/mongo"
"gorm.io/gorm"
"github.com/bitechdev/ResolveSpec/pkg/common"
_ "github.com/mattn/go-sqlite3"
)
type healthCheckStubConnection struct {
healthErr error
reconnectCalls int
}
func (c *healthCheckStubConnection) Name() string { return "stub" }
func (c *healthCheckStubConnection) Type() DatabaseType { return DatabaseTypePostgreSQL }
func (c *healthCheckStubConnection) Bun() (*bun.DB, error) { return nil, fmt.Errorf("not implemented") }
func (c *healthCheckStubConnection) GORM() (*gorm.DB, error) { return nil, fmt.Errorf("not implemented") }
func (c *healthCheckStubConnection) Native() (*sql.DB, error) { return nil, fmt.Errorf("not implemented") }
func (c *healthCheckStubConnection) DB() (*sql.DB, error) { return nil, fmt.Errorf("not implemented") }
func (c *healthCheckStubConnection) Database() (common.Database, error) { return nil, fmt.Errorf("not implemented") }
func (c *healthCheckStubConnection) MongoDB() (*mongo.Client, error) { return nil, fmt.Errorf("not implemented") }
func (c *healthCheckStubConnection) Connect(ctx context.Context) error { return nil }
func (c *healthCheckStubConnection) Close() error { return nil }
func (c *healthCheckStubConnection) HealthCheck(ctx context.Context) error { return c.healthErr }
func (c *healthCheckStubConnection) Reconnect(ctx context.Context) error { c.reconnectCalls++; return nil }
func (c *healthCheckStubConnection) Stats() *ConnectionStats { return &ConnectionStats{} }
func TestBackgroundHealthChecker(t *testing.T) {
// Create a SQLite in-memory database
db, err := sql.Open("sqlite3", ":memory:")
@@ -224,3 +250,41 @@ func TestManagerStatsAfterClose(t *testing.T) {
t.Errorf("Expected 0 total connections after close, got %d", stats.TotalConnections)
}
}
func TestPerformHealthCheckSkipsReconnectForTransientFailures(t *testing.T) {
conn := &healthCheckStubConnection{
healthErr: fmt.Errorf("connection 'primary' health check: dial tcp 127.0.0.1:5432: connect: connection refused"),
}
mgr := &connectionManager{
connections: map[string]Connection{"primary": conn},
config: ManagerConfig{
EnableAutoReconnect: true,
},
}
mgr.performHealthCheck()
if conn.reconnectCalls != 0 {
t.Fatalf("expected no reconnect attempts for transient health failure, got %d", conn.reconnectCalls)
}
}
func TestPerformHealthCheckReconnectsClosedConnections(t *testing.T) {
conn := &healthCheckStubConnection{
healthErr: NewConnectionError("primary", "health check", fmt.Errorf("sql: database is closed")),
}
mgr := &connectionManager{
connections: map[string]Connection{"primary": conn},
config: ManagerConfig{
EnableAutoReconnect: true,
},
}
mgr.performHealthCheck()
if conn.reconnectCalls != 1 {
t.Fatalf("expected reconnect attempt for closed database handle, got %d", conn.reconnectCalls)
}
}

View File

@@ -143,6 +143,22 @@ func (a *DatabaseAuthenticator) reconnectDB() error {
return nil
}
func (a *DatabaseAuthenticator) runDBOpWithReconnect(run func(*sql.DB) error) error {
db := a.getDB()
if db == nil {
return fmt.Errorf("database connection is nil")
}
err := run(db)
if isDBClosed(err) {
if reconnErr := a.reconnectDB(); reconnErr == nil {
err = run(a.getDB())
}
}
return err
}
func (a *DatabaseAuthenticator) Login(ctx context.Context, req LoginRequest) (*LoginResponse, error) {
// Convert LoginRequest to JSON
reqJSON, err := json.Marshal(req)
@@ -154,16 +170,10 @@ func (a *DatabaseAuthenticator) Login(ctx context.Context, req LoginRequest) (*L
var errorMsg sql.NullString
var dataJSON sql.NullString
runLoginQuery := func() error {
err = a.runDBOpWithReconnect(func(db *sql.DB) error {
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Login)
return a.getDB().QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
}
err = runLoginQuery()
if isDBClosed(err) {
if reconnErr := a.reconnectDB(); reconnErr == nil {
err = runLoginQuery()
}
}
return db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
})
if err != nil {
return nil, fmt.Errorf("login query failed: %w", err)
}
@@ -196,8 +206,10 @@ func (a *DatabaseAuthenticator) Register(ctx context.Context, req RegisterReques
var errorMsg sql.NullString
var dataJSON sql.NullString
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Register)
err = a.getDB().QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
err = a.runDBOpWithReconnect(func(db *sql.DB) error {
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Register)
return db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
})
if err != nil {
return nil, fmt.Errorf("register query failed: %w", err)
}
@@ -229,8 +241,10 @@ func (a *DatabaseAuthenticator) Logout(ctx context.Context, req LogoutRequest) e
var errorMsg sql.NullString
var dataJSON sql.NullString
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Logout)
err = a.getDB().QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
err = a.runDBOpWithReconnect(func(db *sql.DB) error {
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Logout)
return db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
})
if err != nil {
return fmt.Errorf("logout query failed: %w", err)
}
@@ -303,8 +317,10 @@ func (a *DatabaseAuthenticator) Authenticate(r *http.Request) (*UserContext, err
var errorMsg sql.NullString
var userJSON sql.NullString
query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2)`, a.sqlNames.Session)
err := a.getDB().QueryRowContext(r.Context(), query, token, reference).Scan(&success, &errorMsg, &userJSON)
err := a.runDBOpWithReconnect(func(db *sql.DB) error {
query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2)`, a.sqlNames.Session)
return db.QueryRowContext(r.Context(), query, token, reference).Scan(&success, &errorMsg, &userJSON)
})
if err != nil {
return nil, fmt.Errorf("session query failed: %w", err)
}
@@ -379,8 +395,10 @@ func (a *DatabaseAuthenticator) updateSessionActivity(ctx context.Context, sessi
var errorMsg sql.NullString
var updatedUserJSON sql.NullString
query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2::jsonb)`, a.sqlNames.SessionUpdate)
_ = a.getDB().QueryRowContext(ctx, query, sessionToken, string(userJSON)).Scan(&success, &errorMsg, &updatedUserJSON)
_ = a.runDBOpWithReconnect(func(db *sql.DB) error {
query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2::jsonb)`, a.sqlNames.SessionUpdate)
return db.QueryRowContext(ctx, query, sessionToken, string(userJSON)).Scan(&success, &errorMsg, &updatedUserJSON)
})
}
// RefreshToken implements Refreshable interface
@@ -390,8 +408,10 @@ func (a *DatabaseAuthenticator) RefreshToken(ctx context.Context, refreshToken s
var errorMsg sql.NullString
var userJSON sql.NullString
// Get current session to pass to refresh
query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2)`, a.sqlNames.Session)
err := a.getDB().QueryRowContext(ctx, query, refreshToken, "refresh").Scan(&success, &errorMsg, &userJSON)
err := a.runDBOpWithReconnect(func(db *sql.DB) error {
query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2)`, a.sqlNames.Session)
return db.QueryRowContext(ctx, query, refreshToken, "refresh").Scan(&success, &errorMsg, &userJSON)
})
if err != nil {
return nil, fmt.Errorf("refresh token query failed: %w", err)
}
@@ -407,8 +427,10 @@ func (a *DatabaseAuthenticator) RefreshToken(ctx context.Context, refreshToken s
var newErrorMsg sql.NullString
var newUserJSON sql.NullString
refreshQuery := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2::jsonb)`, a.sqlNames.RefreshToken)
err = a.getDB().QueryRowContext(ctx, refreshQuery, refreshToken, userJSON).Scan(&newSuccess, &newErrorMsg, &newUserJSON)
err = a.runDBOpWithReconnect(func(db *sql.DB) error {
refreshQuery := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2::jsonb)`, a.sqlNames.RefreshToken)
return db.QueryRowContext(ctx, refreshQuery, refreshToken, userJSON).Scan(&newSuccess, &newErrorMsg, &newUserJSON)
})
if err != nil {
return nil, fmt.Errorf("refresh token generation failed: %w", err)
}

View File

@@ -3,6 +3,7 @@ package security
import (
"context"
"database/sql"
"fmt"
"net/http"
"net/http/httptest"
"testing"
@@ -790,6 +791,211 @@ func TestDatabaseAuthenticatorRefreshToken(t *testing.T) {
})
}
func TestDatabaseAuthenticatorReconnectsClosedDBPaths(t *testing.T) {
newAuthWithReconnect := func(t *testing.T) (*DatabaseAuthenticator, sqlmock.Sqlmock, sqlmock.Sqlmock, func()) {
t.Helper()
primaryDB, primaryMock, err := sqlmock.New()
if err != nil {
t.Fatalf("failed to create primary mock db: %v", err)
}
reconnectDB, reconnectMock, err := sqlmock.New()
if err != nil {
primaryDB.Close()
t.Fatalf("failed to create reconnect mock db: %v", err)
}
cacheProvider := cache.NewMemoryProvider(&cache.Options{
DefaultTTL: 1 * time.Minute,
MaxSize: 1000,
})
auth := NewDatabaseAuthenticatorWithOptions(primaryDB, DatabaseAuthenticatorOptions{
Cache: cache.NewCache(cacheProvider),
DBFactory: func() (*sql.DB, error) {
return reconnectDB, nil
},
})
cleanup := func() {
_ = primaryDB.Close()
_ = reconnectDB.Close()
}
return auth, primaryMock, reconnectMock, cleanup
}
t.Run("Authenticate reconnects after closed database", func(t *testing.T) {
auth, primaryMock, reconnectMock, cleanup := newAuthWithReconnect(t)
defer cleanup()
req := httptest.NewRequest("GET", "/test", nil)
req.Header.Set("Authorization", "Bearer reconnect-auth-token")
primaryMock.ExpectQuery(`SELECT p_success, p_error, p_user::text FROM resolvespec_session`).
WithArgs("reconnect-auth-token", "authenticate").
WillReturnError(fmt.Errorf("sql: database is closed"))
reconnectRows := sqlmock.NewRows([]string{"p_success", "p_error", "p_user"}).
AddRow(true, nil, `{"user_id":7,"user_name":"reconnect-user","session_id":"reconnect-auth-token"}`)
reconnectMock.ExpectQuery(`SELECT p_success, p_error, p_user::text FROM resolvespec_session`).
WithArgs("reconnect-auth-token", "authenticate").
WillReturnRows(reconnectRows)
userCtx, err := auth.Authenticate(req)
if err != nil {
t.Fatalf("expected authenticate to reconnect, got %v", err)
}
if userCtx.UserID != 7 {
t.Fatalf("expected user ID 7, got %d", userCtx.UserID)
}
if err := primaryMock.ExpectationsWereMet(); err != nil {
t.Fatalf("primary db expectations not met: %v", err)
}
if err := reconnectMock.ExpectationsWereMet(); err != nil {
t.Fatalf("reconnect db expectations not met: %v", err)
}
})
t.Run("Register reconnects after closed database", func(t *testing.T) {
auth, primaryMock, reconnectMock, cleanup := newAuthWithReconnect(t)
defer cleanup()
req := RegisterRequest{
Username: "reconnect-register",
Password: "password123",
Email: "reconnect@example.com",
UserLevel: 1,
Roles: []string{"user"},
}
primaryMock.ExpectQuery(`SELECT p_success, p_error, p_data::text FROM resolvespec_register`).
WithArgs(sqlmock.AnyArg()).
WillReturnError(fmt.Errorf("sql: database is closed"))
reconnectRows := sqlmock.NewRows([]string{"p_success", "p_error", "p_data"}).
AddRow(true, nil, `{"token":"reconnected-register-token","user":{"user_id":8,"user_name":"reconnect-register"},"expires_in":86400}`)
reconnectMock.ExpectQuery(`SELECT p_success, p_error, p_data::text FROM resolvespec_register`).
WithArgs(sqlmock.AnyArg()).
WillReturnRows(reconnectRows)
resp, err := auth.Register(context.Background(), req)
if err != nil {
t.Fatalf("expected register to reconnect, got %v", err)
}
if resp.Token != "reconnected-register-token" {
t.Fatalf("expected refreshed token, got %s", resp.Token)
}
if err := primaryMock.ExpectationsWereMet(); err != nil {
t.Fatalf("primary db expectations not met: %v", err)
}
if err := reconnectMock.ExpectationsWereMet(); err != nil {
t.Fatalf("reconnect db expectations not met: %v", err)
}
})
t.Run("Logout reconnects after closed database", func(t *testing.T) {
auth, primaryMock, reconnectMock, cleanup := newAuthWithReconnect(t)
defer cleanup()
req := LogoutRequest{Token: "logout-reconnect-token", UserID: 9}
primaryMock.ExpectQuery(`SELECT p_success, p_error, p_data::text FROM resolvespec_logout`).
WithArgs(sqlmock.AnyArg()).
WillReturnError(fmt.Errorf("sql: database is closed"))
reconnectRows := sqlmock.NewRows([]string{"p_success", "p_error", "p_data"}).
AddRow(true, nil, nil)
reconnectMock.ExpectQuery(`SELECT p_success, p_error, p_data::text FROM resolvespec_logout`).
WithArgs(sqlmock.AnyArg()).
WillReturnRows(reconnectRows)
if err := auth.Logout(context.Background(), req); err != nil {
t.Fatalf("expected logout to reconnect, got %v", err)
}
if err := primaryMock.ExpectationsWereMet(); err != nil {
t.Fatalf("primary db expectations not met: %v", err)
}
if err := reconnectMock.ExpectationsWereMet(); err != nil {
t.Fatalf("reconnect db expectations not met: %v", err)
}
})
t.Run("RefreshToken reconnects after closed database", func(t *testing.T) {
auth, primaryMock, reconnectMock, cleanup := newAuthWithReconnect(t)
defer cleanup()
refreshToken := "refresh-reconnect-token"
primaryMock.ExpectQuery(`SELECT p_success, p_error, p_user::text FROM resolvespec_session`).
WithArgs(refreshToken, "refresh").
WillReturnError(fmt.Errorf("sql: database is closed"))
sessionRows := sqlmock.NewRows([]string{"p_success", "p_error", "p_user"}).
AddRow(true, nil, `{"user_id":10,"user_name":"refresh-user"}`)
reconnectMock.ExpectQuery(`SELECT p_success, p_error, p_user::text FROM resolvespec_session`).
WithArgs(refreshToken, "refresh").
WillReturnRows(sessionRows)
refreshRows := sqlmock.NewRows([]string{"p_success", "p_error", "p_user"}).
AddRow(true, nil, `{"user_id":10,"user_name":"refresh-user","session_id":"refreshed-token"}`)
reconnectMock.ExpectQuery(`SELECT p_success, p_error, p_user::text FROM resolvespec_refresh_token`).
WithArgs(refreshToken, sqlmock.AnyArg()).
WillReturnRows(refreshRows)
resp, err := auth.RefreshToken(context.Background(), refreshToken)
if err != nil {
t.Fatalf("expected refresh token to reconnect, got %v", err)
}
if resp.Token != "refreshed-token" {
t.Fatalf("expected refreshed-token, got %s", resp.Token)
}
if err := primaryMock.ExpectationsWereMet(); err != nil {
t.Fatalf("primary db expectations not met: %v", err)
}
if err := reconnectMock.ExpectationsWereMet(); err != nil {
t.Fatalf("reconnect db expectations not met: %v", err)
}
})
t.Run("updateSessionActivity reconnects after closed database", func(t *testing.T) {
auth, primaryMock, reconnectMock, cleanup := newAuthWithReconnect(t)
defer cleanup()
userCtx := &UserContext{UserID: 11, UserName: "activity-user", SessionID: "activity-token"}
primaryMock.ExpectQuery(`SELECT p_success, p_error, p_user::text FROM resolvespec_session_update`).
WithArgs("activity-token", sqlmock.AnyArg()).
WillReturnError(fmt.Errorf("sql: database is closed"))
reconnectRows := sqlmock.NewRows([]string{"p_success", "p_error", "p_user"}).
AddRow(true, nil, `{"user_id":11,"user_name":"activity-user","session_id":"activity-token"}`)
reconnectMock.ExpectQuery(`SELECT p_success, p_error, p_user::text FROM resolvespec_session_update`).
WithArgs("activity-token", sqlmock.AnyArg()).
WillReturnRows(reconnectRows)
auth.updateSessionActivity(context.Background(), "activity-token", userCtx)
if err := primaryMock.ExpectationsWereMet(); err != nil {
t.Fatalf("primary db expectations not met: %v", err)
}
if err := reconnectMock.ExpectationsWereMet(); err != nil {
t.Fatalf("reconnect db expectations not met: %v", err)
}
})
}
// Test JWTAuthenticator
func TestJWTAuthenticator(t *testing.T) {
db, mock, err := sqlmock.New()