feat(db): add query metrics tracking for database operations

* Introduced metrics tracking for SELECT, INSERT, UPDATE, and DELETE operations.
* Added methods to enable or disable metrics on the PgSQLAdapter.
* Created a new query_metrics.go file to handle metrics recording logic.
* Updated interfaces and implementations to support schema and entity tracking.
* Added tests to verify metrics recording functionality.
This commit is contained in:
Hein
2026-04-10 13:51:46 +02:00
parent 4fc25c60ae
commit e8d0ab28c3
8 changed files with 864 additions and 201 deletions

View File

@@ -95,15 +95,16 @@ func debugScanIntoStruct(rows interface{}, dest interface{}) error {
// BunAdapter adapts Bun to work with our Database interface
// This demonstrates how the abstraction works with different ORMs
type BunAdapter struct {
db *bun.DB
dbMu sync.RWMutex
dbFactory func() (*bun.DB, error)
driverName string
db *bun.DB
dbMu sync.RWMutex
dbFactory func() (*bun.DB, error)
driverName string
metricsEnabled bool
}
// NewBunAdapter creates a new Bun adapter
func NewBunAdapter(db *bun.DB) *BunAdapter {
adapter := &BunAdapter{db: db}
adapter := &BunAdapter{db: db, metricsEnabled: true}
// Initialize driver name
adapter.driverName = adapter.DriverName()
return adapter
@@ -115,6 +116,22 @@ func (b *BunAdapter) WithDBFactory(factory func() (*bun.DB, error)) *BunAdapter
return b
}
// SetMetricsEnabled enables or disables query metrics for this adapter.
func (b *BunAdapter) SetMetricsEnabled(enabled bool) *BunAdapter {
b.metricsEnabled = enabled
return b
}
// EnableMetrics enables query metrics for this adapter.
func (b *BunAdapter) EnableMetrics() *BunAdapter {
return b.SetMetricsEnabled(true)
}
// DisableMetrics disables query metrics for this adapter.
func (b *BunAdapter) DisableMetrics() *BunAdapter {
return b.SetMetricsEnabled(false)
}
func (b *BunAdapter) getDB() *bun.DB {
b.dbMu.RLock()
defer b.dbMu.RUnlock()
@@ -159,22 +176,23 @@ func (b *BunAdapter) DisableQueryDebug() {
func (b *BunAdapter) NewSelect() common.SelectQuery {
return &BunSelectQuery{
query: b.getDB().NewSelect(),
db: b.db,
driverName: b.driverName,
query: b.getDB().NewSelect(),
db: b.db,
driverName: b.driverName,
metricsEnabled: b.metricsEnabled,
}
}
func (b *BunAdapter) NewInsert() common.InsertQuery {
return &BunInsertQuery{query: b.getDB().NewInsert()}
return &BunInsertQuery{query: b.getDB().NewInsert(), driverName: b.driverName, metricsEnabled: b.metricsEnabled}
}
func (b *BunAdapter) NewUpdate() common.UpdateQuery {
return &BunUpdateQuery{query: b.getDB().NewUpdate()}
return &BunUpdateQuery{query: b.getDB().NewUpdate(), driverName: b.driverName, metricsEnabled: b.metricsEnabled}
}
func (b *BunAdapter) NewDelete() common.DeleteQuery {
return &BunDeleteQuery{query: b.getDB().NewDelete()}
return &BunDeleteQuery{query: b.getDB().NewDelete(), driverName: b.driverName, metricsEnabled: b.metricsEnabled}
}
func (b *BunAdapter) Exec(ctx context.Context, query string, args ...interface{}) (res common.Result, err error) {
@@ -183,6 +201,8 @@ func (b *BunAdapter) Exec(ctx context.Context, query string, args ...interface{}
err = logger.HandlePanic("BunAdapter.Exec", r)
}
}()
startedAt := time.Now()
operation, schema, entity, table := metricTargetFromRawQuery(query, b.driverName)
var result sql.Result
run := func() error { var e error; result, e = b.getDB().ExecContext(ctx, query, args...); return e }
err = run()
@@ -191,6 +211,7 @@ func (b *BunAdapter) Exec(ctx context.Context, query string, args ...interface{}
err = run()
}
}
recordQueryMetrics(b.metricsEnabled, operation, schema, entity, table, startedAt, err)
return &BunResult{result: result}, err
}
@@ -200,12 +221,15 @@ func (b *BunAdapter) Query(ctx context.Context, dest interface{}, query string,
err = logger.HandlePanic("BunAdapter.Query", r)
}
}()
startedAt := time.Now()
operation, schema, entity, table := metricTargetFromRawQuery(query, b.driverName)
err = b.getDB().NewRaw(query, args...).Scan(ctx, dest)
if isDBClosed(err) {
if reconnErr := b.reconnectDB(); reconnErr == nil {
err = b.getDB().NewRaw(query, args...).Scan(ctx, dest)
}
}
recordQueryMetrics(b.metricsEnabled, operation, schema, entity, table, startedAt, err)
return err
}
@@ -219,7 +243,7 @@ func (b *BunAdapter) BeginTx(ctx context.Context) (common.Database, error) {
if err != nil {
return nil, err
}
return &BunTxAdapter{tx: tx, driverName: b.driverName}, nil
return &BunTxAdapter{tx: tx, driverName: b.driverName, metricsEnabled: b.metricsEnabled}, nil
}
func (b *BunAdapter) CommitTx(ctx context.Context) error {
@@ -242,7 +266,7 @@ func (b *BunAdapter) RunInTransaction(ctx context.Context, fn func(common.Databa
}()
run := func() error {
return b.getDB().RunInTx(ctx, &sql.TxOptions{}, func(ctx context.Context, tx bun.Tx) error {
adapter := &BunTxAdapter{tx: tx, driverName: b.driverName}
adapter := &BunTxAdapter{tx: tx, driverName: b.driverName, metricsEnabled: b.metricsEnabled}
return fn(adapter)
})
}
@@ -280,25 +304,24 @@ type BunSelectQuery struct {
hasModel bool // Track if Model() was called
schema string // Separated schema name
tableName string // Just the table name, without schema
entity string
tableAlias string
driverName string // Database driver name (postgres, sqlite, mssql)
inJoinContext bool // Track if we're in a JOIN relation context
joinTableAlias string // Alias to use for JOIN conditions
skipAutoDetect bool // Skip auto-detection to prevent circular calls
customPreloads map[string][]func(common.SelectQuery) common.SelectQuery // Relations to load with custom implementation
metricsEnabled bool
}
func (b *BunSelectQuery) Model(model interface{}) common.SelectQuery {
b.query = b.query.Model(model)
b.hasModel = true // Mark that we have a model
// Try to get table name from model if it implements TableNameProvider
if provider, ok := model.(common.TableNameProvider); ok {
fullTableName := provider.TableName()
// Check if the table name contains schema (e.g., "schema.table")
// For SQLite, this will convert "schema.table" to "schema_table"
b.schema, b.tableName = parseTableName(fullTableName, b.driverName)
b.schema, b.tableName = schemaAndTableFromModel(model, b.driverName)
if b.tableName == "" {
b.schema, b.tableName = parseTableName(b.query.GetTableName(), b.driverName)
}
b.entity = entityNameFromModel(model, b.tableName)
if provider, ok := model.(common.TableAliasProvider); ok {
b.tableAlias = provider.TableAlias()
@@ -312,6 +335,9 @@ func (b *BunSelectQuery) Table(table string) common.SelectQuery {
// Check if the table name contains schema (e.g., "schema.table")
// For SQLite, this will convert "schema.table" to "schema_table"
b.schema, b.tableName = parseTableName(table, b.driverName)
if b.entity == "" {
b.entity = cleanMetricIdentifier(b.tableName)
}
return b
}
@@ -617,9 +643,10 @@ func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.S
// Wrap the incoming *bun.SelectQuery in our adapter
wrapper := &BunSelectQuery{
query: sq,
db: b.db,
driverName: b.driverName,
query: sq,
db: b.db,
driverName: b.driverName,
metricsEnabled: b.metricsEnabled,
}
// Try to extract table name and alias from the preload model
@@ -870,7 +897,7 @@ func (b *BunSelectQuery) loadRelationLevel(ctx context.Context, parentRecords re
// Apply user's functions (if any)
if isLast && len(applyFuncs) > 0 {
wrapper := &BunSelectQuery{query: query, db: b.db, driverName: b.driverName}
wrapper := &BunSelectQuery{query: query, db: b.db, driverName: b.driverName, metricsEnabled: b.metricsEnabled}
for _, fn := range applyFuncs {
if fn != nil {
wrapper = fn(wrapper).(*BunSelectQuery)
@@ -1227,6 +1254,7 @@ func (b *BunSelectQuery) Scan(ctx context.Context, dest interface{}) (err error)
err = logger.HandlePanic("BunSelectQuery.Scan", r)
}
}()
startedAt := time.Now()
if dest == nil {
return fmt.Errorf("destination cannot be nil")
}
@@ -1236,9 +1264,11 @@ func (b *BunSelectQuery) Scan(ctx context.Context, dest interface{}) (err error)
// Log SQL string for debugging
sqlStr := b.query.String()
logger.Error("BunSelectQuery.Scan failed. SQL: %s. Error: %v", sqlStr, err)
recordQueryMetrics(b.metricsEnabled, "SELECT", b.schema, b.entity, b.tableName, startedAt, err)
return err
}
recordQueryMetrics(b.metricsEnabled, "SELECT", b.schema, b.entity, b.tableName, startedAt, nil)
return nil
}
@@ -1276,6 +1306,7 @@ func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) {
if b.query.GetModel() == nil {
return fmt.Errorf("model is nil")
}
startedAt := time.Now()
// Optional: Enable detailed field-level debugging (set to true to debug)
const enableDetailedDebug = true
@@ -1293,6 +1324,7 @@ func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) {
// Log SQL string for debugging
sqlStr := b.query.String()
logger.Error("BunSelectQuery.ScanModel failed. SQL: %s. Error: %v", sqlStr, err)
recordQueryMetrics(b.metricsEnabled, "SELECT", b.schema, b.entity, b.tableName, startedAt, err)
return err
}
@@ -1301,10 +1333,12 @@ func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) {
logger.Info("Loading %d custom preload(s) with separate queries", len(b.customPreloads))
if err := b.loadCustomPreloads(ctx); err != nil {
logger.Error("Failed to load custom preloads: %v", err)
recordQueryMetrics(b.metricsEnabled, "SELECT", b.schema, b.entity, b.tableName, startedAt, err)
return err
}
}
recordQueryMetrics(b.metricsEnabled, "SELECT", b.schema, b.entity, b.tableName, startedAt, nil)
return nil
}
@@ -1315,6 +1349,7 @@ func (b *BunSelectQuery) Count(ctx context.Context) (count int, err error) {
count = 0
}
}()
startedAt := time.Now()
// If Model() was set, use bun's native Count() which works properly
if b.hasModel {
count, err := b.query.Count(ctx)
@@ -1323,6 +1358,7 @@ func (b *BunSelectQuery) Count(ctx context.Context) (count int, err error) {
sqlStr := b.query.String()
logger.Error("BunSelectQuery.Count failed. SQL: %s. Error: %v", sqlStr, err)
}
recordQueryMetrics(b.metricsEnabled, "COUNT", b.schema, b.entity, b.tableName, startedAt, err)
return count, err
}
@@ -1337,6 +1373,7 @@ func (b *BunSelectQuery) Count(ctx context.Context) (count int, err error) {
sqlStr := countQuery.String()
logger.Error("BunSelectQuery.Count (subquery) failed. SQL: %s. Error: %v", sqlStr, err)
}
recordQueryMetrics(b.metricsEnabled, "COUNT", b.schema, b.entity, b.tableName, startedAt, err)
return count, err
}
@@ -1347,25 +1384,37 @@ func (b *BunSelectQuery) Exists(ctx context.Context) (exists bool, err error) {
exists = false
}
}()
startedAt := time.Now()
exists, err = b.query.Exists(ctx)
if err != nil {
// Log SQL string for debugging
sqlStr := b.query.String()
logger.Error("BunSelectQuery.Exists failed. SQL: %s. Error: %v", sqlStr, err)
}
recordQueryMetrics(b.metricsEnabled, "EXISTS", b.schema, b.entity, b.tableName, startedAt, err)
return exists, err
}
// BunInsertQuery implements InsertQuery for Bun
type BunInsertQuery struct {
query *bun.InsertQuery
values map[string]interface{}
hasModel bool
query *bun.InsertQuery
values map[string]interface{}
hasModel bool
driverName string
schema string
tableName string
entity string
metricsEnabled bool
}
func (b *BunInsertQuery) Model(model interface{}) common.InsertQuery {
b.query = b.query.Model(model)
b.hasModel = true
b.schema, b.tableName = schemaAndTableFromModel(model, b.driverName)
if b.tableName == "" {
b.schema, b.tableName = parseTableName(b.query.GetTableName(), b.driverName)
}
b.entity = entityNameFromModel(model, b.tableName)
return b
}
@@ -1374,6 +1423,10 @@ func (b *BunInsertQuery) Table(table string) common.InsertQuery {
return b
}
b.query = b.query.Table(table)
b.schema, b.tableName = parseTableName(table, b.driverName)
if b.entity == "" {
b.entity = cleanMetricIdentifier(b.tableName)
}
return b
}
@@ -1403,6 +1456,7 @@ func (b *BunInsertQuery) Exec(ctx context.Context) (res common.Result, err error
err = logger.HandlePanic("BunInsertQuery.Exec", r)
}
}()
startedAt := time.Now()
if len(b.values) > 0 {
if !b.hasModel {
// If no model was set, use the values map as the model
@@ -1416,29 +1470,45 @@ func (b *BunInsertQuery) Exec(ctx context.Context) (res common.Result, err error
}
}
result, err := b.query.Exec(ctx)
recordQueryMetrics(b.metricsEnabled, "INSERT", b.schema, b.entity, b.tableName, startedAt, err)
return &BunResult{result: result}, err
}
// BunUpdateQuery implements UpdateQuery for Bun
type BunUpdateQuery struct {
query *bun.UpdateQuery
model interface{}
query *bun.UpdateQuery
model interface{}
driverName string
schema string
tableName string
entity string
metricsEnabled bool
}
func (b *BunUpdateQuery) Model(model interface{}) common.UpdateQuery {
b.query = b.query.Model(model)
b.model = model
b.schema, b.tableName = schemaAndTableFromModel(model, b.driverName)
if b.tableName == "" {
b.schema, b.tableName = parseTableName(b.query.GetTableName(), b.driverName)
}
b.entity = entityNameFromModel(model, b.tableName)
return b
}
func (b *BunUpdateQuery) Table(table string) common.UpdateQuery {
b.query = b.query.Table(table)
b.schema, b.tableName = parseTableName(table, b.driverName)
if b.entity == "" {
b.entity = cleanMetricIdentifier(b.tableName)
}
if b.model == nil {
// Try to get table name from table string if model is not set
model, err := modelregistry.GetModelByName(table)
if err == nil {
b.model = model
b.entity = entityNameFromModel(model, b.tableName)
}
}
return b
@@ -1489,27 +1559,43 @@ func (b *BunUpdateQuery) Exec(ctx context.Context) (res common.Result, err error
err = logger.HandlePanic("BunUpdateQuery.Exec", r)
}
}()
startedAt := time.Now()
result, err := b.query.Exec(ctx)
if err != nil {
// Log SQL string for debugging
sqlStr := b.query.String()
logger.Error("BunUpdateQuery.Exec failed. SQL: %s. Error: %v", sqlStr, err)
}
recordQueryMetrics(b.metricsEnabled, "UPDATE", b.schema, b.entity, b.tableName, startedAt, err)
return &BunResult{result: result}, err
}
// BunDeleteQuery implements DeleteQuery for Bun
type BunDeleteQuery struct {
query *bun.DeleteQuery
query *bun.DeleteQuery
driverName string
schema string
tableName string
entity string
metricsEnabled bool
}
func (b *BunDeleteQuery) Model(model interface{}) common.DeleteQuery {
b.query = b.query.Model(model)
b.schema, b.tableName = schemaAndTableFromModel(model, b.driverName)
if b.tableName == "" {
b.schema, b.tableName = parseTableName(b.query.GetTableName(), b.driverName)
}
b.entity = entityNameFromModel(model, b.tableName)
return b
}
func (b *BunDeleteQuery) Table(table string) common.DeleteQuery {
b.query = b.query.Table(table)
b.schema, b.tableName = parseTableName(table, b.driverName)
if b.entity == "" {
b.entity = cleanMetricIdentifier(b.tableName)
}
return b
}
@@ -1524,12 +1610,14 @@ func (b *BunDeleteQuery) Exec(ctx context.Context) (res common.Result, err error
err = logger.HandlePanic("BunDeleteQuery.Exec", r)
}
}()
startedAt := time.Now()
result, err := b.query.Exec(ctx)
if err != nil {
// Log SQL string for debugging
sqlStr := b.query.String()
logger.Error("BunDeleteQuery.Exec failed. SQL: %s. Error: %v", sqlStr, err)
}
recordQueryMetrics(b.metricsEnabled, "DELETE", b.schema, b.entity, b.tableName, startedAt, err)
return &BunResult{result: result}, err
}
@@ -1555,37 +1643,46 @@ func (b *BunResult) LastInsertId() (int64, error) {
// BunTxAdapter wraps a Bun transaction to implement the Database interface
type BunTxAdapter struct {
tx bun.Tx
driverName string
tx bun.Tx
driverName string
metricsEnabled bool
}
func (b *BunTxAdapter) NewSelect() common.SelectQuery {
return &BunSelectQuery{
query: b.tx.NewSelect(),
db: b.tx,
driverName: b.driverName,
query: b.tx.NewSelect(),
db: b.tx,
driverName: b.driverName,
metricsEnabled: b.metricsEnabled,
}
}
func (b *BunTxAdapter) NewInsert() common.InsertQuery {
return &BunInsertQuery{query: b.tx.NewInsert()}
return &BunInsertQuery{query: b.tx.NewInsert(), driverName: b.driverName, metricsEnabled: b.metricsEnabled}
}
func (b *BunTxAdapter) NewUpdate() common.UpdateQuery {
return &BunUpdateQuery{query: b.tx.NewUpdate()}
return &BunUpdateQuery{query: b.tx.NewUpdate(), driverName: b.driverName, metricsEnabled: b.metricsEnabled}
}
func (b *BunTxAdapter) NewDelete() common.DeleteQuery {
return &BunDeleteQuery{query: b.tx.NewDelete()}
return &BunDeleteQuery{query: b.tx.NewDelete(), driverName: b.driverName, metricsEnabled: b.metricsEnabled}
}
func (b *BunTxAdapter) Exec(ctx context.Context, query string, args ...interface{}) (common.Result, error) {
startedAt := time.Now()
operation, schema, entity, table := metricTargetFromRawQuery(query, b.driverName)
result, err := b.tx.ExecContext(ctx, query, args...)
recordQueryMetrics(b.metricsEnabled, operation, schema, entity, table, startedAt, err)
return &BunResult{result: result}, err
}
func (b *BunTxAdapter) Query(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
return b.tx.NewRaw(query, args...).Scan(ctx, dest)
startedAt := time.Now()
operation, schema, entity, table := metricTargetFromRawQuery(query, b.driverName)
err := b.tx.NewRaw(query, args...).Scan(ctx, dest)
recordQueryMetrics(b.metricsEnabled, operation, schema, entity, table, startedAt, err)
return err
}
func (b *BunTxAdapter) BeginTx(ctx context.Context) (common.Database, error) {

View File

@@ -5,6 +5,7 @@ import (
"fmt"
"strings"
"sync"
"time"
"gorm.io/gorm"
@@ -16,15 +17,16 @@ import (
// GormAdapter adapts GORM to work with our Database interface
type GormAdapter struct {
dbMu sync.RWMutex
db *gorm.DB
dbFactory func() (*gorm.DB, error)
driverName string
dbMu sync.RWMutex
db *gorm.DB
dbFactory func() (*gorm.DB, error)
driverName string
metricsEnabled bool
}
// NewGormAdapter creates a new GORM adapter
func NewGormAdapter(db *gorm.DB) *GormAdapter {
adapter := &GormAdapter{db: db}
adapter := &GormAdapter{db: db, metricsEnabled: true}
// Initialize driver name
adapter.driverName = adapter.DriverName()
return adapter
@@ -36,6 +38,22 @@ func (g *GormAdapter) WithDBFactory(factory func() (*gorm.DB, error)) *GormAdapt
return g
}
// SetMetricsEnabled enables or disables query metrics for this adapter.
func (g *GormAdapter) SetMetricsEnabled(enabled bool) *GormAdapter {
g.metricsEnabled = enabled
return g
}
// EnableMetrics enables query metrics for this adapter.
func (g *GormAdapter) EnableMetrics() *GormAdapter {
return g.SetMetricsEnabled(true)
}
// DisableMetrics disables query metrics for this adapter.
func (g *GormAdapter) DisableMetrics() *GormAdapter {
return g.SetMetricsEnabled(false)
}
func (g *GormAdapter) getDB() *gorm.DB {
g.dbMu.RLock()
defer g.dbMu.RUnlock()
@@ -109,19 +127,19 @@ func (g *GormAdapter) DisableQueryDebug() *GormAdapter {
}
func (g *GormAdapter) NewSelect() common.SelectQuery {
return &GormSelectQuery{db: g.getDB(), driverName: g.driverName, reconnect: g.reconnectDB}
return &GormSelectQuery{db: g.getDB(), driverName: g.driverName, reconnect: g.reconnectDB, metricsEnabled: g.metricsEnabled}
}
func (g *GormAdapter) NewInsert() common.InsertQuery {
return &GormInsertQuery{db: g.getDB(), reconnect: g.reconnectDB}
return &GormInsertQuery{db: g.getDB(), reconnect: g.reconnectDB, metricsEnabled: g.metricsEnabled}
}
func (g *GormAdapter) NewUpdate() common.UpdateQuery {
return &GormUpdateQuery{db: g.getDB(), reconnect: g.reconnectDB}
return &GormUpdateQuery{db: g.getDB(), reconnect: g.reconnectDB, metricsEnabled: g.metricsEnabled}
}
func (g *GormAdapter) NewDelete() common.DeleteQuery {
return &GormDeleteQuery{db: g.getDB(), reconnect: g.reconnectDB}
return &GormDeleteQuery{db: g.getDB(), reconnect: g.reconnectDB, metricsEnabled: g.metricsEnabled}
}
func (g *GormAdapter) Exec(ctx context.Context, query string, args ...interface{}) (res common.Result, err error) {
@@ -130,6 +148,8 @@ func (g *GormAdapter) Exec(ctx context.Context, query string, args ...interface{
err = logger.HandlePanic("GormAdapter.Exec", r)
}
}()
startedAt := time.Now()
operation, schema, entity, table := metricTargetFromRawQuery(query, g.driverName)
run := func() *gorm.DB {
return g.getDB().WithContext(ctx).Exec(query, args...)
}
@@ -139,6 +159,7 @@ func (g *GormAdapter) Exec(ctx context.Context, query string, args ...interface{
result = run()
}
}
recordQueryMetrics(g.metricsEnabled, operation, schema, entity, table, startedAt, result.Error)
return &GormResult{result: result}, result.Error
}
@@ -148,6 +169,8 @@ func (g *GormAdapter) Query(ctx context.Context, dest interface{}, query string,
err = logger.HandlePanic("GormAdapter.Query", r)
}
}()
startedAt := time.Now()
operation, schema, entity, table := metricTargetFromRawQuery(query, g.driverName)
run := func() error {
return g.getDB().WithContext(ctx).Raw(query, args...).Find(dest).Error
}
@@ -157,6 +180,7 @@ func (g *GormAdapter) Query(ctx context.Context, dest interface{}, query string,
err = run()
}
}
recordQueryMetrics(g.metricsEnabled, operation, schema, entity, table, startedAt, err)
return err
}
@@ -173,7 +197,7 @@ func (g *GormAdapter) BeginTx(ctx context.Context) (common.Database, error) {
if tx.Error != nil {
return nil, tx.Error
}
return &GormAdapter{db: tx, dbFactory: g.dbFactory, driverName: g.driverName}, nil
return &GormAdapter{db: tx, dbFactory: g.dbFactory, driverName: g.driverName, metricsEnabled: g.metricsEnabled}, nil
}
func (g *GormAdapter) CommitTx(ctx context.Context) error {
@@ -192,7 +216,7 @@ func (g *GormAdapter) RunInTransaction(ctx context.Context, fn func(common.Datab
}()
run := func() error {
return g.getDB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
adapter := &GormAdapter{db: tx, dbFactory: g.dbFactory, driverName: g.driverName}
adapter := &GormAdapter{db: tx, dbFactory: g.dbFactory, driverName: g.driverName, metricsEnabled: g.metricsEnabled}
return fn(adapter)
})
}
@@ -236,22 +260,18 @@ type GormSelectQuery struct {
reconnect func(...*gorm.DB) error
schema string // Separated schema name
tableName string // Just the table name, without schema
entity string
tableAlias string
driverName string // Database driver name (postgres, sqlite, mssql)
inJoinContext bool // Track if we're in a JOIN relation context
joinTableAlias string // Alias to use for JOIN conditions
metricsEnabled bool
}
func (g *GormSelectQuery) Model(model interface{}) common.SelectQuery {
g.db = g.db.Model(model)
// Try to get table name from model if it implements TableNameProvider
if provider, ok := model.(common.TableNameProvider); ok {
fullTableName := provider.TableName()
// Check if the table name contains schema (e.g., "schema.table")
// For SQLite, this will convert "schema.table" to "schema_table"
g.schema, g.tableName = parseTableName(fullTableName, g.driverName)
}
g.schema, g.tableName = schemaAndTableFromModel(model, g.driverName)
g.entity = entityNameFromModel(model, g.tableName)
if provider, ok := model.(common.TableAliasProvider); ok {
g.tableAlias = provider.TableAlias()
@@ -265,6 +285,9 @@ func (g *GormSelectQuery) Table(table string) common.SelectQuery {
// Check if the table name contains schema (e.g., "schema.table")
// For SQLite, this will convert "schema.table" to "schema_table"
g.schema, g.tableName = parseTableName(table, g.driverName)
if g.entity == "" {
g.entity = cleanMetricIdentifier(g.tableName)
}
return g
}
@@ -450,9 +473,10 @@ func (g *GormSelectQuery) PreloadRelation(relation string, apply ...func(common.
}
wrapper := &GormSelectQuery{
db: db,
reconnect: g.reconnect,
driverName: g.driverName,
db: db,
reconnect: g.reconnect,
driverName: g.driverName,
metricsEnabled: g.metricsEnabled,
}
current := common.SelectQuery(wrapper)
@@ -494,6 +518,7 @@ func (g *GormSelectQuery) JoinRelation(relation string, apply ...func(common.Sel
driverName: g.driverName,
inJoinContext: true, // Mark as JOIN context
joinTableAlias: strings.ToLower(relation), // Use relation name as alias
metricsEnabled: g.metricsEnabled,
}
current := common.SelectQuery(wrapper)
@@ -550,6 +575,7 @@ func (g *GormSelectQuery) Scan(ctx context.Context, dest interface{}) (err error
err = logger.HandlePanic("GormSelectQuery.Scan", r)
}
}()
startedAt := time.Now()
run := func() error {
return g.db.WithContext(ctx).Find(dest).Error
}
@@ -566,6 +592,7 @@ func (g *GormSelectQuery) Scan(ctx context.Context, dest interface{}) (err error
})
logger.Error("GormSelectQuery.Scan failed. SQL: %s. Error: %v", sqlStr, err)
}
recordQueryMetrics(g.metricsEnabled, "SELECT", g.schema, g.entity, g.tableName, startedAt, err)
return err
}
@@ -578,6 +605,7 @@ func (g *GormSelectQuery) ScanModel(ctx context.Context) (err error) {
if g.db.Statement.Model == nil {
return fmt.Errorf("ScanModel requires Model() to be set before scanning")
}
startedAt := time.Now()
run := func() error {
return g.db.WithContext(ctx).Find(g.db.Statement.Model).Error
}
@@ -594,6 +622,7 @@ func (g *GormSelectQuery) ScanModel(ctx context.Context) (err error) {
})
logger.Error("GormSelectQuery.ScanModel failed. SQL: %s. Error: %v", sqlStr, err)
}
recordQueryMetrics(g.metricsEnabled, "SELECT", g.schema, g.entity, g.tableName, startedAt, err)
return err
}
@@ -604,6 +633,7 @@ func (g *GormSelectQuery) Count(ctx context.Context) (count int, err error) {
count = 0
}
}()
startedAt := time.Now()
var count64 int64
run := func() error {
return g.db.WithContext(ctx).Count(&count64).Error
@@ -621,6 +651,7 @@ func (g *GormSelectQuery) Count(ctx context.Context) (count int, err error) {
})
logger.Error("GormSelectQuery.Count failed. SQL: %s. Error: %v", sqlStr, err)
}
recordQueryMetrics(g.metricsEnabled, "COUNT", g.schema, g.entity, g.tableName, startedAt, err)
return int(count64), err
}
@@ -631,6 +662,7 @@ func (g *GormSelectQuery) Exists(ctx context.Context) (exists bool, err error) {
exists = false
}
}()
startedAt := time.Now()
var count int64
run := func() error {
return g.db.WithContext(ctx).Limit(1).Count(&count).Error
@@ -648,25 +680,36 @@ func (g *GormSelectQuery) Exists(ctx context.Context) (exists bool, err error) {
})
logger.Error("GormSelectQuery.Exists failed. SQL: %s. Error: %v", sqlStr, err)
}
recordQueryMetrics(g.metricsEnabled, "EXISTS", g.schema, g.entity, g.tableName, startedAt, err)
return count > 0, err
}
// GormInsertQuery implements InsertQuery for GORM
type GormInsertQuery struct {
db *gorm.DB
reconnect func(...*gorm.DB) error
model interface{}
values map[string]interface{}
db *gorm.DB
reconnect func(...*gorm.DB) error
model interface{}
values map[string]interface{}
schema string
tableName string
entity string
metricsEnabled bool
}
func (g *GormInsertQuery) Model(model interface{}) common.InsertQuery {
g.model = model
g.db = g.db.Model(model)
g.schema, g.tableName = schemaAndTableFromModel(model, g.db.Name())
g.entity = entityNameFromModel(model, g.tableName)
return g
}
func (g *GormInsertQuery) Table(table string) common.InsertQuery {
g.db = g.db.Table(table)
g.schema, g.tableName = parseTableName(table, g.db.Name())
if g.entity == "" {
g.entity = cleanMetricIdentifier(g.tableName)
}
return g
}
@@ -694,6 +737,7 @@ func (g *GormInsertQuery) Exec(ctx context.Context) (res common.Result, err erro
err = logger.HandlePanic("GormInsertQuery.Exec", r)
}
}()
startedAt := time.Now()
run := func() *gorm.DB {
switch {
case g.model != nil:
@@ -710,30 +754,42 @@ func (g *GormInsertQuery) Exec(ctx context.Context) (res common.Result, err erro
result = run()
}
}
recordQueryMetrics(g.metricsEnabled, "INSERT", g.schema, g.entity, g.tableName, startedAt, result.Error)
return &GormResult{result: result}, result.Error
}
// GormUpdateQuery implements UpdateQuery for GORM
type GormUpdateQuery struct {
db *gorm.DB
reconnect func(...*gorm.DB) error
model interface{}
updates interface{}
db *gorm.DB
reconnect func(...*gorm.DB) error
model interface{}
updates interface{}
schema string
tableName string
entity string
metricsEnabled bool
}
func (g *GormUpdateQuery) Model(model interface{}) common.UpdateQuery {
g.model = model
g.db = g.db.Model(model)
g.schema, g.tableName = schemaAndTableFromModel(model, g.db.Name())
g.entity = entityNameFromModel(model, g.tableName)
return g
}
func (g *GormUpdateQuery) Table(table string) common.UpdateQuery {
g.db = g.db.Table(table)
g.schema, g.tableName = parseTableName(table, g.db.Name())
if g.entity == "" {
g.entity = cleanMetricIdentifier(g.tableName)
}
if g.model == nil {
// Try to get table name from table string if model is not set
model, err := modelregistry.GetModelByName(table)
if err == nil {
g.model = model
g.entity = entityNameFromModel(model, g.tableName)
}
}
return g
@@ -794,6 +850,7 @@ func (g *GormUpdateQuery) Exec(ctx context.Context) (res common.Result, err erro
err = logger.HandlePanic("GormUpdateQuery.Exec", r)
}
}()
startedAt := time.Now()
run := func() *gorm.DB {
return g.db.WithContext(ctx).Updates(g.updates)
}
@@ -810,24 +867,35 @@ func (g *GormUpdateQuery) Exec(ctx context.Context) (res common.Result, err erro
})
logger.Error("GormUpdateQuery.Exec failed. SQL: %s. Error: %v", sqlStr, result.Error)
}
recordQueryMetrics(g.metricsEnabled, "UPDATE", g.schema, g.entity, g.tableName, startedAt, result.Error)
return &GormResult{result: result}, result.Error
}
// GormDeleteQuery implements DeleteQuery for GORM
type GormDeleteQuery struct {
db *gorm.DB
reconnect func(...*gorm.DB) error
model interface{}
db *gorm.DB
reconnect func(...*gorm.DB) error
model interface{}
schema string
tableName string
entity string
metricsEnabled bool
}
func (g *GormDeleteQuery) Model(model interface{}) common.DeleteQuery {
g.model = model
g.db = g.db.Model(model)
g.schema, g.tableName = schemaAndTableFromModel(model, g.db.Name())
g.entity = entityNameFromModel(model, g.tableName)
return g
}
func (g *GormDeleteQuery) Table(table string) common.DeleteQuery {
g.db = g.db.Table(table)
g.schema, g.tableName = parseTableName(table, g.db.Name())
if g.entity == "" {
g.entity = cleanMetricIdentifier(g.tableName)
}
return g
}
@@ -842,6 +910,7 @@ func (g *GormDeleteQuery) Exec(ctx context.Context) (res common.Result, err erro
err = logger.HandlePanic("GormDeleteQuery.Exec", r)
}
}()
startedAt := time.Now()
run := func() *gorm.DB {
return g.db.WithContext(ctx).Delete(g.model)
}
@@ -858,6 +927,7 @@ func (g *GormDeleteQuery) Exec(ctx context.Context) (res common.Result, err erro
})
logger.Error("GormDeleteQuery.Exec failed. SQL: %s. Error: %v", sqlStr, result.Error)
}
recordQueryMetrics(g.metricsEnabled, "DELETE", g.schema, g.entity, g.tableName, startedAt, result.Error)
return &GormResult{result: result}, result.Error
}

