refactor(db): remove metrics enabling methods from adapters

This commit is contained in:
Hein
2026-04-10 14:13:15 +02:00
parent e8d0ab28c3
commit dfb63c3328
6 changed files with 201 additions and 103 deletions

View File

@@ -122,16 +122,6 @@ func (b *BunAdapter) SetMetricsEnabled(enabled bool) *BunAdapter {
return b
}
// EnableMetrics enables query metrics for this adapter.
func (b *BunAdapter) EnableMetrics() *BunAdapter {
return b.SetMetricsEnabled(true)
}
// DisableMetrics disables query metrics for this adapter.
func (b *BunAdapter) DisableMetrics() *BunAdapter {
return b.SetMetricsEnabled(false)
}
func (b *BunAdapter) getDB() *bun.DB {
b.dbMu.RLock()
defer b.dbMu.RUnlock()
@@ -1249,30 +1239,28 @@ func (b *BunSelectQuery) Having(having string, args ...interface{}) common.Selec
}
func (b *BunSelectQuery) Scan(ctx context.Context, dest interface{}) (err error) {
startedAt := time.Now()
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("BunSelectQuery.Scan", r)
}
recordQueryMetrics(b.metricsEnabled, "SELECT", b.schema, b.entity, b.tableName, startedAt, err)
}()
startedAt := time.Now()
if dest == nil {
return fmt.Errorf("destination cannot be nil")
err = fmt.Errorf("destination cannot be nil")
return err
}
err = b.query.Scan(ctx, dest)
if err != nil {
// Log SQL string for debugging
sqlStr := b.query.String()
logger.Error("BunSelectQuery.Scan failed. SQL: %s. Error: %v", sqlStr, err)
recordQueryMetrics(b.metricsEnabled, "SELECT", b.schema, b.entity, b.tableName, startedAt, err)
return err
}
recordQueryMetrics(b.metricsEnabled, "SELECT", b.schema, b.entity, b.tableName, startedAt, nil)
return nil
return err
}
func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) {
startedAt := time.Now()
defer func() {
if r := recover(); r != nil {
// Enhanced panic recovery with model information
@@ -1282,7 +1270,6 @@ func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) {
modelValue := model.Value()
modelInfo = fmt.Sprintf("Model type: %T", modelValue)
// Try to get the model's underlying struct type
v := reflect.ValueOf(modelValue)
if v.Kind() == reflect.Ptr {
v = v.Elem()
@@ -1302,11 +1289,12 @@ func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) {
logger.Error("Panic in BunSelectQuery.ScanModel: %v. %s. SQL: %s", r, modelInfo, sqlStr)
err = logger.HandlePanic("BunSelectQuery.ScanModel", r)
}
recordQueryMetrics(b.metricsEnabled, "SELECT", b.schema, b.entity, b.tableName, startedAt, err)
}()
if b.query.GetModel() == nil {
return fmt.Errorf("model is nil")
err = fmt.Errorf("model is nil")
return err
}
startedAt := time.Now()
// Optional: Enable detailed field-level debugging (set to true to debug)
const enableDetailedDebug = true
@@ -1321,45 +1309,40 @@ func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) {
err = b.query.Scan(ctx)
if err != nil {
// Log SQL string for debugging
sqlStr := b.query.String()
logger.Error("BunSelectQuery.ScanModel failed. SQL: %s. Error: %v", sqlStr, err)
recordQueryMetrics(b.metricsEnabled, "SELECT", b.schema, b.entity, b.tableName, startedAt, err)
return err
}
// After main query, load custom preloads using separate queries
if len(b.customPreloads) > 0 {
logger.Info("Loading %d custom preload(s) with separate queries", len(b.customPreloads))
if err := b.loadCustomPreloads(ctx); err != nil {
if err = b.loadCustomPreloads(ctx); err != nil {
logger.Error("Failed to load custom preloads: %v", err)
recordQueryMetrics(b.metricsEnabled, "SELECT", b.schema, b.entity, b.tableName, startedAt, err)
return err
}
}
recordQueryMetrics(b.metricsEnabled, "SELECT", b.schema, b.entity, b.tableName, startedAt, nil)
return nil
}
func (b *BunSelectQuery) Count(ctx context.Context) (count int, err error) {
startedAt := time.Now()
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("BunSelectQuery.Count", r)
count = 0
}
recordQueryMetrics(b.metricsEnabled, "COUNT", b.schema, b.entity, b.tableName, startedAt, err)
}()
startedAt := time.Now()
// If Model() was set, use bun's native Count() which works properly
if b.hasModel {
count, err := b.query.Count(ctx)
count, err = b.query.Count(ctx) // assign to named returns, not shadow vars
if err != nil {
// Log SQL string for debugging
sqlStr := b.query.String()
logger.Error("BunSelectQuery.Count failed. SQL: %s. Error: %v", sqlStr, err)
}
recordQueryMetrics(b.metricsEnabled, "COUNT", b.schema, b.entity, b.tableName, startedAt, err)
return count, err
return
}
// Otherwise, wrap as subquery to avoid "Model(nil)" error
@@ -1369,30 +1352,27 @@ func (b *BunSelectQuery) Count(ctx context.Context) (count int, err error) {
ColumnExpr("COUNT(*)")
err = countQuery.Scan(ctx, &count)
if err != nil {
// Log SQL string for debugging
sqlStr := countQuery.String()
logger.Error("BunSelectQuery.Count (subquery) failed. SQL: %s. Error: %v", sqlStr, err)
}
recordQueryMetrics(b.metricsEnabled, "COUNT", b.schema, b.entity, b.tableName, startedAt, err)
return count, err
return
}
func (b *BunSelectQuery) Exists(ctx context.Context) (exists bool, err error) {
startedAt := time.Now()
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("BunSelectQuery.Exists", r)
exists = false
}
recordQueryMetrics(b.metricsEnabled, "EXISTS", b.schema, b.entity, b.tableName, startedAt, err)
}()
startedAt := time.Now()
exists, err = b.query.Exists(ctx)
if err != nil {
// Log SQL string for debugging
sqlStr := b.query.String()
logger.Error("BunSelectQuery.Exists failed. SQL: %s. Error: %v", sqlStr, err)
}
recordQueryMetrics(b.metricsEnabled, "EXISTS", b.schema, b.entity, b.tableName, startedAt, err)
return exists, err
return
}
// BunInsertQuery implements InsertQuery for Bun

