diff --git a/pkg/common/adapters/database/bun.go b/pkg/common/adapters/database/bun.go index 598c7ef..b74011a 100644 --- a/pkg/common/adapters/database/bun.go +++ b/pkg/common/adapters/database/bun.go @@ -122,16 +122,6 @@ func (b *BunAdapter) SetMetricsEnabled(enabled bool) *BunAdapter { return b } -// EnableMetrics enables query metrics for this adapter. -func (b *BunAdapter) EnableMetrics() *BunAdapter { - return b.SetMetricsEnabled(true) -} - -// DisableMetrics disables query metrics for this adapter. -func (b *BunAdapter) DisableMetrics() *BunAdapter { - return b.SetMetricsEnabled(false) -} - func (b *BunAdapter) getDB() *bun.DB { b.dbMu.RLock() defer b.dbMu.RUnlock() @@ -1249,30 +1239,28 @@ func (b *BunSelectQuery) Having(having string, args ...interface{}) common.Selec } func (b *BunSelectQuery) Scan(ctx context.Context, dest interface{}) (err error) { + startedAt := time.Now() defer func() { if r := recover(); r != nil { err = logger.HandlePanic("BunSelectQuery.Scan", r) } + recordQueryMetrics(b.metricsEnabled, "SELECT", b.schema, b.entity, b.tableName, startedAt, err) }() - startedAt := time.Now() if dest == nil { - return fmt.Errorf("destination cannot be nil") + err = fmt.Errorf("destination cannot be nil") + return err } err = b.query.Scan(ctx, dest) if err != nil { - // Log SQL string for debugging sqlStr := b.query.String() logger.Error("BunSelectQuery.Scan failed. SQL: %s. Error: %v", sqlStr, err) - recordQueryMetrics(b.metricsEnabled, "SELECT", b.schema, b.entity, b.tableName, startedAt, err) - return err } - - recordQueryMetrics(b.metricsEnabled, "SELECT", b.schema, b.entity, b.tableName, startedAt, nil) - return nil + return err } func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) { + startedAt := time.Now() defer func() { if r := recover(); r != nil { // Enhanced panic recovery with model information @@ -1282,7 +1270,6 @@ func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) { modelValue := model.Value() modelInfo = fmt.Sprintf("Model type: %T", modelValue) - // Try to get the model's underlying struct type v := reflect.ValueOf(modelValue) if v.Kind() == reflect.Ptr { v = v.Elem() @@ -1302,11 +1289,12 @@ func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) { logger.Error("Panic in BunSelectQuery.ScanModel: %v. %s. SQL: %s", r, modelInfo, sqlStr) err = logger.HandlePanic("BunSelectQuery.ScanModel", r) } + recordQueryMetrics(b.metricsEnabled, "SELECT", b.schema, b.entity, b.tableName, startedAt, err) }() if b.query.GetModel() == nil { - return fmt.Errorf("model is nil") + err = fmt.Errorf("model is nil") + return err } - startedAt := time.Now() // Optional: Enable detailed field-level debugging (set to true to debug) const enableDetailedDebug = true @@ -1321,45 +1309,40 @@ func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) { err = b.query.Scan(ctx) if err != nil { - // Log SQL string for debugging sqlStr := b.query.String() logger.Error("BunSelectQuery.ScanModel failed. SQL: %s. Error: %v", sqlStr, err) - recordQueryMetrics(b.metricsEnabled, "SELECT", b.schema, b.entity, b.tableName, startedAt, err) return err } // After main query, load custom preloads using separate queries if len(b.customPreloads) > 0 { logger.Info("Loading %d custom preload(s) with separate queries", len(b.customPreloads)) - if err := b.loadCustomPreloads(ctx); err != nil { + if err = b.loadCustomPreloads(ctx); err != nil { logger.Error("Failed to load custom preloads: %v", err) - recordQueryMetrics(b.metricsEnabled, "SELECT", b.schema, b.entity, b.tableName, startedAt, err) return err } } - recordQueryMetrics(b.metricsEnabled, "SELECT", b.schema, b.entity, b.tableName, startedAt, nil) return nil } func (b *BunSelectQuery) Count(ctx context.Context) (count int, err error) { + startedAt := time.Now() defer func() { if r := recover(); r != nil { err = logger.HandlePanic("BunSelectQuery.Count", r) count = 0 } + recordQueryMetrics(b.metricsEnabled, "COUNT", b.schema, b.entity, b.tableName, startedAt, err) }() - startedAt := time.Now() // If Model() was set, use bun's native Count() which works properly if b.hasModel { - count, err := b.query.Count(ctx) + count, err = b.query.Count(ctx) // assign to named returns, not shadow vars if err != nil { - // Log SQL string for debugging sqlStr := b.query.String() logger.Error("BunSelectQuery.Count failed. SQL: %s. Error: %v", sqlStr, err) } - recordQueryMetrics(b.metricsEnabled, "COUNT", b.schema, b.entity, b.tableName, startedAt, err) - return count, err + return } // Otherwise, wrap as subquery to avoid "Model(nil)" error @@ -1369,30 +1352,27 @@ func (b *BunSelectQuery) Count(ctx context.Context) (count int, err error) { ColumnExpr("COUNT(*)") err = countQuery.Scan(ctx, &count) if err != nil { - // Log SQL string for debugging sqlStr := countQuery.String() logger.Error("BunSelectQuery.Count (subquery) failed. SQL: %s. Error: %v", sqlStr, err) } - recordQueryMetrics(b.metricsEnabled, "COUNT", b.schema, b.entity, b.tableName, startedAt, err) - return count, err + return } func (b *BunSelectQuery) Exists(ctx context.Context) (exists bool, err error) { + startedAt := time.Now() defer func() { if r := recover(); r != nil { err = logger.HandlePanic("BunSelectQuery.Exists", r) exists = false } + recordQueryMetrics(b.metricsEnabled, "EXISTS", b.schema, b.entity, b.tableName, startedAt, err) }() - startedAt := time.Now() exists, err = b.query.Exists(ctx) if err != nil { - // Log SQL string for debugging sqlStr := b.query.String() logger.Error("BunSelectQuery.Exists failed. SQL: %s. Error: %v", sqlStr, err) } - recordQueryMetrics(b.metricsEnabled, "EXISTS", b.schema, b.entity, b.tableName, startedAt, err) - return exists, err + return } // BunInsertQuery implements InsertQuery for Bun diff --git a/pkg/common/adapters/database/gorm.go b/pkg/common/adapters/database/gorm.go index 06ea4cd..d6b9935 100644 --- a/pkg/common/adapters/database/gorm.go +++ b/pkg/common/adapters/database/gorm.go @@ -44,16 +44,6 @@ func (g *GormAdapter) SetMetricsEnabled(enabled bool) *GormAdapter { return g } -// EnableMetrics enables query metrics for this adapter. -func (g *GormAdapter) EnableMetrics() *GormAdapter { - return g.SetMetricsEnabled(true) -} - -// DisableMetrics disables query metrics for this adapter. -func (g *GormAdapter) DisableMetrics() *GormAdapter { - return g.SetMetricsEnabled(false) -} - func (g *GormAdapter) getDB() *gorm.DB { g.dbMu.RLock() defer g.dbMu.RUnlock() @@ -131,15 +121,15 @@ func (g *GormAdapter) NewSelect() common.SelectQuery { } func (g *GormAdapter) NewInsert() common.InsertQuery { - return &GormInsertQuery{db: g.getDB(), reconnect: g.reconnectDB, metricsEnabled: g.metricsEnabled} + return &GormInsertQuery{db: g.getDB(), reconnect: g.reconnectDB, driverName: g.driverName, metricsEnabled: g.metricsEnabled} } func (g *GormAdapter) NewUpdate() common.UpdateQuery { - return &GormUpdateQuery{db: g.getDB(), reconnect: g.reconnectDB, metricsEnabled: g.metricsEnabled} + return &GormUpdateQuery{db: g.getDB(), reconnect: g.reconnectDB, driverName: g.driverName, metricsEnabled: g.metricsEnabled} } func (g *GormAdapter) NewDelete() common.DeleteQuery { - return &GormDeleteQuery{db: g.getDB(), reconnect: g.reconnectDB, metricsEnabled: g.metricsEnabled} + return &GormDeleteQuery{db: g.getDB(), reconnect: g.reconnectDB, driverName: g.driverName, metricsEnabled: g.metricsEnabled} } func (g *GormAdapter) Exec(ctx context.Context, query string, args ...interface{}) (res common.Result, err error) { @@ -693,20 +683,21 @@ type GormInsertQuery struct { schema string tableName string entity string + driverName string metricsEnabled bool } func (g *GormInsertQuery) Model(model interface{}) common.InsertQuery { g.model = model g.db = g.db.Model(model) - g.schema, g.tableName = schemaAndTableFromModel(model, g.db.Name()) + g.schema, g.tableName = schemaAndTableFromModel(model, g.driverName) g.entity = entityNameFromModel(model, g.tableName) return g } func (g *GormInsertQuery) Table(table string) common.InsertQuery { g.db = g.db.Table(table) - g.schema, g.tableName = parseTableName(table, g.db.Name()) + g.schema, g.tableName = parseTableName(table, g.driverName) if g.entity == "" { g.entity = cleanMetricIdentifier(g.tableName) } @@ -767,20 +758,21 @@ type GormUpdateQuery struct { schema string tableName string entity string + driverName string metricsEnabled bool } func (g *GormUpdateQuery) Model(model interface{}) common.UpdateQuery { g.model = model g.db = g.db.Model(model) - g.schema, g.tableName = schemaAndTableFromModel(model, g.db.Name()) + g.schema, g.tableName = schemaAndTableFromModel(model, g.driverName) g.entity = entityNameFromModel(model, g.tableName) return g } func (g *GormUpdateQuery) Table(table string) common.UpdateQuery { g.db = g.db.Table(table) - g.schema, g.tableName = parseTableName(table, g.db.Name()) + g.schema, g.tableName = parseTableName(table, g.driverName) if g.entity == "" { g.entity = cleanMetricIdentifier(g.tableName) } @@ -879,20 +871,21 @@ type GormDeleteQuery struct { schema string tableName string entity string + driverName string metricsEnabled bool } func (g *GormDeleteQuery) Model(model interface{}) common.DeleteQuery { g.model = model g.db = g.db.Model(model) - g.schema, g.tableName = schemaAndTableFromModel(model, g.db.Name()) + g.schema, g.tableName = schemaAndTableFromModel(model, g.driverName) g.entity = entityNameFromModel(model, g.tableName) return g } func (g *GormDeleteQuery) Table(table string) common.DeleteQuery { g.db = g.db.Table(table) - g.schema, g.tableName = parseTableName(table, g.db.Name()) + g.schema, g.tableName = parseTableName(table, g.driverName) if g.entity == "" { g.entity = cleanMetricIdentifier(g.tableName) } diff --git a/pkg/common/adapters/database/pgsql.go b/pkg/common/adapters/database/pgsql.go index 8f4bf6c..64cd3a3 100644 --- a/pkg/common/adapters/database/pgsql.go +++ b/pkg/common/adapters/database/pgsql.go @@ -49,16 +49,6 @@ func (p *PgSQLAdapter) SetMetricsEnabled(enabled bool) *PgSQLAdapter { return p } -// EnableMetrics enables query metrics for this adapter. -func (p *PgSQLAdapter) EnableMetrics() *PgSQLAdapter { - return p.SetMetricsEnabled(true) -} - -// DisableMetrics disables query metrics for this adapter. -func (p *PgSQLAdapter) DisableMetrics() *PgSQLAdapter { - return p.SetMetricsEnabled(false) -} - func (p *PgSQLAdapter) getDB() *sql.DB { p.dbMu.RLock() defer p.dbMu.RUnlock() @@ -544,16 +534,8 @@ func (p *PgSQLSelectQuery) ScanModel(ctx context.Context) error { return p.Scan(ctx, p.model) } -func (p *PgSQLSelectQuery) Count(ctx context.Context) (count int, err error) { - defer func() { - if r := recover(); r != nil { - err = logger.HandlePanic("PgSQLSelectQuery.Count", r) - count = 0 - } - }() - startedAt := time.Now() - - // Build a COUNT query +// countInternal executes the COUNT query and returns the result without recording metrics. +func (p *PgSQLSelectQuery) countInternal(ctx context.Context) (int, error) { var sb strings.Builder sb.WriteString("SELECT COUNT(*) FROM ") sb.WriteString(p.tableName) @@ -587,7 +569,22 @@ func (p *PgSQLSelectQuery) Count(ctx context.Context) (count int, err error) { row = p.db.QueryRowContext(ctx, query, p.args...) } - err = row.Scan(&count) + var count int + if err := row.Scan(&count); err != nil { + return 0, err + } + return count, nil +} + +func (p *PgSQLSelectQuery) Count(ctx context.Context) (count int, err error) { + defer func() { + if r := recover(); r != nil { + err = logger.HandlePanic("PgSQLSelectQuery.Count", r) + count = 0 + } + }() + startedAt := time.Now() + count, err = p.countInternal(ctx) if err != nil { logger.Error("PgSQL COUNT failed: %v", err) } @@ -602,8 +599,12 @@ func (p *PgSQLSelectQuery) Exists(ctx context.Context) (exists bool, err error) exists = false } }() - - count, err := p.Count(ctx) + startedAt := time.Now() + count, err := p.countInternal(ctx) + if err != nil { + logger.Error("PgSQL EXISTS failed: %v", err) + } + recordQueryMetrics(p.metricsEnabled, "EXISTS", p.schema, p.entity, p.tableName, startedAt, err) return count > 0, err } @@ -657,15 +658,17 @@ func (p *PgSQLInsertQuery) Returning(columns ...string) common.InsertQuery { } func (p *PgSQLInsertQuery) Exec(ctx context.Context) (res common.Result, err error) { + startedAt := time.Now() defer func() { if r := recover(); r != nil { err = logger.HandlePanic("PgSQLInsertQuery.Exec", r) } + recordQueryMetrics(p.metricsEnabled, "INSERT", p.schema, p.entity, p.tableName, startedAt, err) }() - startedAt := time.Now() if len(p.values) == 0 { - return nil, fmt.Errorf("no values to insert") + err = fmt.Errorf("no values to insert") + return nil, err } columns := make([]string, 0, len(p.values)) @@ -699,11 +702,9 @@ func (p *PgSQLInsertQuery) Exec(ctx context.Context) (res common.Result, err err if err != nil { logger.Error("PgSQL INSERT failed: %v", err) - recordQueryMetrics(p.metricsEnabled, "INSERT", p.schema, p.entity, p.tableName, startedAt, err) return nil, err } - recordQueryMetrics(p.metricsEnabled, "INSERT", p.schema, p.entity, p.tableName, startedAt, nil) return &PgSQLResult{result: result}, nil } @@ -810,15 +811,17 @@ func (p *PgSQLUpdateQuery) replacePlaceholders(query string, argCount int) strin } func (p *PgSQLUpdateQuery) Exec(ctx context.Context) (res common.Result, err error) { + startedAt := time.Now() defer func() { if r := recover(); r != nil { err = logger.HandlePanic("PgSQLUpdateQuery.Exec", r) } + recordQueryMetrics(p.metricsEnabled, "UPDATE", p.schema, p.entity, p.tableName, startedAt, err) }() - startedAt := time.Now() if len(p.sets) == 0 { - return nil, fmt.Errorf("no values to update") + err = fmt.Errorf("no values to update") + return nil, err } setClauses := make([]string, 0, len(p.sets)) @@ -881,11 +884,9 @@ func (p *PgSQLUpdateQuery) Exec(ctx context.Context) (res common.Result, err err if err != nil { logger.Error("PgSQL UPDATE failed: %v", err) - recordQueryMetrics(p.metricsEnabled, "UPDATE", p.schema, p.entity, p.tableName, startedAt, err) return nil, err } - recordQueryMetrics(p.metricsEnabled, "UPDATE", p.schema, p.entity, p.tableName, startedAt, nil) return &PgSQLResult{result: result}, nil } @@ -936,12 +937,13 @@ func (p *PgSQLDeleteQuery) replacePlaceholders(query string, argCount int) strin } func (p *PgSQLDeleteQuery) Exec(ctx context.Context) (res common.Result, err error) { + startedAt := time.Now() defer func() { if r := recover(); r != nil { err = logger.HandlePanic("PgSQLDeleteQuery.Exec", r) } + recordQueryMetrics(p.metricsEnabled, "DELETE", p.schema, p.entity, p.tableName, startedAt, err) }() - startedAt := time.Now() query := fmt.Sprintf("DELETE FROM %s", p.tableName) @@ -960,11 +962,9 @@ func (p *PgSQLDeleteQuery) Exec(ctx context.Context) (res common.Result, err err if err != nil { logger.Error("PgSQL DELETE failed: %v", err) - recordQueryMetrics(p.metricsEnabled, "DELETE", p.schema, p.entity, p.tableName, startedAt, err) return nil, err } - recordQueryMetrics(p.metricsEnabled, "DELETE", p.schema, p.entity, p.tableName, startedAt, nil) return &PgSQLResult{result: result}, nil } diff --git a/pkg/common/adapters/database/query_metrics.go b/pkg/common/adapters/database/query_metrics.go index 1a00ef4..2d48334 100644 --- a/pkg/common/adapters/database/query_metrics.go +++ b/pkg/common/adapters/database/query_metrics.go @@ -93,6 +93,9 @@ func schemaAndTableFromModel(model interface{}, driverName string) (schema, tabl return parseTableName(provider.TableName(), driverName) } +// tableNameProviderType is cached to avoid repeated reflection on every call. +var tableNameProviderType = reflect.TypeOf((*common.TableNameProvider)(nil)).Elem() + func tableNameProviderFromModel(model interface{}) (common.TableNameProvider, bool) { if model == nil { return nil, false @@ -111,6 +114,12 @@ func tableNameProviderFromModel(model interface{}) (common.TableNameProvider, bo return nil, false } + // Check whether *T implements TableNameProvider before allocating. + ptrType := reflect.PointerTo(modelType) + if !ptrType.Implements(tableNameProviderType) && !modelType.Implements(tableNameProviderType) { + return nil, false + } + modelValue := reflect.New(modelType) if provider, ok := modelValue.Interface().(common.TableNameProvider); ok { return provider, true diff --git a/pkg/common/adapters/database/query_metrics_test.go b/pkg/common/adapters/database/query_metrics_test.go index f281334..b0ab8a4 100644 --- a/pkg/common/adapters/database/query_metrics_test.go +++ b/pkg/common/adapters/database/query_metrics_test.go @@ -3,6 +3,7 @@ package database import ( "context" "database/sql" + "fmt" "net/http" "sync" "testing" @@ -86,8 +87,9 @@ func TestPgSQLAdapterRecordsSchemaEntityTableMetrics(t *testing.T) { defer db.Close() provider := &capturingMetricsProvider{} + prev := metrics.GetProvider() metrics.SetProvider(provider) - defer metrics.SetProvider(nil) + defer metrics.SetProvider(prev) mock.ExpectExec(`UPDATE users SET name = \$1 WHERE id = \$2`). WithArgs("Alice", 1). @@ -117,14 +119,15 @@ func TestPgSQLAdapterDisableMetricsSuppressesEmission(t *testing.T) { defer db.Close() provider := &capturingMetricsProvider{} + prev := metrics.GetProvider() metrics.SetProvider(provider) - defer metrics.SetProvider(nil) + defer metrics.SetProvider(prev) mock.ExpectExec(`DELETE FROM users WHERE id = \$1`). WithArgs(1). WillReturnResult(sqlmock.NewResult(0, 1)) - adapter := NewPgSQLAdapter(db).DisableMetrics() + adapter := NewPgSQLAdapter(db).SetMetricsEnabled(false) _, err = adapter.NewDelete(). Table("users"). Where("id = ?", 1). @@ -143,8 +146,9 @@ func TestGormAdapterRecordsEntityAndTableMetrics(t *testing.T) { require.NoError(t, db.Create(&queryMetricsGormUser{Name: "Alice"}).Error) provider := &capturingMetricsProvider{} + prev := metrics.GetProvider() metrics.SetProvider(provider) - defer metrics.SetProvider(nil) + defer metrics.SetProvider(prev) adapter := NewGormAdapter(db) var users []queryMetricsGormUser @@ -161,6 +165,109 @@ func TestGormAdapterRecordsEntityAndTableMetrics(t *testing.T) { assert.Equal(t, "metrics_gorm_users", calls[0].table) } +func TestPgSQLAdapterRecordsErrorMetric(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer db.Close() + + provider := &capturingMetricsProvider{} + prev := metrics.GetProvider() + metrics.SetProvider(provider) + defer metrics.SetProvider(prev) + + mock.ExpectExec(`INSERT INTO users`). + WillReturnError(fmt.Errorf("unique constraint violation")) + + adapter := NewPgSQLAdapter(db) + _, err = adapter.NewInsert(). + Table("users"). + Value("name", "Alice"). + Exec(context.Background()) + + require.Error(t, err) + + calls := provider.snapshot() + require.Len(t, calls, 1) + assert.Equal(t, "INSERT", calls[0].operation) + assert.Equal(t, "users", calls[0].table) +} + +func TestPgSQLAdapterRecordsExistsMetric(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer db.Close() + + provider := &capturingMetricsProvider{} + prev := metrics.GetProvider() + metrics.SetProvider(provider) + defer metrics.SetProvider(prev) + + mock.ExpectQuery(`SELECT COUNT\(\*\) FROM users`). + WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(3)) + + adapter := NewPgSQLAdapter(db) + exists, err := adapter.NewSelect().Table("users").Exists(context.Background()) + + require.NoError(t, err) + assert.True(t, exists) + + calls := provider.snapshot() + require.Len(t, calls, 1) + assert.Equal(t, "EXISTS", calls[0].operation) + assert.Equal(t, "users", calls[0].table) +} + +func TestPgSQLAdapterRecordsCountMetric(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer db.Close() + + provider := &capturingMetricsProvider{} + prev := metrics.GetProvider() + metrics.SetProvider(provider) + defer metrics.SetProvider(prev) + + mock.ExpectQuery(`SELECT COUNT\(\*\) FROM users`). + WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(5)) + + adapter := NewPgSQLAdapter(db) + count, err := adapter.NewSelect().Table("users").Count(context.Background()) + + require.NoError(t, err) + assert.Equal(t, 5, count) + + calls := provider.snapshot() + require.Len(t, calls, 1) + assert.Equal(t, "COUNT", calls[0].operation) + assert.Equal(t, "users", calls[0].table) +} + +func TestPgSQLAdapterRawExecRecordsMetric(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer db.Close() + + provider := &capturingMetricsProvider{} + prev := metrics.GetProvider() + metrics.SetProvider(provider) + defer metrics.SetProvider(prev) + + mock.ExpectExec(`UPDATE public\.orders SET status = \$1`). + WithArgs("shipped"). + WillReturnResult(sqlmock.NewResult(0, 2)) + + adapter := NewPgSQLAdapter(db) + _, err = adapter.Exec(context.Background(), `UPDATE public.orders SET status = $1`, "shipped") + + require.NoError(t, err) + + calls := provider.snapshot() + require.Len(t, calls, 1) + assert.Equal(t, "UPDATE", calls[0].operation) + assert.Equal(t, "public", calls[0].schema) + assert.Equal(t, "orders", calls[0].table) +} + func TestBunAdapterRecordsEntityAndTableMetrics(t *testing.T) { sqldb, err := sql.Open(sqliteshim.ShimName, "file::memory:?cache=shared") require.NoError(t, err) @@ -179,8 +286,9 @@ func TestBunAdapterRecordsEntityAndTableMetrics(t *testing.T) { require.NoError(t, err) provider := &capturingMetricsProvider{} + prev := metrics.GetProvider() metrics.SetProvider(provider) - defer metrics.SetProvider(nil) + defer metrics.SetProvider(prev) adapter := NewBunAdapter(db) var users []queryMetricsBunUser diff --git a/pkg/metrics/interfaces.go b/pkg/metrics/interfaces.go index 07cdd9f..c040297 100644 --- a/pkg/metrics/interfaces.go +++ b/pkg/metrics/interfaces.go @@ -2,6 +2,7 @@ package metrics import ( "net/http" + "sync" "time" "github.com/bitechdev/ResolveSpec/pkg/logger" @@ -46,21 +47,28 @@ type Provider interface { Handler() http.Handler } -// globalProvider is the global metrics provider -var globalProvider Provider +// globalProvider is the global metrics provider, protected by globalProviderMu. +var ( + globalProviderMu sync.RWMutex + globalProvider Provider +) -// SetProvider sets the global metrics provider +// SetProvider sets the global metrics provider. func SetProvider(p Provider) { + globalProviderMu.Lock() globalProvider = p + globalProviderMu.Unlock() } -// GetProvider returns the current metrics provider +// GetProvider returns the current metrics provider. func GetProvider() Provider { - if globalProvider == nil { - // Return no-op provider if none is set + globalProviderMu.RLock() + p := globalProvider + globalProviderMu.RUnlock() + if p == nil { return &NoOpProvider{} } - return globalProvider + return p } // NoOpProvider is a no-op implementation of Provider