View File

@@ -5,8 +5,10 @@ import (
"database/sql"
"fmt"
"reflect"
"sort"
"strings"
"sync"
"time"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/logger"
@@ -17,10 +19,11 @@ import (
// PgSQLAdapter adapts standard database/sql to work with our Database interface
// This provides a lightweight PostgreSQL adapter without ORM overhead
type PgSQLAdapter struct {
db *sql.DB
dbMu sync.RWMutex
dbFactory func() (*sql.DB, error)
driverName string
db *sql.DB
dbMu sync.RWMutex
dbFactory func() (*sql.DB, error)
driverName string
metricsEnabled bool
}
// NewPgSQLAdapter creates a new adapter wrapping a standard sql.DB.
@@ -31,7 +34,7 @@ func NewPgSQLAdapter(db *sql.DB, driverName ...string) *PgSQLAdapter {
if len(driverName) > 0 && driverName[0] != "" {
name = driverName[0]
}
return &PgSQLAdapter{db: db, driverName: name}
return &PgSQLAdapter{db: db, driverName: name, metricsEnabled: true}
}
// WithDBFactory configures a factory used to reopen the database connection if it is closed.
@@ -40,6 +43,22 @@ func (p *PgSQLAdapter) WithDBFactory(factory func() (*sql.DB, error)) *PgSQLAdap
return p
}
// SetMetricsEnabled enables or disables query metrics for this adapter.
func (p *PgSQLAdapter) SetMetricsEnabled(enabled bool) *PgSQLAdapter {
p.metricsEnabled = enabled
return p
}
// EnableMetrics enables query metrics for this adapter.
func (p *PgSQLAdapter) EnableMetrics() *PgSQLAdapter {
return p.SetMetricsEnabled(true)
}
// DisableMetrics disables query metrics for this adapter.
func (p *PgSQLAdapter) DisableMetrics() *PgSQLAdapter {
return p.SetMetricsEnabled(false)
}
func (p *PgSQLAdapter) getDB() *sql.DB {
p.dbMu.RLock()
defer p.dbMu.RUnlock()
@@ -71,37 +90,41 @@ func (p *PgSQLAdapter) EnableQueryDebug() {
func (p *PgSQLAdapter) NewSelect() common.SelectQuery {
return &PgSQLSelectQuery{
db: p.getDB(),
driverName: p.driverName,
columns: []string{"*"},
args: make([]interface{}, 0),
db: p.getDB(),
driverName: p.driverName,
columns: []string{"*"},
args: make([]interface{}, 0),
metricsEnabled: p.metricsEnabled,
}
}
func (p *PgSQLAdapter) NewInsert() common.InsertQuery {
return &PgSQLInsertQuery{
db: p.getDB(),
driverName: p.driverName,
values: make(map[string]interface{}),
db: p.getDB(),
driverName: p.driverName,
values: make(map[string]interface{}),
metricsEnabled: p.metricsEnabled,
}
}
func (p *PgSQLAdapter) NewUpdate() common.UpdateQuery {
return &PgSQLUpdateQuery{
db: p.getDB(),
driverName: p.driverName,
sets: make(map[string]interface{}),
args: make([]interface{}, 0),
whereClauses: make([]string, 0),
db: p.getDB(),
driverName: p.driverName,
sets: make(map[string]interface{}),
args: make([]interface{}, 0),
whereClauses: make([]string, 0),
metricsEnabled: p.metricsEnabled,
}
}
func (p *PgSQLAdapter) NewDelete() common.DeleteQuery {
return &PgSQLDeleteQuery{
db: p.getDB(),
driverName: p.driverName,
args: make([]interface{}, 0),
whereClauses: make([]string, 0),
db: p.getDB(),
driverName: p.driverName,
args: make([]interface{}, 0),
whereClauses: make([]string, 0),
metricsEnabled: p.metricsEnabled,
}
}
@@ -111,6 +134,8 @@ func (p *PgSQLAdapter) Exec(ctx context.Context, query string, args ...interface
err = logger.HandlePanic("PgSQLAdapter.Exec", r)
}
}()
startedAt := time.Now()
operation, schema, entity, table := metricTargetFromRawQuery(query, p.driverName)
logger.Debug("PgSQL Exec: %s [args: %v]", query, args)
var result sql.Result
run := func() error { var e error; result, e = p.getDB().ExecContext(ctx, query, args...); return e }
@@ -122,8 +147,10 @@ func (p *PgSQLAdapter) Exec(ctx context.Context, query string, args ...interface
}
if err != nil {
logger.Error("PgSQL Exec failed: %v", err)
recordQueryMetrics(p.metricsEnabled, operation, schema, entity, table, startedAt, err)
return nil, err
}
recordQueryMetrics(p.metricsEnabled, operation, schema, entity, table, startedAt, nil)
return &PgSQLResult{result: result}, nil
}
@@ -133,6 +160,8 @@ func (p *PgSQLAdapter) Query(ctx context.Context, dest interface{}, query string
err = logger.HandlePanic("PgSQLAdapter.Query", r)
}
}()
startedAt := time.Now()
operation, schema, entity, table := metricTargetFromRawQuery(query, p.driverName)
logger.Debug("PgSQL Query: %s [args: %v]", query, args)
var rows *sql.Rows
run := func() error { var e error; rows, e = p.getDB().QueryContext(ctx, query, args...); return e }
@@ -144,11 +173,14 @@ func (p *PgSQLAdapter) Query(ctx context.Context, dest interface{}, query string
}
if err != nil {
logger.Error("PgSQL Query failed: %v", err)
recordQueryMetrics(p.metricsEnabled, operation, schema, entity, table, startedAt, err)
return err
}
defer rows.Close()
return scanRows(rows, dest)
err = scanRows(rows, dest)
recordQueryMetrics(p.metricsEnabled, operation, schema, entity, table, startedAt, err)
return err
}
func (p *PgSQLAdapter) BeginTx(ctx context.Context) (common.Database, error) {
@@ -156,7 +188,7 @@ func (p *PgSQLAdapter) BeginTx(ctx context.Context) (common.Database, error) {
if err != nil {
return nil, err
}
return &PgSQLTxAdapter{tx: tx, driverName: p.driverName}, nil
return &PgSQLTxAdapter{tx: tx, driverName: p.driverName, metricsEnabled: p.metricsEnabled}, nil
}
func (p *PgSQLAdapter) CommitTx(ctx context.Context) error {
@@ -179,7 +211,7 @@ func (p *PgSQLAdapter) RunInTransaction(ctx context.Context, fn func(common.Data
return err
}
adapter := &PgSQLTxAdapter{tx: tx, driverName: p.driverName}
adapter := &PgSQLTxAdapter{tx: tx, driverName: p.driverName, metricsEnabled: p.metricsEnabled}
defer func() {
if p := recover(); p != nil {
@@ -222,34 +254,34 @@ type relationMetadata struct {
// PgSQLSelectQuery implements SelectQuery for PostgreSQL
type PgSQLSelectQuery struct {
db *sql.DB
tx *sql.Tx
model interface{}
tableName string
tableAlias string
driverName string // Database driver name (postgres, sqlite, mssql)
columns []string
columnExprs []string
whereClauses []string
orClauses []string
joins []string
orderBy []string
groupBy []string
havingClauses []string
limit int
offset int
args []interface{}
paramCounter int
preloads []preloadConfig
db *sql.DB
tx *sql.Tx
model interface{}
entity string
tableName string
schema string
tableAlias string
driverName string // Database driver name (postgres, sqlite, mssql)
columns []string
columnExprs []string
whereClauses []string
orClauses []string
joins []string
orderBy []string
groupBy []string
havingClauses []string
limit int
offset int
args []interface{}
paramCounter int
preloads []preloadConfig
metricsEnabled bool
}
func (p *PgSQLSelectQuery) Model(model interface{}) common.SelectQuery {
p.model = model
if provider, ok := model.(common.TableNameProvider); ok {
fullTableName := provider.TableName()
// For SQLite, convert "schema.table" to "schema_table"
_, p.tableName = parseTableName(fullTableName, p.driverName)
}
p.schema, p.tableName = schemaAndTableFromModel(model, p.driverName)
p.entity = entityNameFromModel(model, p.tableName)
if provider, ok := model.(common.TableAliasProvider); ok {
p.tableAlias = provider.TableAlias()
}
@@ -258,7 +290,10 @@ func (p *PgSQLSelectQuery) Model(model interface{}) common.SelectQuery {
func (p *PgSQLSelectQuery) Table(table string) common.SelectQuery {
// For SQLite, convert "schema.table" to "schema_table"
_, p.tableName = parseTableName(table, p.driverName)
p.schema, p.tableName = parseTableName(table, p.driverName)
if p.entity == "" {
p.entity = cleanMetricIdentifier(p.tableName)
}
return p
}
@@ -468,6 +503,7 @@ func (p *PgSQLSelectQuery) Scan(ctx context.Context, dest interface{}) (err erro
err = logger.HandlePanic("PgSQLSelectQuery.Scan", r)
}
}()
startedAt := time.Now()
// Apply preloads that use JOINs
p.applyJoinPreloads()
@@ -484,17 +520,21 @@ func (p *PgSQLSelectQuery) Scan(ctx context.Context, dest interface{}) (err erro
if err != nil {
logger.Error("PgSQL SELECT failed: %v", err)
recordQueryMetrics(p.metricsEnabled, "SELECT", p.schema, p.entity, p.tableName, startedAt, err)
return err
}
defer rows.Close()
err = scanRows(rows, dest)
if err != nil {
recordQueryMetrics(p.metricsEnabled, "SELECT", p.schema, p.entity, p.tableName, startedAt, err)
return err
}
// Apply preloads that use separate queries
return p.applySubqueryPreloads(ctx, dest)
err = p.applySubqueryPreloads(ctx, dest)
recordQueryMetrics(p.metricsEnabled, "SELECT", p.schema, p.entity, p.tableName, startedAt, err)
return err
}
func (p *PgSQLSelectQuery) ScanModel(ctx context.Context) error {
@@ -511,6 +551,7 @@ func (p *PgSQLSelectQuery) Count(ctx context.Context) (count int, err error) {
count = 0
}
}()
startedAt := time.Now()
// Build a COUNT query
var sb strings.Builder
@@ -550,6 +591,7 @@ func (p *PgSQLSelectQuery) Count(ctx context.Context) (count int, err error) {
if err != nil {
logger.Error("PgSQL COUNT failed: %v", err)
}
recordQueryMetrics(p.metricsEnabled, "COUNT", p.schema, p.entity, p.tableName, startedAt, err)
return count, err
}
@@ -567,20 +609,21 @@ func (p *PgSQLSelectQuery) Exists(ctx context.Context) (exists bool, err error)
// PgSQLInsertQuery implements InsertQuery for PostgreSQL
type PgSQLInsertQuery struct {
db *sql.DB
tx *sql.Tx
tableName string
driverName string
values map[string]interface{}
returning []string
db *sql.DB
tx *sql.Tx
schema string
tableName string
entity string
driverName string
values map[string]interface{}
valueOrder []string
returning []string
metricsEnabled bool
}
func (p *PgSQLInsertQuery) Model(model interface{}) common.InsertQuery {
if provider, ok := model.(common.TableNameProvider); ok {
fullTableName := provider.TableName()
// For SQLite, convert "schema.table" to "schema_table"
_, p.tableName = parseTableName(fullTableName, p.driverName)
}
p.schema, p.tableName = schemaAndTableFromModel(model, p.driverName)
p.entity = entityNameFromModel(model, p.tableName)
// Extract values from model using reflection
// This is a simplified implementation
return p
@@ -588,11 +631,17 @@ func (p *PgSQLInsertQuery) Model(model interface{}) common.InsertQuery {
func (p *PgSQLInsertQuery) Table(table string) common.InsertQuery {
// For SQLite, convert "schema.table" to "schema_table"
_, p.tableName = parseTableName(table, p.driverName)
p.schema, p.tableName = parseTableName(table, p.driverName)
if p.entity == "" {
p.entity = cleanMetricIdentifier(p.tableName)
}
return p
}
func (p *PgSQLInsertQuery) Value(column string, value interface{}) common.InsertQuery {
if _, exists := p.values[column]; !exists {
p.valueOrder = append(p.valueOrder, column)
}
p.values[column] = value
return p
}
@@ -613,6 +662,7 @@ func (p *PgSQLInsertQuery) Exec(ctx context.Context) (res common.Result, err err
err = logger.HandlePanic("PgSQLInsertQuery.Exec", r)
}
}()
startedAt := time.Now()
if len(p.values) == 0 {
return nil, fmt.Errorf("no values to insert")
@@ -621,12 +671,11 @@ func (p *PgSQLInsertQuery) Exec(ctx context.Context) (res common.Result, err err
columns := make([]string, 0, len(p.values))
placeholders := make([]string, 0, len(p.values))
args := make([]interface{}, 0, len(p.values))
i := 1
for col, val := range p.values {
for _, col := range p.valueOrder {
columns = append(columns, col)
placeholders = append(placeholders, fmt.Sprintf("$%d", i))
args = append(args, val)
args = append(args, p.values[col])
i++
}
@@ -650,43 +699,50 @@ func (p *PgSQLInsertQuery) Exec(ctx context.Context) (res common.Result, err err
if err != nil {
logger.Error("PgSQL INSERT failed: %v", err)
recordQueryMetrics(p.metricsEnabled, "INSERT", p.schema, p.entity, p.tableName, startedAt, err)
return nil, err
}
recordQueryMetrics(p.metricsEnabled, "INSERT", p.schema, p.entity, p.tableName, startedAt, nil)
return &PgSQLResult{result: result}, nil
}
// PgSQLUpdateQuery implements UpdateQuery for PostgreSQL
type PgSQLUpdateQuery struct {
db *sql.DB
tx *sql.Tx
tableName string
driverName string
model interface{}
sets map[string]interface{}
whereClauses []string
args []interface{}
paramCounter int
returning []string
db *sql.DB
tx *sql.Tx
schema string
tableName string
entity string
driverName string
model interface{}
sets map[string]interface{}
setOrder []string
whereClauses []string
args []interface{}
paramCounter int
returning []string
metricsEnabled bool
}
func (p *PgSQLUpdateQuery) Model(model interface{}) common.UpdateQuery {
p.model = model
if provider, ok := model.(common.TableNameProvider); ok {
fullTableName := provider.TableName()
// For SQLite, convert "schema.table" to "schema_table"
_, p.tableName = parseTableName(fullTableName, p.driverName)
}
p.schema, p.tableName = schemaAndTableFromModel(model, p.driverName)
p.entity = entityNameFromModel(model, p.tableName)
return p
}
func (p *PgSQLUpdateQuery) Table(table string) common.UpdateQuery {
// For SQLite, convert "schema.table" to "schema_table"
_, p.tableName = parseTableName(table, p.driverName)
p.schema, p.tableName = parseTableName(table, p.driverName)
if p.entity == "" {
p.entity = cleanMetricIdentifier(p.tableName)
}
if p.model == nil {
model, err := modelregistry.GetModelByName(table)
if err == nil {
p.model = model
p.entity = entityNameFromModel(model, p.tableName)
}
}
return p
@@ -696,6 +752,9 @@ func (p *PgSQLUpdateQuery) Set(column string, value interface{}) common.UpdateQu
if p.model != nil && !reflection.IsColumnWritable(p.model, column) {
return p
}
if _, exists := p.sets[column]; !exists {
p.setOrder = append(p.setOrder, column)
}
p.sets[column] = value
return p
}
@@ -706,13 +765,23 @@ func (p *PgSQLUpdateQuery) SetMap(values map[string]interface{}) common.UpdateQu
pkName = reflection.GetPrimaryKeyName(p.model)
}
for column, value := range values {
orderedColumns := make([]string, 0, len(values))
for column := range values {
orderedColumns = append(orderedColumns, column)
}
sort.Strings(orderedColumns)
for _, column := range orderedColumns {
value := values[column]
if pkName != "" && column == pkName {
continue
}
if p.model != nil && !reflection.IsColumnWritable(p.model, column) {
continue
}
if _, exists := p.sets[column]; !exists {
p.setOrder = append(p.setOrder, column)
}
p.sets[column] = value
}
return p
@@ -746,6 +815,7 @@ func (p *PgSQLUpdateQuery) Exec(ctx context.Context) (res common.Result, err err
err = logger.HandlePanic("PgSQLUpdateQuery.Exec", r)
}
}()
startedAt := time.Now()
if len(p.sets) == 0 {
return nil, fmt.Errorf("no values to update")
@@ -753,12 +823,11 @@ func (p *PgSQLUpdateQuery) Exec(ctx context.Context) (res common.Result, err err
setClauses := make([]string, 0, len(p.sets))
setArgs := make([]interface{}, 0, len(p.sets))
// SET parameters start at $1
i := 1
for col, val := range p.sets {
for _, col := range p.setOrder {
setClauses = append(setClauses, fmt.Sprintf("%s = $%d", col, i))
setArgs = append(setArgs, val)
setArgs = append(setArgs, p.sets[col])
i++
}
@@ -812,35 +881,40 @@ func (p *PgSQLUpdateQuery) Exec(ctx context.Context) (res common.Result, err err
if err != nil {
logger.Error("PgSQL UPDATE failed: %v", err)
recordQueryMetrics(p.metricsEnabled, "UPDATE", p.schema, p.entity, p.tableName, startedAt, err)
return nil, err
}
recordQueryMetrics(p.metricsEnabled, "UPDATE", p.schema, p.entity, p.tableName, startedAt, nil)
return &PgSQLResult{result: result}, nil
}
// PgSQLDeleteQuery implements DeleteQuery for PostgreSQL
type PgSQLDeleteQuery struct {
db *sql.DB
tx *sql.Tx
tableName string
driverName string
whereClauses []string
args []interface{}
paramCounter int
db *sql.DB
tx *sql.Tx
schema string
tableName string
entity string
driverName string
whereClauses []string
args []interface{}
paramCounter int
metricsEnabled bool
}
func (p *PgSQLDeleteQuery) Model(model interface{}) common.DeleteQuery {
if provider, ok := model.(common.TableNameProvider); ok {
fullTableName := provider.TableName()
// For SQLite, convert "schema.table" to "schema_table"
_, p.tableName = parseTableName(fullTableName, p.driverName)
}
p.schema, p.tableName = schemaAndTableFromModel(model, p.driverName)
p.entity = entityNameFromModel(model, p.tableName)
return p
}
func (p *PgSQLDeleteQuery) Table(table string) common.DeleteQuery {
// For SQLite, convert "schema.table" to "schema_table"
_, p.tableName = parseTableName(table, p.driverName)
p.schema, p.tableName = parseTableName(table, p.driverName)
if p.entity == "" {
p.entity = cleanMetricIdentifier(p.tableName)
}
return p
}
@@ -867,6 +941,7 @@ func (p *PgSQLDeleteQuery) Exec(ctx context.Context) (res common.Result, err err
err = logger.HandlePanic("PgSQLDeleteQuery.Exec", r)
}
}()
startedAt := time.Now()
query := fmt.Sprintf("DELETE FROM %s", p.tableName)
@@ -885,9 +960,11 @@ func (p *PgSQLDeleteQuery) Exec(ctx context.Context) (res common.Result, err err
if err != nil {
logger.Error("PgSQL DELETE failed: %v", err)
recordQueryMetrics(p.metricsEnabled, "DELETE", p.schema, p.entity, p.tableName, startedAt, err)
return nil, err
}
recordQueryMetrics(p.metricsEnabled, "DELETE", p.schema, p.entity, p.tableName, startedAt, nil)
return &PgSQLResult{result: result}, nil
}
@@ -913,66 +990,80 @@ func (p *PgSQLResult) LastInsertId() (int64, error) {
// PgSQLTxAdapter wraps a PostgreSQL transaction
type PgSQLTxAdapter struct {
tx *sql.Tx
driverName string
tx *sql.Tx
driverName string
metricsEnabled bool
}
func (p *PgSQLTxAdapter) NewSelect() common.SelectQuery {
return &PgSQLSelectQuery{
tx: p.tx,
driverName: p.driverName,
columns: []string{"*"},
args: make([]interface{}, 0),
tx: p.tx,
driverName: p.driverName,
columns: []string{"*"},
args: make([]interface{}, 0),
metricsEnabled: p.metricsEnabled,
}
}
func (p *PgSQLTxAdapter) NewInsert() common.InsertQuery {
return &PgSQLInsertQuery{
tx: p.tx,
driverName: p.driverName,
values: make(map[string]interface{}),
tx: p.tx,
driverName: p.driverName,
values: make(map[string]interface{}),
metricsEnabled: p.metricsEnabled,
}
}
func (p *PgSQLTxAdapter) NewUpdate() common.UpdateQuery {
return &PgSQLUpdateQuery{
tx: p.tx,
driverName: p.driverName,
sets: make(map[string]interface{}),
args: make([]interface{}, 0),
whereClauses: make([]string, 0),
tx: p.tx,
driverName: p.driverName,
sets: make(map[string]interface{}),
args: make([]interface{}, 0),
whereClauses: make([]string, 0),
metricsEnabled: p.metricsEnabled,
}
}
func (p *PgSQLTxAdapter) NewDelete() common.DeleteQuery {
return &PgSQLDeleteQuery{
tx: p.tx,
driverName: p.driverName,
args: make([]interface{}, 0),
whereClauses: make([]string, 0),
tx: p.tx,
driverName: p.driverName,
args: make([]interface{}, 0),
whereClauses: make([]string, 0),
metricsEnabled: p.metricsEnabled,
}
}
func (p *PgSQLTxAdapter) Exec(ctx context.Context, query string, args ...interface{}) (common.Result, error) {
startedAt := time.Now()
operation, schema, entity, table := metricTargetFromRawQuery(query, p.driverName)
logger.Debug("PgSQL Tx Exec: %s [args: %v]", query, args)
result, err := p.tx.ExecContext(ctx, query, args...)
if err != nil {
logger.Error("PgSQL Tx Exec failed: %v", err)
recordQueryMetrics(p.metricsEnabled, operation, schema, entity, table, startedAt, err)
return nil, err
}
recordQueryMetrics(p.metricsEnabled, operation, schema, entity, table, startedAt, nil)
return &PgSQLResult{result: result}, nil
}
func (p *PgSQLTxAdapter) Query(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
startedAt := time.Now()
operation, schema, entity, table := metricTargetFromRawQuery(query, p.driverName)
logger.Debug("PgSQL Tx Query: %s [args: %v]", query, args)
rows, err := p.tx.QueryContext(ctx, query, args...)
if err != nil {
logger.Error("PgSQL Tx Query failed: %v", err)
recordQueryMetrics(p.metricsEnabled, operation, schema, entity, table, startedAt, err)
return err
}
defer rows.Close()
return scanRows(rows, dest)
err = scanRows(rows, dest)
recordQueryMetrics(p.metricsEnabled, operation, schema, entity, table, startedAt, err)
return err
}
func (p *PgSQLTxAdapter) BeginTx(ctx context.Context) (common.Database, error) {

View File

@@ -0,0 +1,197 @@
package database
import (
"reflect"
"strings"
"time"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/metrics"
"github.com/bitechdev/ResolveSpec/pkg/reflection"
)
func recordQueryMetrics(enabled bool, operation, schema, entity, table string, startedAt time.Time, err error) {
if !enabled {
return
}
metrics.GetProvider().RecordDBQuery(
normalizeMetricOperation(operation),
normalizeMetricSchema(schema),
normalizeMetricEntity(entity, table),
normalizeMetricTable(table),
time.Since(startedAt),
err,
)
}
func normalizeMetricOperation(operation string) string {
operation = strings.ToUpper(strings.TrimSpace(operation))
if operation == "" {
return "UNKNOWN"
}
return operation
}
func normalizeMetricSchema(schema string) string {
schema = cleanMetricIdentifier(schema)
if schema == "" {
return "default"
}
return schema
}
func normalizeMetricEntity(entity, table string) string {
entity = cleanMetricIdentifier(entity)
if entity != "" {
return entity
}
table = cleanMetricIdentifier(table)
if table != "" {
return table
}
return "unknown"
}
func normalizeMetricTable(table string) string {
table = cleanMetricIdentifier(table)
if table == "" {
return "unknown"
}
return table
}
func entityNameFromModel(model interface{}, table string) string {
if model == nil {
return cleanMetricIdentifier(table)
}
modelType := reflect.TypeOf(model)
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
modelType = modelType.Elem()
}
if modelType == nil {
return cleanMetricIdentifier(table)
}
if modelType.Kind() == reflect.Struct && modelType.Name() != "" {
return reflection.ToSnakeCase(modelType.Name())
}
return cleanMetricIdentifier(table)
}
func schemaAndTableFromModel(model interface{}, driverName string) (schema, table string) {
provider, ok := tableNameProviderFromModel(model)
if !ok {
return "", ""
}
return parseTableName(provider.TableName(), driverName)
}
func tableNameProviderFromModel(model interface{}) (common.TableNameProvider, bool) {
if model == nil {
return nil, false
}
if provider, ok := model.(common.TableNameProvider); ok {
return provider, true
}
modelType := reflect.TypeOf(model)
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
modelType = modelType.Elem()
}
if modelType == nil || modelType.Kind() != reflect.Struct {
return nil, false
}
modelValue := reflect.New(modelType)
if provider, ok := modelValue.Interface().(common.TableNameProvider); ok {
return provider, true
}
if provider, ok := modelValue.Elem().Interface().(common.TableNameProvider); ok {
return provider, true
}
return nil, false
}
func metricTargetFromRawQuery(query, driverName string) (operation, schema, entity, table string) {
operation = normalizeMetricOperation(firstQueryKeyword(query))
tableRef := tableFromRawQuery(query, operation)
if tableRef == "" {
return operation, "", "unknown", "unknown"
}
schema, table = parseTableName(tableRef, driverName)
entity = cleanMetricIdentifier(table)
return operation, schema, entity, table
}
func firstQueryKeyword(query string) string {
query = strings.TrimSpace(query)
if query == "" {
return ""
}
fields := strings.Fields(query)
if len(fields) == 0 {
return ""
}
return fields[0]
}
func tableFromRawQuery(query, operation string) string {
tokens := tokenizeQuery(query)
if len(tokens) == 0 {
return ""
}
switch operation {
case "SELECT":
return tokenAfter(tokens, "FROM")
case "INSERT":
return tokenAfter(tokens, "INTO")
case "UPDATE":
return tokenAfter(tokens, "UPDATE")
case "DELETE":
return tokenAfter(tokens, "FROM")
default:
return ""
}
}
func tokenAfter(tokens []string, keyword string) string {
for idx, token := range tokens {
if strings.EqualFold(token, keyword) && idx+1 < len(tokens) {
return cleanMetricIdentifier(tokens[idx+1])
}
}
return ""
}
func tokenizeQuery(query string) []string {
replacer := strings.NewReplacer(
"\n", " ",
"\t", " ",
"(", " ",
")", " ",
",", " ",
)
return strings.Fields(replacer.Replace(query))
}
func cleanMetricIdentifier(value string) string {
value = strings.TrimSpace(value)
value = strings.Trim(value, "\"'`[]")
value = strings.TrimRight(value, ";")
return value
}

View File

@@ -0,0 +1,198 @@
package database
import (
"context"
"database/sql"
"net/http"
"sync"
"testing"
"time"
"github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/uptrace/bun"
"github.com/uptrace/bun/dialect/sqlitedialect"
"github.com/uptrace/bun/driver/sqliteshim"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"github.com/bitechdev/ResolveSpec/pkg/metrics"
)
type queryMetricCall struct {
operation string
schema string
entity string
table string
}
type capturingMetricsProvider struct {
mu sync.Mutex
calls []queryMetricCall
}
func (c *capturingMetricsProvider) RecordHTTPRequest(method, path, status string, duration time.Duration) {
}
func (c *capturingMetricsProvider) IncRequestsInFlight() {}
func (c *capturingMetricsProvider) DecRequestsInFlight() {}
func (c *capturingMetricsProvider) RecordDBQuery(operation, schema, entity, table string, duration time.Duration, err error) {
c.mu.Lock()
defer c.mu.Unlock()
c.calls = append(c.calls, queryMetricCall{
operation: operation,
schema: schema,
entity: entity,
table: table,
})
}
func (c *capturingMetricsProvider) RecordCacheHit(provider string) {}
func (c *capturingMetricsProvider) RecordCacheMiss(provider string) {}
func (c *capturingMetricsProvider) UpdateCacheSize(provider string, size int64) {
}
func (c *capturingMetricsProvider) RecordEventPublished(source, eventType string) {}
func (c *capturingMetricsProvider) RecordEventProcessed(source, eventType, status string, duration time.Duration) {
}
func (c *capturingMetricsProvider) UpdateEventQueueSize(size int64) {}
func (c *capturingMetricsProvider) RecordPanic(methodName string) {}
func (c *capturingMetricsProvider) Handler() http.Handler { return http.NewServeMux() }
func (c *capturingMetricsProvider) snapshot() []queryMetricCall {
c.mu.Lock()
defer c.mu.Unlock()
out := make([]queryMetricCall, len(c.calls))
copy(out, c.calls)
return out
}
type queryMetricsGormUser struct {
ID int `gorm:"primaryKey"`
Name string
}
func (queryMetricsGormUser) TableName() string {
return "metrics_gorm_users"
}
type queryMetricsBunUser struct {
bun.BaseModel `bun:"table:metrics_bun_users"`
ID int64 `bun:"id,pk,autoincrement"`
Name string `bun:"name"`
}
func TestPgSQLAdapterRecordsSchemaEntityTableMetrics(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
provider := &capturingMetricsProvider{}
metrics.SetProvider(provider)
defer metrics.SetProvider(nil)
mock.ExpectExec(`UPDATE users SET name = \$1 WHERE id = \$2`).
WithArgs("Alice", 1).
WillReturnResult(sqlmock.NewResult(0, 1))
adapter := NewPgSQLAdapter(db)
_, err = adapter.NewUpdate().
Table("public.users").
Set("name", "Alice").
Where("id = ?", 1).
Exec(context.Background())
require.NoError(t, err)
require.NoError(t, mock.ExpectationsWereMet())
calls := provider.snapshot()
require.Len(t, calls, 1)
assert.Equal(t, "UPDATE", calls[0].operation)
assert.Equal(t, "public", calls[0].schema)
assert.Equal(t, "users", calls[0].entity)
assert.Equal(t, "users", calls[0].table)
}
func TestPgSQLAdapterDisableMetricsSuppressesEmission(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
provider := &capturingMetricsProvider{}
metrics.SetProvider(provider)
defer metrics.SetProvider(nil)
mock.ExpectExec(`DELETE FROM users WHERE id = \$1`).
WithArgs(1).
WillReturnResult(sqlmock.NewResult(0, 1))
adapter := NewPgSQLAdapter(db).DisableMetrics()
_, err = adapter.NewDelete().
Table("users").
Where("id = ?", 1).
Exec(context.Background())
require.NoError(t, err)
require.NoError(t, mock.ExpectationsWereMet())
assert.Empty(t, provider.snapshot())
}
func TestGormAdapterRecordsEntityAndTableMetrics(t *testing.T) {
db, err := gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{})
require.NoError(t, err)
require.NoError(t, db.AutoMigrate(&queryMetricsGormUser{}))
require.NoError(t, db.Create(&queryMetricsGormUser{Name: "Alice"}).Error)
provider := &capturingMetricsProvider{}
metrics.SetProvider(provider)
defer metrics.SetProvider(nil)
adapter := NewGormAdapter(db)
var users []queryMetricsGormUser
err = adapter.NewSelect().Model(&users).Scan(context.Background(), &users)
require.NoError(t, err)
require.NotEmpty(t, users)
calls := provider.snapshot()
require.Len(t, calls, 1)
assert.Equal(t, "SELECT", calls[0].operation)
assert.Equal(t, "default", calls[0].schema)
assert.Equal(t, "query_metrics_gorm_user", calls[0].entity)
assert.Equal(t, "metrics_gorm_users", calls[0].table)
}
func TestBunAdapterRecordsEntityAndTableMetrics(t *testing.T) {
sqldb, err := sql.Open(sqliteshim.ShimName, "file::memory:?cache=shared")
require.NoError(t, err)
defer sqldb.Close()
db := bun.NewDB(sqldb, sqlitedialect.New())
defer db.Close()
_, err = db.NewCreateTable().
Model((*queryMetricsBunUser)(nil)).
IfNotExists().
Exec(context.Background())
require.NoError(t, err)
_, err = db.NewInsert().Model(&queryMetricsBunUser{Name: "Alice"}).Exec(context.Background())
require.NoError(t, err)
provider := &capturingMetricsProvider{}
metrics.SetProvider(provider)
defer metrics.SetProvider(nil)
adapter := NewBunAdapter(db)
var users []queryMetricsBunUser
err = adapter.NewSelect().Model(&users).Scan(context.Background(), &users)
require.NoError(t, err)
require.NotEmpty(t, users)
calls := provider.snapshot()
require.Len(t, calls, 1)
assert.Equal(t, "SELECT", calls[0].operation)
assert.Equal(t, "default", calls[0].schema)
assert.Equal(t, "query_metrics_bun_user", calls[0].entity)
assert.Equal(t, "metrics_bun_users", calls[0].table)
}

View File

@@ -427,7 +427,9 @@ func (c *sqlConnection) getBunAdapter() (common.Database, error) {
c.bunDB = bun.NewDB(native, dialect)
}
c.bunAdapter = database.NewBunAdapter(c.bunDB).WithDBFactory(c.reopenBunForAdapter)
c.bunAdapter = database.NewBunAdapter(c.bunDB).
WithDBFactory(c.reopenBunForAdapter).
SetMetricsEnabled(c.config.EnableMetrics)
return c.bunAdapter, nil
}
@@ -468,7 +470,9 @@ func (c *sqlConnection) getGORMAdapter() (common.Database, error) {
c.gormDB = db
}
c.gormAdapter = database.NewGormAdapter(c.gormDB).WithDBFactory(c.reopenGORMForAdapter)
c.gormAdapter = database.NewGormAdapter(c.gormDB).
WithDBFactory(c.reopenGORMForAdapter).
SetMetricsEnabled(c.config.EnableMetrics)
return c.gormAdapter, nil
}
@@ -509,11 +513,17 @@ func (c *sqlConnection) getNativeAdapter() (common.Database, error) {
// Create a native adapter based on database type
switch c.dbType {
case DatabaseTypePostgreSQL:
c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB, string(c.dbType)).WithDBFactory(c.reopenNativeForAdapter)
c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB, string(c.dbType)).
WithDBFactory(c.reopenNativeForAdapter).
SetMetricsEnabled(c.config.EnableMetrics)
case DatabaseTypeSQLite:
c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB, string(c.dbType)).WithDBFactory(c.reopenNativeForAdapter)
c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB, string(c.dbType)).
WithDBFactory(c.reopenNativeForAdapter).
SetMetricsEnabled(c.config.EnableMetrics)
case DatabaseTypeMSSQL:
c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB, string(c.dbType)).WithDBFactory(c.reopenNativeForAdapter)
c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB, string(c.dbType)).
WithDBFactory(c.reopenNativeForAdapter).
SetMetricsEnabled(c.config.EnableMetrics)
default:
return nil, ErrUnsupportedDatabase
}

View File

@@ -19,7 +19,7 @@ type Provider interface {
DecRequestsInFlight()
// RecordDBQuery records metrics for a database query
RecordDBQuery(operation, table string, duration time.Duration, err error)
RecordDBQuery(operation, schema, entity, table string, duration time.Duration, err error)
// RecordCacheHit records a cache hit
RecordCacheHit(provider string)
@@ -69,7 +69,7 @@ type NoOpProvider struct{}
func (n *NoOpProvider) RecordHTTPRequest(method, path, status string, duration time.Duration) {}
func (n *NoOpProvider) IncRequestsInFlight() {}
func (n *NoOpProvider) DecRequestsInFlight() {}
func (n *NoOpProvider) RecordDBQuery(operation, table string, duration time.Duration, err error) {
func (n *NoOpProvider) RecordDBQuery(operation, schema, entity, table string, duration time.Duration, err error) {
}
func (n *NoOpProvider) RecordCacheHit(provider string) {}
func (n *NoOpProvider) RecordCacheMiss(provider string) {}

View File

@@ -83,14 +83,14 @@ func NewPrometheusProvider(cfg *Config) *PrometheusProvider {
Help: "Database query duration in seconds",
Buckets: cfg.DBQueryBuckets,
},
[]string{"operation", "table"},
[]string{"operation", "schema", "entity", "table"},
),
dbQueryTotal: promauto.NewCounterVec(
prometheus.CounterOpts{
Name: metricName("db_queries_total"),
Help: "Total number of database queries",
},
[]string{"operation", "table", "status"},
[]string{"operation", "schema", "entity", "table", "status"},
),
cacheHits: promauto.NewCounterVec(
prometheus.CounterOpts{
@@ -204,13 +204,13 @@ func (p *PrometheusProvider) DecRequestsInFlight() {
}
// RecordDBQuery implements Provider interface
func (p *PrometheusProvider) RecordDBQuery(operation, table string, duration time.Duration, err error) {
func (p *PrometheusProvider) RecordDBQuery(operation, schema, entity, table string, duration time.Duration, err error) {
status := "success"
if err != nil {
status = "error"
}
p.dbQueryDuration.WithLabelValues(operation, table).Observe(duration.Seconds())
p.dbQueryTotal.WithLabelValues(operation, table, status).Inc()
p.dbQueryDuration.WithLabelValues(operation, schema, entity, table).Observe(duration.Seconds())
p.dbQueryTotal.WithLabelValues(operation, schema, entity, table, status).Inc()
}
// RecordCacheHit implements Provider interface