View File

@@ -44,16 +44,6 @@ func (g *GormAdapter) SetMetricsEnabled(enabled bool) *GormAdapter {
return g
}
// EnableMetrics enables query metrics for this adapter.
func (g *GormAdapter) EnableMetrics() *GormAdapter {
return g.SetMetricsEnabled(true)
}
// DisableMetrics disables query metrics for this adapter.
func (g *GormAdapter) DisableMetrics() *GormAdapter {
return g.SetMetricsEnabled(false)
}
func (g *GormAdapter) getDB() *gorm.DB {
g.dbMu.RLock()
defer g.dbMu.RUnlock()
@@ -131,15 +121,15 @@ func (g *GormAdapter) NewSelect() common.SelectQuery {
}
func (g *GormAdapter) NewInsert() common.InsertQuery {
return &GormInsertQuery{db: g.getDB(), reconnect: g.reconnectDB, metricsEnabled: g.metricsEnabled}
return &GormInsertQuery{db: g.getDB(), reconnect: g.reconnectDB, driverName: g.driverName, metricsEnabled: g.metricsEnabled}
}
func (g *GormAdapter) NewUpdate() common.UpdateQuery {
return &GormUpdateQuery{db: g.getDB(), reconnect: g.reconnectDB, metricsEnabled: g.metricsEnabled}
return &GormUpdateQuery{db: g.getDB(), reconnect: g.reconnectDB, driverName: g.driverName, metricsEnabled: g.metricsEnabled}
}
func (g *GormAdapter) NewDelete() common.DeleteQuery {
return &GormDeleteQuery{db: g.getDB(), reconnect: g.reconnectDB, metricsEnabled: g.metricsEnabled}
return &GormDeleteQuery{db: g.getDB(), reconnect: g.reconnectDB, driverName: g.driverName, metricsEnabled: g.metricsEnabled}
}
func (g *GormAdapter) Exec(ctx context.Context, query string, args ...interface{}) (res common.Result, err error) {
@@ -693,20 +683,21 @@ type GormInsertQuery struct {
schema string
tableName string
entity string
driverName string
metricsEnabled bool
}
func (g *GormInsertQuery) Model(model interface{}) common.InsertQuery {
g.model = model
g.db = g.db.Model(model)
g.schema, g.tableName = schemaAndTableFromModel(model, g.db.Name())
g.schema, g.tableName = schemaAndTableFromModel(model, g.driverName)
g.entity = entityNameFromModel(model, g.tableName)
return g
}
func (g *GormInsertQuery) Table(table string) common.InsertQuery {
g.db = g.db.Table(table)
g.schema, g.tableName = parseTableName(table, g.db.Name())
g.schema, g.tableName = parseTableName(table, g.driverName)
if g.entity == "" {
g.entity = cleanMetricIdentifier(g.tableName)
}
@@ -767,20 +758,21 @@ type GormUpdateQuery struct {
schema string
tableName string
entity string
driverName string
metricsEnabled bool
}
func (g *GormUpdateQuery) Model(model interface{}) common.UpdateQuery {
g.model = model
g.db = g.db.Model(model)
g.schema, g.tableName = schemaAndTableFromModel(model, g.db.Name())
g.schema, g.tableName = schemaAndTableFromModel(model, g.driverName)
g.entity = entityNameFromModel(model, g.tableName)
return g
}
func (g *GormUpdateQuery) Table(table string) common.UpdateQuery {
g.db = g.db.Table(table)
g.schema, g.tableName = parseTableName(table, g.db.Name())
g.schema, g.tableName = parseTableName(table, g.driverName)
if g.entity == "" {
g.entity = cleanMetricIdentifier(g.tableName)
}
@@ -879,20 +871,21 @@ type GormDeleteQuery struct {
schema string
tableName string
entity string
driverName string
metricsEnabled bool
}
func (g *GormDeleteQuery) Model(model interface{}) common.DeleteQuery {
g.model = model
g.db = g.db.Model(model)
g.schema, g.tableName = schemaAndTableFromModel(model, g.db.Name())
g.schema, g.tableName = schemaAndTableFromModel(model, g.driverName)
g.entity = entityNameFromModel(model, g.tableName)
return g
}
func (g *GormDeleteQuery) Table(table string) common.DeleteQuery {
g.db = g.db.Table(table)
g.schema, g.tableName = parseTableName(table, g.db.Name())
g.schema, g.tableName = parseTableName(table, g.driverName)
if g.entity == "" {
g.entity = cleanMetricIdentifier(g.tableName)
}

View File

@@ -49,16 +49,6 @@ func (p *PgSQLAdapter) SetMetricsEnabled(enabled bool) *PgSQLAdapter {
return p
}
// EnableMetrics enables query metrics for this adapter.
func (p *PgSQLAdapter) EnableMetrics() *PgSQLAdapter {
return p.SetMetricsEnabled(true)
}
// DisableMetrics disables query metrics for this adapter.
func (p *PgSQLAdapter) DisableMetrics() *PgSQLAdapter {
return p.SetMetricsEnabled(false)
}
func (p *PgSQLAdapter) getDB() *sql.DB {
p.dbMu.RLock()
defer p.dbMu.RUnlock()
@@ -544,16 +534,8 @@ func (p *PgSQLSelectQuery) ScanModel(ctx context.Context) error {
return p.Scan(ctx, p.model)
}
func (p *PgSQLSelectQuery) Count(ctx context.Context) (count int, err error) {
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("PgSQLSelectQuery.Count", r)
count = 0
}
}()
startedAt := time.Now()
// Build a COUNT query
// countInternal executes the COUNT query and returns the result without recording metrics.
func (p *PgSQLSelectQuery) countInternal(ctx context.Context) (int, error) {
var sb strings.Builder
sb.WriteString("SELECT COUNT(*) FROM ")
sb.WriteString(p.tableName)
@@ -587,7 +569,22 @@ func (p *PgSQLSelectQuery) Count(ctx context.Context) (count int, err error) {
row = p.db.QueryRowContext(ctx, query, p.args...)
}
err = row.Scan(&count)
var count int
if err := row.Scan(&count); err != nil {
return 0, err
}
return count, nil
}
func (p *PgSQLSelectQuery) Count(ctx context.Context) (count int, err error) {
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("PgSQLSelectQuery.Count", r)
count = 0
}
}()
startedAt := time.Now()
count, err = p.countInternal(ctx)
if err != nil {
logger.Error("PgSQL COUNT failed: %v", err)
}
@@ -602,8 +599,12 @@ func (p *PgSQLSelectQuery) Exists(ctx context.Context) (exists bool, err error)
exists = false
}
}()
count, err := p.Count(ctx)
startedAt := time.Now()
count, err := p.countInternal(ctx)
if err != nil {
logger.Error("PgSQL EXISTS failed: %v", err)
}
recordQueryMetrics(p.metricsEnabled, "EXISTS", p.schema, p.entity, p.tableName, startedAt, err)
return count > 0, err
}
@@ -657,15 +658,17 @@ func (p *PgSQLInsertQuery) Returning(columns ...string) common.InsertQuery {
}
func (p *PgSQLInsertQuery) Exec(ctx context.Context) (res common.Result, err error) {
startedAt := time.Now()
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("PgSQLInsertQuery.Exec", r)
}
recordQueryMetrics(p.metricsEnabled, "INSERT", p.schema, p.entity, p.tableName, startedAt, err)
}()
startedAt := time.Now()
if len(p.values) == 0 {
return nil, fmt.Errorf("no values to insert")
err = fmt.Errorf("no values to insert")
return nil, err
}
columns := make([]string, 0, len(p.values))
@@ -699,11 +702,9 @@ func (p *PgSQLInsertQuery) Exec(ctx context.Context) (res common.Result, err err
if err != nil {
logger.Error("PgSQL INSERT failed: %v", err)
recordQueryMetrics(p.metricsEnabled, "INSERT", p.schema, p.entity, p.tableName, startedAt, err)
return nil, err
}
recordQueryMetrics(p.metricsEnabled, "INSERT", p.schema, p.entity, p.tableName, startedAt, nil)
return &PgSQLResult{result: result}, nil
}
@@ -810,15 +811,17 @@ func (p *PgSQLUpdateQuery) replacePlaceholders(query string, argCount int) strin
}
func (p *PgSQLUpdateQuery) Exec(ctx context.Context) (res common.Result, err error) {
startedAt := time.Now()
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("PgSQLUpdateQuery.Exec", r)
}
recordQueryMetrics(p.metricsEnabled, "UPDATE", p.schema, p.entity, p.tableName, startedAt, err)
}()
startedAt := time.Now()
if len(p.sets) == 0 {
return nil, fmt.Errorf("no values to update")
err = fmt.Errorf("no values to update")
return nil, err
}
setClauses := make([]string, 0, len(p.sets))
@@ -881,11 +884,9 @@ func (p *PgSQLUpdateQuery) Exec(ctx context.Context) (res common.Result, err err
if err != nil {
logger.Error("PgSQL UPDATE failed: %v", err)
recordQueryMetrics(p.metricsEnabled, "UPDATE", p.schema, p.entity, p.tableName, startedAt, err)
return nil, err
}
recordQueryMetrics(p.metricsEnabled, "UPDATE", p.schema, p.entity, p.tableName, startedAt, nil)
return &PgSQLResult{result: result}, nil
}
@@ -936,12 +937,13 @@ func (p *PgSQLDeleteQuery) replacePlaceholders(query string, argCount int) strin
}
func (p *PgSQLDeleteQuery) Exec(ctx context.Context) (res common.Result, err error) {
startedAt := time.Now()
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("PgSQLDeleteQuery.Exec", r)
}
recordQueryMetrics(p.metricsEnabled, "DELETE", p.schema, p.entity, p.tableName, startedAt, err)
}()
startedAt := time.Now()
query := fmt.Sprintf("DELETE FROM %s", p.tableName)
@@ -960,11 +962,9 @@ func (p *PgSQLDeleteQuery) Exec(ctx context.Context) (res common.Result, err err
if err != nil {
logger.Error("PgSQL DELETE failed: %v", err)
recordQueryMetrics(p.metricsEnabled, "DELETE", p.schema, p.entity, p.tableName, startedAt, err)
return nil, err
}
recordQueryMetrics(p.metricsEnabled, "DELETE", p.schema, p.entity, p.tableName, startedAt, nil)
return &PgSQLResult{result: result}, nil
}

