mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2026-04-10 18:03:57 +00:00
Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4fc25c60ae | ||
|
|
16a960d973 |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -29,3 +29,4 @@ test.db
|
|||||||
tests/data/
|
tests/data/
|
||||||
node_modules/
|
node_modules/
|
||||||
resolvespec-js/dist/
|
resolvespec-js/dist/
|
||||||
|
.codex
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
|
|
||||||
@@ -15,7 +16,9 @@ import (
|
|||||||
|
|
||||||
// GormAdapter adapts GORM to work with our Database interface
|
// GormAdapter adapts GORM to work with our Database interface
|
||||||
type GormAdapter struct {
|
type GormAdapter struct {
|
||||||
|
dbMu sync.RWMutex
|
||||||
db *gorm.DB
|
db *gorm.DB
|
||||||
|
dbFactory func() (*gorm.DB, error)
|
||||||
driverName string
|
driverName string
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -27,10 +30,72 @@ func NewGormAdapter(db *gorm.DB) *GormAdapter {
|
|||||||
return adapter
|
return adapter
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithDBFactory configures a factory used to reopen the database connection if it is closed.
|
||||||
|
func (g *GormAdapter) WithDBFactory(factory func() (*gorm.DB, error)) *GormAdapter {
|
||||||
|
g.dbFactory = factory
|
||||||
|
return g
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *GormAdapter) getDB() *gorm.DB {
|
||||||
|
g.dbMu.RLock()
|
||||||
|
defer g.dbMu.RUnlock()
|
||||||
|
return g.db
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *GormAdapter) reconnectDB(targets ...*gorm.DB) error {
|
||||||
|
if g.dbFactory == nil {
|
||||||
|
return fmt.Errorf("no db factory configured for reconnect")
|
||||||
|
}
|
||||||
|
|
||||||
|
freshDB, err := g.dbFactory()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
g.dbMu.Lock()
|
||||||
|
previous := g.db
|
||||||
|
g.db = freshDB
|
||||||
|
g.driverName = normalizeGormDriverName(freshDB)
|
||||||
|
g.dbMu.Unlock()
|
||||||
|
|
||||||
|
if previous != nil {
|
||||||
|
syncGormConnPool(previous, freshDB)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, target := range targets {
|
||||||
|
if target != nil && target != previous {
|
||||||
|
syncGormConnPool(target, freshDB)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func syncGormConnPool(target, fresh *gorm.DB) {
|
||||||
|
if target == nil || fresh == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if target.Config != nil && fresh.Config != nil {
|
||||||
|
target.ConnPool = fresh.ConnPool
|
||||||
|
}
|
||||||
|
|
||||||
|
if target.Statement != nil {
|
||||||
|
if fresh.Statement != nil && fresh.Statement.ConnPool != nil {
|
||||||
|
target.Statement.ConnPool = fresh.Statement.ConnPool
|
||||||
|
} else if fresh.Config != nil {
|
||||||
|
target.Statement.ConnPool = fresh.ConnPool
|
||||||
|
}
|
||||||
|
target.Statement.DB = target
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// EnableQueryDebug enables query debugging which logs all SQL queries including preloads
|
// EnableQueryDebug enables query debugging which logs all SQL queries including preloads
|
||||||
// This is useful for debugging preload queries that may be failing
|
// This is useful for debugging preload queries that may be failing
|
||||||
func (g *GormAdapter) EnableQueryDebug() *GormAdapter {
|
func (g *GormAdapter) EnableQueryDebug() *GormAdapter {
|
||||||
|
g.dbMu.Lock()
|
||||||
g.db = g.db.Debug()
|
g.db = g.db.Debug()
|
||||||
|
g.dbMu.Unlock()
|
||||||
logger.Info("GORM query debug mode enabled - all SQL queries will be logged")
|
logger.Info("GORM query debug mode enabled - all SQL queries will be logged")
|
||||||
return g
|
return g
|
||||||
}
|
}
|
||||||
@@ -44,19 +109,19 @@ func (g *GormAdapter) DisableQueryDebug() *GormAdapter {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (g *GormAdapter) NewSelect() common.SelectQuery {
|
func (g *GormAdapter) NewSelect() common.SelectQuery {
|
||||||
return &GormSelectQuery{db: g.db, driverName: g.driverName}
|
return &GormSelectQuery{db: g.getDB(), driverName: g.driverName, reconnect: g.reconnectDB}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GormAdapter) NewInsert() common.InsertQuery {
|
func (g *GormAdapter) NewInsert() common.InsertQuery {
|
||||||
return &GormInsertQuery{db: g.db}
|
return &GormInsertQuery{db: g.getDB(), reconnect: g.reconnectDB}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GormAdapter) NewUpdate() common.UpdateQuery {
|
func (g *GormAdapter) NewUpdate() common.UpdateQuery {
|
||||||
return &GormUpdateQuery{db: g.db}
|
return &GormUpdateQuery{db: g.getDB(), reconnect: g.reconnectDB}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GormAdapter) NewDelete() common.DeleteQuery {
|
func (g *GormAdapter) NewDelete() common.DeleteQuery {
|
||||||
return &GormDeleteQuery{db: g.db}
|
return &GormDeleteQuery{db: g.getDB(), reconnect: g.reconnectDB}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GormAdapter) Exec(ctx context.Context, query string, args ...interface{}) (res common.Result, err error) {
|
func (g *GormAdapter) Exec(ctx context.Context, query string, args ...interface{}) (res common.Result, err error) {
|
||||||
@@ -65,7 +130,15 @@ func (g *GormAdapter) Exec(ctx context.Context, query string, args ...interface{
|
|||||||
err = logger.HandlePanic("GormAdapter.Exec", r)
|
err = logger.HandlePanic("GormAdapter.Exec", r)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
result := g.db.WithContext(ctx).Exec(query, args...)
|
run := func() *gorm.DB {
|
||||||
|
return g.getDB().WithContext(ctx).Exec(query, args...)
|
||||||
|
}
|
||||||
|
result := run()
|
||||||
|
if isDBClosed(result.Error) {
|
||||||
|
if reconnErr := g.reconnectDB(); reconnErr == nil {
|
||||||
|
result = run()
|
||||||
|
}
|
||||||
|
}
|
||||||
return &GormResult{result: result}, result.Error
|
return &GormResult{result: result}, result.Error
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -75,15 +148,32 @@ func (g *GormAdapter) Query(ctx context.Context, dest interface{}, query string,
|
|||||||
err = logger.HandlePanic("GormAdapter.Query", r)
|
err = logger.HandlePanic("GormAdapter.Query", r)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
return g.db.WithContext(ctx).Raw(query, args...).Find(dest).Error
|
run := func() error {
|
||||||
|
return g.getDB().WithContext(ctx).Raw(query, args...).Find(dest).Error
|
||||||
|
}
|
||||||
|
err = run()
|
||||||
|
if isDBClosed(err) {
|
||||||
|
if reconnErr := g.reconnectDB(); reconnErr == nil {
|
||||||
|
err = run()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GormAdapter) BeginTx(ctx context.Context) (common.Database, error) {
|
func (g *GormAdapter) BeginTx(ctx context.Context) (common.Database, error) {
|
||||||
tx := g.db.WithContext(ctx).Begin()
|
run := func() *gorm.DB {
|
||||||
|
return g.getDB().WithContext(ctx).Begin()
|
||||||
|
}
|
||||||
|
tx := run()
|
||||||
|
if isDBClosed(tx.Error) {
|
||||||
|
if reconnErr := g.reconnectDB(); reconnErr == nil {
|
||||||
|
tx = run()
|
||||||
|
}
|
||||||
|
}
|
||||||
if tx.Error != nil {
|
if tx.Error != nil {
|
||||||
return nil, tx.Error
|
return nil, tx.Error
|
||||||
}
|
}
|
||||||
return &GormAdapter{db: tx, driverName: g.driverName}, nil
|
return &GormAdapter{db: tx, dbFactory: g.dbFactory, driverName: g.driverName}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GormAdapter) CommitTx(ctx context.Context) error {
|
func (g *GormAdapter) CommitTx(ctx context.Context) error {
|
||||||
@@ -100,24 +190,37 @@ func (g *GormAdapter) RunInTransaction(ctx context.Context, fn func(common.Datab
|
|||||||
err = logger.HandlePanic("GormAdapter.RunInTransaction", r)
|
err = logger.HandlePanic("GormAdapter.RunInTransaction", r)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
return g.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
run := func() error {
|
||||||
adapter := &GormAdapter{db: tx, driverName: g.driverName}
|
return g.getDB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||||
return fn(adapter)
|
adapter := &GormAdapter{db: tx, dbFactory: g.dbFactory, driverName: g.driverName}
|
||||||
})
|
return fn(adapter)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
err = run()
|
||||||
|
if isDBClosed(err) {
|
||||||
|
if reconnErr := g.reconnectDB(); reconnErr == nil {
|
||||||
|
err = run()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GormAdapter) GetUnderlyingDB() interface{} {
|
func (g *GormAdapter) GetUnderlyingDB() interface{} {
|
||||||
return g.db
|
return g.getDB()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GormAdapter) DriverName() string {
|
func (g *GormAdapter) DriverName() string {
|
||||||
if g.db.Dialector == nil {
|
return normalizeGormDriverName(g.getDB())
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeGormDriverName(db *gorm.DB) string {
|
||||||
|
if db == nil || db.Dialector == nil {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
// Normalize GORM's dialector name to match the project's canonical vocabulary.
|
// Normalize GORM's dialector name to match the project's canonical vocabulary.
|
||||||
// GORM returns "sqlserver" for MSSQL; the rest of the project uses "mssql".
|
// GORM returns "sqlserver" for MSSQL; the rest of the project uses "mssql".
|
||||||
// GORM returns "sqlite" or "sqlite3" for SQLite; we normalize to "sqlite".
|
// GORM returns "sqlite" or "sqlite3" for SQLite; we normalize to "sqlite".
|
||||||
switch name := g.db.Name(); name {
|
switch name := db.Name(); name {
|
||||||
case "sqlserver":
|
case "sqlserver":
|
||||||
return "mssql"
|
return "mssql"
|
||||||
case "sqlite3":
|
case "sqlite3":
|
||||||
@@ -130,6 +233,7 @@ func (g *GormAdapter) DriverName() string {
|
|||||||
// GormSelectQuery implements SelectQuery for GORM
|
// GormSelectQuery implements SelectQuery for GORM
|
||||||
type GormSelectQuery struct {
|
type GormSelectQuery struct {
|
||||||
db *gorm.DB
|
db *gorm.DB
|
||||||
|
reconnect func(...*gorm.DB) error
|
||||||
schema string // Separated schema name
|
schema string // Separated schema name
|
||||||
tableName string // Just the table name, without schema
|
tableName string // Just the table name, without schema
|
||||||
tableAlias string
|
tableAlias string
|
||||||
@@ -347,6 +451,7 @@ func (g *GormSelectQuery) PreloadRelation(relation string, apply ...func(common.
|
|||||||
|
|
||||||
wrapper := &GormSelectQuery{
|
wrapper := &GormSelectQuery{
|
||||||
db: db,
|
db: db,
|
||||||
|
reconnect: g.reconnect,
|
||||||
driverName: g.driverName,
|
driverName: g.driverName,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -385,6 +490,7 @@ func (g *GormSelectQuery) JoinRelation(relation string, apply ...func(common.Sel
|
|||||||
|
|
||||||
wrapper := &GormSelectQuery{
|
wrapper := &GormSelectQuery{
|
||||||
db: db,
|
db: db,
|
||||||
|
reconnect: g.reconnect,
|
||||||
driverName: g.driverName,
|
driverName: g.driverName,
|
||||||
inJoinContext: true, // Mark as JOIN context
|
inJoinContext: true, // Mark as JOIN context
|
||||||
joinTableAlias: strings.ToLower(relation), // Use relation name as alias
|
joinTableAlias: strings.ToLower(relation), // Use relation name as alias
|
||||||
@@ -444,7 +550,15 @@ func (g *GormSelectQuery) Scan(ctx context.Context, dest interface{}) (err error
|
|||||||
err = logger.HandlePanic("GormSelectQuery.Scan", r)
|
err = logger.HandlePanic("GormSelectQuery.Scan", r)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
err = g.db.WithContext(ctx).Find(dest).Error
|
run := func() error {
|
||||||
|
return g.db.WithContext(ctx).Find(dest).Error
|
||||||
|
}
|
||||||
|
err = run()
|
||||||
|
if isDBClosed(err) && g.reconnect != nil {
|
||||||
|
if reconnErr := g.reconnect(g.db); reconnErr == nil {
|
||||||
|
err = run()
|
||||||
|
}
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Log SQL string for debugging
|
// Log SQL string for debugging
|
||||||
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
||||||
@@ -464,7 +578,15 @@ func (g *GormSelectQuery) ScanModel(ctx context.Context) (err error) {
|
|||||||
if g.db.Statement.Model == nil {
|
if g.db.Statement.Model == nil {
|
||||||
return fmt.Errorf("ScanModel requires Model() to be set before scanning")
|
return fmt.Errorf("ScanModel requires Model() to be set before scanning")
|
||||||
}
|
}
|
||||||
err = g.db.WithContext(ctx).Find(g.db.Statement.Model).Error
|
run := func() error {
|
||||||
|
return g.db.WithContext(ctx).Find(g.db.Statement.Model).Error
|
||||||
|
}
|
||||||
|
err = run()
|
||||||
|
if isDBClosed(err) && g.reconnect != nil {
|
||||||
|
if reconnErr := g.reconnect(g.db); reconnErr == nil {
|
||||||
|
err = run()
|
||||||
|
}
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Log SQL string for debugging
|
// Log SQL string for debugging
|
||||||
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
||||||
@@ -483,7 +605,15 @@ func (g *GormSelectQuery) Count(ctx context.Context) (count int, err error) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
var count64 int64
|
var count64 int64
|
||||||
err = g.db.WithContext(ctx).Count(&count64).Error
|
run := func() error {
|
||||||
|
return g.db.WithContext(ctx).Count(&count64).Error
|
||||||
|
}
|
||||||
|
err = run()
|
||||||
|
if isDBClosed(err) && g.reconnect != nil {
|
||||||
|
if reconnErr := g.reconnect(g.db); reconnErr == nil {
|
||||||
|
err = run()
|
||||||
|
}
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Log SQL string for debugging
|
// Log SQL string for debugging
|
||||||
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
||||||
@@ -502,7 +632,15 @@ func (g *GormSelectQuery) Exists(ctx context.Context) (exists bool, err error) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
var count int64
|
var count int64
|
||||||
err = g.db.WithContext(ctx).Limit(1).Count(&count).Error
|
run := func() error {
|
||||||
|
return g.db.WithContext(ctx).Limit(1).Count(&count).Error
|
||||||
|
}
|
||||||
|
err = run()
|
||||||
|
if isDBClosed(err) && g.reconnect != nil {
|
||||||
|
if reconnErr := g.reconnect(g.db); reconnErr == nil {
|
||||||
|
err = run()
|
||||||
|
}
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Log SQL string for debugging
|
// Log SQL string for debugging
|
||||||
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
||||||
@@ -515,9 +653,10 @@ func (g *GormSelectQuery) Exists(ctx context.Context) (exists bool, err error) {
|
|||||||
|
|
||||||
// GormInsertQuery implements InsertQuery for GORM
|
// GormInsertQuery implements InsertQuery for GORM
|
||||||
type GormInsertQuery struct {
|
type GormInsertQuery struct {
|
||||||
db *gorm.DB
|
db *gorm.DB
|
||||||
model interface{}
|
reconnect func(...*gorm.DB) error
|
||||||
values map[string]interface{}
|
model interface{}
|
||||||
|
values map[string]interface{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GormInsertQuery) Model(model interface{}) common.InsertQuery {
|
func (g *GormInsertQuery) Model(model interface{}) common.InsertQuery {
|
||||||
@@ -555,23 +694,31 @@ func (g *GormInsertQuery) Exec(ctx context.Context) (res common.Result, err erro
|
|||||||
err = logger.HandlePanic("GormInsertQuery.Exec", r)
|
err = logger.HandlePanic("GormInsertQuery.Exec", r)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
var result *gorm.DB
|
run := func() *gorm.DB {
|
||||||
switch {
|
switch {
|
||||||
case g.model != nil:
|
case g.model != nil:
|
||||||
result = g.db.WithContext(ctx).Create(g.model)
|
return g.db.WithContext(ctx).Create(g.model)
|
||||||
case g.values != nil:
|
case g.values != nil:
|
||||||
result = g.db.WithContext(ctx).Create(g.values)
|
return g.db.WithContext(ctx).Create(g.values)
|
||||||
default:
|
default:
|
||||||
result = g.db.WithContext(ctx).Create(map[string]interface{}{})
|
return g.db.WithContext(ctx).Create(map[string]interface{}{})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
result := run()
|
||||||
|
if isDBClosed(result.Error) && g.reconnect != nil {
|
||||||
|
if reconnErr := g.reconnect(g.db); reconnErr == nil {
|
||||||
|
result = run()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return &GormResult{result: result}, result.Error
|
return &GormResult{result: result}, result.Error
|
||||||
}
|
}
|
||||||
|
|
||||||
// GormUpdateQuery implements UpdateQuery for GORM
|
// GormUpdateQuery implements UpdateQuery for GORM
|
||||||
type GormUpdateQuery struct {
|
type GormUpdateQuery struct {
|
||||||
db *gorm.DB
|
db *gorm.DB
|
||||||
model interface{}
|
reconnect func(...*gorm.DB) error
|
||||||
updates interface{}
|
model interface{}
|
||||||
|
updates interface{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GormUpdateQuery) Model(model interface{}) common.UpdateQuery {
|
func (g *GormUpdateQuery) Model(model interface{}) common.UpdateQuery {
|
||||||
@@ -647,7 +794,15 @@ func (g *GormUpdateQuery) Exec(ctx context.Context) (res common.Result, err erro
|
|||||||
err = logger.HandlePanic("GormUpdateQuery.Exec", r)
|
err = logger.HandlePanic("GormUpdateQuery.Exec", r)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
result := g.db.WithContext(ctx).Updates(g.updates)
|
run := func() *gorm.DB {
|
||||||
|
return g.db.WithContext(ctx).Updates(g.updates)
|
||||||
|
}
|
||||||
|
result := run()
|
||||||
|
if isDBClosed(result.Error) && g.reconnect != nil {
|
||||||
|
if reconnErr := g.reconnect(g.db); reconnErr == nil {
|
||||||
|
result = run()
|
||||||
|
}
|
||||||
|
}
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
// Log SQL string for debugging
|
// Log SQL string for debugging
|
||||||
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
||||||
@@ -660,8 +815,9 @@ func (g *GormUpdateQuery) Exec(ctx context.Context) (res common.Result, err erro
|
|||||||
|
|
||||||
// GormDeleteQuery implements DeleteQuery for GORM
|
// GormDeleteQuery implements DeleteQuery for GORM
|
||||||
type GormDeleteQuery struct {
|
type GormDeleteQuery struct {
|
||||||
db *gorm.DB
|
db *gorm.DB
|
||||||
model interface{}
|
reconnect func(...*gorm.DB) error
|
||||||
|
model interface{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GormDeleteQuery) Model(model interface{}) common.DeleteQuery {
|
func (g *GormDeleteQuery) Model(model interface{}) common.DeleteQuery {
|
||||||
@@ -686,7 +842,15 @@ func (g *GormDeleteQuery) Exec(ctx context.Context) (res common.Result, err erro
|
|||||||
err = logger.HandlePanic("GormDeleteQuery.Exec", r)
|
err = logger.HandlePanic("GormDeleteQuery.Exec", r)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
result := g.db.WithContext(ctx).Delete(g.model)
|
run := func() *gorm.DB {
|
||||||
|
return g.db.WithContext(ctx).Delete(g.model)
|
||||||
|
}
|
||||||
|
result := run()
|
||||||
|
if isDBClosed(result.Error) && g.reconnect != nil {
|
||||||
|
if reconnErr := g.reconnect(g.db); reconnErr == nil {
|
||||||
|
result = run()
|
||||||
|
}
|
||||||
|
}
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
// Log SQL string for debugging
|
// Log SQL string for debugging
|
||||||
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
||||||
|
|||||||
@@ -359,6 +359,42 @@ func (c *sqlConnection) Stats() *ConnectionStats {
|
|||||||
return stats
|
return stats
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *sqlConnection) reconnectForAdapter() error {
|
||||||
|
timeout := c.config.ConnectTimeout
|
||||||
|
if timeout <= 0 {
|
||||||
|
timeout = 10 * time.Second
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
return c.Reconnect(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *sqlConnection) reopenNativeForAdapter() (*sql.DB, error) {
|
||||||
|
if err := c.reconnectForAdapter(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.Native()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *sqlConnection) reopenBunForAdapter() (*bun.DB, error) {
|
||||||
|
if err := c.reconnectForAdapter(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.Bun()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *sqlConnection) reopenGORMForAdapter() (*gorm.DB, error) {
|
||||||
|
if err := c.reconnectForAdapter(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.GORM()
|
||||||
|
}
|
||||||
|
|
||||||
// getBunAdapter returns or creates the Bun adapter
|
// getBunAdapter returns or creates the Bun adapter
|
||||||
func (c *sqlConnection) getBunAdapter() (common.Database, error) {
|
func (c *sqlConnection) getBunAdapter() (common.Database, error) {
|
||||||
if c == nil {
|
if c == nil {
|
||||||
@@ -391,7 +427,7 @@ func (c *sqlConnection) getBunAdapter() (common.Database, error) {
|
|||||||
c.bunDB = bun.NewDB(native, dialect)
|
c.bunDB = bun.NewDB(native, dialect)
|
||||||
}
|
}
|
||||||
|
|
||||||
c.bunAdapter = database.NewBunAdapter(c.bunDB)
|
c.bunAdapter = database.NewBunAdapter(c.bunDB).WithDBFactory(c.reopenBunForAdapter)
|
||||||
return c.bunAdapter, nil
|
return c.bunAdapter, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -432,7 +468,7 @@ func (c *sqlConnection) getGORMAdapter() (common.Database, error) {
|
|||||||
c.gormDB = db
|
c.gormDB = db
|
||||||
}
|
}
|
||||||
|
|
||||||
c.gormAdapter = database.NewGormAdapter(c.gormDB)
|
c.gormAdapter = database.NewGormAdapter(c.gormDB).WithDBFactory(c.reopenGORMForAdapter)
|
||||||
return c.gormAdapter, nil
|
return c.gormAdapter, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -473,11 +509,11 @@ func (c *sqlConnection) getNativeAdapter() (common.Database, error) {
|
|||||||
// Create a native adapter based on database type
|
// Create a native adapter based on database type
|
||||||
switch c.dbType {
|
switch c.dbType {
|
||||||
case DatabaseTypePostgreSQL:
|
case DatabaseTypePostgreSQL:
|
||||||
c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB, string(c.dbType))
|
c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB, string(c.dbType)).WithDBFactory(c.reopenNativeForAdapter)
|
||||||
case DatabaseTypeSQLite:
|
case DatabaseTypeSQLite:
|
||||||
c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB, string(c.dbType))
|
c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB, string(c.dbType)).WithDBFactory(c.reopenNativeForAdapter)
|
||||||
case DatabaseTypeMSSQL:
|
case DatabaseTypeMSSQL:
|
||||||
c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB, string(c.dbType))
|
c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB, string(c.dbType)).WithDBFactory(c.reopenNativeForAdapter)
|
||||||
default:
|
default:
|
||||||
return nil, ErrUnsupportedDatabase
|
return nil, ErrUnsupportedDatabase
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,8 +4,13 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
_ "github.com/mattn/go-sqlite3"
|
_ "github.com/mattn/go-sqlite3"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/common/adapters/database"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/dbmanager/providers"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestNewConnectionFromDB(t *testing.T) {
|
func TestNewConnectionFromDB(t *testing.T) {
|
||||||
@@ -208,3 +213,157 @@ func TestNewConnectionFromDB_PostgreSQL(t *testing.T) {
|
|||||||
t.Errorf("Expected type DatabaseTypePostgreSQL, got '%s'", conn.Type())
|
t.Errorf("Expected type DatabaseTypePostgreSQL, got '%s'", conn.Type())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDatabaseNativeAdapterReconnectFactory(t *testing.T) {
|
||||||
|
conn := newSQLConnection("test-native", DatabaseTypeSQLite, ConnectionConfig{
|
||||||
|
Name: "test-native",
|
||||||
|
Type: DatabaseTypeSQLite,
|
||||||
|
FilePath: ":memory:",
|
||||||
|
DefaultORM: string(ORMTypeNative),
|
||||||
|
ConnectTimeout: 2 * time.Second,
|
||||||
|
}, providers.NewSQLiteProvider())
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
if err := conn.Connect(ctx); err != nil {
|
||||||
|
t.Fatalf("Failed to connect: %v", err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
db, err := conn.Database()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to get database adapter: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
adapter, ok := db.(*database.PgSQLAdapter)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("Expected PgSQLAdapter, got %T", db)
|
||||||
|
}
|
||||||
|
|
||||||
|
underlyingBefore, ok := adapter.GetUnderlyingDB().(*sql.DB)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("Expected underlying *sql.DB, got %T", adapter.GetUnderlyingDB())
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := underlyingBefore.Close(); err != nil {
|
||||||
|
t.Fatalf("Failed to close underlying database: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := db.Exec(ctx, "SELECT 1"); err != nil {
|
||||||
|
t.Fatalf("Expected native adapter to reconnect, got error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
underlyingAfter, ok := adapter.GetUnderlyingDB().(*sql.DB)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("Expected reconnected *sql.DB, got %T", adapter.GetUnderlyingDB())
|
||||||
|
}
|
||||||
|
|
||||||
|
if underlyingAfter == underlyingBefore {
|
||||||
|
t.Fatal("Expected adapter to swap to a fresh *sql.DB after reconnect")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDatabaseBunAdapterReconnectFactory(t *testing.T) {
|
||||||
|
conn := newSQLConnection("test-bun", DatabaseTypeSQLite, ConnectionConfig{
|
||||||
|
Name: "test-bun",
|
||||||
|
Type: DatabaseTypeSQLite,
|
||||||
|
FilePath: ":memory:",
|
||||||
|
DefaultORM: string(ORMTypeBun),
|
||||||
|
ConnectTimeout: 2 * time.Second,
|
||||||
|
}, providers.NewSQLiteProvider())
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
if err := conn.Connect(ctx); err != nil {
|
||||||
|
t.Fatalf("Failed to connect: %v", err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
db, err := conn.Database()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to get database adapter: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
adapter, ok := db.(*database.BunAdapter)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("Expected BunAdapter, got %T", db)
|
||||||
|
}
|
||||||
|
|
||||||
|
underlyingBefore, ok := adapter.GetUnderlyingDB().(interface{ Close() error })
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("Expected underlying Bun DB with Close method, got %T", adapter.GetUnderlyingDB())
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := underlyingBefore.Close(); err != nil {
|
||||||
|
t.Fatalf("Failed to close underlying Bun database: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := db.Exec(ctx, "SELECT 1"); err != nil {
|
||||||
|
t.Fatalf("Expected Bun adapter to reconnect, got error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
underlyingAfter := adapter.GetUnderlyingDB()
|
||||||
|
if underlyingAfter == underlyingBefore {
|
||||||
|
t.Fatal("Expected adapter to swap to a fresh Bun DB after reconnect")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDatabaseGormAdapterReconnectFactory(t *testing.T) {
|
||||||
|
conn := newSQLConnection("test-gorm", DatabaseTypeSQLite, ConnectionConfig{
|
||||||
|
Name: "test-gorm",
|
||||||
|
Type: DatabaseTypeSQLite,
|
||||||
|
FilePath: ":memory:",
|
||||||
|
DefaultORM: string(ORMTypeGORM),
|
||||||
|
ConnectTimeout: 2 * time.Second,
|
||||||
|
}, providers.NewSQLiteProvider())
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
if err := conn.Connect(ctx); err != nil {
|
||||||
|
t.Fatalf("Failed to connect: %v", err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
db, err := conn.Database()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to get database adapter: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
adapter, ok := db.(*database.GormAdapter)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("Expected GormAdapter, got %T", db)
|
||||||
|
}
|
||||||
|
|
||||||
|
gormBefore, ok := adapter.GetUnderlyingDB().(*gorm.DB)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("Expected underlying *gorm.DB, got %T", adapter.GetUnderlyingDB())
|
||||||
|
}
|
||||||
|
|
||||||
|
sqlBefore, err := gormBefore.DB()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to get underlying *sql.DB: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := sqlBefore.Close(); err != nil {
|
||||||
|
t.Fatalf("Failed to close underlying database: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
count, err := db.NewSelect().Table("sqlite_master").Count(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Expected GORM query builder to reconnect, got error: %v", err)
|
||||||
|
}
|
||||||
|
if count < 0 {
|
||||||
|
t.Fatalf("Expected non-negative count, got %d", count)
|
||||||
|
}
|
||||||
|
|
||||||
|
gormAfter, ok := adapter.GetUnderlyingDB().(*gorm.DB)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("Expected reconnected *gorm.DB, got %T", adapter.GetUnderlyingDB())
|
||||||
|
}
|
||||||
|
|
||||||
|
sqlAfter, err := gormAfter.DB()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to get reconnected *sql.DB: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if sqlAfter == sqlBefore {
|
||||||
|
t.Fatal("Expected GORM adapter to use a fresh *sql.DB after reconnect")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,7 +2,9 @@ package dbmanager
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -366,8 +368,11 @@ func (m *connectionManager) performHealthCheck() {
|
|||||||
"connection", item.name,
|
"connection", item.name,
|
||||||
"error", err)
|
"error", err)
|
||||||
|
|
||||||
// Attempt reconnection if enabled
|
// Only reconnect when the client handle itself is closed/disconnected.
|
||||||
if m.config.EnableAutoReconnect {
|
// For transient database restarts or network blips, *sql.DB can recover
|
||||||
|
// on its own; forcing Close()+Connect() here invalidates any cached ORM
|
||||||
|
// wrappers and callers that still hold the old handle.
|
||||||
|
if m.config.EnableAutoReconnect && shouldReconnectAfterHealthCheck(err) {
|
||||||
logger.Info("Attempting reconnection: connection=%s", item.name)
|
logger.Info("Attempting reconnection: connection=%s", item.name)
|
||||||
if err := item.conn.Reconnect(ctx); err != nil {
|
if err := item.conn.Reconnect(ctx); err != nil {
|
||||||
logger.Error("Reconnection failed",
|
logger.Error("Reconnection failed",
|
||||||
@@ -376,7 +381,21 @@ func (m *connectionManager) performHealthCheck() {
|
|||||||
} else {
|
} else {
|
||||||
logger.Info("Reconnection successful: connection=%s", item.name)
|
logger.Info("Reconnection successful: connection=%s", item.name)
|
||||||
}
|
}
|
||||||
|
} else if m.config.EnableAutoReconnect {
|
||||||
|
logger.Info("Skipping reconnect for transient health check failure: connection=%s", item.name)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func shouldReconnectAfterHealthCheck(err error) bool {
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if errors.Is(err, ErrConnectionClosed) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return strings.Contains(err.Error(), "sql: database is closed")
|
||||||
|
}
|
||||||
|
|||||||
@@ -3,12 +3,38 @@ package dbmanager
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/uptrace/bun"
|
||||||
|
"go.mongodb.org/mongo-driver/mongo"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||||
|
|
||||||
_ "github.com/mattn/go-sqlite3"
|
_ "github.com/mattn/go-sqlite3"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type healthCheckStubConnection struct {
|
||||||
|
healthErr error
|
||||||
|
reconnectCalls int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *healthCheckStubConnection) Name() string { return "stub" }
|
||||||
|
func (c *healthCheckStubConnection) Type() DatabaseType { return DatabaseTypePostgreSQL }
|
||||||
|
func (c *healthCheckStubConnection) Bun() (*bun.DB, error) { return nil, fmt.Errorf("not implemented") }
|
||||||
|
func (c *healthCheckStubConnection) GORM() (*gorm.DB, error) { return nil, fmt.Errorf("not implemented") }
|
||||||
|
func (c *healthCheckStubConnection) Native() (*sql.DB, error) { return nil, fmt.Errorf("not implemented") }
|
||||||
|
func (c *healthCheckStubConnection) DB() (*sql.DB, error) { return nil, fmt.Errorf("not implemented") }
|
||||||
|
func (c *healthCheckStubConnection) Database() (common.Database, error) { return nil, fmt.Errorf("not implemented") }
|
||||||
|
func (c *healthCheckStubConnection) MongoDB() (*mongo.Client, error) { return nil, fmt.Errorf("not implemented") }
|
||||||
|
func (c *healthCheckStubConnection) Connect(ctx context.Context) error { return nil }
|
||||||
|
func (c *healthCheckStubConnection) Close() error { return nil }
|
||||||
|
func (c *healthCheckStubConnection) HealthCheck(ctx context.Context) error { return c.healthErr }
|
||||||
|
func (c *healthCheckStubConnection) Reconnect(ctx context.Context) error { c.reconnectCalls++; return nil }
|
||||||
|
func (c *healthCheckStubConnection) Stats() *ConnectionStats { return &ConnectionStats{} }
|
||||||
|
|
||||||
func TestBackgroundHealthChecker(t *testing.T) {
|
func TestBackgroundHealthChecker(t *testing.T) {
|
||||||
// Create a SQLite in-memory database
|
// Create a SQLite in-memory database
|
||||||
db, err := sql.Open("sqlite3", ":memory:")
|
db, err := sql.Open("sqlite3", ":memory:")
|
||||||
@@ -224,3 +250,41 @@ func TestManagerStatsAfterClose(t *testing.T) {
|
|||||||
t.Errorf("Expected 0 total connections after close, got %d", stats.TotalConnections)
|
t.Errorf("Expected 0 total connections after close, got %d", stats.TotalConnections)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestPerformHealthCheckSkipsReconnectForTransientFailures(t *testing.T) {
|
||||||
|
conn := &healthCheckStubConnection{
|
||||||
|
healthErr: fmt.Errorf("connection 'primary' health check: dial tcp 127.0.0.1:5432: connect: connection refused"),
|
||||||
|
}
|
||||||
|
|
||||||
|
mgr := &connectionManager{
|
||||||
|
connections: map[string]Connection{"primary": conn},
|
||||||
|
config: ManagerConfig{
|
||||||
|
EnableAutoReconnect: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
mgr.performHealthCheck()
|
||||||
|
|
||||||
|
if conn.reconnectCalls != 0 {
|
||||||
|
t.Fatalf("expected no reconnect attempts for transient health failure, got %d", conn.reconnectCalls)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPerformHealthCheckReconnectsClosedConnections(t *testing.T) {
|
||||||
|
conn := &healthCheckStubConnection{
|
||||||
|
healthErr: NewConnectionError("primary", "health check", fmt.Errorf("sql: database is closed")),
|
||||||
|
}
|
||||||
|
|
||||||
|
mgr := &connectionManager{
|
||||||
|
connections: map[string]Connection{"primary": conn},
|
||||||
|
config: ManagerConfig{
|
||||||
|
EnableAutoReconnect: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
mgr.performHealthCheck()
|
||||||
|
|
||||||
|
if conn.reconnectCalls != 1 {
|
||||||
|
t.Fatalf("expected reconnect attempt for closed database handle, got %d", conn.reconnectCalls)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -143,6 +143,22 @@ func (a *DatabaseAuthenticator) reconnectDB() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *DatabaseAuthenticator) runDBOpWithReconnect(run func(*sql.DB) error) error {
|
||||||
|
db := a.getDB()
|
||||||
|
if db == nil {
|
||||||
|
return fmt.Errorf("database connection is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
err := run(db)
|
||||||
|
if isDBClosed(err) {
|
||||||
|
if reconnErr := a.reconnectDB(); reconnErr == nil {
|
||||||
|
err = run(a.getDB())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
func (a *DatabaseAuthenticator) Login(ctx context.Context, req LoginRequest) (*LoginResponse, error) {
|
func (a *DatabaseAuthenticator) Login(ctx context.Context, req LoginRequest) (*LoginResponse, error) {
|
||||||
// Convert LoginRequest to JSON
|
// Convert LoginRequest to JSON
|
||||||
reqJSON, err := json.Marshal(req)
|
reqJSON, err := json.Marshal(req)
|
||||||
@@ -154,16 +170,10 @@ func (a *DatabaseAuthenticator) Login(ctx context.Context, req LoginRequest) (*L
|
|||||||
var errorMsg sql.NullString
|
var errorMsg sql.NullString
|
||||||
var dataJSON sql.NullString
|
var dataJSON sql.NullString
|
||||||
|
|
||||||
runLoginQuery := func() error {
|
err = a.runDBOpWithReconnect(func(db *sql.DB) error {
|
||||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Login)
|
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Login)
|
||||||
return a.getDB().QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
|
return db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
|
||||||
}
|
})
|
||||||
err = runLoginQuery()
|
|
||||||
if isDBClosed(err) {
|
|
||||||
if reconnErr := a.reconnectDB(); reconnErr == nil {
|
|
||||||
err = runLoginQuery()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("login query failed: %w", err)
|
return nil, fmt.Errorf("login query failed: %w", err)
|
||||||
}
|
}
|
||||||
@@ -196,8 +206,10 @@ func (a *DatabaseAuthenticator) Register(ctx context.Context, req RegisterReques
|
|||||||
var errorMsg sql.NullString
|
var errorMsg sql.NullString
|
||||||
var dataJSON sql.NullString
|
var dataJSON sql.NullString
|
||||||
|
|
||||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Register)
|
err = a.runDBOpWithReconnect(func(db *sql.DB) error {
|
||||||
err = a.getDB().QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
|
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Register)
|
||||||
|
return db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("register query failed: %w", err)
|
return nil, fmt.Errorf("register query failed: %w", err)
|
||||||
}
|
}
|
||||||
@@ -229,8 +241,10 @@ func (a *DatabaseAuthenticator) Logout(ctx context.Context, req LogoutRequest) e
|
|||||||
var errorMsg sql.NullString
|
var errorMsg sql.NullString
|
||||||
var dataJSON sql.NullString
|
var dataJSON sql.NullString
|
||||||
|
|
||||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Logout)
|
err = a.runDBOpWithReconnect(func(db *sql.DB) error {
|
||||||
err = a.getDB().QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
|
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Logout)
|
||||||
|
return db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("logout query failed: %w", err)
|
return fmt.Errorf("logout query failed: %w", err)
|
||||||
}
|
}
|
||||||
@@ -303,8 +317,10 @@ func (a *DatabaseAuthenticator) Authenticate(r *http.Request) (*UserContext, err
|
|||||||
var errorMsg sql.NullString
|
var errorMsg sql.NullString
|
||||||
var userJSON sql.NullString
|
var userJSON sql.NullString
|
||||||
|
|
||||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2)`, a.sqlNames.Session)
|
err := a.runDBOpWithReconnect(func(db *sql.DB) error {
|
||||||
err := a.getDB().QueryRowContext(r.Context(), query, token, reference).Scan(&success, &errorMsg, &userJSON)
|
query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2)`, a.sqlNames.Session)
|
||||||
|
return db.QueryRowContext(r.Context(), query, token, reference).Scan(&success, &errorMsg, &userJSON)
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("session query failed: %w", err)
|
return nil, fmt.Errorf("session query failed: %w", err)
|
||||||
}
|
}
|
||||||
@@ -379,8 +395,10 @@ func (a *DatabaseAuthenticator) updateSessionActivity(ctx context.Context, sessi
|
|||||||
var errorMsg sql.NullString
|
var errorMsg sql.NullString
|
||||||
var updatedUserJSON sql.NullString
|
var updatedUserJSON sql.NullString
|
||||||
|
|
||||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2::jsonb)`, a.sqlNames.SessionUpdate)
|
_ = a.runDBOpWithReconnect(func(db *sql.DB) error {
|
||||||
_ = a.getDB().QueryRowContext(ctx, query, sessionToken, string(userJSON)).Scan(&success, &errorMsg, &updatedUserJSON)
|
query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2::jsonb)`, a.sqlNames.SessionUpdate)
|
||||||
|
return db.QueryRowContext(ctx, query, sessionToken, string(userJSON)).Scan(&success, &errorMsg, &updatedUserJSON)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// RefreshToken implements Refreshable interface
|
// RefreshToken implements Refreshable interface
|
||||||
@@ -390,8 +408,10 @@ func (a *DatabaseAuthenticator) RefreshToken(ctx context.Context, refreshToken s
|
|||||||
var errorMsg sql.NullString
|
var errorMsg sql.NullString
|
||||||
var userJSON sql.NullString
|
var userJSON sql.NullString
|
||||||
// Get current session to pass to refresh
|
// Get current session to pass to refresh
|
||||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2)`, a.sqlNames.Session)
|
err := a.runDBOpWithReconnect(func(db *sql.DB) error {
|
||||||
err := a.getDB().QueryRowContext(ctx, query, refreshToken, "refresh").Scan(&success, &errorMsg, &userJSON)
|
query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2)`, a.sqlNames.Session)
|
||||||
|
return db.QueryRowContext(ctx, query, refreshToken, "refresh").Scan(&success, &errorMsg, &userJSON)
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("refresh token query failed: %w", err)
|
return nil, fmt.Errorf("refresh token query failed: %w", err)
|
||||||
}
|
}
|
||||||
@@ -407,8 +427,10 @@ func (a *DatabaseAuthenticator) RefreshToken(ctx context.Context, refreshToken s
|
|||||||
var newErrorMsg sql.NullString
|
var newErrorMsg sql.NullString
|
||||||
var newUserJSON sql.NullString
|
var newUserJSON sql.NullString
|
||||||
|
|
||||||
refreshQuery := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2::jsonb)`, a.sqlNames.RefreshToken)
|
err = a.runDBOpWithReconnect(func(db *sql.DB) error {
|
||||||
err = a.getDB().QueryRowContext(ctx, refreshQuery, refreshToken, userJSON).Scan(&newSuccess, &newErrorMsg, &newUserJSON)
|
refreshQuery := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2::jsonb)`, a.sqlNames.RefreshToken)
|
||||||
|
return db.QueryRowContext(ctx, refreshQuery, refreshToken, userJSON).Scan(&newSuccess, &newErrorMsg, &newUserJSON)
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("refresh token generation failed: %w", err)
|
return nil, fmt.Errorf("refresh token generation failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package security
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"testing"
|
"testing"
|
||||||
@@ -790,6 +791,211 @@ func TestDatabaseAuthenticatorRefreshToken(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDatabaseAuthenticatorReconnectsClosedDBPaths(t *testing.T) {
|
||||||
|
newAuthWithReconnect := func(t *testing.T) (*DatabaseAuthenticator, sqlmock.Sqlmock, sqlmock.Sqlmock, func()) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
primaryDB, primaryMock, err := sqlmock.New()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create primary mock db: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
reconnectDB, reconnectMock, err := sqlmock.New()
|
||||||
|
if err != nil {
|
||||||
|
primaryDB.Close()
|
||||||
|
t.Fatalf("failed to create reconnect mock db: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cacheProvider := cache.NewMemoryProvider(&cache.Options{
|
||||||
|
DefaultTTL: 1 * time.Minute,
|
||||||
|
MaxSize: 1000,
|
||||||
|
})
|
||||||
|
|
||||||
|
auth := NewDatabaseAuthenticatorWithOptions(primaryDB, DatabaseAuthenticatorOptions{
|
||||||
|
Cache: cache.NewCache(cacheProvider),
|
||||||
|
DBFactory: func() (*sql.DB, error) {
|
||||||
|
return reconnectDB, nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
cleanup := func() {
|
||||||
|
_ = primaryDB.Close()
|
||||||
|
_ = reconnectDB.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
return auth, primaryMock, reconnectMock, cleanup
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("Authenticate reconnects after closed database", func(t *testing.T) {
|
||||||
|
auth, primaryMock, reconnectMock, cleanup := newAuthWithReconnect(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/test", nil)
|
||||||
|
req.Header.Set("Authorization", "Bearer reconnect-auth-token")
|
||||||
|
|
||||||
|
primaryMock.ExpectQuery(`SELECT p_success, p_error, p_user::text FROM resolvespec_session`).
|
||||||
|
WithArgs("reconnect-auth-token", "authenticate").
|
||||||
|
WillReturnError(fmt.Errorf("sql: database is closed"))
|
||||||
|
|
||||||
|
reconnectRows := sqlmock.NewRows([]string{"p_success", "p_error", "p_user"}).
|
||||||
|
AddRow(true, nil, `{"user_id":7,"user_name":"reconnect-user","session_id":"reconnect-auth-token"}`)
|
||||||
|
|
||||||
|
reconnectMock.ExpectQuery(`SELECT p_success, p_error, p_user::text FROM resolvespec_session`).
|
||||||
|
WithArgs("reconnect-auth-token", "authenticate").
|
||||||
|
WillReturnRows(reconnectRows)
|
||||||
|
|
||||||
|
userCtx, err := auth.Authenticate(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected authenticate to reconnect, got %v", err)
|
||||||
|
}
|
||||||
|
if userCtx.UserID != 7 {
|
||||||
|
t.Fatalf("expected user ID 7, got %d", userCtx.UserID)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := primaryMock.ExpectationsWereMet(); err != nil {
|
||||||
|
t.Fatalf("primary db expectations not met: %v", err)
|
||||||
|
}
|
||||||
|
if err := reconnectMock.ExpectationsWereMet(); err != nil {
|
||||||
|
t.Fatalf("reconnect db expectations not met: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Register reconnects after closed database", func(t *testing.T) {
|
||||||
|
auth, primaryMock, reconnectMock, cleanup := newAuthWithReconnect(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
req := RegisterRequest{
|
||||||
|
Username: "reconnect-register",
|
||||||
|
Password: "password123",
|
||||||
|
Email: "reconnect@example.com",
|
||||||
|
UserLevel: 1,
|
||||||
|
Roles: []string{"user"},
|
||||||
|
}
|
||||||
|
|
||||||
|
primaryMock.ExpectQuery(`SELECT p_success, p_error, p_data::text FROM resolvespec_register`).
|
||||||
|
WithArgs(sqlmock.AnyArg()).
|
||||||
|
WillReturnError(fmt.Errorf("sql: database is closed"))
|
||||||
|
|
||||||
|
reconnectRows := sqlmock.NewRows([]string{"p_success", "p_error", "p_data"}).
|
||||||
|
AddRow(true, nil, `{"token":"reconnected-register-token","user":{"user_id":8,"user_name":"reconnect-register"},"expires_in":86400}`)
|
||||||
|
|
||||||
|
reconnectMock.ExpectQuery(`SELECT p_success, p_error, p_data::text FROM resolvespec_register`).
|
||||||
|
WithArgs(sqlmock.AnyArg()).
|
||||||
|
WillReturnRows(reconnectRows)
|
||||||
|
|
||||||
|
resp, err := auth.Register(context.Background(), req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected register to reconnect, got %v", err)
|
||||||
|
}
|
||||||
|
if resp.Token != "reconnected-register-token" {
|
||||||
|
t.Fatalf("expected refreshed token, got %s", resp.Token)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := primaryMock.ExpectationsWereMet(); err != nil {
|
||||||
|
t.Fatalf("primary db expectations not met: %v", err)
|
||||||
|
}
|
||||||
|
if err := reconnectMock.ExpectationsWereMet(); err != nil {
|
||||||
|
t.Fatalf("reconnect db expectations not met: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Logout reconnects after closed database", func(t *testing.T) {
|
||||||
|
auth, primaryMock, reconnectMock, cleanup := newAuthWithReconnect(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
req := LogoutRequest{Token: "logout-reconnect-token", UserID: 9}
|
||||||
|
|
||||||
|
primaryMock.ExpectQuery(`SELECT p_success, p_error, p_data::text FROM resolvespec_logout`).
|
||||||
|
WithArgs(sqlmock.AnyArg()).
|
||||||
|
WillReturnError(fmt.Errorf("sql: database is closed"))
|
||||||
|
|
||||||
|
reconnectRows := sqlmock.NewRows([]string{"p_success", "p_error", "p_data"}).
|
||||||
|
AddRow(true, nil, nil)
|
||||||
|
|
||||||
|
reconnectMock.ExpectQuery(`SELECT p_success, p_error, p_data::text FROM resolvespec_logout`).
|
||||||
|
WithArgs(sqlmock.AnyArg()).
|
||||||
|
WillReturnRows(reconnectRows)
|
||||||
|
|
||||||
|
if err := auth.Logout(context.Background(), req); err != nil {
|
||||||
|
t.Fatalf("expected logout to reconnect, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := primaryMock.ExpectationsWereMet(); err != nil {
|
||||||
|
t.Fatalf("primary db expectations not met: %v", err)
|
||||||
|
}
|
||||||
|
if err := reconnectMock.ExpectationsWereMet(); err != nil {
|
||||||
|
t.Fatalf("reconnect db expectations not met: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("RefreshToken reconnects after closed database", func(t *testing.T) {
|
||||||
|
auth, primaryMock, reconnectMock, cleanup := newAuthWithReconnect(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
refreshToken := "refresh-reconnect-token"
|
||||||
|
|
||||||
|
primaryMock.ExpectQuery(`SELECT p_success, p_error, p_user::text FROM resolvespec_session`).
|
||||||
|
WithArgs(refreshToken, "refresh").
|
||||||
|
WillReturnError(fmt.Errorf("sql: database is closed"))
|
||||||
|
|
||||||
|
sessionRows := sqlmock.NewRows([]string{"p_success", "p_error", "p_user"}).
|
||||||
|
AddRow(true, nil, `{"user_id":10,"user_name":"refresh-user"}`)
|
||||||
|
|
||||||
|
reconnectMock.ExpectQuery(`SELECT p_success, p_error, p_user::text FROM resolvespec_session`).
|
||||||
|
WithArgs(refreshToken, "refresh").
|
||||||
|
WillReturnRows(sessionRows)
|
||||||
|
|
||||||
|
refreshRows := sqlmock.NewRows([]string{"p_success", "p_error", "p_user"}).
|
||||||
|
AddRow(true, nil, `{"user_id":10,"user_name":"refresh-user","session_id":"refreshed-token"}`)
|
||||||
|
|
||||||
|
reconnectMock.ExpectQuery(`SELECT p_success, p_error, p_user::text FROM resolvespec_refresh_token`).
|
||||||
|
WithArgs(refreshToken, sqlmock.AnyArg()).
|
||||||
|
WillReturnRows(refreshRows)
|
||||||
|
|
||||||
|
resp, err := auth.RefreshToken(context.Background(), refreshToken)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected refresh token to reconnect, got %v", err)
|
||||||
|
}
|
||||||
|
if resp.Token != "refreshed-token" {
|
||||||
|
t.Fatalf("expected refreshed-token, got %s", resp.Token)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := primaryMock.ExpectationsWereMet(); err != nil {
|
||||||
|
t.Fatalf("primary db expectations not met: %v", err)
|
||||||
|
}
|
||||||
|
if err := reconnectMock.ExpectationsWereMet(); err != nil {
|
||||||
|
t.Fatalf("reconnect db expectations not met: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("updateSessionActivity reconnects after closed database", func(t *testing.T) {
|
||||||
|
auth, primaryMock, reconnectMock, cleanup := newAuthWithReconnect(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
userCtx := &UserContext{UserID: 11, UserName: "activity-user", SessionID: "activity-token"}
|
||||||
|
|
||||||
|
primaryMock.ExpectQuery(`SELECT p_success, p_error, p_user::text FROM resolvespec_session_update`).
|
||||||
|
WithArgs("activity-token", sqlmock.AnyArg()).
|
||||||
|
WillReturnError(fmt.Errorf("sql: database is closed"))
|
||||||
|
|
||||||
|
reconnectRows := sqlmock.NewRows([]string{"p_success", "p_error", "p_user"}).
|
||||||
|
AddRow(true, nil, `{"user_id":11,"user_name":"activity-user","session_id":"activity-token"}`)
|
||||||
|
|
||||||
|
reconnectMock.ExpectQuery(`SELECT p_success, p_error, p_user::text FROM resolvespec_session_update`).
|
||||||
|
WithArgs("activity-token", sqlmock.AnyArg()).
|
||||||
|
WillReturnRows(reconnectRows)
|
||||||
|
|
||||||
|
auth.updateSessionActivity(context.Background(), "activity-token", userCtx)
|
||||||
|
|
||||||
|
if err := primaryMock.ExpectationsWereMet(); err != nil {
|
||||||
|
t.Fatalf("primary db expectations not met: %v", err)
|
||||||
|
}
|
||||||
|
if err := reconnectMock.ExpectationsWereMet(); err != nil {
|
||||||
|
t.Fatalf("reconnect db expectations not met: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// Test JWTAuthenticator
|
// Test JWTAuthenticator
|
||||||
func TestJWTAuthenticator(t *testing.T) {
|
func TestJWTAuthenticator(t *testing.T) {
|
||||||
db, mock, err := sqlmock.New()
|
db, mock, err := sqlmock.New()
|
||||||
|
|||||||
Reference in New Issue
Block a user