mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2026-04-10 09:56:24 +00:00
Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4fc25c60ae | ||
|
|
16a960d973 |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -29,3 +29,4 @@ test.db
|
||||
tests/data/
|
||||
node_modules/
|
||||
resolvespec-js/dist/
|
||||
.codex
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
@@ -15,7 +16,9 @@ import (
|
||||
|
||||
// GormAdapter adapts GORM to work with our Database interface
|
||||
type GormAdapter struct {
|
||||
dbMu sync.RWMutex
|
||||
db *gorm.DB
|
||||
dbFactory func() (*gorm.DB, error)
|
||||
driverName string
|
||||
}
|
||||
|
||||
@@ -27,10 +30,72 @@ func NewGormAdapter(db *gorm.DB) *GormAdapter {
|
||||
return adapter
|
||||
}
|
||||
|
||||
// WithDBFactory configures a factory used to reopen the database connection if it is closed.
|
||||
func (g *GormAdapter) WithDBFactory(factory func() (*gorm.DB, error)) *GormAdapter {
|
||||
g.dbFactory = factory
|
||||
return g
|
||||
}
|
||||
|
||||
func (g *GormAdapter) getDB() *gorm.DB {
|
||||
g.dbMu.RLock()
|
||||
defer g.dbMu.RUnlock()
|
||||
return g.db
|
||||
}
|
||||
|
||||
func (g *GormAdapter) reconnectDB(targets ...*gorm.DB) error {
|
||||
if g.dbFactory == nil {
|
||||
return fmt.Errorf("no db factory configured for reconnect")
|
||||
}
|
||||
|
||||
freshDB, err := g.dbFactory()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
g.dbMu.Lock()
|
||||
previous := g.db
|
||||
g.db = freshDB
|
||||
g.driverName = normalizeGormDriverName(freshDB)
|
||||
g.dbMu.Unlock()
|
||||
|
||||
if previous != nil {
|
||||
syncGormConnPool(previous, freshDB)
|
||||
}
|
||||
|
||||
for _, target := range targets {
|
||||
if target != nil && target != previous {
|
||||
syncGormConnPool(target, freshDB)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func syncGormConnPool(target, fresh *gorm.DB) {
|
||||
if target == nil || fresh == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if target.Config != nil && fresh.Config != nil {
|
||||
target.ConnPool = fresh.ConnPool
|
||||
}
|
||||
|
||||
if target.Statement != nil {
|
||||
if fresh.Statement != nil && fresh.Statement.ConnPool != nil {
|
||||
target.Statement.ConnPool = fresh.Statement.ConnPool
|
||||
} else if fresh.Config != nil {
|
||||
target.Statement.ConnPool = fresh.ConnPool
|
||||
}
|
||||
target.Statement.DB = target
|
||||
}
|
||||
}
|
||||
|
||||
// EnableQueryDebug enables query debugging which logs all SQL queries including preloads
|
||||
// This is useful for debugging preload queries that may be failing
|
||||
func (g *GormAdapter) EnableQueryDebug() *GormAdapter {
|
||||
g.dbMu.Lock()
|
||||
g.db = g.db.Debug()
|
||||
g.dbMu.Unlock()
|
||||
logger.Info("GORM query debug mode enabled - all SQL queries will be logged")
|
||||
return g
|
||||
}
|
||||
@@ -44,19 +109,19 @@ func (g *GormAdapter) DisableQueryDebug() *GormAdapter {
|
||||
}
|
||||
|
||||
func (g *GormAdapter) NewSelect() common.SelectQuery {
|
||||
return &GormSelectQuery{db: g.db, driverName: g.driverName}
|
||||
return &GormSelectQuery{db: g.getDB(), driverName: g.driverName, reconnect: g.reconnectDB}
|
||||
}
|
||||
|
||||
func (g *GormAdapter) NewInsert() common.InsertQuery {
|
||||
return &GormInsertQuery{db: g.db}
|
||||
return &GormInsertQuery{db: g.getDB(), reconnect: g.reconnectDB}
|
||||
}
|
||||
|
||||
func (g *GormAdapter) NewUpdate() common.UpdateQuery {
|
||||
return &GormUpdateQuery{db: g.db}
|
||||
return &GormUpdateQuery{db: g.getDB(), reconnect: g.reconnectDB}
|
||||
}
|
||||
|
||||
func (g *GormAdapter) NewDelete() common.DeleteQuery {
|
||||
return &GormDeleteQuery{db: g.db}
|
||||
return &GormDeleteQuery{db: g.getDB(), reconnect: g.reconnectDB}
|
||||
}
|
||||
|
||||
func (g *GormAdapter) Exec(ctx context.Context, query string, args ...interface{}) (res common.Result, err error) {
|
||||
@@ -65,7 +130,15 @@ func (g *GormAdapter) Exec(ctx context.Context, query string, args ...interface{
|
||||
err = logger.HandlePanic("GormAdapter.Exec", r)
|
||||
}
|
||||
}()
|
||||
result := g.db.WithContext(ctx).Exec(query, args...)
|
||||
run := func() *gorm.DB {
|
||||
return g.getDB().WithContext(ctx).Exec(query, args...)
|
||||
}
|
||||
result := run()
|
||||
if isDBClosed(result.Error) {
|
||||
if reconnErr := g.reconnectDB(); reconnErr == nil {
|
||||
result = run()
|
||||
}
|
||||
}
|
||||
return &GormResult{result: result}, result.Error
|
||||
}
|
||||
|
||||
@@ -75,15 +148,32 @@ func (g *GormAdapter) Query(ctx context.Context, dest interface{}, query string,
|
||||
err = logger.HandlePanic("GormAdapter.Query", r)
|
||||
}
|
||||
}()
|
||||
return g.db.WithContext(ctx).Raw(query, args...).Find(dest).Error
|
||||
run := func() error {
|
||||
return g.getDB().WithContext(ctx).Raw(query, args...).Find(dest).Error
|
||||
}
|
||||
err = run()
|
||||
if isDBClosed(err) {
|
||||
if reconnErr := g.reconnectDB(); reconnErr == nil {
|
||||
err = run()
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (g *GormAdapter) BeginTx(ctx context.Context) (common.Database, error) {
|
||||
tx := g.db.WithContext(ctx).Begin()
|
||||
run := func() *gorm.DB {
|
||||
return g.getDB().WithContext(ctx).Begin()
|
||||
}
|
||||
tx := run()
|
||||
if isDBClosed(tx.Error) {
|
||||
if reconnErr := g.reconnectDB(); reconnErr == nil {
|
||||
tx = run()
|
||||
}
|
||||
}
|
||||
if tx.Error != nil {
|
||||
return nil, tx.Error
|
||||
}
|
||||
return &GormAdapter{db: tx, driverName: g.driverName}, nil
|
||||
return &GormAdapter{db: tx, dbFactory: g.dbFactory, driverName: g.driverName}, nil
|
||||
}
|
||||
|
||||
func (g *GormAdapter) CommitTx(ctx context.Context) error {
|
||||
@@ -100,24 +190,37 @@ func (g *GormAdapter) RunInTransaction(ctx context.Context, fn func(common.Datab
|
||||
err = logger.HandlePanic("GormAdapter.RunInTransaction", r)
|
||||
}
|
||||
}()
|
||||
return g.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
adapter := &GormAdapter{db: tx, driverName: g.driverName}
|
||||
run := func() error {
|
||||
return g.getDB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
adapter := &GormAdapter{db: tx, dbFactory: g.dbFactory, driverName: g.driverName}
|
||||
return fn(adapter)
|
||||
})
|
||||
}
|
||||
err = run()
|
||||
if isDBClosed(err) {
|
||||
if reconnErr := g.reconnectDB(); reconnErr == nil {
|
||||
err = run()
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (g *GormAdapter) GetUnderlyingDB() interface{} {
|
||||
return g.db
|
||||
return g.getDB()
|
||||
}
|
||||
|
||||
func (g *GormAdapter) DriverName() string {
|
||||
if g.db.Dialector == nil {
|
||||
return normalizeGormDriverName(g.getDB())
|
||||
}
|
||||
|
||||
func normalizeGormDriverName(db *gorm.DB) string {
|
||||
if db == nil || db.Dialector == nil {
|
||||
return ""
|
||||
}
|
||||
// Normalize GORM's dialector name to match the project's canonical vocabulary.
|
||||
// GORM returns "sqlserver" for MSSQL; the rest of the project uses "mssql".
|
||||
// GORM returns "sqlite" or "sqlite3" for SQLite; we normalize to "sqlite".
|
||||
switch name := g.db.Name(); name {
|
||||
switch name := db.Name(); name {
|
||||
case "sqlserver":
|
||||
return "mssql"
|
||||
case "sqlite3":
|
||||
@@ -130,6 +233,7 @@ func (g *GormAdapter) DriverName() string {
|
||||
// GormSelectQuery implements SelectQuery for GORM
|
||||
type GormSelectQuery struct {
|
||||
db *gorm.DB
|
||||
reconnect func(...*gorm.DB) error
|
||||
schema string // Separated schema name
|
||||
tableName string // Just the table name, without schema
|
||||
tableAlias string
|
||||
@@ -347,6 +451,7 @@ func (g *GormSelectQuery) PreloadRelation(relation string, apply ...func(common.
|
||||
|
||||
wrapper := &GormSelectQuery{
|
||||
db: db,
|
||||
reconnect: g.reconnect,
|
||||
driverName: g.driverName,
|
||||
}
|
||||
|
||||
@@ -385,6 +490,7 @@ func (g *GormSelectQuery) JoinRelation(relation string, apply ...func(common.Sel
|
||||
|
||||
wrapper := &GormSelectQuery{
|
||||
db: db,
|
||||
reconnect: g.reconnect,
|
||||
driverName: g.driverName,
|
||||
inJoinContext: true, // Mark as JOIN context
|
||||
joinTableAlias: strings.ToLower(relation), // Use relation name as alias
|
||||
@@ -444,7 +550,15 @@ func (g *GormSelectQuery) Scan(ctx context.Context, dest interface{}) (err error
|
||||
err = logger.HandlePanic("GormSelectQuery.Scan", r)
|
||||
}
|
||||
}()
|
||||
err = g.db.WithContext(ctx).Find(dest).Error
|
||||
run := func() error {
|
||||
return g.db.WithContext(ctx).Find(dest).Error
|
||||
}
|
||||
err = run()
|
||||
if isDBClosed(err) && g.reconnect != nil {
|
||||
if reconnErr := g.reconnect(g.db); reconnErr == nil {
|
||||
err = run()
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
// Log SQL string for debugging
|
||||
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
||||
@@ -464,7 +578,15 @@ func (g *GormSelectQuery) ScanModel(ctx context.Context) (err error) {
|
||||
if g.db.Statement.Model == nil {
|
||||
return fmt.Errorf("ScanModel requires Model() to be set before scanning")
|
||||
}
|
||||
err = g.db.WithContext(ctx).Find(g.db.Statement.Model).Error
|
||||
run := func() error {
|
||||
return g.db.WithContext(ctx).Find(g.db.Statement.Model).Error
|
||||
}
|
||||
err = run()
|
||||
if isDBClosed(err) && g.reconnect != nil {
|
||||
if reconnErr := g.reconnect(g.db); reconnErr == nil {
|
||||
err = run()
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
// Log SQL string for debugging
|
||||
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
||||
@@ -483,7 +605,15 @@ func (g *GormSelectQuery) Count(ctx context.Context) (count int, err error) {
|
||||
}
|
||||
}()
|
||||
var count64 int64
|
||||
err = g.db.WithContext(ctx).Count(&count64).Error
|
||||
run := func() error {
|
||||
return g.db.WithContext(ctx).Count(&count64).Error
|
||||
}
|
||||
err = run()
|
||||
if isDBClosed(err) && g.reconnect != nil {
|
||||
if reconnErr := g.reconnect(g.db); reconnErr == nil {
|
||||
err = run()
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
// Log SQL string for debugging
|
||||
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
||||
@@ -502,7 +632,15 @@ func (g *GormSelectQuery) Exists(ctx context.Context) (exists bool, err error) {
|
||||
}
|
||||
}()
|
||||
var count int64
|
||||
err = g.db.WithContext(ctx).Limit(1).Count(&count).Error
|
||||
run := func() error {
|
||||
return g.db.WithContext(ctx).Limit(1).Count(&count).Error
|
||||
}
|
||||
err = run()
|
||||
if isDBClosed(err) && g.reconnect != nil {
|
||||
if reconnErr := g.reconnect(g.db); reconnErr == nil {
|
||||
err = run()
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
// Log SQL string for debugging
|
||||
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
||||
@@ -516,6 +654,7 @@ func (g *GormSelectQuery) Exists(ctx context.Context) (exists bool, err error) {
|
||||
// GormInsertQuery implements InsertQuery for GORM
|
||||
type GormInsertQuery struct {
|
||||
db *gorm.DB
|
||||
reconnect func(...*gorm.DB) error
|
||||
model interface{}
|
||||
values map[string]interface{}
|
||||
}
|
||||
@@ -555,14 +694,21 @@ func (g *GormInsertQuery) Exec(ctx context.Context) (res common.Result, err erro
|
||||
err = logger.HandlePanic("GormInsertQuery.Exec", r)
|
||||
}
|
||||
}()
|
||||
var result *gorm.DB
|
||||
run := func() *gorm.DB {
|
||||
switch {
|
||||
case g.model != nil:
|
||||
result = g.db.WithContext(ctx).Create(g.model)
|
||||
return g.db.WithContext(ctx).Create(g.model)
|
||||
case g.values != nil:
|
||||
result = g.db.WithContext(ctx).Create(g.values)
|
||||
return g.db.WithContext(ctx).Create(g.values)
|
||||
default:
|
||||
result = g.db.WithContext(ctx).Create(map[string]interface{}{})
|
||||
return g.db.WithContext(ctx).Create(map[string]interface{}{})
|
||||
}
|
||||
}
|
||||
result := run()
|
||||
if isDBClosed(result.Error) && g.reconnect != nil {
|
||||
if reconnErr := g.reconnect(g.db); reconnErr == nil {
|
||||
result = run()
|
||||
}
|
||||
}
|
||||
return &GormResult{result: result}, result.Error
|
||||
}
|
||||
@@ -570,6 +716,7 @@ func (g *GormInsertQuery) Exec(ctx context.Context) (res common.Result, err erro
|
||||
// GormUpdateQuery implements UpdateQuery for GORM
|
||||
type GormUpdateQuery struct {
|
||||
db *gorm.DB
|
||||
reconnect func(...*gorm.DB) error
|
||||
model interface{}
|
||||
updates interface{}
|
||||
}
|
||||
@@ -647,7 +794,15 @@ func (g *GormUpdateQuery) Exec(ctx context.Context) (res common.Result, err erro
|
||||
err = logger.HandlePanic("GormUpdateQuery.Exec", r)
|
||||
}
|
||||
}()
|
||||
result := g.db.WithContext(ctx).Updates(g.updates)
|
||||
run := func() *gorm.DB {
|
||||
return g.db.WithContext(ctx).Updates(g.updates)
|
||||
}
|
||||
result := run()
|
||||
if isDBClosed(result.Error) && g.reconnect != nil {
|
||||
if reconnErr := g.reconnect(g.db); reconnErr == nil {
|
||||
result = run()
|
||||
}
|
||||
}
|
||||
if result.Error != nil {
|
||||
// Log SQL string for debugging
|
||||
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
||||
@@ -661,6 +816,7 @@ func (g *GormUpdateQuery) Exec(ctx context.Context) (res common.Result, err erro
|
||||
// GormDeleteQuery implements DeleteQuery for GORM
|
||||
type GormDeleteQuery struct {
|
||||
db *gorm.DB
|
||||
reconnect func(...*gorm.DB) error
|
||||
model interface{}
|
||||
}
|
||||
|
||||
@@ -686,7 +842,15 @@ func (g *GormDeleteQuery) Exec(ctx context.Context) (res common.Result, err erro
|
||||
err = logger.HandlePanic("GormDeleteQuery.Exec", r)
|
||||
}
|
||||
}()
|
||||
result := g.db.WithContext(ctx).Delete(g.model)
|
||||
run := func() *gorm.DB {
|
||||
return g.db.WithContext(ctx).Delete(g.model)
|
||||
}
|
||||
result := run()
|
||||
if isDBClosed(result.Error) && g.reconnect != nil {
|
||||
if reconnErr := g.reconnect(g.db); reconnErr == nil {
|
||||
result = run()
|
||||
}
|
||||
}
|
||||
if result.Error != nil {
|
||||
// Log SQL string for debugging
|
||||
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
||||
|
||||
@@ -359,6 +359,42 @@ func (c *sqlConnection) Stats() *ConnectionStats {
|
||||
return stats
|
||||
}
|
||||
|
||||
func (c *sqlConnection) reconnectForAdapter() error {
|
||||
timeout := c.config.ConnectTimeout
|
||||
if timeout <= 0 {
|
||||
timeout = 10 * time.Second
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
|
||||
return c.Reconnect(ctx)
|
||||
}
|
||||
|
||||
func (c *sqlConnection) reopenNativeForAdapter() (*sql.DB, error) {
|
||||
if err := c.reconnectForAdapter(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return c.Native()
|
||||
}
|
||||
|
||||
func (c *sqlConnection) reopenBunForAdapter() (*bun.DB, error) {
|
||||
if err := c.reconnectForAdapter(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return c.Bun()
|
||||
}
|
||||
|
||||
func (c *sqlConnection) reopenGORMForAdapter() (*gorm.DB, error) {
|
||||
if err := c.reconnectForAdapter(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return c.GORM()
|
||||
}
|
||||
|
||||
// getBunAdapter returns or creates the Bun adapter
|
||||
func (c *sqlConnection) getBunAdapter() (common.Database, error) {
|
||||
if c == nil {
|
||||
@@ -391,7 +427,7 @@ func (c *sqlConnection) getBunAdapter() (common.Database, error) {
|
||||
c.bunDB = bun.NewDB(native, dialect)
|
||||
}
|
||||
|
||||
c.bunAdapter = database.NewBunAdapter(c.bunDB)
|
||||
c.bunAdapter = database.NewBunAdapter(c.bunDB).WithDBFactory(c.reopenBunForAdapter)
|
||||
return c.bunAdapter, nil
|
||||
}
|
||||
|
||||
@@ -432,7 +468,7 @@ func (c *sqlConnection) getGORMAdapter() (common.Database, error) {
|
||||
c.gormDB = db
|
||||
}
|
||||
|
||||
c.gormAdapter = database.NewGormAdapter(c.gormDB)
|
||||
c.gormAdapter = database.NewGormAdapter(c.gormDB).WithDBFactory(c.reopenGORMForAdapter)
|
||||
return c.gormAdapter, nil
|
||||
}
|
||||
|
||||
@@ -473,11 +509,11 @@ func (c *sqlConnection) getNativeAdapter() (common.Database, error) {
|
||||
// Create a native adapter based on database type
|
||||
switch c.dbType {
|
||||
case DatabaseTypePostgreSQL:
|
||||
c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB, string(c.dbType))
|
||||
c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB, string(c.dbType)).WithDBFactory(c.reopenNativeForAdapter)
|
||||
case DatabaseTypeSQLite:
|
||||
c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB, string(c.dbType))
|
||||
c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB, string(c.dbType)).WithDBFactory(c.reopenNativeForAdapter)
|
||||
case DatabaseTypeMSSQL:
|
||||
c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB, string(c.dbType))
|
||||
c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB, string(c.dbType)).WithDBFactory(c.reopenNativeForAdapter)
|
||||
default:
|
||||
return nil, ErrUnsupportedDatabase
|
||||
}
|
||||
|
||||
@@ -4,8 +4,13 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common/adapters/database"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/dbmanager/providers"
|
||||
)
|
||||
|
||||
func TestNewConnectionFromDB(t *testing.T) {
|
||||
@@ -208,3 +213,157 @@ func TestNewConnectionFromDB_PostgreSQL(t *testing.T) {
|
||||
t.Errorf("Expected type DatabaseTypePostgreSQL, got '%s'", conn.Type())
|
||||
}
|
||||
}
|
||||
|
||||
func TestDatabaseNativeAdapterReconnectFactory(t *testing.T) {
|
||||
conn := newSQLConnection("test-native", DatabaseTypeSQLite, ConnectionConfig{
|
||||
Name: "test-native",
|
||||
Type: DatabaseTypeSQLite,
|
||||
FilePath: ":memory:",
|
||||
DefaultORM: string(ORMTypeNative),
|
||||
ConnectTimeout: 2 * time.Second,
|
||||
}, providers.NewSQLiteProvider())
|
||||
|
||||
ctx := context.Background()
|
||||
if err := conn.Connect(ctx); err != nil {
|
||||
t.Fatalf("Failed to connect: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
db, err := conn.Database()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get database adapter: %v", err)
|
||||
}
|
||||
|
||||
adapter, ok := db.(*database.PgSQLAdapter)
|
||||
if !ok {
|
||||
t.Fatalf("Expected PgSQLAdapter, got %T", db)
|
||||
}
|
||||
|
||||
underlyingBefore, ok := adapter.GetUnderlyingDB().(*sql.DB)
|
||||
if !ok {
|
||||
t.Fatalf("Expected underlying *sql.DB, got %T", adapter.GetUnderlyingDB())
|
||||
}
|
||||
|
||||
if err := underlyingBefore.Close(); err != nil {
|
||||
t.Fatalf("Failed to close underlying database: %v", err)
|
||||
}
|
||||
|
||||
if _, err := db.Exec(ctx, "SELECT 1"); err != nil {
|
||||
t.Fatalf("Expected native adapter to reconnect, got error: %v", err)
|
||||
}
|
||||
|
||||
underlyingAfter, ok := adapter.GetUnderlyingDB().(*sql.DB)
|
||||
if !ok {
|
||||
t.Fatalf("Expected reconnected *sql.DB, got %T", adapter.GetUnderlyingDB())
|
||||
}
|
||||
|
||||
if underlyingAfter == underlyingBefore {
|
||||
t.Fatal("Expected adapter to swap to a fresh *sql.DB after reconnect")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDatabaseBunAdapterReconnectFactory(t *testing.T) {
|
||||
conn := newSQLConnection("test-bun", DatabaseTypeSQLite, ConnectionConfig{
|
||||
Name: "test-bun",
|
||||
Type: DatabaseTypeSQLite,
|
||||
FilePath: ":memory:",
|
||||
DefaultORM: string(ORMTypeBun),
|
||||
ConnectTimeout: 2 * time.Second,
|
||||
}, providers.NewSQLiteProvider())
|
||||
|
||||
ctx := context.Background()
|
||||
if err := conn.Connect(ctx); err != nil {
|
||||
t.Fatalf("Failed to connect: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
db, err := conn.Database()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get database adapter: %v", err)
|
||||
}
|
||||
|
||||
adapter, ok := db.(*database.BunAdapter)
|
||||
if !ok {
|
||||
t.Fatalf("Expected BunAdapter, got %T", db)
|
||||
}
|
||||
|
||||
underlyingBefore, ok := adapter.GetUnderlyingDB().(interface{ Close() error })
|
||||
if !ok {
|
||||
t.Fatalf("Expected underlying Bun DB with Close method, got %T", adapter.GetUnderlyingDB())
|
||||
}
|
||||
|
||||
if err := underlyingBefore.Close(); err != nil {
|
||||
t.Fatalf("Failed to close underlying Bun database: %v", err)
|
||||
}
|
||||
|
||||
if _, err := db.Exec(ctx, "SELECT 1"); err != nil {
|
||||
t.Fatalf("Expected Bun adapter to reconnect, got error: %v", err)
|
||||
}
|
||||
|
||||
underlyingAfter := adapter.GetUnderlyingDB()
|
||||
if underlyingAfter == underlyingBefore {
|
||||
t.Fatal("Expected adapter to swap to a fresh Bun DB after reconnect")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDatabaseGormAdapterReconnectFactory(t *testing.T) {
|
||||
conn := newSQLConnection("test-gorm", DatabaseTypeSQLite, ConnectionConfig{
|
||||
Name: "test-gorm",
|
||||
Type: DatabaseTypeSQLite,
|
||||
FilePath: ":memory:",
|
||||
DefaultORM: string(ORMTypeGORM),
|
||||
ConnectTimeout: 2 * time.Second,
|
||||
}, providers.NewSQLiteProvider())
|
||||
|
||||
ctx := context.Background()
|
||||
if err := conn.Connect(ctx); err != nil {
|
||||
t.Fatalf("Failed to connect: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
db, err := conn.Database()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get database adapter: %v", err)
|
||||
}
|
||||
|
||||
adapter, ok := db.(*database.GormAdapter)
|
||||
if !ok {
|
||||
t.Fatalf("Expected GormAdapter, got %T", db)
|
||||
}
|
||||
|
||||
gormBefore, ok := adapter.GetUnderlyingDB().(*gorm.DB)
|
||||
if !ok {
|
||||
t.Fatalf("Expected underlying *gorm.DB, got %T", adapter.GetUnderlyingDB())
|
||||
}
|
||||
|
||||
sqlBefore, err := gormBefore.DB()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get underlying *sql.DB: %v", err)
|
||||
}
|
||||
|
||||
if err := sqlBefore.Close(); err != nil {
|
||||
t.Fatalf("Failed to close underlying database: %v", err)
|
||||
}
|
||||
|
||||
count, err := db.NewSelect().Table("sqlite_master").Count(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected GORM query builder to reconnect, got error: %v", err)
|
||||
}
|
||||
if count < 0 {
|
||||
t.Fatalf("Expected non-negative count, got %d", count)
|
||||
}
|
||||
|
||||
gormAfter, ok := adapter.GetUnderlyingDB().(*gorm.DB)
|
||||
if !ok {
|
||||
t.Fatalf("Expected reconnected *gorm.DB, got %T", adapter.GetUnderlyingDB())
|
||||
}
|
||||
|
||||
sqlAfter, err := gormAfter.DB()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get reconnected *sql.DB: %v", err)
|
||||
}
|
||||
|
||||
if sqlAfter == sqlBefore {
|
||||
t.Fatal("Expected GORM adapter to use a fresh *sql.DB after reconnect")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,7 +2,9 @@ package dbmanager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -366,8 +368,11 @@ func (m *connectionManager) performHealthCheck() {
|
||||
"connection", item.name,
|
||||
"error", err)
|
||||
|
||||
// Attempt reconnection if enabled
|
||||
if m.config.EnableAutoReconnect {
|
||||
// Only reconnect when the client handle itself is closed/disconnected.
|
||||
// For transient database restarts or network blips, *sql.DB can recover
|
||||
// on its own; forcing Close()+Connect() here invalidates any cached ORM
|
||||
// wrappers and callers that still hold the old handle.
|
||||
if m.config.EnableAutoReconnect && shouldReconnectAfterHealthCheck(err) {
|
||||
logger.Info("Attempting reconnection: connection=%s", item.name)
|
||||
if err := item.conn.Reconnect(ctx); err != nil {
|
||||
logger.Error("Reconnection failed",
|
||||
@@ -376,7 +381,21 @@ func (m *connectionManager) performHealthCheck() {
|
||||
} else {
|
||||
logger.Info("Reconnection successful: connection=%s", item.name)
|
||||
}
|
||||
} else if m.config.EnableAutoReconnect {
|
||||
logger.Info("Skipping reconnect for transient health check failure: connection=%s", item.name)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func shouldReconnectAfterHealthCheck(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if errors.Is(err, ErrConnectionClosed) {
|
||||
return true
|
||||
}
|
||||
|
||||
return strings.Contains(err.Error(), "sql: database is closed")
|
||||
}
|
||||
|
||||
@@ -3,12 +3,38 @@ package dbmanager
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/uptrace/bun"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
type healthCheckStubConnection struct {
|
||||
healthErr error
|
||||
reconnectCalls int
|
||||
}
|
||||
|
||||
func (c *healthCheckStubConnection) Name() string { return "stub" }
|
||||
func (c *healthCheckStubConnection) Type() DatabaseType { return DatabaseTypePostgreSQL }
|
||||
func (c *healthCheckStubConnection) Bun() (*bun.DB, error) { return nil, fmt.Errorf("not implemented") }
|
||||
func (c *healthCheckStubConnection) GORM() (*gorm.DB, error) { return nil, fmt.Errorf("not implemented") }
|
||||
func (c *healthCheckStubConnection) Native() (*sql.DB, error) { return nil, fmt.Errorf("not implemented") }
|
||||
func (c *healthCheckStubConnection) DB() (*sql.DB, error) { return nil, fmt.Errorf("not implemented") }
|
||||
func (c *healthCheckStubConnection) Database() (common.Database, error) { return nil, fmt.Errorf("not implemented") }
|
||||
func (c *healthCheckStubConnection) MongoDB() (*mongo.Client, error) { return nil, fmt.Errorf("not implemented") }
|
||||
func (c *healthCheckStubConnection) Connect(ctx context.Context) error { return nil }
|
||||
func (c *healthCheckStubConnection) Close() error { return nil }
|
||||
func (c *healthCheckStubConnection) HealthCheck(ctx context.Context) error { return c.healthErr }
|
||||
func (c *healthCheckStubConnection) Reconnect(ctx context.Context) error { c.reconnectCalls++; return nil }
|
||||
func (c *healthCheckStubConnection) Stats() *ConnectionStats { return &ConnectionStats{} }
|
||||
|
||||
func TestBackgroundHealthChecker(t *testing.T) {
|
||||
// Create a SQLite in-memory database
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
@@ -224,3 +250,41 @@ func TestManagerStatsAfterClose(t *testing.T) {
|
||||
t.Errorf("Expected 0 total connections after close, got %d", stats.TotalConnections)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPerformHealthCheckSkipsReconnectForTransientFailures(t *testing.T) {
|
||||
conn := &healthCheckStubConnection{
|
||||
healthErr: fmt.Errorf("connection 'primary' health check: dial tcp 127.0.0.1:5432: connect: connection refused"),
|
||||
}
|
||||
|
||||
mgr := &connectionManager{
|
||||
connections: map[string]Connection{"primary": conn},
|
||||
config: ManagerConfig{
|
||||
EnableAutoReconnect: true,
|
||||
},
|
||||
}
|
||||
|
||||
mgr.performHealthCheck()
|
||||
|
||||
if conn.reconnectCalls != 0 {
|
||||
t.Fatalf("expected no reconnect attempts for transient health failure, got %d", conn.reconnectCalls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPerformHealthCheckReconnectsClosedConnections(t *testing.T) {
|
||||
conn := &healthCheckStubConnection{
|
||||
healthErr: NewConnectionError("primary", "health check", fmt.Errorf("sql: database is closed")),
|
||||
}
|
||||
|
||||
mgr := &connectionManager{
|
||||
connections: map[string]Connection{"primary": conn},
|
||||
config: ManagerConfig{
|
||||
EnableAutoReconnect: true,
|
||||
},
|
||||
}
|
||||
|
||||
mgr.performHealthCheck()
|
||||
|
||||
if conn.reconnectCalls != 1 {
|
||||
t.Fatalf("expected reconnect attempt for closed database handle, got %d", conn.reconnectCalls)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -143,6 +143,22 @@ func (a *DatabaseAuthenticator) reconnectDB() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *DatabaseAuthenticator) runDBOpWithReconnect(run func(*sql.DB) error) error {
|
||||
db := a.getDB()
|
||||
if db == nil {
|
||||
return fmt.Errorf("database connection is nil")
|
||||
}
|
||||
|
||||
err := run(db)
|
||||
if isDBClosed(err) {
|
||||
if reconnErr := a.reconnectDB(); reconnErr == nil {
|
||||
err = run(a.getDB())
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (a *DatabaseAuthenticator) Login(ctx context.Context, req LoginRequest) (*LoginResponse, error) {
|
||||
// Convert LoginRequest to JSON
|
||||
reqJSON, err := json.Marshal(req)
|
||||
@@ -154,16 +170,10 @@ func (a *DatabaseAuthenticator) Login(ctx context.Context, req LoginRequest) (*L
|
||||
var errorMsg sql.NullString
|
||||
var dataJSON sql.NullString
|
||||
|
||||
runLoginQuery := func() error {
|
||||
err = a.runDBOpWithReconnect(func(db *sql.DB) error {
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Login)
|
||||
return a.getDB().QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
|
||||
}
|
||||
err = runLoginQuery()
|
||||
if isDBClosed(err) {
|
||||
if reconnErr := a.reconnectDB(); reconnErr == nil {
|
||||
err = runLoginQuery()
|
||||
}
|
||||
}
|
||||
return db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("login query failed: %w", err)
|
||||
}
|
||||
@@ -196,8 +206,10 @@ func (a *DatabaseAuthenticator) Register(ctx context.Context, req RegisterReques
|
||||
var errorMsg sql.NullString
|
||||
var dataJSON sql.NullString
|
||||
|
||||
err = a.runDBOpWithReconnect(func(db *sql.DB) error {
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Register)
|
||||
err = a.getDB().QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
|
||||
return db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("register query failed: %w", err)
|
||||
}
|
||||
@@ -229,8 +241,10 @@ func (a *DatabaseAuthenticator) Logout(ctx context.Context, req LogoutRequest) e
|
||||
var errorMsg sql.NullString
|
||||
var dataJSON sql.NullString
|
||||
|
||||
err = a.runDBOpWithReconnect(func(db *sql.DB) error {
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Logout)
|
||||
err = a.getDB().QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
|
||||
return db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("logout query failed: %w", err)
|
||||
}
|
||||
@@ -303,8 +317,10 @@ func (a *DatabaseAuthenticator) Authenticate(r *http.Request) (*UserContext, err
|
||||
var errorMsg sql.NullString
|
||||
var userJSON sql.NullString
|
||||
|
||||
err := a.runDBOpWithReconnect(func(db *sql.DB) error {
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2)`, a.sqlNames.Session)
|
||||
err := a.getDB().QueryRowContext(r.Context(), query, token, reference).Scan(&success, &errorMsg, &userJSON)
|
||||
return db.QueryRowContext(r.Context(), query, token, reference).Scan(&success, &errorMsg, &userJSON)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("session query failed: %w", err)
|
||||
}
|
||||
@@ -379,8 +395,10 @@ func (a *DatabaseAuthenticator) updateSessionActivity(ctx context.Context, sessi
|
||||
var errorMsg sql.NullString
|
||||
var updatedUserJSON sql.NullString
|
||||
|
||||
_ = a.runDBOpWithReconnect(func(db *sql.DB) error {
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2::jsonb)`, a.sqlNames.SessionUpdate)
|
||||
_ = a.getDB().QueryRowContext(ctx, query, sessionToken, string(userJSON)).Scan(&success, &errorMsg, &updatedUserJSON)
|
||||
return db.QueryRowContext(ctx, query, sessionToken, string(userJSON)).Scan(&success, &errorMsg, &updatedUserJSON)
|
||||
})
|
||||
}
|
||||
|
||||
// RefreshToken implements Refreshable interface
|
||||
@@ -390,8 +408,10 @@ func (a *DatabaseAuthenticator) RefreshToken(ctx context.Context, refreshToken s
|
||||
var errorMsg sql.NullString
|
||||
var userJSON sql.NullString
|
||||
// Get current session to pass to refresh
|
||||
err := a.runDBOpWithReconnect(func(db *sql.DB) error {
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2)`, a.sqlNames.Session)
|
||||
err := a.getDB().QueryRowContext(ctx, query, refreshToken, "refresh").Scan(&success, &errorMsg, &userJSON)
|
||||
return db.QueryRowContext(ctx, query, refreshToken, "refresh").Scan(&success, &errorMsg, &userJSON)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("refresh token query failed: %w", err)
|
||||
}
|
||||
@@ -407,8 +427,10 @@ func (a *DatabaseAuthenticator) RefreshToken(ctx context.Context, refreshToken s
|
||||
var newErrorMsg sql.NullString
|
||||
var newUserJSON sql.NullString
|
||||
|
||||
err = a.runDBOpWithReconnect(func(db *sql.DB) error {
|
||||
refreshQuery := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2::jsonb)`, a.sqlNames.RefreshToken)
|
||||
err = a.getDB().QueryRowContext(ctx, refreshQuery, refreshToken, userJSON).Scan(&newSuccess, &newErrorMsg, &newUserJSON)
|
||||
return db.QueryRowContext(ctx, refreshQuery, refreshToken, userJSON).Scan(&newSuccess, &newErrorMsg, &newUserJSON)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("refresh token generation failed: %w", err)
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package security
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
@@ -790,6 +791,211 @@ func TestDatabaseAuthenticatorRefreshToken(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestDatabaseAuthenticatorReconnectsClosedDBPaths(t *testing.T) {
|
||||
newAuthWithReconnect := func(t *testing.T) (*DatabaseAuthenticator, sqlmock.Sqlmock, sqlmock.Sqlmock, func()) {
|
||||
t.Helper()
|
||||
|
||||
primaryDB, primaryMock, err := sqlmock.New()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create primary mock db: %v", err)
|
||||
}
|
||||
|
||||
reconnectDB, reconnectMock, err := sqlmock.New()
|
||||
if err != nil {
|
||||
primaryDB.Close()
|
||||
t.Fatalf("failed to create reconnect mock db: %v", err)
|
||||
}
|
||||
|
||||
cacheProvider := cache.NewMemoryProvider(&cache.Options{
|
||||
DefaultTTL: 1 * time.Minute,
|
||||
MaxSize: 1000,
|
||||
})
|
||||
|
||||
auth := NewDatabaseAuthenticatorWithOptions(primaryDB, DatabaseAuthenticatorOptions{
|
||||
Cache: cache.NewCache(cacheProvider),
|
||||
DBFactory: func() (*sql.DB, error) {
|
||||
return reconnectDB, nil
|
||||
},
|
||||
})
|
||||
|
||||
cleanup := func() {
|
||||
_ = primaryDB.Close()
|
||||
_ = reconnectDB.Close()
|
||||
}
|
||||
|
||||
return auth, primaryMock, reconnectMock, cleanup
|
||||
}
|
||||
|
||||
t.Run("Authenticate reconnects after closed database", func(t *testing.T) {
|
||||
auth, primaryMock, reconnectMock, cleanup := newAuthWithReconnect(t)
|
||||
defer cleanup()
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.Header.Set("Authorization", "Bearer reconnect-auth-token")
|
||||
|
||||
primaryMock.ExpectQuery(`SELECT p_success, p_error, p_user::text FROM resolvespec_session`).
|
||||
WithArgs("reconnect-auth-token", "authenticate").
|
||||
WillReturnError(fmt.Errorf("sql: database is closed"))
|
||||
|
||||
reconnectRows := sqlmock.NewRows([]string{"p_success", "p_error", "p_user"}).
|
||||
AddRow(true, nil, `{"user_id":7,"user_name":"reconnect-user","session_id":"reconnect-auth-token"}`)
|
||||
|
||||
reconnectMock.ExpectQuery(`SELECT p_success, p_error, p_user::text FROM resolvespec_session`).
|
||||
WithArgs("reconnect-auth-token", "authenticate").
|
||||
WillReturnRows(reconnectRows)
|
||||
|
||||
userCtx, err := auth.Authenticate(req)
|
||||
if err != nil {
|
||||
t.Fatalf("expected authenticate to reconnect, got %v", err)
|
||||
}
|
||||
if userCtx.UserID != 7 {
|
||||
t.Fatalf("expected user ID 7, got %d", userCtx.UserID)
|
||||
}
|
||||
|
||||
if err := primaryMock.ExpectationsWereMet(); err != nil {
|
||||
t.Fatalf("primary db expectations not met: %v", err)
|
||||
}
|
||||
if err := reconnectMock.ExpectationsWereMet(); err != nil {
|
||||
t.Fatalf("reconnect db expectations not met: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Register reconnects after closed database", func(t *testing.T) {
|
||||
auth, primaryMock, reconnectMock, cleanup := newAuthWithReconnect(t)
|
||||
defer cleanup()
|
||||
|
||||
req := RegisterRequest{
|
||||
Username: "reconnect-register",
|
||||
Password: "password123",
|
||||
Email: "reconnect@example.com",
|
||||
UserLevel: 1,
|
||||
Roles: []string{"user"},
|
||||
}
|
||||
|
||||
primaryMock.ExpectQuery(`SELECT p_success, p_error, p_data::text FROM resolvespec_register`).
|
||||
WithArgs(sqlmock.AnyArg()).
|
||||
WillReturnError(fmt.Errorf("sql: database is closed"))
|
||||
|
||||
reconnectRows := sqlmock.NewRows([]string{"p_success", "p_error", "p_data"}).
|
||||
AddRow(true, nil, `{"token":"reconnected-register-token","user":{"user_id":8,"user_name":"reconnect-register"},"expires_in":86400}`)
|
||||
|
||||
reconnectMock.ExpectQuery(`SELECT p_success, p_error, p_data::text FROM resolvespec_register`).
|
||||
WithArgs(sqlmock.AnyArg()).
|
||||
WillReturnRows(reconnectRows)
|
||||
|
||||
resp, err := auth.Register(context.Background(), req)
|
||||
if err != nil {
|
||||
t.Fatalf("expected register to reconnect, got %v", err)
|
||||
}
|
||||
if resp.Token != "reconnected-register-token" {
|
||||
t.Fatalf("expected refreshed token, got %s", resp.Token)
|
||||
}
|
||||
|
||||
if err := primaryMock.ExpectationsWereMet(); err != nil {
|
||||
t.Fatalf("primary db expectations not met: %v", err)
|
||||
}
|
||||
if err := reconnectMock.ExpectationsWereMet(); err != nil {
|
||||
t.Fatalf("reconnect db expectations not met: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Logout reconnects after closed database", func(t *testing.T) {
|
||||
auth, primaryMock, reconnectMock, cleanup := newAuthWithReconnect(t)
|
||||
defer cleanup()
|
||||
|
||||
req := LogoutRequest{Token: "logout-reconnect-token", UserID: 9}
|
||||
|
||||
primaryMock.ExpectQuery(`SELECT p_success, p_error, p_data::text FROM resolvespec_logout`).
|
||||
WithArgs(sqlmock.AnyArg()).
|
||||
WillReturnError(fmt.Errorf("sql: database is closed"))
|
||||
|
||||
reconnectRows := sqlmock.NewRows([]string{"p_success", "p_error", "p_data"}).
|
||||
AddRow(true, nil, nil)
|
||||
|
||||
reconnectMock.ExpectQuery(`SELECT p_success, p_error, p_data::text FROM resolvespec_logout`).
|
||||
WithArgs(sqlmock.AnyArg()).
|
||||
WillReturnRows(reconnectRows)
|
||||
|
||||
if err := auth.Logout(context.Background(), req); err != nil {
|
||||
t.Fatalf("expected logout to reconnect, got %v", err)
|
||||
}
|
||||
|
||||
if err := primaryMock.ExpectationsWereMet(); err != nil {
|
||||
t.Fatalf("primary db expectations not met: %v", err)
|
||||
}
|
||||
if err := reconnectMock.ExpectationsWereMet(); err != nil {
|
||||
t.Fatalf("reconnect db expectations not met: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("RefreshToken reconnects after closed database", func(t *testing.T) {
|
||||
auth, primaryMock, reconnectMock, cleanup := newAuthWithReconnect(t)
|
||||
defer cleanup()
|
||||
|
||||
refreshToken := "refresh-reconnect-token"
|
||||
|
||||
primaryMock.ExpectQuery(`SELECT p_success, p_error, p_user::text FROM resolvespec_session`).
|
||||
WithArgs(refreshToken, "refresh").
|
||||
WillReturnError(fmt.Errorf("sql: database is closed"))
|
||||
|
||||
sessionRows := sqlmock.NewRows([]string{"p_success", "p_error", "p_user"}).
|
||||
AddRow(true, nil, `{"user_id":10,"user_name":"refresh-user"}`)
|
||||
|
||||
reconnectMock.ExpectQuery(`SELECT p_success, p_error, p_user::text FROM resolvespec_session`).
|
||||
WithArgs(refreshToken, "refresh").
|
||||
WillReturnRows(sessionRows)
|
||||
|
||||
refreshRows := sqlmock.NewRows([]string{"p_success", "p_error", "p_user"}).
|
||||
AddRow(true, nil, `{"user_id":10,"user_name":"refresh-user","session_id":"refreshed-token"}`)
|
||||
|
||||
reconnectMock.ExpectQuery(`SELECT p_success, p_error, p_user::text FROM resolvespec_refresh_token`).
|
||||
WithArgs(refreshToken, sqlmock.AnyArg()).
|
||||
WillReturnRows(refreshRows)
|
||||
|
||||
resp, err := auth.RefreshToken(context.Background(), refreshToken)
|
||||
if err != nil {
|
||||
t.Fatalf("expected refresh token to reconnect, got %v", err)
|
||||
}
|
||||
if resp.Token != "refreshed-token" {
|
||||
t.Fatalf("expected refreshed-token, got %s", resp.Token)
|
||||
}
|
||||
|
||||
if err := primaryMock.ExpectationsWereMet(); err != nil {
|
||||
t.Fatalf("primary db expectations not met: %v", err)
|
||||
}
|
||||
if err := reconnectMock.ExpectationsWereMet(); err != nil {
|
||||
t.Fatalf("reconnect db expectations not met: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("updateSessionActivity reconnects after closed database", func(t *testing.T) {
|
||||
auth, primaryMock, reconnectMock, cleanup := newAuthWithReconnect(t)
|
||||
defer cleanup()
|
||||
|
||||
userCtx := &UserContext{UserID: 11, UserName: "activity-user", SessionID: "activity-token"}
|
||||
|
||||
primaryMock.ExpectQuery(`SELECT p_success, p_error, p_user::text FROM resolvespec_session_update`).
|
||||
WithArgs("activity-token", sqlmock.AnyArg()).
|
||||
WillReturnError(fmt.Errorf("sql: database is closed"))
|
||||
|
||||
reconnectRows := sqlmock.NewRows([]string{"p_success", "p_error", "p_user"}).
|
||||
AddRow(true, nil, `{"user_id":11,"user_name":"activity-user","session_id":"activity-token"}`)
|
||||
|
||||
reconnectMock.ExpectQuery(`SELECT p_success, p_error, p_user::text FROM resolvespec_session_update`).
|
||||
WithArgs("activity-token", sqlmock.AnyArg()).
|
||||
WillReturnRows(reconnectRows)
|
||||
|
||||
auth.updateSessionActivity(context.Background(), "activity-token", userCtx)
|
||||
|
||||
if err := primaryMock.ExpectationsWereMet(); err != nil {
|
||||
t.Fatalf("primary db expectations not met: %v", err)
|
||||
}
|
||||
if err := reconnectMock.ExpectationsWereMet(); err != nil {
|
||||
t.Fatalf("reconnect db expectations not met: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Test JWTAuthenticator
|
||||
func TestJWTAuthenticator(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
|
||||
Reference in New Issue
Block a user