View File

@@ -93,6 +93,9 @@ func schemaAndTableFromModel(model interface{}, driverName string) (schema, tabl
return parseTableName(provider.TableName(), driverName)
}
// tableNameProviderType is cached to avoid repeated reflection on every call.
var tableNameProviderType = reflect.TypeOf((*common.TableNameProvider)(nil)).Elem()
func tableNameProviderFromModel(model interface{}) (common.TableNameProvider, bool) {
if model == nil {
return nil, false
@@ -111,6 +114,12 @@ func tableNameProviderFromModel(model interface{}) (common.TableNameProvider, bo
return nil, false
}
// Check whether *T implements TableNameProvider before allocating.
ptrType := reflect.PointerTo(modelType)
if !ptrType.Implements(tableNameProviderType) && !modelType.Implements(tableNameProviderType) {
return nil, false
}
modelValue := reflect.New(modelType)
if provider, ok := modelValue.Interface().(common.TableNameProvider); ok {
return provider, true

View File

@@ -3,6 +3,7 @@ package database
import (
"context"
"database/sql"
"fmt"
"net/http"
"sync"
"testing"
@@ -86,8 +87,9 @@ func TestPgSQLAdapterRecordsSchemaEntityTableMetrics(t *testing.T) {
defer db.Close()
provider := &capturingMetricsProvider{}
prev := metrics.GetProvider()
metrics.SetProvider(provider)
defer metrics.SetProvider(nil)
defer metrics.SetProvider(prev)
mock.ExpectExec(`UPDATE users SET name = \$1 WHERE id = \$2`).
WithArgs("Alice", 1).
@@ -117,14 +119,15 @@ func TestPgSQLAdapterDisableMetricsSuppressesEmission(t *testing.T) {
defer db.Close()
provider := &capturingMetricsProvider{}
prev := metrics.GetProvider()
metrics.SetProvider(provider)
defer metrics.SetProvider(nil)
defer metrics.SetProvider(prev)
mock.ExpectExec(`DELETE FROM users WHERE id = \$1`).
WithArgs(1).
WillReturnResult(sqlmock.NewResult(0, 1))
adapter := NewPgSQLAdapter(db).DisableMetrics()
adapter := NewPgSQLAdapter(db).SetMetricsEnabled(false)
_, err = adapter.NewDelete().
Table("users").
Where("id = ?", 1).
@@ -143,8 +146,9 @@ func TestGormAdapterRecordsEntityAndTableMetrics(t *testing.T) {
require.NoError(t, db.Create(&queryMetricsGormUser{Name: "Alice"}).Error)
provider := &capturingMetricsProvider{}
prev := metrics.GetProvider()
metrics.SetProvider(provider)
defer metrics.SetProvider(nil)
defer metrics.SetProvider(prev)
adapter := NewGormAdapter(db)
var users []queryMetricsGormUser
@@ -161,6 +165,109 @@ func TestGormAdapterRecordsEntityAndTableMetrics(t *testing.T) {
assert.Equal(t, "metrics_gorm_users", calls[0].table)
}
func TestPgSQLAdapterRecordsErrorMetric(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
provider := &capturingMetricsProvider{}
prev := metrics.GetProvider()
metrics.SetProvider(provider)
defer metrics.SetProvider(prev)
mock.ExpectExec(`INSERT INTO users`).
WillReturnError(fmt.Errorf("unique constraint violation"))
adapter := NewPgSQLAdapter(db)
_, err = adapter.NewInsert().
Table("users").
Value("name", "Alice").
Exec(context.Background())
require.Error(t, err)
calls := provider.snapshot()
require.Len(t, calls, 1)
assert.Equal(t, "INSERT", calls[0].operation)
assert.Equal(t, "users", calls[0].table)
}
func TestPgSQLAdapterRecordsExistsMetric(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
provider := &capturingMetricsProvider{}
prev := metrics.GetProvider()
metrics.SetProvider(provider)
defer metrics.SetProvider(prev)
mock.ExpectQuery(`SELECT COUNT\(\*\) FROM users`).
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(3))
adapter := NewPgSQLAdapter(db)
exists, err := adapter.NewSelect().Table("users").Exists(context.Background())
require.NoError(t, err)
assert.True(t, exists)
calls := provider.snapshot()
require.Len(t, calls, 1)
assert.Equal(t, "EXISTS", calls[0].operation)
assert.Equal(t, "users", calls[0].table)
}
func TestPgSQLAdapterRecordsCountMetric(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
provider := &capturingMetricsProvider{}
prev := metrics.GetProvider()
metrics.SetProvider(provider)
defer metrics.SetProvider(prev)
mock.ExpectQuery(`SELECT COUNT\(\*\) FROM users`).
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(5))
adapter := NewPgSQLAdapter(db)
count, err := adapter.NewSelect().Table("users").Count(context.Background())
require.NoError(t, err)
assert.Equal(t, 5, count)
calls := provider.snapshot()
require.Len(t, calls, 1)
assert.Equal(t, "COUNT", calls[0].operation)
assert.Equal(t, "users", calls[0].table)
}
func TestPgSQLAdapterRawExecRecordsMetric(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
provider := &capturingMetricsProvider{}
prev := metrics.GetProvider()
metrics.SetProvider(provider)
defer metrics.SetProvider(prev)
mock.ExpectExec(`UPDATE public\.orders SET status = \$1`).
WithArgs("shipped").
WillReturnResult(sqlmock.NewResult(0, 2))
adapter := NewPgSQLAdapter(db)
_, err = adapter.Exec(context.Background(), `UPDATE public.orders SET status = $1`, "shipped")
require.NoError(t, err)
calls := provider.snapshot()
require.Len(t, calls, 1)
assert.Equal(t, "UPDATE", calls[0].operation)
assert.Equal(t, "public", calls[0].schema)
assert.Equal(t, "orders", calls[0].table)
}
func TestBunAdapterRecordsEntityAndTableMetrics(t *testing.T) {
sqldb, err := sql.Open(sqliteshim.ShimName, "file::memory:?cache=shared")
require.NoError(t, err)
@@ -179,8 +286,9 @@ func TestBunAdapterRecordsEntityAndTableMetrics(t *testing.T) {
require.NoError(t, err)
provider := &capturingMetricsProvider{}
prev := metrics.GetProvider()
metrics.SetProvider(provider)
defer metrics.SetProvider(nil)
defer metrics.SetProvider(prev)
adapter := NewBunAdapter(db)
var users []queryMetricsBunUser