mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2026-04-10 18:03:57 +00:00
Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
dfb63c3328 | ||
|
|
e8d0ab28c3 | ||
|
|
4fc25c60ae | ||
|
|
16a960d973 | ||
|
|
2afee9d238 |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -29,3 +29,4 @@ test.db
|
||||
tests/data/
|
||||
node_modules/
|
||||
resolvespec-js/dist/
|
||||
.codex
|
||||
|
||||
@@ -95,15 +95,16 @@ func debugScanIntoStruct(rows interface{}, dest interface{}) error {
|
||||
// BunAdapter adapts Bun to work with our Database interface
|
||||
// This demonstrates how the abstraction works with different ORMs
|
||||
type BunAdapter struct {
|
||||
db *bun.DB
|
||||
dbMu sync.RWMutex
|
||||
dbFactory func() (*bun.DB, error)
|
||||
driverName string
|
||||
db *bun.DB
|
||||
dbMu sync.RWMutex
|
||||
dbFactory func() (*bun.DB, error)
|
||||
driverName string
|
||||
metricsEnabled bool
|
||||
}
|
||||
|
||||
// NewBunAdapter creates a new Bun adapter
|
||||
func NewBunAdapter(db *bun.DB) *BunAdapter {
|
||||
adapter := &BunAdapter{db: db}
|
||||
adapter := &BunAdapter{db: db, metricsEnabled: true}
|
||||
// Initialize driver name
|
||||
adapter.driverName = adapter.DriverName()
|
||||
return adapter
|
||||
@@ -115,6 +116,12 @@ func (b *BunAdapter) WithDBFactory(factory func() (*bun.DB, error)) *BunAdapter
|
||||
return b
|
||||
}
|
||||
|
||||
// SetMetricsEnabled enables or disables query metrics for this adapter.
|
||||
func (b *BunAdapter) SetMetricsEnabled(enabled bool) *BunAdapter {
|
||||
b.metricsEnabled = enabled
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *BunAdapter) getDB() *bun.DB {
|
||||
b.dbMu.RLock()
|
||||
defer b.dbMu.RUnlock()
|
||||
@@ -159,22 +166,23 @@ func (b *BunAdapter) DisableQueryDebug() {
|
||||
|
||||
func (b *BunAdapter) NewSelect() common.SelectQuery {
|
||||
return &BunSelectQuery{
|
||||
query: b.getDB().NewSelect(),
|
||||
db: b.db,
|
||||
driverName: b.driverName,
|
||||
query: b.getDB().NewSelect(),
|
||||
db: b.db,
|
||||
driverName: b.driverName,
|
||||
metricsEnabled: b.metricsEnabled,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *BunAdapter) NewInsert() common.InsertQuery {
|
||||
return &BunInsertQuery{query: b.getDB().NewInsert()}
|
||||
return &BunInsertQuery{query: b.getDB().NewInsert(), driverName: b.driverName, metricsEnabled: b.metricsEnabled}
|
||||
}
|
||||
|
||||
func (b *BunAdapter) NewUpdate() common.UpdateQuery {
|
||||
return &BunUpdateQuery{query: b.getDB().NewUpdate()}
|
||||
return &BunUpdateQuery{query: b.getDB().NewUpdate(), driverName: b.driverName, metricsEnabled: b.metricsEnabled}
|
||||
}
|
||||
|
||||
func (b *BunAdapter) NewDelete() common.DeleteQuery {
|
||||
return &BunDeleteQuery{query: b.getDB().NewDelete()}
|
||||
return &BunDeleteQuery{query: b.getDB().NewDelete(), driverName: b.driverName, metricsEnabled: b.metricsEnabled}
|
||||
}
|
||||
|
||||
func (b *BunAdapter) Exec(ctx context.Context, query string, args ...interface{}) (res common.Result, err error) {
|
||||
@@ -183,6 +191,8 @@ func (b *BunAdapter) Exec(ctx context.Context, query string, args ...interface{}
|
||||
err = logger.HandlePanic("BunAdapter.Exec", r)
|
||||
}
|
||||
}()
|
||||
startedAt := time.Now()
|
||||
operation, schema, entity, table := metricTargetFromRawQuery(query, b.driverName)
|
||||
var result sql.Result
|
||||
run := func() error { var e error; result, e = b.getDB().ExecContext(ctx, query, args...); return e }
|
||||
err = run()
|
||||
@@ -191,6 +201,7 @@ func (b *BunAdapter) Exec(ctx context.Context, query string, args ...interface{}
|
||||
err = run()
|
||||
}
|
||||
}
|
||||
recordQueryMetrics(b.metricsEnabled, operation, schema, entity, table, startedAt, err)
|
||||
return &BunResult{result: result}, err
|
||||
}
|
||||
|
||||
@@ -200,22 +211,29 @@ func (b *BunAdapter) Query(ctx context.Context, dest interface{}, query string,
|
||||
err = logger.HandlePanic("BunAdapter.Query", r)
|
||||
}
|
||||
}()
|
||||
startedAt := time.Now()
|
||||
operation, schema, entity, table := metricTargetFromRawQuery(query, b.driverName)
|
||||
err = b.getDB().NewRaw(query, args...).Scan(ctx, dest)
|
||||
if isDBClosed(err) {
|
||||
if reconnErr := b.reconnectDB(); reconnErr == nil {
|
||||
err = b.getDB().NewRaw(query, args...).Scan(ctx, dest)
|
||||
}
|
||||
}
|
||||
recordQueryMetrics(b.metricsEnabled, operation, schema, entity, table, startedAt, err)
|
||||
return err
|
||||
}
|
||||
|
||||
func (b *BunAdapter) BeginTx(ctx context.Context) (common.Database, error) {
|
||||
tx, err := b.getDB().BeginTx(ctx, &sql.TxOptions{})
|
||||
if isDBClosed(err) {
|
||||
if reconnErr := b.reconnectDB(); reconnErr == nil {
|
||||
tx, err = b.getDB().BeginTx(ctx, &sql.TxOptions{})
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// For Bun, we'll return a special wrapper that holds the transaction
|
||||
return &BunTxAdapter{tx: tx, driverName: b.driverName}, nil
|
||||
return &BunTxAdapter{tx: tx, driverName: b.driverName, metricsEnabled: b.metricsEnabled}, nil
|
||||
}
|
||||
|
||||
func (b *BunAdapter) CommitTx(ctx context.Context) error {
|
||||
@@ -236,11 +254,19 @@ func (b *BunAdapter) RunInTransaction(ctx context.Context, fn func(common.Databa
|
||||
err = logger.HandlePanic("BunAdapter.RunInTransaction", r)
|
||||
}
|
||||
}()
|
||||
return b.getDB().RunInTx(ctx, &sql.TxOptions{}, func(ctx context.Context, tx bun.Tx) error {
|
||||
// Create adapter with transaction
|
||||
adapter := &BunTxAdapter{tx: tx, driverName: b.driverName}
|
||||
return fn(adapter)
|
||||
})
|
||||
run := func() error {
|
||||
return b.getDB().RunInTx(ctx, &sql.TxOptions{}, func(ctx context.Context, tx bun.Tx) error {
|
||||
adapter := &BunTxAdapter{tx: tx, driverName: b.driverName, metricsEnabled: b.metricsEnabled}
|
||||
return fn(adapter)
|
||||
})
|
||||
}
|
||||
err = run()
|
||||
if isDBClosed(err) {
|
||||
if reconnErr := b.reconnectDB(); reconnErr == nil {
|
||||
err = run()
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (b *BunAdapter) GetUnderlyingDB() interface{} {
|
||||
@@ -268,25 +294,24 @@ type BunSelectQuery struct {
|
||||
hasModel bool // Track if Model() was called
|
||||
schema string // Separated schema name
|
||||
tableName string // Just the table name, without schema
|
||||
entity string
|
||||
tableAlias string
|
||||
driverName string // Database driver name (postgres, sqlite, mssql)
|
||||
inJoinContext bool // Track if we're in a JOIN relation context
|
||||
joinTableAlias string // Alias to use for JOIN conditions
|
||||
skipAutoDetect bool // Skip auto-detection to prevent circular calls
|
||||
customPreloads map[string][]func(common.SelectQuery) common.SelectQuery // Relations to load with custom implementation
|
||||
metricsEnabled bool
|
||||
}
|
||||
|
||||
func (b *BunSelectQuery) Model(model interface{}) common.SelectQuery {
|
||||
b.query = b.query.Model(model)
|
||||
b.hasModel = true // Mark that we have a model
|
||||
|
||||
// Try to get table name from model if it implements TableNameProvider
|
||||
if provider, ok := model.(common.TableNameProvider); ok {
|
||||
fullTableName := provider.TableName()
|
||||
// Check if the table name contains schema (e.g., "schema.table")
|
||||
// For SQLite, this will convert "schema.table" to "schema_table"
|
||||
b.schema, b.tableName = parseTableName(fullTableName, b.driverName)
|
||||
b.schema, b.tableName = schemaAndTableFromModel(model, b.driverName)
|
||||
if b.tableName == "" {
|
||||
b.schema, b.tableName = parseTableName(b.query.GetTableName(), b.driverName)
|
||||
}
|
||||
b.entity = entityNameFromModel(model, b.tableName)
|
||||
|
||||
if provider, ok := model.(common.TableAliasProvider); ok {
|
||||
b.tableAlias = provider.TableAlias()
|
||||
@@ -300,6 +325,9 @@ func (b *BunSelectQuery) Table(table string) common.SelectQuery {
|
||||
// Check if the table name contains schema (e.g., "schema.table")
|
||||
// For SQLite, this will convert "schema.table" to "schema_table"
|
||||
b.schema, b.tableName = parseTableName(table, b.driverName)
|
||||
if b.entity == "" {
|
||||
b.entity = cleanMetricIdentifier(b.tableName)
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
@@ -605,9 +633,10 @@ func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.S
|
||||
|
||||
// Wrap the incoming *bun.SelectQuery in our adapter
|
||||
wrapper := &BunSelectQuery{
|
||||
query: sq,
|
||||
db: b.db,
|
||||
driverName: b.driverName,
|
||||
query: sq,
|
||||
db: b.db,
|
||||
driverName: b.driverName,
|
||||
metricsEnabled: b.metricsEnabled,
|
||||
}
|
||||
|
||||
// Try to extract table name and alias from the preload model
|
||||
@@ -858,7 +887,7 @@ func (b *BunSelectQuery) loadRelationLevel(ctx context.Context, parentRecords re
|
||||
|
||||
// Apply user's functions (if any)
|
||||
if isLast && len(applyFuncs) > 0 {
|
||||
wrapper := &BunSelectQuery{query: query, db: b.db, driverName: b.driverName}
|
||||
wrapper := &BunSelectQuery{query: query, db: b.db, driverName: b.driverName, metricsEnabled: b.metricsEnabled}
|
||||
for _, fn := range applyFuncs {
|
||||
if fn != nil {
|
||||
wrapper = fn(wrapper).(*BunSelectQuery)
|
||||
@@ -1210,27 +1239,28 @@ func (b *BunSelectQuery) Having(having string, args ...interface{}) common.Selec
|
||||
}
|
||||
|
||||
func (b *BunSelectQuery) Scan(ctx context.Context, dest interface{}) (err error) {
|
||||
startedAt := time.Now()
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = logger.HandlePanic("BunSelectQuery.Scan", r)
|
||||
}
|
||||
recordQueryMetrics(b.metricsEnabled, "SELECT", b.schema, b.entity, b.tableName, startedAt, err)
|
||||
}()
|
||||
if dest == nil {
|
||||
return fmt.Errorf("destination cannot be nil")
|
||||
err = fmt.Errorf("destination cannot be nil")
|
||||
return err
|
||||
}
|
||||
|
||||
err = b.query.Scan(ctx, dest)
|
||||
if err != nil {
|
||||
// Log SQL string for debugging
|
||||
sqlStr := b.query.String()
|
||||
logger.Error("BunSelectQuery.Scan failed. SQL: %s. Error: %v", sqlStr, err)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
return err
|
||||
}
|
||||
|
||||
func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) {
|
||||
startedAt := time.Now()
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
// Enhanced panic recovery with model information
|
||||
@@ -1240,7 +1270,6 @@ func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) {
|
||||
modelValue := model.Value()
|
||||
modelInfo = fmt.Sprintf("Model type: %T", modelValue)
|
||||
|
||||
// Try to get the model's underlying struct type
|
||||
v := reflect.ValueOf(modelValue)
|
||||
if v.Kind() == reflect.Ptr {
|
||||
v = v.Elem()
|
||||
@@ -1260,9 +1289,11 @@ func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) {
|
||||
logger.Error("Panic in BunSelectQuery.ScanModel: %v. %s. SQL: %s", r, modelInfo, sqlStr)
|
||||
err = logger.HandlePanic("BunSelectQuery.ScanModel", r)
|
||||
}
|
||||
recordQueryMetrics(b.metricsEnabled, "SELECT", b.schema, b.entity, b.tableName, startedAt, err)
|
||||
}()
|
||||
if b.query.GetModel() == nil {
|
||||
return fmt.Errorf("model is nil")
|
||||
err = fmt.Errorf("model is nil")
|
||||
return err
|
||||
}
|
||||
|
||||
// Optional: Enable detailed field-level debugging (set to true to debug)
|
||||
@@ -1278,7 +1309,6 @@ func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) {
|
||||
|
||||
err = b.query.Scan(ctx)
|
||||
if err != nil {
|
||||
// Log SQL string for debugging
|
||||
sqlStr := b.query.String()
|
||||
logger.Error("BunSelectQuery.ScanModel failed. SQL: %s. Error: %v", sqlStr, err)
|
||||
return err
|
||||
@@ -1287,7 +1317,7 @@ func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) {
|
||||
// After main query, load custom preloads using separate queries
|
||||
if len(b.customPreloads) > 0 {
|
||||
logger.Info("Loading %d custom preload(s) with separate queries", len(b.customPreloads))
|
||||
if err := b.loadCustomPreloads(ctx); err != nil {
|
||||
if err = b.loadCustomPreloads(ctx); err != nil {
|
||||
logger.Error("Failed to load custom preloads: %v", err)
|
||||
return err
|
||||
}
|
||||
@@ -1297,21 +1327,22 @@ func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) {
|
||||
}
|
||||
|
||||
func (b *BunSelectQuery) Count(ctx context.Context) (count int, err error) {
|
||||
startedAt := time.Now()
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = logger.HandlePanic("BunSelectQuery.Count", r)
|
||||
count = 0
|
||||
}
|
||||
recordQueryMetrics(b.metricsEnabled, "COUNT", b.schema, b.entity, b.tableName, startedAt, err)
|
||||
}()
|
||||
// If Model() was set, use bun's native Count() which works properly
|
||||
if b.hasModel {
|
||||
count, err := b.query.Count(ctx)
|
||||
count, err = b.query.Count(ctx) // assign to named returns, not shadow vars
|
||||
if err != nil {
|
||||
// Log SQL string for debugging
|
||||
sqlStr := b.query.String()
|
||||
logger.Error("BunSelectQuery.Count failed. SQL: %s. Error: %v", sqlStr, err)
|
||||
}
|
||||
return count, err
|
||||
return
|
||||
}
|
||||
|
||||
// Otherwise, wrap as subquery to avoid "Model(nil)" error
|
||||
@@ -1321,39 +1352,49 @@ func (b *BunSelectQuery) Count(ctx context.Context) (count int, err error) {
|
||||
ColumnExpr("COUNT(*)")
|
||||
err = countQuery.Scan(ctx, &count)
|
||||
if err != nil {
|
||||
// Log SQL string for debugging
|
||||
sqlStr := countQuery.String()
|
||||
logger.Error("BunSelectQuery.Count (subquery) failed. SQL: %s. Error: %v", sqlStr, err)
|
||||
}
|
||||
return count, err
|
||||
return
|
||||
}
|
||||
|
||||
func (b *BunSelectQuery) Exists(ctx context.Context) (exists bool, err error) {
|
||||
startedAt := time.Now()
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = logger.HandlePanic("BunSelectQuery.Exists", r)
|
||||
exists = false
|
||||
}
|
||||
recordQueryMetrics(b.metricsEnabled, "EXISTS", b.schema, b.entity, b.tableName, startedAt, err)
|
||||
}()
|
||||
exists, err = b.query.Exists(ctx)
|
||||
if err != nil {
|
||||
// Log SQL string for debugging
|
||||
sqlStr := b.query.String()
|
||||
logger.Error("BunSelectQuery.Exists failed. SQL: %s. Error: %v", sqlStr, err)
|
||||
}
|
||||
return exists, err
|
||||
return
|
||||
}
|
||||
|
||||
// BunInsertQuery implements InsertQuery for Bun
|
||||
type BunInsertQuery struct {
|
||||
query *bun.InsertQuery
|
||||
values map[string]interface{}
|
||||
hasModel bool
|
||||
query *bun.InsertQuery
|
||||
values map[string]interface{}
|
||||
hasModel bool
|
||||
driverName string
|
||||
schema string
|
||||
tableName string
|
||||
entity string
|
||||
metricsEnabled bool
|
||||
}
|
||||
|
||||
func (b *BunInsertQuery) Model(model interface{}) common.InsertQuery {
|
||||
b.query = b.query.Model(model)
|
||||
b.hasModel = true
|
||||
b.schema, b.tableName = schemaAndTableFromModel(model, b.driverName)
|
||||
if b.tableName == "" {
|
||||
b.schema, b.tableName = parseTableName(b.query.GetTableName(), b.driverName)
|
||||
}
|
||||
b.entity = entityNameFromModel(model, b.tableName)
|
||||
return b
|
||||
}
|
||||
|
||||
@@ -1362,6 +1403,10 @@ func (b *BunInsertQuery) Table(table string) common.InsertQuery {
|
||||
return b
|
||||
}
|
||||
b.query = b.query.Table(table)
|
||||
b.schema, b.tableName = parseTableName(table, b.driverName)
|
||||
if b.entity == "" {
|
||||
b.entity = cleanMetricIdentifier(b.tableName)
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
@@ -1391,6 +1436,7 @@ func (b *BunInsertQuery) Exec(ctx context.Context) (res common.Result, err error
|
||||
err = logger.HandlePanic("BunInsertQuery.Exec", r)
|
||||
}
|
||||
}()
|
||||
startedAt := time.Now()
|
||||
if len(b.values) > 0 {
|
||||
if !b.hasModel {
|
||||
// If no model was set, use the values map as the model
|
||||
@@ -1404,29 +1450,45 @@ func (b *BunInsertQuery) Exec(ctx context.Context) (res common.Result, err error
|
||||
}
|
||||
}
|
||||
result, err := b.query.Exec(ctx)
|
||||
recordQueryMetrics(b.metricsEnabled, "INSERT", b.schema, b.entity, b.tableName, startedAt, err)
|
||||
return &BunResult{result: result}, err
|
||||
}
|
||||
|
||||
// BunUpdateQuery implements UpdateQuery for Bun
|
||||
type BunUpdateQuery struct {
|
||||
query *bun.UpdateQuery
|
||||
model interface{}
|
||||
query *bun.UpdateQuery
|
||||
model interface{}
|
||||
driverName string
|
||||
schema string
|
||||
tableName string
|
||||
entity string
|
||||
metricsEnabled bool
|
||||
}
|
||||
|
||||
func (b *BunUpdateQuery) Model(model interface{}) common.UpdateQuery {
|
||||
b.query = b.query.Model(model)
|
||||
b.model = model
|
||||
b.schema, b.tableName = schemaAndTableFromModel(model, b.driverName)
|
||||
if b.tableName == "" {
|
||||
b.schema, b.tableName = parseTableName(b.query.GetTableName(), b.driverName)
|
||||
}
|
||||
b.entity = entityNameFromModel(model, b.tableName)
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *BunUpdateQuery) Table(table string) common.UpdateQuery {
|
||||
b.query = b.query.Table(table)
|
||||
b.schema, b.tableName = parseTableName(table, b.driverName)
|
||||
if b.entity == "" {
|
||||
b.entity = cleanMetricIdentifier(b.tableName)
|
||||
}
|
||||
if b.model == nil {
|
||||
// Try to get table name from table string if model is not set
|
||||
|
||||
model, err := modelregistry.GetModelByName(table)
|
||||
if err == nil {
|
||||
b.model = model
|
||||
b.entity = entityNameFromModel(model, b.tableName)
|
||||
}
|
||||
}
|
||||
return b
|
||||
@@ -1477,27 +1539,43 @@ func (b *BunUpdateQuery) Exec(ctx context.Context) (res common.Result, err error
|
||||
err = logger.HandlePanic("BunUpdateQuery.Exec", r)
|
||||
}
|
||||
}()
|
||||
startedAt := time.Now()
|
||||
result, err := b.query.Exec(ctx)
|
||||
if err != nil {
|
||||
// Log SQL string for debugging
|
||||
sqlStr := b.query.String()
|
||||
logger.Error("BunUpdateQuery.Exec failed. SQL: %s. Error: %v", sqlStr, err)
|
||||
}
|
||||
recordQueryMetrics(b.metricsEnabled, "UPDATE", b.schema, b.entity, b.tableName, startedAt, err)
|
||||
return &BunResult{result: result}, err
|
||||
}
|
||||
|
||||
// BunDeleteQuery implements DeleteQuery for Bun
|
||||
type BunDeleteQuery struct {
|
||||
query *bun.DeleteQuery
|
||||
query *bun.DeleteQuery
|
||||
driverName string
|
||||
schema string
|
||||
tableName string
|
||||
entity string
|
||||
metricsEnabled bool
|
||||
}
|
||||
|
||||
func (b *BunDeleteQuery) Model(model interface{}) common.DeleteQuery {
|
||||
b.query = b.query.Model(model)
|
||||
b.schema, b.tableName = schemaAndTableFromModel(model, b.driverName)
|
||||
if b.tableName == "" {
|
||||
b.schema, b.tableName = parseTableName(b.query.GetTableName(), b.driverName)
|
||||
}
|
||||
b.entity = entityNameFromModel(model, b.tableName)
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *BunDeleteQuery) Table(table string) common.DeleteQuery {
|
||||
b.query = b.query.Table(table)
|
||||
b.schema, b.tableName = parseTableName(table, b.driverName)
|
||||
if b.entity == "" {
|
||||
b.entity = cleanMetricIdentifier(b.tableName)
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
@@ -1512,12 +1590,14 @@ func (b *BunDeleteQuery) Exec(ctx context.Context) (res common.Result, err error
|
||||
err = logger.HandlePanic("BunDeleteQuery.Exec", r)
|
||||
}
|
||||
}()
|
||||
startedAt := time.Now()
|
||||
result, err := b.query.Exec(ctx)
|
||||
if err != nil {
|
||||
// Log SQL string for debugging
|
||||
sqlStr := b.query.String()
|
||||
logger.Error("BunDeleteQuery.Exec failed. SQL: %s. Error: %v", sqlStr, err)
|
||||
}
|
||||
recordQueryMetrics(b.metricsEnabled, "DELETE", b.schema, b.entity, b.tableName, startedAt, err)
|
||||
return &BunResult{result: result}, err
|
||||
}
|
||||
|
||||
@@ -1543,37 +1623,46 @@ func (b *BunResult) LastInsertId() (int64, error) {
|
||||
|
||||
// BunTxAdapter wraps a Bun transaction to implement the Database interface
|
||||
type BunTxAdapter struct {
|
||||
tx bun.Tx
|
||||
driverName string
|
||||
tx bun.Tx
|
||||
driverName string
|
||||
metricsEnabled bool
|
||||
}
|
||||
|
||||
func (b *BunTxAdapter) NewSelect() common.SelectQuery {
|
||||
return &BunSelectQuery{
|
||||
query: b.tx.NewSelect(),
|
||||
db: b.tx,
|
||||
driverName: b.driverName,
|
||||
query: b.tx.NewSelect(),
|
||||
db: b.tx,
|
||||
driverName: b.driverName,
|
||||
metricsEnabled: b.metricsEnabled,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *BunTxAdapter) NewInsert() common.InsertQuery {
|
||||
return &BunInsertQuery{query: b.tx.NewInsert()}
|
||||
return &BunInsertQuery{query: b.tx.NewInsert(), driverName: b.driverName, metricsEnabled: b.metricsEnabled}
|
||||
}
|
||||
|
||||
func (b *BunTxAdapter) NewUpdate() common.UpdateQuery {
|
||||
return &BunUpdateQuery{query: b.tx.NewUpdate()}
|
||||
return &BunUpdateQuery{query: b.tx.NewUpdate(), driverName: b.driverName, metricsEnabled: b.metricsEnabled}
|
||||
}
|
||||
|
||||
func (b *BunTxAdapter) NewDelete() common.DeleteQuery {
|
||||
return &BunDeleteQuery{query: b.tx.NewDelete()}
|
||||
return &BunDeleteQuery{query: b.tx.NewDelete(), driverName: b.driverName, metricsEnabled: b.metricsEnabled}
|
||||
}
|
||||
|
||||
func (b *BunTxAdapter) Exec(ctx context.Context, query string, args ...interface{}) (common.Result, error) {
|
||||
startedAt := time.Now()
|
||||
operation, schema, entity, table := metricTargetFromRawQuery(query, b.driverName)
|
||||
result, err := b.tx.ExecContext(ctx, query, args...)
|
||||
recordQueryMetrics(b.metricsEnabled, operation, schema, entity, table, startedAt, err)
|
||||
return &BunResult{result: result}, err
|
||||
}
|
||||
|
||||
func (b *BunTxAdapter) Query(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
|
||||
return b.tx.NewRaw(query, args...).Scan(ctx, dest)
|
||||
startedAt := time.Now()
|
||||
operation, schema, entity, table := metricTargetFromRawQuery(query, b.driverName)
|
||||
err := b.tx.NewRaw(query, args...).Scan(ctx, dest)
|
||||
recordQueryMetrics(b.metricsEnabled, operation, schema, entity, table, startedAt, err)
|
||||
return err
|
||||
}
|
||||
|
||||
func (b *BunTxAdapter) BeginTx(ctx context.Context) (common.Database, error) {
|
||||
|
||||
@@ -4,6 +4,8 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
@@ -15,22 +17,93 @@ import (
|
||||
|
||||
// GormAdapter adapts GORM to work with our Database interface
|
||||
type GormAdapter struct {
|
||||
db *gorm.DB
|
||||
driverName string
|
||||
dbMu sync.RWMutex
|
||||
db *gorm.DB
|
||||
dbFactory func() (*gorm.DB, error)
|
||||
driverName string
|
||||
metricsEnabled bool
|
||||
}
|
||||
|
||||
// NewGormAdapter creates a new GORM adapter
|
||||
func NewGormAdapter(db *gorm.DB) *GormAdapter {
|
||||
adapter := &GormAdapter{db: db}
|
||||
adapter := &GormAdapter{db: db, metricsEnabled: true}
|
||||
// Initialize driver name
|
||||
adapter.driverName = adapter.DriverName()
|
||||
return adapter
|
||||
}
|
||||
|
||||
// WithDBFactory configures a factory used to reopen the database connection if it is closed.
|
||||
func (g *GormAdapter) WithDBFactory(factory func() (*gorm.DB, error)) *GormAdapter {
|
||||
g.dbFactory = factory
|
||||
return g
|
||||
}
|
||||
|
||||
// SetMetricsEnabled enables or disables query metrics for this adapter.
|
||||
func (g *GormAdapter) SetMetricsEnabled(enabled bool) *GormAdapter {
|
||||
g.metricsEnabled = enabled
|
||||
return g
|
||||
}
|
||||
|
||||
func (g *GormAdapter) getDB() *gorm.DB {
|
||||
g.dbMu.RLock()
|
||||
defer g.dbMu.RUnlock()
|
||||
return g.db
|
||||
}
|
||||
|
||||
func (g *GormAdapter) reconnectDB(targets ...*gorm.DB) error {
|
||||
if g.dbFactory == nil {
|
||||
return fmt.Errorf("no db factory configured for reconnect")
|
||||
}
|
||||
|
||||
freshDB, err := g.dbFactory()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
g.dbMu.Lock()
|
||||
previous := g.db
|
||||
g.db = freshDB
|
||||
g.driverName = normalizeGormDriverName(freshDB)
|
||||
g.dbMu.Unlock()
|
||||
|
||||
if previous != nil {
|
||||
syncGormConnPool(previous, freshDB)
|
||||
}
|
||||
|
||||
for _, target := range targets {
|
||||
if target != nil && target != previous {
|
||||
syncGormConnPool(target, freshDB)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func syncGormConnPool(target, fresh *gorm.DB) {
|
||||
if target == nil || fresh == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if target.Config != nil && fresh.Config != nil {
|
||||
target.ConnPool = fresh.ConnPool
|
||||
}
|
||||
|
||||
if target.Statement != nil {
|
||||
if fresh.Statement != nil && fresh.Statement.ConnPool != nil {
|
||||
target.Statement.ConnPool = fresh.Statement.ConnPool
|
||||
} else if fresh.Config != nil {
|
||||
target.Statement.ConnPool = fresh.ConnPool
|
||||
}
|
||||
target.Statement.DB = target
|
||||
}
|
||||
}
|
||||
|
||||
// EnableQueryDebug enables query debugging which logs all SQL queries including preloads
|
||||
// This is useful for debugging preload queries that may be failing
|
||||
func (g *GormAdapter) EnableQueryDebug() *GormAdapter {
|
||||
g.dbMu.Lock()
|
||||
g.db = g.db.Debug()
|
||||
g.dbMu.Unlock()
|
||||
logger.Info("GORM query debug mode enabled - all SQL queries will be logged")
|
||||
return g
|
||||
}
|
||||
@@ -44,19 +117,19 @@ func (g *GormAdapter) DisableQueryDebug() *GormAdapter {
|
||||
}
|
||||
|
||||
func (g *GormAdapter) NewSelect() common.SelectQuery {
|
||||
return &GormSelectQuery{db: g.db, driverName: g.driverName}
|
||||
return &GormSelectQuery{db: g.getDB(), driverName: g.driverName, reconnect: g.reconnectDB, metricsEnabled: g.metricsEnabled}
|
||||
}
|
||||
|
||||
func (g *GormAdapter) NewInsert() common.InsertQuery {
|
||||
return &GormInsertQuery{db: g.db}
|
||||
return &GormInsertQuery{db: g.getDB(), reconnect: g.reconnectDB, driverName: g.driverName, metricsEnabled: g.metricsEnabled}
|
||||
}
|
||||
|
||||
func (g *GormAdapter) NewUpdate() common.UpdateQuery {
|
||||
return &GormUpdateQuery{db: g.db}
|
||||
return &GormUpdateQuery{db: g.getDB(), reconnect: g.reconnectDB, driverName: g.driverName, metricsEnabled: g.metricsEnabled}
|
||||
}
|
||||
|
||||
func (g *GormAdapter) NewDelete() common.DeleteQuery {
|
||||
return &GormDeleteQuery{db: g.db}
|
||||
return &GormDeleteQuery{db: g.getDB(), reconnect: g.reconnectDB, driverName: g.driverName, metricsEnabled: g.metricsEnabled}
|
||||
}
|
||||
|
||||
func (g *GormAdapter) Exec(ctx context.Context, query string, args ...interface{}) (res common.Result, err error) {
|
||||
@@ -65,7 +138,18 @@ func (g *GormAdapter) Exec(ctx context.Context, query string, args ...interface{
|
||||
err = logger.HandlePanic("GormAdapter.Exec", r)
|
||||
}
|
||||
}()
|
||||
result := g.db.WithContext(ctx).Exec(query, args...)
|
||||
startedAt := time.Now()
|
||||
operation, schema, entity, table := metricTargetFromRawQuery(query, g.driverName)
|
||||
run := func() *gorm.DB {
|
||||
return g.getDB().WithContext(ctx).Exec(query, args...)
|
||||
}
|
||||
result := run()
|
||||
if isDBClosed(result.Error) {
|
||||
if reconnErr := g.reconnectDB(); reconnErr == nil {
|
||||
result = run()
|
||||
}
|
||||
}
|
||||
recordQueryMetrics(g.metricsEnabled, operation, schema, entity, table, startedAt, result.Error)
|
||||
return &GormResult{result: result}, result.Error
|
||||
}
|
||||
|
||||
@@ -75,15 +159,35 @@ func (g *GormAdapter) Query(ctx context.Context, dest interface{}, query string,
|
||||
err = logger.HandlePanic("GormAdapter.Query", r)
|
||||
}
|
||||
}()
|
||||
return g.db.WithContext(ctx).Raw(query, args...).Find(dest).Error
|
||||
startedAt := time.Now()
|
||||
operation, schema, entity, table := metricTargetFromRawQuery(query, g.driverName)
|
||||
run := func() error {
|
||||
return g.getDB().WithContext(ctx).Raw(query, args...).Find(dest).Error
|
||||
}
|
||||
err = run()
|
||||
if isDBClosed(err) {
|
||||
if reconnErr := g.reconnectDB(); reconnErr == nil {
|
||||
err = run()
|
||||
}
|
||||
}
|
||||
recordQueryMetrics(g.metricsEnabled, operation, schema, entity, table, startedAt, err)
|
||||
return err
|
||||
}
|
||||
|
||||
func (g *GormAdapter) BeginTx(ctx context.Context) (common.Database, error) {
|
||||
tx := g.db.WithContext(ctx).Begin()
|
||||
run := func() *gorm.DB {
|
||||
return g.getDB().WithContext(ctx).Begin()
|
||||
}
|
||||
tx := run()
|
||||
if isDBClosed(tx.Error) {
|
||||
if reconnErr := g.reconnectDB(); reconnErr == nil {
|
||||
tx = run()
|
||||
}
|
||||
}
|
||||
if tx.Error != nil {
|
||||
return nil, tx.Error
|
||||
}
|
||||
return &GormAdapter{db: tx, driverName: g.driverName}, nil
|
||||
return &GormAdapter{db: tx, dbFactory: g.dbFactory, driverName: g.driverName, metricsEnabled: g.metricsEnabled}, nil
|
||||
}
|
||||
|
||||
func (g *GormAdapter) CommitTx(ctx context.Context) error {
|
||||
@@ -100,24 +204,37 @@ func (g *GormAdapter) RunInTransaction(ctx context.Context, fn func(common.Datab
|
||||
err = logger.HandlePanic("GormAdapter.RunInTransaction", r)
|
||||
}
|
||||
}()
|
||||
return g.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
adapter := &GormAdapter{db: tx, driverName: g.driverName}
|
||||
return fn(adapter)
|
||||
})
|
||||
run := func() error {
|
||||
return g.getDB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
adapter := &GormAdapter{db: tx, dbFactory: g.dbFactory, driverName: g.driverName, metricsEnabled: g.metricsEnabled}
|
||||
return fn(adapter)
|
||||
})
|
||||
}
|
||||
err = run()
|
||||
if isDBClosed(err) {
|
||||
if reconnErr := g.reconnectDB(); reconnErr == nil {
|
||||
err = run()
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (g *GormAdapter) GetUnderlyingDB() interface{} {
|
||||
return g.db
|
||||
return g.getDB()
|
||||
}
|
||||
|
||||
func (g *GormAdapter) DriverName() string {
|
||||
if g.db.Dialector == nil {
|
||||
return normalizeGormDriverName(g.getDB())
|
||||
}
|
||||
|
||||
func normalizeGormDriverName(db *gorm.DB) string {
|
||||
if db == nil || db.Dialector == nil {
|
||||
return ""
|
||||
}
|
||||
// Normalize GORM's dialector name to match the project's canonical vocabulary.
|
||||
// GORM returns "sqlserver" for MSSQL; the rest of the project uses "mssql".
|
||||
// GORM returns "sqlite" or "sqlite3" for SQLite; we normalize to "sqlite".
|
||||
switch name := g.db.Name(); name {
|
||||
switch name := db.Name(); name {
|
||||
case "sqlserver":
|
||||
return "mssql"
|
||||
case "sqlite3":
|
||||
@@ -130,24 +247,21 @@ func (g *GormAdapter) DriverName() string {
|
||||
// GormSelectQuery implements SelectQuery for GORM
|
||||
type GormSelectQuery struct {
|
||||
db *gorm.DB
|
||||
reconnect func(...*gorm.DB) error
|
||||
schema string // Separated schema name
|
||||
tableName string // Just the table name, without schema
|
||||
entity string
|
||||
tableAlias string
|
||||
driverName string // Database driver name (postgres, sqlite, mssql)
|
||||
inJoinContext bool // Track if we're in a JOIN relation context
|
||||
joinTableAlias string // Alias to use for JOIN conditions
|
||||
metricsEnabled bool
|
||||
}
|
||||
|
||||
func (g *GormSelectQuery) Model(model interface{}) common.SelectQuery {
|
||||
g.db = g.db.Model(model)
|
||||
|
||||
// Try to get table name from model if it implements TableNameProvider
|
||||
if provider, ok := model.(common.TableNameProvider); ok {
|
||||
fullTableName := provider.TableName()
|
||||
// Check if the table name contains schema (e.g., "schema.table")
|
||||
// For SQLite, this will convert "schema.table" to "schema_table"
|
||||
g.schema, g.tableName = parseTableName(fullTableName, g.driverName)
|
||||
}
|
||||
g.schema, g.tableName = schemaAndTableFromModel(model, g.driverName)
|
||||
g.entity = entityNameFromModel(model, g.tableName)
|
||||
|
||||
if provider, ok := model.(common.TableAliasProvider); ok {
|
||||
g.tableAlias = provider.TableAlias()
|
||||
@@ -161,6 +275,9 @@ func (g *GormSelectQuery) Table(table string) common.SelectQuery {
|
||||
// Check if the table name contains schema (e.g., "schema.table")
|
||||
// For SQLite, this will convert "schema.table" to "schema_table"
|
||||
g.schema, g.tableName = parseTableName(table, g.driverName)
|
||||
if g.entity == "" {
|
||||
g.entity = cleanMetricIdentifier(g.tableName)
|
||||
}
|
||||
|
||||
return g
|
||||
}
|
||||
@@ -346,8 +463,10 @@ func (g *GormSelectQuery) PreloadRelation(relation string, apply ...func(common.
|
||||
}
|
||||
|
||||
wrapper := &GormSelectQuery{
|
||||
db: db,
|
||||
driverName: g.driverName,
|
||||
db: db,
|
||||
reconnect: g.reconnect,
|
||||
driverName: g.driverName,
|
||||
metricsEnabled: g.metricsEnabled,
|
||||
}
|
||||
|
||||
current := common.SelectQuery(wrapper)
|
||||
@@ -385,9 +504,11 @@ func (g *GormSelectQuery) JoinRelation(relation string, apply ...func(common.Sel
|
||||
|
||||
wrapper := &GormSelectQuery{
|
||||
db: db,
|
||||
reconnect: g.reconnect,
|
||||
driverName: g.driverName,
|
||||
inJoinContext: true, // Mark as JOIN context
|
||||
joinTableAlias: strings.ToLower(relation), // Use relation name as alias
|
||||
metricsEnabled: g.metricsEnabled,
|
||||
}
|
||||
current := common.SelectQuery(wrapper)
|
||||
|
||||
@@ -444,7 +565,16 @@ func (g *GormSelectQuery) Scan(ctx context.Context, dest interface{}) (err error
|
||||
err = logger.HandlePanic("GormSelectQuery.Scan", r)
|
||||
}
|
||||
}()
|
||||
err = g.db.WithContext(ctx).Find(dest).Error
|
||||
startedAt := time.Now()
|
||||
run := func() error {
|
||||
return g.db.WithContext(ctx).Find(dest).Error
|
||||
}
|
||||
err = run()
|
||||
if isDBClosed(err) && g.reconnect != nil {
|
||||
if reconnErr := g.reconnect(g.db); reconnErr == nil {
|
||||
err = run()
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
// Log SQL string for debugging
|
||||
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
||||
@@ -452,6 +582,7 @@ func (g *GormSelectQuery) Scan(ctx context.Context, dest interface{}) (err error
|
||||
})
|
||||
logger.Error("GormSelectQuery.Scan failed. SQL: %s. Error: %v", sqlStr, err)
|
||||
}
|
||||
recordQueryMetrics(g.metricsEnabled, "SELECT", g.schema, g.entity, g.tableName, startedAt, err)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -464,7 +595,16 @@ func (g *GormSelectQuery) ScanModel(ctx context.Context) (err error) {
|
||||
if g.db.Statement.Model == nil {
|
||||
return fmt.Errorf("ScanModel requires Model() to be set before scanning")
|
||||
}
|
||||
err = g.db.WithContext(ctx).Find(g.db.Statement.Model).Error
|
||||
startedAt := time.Now()
|
||||
run := func() error {
|
||||
return g.db.WithContext(ctx).Find(g.db.Statement.Model).Error
|
||||
}
|
||||
err = run()
|
||||
if isDBClosed(err) && g.reconnect != nil {
|
||||
if reconnErr := g.reconnect(g.db); reconnErr == nil {
|
||||
err = run()
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
// Log SQL string for debugging
|
||||
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
||||
@@ -472,6 +612,7 @@ func (g *GormSelectQuery) ScanModel(ctx context.Context) (err error) {
|
||||
})
|
||||
logger.Error("GormSelectQuery.ScanModel failed. SQL: %s. Error: %v", sqlStr, err)
|
||||
}
|
||||
recordQueryMetrics(g.metricsEnabled, "SELECT", g.schema, g.entity, g.tableName, startedAt, err)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -482,8 +623,17 @@ func (g *GormSelectQuery) Count(ctx context.Context) (count int, err error) {
|
||||
count = 0
|
||||
}
|
||||
}()
|
||||
startedAt := time.Now()
|
||||
var count64 int64
|
||||
err = g.db.WithContext(ctx).Count(&count64).Error
|
||||
run := func() error {
|
||||
return g.db.WithContext(ctx).Count(&count64).Error
|
||||
}
|
||||
err = run()
|
||||
if isDBClosed(err) && g.reconnect != nil {
|
||||
if reconnErr := g.reconnect(g.db); reconnErr == nil {
|
||||
err = run()
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
// Log SQL string for debugging
|
||||
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
||||
@@ -491,6 +641,7 @@ func (g *GormSelectQuery) Count(ctx context.Context) (count int, err error) {
|
||||
})
|
||||
logger.Error("GormSelectQuery.Count failed. SQL: %s. Error: %v", sqlStr, err)
|
||||
}
|
||||
recordQueryMetrics(g.metricsEnabled, "COUNT", g.schema, g.entity, g.tableName, startedAt, err)
|
||||
return int(count64), err
|
||||
}
|
||||
|
||||
@@ -501,8 +652,17 @@ func (g *GormSelectQuery) Exists(ctx context.Context) (exists bool, err error) {
|
||||
exists = false
|
||||
}
|
||||
}()
|
||||
startedAt := time.Now()
|
||||
var count int64
|
||||
err = g.db.WithContext(ctx).Limit(1).Count(&count).Error
|
||||
run := func() error {
|
||||
return g.db.WithContext(ctx).Limit(1).Count(&count).Error
|
||||
}
|
||||
err = run()
|
||||
if isDBClosed(err) && g.reconnect != nil {
|
||||
if reconnErr := g.reconnect(g.db); reconnErr == nil {
|
||||
err = run()
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
// Log SQL string for debugging
|
||||
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
||||
@@ -510,24 +670,37 @@ func (g *GormSelectQuery) Exists(ctx context.Context) (exists bool, err error) {
|
||||
})
|
||||
logger.Error("GormSelectQuery.Exists failed. SQL: %s. Error: %v", sqlStr, err)
|
||||
}
|
||||
recordQueryMetrics(g.metricsEnabled, "EXISTS", g.schema, g.entity, g.tableName, startedAt, err)
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
// GormInsertQuery implements InsertQuery for GORM
|
||||
type GormInsertQuery struct {
|
||||
db *gorm.DB
|
||||
model interface{}
|
||||
values map[string]interface{}
|
||||
db *gorm.DB
|
||||
reconnect func(...*gorm.DB) error
|
||||
model interface{}
|
||||
values map[string]interface{}
|
||||
schema string
|
||||
tableName string
|
||||
entity string
|
||||
driverName string
|
||||
metricsEnabled bool
|
||||
}
|
||||
|
||||
func (g *GormInsertQuery) Model(model interface{}) common.InsertQuery {
|
||||
g.model = model
|
||||
g.db = g.db.Model(model)
|
||||
g.schema, g.tableName = schemaAndTableFromModel(model, g.driverName)
|
||||
g.entity = entityNameFromModel(model, g.tableName)
|
||||
return g
|
||||
}
|
||||
|
||||
func (g *GormInsertQuery) Table(table string) common.InsertQuery {
|
||||
g.db = g.db.Table(table)
|
||||
g.schema, g.tableName = parseTableName(table, g.driverName)
|
||||
if g.entity == "" {
|
||||
g.entity = cleanMetricIdentifier(g.tableName)
|
||||
}
|
||||
return g
|
||||
}
|
||||
|
||||
@@ -555,38 +728,60 @@ func (g *GormInsertQuery) Exec(ctx context.Context) (res common.Result, err erro
|
||||
err = logger.HandlePanic("GormInsertQuery.Exec", r)
|
||||
}
|
||||
}()
|
||||
var result *gorm.DB
|
||||
switch {
|
||||
case g.model != nil:
|
||||
result = g.db.WithContext(ctx).Create(g.model)
|
||||
case g.values != nil:
|
||||
result = g.db.WithContext(ctx).Create(g.values)
|
||||
default:
|
||||
result = g.db.WithContext(ctx).Create(map[string]interface{}{})
|
||||
startedAt := time.Now()
|
||||
run := func() *gorm.DB {
|
||||
switch {
|
||||
case g.model != nil:
|
||||
return g.db.WithContext(ctx).Create(g.model)
|
||||
case g.values != nil:
|
||||
return g.db.WithContext(ctx).Create(g.values)
|
||||
default:
|
||||
return g.db.WithContext(ctx).Create(map[string]interface{}{})
|
||||
}
|
||||
}
|
||||
result := run()
|
||||
if isDBClosed(result.Error) && g.reconnect != nil {
|
||||
if reconnErr := g.reconnect(g.db); reconnErr == nil {
|
||||
result = run()
|
||||
}
|
||||
}
|
||||
recordQueryMetrics(g.metricsEnabled, "INSERT", g.schema, g.entity, g.tableName, startedAt, result.Error)
|
||||
return &GormResult{result: result}, result.Error
|
||||
}
|
||||
|
||||
// GormUpdateQuery implements UpdateQuery for GORM
|
||||
type GormUpdateQuery struct {
|
||||
db *gorm.DB
|
||||
model interface{}
|
||||
updates interface{}
|
||||
db *gorm.DB
|
||||
reconnect func(...*gorm.DB) error
|
||||
model interface{}
|
||||
updates interface{}
|
||||
schema string
|
||||
tableName string
|
||||
entity string
|
||||
driverName string
|
||||
metricsEnabled bool
|
||||
}
|
||||
|
||||
func (g *GormUpdateQuery) Model(model interface{}) common.UpdateQuery {
|
||||
g.model = model
|
||||
g.db = g.db.Model(model)
|
||||
g.schema, g.tableName = schemaAndTableFromModel(model, g.driverName)
|
||||
g.entity = entityNameFromModel(model, g.tableName)
|
||||
return g
|
||||
}
|
||||
|
||||
func (g *GormUpdateQuery) Table(table string) common.UpdateQuery {
|
||||
g.db = g.db.Table(table)
|
||||
g.schema, g.tableName = parseTableName(table, g.driverName)
|
||||
if g.entity == "" {
|
||||
g.entity = cleanMetricIdentifier(g.tableName)
|
||||
}
|
||||
if g.model == nil {
|
||||
// Try to get table name from table string if model is not set
|
||||
model, err := modelregistry.GetModelByName(table)
|
||||
if err == nil {
|
||||
g.model = model
|
||||
g.entity = entityNameFromModel(model, g.tableName)
|
||||
}
|
||||
}
|
||||
return g
|
||||
@@ -647,7 +842,16 @@ func (g *GormUpdateQuery) Exec(ctx context.Context) (res common.Result, err erro
|
||||
err = logger.HandlePanic("GormUpdateQuery.Exec", r)
|
||||
}
|
||||
}()
|
||||
result := g.db.WithContext(ctx).Updates(g.updates)
|
||||
startedAt := time.Now()
|
||||
run := func() *gorm.DB {
|
||||
return g.db.WithContext(ctx).Updates(g.updates)
|
||||
}
|
||||
result := run()
|
||||
if isDBClosed(result.Error) && g.reconnect != nil {
|
||||
if reconnErr := g.reconnect(g.db); reconnErr == nil {
|
||||
result = run()
|
||||
}
|
||||
}
|
||||
if result.Error != nil {
|
||||
// Log SQL string for debugging
|
||||
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
||||
@@ -655,23 +859,36 @@ func (g *GormUpdateQuery) Exec(ctx context.Context) (res common.Result, err erro
|
||||
})
|
||||
logger.Error("GormUpdateQuery.Exec failed. SQL: %s. Error: %v", sqlStr, result.Error)
|
||||
}
|
||||
recordQueryMetrics(g.metricsEnabled, "UPDATE", g.schema, g.entity, g.tableName, startedAt, result.Error)
|
||||
return &GormResult{result: result}, result.Error
|
||||
}
|
||||
|
||||
// GormDeleteQuery implements DeleteQuery for GORM
|
||||
type GormDeleteQuery struct {
|
||||
db *gorm.DB
|
||||
model interface{}
|
||||
db *gorm.DB
|
||||
reconnect func(...*gorm.DB) error
|
||||
model interface{}
|
||||
schema string
|
||||
tableName string
|
||||
entity string
|
||||
driverName string
|
||||
metricsEnabled bool
|
||||
}
|
||||
|
||||
func (g *GormDeleteQuery) Model(model interface{}) common.DeleteQuery {
|
||||
g.model = model
|
||||
g.db = g.db.Model(model)
|
||||
g.schema, g.tableName = schemaAndTableFromModel(model, g.driverName)
|
||||
g.entity = entityNameFromModel(model, g.tableName)
|
||||
return g
|
||||
}
|
||||
|
||||
func (g *GormDeleteQuery) Table(table string) common.DeleteQuery {
|
||||
g.db = g.db.Table(table)
|
||||
g.schema, g.tableName = parseTableName(table, g.driverName)
|
||||
if g.entity == "" {
|
||||
g.entity = cleanMetricIdentifier(g.tableName)
|
||||
}
|
||||
return g
|
||||
}
|
||||
|
||||
@@ -686,7 +903,16 @@ func (g *GormDeleteQuery) Exec(ctx context.Context) (res common.Result, err erro
|
||||
err = logger.HandlePanic("GormDeleteQuery.Exec", r)
|
||||
}
|
||||
}()
|
||||
result := g.db.WithContext(ctx).Delete(g.model)
|
||||
startedAt := time.Now()
|
||||
run := func() *gorm.DB {
|
||||
return g.db.WithContext(ctx).Delete(g.model)
|
||||
}
|
||||
result := run()
|
||||
if isDBClosed(result.Error) && g.reconnect != nil {
|
||||
if reconnErr := g.reconnect(g.db); reconnErr == nil {
|
||||
result = run()
|
||||
}
|
||||
}
|
||||
if result.Error != nil {
|
||||
// Log SQL string for debugging
|
||||
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
||||
@@ -694,6 +920,7 @@ func (g *GormDeleteQuery) Exec(ctx context.Context) (res common.Result, err erro
|
||||
})
|
||||
logger.Error("GormDeleteQuery.Exec failed. SQL: %s. Error: %v", sqlStr, result.Error)
|
||||
}
|
||||
recordQueryMetrics(g.metricsEnabled, "DELETE", g.schema, g.entity, g.tableName, startedAt, result.Error)
|
||||
return &GormResult{result: result}, result.Error
|
||||
}
|
||||
|
||||
|
||||
@@ -5,8 +5,10 @@ import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
@@ -17,10 +19,11 @@ import (
|
||||
// PgSQLAdapter adapts standard database/sql to work with our Database interface
|
||||
// This provides a lightweight PostgreSQL adapter without ORM overhead
|
||||
type PgSQLAdapter struct {
|
||||
db *sql.DB
|
||||
dbMu sync.RWMutex
|
||||
dbFactory func() (*sql.DB, error)
|
||||
driverName string
|
||||
db *sql.DB
|
||||
dbMu sync.RWMutex
|
||||
dbFactory func() (*sql.DB, error)
|
||||
driverName string
|
||||
metricsEnabled bool
|
||||
}
|
||||
|
||||
// NewPgSQLAdapter creates a new adapter wrapping a standard sql.DB.
|
||||
@@ -31,7 +34,7 @@ func NewPgSQLAdapter(db *sql.DB, driverName ...string) *PgSQLAdapter {
|
||||
if len(driverName) > 0 && driverName[0] != "" {
|
||||
name = driverName[0]
|
||||
}
|
||||
return &PgSQLAdapter{db: db, driverName: name}
|
||||
return &PgSQLAdapter{db: db, driverName: name, metricsEnabled: true}
|
||||
}
|
||||
|
||||
// WithDBFactory configures a factory used to reopen the database connection if it is closed.
|
||||
@@ -40,6 +43,12 @@ func (p *PgSQLAdapter) WithDBFactory(factory func() (*sql.DB, error)) *PgSQLAdap
|
||||
return p
|
||||
}
|
||||
|
||||
// SetMetricsEnabled enables or disables query metrics for this adapter.
|
||||
func (p *PgSQLAdapter) SetMetricsEnabled(enabled bool) *PgSQLAdapter {
|
||||
p.metricsEnabled = enabled
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *PgSQLAdapter) getDB() *sql.DB {
|
||||
p.dbMu.RLock()
|
||||
defer p.dbMu.RUnlock()
|
||||
@@ -71,37 +80,41 @@ func (p *PgSQLAdapter) EnableQueryDebug() {
|
||||
|
||||
func (p *PgSQLAdapter) NewSelect() common.SelectQuery {
|
||||
return &PgSQLSelectQuery{
|
||||
db: p.getDB(),
|
||||
driverName: p.driverName,
|
||||
columns: []string{"*"},
|
||||
args: make([]interface{}, 0),
|
||||
db: p.getDB(),
|
||||
driverName: p.driverName,
|
||||
columns: []string{"*"},
|
||||
args: make([]interface{}, 0),
|
||||
metricsEnabled: p.metricsEnabled,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *PgSQLAdapter) NewInsert() common.InsertQuery {
|
||||
return &PgSQLInsertQuery{
|
||||
db: p.getDB(),
|
||||
driverName: p.driverName,
|
||||
values: make(map[string]interface{}),
|
||||
db: p.getDB(),
|
||||
driverName: p.driverName,
|
||||
values: make(map[string]interface{}),
|
||||
metricsEnabled: p.metricsEnabled,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *PgSQLAdapter) NewUpdate() common.UpdateQuery {
|
||||
return &PgSQLUpdateQuery{
|
||||
db: p.getDB(),
|
||||
driverName: p.driverName,
|
||||
sets: make(map[string]interface{}),
|
||||
args: make([]interface{}, 0),
|
||||
whereClauses: make([]string, 0),
|
||||
db: p.getDB(),
|
||||
driverName: p.driverName,
|
||||
sets: make(map[string]interface{}),
|
||||
args: make([]interface{}, 0),
|
||||
whereClauses: make([]string, 0),
|
||||
metricsEnabled: p.metricsEnabled,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *PgSQLAdapter) NewDelete() common.DeleteQuery {
|
||||
return &PgSQLDeleteQuery{
|
||||
db: p.getDB(),
|
||||
driverName: p.driverName,
|
||||
args: make([]interface{}, 0),
|
||||
whereClauses: make([]string, 0),
|
||||
db: p.getDB(),
|
||||
driverName: p.driverName,
|
||||
args: make([]interface{}, 0),
|
||||
whereClauses: make([]string, 0),
|
||||
metricsEnabled: p.metricsEnabled,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -111,6 +124,8 @@ func (p *PgSQLAdapter) Exec(ctx context.Context, query string, args ...interface
|
||||
err = logger.HandlePanic("PgSQLAdapter.Exec", r)
|
||||
}
|
||||
}()
|
||||
startedAt := time.Now()
|
||||
operation, schema, entity, table := metricTargetFromRawQuery(query, p.driverName)
|
||||
logger.Debug("PgSQL Exec: %s [args: %v]", query, args)
|
||||
var result sql.Result
|
||||
run := func() error { var e error; result, e = p.getDB().ExecContext(ctx, query, args...); return e }
|
||||
@@ -122,8 +137,10 @@ func (p *PgSQLAdapter) Exec(ctx context.Context, query string, args ...interface
|
||||
}
|
||||
if err != nil {
|
||||
logger.Error("PgSQL Exec failed: %v", err)
|
||||
recordQueryMetrics(p.metricsEnabled, operation, schema, entity, table, startedAt, err)
|
||||
return nil, err
|
||||
}
|
||||
recordQueryMetrics(p.metricsEnabled, operation, schema, entity, table, startedAt, nil)
|
||||
return &PgSQLResult{result: result}, nil
|
||||
}
|
||||
|
||||
@@ -133,6 +150,8 @@ func (p *PgSQLAdapter) Query(ctx context.Context, dest interface{}, query string
|
||||
err = logger.HandlePanic("PgSQLAdapter.Query", r)
|
||||
}
|
||||
}()
|
||||
startedAt := time.Now()
|
||||
operation, schema, entity, table := metricTargetFromRawQuery(query, p.driverName)
|
||||
logger.Debug("PgSQL Query: %s [args: %v]", query, args)
|
||||
var rows *sql.Rows
|
||||
run := func() error { var e error; rows, e = p.getDB().QueryContext(ctx, query, args...); return e }
|
||||
@@ -144,11 +163,14 @@ func (p *PgSQLAdapter) Query(ctx context.Context, dest interface{}, query string
|
||||
}
|
||||
if err != nil {
|
||||
logger.Error("PgSQL Query failed: %v", err)
|
||||
recordQueryMetrics(p.metricsEnabled, operation, schema, entity, table, startedAt, err)
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return scanRows(rows, dest)
|
||||
err = scanRows(rows, dest)
|
||||
recordQueryMetrics(p.metricsEnabled, operation, schema, entity, table, startedAt, err)
|
||||
return err
|
||||
}
|
||||
|
||||
func (p *PgSQLAdapter) BeginTx(ctx context.Context) (common.Database, error) {
|
||||
@@ -156,7 +178,7 @@ func (p *PgSQLAdapter) BeginTx(ctx context.Context) (common.Database, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &PgSQLTxAdapter{tx: tx, driverName: p.driverName}, nil
|
||||
return &PgSQLTxAdapter{tx: tx, driverName: p.driverName, metricsEnabled: p.metricsEnabled}, nil
|
||||
}
|
||||
|
||||
func (p *PgSQLAdapter) CommitTx(ctx context.Context) error {
|
||||
@@ -179,7 +201,7 @@ func (p *PgSQLAdapter) RunInTransaction(ctx context.Context, fn func(common.Data
|
||||
return err
|
||||
}
|
||||
|
||||
adapter := &PgSQLTxAdapter{tx: tx, driverName: p.driverName}
|
||||
adapter := &PgSQLTxAdapter{tx: tx, driverName: p.driverName, metricsEnabled: p.metricsEnabled}
|
||||
|
||||
defer func() {
|
||||
if p := recover(); p != nil {
|
||||
@@ -222,34 +244,34 @@ type relationMetadata struct {
|
||||
|
||||
// PgSQLSelectQuery implements SelectQuery for PostgreSQL
|
||||
type PgSQLSelectQuery struct {
|
||||
db *sql.DB
|
||||
tx *sql.Tx
|
||||
model interface{}
|
||||
tableName string
|
||||
tableAlias string
|
||||
driverName string // Database driver name (postgres, sqlite, mssql)
|
||||
columns []string
|
||||
columnExprs []string
|
||||
whereClauses []string
|
||||
orClauses []string
|
||||
joins []string
|
||||
orderBy []string
|
||||
groupBy []string
|
||||
havingClauses []string
|
||||
limit int
|
||||
offset int
|
||||
args []interface{}
|
||||
paramCounter int
|
||||
preloads []preloadConfig
|
||||
db *sql.DB
|
||||
tx *sql.Tx
|
||||
model interface{}
|
||||
entity string
|
||||
tableName string
|
||||
schema string
|
||||
tableAlias string
|
||||
driverName string // Database driver name (postgres, sqlite, mssql)
|
||||
columns []string
|
||||
columnExprs []string
|
||||
whereClauses []string
|
||||
orClauses []string
|
||||
joins []string
|
||||
orderBy []string
|
||||
groupBy []string
|
||||
havingClauses []string
|
||||
limit int
|
||||
offset int
|
||||
args []interface{}
|
||||
paramCounter int
|
||||
preloads []preloadConfig
|
||||
metricsEnabled bool
|
||||
}
|
||||
|
||||
func (p *PgSQLSelectQuery) Model(model interface{}) common.SelectQuery {
|
||||
p.model = model
|
||||
if provider, ok := model.(common.TableNameProvider); ok {
|
||||
fullTableName := provider.TableName()
|
||||
// For SQLite, convert "schema.table" to "schema_table"
|
||||
_, p.tableName = parseTableName(fullTableName, p.driverName)
|
||||
}
|
||||
p.schema, p.tableName = schemaAndTableFromModel(model, p.driverName)
|
||||
p.entity = entityNameFromModel(model, p.tableName)
|
||||
if provider, ok := model.(common.TableAliasProvider); ok {
|
||||
p.tableAlias = provider.TableAlias()
|
||||
}
|
||||
@@ -258,7 +280,10 @@ func (p *PgSQLSelectQuery) Model(model interface{}) common.SelectQuery {
|
||||
|
||||
func (p *PgSQLSelectQuery) Table(table string) common.SelectQuery {
|
||||
// For SQLite, convert "schema.table" to "schema_table"
|
||||
_, p.tableName = parseTableName(table, p.driverName)
|
||||
p.schema, p.tableName = parseTableName(table, p.driverName)
|
||||
if p.entity == "" {
|
||||
p.entity = cleanMetricIdentifier(p.tableName)
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
@@ -468,6 +493,7 @@ func (p *PgSQLSelectQuery) Scan(ctx context.Context, dest interface{}) (err erro
|
||||
err = logger.HandlePanic("PgSQLSelectQuery.Scan", r)
|
||||
}
|
||||
}()
|
||||
startedAt := time.Now()
|
||||
|
||||
// Apply preloads that use JOINs
|
||||
p.applyJoinPreloads()
|
||||
@@ -484,17 +510,21 @@ func (p *PgSQLSelectQuery) Scan(ctx context.Context, dest interface{}) (err erro
|
||||
|
||||
if err != nil {
|
||||
logger.Error("PgSQL SELECT failed: %v", err)
|
||||
recordQueryMetrics(p.metricsEnabled, "SELECT", p.schema, p.entity, p.tableName, startedAt, err)
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
err = scanRows(rows, dest)
|
||||
if err != nil {
|
||||
recordQueryMetrics(p.metricsEnabled, "SELECT", p.schema, p.entity, p.tableName, startedAt, err)
|
||||
return err
|
||||
}
|
||||
|
||||
// Apply preloads that use separate queries
|
||||
return p.applySubqueryPreloads(ctx, dest)
|
||||
err = p.applySubqueryPreloads(ctx, dest)
|
||||
recordQueryMetrics(p.metricsEnabled, "SELECT", p.schema, p.entity, p.tableName, startedAt, err)
|
||||
return err
|
||||
}
|
||||
|
||||
func (p *PgSQLSelectQuery) ScanModel(ctx context.Context) error {
|
||||
@@ -504,15 +534,8 @@ func (p *PgSQLSelectQuery) ScanModel(ctx context.Context) error {
|
||||
return p.Scan(ctx, p.model)
|
||||
}
|
||||
|
||||
func (p *PgSQLSelectQuery) Count(ctx context.Context) (count int, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = logger.HandlePanic("PgSQLSelectQuery.Count", r)
|
||||
count = 0
|
||||
}
|
||||
}()
|
||||
|
||||
// Build a COUNT query
|
||||
// countInternal executes the COUNT query and returns the result without recording metrics.
|
||||
func (p *PgSQLSelectQuery) countInternal(ctx context.Context) (int, error) {
|
||||
var sb strings.Builder
|
||||
sb.WriteString("SELECT COUNT(*) FROM ")
|
||||
sb.WriteString(p.tableName)
|
||||
@@ -546,10 +569,26 @@ func (p *PgSQLSelectQuery) Count(ctx context.Context) (count int, err error) {
|
||||
row = p.db.QueryRowContext(ctx, query, p.args...)
|
||||
}
|
||||
|
||||
err = row.Scan(&count)
|
||||
var count int
|
||||
if err := row.Scan(&count); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (p *PgSQLSelectQuery) Count(ctx context.Context) (count int, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = logger.HandlePanic("PgSQLSelectQuery.Count", r)
|
||||
count = 0
|
||||
}
|
||||
}()
|
||||
startedAt := time.Now()
|
||||
count, err = p.countInternal(ctx)
|
||||
if err != nil {
|
||||
logger.Error("PgSQL COUNT failed: %v", err)
|
||||
}
|
||||
recordQueryMetrics(p.metricsEnabled, "COUNT", p.schema, p.entity, p.tableName, startedAt, err)
|
||||
return count, err
|
||||
}
|
||||
|
||||
@@ -560,27 +599,32 @@ func (p *PgSQLSelectQuery) Exists(ctx context.Context) (exists bool, err error)
|
||||
exists = false
|
||||
}
|
||||
}()
|
||||
|
||||
count, err := p.Count(ctx)
|
||||
startedAt := time.Now()
|
||||
count, err := p.countInternal(ctx)
|
||||
if err != nil {
|
||||
logger.Error("PgSQL EXISTS failed: %v", err)
|
||||
}
|
||||
recordQueryMetrics(p.metricsEnabled, "EXISTS", p.schema, p.entity, p.tableName, startedAt, err)
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
// PgSQLInsertQuery implements InsertQuery for PostgreSQL
|
||||
type PgSQLInsertQuery struct {
|
||||
db *sql.DB
|
||||
tx *sql.Tx
|
||||
tableName string
|
||||
driverName string
|
||||
values map[string]interface{}
|
||||
returning []string
|
||||
db *sql.DB
|
||||
tx *sql.Tx
|
||||
schema string
|
||||
tableName string
|
||||
entity string
|
||||
driverName string
|
||||
values map[string]interface{}
|
||||
valueOrder []string
|
||||
returning []string
|
||||
metricsEnabled bool
|
||||
}
|
||||
|
||||
func (p *PgSQLInsertQuery) Model(model interface{}) common.InsertQuery {
|
||||
if provider, ok := model.(common.TableNameProvider); ok {
|
||||
fullTableName := provider.TableName()
|
||||
// For SQLite, convert "schema.table" to "schema_table"
|
||||
_, p.tableName = parseTableName(fullTableName, p.driverName)
|
||||
}
|
||||
p.schema, p.tableName = schemaAndTableFromModel(model, p.driverName)
|
||||
p.entity = entityNameFromModel(model, p.tableName)
|
||||
// Extract values from model using reflection
|
||||
// This is a simplified implementation
|
||||
return p
|
||||
@@ -588,11 +632,17 @@ func (p *PgSQLInsertQuery) Model(model interface{}) common.InsertQuery {
|
||||
|
||||
func (p *PgSQLInsertQuery) Table(table string) common.InsertQuery {
|
||||
// For SQLite, convert "schema.table" to "schema_table"
|
||||
_, p.tableName = parseTableName(table, p.driverName)
|
||||
p.schema, p.tableName = parseTableName(table, p.driverName)
|
||||
if p.entity == "" {
|
||||
p.entity = cleanMetricIdentifier(p.tableName)
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *PgSQLInsertQuery) Value(column string, value interface{}) common.InsertQuery {
|
||||
if _, exists := p.values[column]; !exists {
|
||||
p.valueOrder = append(p.valueOrder, column)
|
||||
}
|
||||
p.values[column] = value
|
||||
return p
|
||||
}
|
||||
@@ -608,25 +658,27 @@ func (p *PgSQLInsertQuery) Returning(columns ...string) common.InsertQuery {
|
||||
}
|
||||
|
||||
func (p *PgSQLInsertQuery) Exec(ctx context.Context) (res common.Result, err error) {
|
||||
startedAt := time.Now()
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = logger.HandlePanic("PgSQLInsertQuery.Exec", r)
|
||||
}
|
||||
recordQueryMetrics(p.metricsEnabled, "INSERT", p.schema, p.entity, p.tableName, startedAt, err)
|
||||
}()
|
||||
|
||||
if len(p.values) == 0 {
|
||||
return nil, fmt.Errorf("no values to insert")
|
||||
err = fmt.Errorf("no values to insert")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
columns := make([]string, 0, len(p.values))
|
||||
placeholders := make([]string, 0, len(p.values))
|
||||
args := make([]interface{}, 0, len(p.values))
|
||||
|
||||
i := 1
|
||||
for col, val := range p.values {
|
||||
for _, col := range p.valueOrder {
|
||||
columns = append(columns, col)
|
||||
placeholders = append(placeholders, fmt.Sprintf("$%d", i))
|
||||
args = append(args, val)
|
||||
args = append(args, p.values[col])
|
||||
i++
|
||||
}
|
||||
|
||||
@@ -658,35 +710,40 @@ func (p *PgSQLInsertQuery) Exec(ctx context.Context) (res common.Result, err err
|
||||
|
||||
// PgSQLUpdateQuery implements UpdateQuery for PostgreSQL
|
||||
type PgSQLUpdateQuery struct {
|
||||
db *sql.DB
|
||||
tx *sql.Tx
|
||||
tableName string
|
||||
driverName string
|
||||
model interface{}
|
||||
sets map[string]interface{}
|
||||
whereClauses []string
|
||||
args []interface{}
|
||||
paramCounter int
|
||||
returning []string
|
||||
db *sql.DB
|
||||
tx *sql.Tx
|
||||
schema string
|
||||
tableName string
|
||||
entity string
|
||||
driverName string
|
||||
model interface{}
|
||||
sets map[string]interface{}
|
||||
setOrder []string
|
||||
whereClauses []string
|
||||
args []interface{}
|
||||
paramCounter int
|
||||
returning []string
|
||||
metricsEnabled bool
|
||||
}
|
||||
|
||||
func (p *PgSQLUpdateQuery) Model(model interface{}) common.UpdateQuery {
|
||||
p.model = model
|
||||
if provider, ok := model.(common.TableNameProvider); ok {
|
||||
fullTableName := provider.TableName()
|
||||
// For SQLite, convert "schema.table" to "schema_table"
|
||||
_, p.tableName = parseTableName(fullTableName, p.driverName)
|
||||
}
|
||||
p.schema, p.tableName = schemaAndTableFromModel(model, p.driverName)
|
||||
p.entity = entityNameFromModel(model, p.tableName)
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *PgSQLUpdateQuery) Table(table string) common.UpdateQuery {
|
||||
// For SQLite, convert "schema.table" to "schema_table"
|
||||
_, p.tableName = parseTableName(table, p.driverName)
|
||||
p.schema, p.tableName = parseTableName(table, p.driverName)
|
||||
if p.entity == "" {
|
||||
p.entity = cleanMetricIdentifier(p.tableName)
|
||||
}
|
||||
if p.model == nil {
|
||||
model, err := modelregistry.GetModelByName(table)
|
||||
if err == nil {
|
||||
p.model = model
|
||||
p.entity = entityNameFromModel(model, p.tableName)
|
||||
}
|
||||
}
|
||||
return p
|
||||
@@ -696,6 +753,9 @@ func (p *PgSQLUpdateQuery) Set(column string, value interface{}) common.UpdateQu
|
||||
if p.model != nil && !reflection.IsColumnWritable(p.model, column) {
|
||||
return p
|
||||
}
|
||||
if _, exists := p.sets[column]; !exists {
|
||||
p.setOrder = append(p.setOrder, column)
|
||||
}
|
||||
p.sets[column] = value
|
||||
return p
|
||||
}
|
||||
@@ -706,13 +766,23 @@ func (p *PgSQLUpdateQuery) SetMap(values map[string]interface{}) common.UpdateQu
|
||||
pkName = reflection.GetPrimaryKeyName(p.model)
|
||||
}
|
||||
|
||||
for column, value := range values {
|
||||
orderedColumns := make([]string, 0, len(values))
|
||||
for column := range values {
|
||||
orderedColumns = append(orderedColumns, column)
|
||||
}
|
||||
sort.Strings(orderedColumns)
|
||||
|
||||
for _, column := range orderedColumns {
|
||||
value := values[column]
|
||||
if pkName != "" && column == pkName {
|
||||
continue
|
||||
}
|
||||
if p.model != nil && !reflection.IsColumnWritable(p.model, column) {
|
||||
continue
|
||||
}
|
||||
if _, exists := p.sets[column]; !exists {
|
||||
p.setOrder = append(p.setOrder, column)
|
||||
}
|
||||
p.sets[column] = value
|
||||
}
|
||||
return p
|
||||
@@ -741,24 +811,26 @@ func (p *PgSQLUpdateQuery) replacePlaceholders(query string, argCount int) strin
|
||||
}
|
||||
|
||||
func (p *PgSQLUpdateQuery) Exec(ctx context.Context) (res common.Result, err error) {
|
||||
startedAt := time.Now()
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = logger.HandlePanic("PgSQLUpdateQuery.Exec", r)
|
||||
}
|
||||
recordQueryMetrics(p.metricsEnabled, "UPDATE", p.schema, p.entity, p.tableName, startedAt, err)
|
||||
}()
|
||||
|
||||
if len(p.sets) == 0 {
|
||||
return nil, fmt.Errorf("no values to update")
|
||||
err = fmt.Errorf("no values to update")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
setClauses := make([]string, 0, len(p.sets))
|
||||
setArgs := make([]interface{}, 0, len(p.sets))
|
||||
|
||||
// SET parameters start at $1
|
||||
i := 1
|
||||
for col, val := range p.sets {
|
||||
for _, col := range p.setOrder {
|
||||
setClauses = append(setClauses, fmt.Sprintf("%s = $%d", col, i))
|
||||
setArgs = append(setArgs, val)
|
||||
setArgs = append(setArgs, p.sets[col])
|
||||
i++
|
||||
}
|
||||
|
||||
@@ -820,27 +892,30 @@ func (p *PgSQLUpdateQuery) Exec(ctx context.Context) (res common.Result, err err
|
||||
|
||||
// PgSQLDeleteQuery implements DeleteQuery for PostgreSQL
|
||||
type PgSQLDeleteQuery struct {
|
||||
db *sql.DB
|
||||
tx *sql.Tx
|
||||
tableName string
|
||||
driverName string
|
||||
whereClauses []string
|
||||
args []interface{}
|
||||
paramCounter int
|
||||
db *sql.DB
|
||||
tx *sql.Tx
|
||||
schema string
|
||||
tableName string
|
||||
entity string
|
||||
driverName string
|
||||
whereClauses []string
|
||||
args []interface{}
|
||||
paramCounter int
|
||||
metricsEnabled bool
|
||||
}
|
||||
|
||||
func (p *PgSQLDeleteQuery) Model(model interface{}) common.DeleteQuery {
|
||||
if provider, ok := model.(common.TableNameProvider); ok {
|
||||
fullTableName := provider.TableName()
|
||||
// For SQLite, convert "schema.table" to "schema_table"
|
||||
_, p.tableName = parseTableName(fullTableName, p.driverName)
|
||||
}
|
||||
p.schema, p.tableName = schemaAndTableFromModel(model, p.driverName)
|
||||
p.entity = entityNameFromModel(model, p.tableName)
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *PgSQLDeleteQuery) Table(table string) common.DeleteQuery {
|
||||
// For SQLite, convert "schema.table" to "schema_table"
|
||||
_, p.tableName = parseTableName(table, p.driverName)
|
||||
p.schema, p.tableName = parseTableName(table, p.driverName)
|
||||
if p.entity == "" {
|
||||
p.entity = cleanMetricIdentifier(p.tableName)
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
@@ -862,10 +937,12 @@ func (p *PgSQLDeleteQuery) replacePlaceholders(query string, argCount int) strin
|
||||
}
|
||||
|
||||
func (p *PgSQLDeleteQuery) Exec(ctx context.Context) (res common.Result, err error) {
|
||||
startedAt := time.Now()
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = logger.HandlePanic("PgSQLDeleteQuery.Exec", r)
|
||||
}
|
||||
recordQueryMetrics(p.metricsEnabled, "DELETE", p.schema, p.entity, p.tableName, startedAt, err)
|
||||
}()
|
||||
|
||||
query := fmt.Sprintf("DELETE FROM %s", p.tableName)
|
||||
@@ -913,66 +990,80 @@ func (p *PgSQLResult) LastInsertId() (int64, error) {
|
||||
|
||||
// PgSQLTxAdapter wraps a PostgreSQL transaction
|
||||
type PgSQLTxAdapter struct {
|
||||
tx *sql.Tx
|
||||
driverName string
|
||||
tx *sql.Tx
|
||||
driverName string
|
||||
metricsEnabled bool
|
||||
}
|
||||
|
||||
func (p *PgSQLTxAdapter) NewSelect() common.SelectQuery {
|
||||
return &PgSQLSelectQuery{
|
||||
tx: p.tx,
|
||||
driverName: p.driverName,
|
||||
columns: []string{"*"},
|
||||
args: make([]interface{}, 0),
|
||||
tx: p.tx,
|
||||
driverName: p.driverName,
|
||||
columns: []string{"*"},
|
||||
args: make([]interface{}, 0),
|
||||
metricsEnabled: p.metricsEnabled,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *PgSQLTxAdapter) NewInsert() common.InsertQuery {
|
||||
return &PgSQLInsertQuery{
|
||||
tx: p.tx,
|
||||
driverName: p.driverName,
|
||||
values: make(map[string]interface{}),
|
||||
tx: p.tx,
|
||||
driverName: p.driverName,
|
||||
values: make(map[string]interface{}),
|
||||
metricsEnabled: p.metricsEnabled,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *PgSQLTxAdapter) NewUpdate() common.UpdateQuery {
|
||||
return &PgSQLUpdateQuery{
|
||||
tx: p.tx,
|
||||
driverName: p.driverName,
|
||||
sets: make(map[string]interface{}),
|
||||
args: make([]interface{}, 0),
|
||||
whereClauses: make([]string, 0),
|
||||
tx: p.tx,
|
||||
driverName: p.driverName,
|
||||
sets: make(map[string]interface{}),
|
||||
args: make([]interface{}, 0),
|
||||
whereClauses: make([]string, 0),
|
||||
metricsEnabled: p.metricsEnabled,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *PgSQLTxAdapter) NewDelete() common.DeleteQuery {
|
||||
return &PgSQLDeleteQuery{
|
||||
tx: p.tx,
|
||||
driverName: p.driverName,
|
||||
args: make([]interface{}, 0),
|
||||
whereClauses: make([]string, 0),
|
||||
tx: p.tx,
|
||||
driverName: p.driverName,
|
||||
args: make([]interface{}, 0),
|
||||
whereClauses: make([]string, 0),
|
||||
metricsEnabled: p.metricsEnabled,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *PgSQLTxAdapter) Exec(ctx context.Context, query string, args ...interface{}) (common.Result, error) {
|
||||
startedAt := time.Now()
|
||||
operation, schema, entity, table := metricTargetFromRawQuery(query, p.driverName)
|
||||
logger.Debug("PgSQL Tx Exec: %s [args: %v]", query, args)
|
||||
result, err := p.tx.ExecContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
logger.Error("PgSQL Tx Exec failed: %v", err)
|
||||
recordQueryMetrics(p.metricsEnabled, operation, schema, entity, table, startedAt, err)
|
||||
return nil, err
|
||||
}
|
||||
recordQueryMetrics(p.metricsEnabled, operation, schema, entity, table, startedAt, nil)
|
||||
return &PgSQLResult{result: result}, nil
|
||||
}
|
||||
|
||||
func (p *PgSQLTxAdapter) Query(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
|
||||
startedAt := time.Now()
|
||||
operation, schema, entity, table := metricTargetFromRawQuery(query, p.driverName)
|
||||
logger.Debug("PgSQL Tx Query: %s [args: %v]", query, args)
|
||||
rows, err := p.tx.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
logger.Error("PgSQL Tx Query failed: %v", err)
|
||||
recordQueryMetrics(p.metricsEnabled, operation, schema, entity, table, startedAt, err)
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return scanRows(rows, dest)
|
||||
err = scanRows(rows, dest)
|
||||
recordQueryMetrics(p.metricsEnabled, operation, schema, entity, table, startedAt, err)
|
||||
return err
|
||||
}
|
||||
|
||||
func (p *PgSQLTxAdapter) BeginTx(ctx context.Context) (common.Database, error) {
|
||||
|
||||
206
pkg/common/adapters/database/query_metrics.go
Normal file
206
pkg/common/adapters/database/query_metrics.go
Normal file
@@ -0,0 +1,206 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/metrics"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
||||
)
|
||||
|
||||
func recordQueryMetrics(enabled bool, operation, schema, entity, table string, startedAt time.Time, err error) {
|
||||
if !enabled {
|
||||
return
|
||||
}
|
||||
|
||||
metrics.GetProvider().RecordDBQuery(
|
||||
normalizeMetricOperation(operation),
|
||||
normalizeMetricSchema(schema),
|
||||
normalizeMetricEntity(entity, table),
|
||||
normalizeMetricTable(table),
|
||||
time.Since(startedAt),
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
func normalizeMetricOperation(operation string) string {
|
||||
operation = strings.ToUpper(strings.TrimSpace(operation))
|
||||
if operation == "" {
|
||||
return "UNKNOWN"
|
||||
}
|
||||
return operation
|
||||
}
|
||||
|
||||
func normalizeMetricSchema(schema string) string {
|
||||
schema = cleanMetricIdentifier(schema)
|
||||
if schema == "" {
|
||||
return "default"
|
||||
}
|
||||
return schema
|
||||
}
|
||||
|
||||
func normalizeMetricEntity(entity, table string) string {
|
||||
entity = cleanMetricIdentifier(entity)
|
||||
if entity != "" {
|
||||
return entity
|
||||
}
|
||||
|
||||
table = cleanMetricIdentifier(table)
|
||||
if table != "" {
|
||||
return table
|
||||
}
|
||||
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
func normalizeMetricTable(table string) string {
|
||||
table = cleanMetricIdentifier(table)
|
||||
if table == "" {
|
||||
return "unknown"
|
||||
}
|
||||
return table
|
||||
}
|
||||
|
||||
func entityNameFromModel(model interface{}, table string) string {
|
||||
if model == nil {
|
||||
return cleanMetricIdentifier(table)
|
||||
}
|
||||
|
||||
modelType := reflect.TypeOf(model)
|
||||
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
|
||||
if modelType == nil {
|
||||
return cleanMetricIdentifier(table)
|
||||
}
|
||||
|
||||
if modelType.Kind() == reflect.Struct && modelType.Name() != "" {
|
||||
return reflection.ToSnakeCase(modelType.Name())
|
||||
}
|
||||
|
||||
return cleanMetricIdentifier(table)
|
||||
}
|
||||
|
||||
func schemaAndTableFromModel(model interface{}, driverName string) (schema, table string) {
|
||||
provider, ok := tableNameProviderFromModel(model)
|
||||
if !ok {
|
||||
return "", ""
|
||||
}
|
||||
|
||||
return parseTableName(provider.TableName(), driverName)
|
||||
}
|
||||
|
||||
// tableNameProviderType is cached to avoid repeated reflection on every call.
|
||||
var tableNameProviderType = reflect.TypeOf((*common.TableNameProvider)(nil)).Elem()
|
||||
|
||||
func tableNameProviderFromModel(model interface{}) (common.TableNameProvider, bool) {
|
||||
if model == nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
if provider, ok := model.(common.TableNameProvider); ok {
|
||||
return provider, true
|
||||
}
|
||||
|
||||
modelType := reflect.TypeOf(model)
|
||||
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
|
||||
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// Check whether *T implements TableNameProvider before allocating.
|
||||
ptrType := reflect.PointerTo(modelType)
|
||||
if !ptrType.Implements(tableNameProviderType) && !modelType.Implements(tableNameProviderType) {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
modelValue := reflect.New(modelType)
|
||||
if provider, ok := modelValue.Interface().(common.TableNameProvider); ok {
|
||||
return provider, true
|
||||
}
|
||||
|
||||
if provider, ok := modelValue.Elem().Interface().(common.TableNameProvider); ok {
|
||||
return provider, true
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func metricTargetFromRawQuery(query, driverName string) (operation, schema, entity, table string) {
|
||||
operation = normalizeMetricOperation(firstQueryKeyword(query))
|
||||
tableRef := tableFromRawQuery(query, operation)
|
||||
if tableRef == "" {
|
||||
return operation, "", "unknown", "unknown"
|
||||
}
|
||||
|
||||
schema, table = parseTableName(tableRef, driverName)
|
||||
entity = cleanMetricIdentifier(table)
|
||||
return operation, schema, entity, table
|
||||
}
|
||||
|
||||
func firstQueryKeyword(query string) string {
|
||||
query = strings.TrimSpace(query)
|
||||
if query == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
fields := strings.Fields(query)
|
||||
if len(fields) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
return fields[0]
|
||||
}
|
||||
|
||||
func tableFromRawQuery(query, operation string) string {
|
||||
tokens := tokenizeQuery(query)
|
||||
if len(tokens) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
switch operation {
|
||||
case "SELECT":
|
||||
return tokenAfter(tokens, "FROM")
|
||||
case "INSERT":
|
||||
return tokenAfter(tokens, "INTO")
|
||||
case "UPDATE":
|
||||
return tokenAfter(tokens, "UPDATE")
|
||||
case "DELETE":
|
||||
return tokenAfter(tokens, "FROM")
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func tokenAfter(tokens []string, keyword string) string {
|
||||
for idx, token := range tokens {
|
||||
if strings.EqualFold(token, keyword) && idx+1 < len(tokens) {
|
||||
return cleanMetricIdentifier(tokens[idx+1])
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func tokenizeQuery(query string) []string {
|
||||
replacer := strings.NewReplacer(
|
||||
"\n", " ",
|
||||
"\t", " ",
|
||||
"(", " ",
|
||||
")", " ",
|
||||
",", " ",
|
||||
)
|
||||
return strings.Fields(replacer.Replace(query))
|
||||
}
|
||||
|
||||
func cleanMetricIdentifier(value string) string {
|
||||
value = strings.TrimSpace(value)
|
||||
value = strings.Trim(value, "\"'`[]")
|
||||
value = strings.TrimRight(value, ";")
|
||||
return value
|
||||
}
|
||||
306
pkg/common/adapters/database/query_metrics_test.go
Normal file
306
pkg/common/adapters/database/query_metrics_test.go
Normal file
@@ -0,0 +1,306 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/uptrace/bun"
|
||||
"github.com/uptrace/bun/dialect/sqlitedialect"
|
||||
"github.com/uptrace/bun/driver/sqliteshim"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/metrics"
|
||||
)
|
||||
|
||||
type queryMetricCall struct {
|
||||
operation string
|
||||
schema string
|
||||
entity string
|
||||
table string
|
||||
}
|
||||
|
||||
type capturingMetricsProvider struct {
|
||||
mu sync.Mutex
|
||||
calls []queryMetricCall
|
||||
}
|
||||
|
||||
func (c *capturingMetricsProvider) RecordHTTPRequest(method, path, status string, duration time.Duration) {
|
||||
}
|
||||
func (c *capturingMetricsProvider) IncRequestsInFlight() {}
|
||||
func (c *capturingMetricsProvider) DecRequestsInFlight() {}
|
||||
func (c *capturingMetricsProvider) RecordDBQuery(operation, schema, entity, table string, duration time.Duration, err error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.calls = append(c.calls, queryMetricCall{
|
||||
operation: operation,
|
||||
schema: schema,
|
||||
entity: entity,
|
||||
table: table,
|
||||
})
|
||||
}
|
||||
func (c *capturingMetricsProvider) RecordCacheHit(provider string) {}
|
||||
func (c *capturingMetricsProvider) RecordCacheMiss(provider string) {}
|
||||
func (c *capturingMetricsProvider) UpdateCacheSize(provider string, size int64) {
|
||||
}
|
||||
func (c *capturingMetricsProvider) RecordEventPublished(source, eventType string) {}
|
||||
func (c *capturingMetricsProvider) RecordEventProcessed(source, eventType, status string, duration time.Duration) {
|
||||
}
|
||||
func (c *capturingMetricsProvider) UpdateEventQueueSize(size int64) {}
|
||||
func (c *capturingMetricsProvider) RecordPanic(methodName string) {}
|
||||
func (c *capturingMetricsProvider) Handler() http.Handler { return http.NewServeMux() }
|
||||
|
||||
func (c *capturingMetricsProvider) snapshot() []queryMetricCall {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
out := make([]queryMetricCall, len(c.calls))
|
||||
copy(out, c.calls)
|
||||
return out
|
||||
}
|
||||
|
||||
type queryMetricsGormUser struct {
|
||||
ID int `gorm:"primaryKey"`
|
||||
Name string
|
||||
}
|
||||
|
||||
func (queryMetricsGormUser) TableName() string {
|
||||
return "metrics_gorm_users"
|
||||
}
|
||||
|
||||
type queryMetricsBunUser struct {
|
||||
bun.BaseModel `bun:"table:metrics_bun_users"`
|
||||
ID int64 `bun:"id,pk,autoincrement"`
|
||||
Name string `bun:"name"`
|
||||
}
|
||||
|
||||
func TestPgSQLAdapterRecordsSchemaEntityTableMetrics(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
provider := &capturingMetricsProvider{}
|
||||
prev := metrics.GetProvider()
|
||||
metrics.SetProvider(provider)
|
||||
defer metrics.SetProvider(prev)
|
||||
|
||||
mock.ExpectExec(`UPDATE users SET name = \$1 WHERE id = \$2`).
|
||||
WithArgs("Alice", 1).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
adapter := NewPgSQLAdapter(db)
|
||||
_, err = adapter.NewUpdate().
|
||||
Table("public.users").
|
||||
Set("name", "Alice").
|
||||
Where("id = ?", 1).
|
||||
Exec(context.Background())
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
|
||||
calls := provider.snapshot()
|
||||
require.Len(t, calls, 1)
|
||||
assert.Equal(t, "UPDATE", calls[0].operation)
|
||||
assert.Equal(t, "public", calls[0].schema)
|
||||
assert.Equal(t, "users", calls[0].entity)
|
||||
assert.Equal(t, "users", calls[0].table)
|
||||
}
|
||||
|
||||
func TestPgSQLAdapterDisableMetricsSuppressesEmission(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
provider := &capturingMetricsProvider{}
|
||||
prev := metrics.GetProvider()
|
||||
metrics.SetProvider(provider)
|
||||
defer metrics.SetProvider(prev)
|
||||
|
||||
mock.ExpectExec(`DELETE FROM users WHERE id = \$1`).
|
||||
WithArgs(1).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
adapter := NewPgSQLAdapter(db).SetMetricsEnabled(false)
|
||||
_, err = adapter.NewDelete().
|
||||
Table("users").
|
||||
Where("id = ?", 1).
|
||||
Exec(context.Background())
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
assert.Empty(t, provider.snapshot())
|
||||
}
|
||||
|
||||
func TestGormAdapterRecordsEntityAndTableMetrics(t *testing.T) {
|
||||
db, err := gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{})
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, db.AutoMigrate(&queryMetricsGormUser{}))
|
||||
require.NoError(t, db.Create(&queryMetricsGormUser{Name: "Alice"}).Error)
|
||||
|
||||
provider := &capturingMetricsProvider{}
|
||||
prev := metrics.GetProvider()
|
||||
metrics.SetProvider(provider)
|
||||
defer metrics.SetProvider(prev)
|
||||
|
||||
adapter := NewGormAdapter(db)
|
||||
var users []queryMetricsGormUser
|
||||
err = adapter.NewSelect().Model(&users).Scan(context.Background(), &users)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, users)
|
||||
|
||||
calls := provider.snapshot()
|
||||
require.Len(t, calls, 1)
|
||||
assert.Equal(t, "SELECT", calls[0].operation)
|
||||
assert.Equal(t, "default", calls[0].schema)
|
||||
assert.Equal(t, "query_metrics_gorm_user", calls[0].entity)
|
||||
assert.Equal(t, "metrics_gorm_users", calls[0].table)
|
||||
}
|
||||
|
||||
func TestPgSQLAdapterRecordsErrorMetric(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
provider := &capturingMetricsProvider{}
|
||||
prev := metrics.GetProvider()
|
||||
metrics.SetProvider(provider)
|
||||
defer metrics.SetProvider(prev)
|
||||
|
||||
mock.ExpectExec(`INSERT INTO users`).
|
||||
WillReturnError(fmt.Errorf("unique constraint violation"))
|
||||
|
||||
adapter := NewPgSQLAdapter(db)
|
||||
_, err = adapter.NewInsert().
|
||||
Table("users").
|
||||
Value("name", "Alice").
|
||||
Exec(context.Background())
|
||||
|
||||
require.Error(t, err)
|
||||
|
||||
calls := provider.snapshot()
|
||||
require.Len(t, calls, 1)
|
||||
assert.Equal(t, "INSERT", calls[0].operation)
|
||||
assert.Equal(t, "users", calls[0].table)
|
||||
}
|
||||
|
||||
func TestPgSQLAdapterRecordsExistsMetric(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
provider := &capturingMetricsProvider{}
|
||||
prev := metrics.GetProvider()
|
||||
metrics.SetProvider(provider)
|
||||
defer metrics.SetProvider(prev)
|
||||
|
||||
mock.ExpectQuery(`SELECT COUNT\(\*\) FROM users`).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(3))
|
||||
|
||||
adapter := NewPgSQLAdapter(db)
|
||||
exists, err := adapter.NewSelect().Table("users").Exists(context.Background())
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
|
||||
calls := provider.snapshot()
|
||||
require.Len(t, calls, 1)
|
||||
assert.Equal(t, "EXISTS", calls[0].operation)
|
||||
assert.Equal(t, "users", calls[0].table)
|
||||
}
|
||||
|
||||
func TestPgSQLAdapterRecordsCountMetric(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
provider := &capturingMetricsProvider{}
|
||||
prev := metrics.GetProvider()
|
||||
metrics.SetProvider(provider)
|
||||
defer metrics.SetProvider(prev)
|
||||
|
||||
mock.ExpectQuery(`SELECT COUNT\(\*\) FROM users`).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(5))
|
||||
|
||||
adapter := NewPgSQLAdapter(db)
|
||||
count, err := adapter.NewSelect().Table("users").Count(context.Background())
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 5, count)
|
||||
|
||||
calls := provider.snapshot()
|
||||
require.Len(t, calls, 1)
|
||||
assert.Equal(t, "COUNT", calls[0].operation)
|
||||
assert.Equal(t, "users", calls[0].table)
|
||||
}
|
||||
|
||||
func TestPgSQLAdapterRawExecRecordsMetric(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
provider := &capturingMetricsProvider{}
|
||||
prev := metrics.GetProvider()
|
||||
metrics.SetProvider(provider)
|
||||
defer metrics.SetProvider(prev)
|
||||
|
||||
mock.ExpectExec(`UPDATE public\.orders SET status = \$1`).
|
||||
WithArgs("shipped").
|
||||
WillReturnResult(sqlmock.NewResult(0, 2))
|
||||
|
||||
adapter := NewPgSQLAdapter(db)
|
||||
_, err = adapter.Exec(context.Background(), `UPDATE public.orders SET status = $1`, "shipped")
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
calls := provider.snapshot()
|
||||
require.Len(t, calls, 1)
|
||||
assert.Equal(t, "UPDATE", calls[0].operation)
|
||||
assert.Equal(t, "public", calls[0].schema)
|
||||
assert.Equal(t, "orders", calls[0].table)
|
||||
}
|
||||
|
||||
func TestBunAdapterRecordsEntityAndTableMetrics(t *testing.T) {
|
||||
sqldb, err := sql.Open(sqliteshim.ShimName, "file::memory:?cache=shared")
|
||||
require.NoError(t, err)
|
||||
defer sqldb.Close()
|
||||
|
||||
db := bun.NewDB(sqldb, sqlitedialect.New())
|
||||
defer db.Close()
|
||||
|
||||
_, err = db.NewCreateTable().
|
||||
Model((*queryMetricsBunUser)(nil)).
|
||||
IfNotExists().
|
||||
Exec(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = db.NewInsert().Model(&queryMetricsBunUser{Name: "Alice"}).Exec(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
provider := &capturingMetricsProvider{}
|
||||
prev := metrics.GetProvider()
|
||||
metrics.SetProvider(provider)
|
||||
defer metrics.SetProvider(prev)
|
||||
|
||||
adapter := NewBunAdapter(db)
|
||||
var users []queryMetricsBunUser
|
||||
err = adapter.NewSelect().Model(&users).Scan(context.Background(), &users)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, users)
|
||||
|
||||
calls := provider.snapshot()
|
||||
require.Len(t, calls, 1)
|
||||
assert.Equal(t, "SELECT", calls[0].operation)
|
||||
assert.Equal(t, "default", calls[0].schema)
|
||||
assert.Equal(t, "query_metrics_bun_user", calls[0].entity)
|
||||
assert.Equal(t, "metrics_bun_users", calls[0].table)
|
||||
}
|
||||
@@ -359,6 +359,42 @@ func (c *sqlConnection) Stats() *ConnectionStats {
|
||||
return stats
|
||||
}
|
||||
|
||||
func (c *sqlConnection) reconnectForAdapter() error {
|
||||
timeout := c.config.ConnectTimeout
|
||||
if timeout <= 0 {
|
||||
timeout = 10 * time.Second
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
|
||||
return c.Reconnect(ctx)
|
||||
}
|
||||
|
||||
func (c *sqlConnection) reopenNativeForAdapter() (*sql.DB, error) {
|
||||
if err := c.reconnectForAdapter(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return c.Native()
|
||||
}
|
||||
|
||||
func (c *sqlConnection) reopenBunForAdapter() (*bun.DB, error) {
|
||||
if err := c.reconnectForAdapter(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return c.Bun()
|
||||
}
|
||||
|
||||
func (c *sqlConnection) reopenGORMForAdapter() (*gorm.DB, error) {
|
||||
if err := c.reconnectForAdapter(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return c.GORM()
|
||||
}
|
||||
|
||||
// getBunAdapter returns or creates the Bun adapter
|
||||
func (c *sqlConnection) getBunAdapter() (common.Database, error) {
|
||||
if c == nil {
|
||||
@@ -391,7 +427,9 @@ func (c *sqlConnection) getBunAdapter() (common.Database, error) {
|
||||
c.bunDB = bun.NewDB(native, dialect)
|
||||
}
|
||||
|
||||
c.bunAdapter = database.NewBunAdapter(c.bunDB)
|
||||
c.bunAdapter = database.NewBunAdapter(c.bunDB).
|
||||
WithDBFactory(c.reopenBunForAdapter).
|
||||
SetMetricsEnabled(c.config.EnableMetrics)
|
||||
return c.bunAdapter, nil
|
||||
}
|
||||
|
||||
@@ -432,7 +470,9 @@ func (c *sqlConnection) getGORMAdapter() (common.Database, error) {
|
||||
c.gormDB = db
|
||||
}
|
||||
|
||||
c.gormAdapter = database.NewGormAdapter(c.gormDB)
|
||||
c.gormAdapter = database.NewGormAdapter(c.gormDB).
|
||||
WithDBFactory(c.reopenGORMForAdapter).
|
||||
SetMetricsEnabled(c.config.EnableMetrics)
|
||||
return c.gormAdapter, nil
|
||||
}
|
||||
|
||||
@@ -473,11 +513,17 @@ func (c *sqlConnection) getNativeAdapter() (common.Database, error) {
|
||||
// Create a native adapter based on database type
|
||||
switch c.dbType {
|
||||
case DatabaseTypePostgreSQL:
|
||||
c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB, string(c.dbType))
|
||||
c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB, string(c.dbType)).
|
||||
WithDBFactory(c.reopenNativeForAdapter).
|
||||
SetMetricsEnabled(c.config.EnableMetrics)
|
||||
case DatabaseTypeSQLite:
|
||||
c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB, string(c.dbType))
|
||||
c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB, string(c.dbType)).
|
||||
WithDBFactory(c.reopenNativeForAdapter).
|
||||
SetMetricsEnabled(c.config.EnableMetrics)
|
||||
case DatabaseTypeMSSQL:
|
||||
c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB, string(c.dbType))
|
||||
c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB, string(c.dbType)).
|
||||
WithDBFactory(c.reopenNativeForAdapter).
|
||||
SetMetricsEnabled(c.config.EnableMetrics)
|
||||
default:
|
||||
return nil, ErrUnsupportedDatabase
|
||||
}
|
||||
|
||||
@@ -4,8 +4,13 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common/adapters/database"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/dbmanager/providers"
|
||||
)
|
||||
|
||||
func TestNewConnectionFromDB(t *testing.T) {
|
||||
@@ -208,3 +213,157 @@ func TestNewConnectionFromDB_PostgreSQL(t *testing.T) {
|
||||
t.Errorf("Expected type DatabaseTypePostgreSQL, got '%s'", conn.Type())
|
||||
}
|
||||
}
|
||||
|
||||
func TestDatabaseNativeAdapterReconnectFactory(t *testing.T) {
|
||||
conn := newSQLConnection("test-native", DatabaseTypeSQLite, ConnectionConfig{
|
||||
Name: "test-native",
|
||||
Type: DatabaseTypeSQLite,
|
||||
FilePath: ":memory:",
|
||||
DefaultORM: string(ORMTypeNative),
|
||||
ConnectTimeout: 2 * time.Second,
|
||||
}, providers.NewSQLiteProvider())
|
||||
|
||||
ctx := context.Background()
|
||||
if err := conn.Connect(ctx); err != nil {
|
||||
t.Fatalf("Failed to connect: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
db, err := conn.Database()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get database adapter: %v", err)
|
||||
}
|
||||
|
||||
adapter, ok := db.(*database.PgSQLAdapter)
|
||||
if !ok {
|
||||
t.Fatalf("Expected PgSQLAdapter, got %T", db)
|
||||
}
|
||||
|
||||
underlyingBefore, ok := adapter.GetUnderlyingDB().(*sql.DB)
|
||||
if !ok {
|
||||
t.Fatalf("Expected underlying *sql.DB, got %T", adapter.GetUnderlyingDB())
|
||||
}
|
||||
|
||||
if err := underlyingBefore.Close(); err != nil {
|
||||
t.Fatalf("Failed to close underlying database: %v", err)
|
||||
}
|
||||
|
||||
if _, err := db.Exec(ctx, "SELECT 1"); err != nil {
|
||||
t.Fatalf("Expected native adapter to reconnect, got error: %v", err)
|
||||
}
|
||||
|
||||
underlyingAfter, ok := adapter.GetUnderlyingDB().(*sql.DB)
|
||||
if !ok {
|
||||
t.Fatalf("Expected reconnected *sql.DB, got %T", adapter.GetUnderlyingDB())
|
||||
}
|
||||
|
||||
if underlyingAfter == underlyingBefore {
|
||||
t.Fatal("Expected adapter to swap to a fresh *sql.DB after reconnect")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDatabaseBunAdapterReconnectFactory(t *testing.T) {
|
||||
conn := newSQLConnection("test-bun", DatabaseTypeSQLite, ConnectionConfig{
|
||||
Name: "test-bun",
|
||||
Type: DatabaseTypeSQLite,
|
||||
FilePath: ":memory:",
|
||||
DefaultORM: string(ORMTypeBun),
|
||||
ConnectTimeout: 2 * time.Second,
|
||||
}, providers.NewSQLiteProvider())
|
||||
|
||||
ctx := context.Background()
|
||||
if err := conn.Connect(ctx); err != nil {
|
||||
t.Fatalf("Failed to connect: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
db, err := conn.Database()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get database adapter: %v", err)
|
||||
}
|
||||
|
||||
adapter, ok := db.(*database.BunAdapter)
|
||||
if !ok {
|
||||
t.Fatalf("Expected BunAdapter, got %T", db)
|
||||
}
|
||||
|
||||
underlyingBefore, ok := adapter.GetUnderlyingDB().(interface{ Close() error })
|
||||
if !ok {
|
||||
t.Fatalf("Expected underlying Bun DB with Close method, got %T", adapter.GetUnderlyingDB())
|
||||
}
|
||||
|
||||
if err := underlyingBefore.Close(); err != nil {
|
||||
t.Fatalf("Failed to close underlying Bun database: %v", err)
|
||||
}
|
||||
|
||||
if _, err := db.Exec(ctx, "SELECT 1"); err != nil {
|
||||
t.Fatalf("Expected Bun adapter to reconnect, got error: %v", err)
|
||||
}
|
||||
|
||||
underlyingAfter := adapter.GetUnderlyingDB()
|
||||
if underlyingAfter == underlyingBefore {
|
||||
t.Fatal("Expected adapter to swap to a fresh Bun DB after reconnect")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDatabaseGormAdapterReconnectFactory(t *testing.T) {
|
||||
conn := newSQLConnection("test-gorm", DatabaseTypeSQLite, ConnectionConfig{
|
||||
Name: "test-gorm",
|
||||
Type: DatabaseTypeSQLite,
|
||||
FilePath: ":memory:",
|
||||
DefaultORM: string(ORMTypeGORM),
|
||||
ConnectTimeout: 2 * time.Second,
|
||||
}, providers.NewSQLiteProvider())
|
||||
|
||||
ctx := context.Background()
|
||||
if err := conn.Connect(ctx); err != nil {
|
||||
t.Fatalf("Failed to connect: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
db, err := conn.Database()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get database adapter: %v", err)
|
||||
}
|
||||
|
||||
adapter, ok := db.(*database.GormAdapter)
|
||||
if !ok {
|
||||
t.Fatalf("Expected GormAdapter, got %T", db)
|
||||
}
|
||||
|
||||
gormBefore, ok := adapter.GetUnderlyingDB().(*gorm.DB)
|
||||
if !ok {
|
||||
t.Fatalf("Expected underlying *gorm.DB, got %T", adapter.GetUnderlyingDB())
|
||||
}
|
||||
|
||||
sqlBefore, err := gormBefore.DB()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get underlying *sql.DB: %v", err)
|
||||
}
|
||||
|
||||
if err := sqlBefore.Close(); err != nil {
|
||||
t.Fatalf("Failed to close underlying database: %v", err)
|
||||
}
|
||||
|
||||
count, err := db.NewSelect().Table("sqlite_master").Count(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected GORM query builder to reconnect, got error: %v", err)
|
||||
}
|
||||
if count < 0 {
|
||||
t.Fatalf("Expected non-negative count, got %d", count)
|
||||
}
|
||||
|
||||
gormAfter, ok := adapter.GetUnderlyingDB().(*gorm.DB)
|
||||
if !ok {
|
||||
t.Fatalf("Expected reconnected *gorm.DB, got %T", adapter.GetUnderlyingDB())
|
||||
}
|
||||
|
||||
sqlAfter, err := gormAfter.DB()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get reconnected *sql.DB: %v", err)
|
||||
}
|
||||
|
||||
if sqlAfter == sqlBefore {
|
||||
t.Fatal("Expected GORM adapter to use a fresh *sql.DB after reconnect")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,7 +2,9 @@ package dbmanager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -366,8 +368,11 @@ func (m *connectionManager) performHealthCheck() {
|
||||
"connection", item.name,
|
||||
"error", err)
|
||||
|
||||
// Attempt reconnection if enabled
|
||||
if m.config.EnableAutoReconnect {
|
||||
// Only reconnect when the client handle itself is closed/disconnected.
|
||||
// For transient database restarts or network blips, *sql.DB can recover
|
||||
// on its own; forcing Close()+Connect() here invalidates any cached ORM
|
||||
// wrappers and callers that still hold the old handle.
|
||||
if m.config.EnableAutoReconnect && shouldReconnectAfterHealthCheck(err) {
|
||||
logger.Info("Attempting reconnection: connection=%s", item.name)
|
||||
if err := item.conn.Reconnect(ctx); err != nil {
|
||||
logger.Error("Reconnection failed",
|
||||
@@ -376,7 +381,21 @@ func (m *connectionManager) performHealthCheck() {
|
||||
} else {
|
||||
logger.Info("Reconnection successful: connection=%s", item.name)
|
||||
}
|
||||
} else if m.config.EnableAutoReconnect {
|
||||
logger.Info("Skipping reconnect for transient health check failure: connection=%s", item.name)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func shouldReconnectAfterHealthCheck(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if errors.Is(err, ErrConnectionClosed) {
|
||||
return true
|
||||
}
|
||||
|
||||
return strings.Contains(err.Error(), "sql: database is closed")
|
||||
}
|
||||
|
||||
@@ -3,12 +3,38 @@ package dbmanager
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/uptrace/bun"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
type healthCheckStubConnection struct {
|
||||
healthErr error
|
||||
reconnectCalls int
|
||||
}
|
||||
|
||||
func (c *healthCheckStubConnection) Name() string { return "stub" }
|
||||
func (c *healthCheckStubConnection) Type() DatabaseType { return DatabaseTypePostgreSQL }
|
||||
func (c *healthCheckStubConnection) Bun() (*bun.DB, error) { return nil, fmt.Errorf("not implemented") }
|
||||
func (c *healthCheckStubConnection) GORM() (*gorm.DB, error) { return nil, fmt.Errorf("not implemented") }
|
||||
func (c *healthCheckStubConnection) Native() (*sql.DB, error) { return nil, fmt.Errorf("not implemented") }
|
||||
func (c *healthCheckStubConnection) DB() (*sql.DB, error) { return nil, fmt.Errorf("not implemented") }
|
||||
func (c *healthCheckStubConnection) Database() (common.Database, error) { return nil, fmt.Errorf("not implemented") }
|
||||
func (c *healthCheckStubConnection) MongoDB() (*mongo.Client, error) { return nil, fmt.Errorf("not implemented") }
|
||||
func (c *healthCheckStubConnection) Connect(ctx context.Context) error { return nil }
|
||||
func (c *healthCheckStubConnection) Close() error { return nil }
|
||||
func (c *healthCheckStubConnection) HealthCheck(ctx context.Context) error { return c.healthErr }
|
||||
func (c *healthCheckStubConnection) Reconnect(ctx context.Context) error { c.reconnectCalls++; return nil }
|
||||
func (c *healthCheckStubConnection) Stats() *ConnectionStats { return &ConnectionStats{} }
|
||||
|
||||
func TestBackgroundHealthChecker(t *testing.T) {
|
||||
// Create a SQLite in-memory database
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
@@ -224,3 +250,41 @@ func TestManagerStatsAfterClose(t *testing.T) {
|
||||
t.Errorf("Expected 0 total connections after close, got %d", stats.TotalConnections)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPerformHealthCheckSkipsReconnectForTransientFailures(t *testing.T) {
|
||||
conn := &healthCheckStubConnection{
|
||||
healthErr: fmt.Errorf("connection 'primary' health check: dial tcp 127.0.0.1:5432: connect: connection refused"),
|
||||
}
|
||||
|
||||
mgr := &connectionManager{
|
||||
connections: map[string]Connection{"primary": conn},
|
||||
config: ManagerConfig{
|
||||
EnableAutoReconnect: true,
|
||||
},
|
||||
}
|
||||
|
||||
mgr.performHealthCheck()
|
||||
|
||||
if conn.reconnectCalls != 0 {
|
||||
t.Fatalf("expected no reconnect attempts for transient health failure, got %d", conn.reconnectCalls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPerformHealthCheckReconnectsClosedConnections(t *testing.T) {
|
||||
conn := &healthCheckStubConnection{
|
||||
healthErr: NewConnectionError("primary", "health check", fmt.Errorf("sql: database is closed")),
|
||||
}
|
||||
|
||||
mgr := &connectionManager{
|
||||
connections: map[string]Connection{"primary": conn},
|
||||
config: ManagerConfig{
|
||||
EnableAutoReconnect: true,
|
||||
},
|
||||
}
|
||||
|
||||
mgr.performHealthCheck()
|
||||
|
||||
if conn.reconnectCalls != 1 {
|
||||
t.Fatalf("expected reconnect attempt for closed database handle, got %d", conn.reconnectCalls)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package metrics
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
@@ -19,7 +20,7 @@ type Provider interface {
|
||||
DecRequestsInFlight()
|
||||
|
||||
// RecordDBQuery records metrics for a database query
|
||||
RecordDBQuery(operation, table string, duration time.Duration, err error)
|
||||
RecordDBQuery(operation, schema, entity, table string, duration time.Duration, err error)
|
||||
|
||||
// RecordCacheHit records a cache hit
|
||||
RecordCacheHit(provider string)
|
||||
@@ -46,21 +47,28 @@ type Provider interface {
|
||||
Handler() http.Handler
|
||||
}
|
||||
|
||||
// globalProvider is the global metrics provider
|
||||
var globalProvider Provider
|
||||
// globalProvider is the global metrics provider, protected by globalProviderMu.
|
||||
var (
|
||||
globalProviderMu sync.RWMutex
|
||||
globalProvider Provider
|
||||
)
|
||||
|
||||
// SetProvider sets the global metrics provider
|
||||
// SetProvider sets the global metrics provider.
|
||||
func SetProvider(p Provider) {
|
||||
globalProviderMu.Lock()
|
||||
globalProvider = p
|
||||
globalProviderMu.Unlock()
|
||||
}
|
||||
|
||||
// GetProvider returns the current metrics provider
|
||||
// GetProvider returns the current metrics provider.
|
||||
func GetProvider() Provider {
|
||||
if globalProvider == nil {
|
||||
// Return no-op provider if none is set
|
||||
globalProviderMu.RLock()
|
||||
p := globalProvider
|
||||
globalProviderMu.RUnlock()
|
||||
if p == nil {
|
||||
return &NoOpProvider{}
|
||||
}
|
||||
return globalProvider
|
||||
return p
|
||||
}
|
||||
|
||||
// NoOpProvider is a no-op implementation of Provider
|
||||
@@ -69,7 +77,7 @@ type NoOpProvider struct{}
|
||||
func (n *NoOpProvider) RecordHTTPRequest(method, path, status string, duration time.Duration) {}
|
||||
func (n *NoOpProvider) IncRequestsInFlight() {}
|
||||
func (n *NoOpProvider) DecRequestsInFlight() {}
|
||||
func (n *NoOpProvider) RecordDBQuery(operation, table string, duration time.Duration, err error) {
|
||||
func (n *NoOpProvider) RecordDBQuery(operation, schema, entity, table string, duration time.Duration, err error) {
|
||||
}
|
||||
func (n *NoOpProvider) RecordCacheHit(provider string) {}
|
||||
func (n *NoOpProvider) RecordCacheMiss(provider string) {}
|
||||
|
||||
@@ -83,14 +83,14 @@ func NewPrometheusProvider(cfg *Config) *PrometheusProvider {
|
||||
Help: "Database query duration in seconds",
|
||||
Buckets: cfg.DBQueryBuckets,
|
||||
},
|
||||
[]string{"operation", "table"},
|
||||
[]string{"operation", "schema", "entity", "table"},
|
||||
),
|
||||
dbQueryTotal: promauto.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Name: metricName("db_queries_total"),
|
||||
Help: "Total number of database queries",
|
||||
},
|
||||
[]string{"operation", "table", "status"},
|
||||
[]string{"operation", "schema", "entity", "table", "status"},
|
||||
),
|
||||
cacheHits: promauto.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
@@ -204,13 +204,13 @@ func (p *PrometheusProvider) DecRequestsInFlight() {
|
||||
}
|
||||
|
||||
// RecordDBQuery implements Provider interface
|
||||
func (p *PrometheusProvider) RecordDBQuery(operation, table string, duration time.Duration, err error) {
|
||||
func (p *PrometheusProvider) RecordDBQuery(operation, schema, entity, table string, duration time.Duration, err error) {
|
||||
status := "success"
|
||||
if err != nil {
|
||||
status = "error"
|
||||
}
|
||||
p.dbQueryDuration.WithLabelValues(operation, table).Observe(duration.Seconds())
|
||||
p.dbQueryTotal.WithLabelValues(operation, table, status).Inc()
|
||||
p.dbQueryDuration.WithLabelValues(operation, schema, entity, table).Observe(duration.Seconds())
|
||||
p.dbQueryTotal.WithLabelValues(operation, schema, entity, table, status).Inc()
|
||||
}
|
||||
|
||||
// RecordCacheHit implements Provider interface
|
||||
|
||||
@@ -143,6 +143,22 @@ func (a *DatabaseAuthenticator) reconnectDB() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *DatabaseAuthenticator) runDBOpWithReconnect(run func(*sql.DB) error) error {
|
||||
db := a.getDB()
|
||||
if db == nil {
|
||||
return fmt.Errorf("database connection is nil")
|
||||
}
|
||||
|
||||
err := run(db)
|
||||
if isDBClosed(err) {
|
||||
if reconnErr := a.reconnectDB(); reconnErr == nil {
|
||||
err = run(a.getDB())
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (a *DatabaseAuthenticator) Login(ctx context.Context, req LoginRequest) (*LoginResponse, error) {
|
||||
// Convert LoginRequest to JSON
|
||||
reqJSON, err := json.Marshal(req)
|
||||
@@ -154,16 +170,10 @@ func (a *DatabaseAuthenticator) Login(ctx context.Context, req LoginRequest) (*L
|
||||
var errorMsg sql.NullString
|
||||
var dataJSON sql.NullString
|
||||
|
||||
runLoginQuery := func() error {
|
||||
err = a.runDBOpWithReconnect(func(db *sql.DB) error {
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Login)
|
||||
return a.getDB().QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
|
||||
}
|
||||
err = runLoginQuery()
|
||||
if isDBClosed(err) {
|
||||
if reconnErr := a.reconnectDB(); reconnErr == nil {
|
||||
err = runLoginQuery()
|
||||
}
|
||||
}
|
||||
return db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("login query failed: %w", err)
|
||||
}
|
||||
@@ -196,8 +206,10 @@ func (a *DatabaseAuthenticator) Register(ctx context.Context, req RegisterReques
|
||||
var errorMsg sql.NullString
|
||||
var dataJSON sql.NullString
|
||||
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Register)
|
||||
err = a.getDB().QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
|
||||
err = a.runDBOpWithReconnect(func(db *sql.DB) error {
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Register)
|
||||
return db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("register query failed: %w", err)
|
||||
}
|
||||
@@ -229,8 +241,10 @@ func (a *DatabaseAuthenticator) Logout(ctx context.Context, req LogoutRequest) e
|
||||
var errorMsg sql.NullString
|
||||
var dataJSON sql.NullString
|
||||
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Logout)
|
||||
err = a.getDB().QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
|
||||
err = a.runDBOpWithReconnect(func(db *sql.DB) error {
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Logout)
|
||||
return db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("logout query failed: %w", err)
|
||||
}
|
||||
@@ -303,8 +317,10 @@ func (a *DatabaseAuthenticator) Authenticate(r *http.Request) (*UserContext, err
|
||||
var errorMsg sql.NullString
|
||||
var userJSON sql.NullString
|
||||
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2)`, a.sqlNames.Session)
|
||||
err := a.getDB().QueryRowContext(r.Context(), query, token, reference).Scan(&success, &errorMsg, &userJSON)
|
||||
err := a.runDBOpWithReconnect(func(db *sql.DB) error {
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2)`, a.sqlNames.Session)
|
||||
return db.QueryRowContext(r.Context(), query, token, reference).Scan(&success, &errorMsg, &userJSON)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("session query failed: %w", err)
|
||||
}
|
||||
@@ -379,8 +395,10 @@ func (a *DatabaseAuthenticator) updateSessionActivity(ctx context.Context, sessi
|
||||
var errorMsg sql.NullString
|
||||
var updatedUserJSON sql.NullString
|
||||
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2::jsonb)`, a.sqlNames.SessionUpdate)
|
||||
_ = a.getDB().QueryRowContext(ctx, query, sessionToken, string(userJSON)).Scan(&success, &errorMsg, &updatedUserJSON)
|
||||
_ = a.runDBOpWithReconnect(func(db *sql.DB) error {
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2::jsonb)`, a.sqlNames.SessionUpdate)
|
||||
return db.QueryRowContext(ctx, query, sessionToken, string(userJSON)).Scan(&success, &errorMsg, &updatedUserJSON)
|
||||
})
|
||||
}
|
||||
|
||||
// RefreshToken implements Refreshable interface
|
||||
@@ -390,8 +408,10 @@ func (a *DatabaseAuthenticator) RefreshToken(ctx context.Context, refreshToken s
|
||||
var errorMsg sql.NullString
|
||||
var userJSON sql.NullString
|
||||
// Get current session to pass to refresh
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2)`, a.sqlNames.Session)
|
||||
err := a.getDB().QueryRowContext(ctx, query, refreshToken, "refresh").Scan(&success, &errorMsg, &userJSON)
|
||||
err := a.runDBOpWithReconnect(func(db *sql.DB) error {
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2)`, a.sqlNames.Session)
|
||||
return db.QueryRowContext(ctx, query, refreshToken, "refresh").Scan(&success, &errorMsg, &userJSON)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("refresh token query failed: %w", err)
|
||||
}
|
||||
@@ -407,8 +427,10 @@ func (a *DatabaseAuthenticator) RefreshToken(ctx context.Context, refreshToken s
|
||||
var newErrorMsg sql.NullString
|
||||
var newUserJSON sql.NullString
|
||||
|
||||
refreshQuery := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2::jsonb)`, a.sqlNames.RefreshToken)
|
||||
err = a.getDB().QueryRowContext(ctx, refreshQuery, refreshToken, userJSON).Scan(&newSuccess, &newErrorMsg, &newUserJSON)
|
||||
err = a.runDBOpWithReconnect(func(db *sql.DB) error {
|
||||
refreshQuery := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2::jsonb)`, a.sqlNames.RefreshToken)
|
||||
return db.QueryRowContext(ctx, refreshQuery, refreshToken, userJSON).Scan(&newSuccess, &newErrorMsg, &newUserJSON)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("refresh token generation failed: %w", err)
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package security
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
@@ -790,6 +791,211 @@ func TestDatabaseAuthenticatorRefreshToken(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestDatabaseAuthenticatorReconnectsClosedDBPaths(t *testing.T) {
|
||||
newAuthWithReconnect := func(t *testing.T) (*DatabaseAuthenticator, sqlmock.Sqlmock, sqlmock.Sqlmock, func()) {
|
||||
t.Helper()
|
||||
|
||||
primaryDB, primaryMock, err := sqlmock.New()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create primary mock db: %v", err)
|
||||
}
|
||||
|
||||
reconnectDB, reconnectMock, err := sqlmock.New()
|
||||
if err != nil {
|
||||
primaryDB.Close()
|
||||
t.Fatalf("failed to create reconnect mock db: %v", err)
|
||||
}
|
||||
|
||||
cacheProvider := cache.NewMemoryProvider(&cache.Options{
|
||||
DefaultTTL: 1 * time.Minute,
|
||||
MaxSize: 1000,
|
||||
})
|
||||
|
||||
auth := NewDatabaseAuthenticatorWithOptions(primaryDB, DatabaseAuthenticatorOptions{
|
||||
Cache: cache.NewCache(cacheProvider),
|
||||
DBFactory: func() (*sql.DB, error) {
|
||||
return reconnectDB, nil
|
||||
},
|
||||
})
|
||||
|
||||
cleanup := func() {
|
||||
_ = primaryDB.Close()
|
||||
_ = reconnectDB.Close()
|
||||
}
|
||||
|
||||
return auth, primaryMock, reconnectMock, cleanup
|
||||
}
|
||||
|
||||
t.Run("Authenticate reconnects after closed database", func(t *testing.T) {
|
||||
auth, primaryMock, reconnectMock, cleanup := newAuthWithReconnect(t)
|
||||
defer cleanup()
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.Header.Set("Authorization", "Bearer reconnect-auth-token")
|
||||
|
||||
primaryMock.ExpectQuery(`SELECT p_success, p_error, p_user::text FROM resolvespec_session`).
|
||||
WithArgs("reconnect-auth-token", "authenticate").
|
||||
WillReturnError(fmt.Errorf("sql: database is closed"))
|
||||
|
||||
reconnectRows := sqlmock.NewRows([]string{"p_success", "p_error", "p_user"}).
|
||||
AddRow(true, nil, `{"user_id":7,"user_name":"reconnect-user","session_id":"reconnect-auth-token"}`)
|
||||
|
||||
reconnectMock.ExpectQuery(`SELECT p_success, p_error, p_user::text FROM resolvespec_session`).
|
||||
WithArgs("reconnect-auth-token", "authenticate").
|
||||
WillReturnRows(reconnectRows)
|
||||
|
||||
userCtx, err := auth.Authenticate(req)
|
||||
if err != nil {
|
||||
t.Fatalf("expected authenticate to reconnect, got %v", err)
|
||||
}
|
||||
if userCtx.UserID != 7 {
|
||||
t.Fatalf("expected user ID 7, got %d", userCtx.UserID)
|
||||
}
|
||||
|
||||
if err := primaryMock.ExpectationsWereMet(); err != nil {
|
||||
t.Fatalf("primary db expectations not met: %v", err)
|
||||
}
|
||||
if err := reconnectMock.ExpectationsWereMet(); err != nil {
|
||||
t.Fatalf("reconnect db expectations not met: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Register reconnects after closed database", func(t *testing.T) {
|
||||
auth, primaryMock, reconnectMock, cleanup := newAuthWithReconnect(t)
|
||||
defer cleanup()
|
||||
|
||||
req := RegisterRequest{
|
||||
Username: "reconnect-register",
|
||||
Password: "password123",
|
||||
Email: "reconnect@example.com",
|
||||
UserLevel: 1,
|
||||
Roles: []string{"user"},
|
||||
}
|
||||
|
||||
primaryMock.ExpectQuery(`SELECT p_success, p_error, p_data::text FROM resolvespec_register`).
|
||||
WithArgs(sqlmock.AnyArg()).
|
||||
WillReturnError(fmt.Errorf("sql: database is closed"))
|
||||
|
||||
reconnectRows := sqlmock.NewRows([]string{"p_success", "p_error", "p_data"}).
|
||||
AddRow(true, nil, `{"token":"reconnected-register-token","user":{"user_id":8,"user_name":"reconnect-register"},"expires_in":86400}`)
|
||||
|
||||
reconnectMock.ExpectQuery(`SELECT p_success, p_error, p_data::text FROM resolvespec_register`).
|
||||
WithArgs(sqlmock.AnyArg()).
|
||||
WillReturnRows(reconnectRows)
|
||||
|
||||
resp, err := auth.Register(context.Background(), req)
|
||||
if err != nil {
|
||||
t.Fatalf("expected register to reconnect, got %v", err)
|
||||
}
|
||||
if resp.Token != "reconnected-register-token" {
|
||||
t.Fatalf("expected refreshed token, got %s", resp.Token)
|
||||
}
|
||||
|
||||
if err := primaryMock.ExpectationsWereMet(); err != nil {
|
||||
t.Fatalf("primary db expectations not met: %v", err)
|
||||
}
|
||||
if err := reconnectMock.ExpectationsWereMet(); err != nil {
|
||||
t.Fatalf("reconnect db expectations not met: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Logout reconnects after closed database", func(t *testing.T) {
|
||||
auth, primaryMock, reconnectMock, cleanup := newAuthWithReconnect(t)
|
||||
defer cleanup()
|
||||
|
||||
req := LogoutRequest{Token: "logout-reconnect-token", UserID: 9}
|
||||
|
||||
primaryMock.ExpectQuery(`SELECT p_success, p_error, p_data::text FROM resolvespec_logout`).
|
||||
WithArgs(sqlmock.AnyArg()).
|
||||
WillReturnError(fmt.Errorf("sql: database is closed"))
|
||||
|
||||
reconnectRows := sqlmock.NewRows([]string{"p_success", "p_error", "p_data"}).
|
||||
AddRow(true, nil, nil)
|
||||
|
||||
reconnectMock.ExpectQuery(`SELECT p_success, p_error, p_data::text FROM resolvespec_logout`).
|
||||
WithArgs(sqlmock.AnyArg()).
|
||||
WillReturnRows(reconnectRows)
|
||||
|
||||
if err := auth.Logout(context.Background(), req); err != nil {
|
||||
t.Fatalf("expected logout to reconnect, got %v", err)
|
||||
}
|
||||
|
||||
if err := primaryMock.ExpectationsWereMet(); err != nil {
|
||||
t.Fatalf("primary db expectations not met: %v", err)
|
||||
}
|
||||
if err := reconnectMock.ExpectationsWereMet(); err != nil {
|
||||
t.Fatalf("reconnect db expectations not met: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("RefreshToken reconnects after closed database", func(t *testing.T) {
|
||||
auth, primaryMock, reconnectMock, cleanup := newAuthWithReconnect(t)
|
||||
defer cleanup()
|
||||
|
||||
refreshToken := "refresh-reconnect-token"
|
||||
|
||||
primaryMock.ExpectQuery(`SELECT p_success, p_error, p_user::text FROM resolvespec_session`).
|
||||
WithArgs(refreshToken, "refresh").
|
||||
WillReturnError(fmt.Errorf("sql: database is closed"))
|
||||
|
||||
sessionRows := sqlmock.NewRows([]string{"p_success", "p_error", "p_user"}).
|
||||
AddRow(true, nil, `{"user_id":10,"user_name":"refresh-user"}`)
|
||||
|
||||
reconnectMock.ExpectQuery(`SELECT p_success, p_error, p_user::text FROM resolvespec_session`).
|
||||
WithArgs(refreshToken, "refresh").
|
||||
WillReturnRows(sessionRows)
|
||||
|
||||
refreshRows := sqlmock.NewRows([]string{"p_success", "p_error", "p_user"}).
|
||||
AddRow(true, nil, `{"user_id":10,"user_name":"refresh-user","session_id":"refreshed-token"}`)
|
||||
|
||||
reconnectMock.ExpectQuery(`SELECT p_success, p_error, p_user::text FROM resolvespec_refresh_token`).
|
||||
WithArgs(refreshToken, sqlmock.AnyArg()).
|
||||
WillReturnRows(refreshRows)
|
||||
|
||||
resp, err := auth.RefreshToken(context.Background(), refreshToken)
|
||||
if err != nil {
|
||||
t.Fatalf("expected refresh token to reconnect, got %v", err)
|
||||
}
|
||||
if resp.Token != "refreshed-token" {
|
||||
t.Fatalf("expected refreshed-token, got %s", resp.Token)
|
||||
}
|
||||
|
||||
if err := primaryMock.ExpectationsWereMet(); err != nil {
|
||||
t.Fatalf("primary db expectations not met: %v", err)
|
||||
}
|
||||
if err := reconnectMock.ExpectationsWereMet(); err != nil {
|
||||
t.Fatalf("reconnect db expectations not met: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("updateSessionActivity reconnects after closed database", func(t *testing.T) {
|
||||
auth, primaryMock, reconnectMock, cleanup := newAuthWithReconnect(t)
|
||||
defer cleanup()
|
||||
|
||||
userCtx := &UserContext{UserID: 11, UserName: "activity-user", SessionID: "activity-token"}
|
||||
|
||||
primaryMock.ExpectQuery(`SELECT p_success, p_error, p_user::text FROM resolvespec_session_update`).
|
||||
WithArgs("activity-token", sqlmock.AnyArg()).
|
||||
WillReturnError(fmt.Errorf("sql: database is closed"))
|
||||
|
||||
reconnectRows := sqlmock.NewRows([]string{"p_success", "p_error", "p_user"}).
|
||||
AddRow(true, nil, `{"user_id":11,"user_name":"activity-user","session_id":"activity-token"}`)
|
||||
|
||||
reconnectMock.ExpectQuery(`SELECT p_success, p_error, p_user::text FROM resolvespec_session_update`).
|
||||
WithArgs("activity-token", sqlmock.AnyArg()).
|
||||
WillReturnRows(reconnectRows)
|
||||
|
||||
auth.updateSessionActivity(context.Background(), "activity-token", userCtx)
|
||||
|
||||
if err := primaryMock.ExpectationsWereMet(); err != nil {
|
||||
t.Fatalf("primary db expectations not met: %v", err)
|
||||
}
|
||||
if err := reconnectMock.ExpectationsWereMet(); err != nil {
|
||||
t.Fatalf("reconnect db expectations not met: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Test JWTAuthenticator
|
||||
func TestJWTAuthenticator(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
|
||||
Reference in New Issue
Block a user