diff --git a/pkg/common/adapters/database/query_metrics.go b/pkg/common/adapters/database/query_metrics.go index 2d48334..f457213 100644 --- a/pkg/common/adapters/database/query_metrics.go +++ b/pkg/common/adapters/database/query_metrics.go @@ -10,6 +10,8 @@ import ( "github.com/bitechdev/ResolveSpec/pkg/reflection" ) +const maxMetricFallbackEntityLength = 120 + func recordQueryMetrics(enabled bool, operation, schema, entity, table string, startedAt time.Time, err error) { if !enabled { return @@ -136,7 +138,7 @@ func metricTargetFromRawQuery(query, driverName string) (operation, schema, enti operation = normalizeMetricOperation(firstQueryKeyword(query)) tableRef := tableFromRawQuery(query, operation) if tableRef == "" { - return operation, "", "unknown", "unknown" + return operation, "", fallbackMetricEntityFromQuery(query), "unknown" } schema, table = parseTableName(tableRef, driverName) @@ -144,6 +146,133 @@ func metricTargetFromRawQuery(query, driverName string) (operation, schema, enti return operation, schema, entity, table } +func fallbackMetricEntityFromQuery(query string) string { + query = sanitizeMetricQueryShape(query) + if query == "" { + return "unknown" + } + + if len(query) > maxMetricFallbackEntityLength { + return query[:maxMetricFallbackEntityLength-3] + "..." + } + + return query +} + +func sanitizeMetricQueryShape(query string) string { + query = strings.TrimSpace(query) + if query == "" { + return "" + } + + var out strings.Builder + for i := 0; i < len(query); { + if query[i] == '\'' { + out.WriteByte('?') + i++ + for i < len(query) { + if query[i] == '\'' { + if i+1 < len(query) && query[i+1] == '\'' { + i += 2 + continue + } + i++ + break + } + i++ + } + continue + } + + if query[i] == '?' { + out.WriteByte('?') + i++ + continue + } + + if query[i] == '$' && i+1 < len(query) && isASCIIDigit(query[i+1]) { + out.WriteByte('?') + i++ + for i < len(query) && isASCIIDigit(query[i]) { + i++ + } + continue + } + + if query[i] == ':' && (i == 0 || query[i-1] != ':') && i+1 < len(query) && isIdentifierStart(query[i+1]) { + out.WriteByte('?') + i++ + for i < len(query) && isIdentifierPart(query[i]) { + i++ + } + continue + } + + if query[i] == '@' && (i == 0 || query[i-1] != '@') && i+1 < len(query) && isIdentifierStart(query[i+1]) { + out.WriteByte('?') + i++ + for i < len(query) && isIdentifierPart(query[i]) { + i++ + } + continue + } + + if startsNumericLiteral(query, i) { + out.WriteByte('?') + i++ + for i < len(query) && (isASCIIDigit(query[i]) || query[i] == '.') { + i++ + } + continue + } + + out.WriteByte(query[i]) + i++ + } + + return strings.Join(strings.Fields(out.String()), " ") +} + +func startsNumericLiteral(query string, idx int) bool { + if idx >= len(query) { + return false + } + + start := idx + if query[idx] == '-' { + if idx+1 >= len(query) || !isASCIIDigit(query[idx+1]) { + return false + } + start++ + } + + if !isASCIIDigit(query[start]) { + return false + } + + if idx > 0 && isIdentifierPart(query[idx-1]) { + return false + } + + if start+1 < len(query) && query[start] == '0' && (query[start+1] == 'x' || query[start+1] == 'X') { + return false + } + + return true +} + +func isASCIIDigit(ch byte) bool { + return ch >= '0' && ch <= '9' +} + +func isIdentifierStart(ch byte) bool { + return (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || ch == '_' +} + +func isIdentifierPart(ch byte) bool { + return isIdentifierStart(ch) || isASCIIDigit(ch) +} + func firstQueryKeyword(query string) string { query = strings.TrimSpace(query) if query == "" { diff --git a/pkg/common/adapters/database/query_metrics_test.go b/pkg/common/adapters/database/query_metrics_test.go index b0ab8a4..a91bf49 100644 --- a/pkg/common/adapters/database/query_metrics_test.go +++ b/pkg/common/adapters/database/query_metrics_test.go @@ -5,6 +5,7 @@ import ( "database/sql" "fmt" "net/http" + "strings" "sync" "testing" "time" @@ -268,6 +269,47 @@ func TestPgSQLAdapterRawExecRecordsMetric(t *testing.T) { assert.Equal(t, "orders", calls[0].table) } +func TestPgSQLAdapterRawExecUsesSQLAsEntityWhenTargetUnknown(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer db.Close() + + provider := &capturingMetricsProvider{} + prev := metrics.GetProvider() + metrics.SetProvider(provider) + defer metrics.SetProvider(prev) + + query := `select core.c_setuserid($1)` + mock.ExpectExec(`select core\.c_setuserid\(\$1\)`). + WithArgs(42). + WillReturnResult(sqlmock.NewResult(0, 1)) + + adapter := NewPgSQLAdapter(db) + _, err = adapter.Exec(context.Background(), query, 42) + + require.NoError(t, err) + + calls := provider.snapshot() + require.Len(t, calls, 1) + assert.Equal(t, "SELECT", calls[0].operation) + assert.Equal(t, "default", calls[0].schema) + assert.Equal(t, "select core.c_setuserid(?)", calls[0].entity) + assert.Equal(t, "unknown", calls[0].table) +} + +func TestFallbackMetricEntityFromQuerySanitizesAndTruncates(t *testing.T) { + entity := fallbackMetricEntityFromQuery(" \n SELECT some_function(1, 'abc', $2, ?, :name, @p1, true, null) \t ") + assert.Equal(t, "SELECT some_function(?, ?, ?, ?, ?, ?, true, null)", entity) + + entity = fallbackMetricEntityFromQuery("SELECT price::numeric, id FROM logs WHERE code = -42") + assert.Equal(t, "SELECT price::numeric, id FROM logs WHERE code = ?", entity) + + longQuery := "SELECT " + strings.Repeat("x", maxMetricFallbackEntityLength) + entity = fallbackMetricEntityFromQuery(longQuery) + assert.Len(t, entity, maxMetricFallbackEntityLength) + assert.True(t, strings.HasSuffix(entity, "...")) +} + func TestBunAdapterRecordsEntityAndTableMetrics(t *testing.T) { sqldb, err := sql.Open(sqliteshim.ShimName, "file::memory:?cache=shared") require.NoError(t, err)