Compare commits

...

15 Commits

Author SHA1 Message Date
Hein
c696d502c5 extractTableAndColumn 2025-12-10 10:10:55 +02:00
Hein
4ed1fba6ad Fixed extractTableAndColumn 2025-12-10 10:10:43 +02:00
Hein
1d0407a16d Fixed linting
Some checks are pending
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Waiting to run
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Waiting to run
Build , Vet Test, and Lint / Lint Code (push) Waiting to run
Build , Vet Test, and Lint / Build (push) Waiting to run
Tests / Unit Tests (push) Waiting to run
Tests / Integration Tests (push) Waiting to run
2025-12-10 10:00:01 +02:00
Hein
99001c749d Better sql where validation 2025-12-10 09:52:13 +02:00
Hein
1f7a57f8e3 Tracking provider 2025-12-10 09:31:55 +02:00
Hein
a95c28a0bf Multi Token warning and handling 2025-12-10 08:44:37 +02:00
Hein
e1abd5ebc1 Enhanced the SanitizeWhereClause function 2025-12-10 08:36:24 +02:00
Hein
ca4e53969b Better tests
Some checks are pending
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Waiting to run
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Waiting to run
Build , Vet Test, and Lint / Lint Code (push) Waiting to run
Build , Vet Test, and Lint / Build (push) Waiting to run
Tests / Unit Tests (push) Waiting to run
Tests / Integration Tests (push) Waiting to run
2025-12-09 15:32:16 +02:00
Hein
db2b7e878e Better handling of preloads 2025-12-09 15:12:17 +02:00
Hein
9572bfc7b8 Fix qualified column reference (like APIL.rid_hub) in a preload: 2025-12-09 14:46:33 +02:00
Hein
f0962ea1ec Added EnableQueryDebug log 2025-12-09 14:37:09 +02:00
Hein
8fcb065b42 Better Query Debugging 2025-12-09 14:31:26 +02:00
Hein
dc3b621380 Fixed test for session id changes 2025-12-09 14:07:00 +02:00
Hein
a4dd2a7086 exposed types FromString 2025-12-09 14:03:55 +02:00
Hein
3ec2e5f15a Proper handling of fromString in the types 2025-12-09 13:55:51 +02:00
28 changed files with 3729 additions and 255 deletions

View File

@@ -71,35 +71,18 @@
}, },
"gocritic": { "gocritic": {
"enabled-checks": [ "enabled-checks": [
"appendAssign",
"assignOp",
"boolExprSimplify", "boolExprSimplify",
"builtinShadow", "builtinShadow",
"captLocal",
"caseOrder",
"defaultCaseOrder",
"dupArg",
"dupBranchBody",
"dupCase",
"dupSubExpr",
"elseif",
"emptyFallthrough", "emptyFallthrough",
"equalFold", "equalFold",
"flagName",
"indexAlloc", "indexAlloc",
"initClause", "initClause",
"methodExprCall", "methodExprCall",
"nilValReturn", "nilValReturn",
"rangeExprCopy", "rangeExprCopy",
"rangeValCopy", "rangeValCopy",
"regexpMust",
"singleCaseSwitch",
"sloppyLen",
"stringXbytes", "stringXbytes",
"switchTrue",
"typeAssertChain", "typeAssertChain",
"typeSwitchVar",
"underef",
"unlabelStmt", "unlabelStmt",
"unnamedResult", "unnamedResult",
"unnecessaryBlock", "unnecessaryBlock",

5
go.mod
View File

@@ -5,12 +5,15 @@ go 1.24.0
toolchain go1.24.6 toolchain go1.24.6
require ( require (
github.com/DATA-DOG/go-sqlmock v1.5.2
github.com/bradfitz/gomemcache v0.0.0-20250403215159-8d39553ac7cf github.com/bradfitz/gomemcache v0.0.0-20250403215159-8d39553ac7cf
github.com/getsentry/sentry-go v0.40.0
github.com/glebarez/sqlite v1.11.0 github.com/glebarez/sqlite v1.11.0
github.com/google/uuid v1.6.0 github.com/google/uuid v1.6.0
github.com/gorilla/mux v1.8.1 github.com/gorilla/mux v1.8.1
github.com/prometheus/client_golang v1.23.2 github.com/prometheus/client_golang v1.23.2
github.com/redis/go-redis/v9 v9.17.1 github.com/redis/go-redis/v9 v9.17.1
github.com/spf13/viper v1.21.0
github.com/stretchr/testify v1.11.1 github.com/stretchr/testify v1.11.1
github.com/tidwall/gjson v1.18.0 github.com/tidwall/gjson v1.18.0
github.com/tidwall/sjson v1.2.5 github.com/tidwall/sjson v1.2.5
@@ -30,7 +33,6 @@ require (
) )
require ( require (
github.com/DATA-DOG/go-sqlmock v1.5.2 // indirect
github.com/beorn7/perks v1.0.1 // indirect github.com/beorn7/perks v1.0.1 // indirect
github.com/cenkalti/backoff/v5 v5.0.3 // indirect github.com/cenkalti/backoff/v5 v5.0.3 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect
@@ -65,7 +67,6 @@ require (
github.com/spf13/afero v1.15.0 // indirect github.com/spf13/afero v1.15.0 // indirect
github.com/spf13/cast v1.10.0 // indirect github.com/spf13/cast v1.10.0 // indirect
github.com/spf13/pflag v1.0.10 // indirect github.com/spf13/pflag v1.0.10 // indirect
github.com/spf13/viper v1.21.0 // indirect
github.com/subosito/gotenv v1.6.0 // indirect github.com/subosito/gotenv v1.6.0 // indirect
github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.0 // indirect github.com/tidwall/pretty v1.2.0 // indirect

10
go.sum
View File

@@ -19,12 +19,18 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/r
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k= github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
github.com/getsentry/sentry-go v0.40.0 h1:VTJMN9zbTvqDqPwheRVLcp0qcUcM+8eFivvGocAaSbo=
github.com/getsentry/sentry-go v0.40.0/go.mod h1:eRXCoh3uvmjQLY6qu63BjUZnaBu5L5WhMV1RwYO8W5s=
github.com/glebarez/go-sqlite v1.21.2 h1:3a6LFC4sKahUunAmynQKLZceZCOzUthkRkEAl9gAXWo= github.com/glebarez/go-sqlite v1.21.2 h1:3a6LFC4sKahUunAmynQKLZceZCOzUthkRkEAl9gAXWo=
github.com/glebarez/go-sqlite v1.21.2/go.mod h1:sfxdZyhQjTM2Wry3gVYWaW072Ri1WMdWJi0k6+3382k= github.com/glebarez/go-sqlite v1.21.2/go.mod h1:sfxdZyhQjTM2Wry3gVYWaW072Ri1WMdWJi0k6+3382k=
github.com/glebarez/sqlite v1.11.0 h1:wSG0irqzP6VurnMEpFGer5Li19RpIRi2qvQz++w0GMw= github.com/glebarez/sqlite v1.11.0 h1:wSG0irqzP6VurnMEpFGer5Li19RpIRi2qvQz++w0GMw=
github.com/glebarez/sqlite v1.11.0/go.mod h1:h8/o8j5wiAsqSPoWELDUdJXhjAhsVliSn7bWZjOhrgQ= github.com/glebarez/sqlite v1.11.0/go.mod h1:h8/o8j5wiAsqSPoWELDUdJXhjAhsVliSn7bWZjOhrgQ=
github.com/go-errors/errors v1.4.2 h1:J6MZopCL4uSllY1OfXM374weqZFFItUbrImctkmUxIA=
github.com/go-errors/errors v1.4.2/go.mod h1:sIVyrIiJhuEF+Pj9Ebtd6P/rEYROXFi3BopGUQ5a5Og=
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
@@ -75,6 +81,10 @@ github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdh
github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4= github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4=
github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
github.com/pingcap/errors v0.11.4 h1:lFuQV/oaUMGcD2tqt+01ROSmJs75VG1ToEOkZIZ4nE4=
github.com/pingcap/errors v0.11.4/go.mod h1:Oi8TUi2kEtXXLMJk9l1cGmz20kV3TaQ0usTwv5KuLY8=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o= github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o=

View File

@@ -0,0 +1,218 @@
# Automatic Relation Loading Strategies
## Overview
**NEW:** The database adapters now **automatically** choose the optimal loading strategy by inspecting your model's relationship tags!
Simply use `PreloadRelation()` and the system automatically:
- Detects relationship type from Bun/GORM tags
- Uses **JOIN** for many-to-one and one-to-one (efficient, no duplication)
- Uses **separate query** for one-to-many and many-to-many (avoids duplication)
## How It Works
```go
// Just write this - the system handles the rest!
db.NewSelect().
Model(&links).
PreloadRelation("Provider"). // ✓ Auto-detects belongs-to → uses JOIN
PreloadRelation("Tags"). // ✓ Auto-detects has-many → uses separate query
Scan(ctx, &links)
```
### Detection Logic
The system inspects your model's struct tags:
**Bun models:**
```go
type Link struct {
Provider *Provider `bun:"rel:belongs-to"` // → Detected: belongs-to → JOIN
Tags []Tag `bun:"rel:has-many"` // → Detected: has-many → Separate query
}
```
**GORM models:**
```go
type Link struct {
ProviderID int
Provider *Provider `gorm:"foreignKey:ProviderID"` // → Detected: belongs-to → JOIN
Tags []Tag `gorm:"many2many:link_tags"` // → Detected: many-to-many → Separate query
}
```
**Type inference (fallback):**
- `[]Type` (slice) → has-many → Separate query
- `*Type` (pointer) → belongs-to → JOIN
- `Type` (struct) → belongs-to → JOIN
### What Gets Logged
Enable debug logging to see strategy selection:
```go
bunAdapter.EnableQueryDebug()
```
**Output:**
```
DEBUG: PreloadRelation 'Provider' detected as: belongs-to
INFO: Using JOIN strategy for belongs-to relation 'Provider'
DEBUG: PreloadRelation 'Links' detected as: has-many
DEBUG: Using separate query for has-many relation 'Links'
```
## Relationship Types
| Bun Tag | GORM Pattern | Field Type | Strategy | Why |
|---------|--------------|------------|----------|-----|
| `rel:has-many` | Slice field | `[]Type` | Separate Query | Avoids duplicating parent data |
| `rel:belongs-to` | `foreignKey:` | `*Type` | JOIN | Single parent, no duplication |
| `rel:has-one` | Single pointer | `*Type` | JOIN | One-to-one, no duplication |
| `rel:many-to-many` | `many2many:` | `[]Type` | Separate Query | Complex join, avoid cartesian |
## Manual Override
If you need to force a specific strategy, use `JoinRelation()`:
```go
// Force JOIN even for has-many (not recommended)
db.NewSelect().
Model(&providers).
JoinRelation("Links"). // Explicitly use JOIN
Scan(ctx, &providers)
```
## Examples
### Automatic Strategy Selection (Recommended)
```go
// Example 1: Loading parent provider for each link
// System detects belongs-to → uses JOIN automatically
db.NewSelect().
Model(&links).
PreloadRelation("Provider", func(q common.SelectQuery) common.SelectQuery {
return q.Where("active = ?", true)
}).
Scan(ctx, &links)
// Generated SQL: Single query with JOIN
// SELECT links.*, providers.*
// FROM links
// LEFT JOIN providers ON links.provider_id = providers.id
// WHERE providers.active = true
// Example 2: Loading child links for each provider
// System detects has-many → uses separate query automatically
db.NewSelect().
Model(&providers).
PreloadRelation("Links", func(q common.SelectQuery) common.SelectQuery {
return q.Where("active = ?", true)
}).
Scan(ctx, &providers)
// Generated SQL: Two queries
// Query 1: SELECT * FROM providers
// Query 2: SELECT * FROM links
// WHERE provider_id IN (1, 2, 3, ...)
// AND active = true
```
### Mixed Relationships
```go
type Order struct {
ID int
CustomerID int
Customer *Customer `bun:"rel:belongs-to"` // JOIN
Items []Item `bun:"rel:has-many"` // Separate
Invoice *Invoice `bun:"rel:has-one"` // JOIN
}
// All three handled optimally!
db.NewSelect().
Model(&orders).
PreloadRelation("Customer"). // → JOIN (many-to-one)
PreloadRelation("Items"). // → Separate (one-to-many)
PreloadRelation("Invoice"). // → JOIN (one-to-one)
Scan(ctx, &orders)
```
## Performance Benefits
### Before (Manual Strategy Selection)
```go
// You had to remember which to use:
.PreloadRelation("Provider") // Should I use PreloadRelation or JoinRelation?
.PreloadRelation("Links") // Which is more efficient here?
```
### After (Automatic Selection)
```go
// Just use PreloadRelation everywhere:
.PreloadRelation("Provider") // ✓ System uses JOIN automatically
.PreloadRelation("Links") // ✓ System uses separate query automatically
```
## Migration Guide
**No changes needed!** If you're already using `PreloadRelation()`, it now automatically optimizes:
```go
// Before: Always used separate query
.PreloadRelation("Provider") // Inefficient: extra round trip
// After: Automatic optimization
.PreloadRelation("Provider") // ✓ Now uses JOIN automatically!
```
## Implementation Details
### Supported Bun Tags
- `rel:has-many` → Separate query
- `rel:belongs-to` → JOIN
- `rel:has-one` → JOIN
- `rel:many-to-many` or `rel:m2m` → Separate query
### Supported GORM Patterns
- `many2many:` tag → Separate query
- `foreignKey:` tag → JOIN (belongs-to)
- `[]Type` slice without many2many → Separate query (has-many)
- `*Type` pointer with foreignKey → JOIN (belongs-to)
- `*Type` pointer without foreignKey → JOIN (has-one)
### Fallback Behavior
- `[]Type` (slice) → Separate query (safe default for collections)
- `*Type` or `Type` (single) → JOIN (safe default for single relations)
- Unknown → Separate query (safest default)
## Debugging
To see strategy selection in action:
```go
// Enable debug logging
bunAdapter.EnableQueryDebug() // or gormAdapter.EnableQueryDebug()
// Run your query
db.NewSelect().
Model(&records).
PreloadRelation("RelationName").
Scan(ctx, &records)
// Check logs for:
// - "PreloadRelation 'X' detected as: belongs-to"
// - "Using JOIN strategy for belongs-to relation 'X'"
// - Actual SQL queries executed
```
## Best Practices
1. **Use PreloadRelation() for everything** - Let the system optimize
2. **Define proper relationship tags** - Ensures correct detection
3. **Only use JoinRelation() for overrides** - When you know better than auto-detection
4. **Enable debug logging during development** - Verify optimal strategies are chosen
5. **Trust the system** - It's designed to choose correctly based on relationship type

View File

@@ -0,0 +1,81 @@
package database
import (
"testing"
)
func TestNormalizeTableAlias(t *testing.T) {
tests := []struct {
name string
query string
expectedAlias string
tableName string
want string
}{
{
name: "strips plausible alias from simple condition",
query: "APIL.rid_hub = 2576",
expectedAlias: "apiproviderlink",
tableName: "apiproviderlink",
want: "rid_hub = 2576",
},
{
name: "keeps correct alias",
query: "apiproviderlink.rid_hub = 2576",
expectedAlias: "apiproviderlink",
tableName: "apiproviderlink",
want: "apiproviderlink.rid_hub = 2576",
},
{
name: "strips plausible alias with multiple conditions",
query: "APIL.rid_hub = ? AND APIL.active = ?",
expectedAlias: "apiproviderlink",
tableName: "apiproviderlink",
want: "rid_hub = ? AND active = ?",
},
{
name: "handles mixed correct and plausible aliases",
query: "APIL.rid_hub = ? AND apiproviderlink.active = ?",
expectedAlias: "apiproviderlink",
tableName: "apiproviderlink",
want: "rid_hub = ? AND apiproviderlink.active = ?",
},
{
name: "handles parentheses",
query: "(APIL.rid_hub = ?)",
expectedAlias: "apiproviderlink",
tableName: "apiproviderlink",
want: "(rid_hub = ?)",
},
{
name: "no alias in query",
query: "rid_hub = ?",
expectedAlias: "apiproviderlink",
tableName: "apiproviderlink",
want: "rid_hub = ?",
},
{
name: "keeps reference to different table (not in current table name)",
query: "APIL.rid_hub = ?",
expectedAlias: "apiprovider",
tableName: "apiprovider",
want: "APIL.rid_hub = ?",
},
{
name: "keeps reference with short prefix that might be ambiguous",
query: "AP.rid = ?",
expectedAlias: "apiprovider",
tableName: "apiprovider",
want: "AP.rid = ?",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := normalizeTableAlias(tt.query, tt.expectedAlias, tt.tableName)
if got != tt.want {
t.Errorf("normalizeTableAlias() = %q, want %q", got, tt.want)
}
})
}
}

View File

@@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"reflect" "reflect"
"strings" "strings"
"time"
"github.com/uptrace/bun" "github.com/uptrace/bun"
@@ -15,6 +16,24 @@ import (
"github.com/bitechdev/ResolveSpec/pkg/reflection" "github.com/bitechdev/ResolveSpec/pkg/reflection"
) )
// QueryDebugHook is a Bun query hook that logs all SQL queries including preloads
type QueryDebugHook struct{}
func (h *QueryDebugHook) BeforeQuery(ctx context.Context, event *bun.QueryEvent) context.Context {
return ctx
}
func (h *QueryDebugHook) AfterQuery(ctx context.Context, event *bun.QueryEvent) {
query := event.Query
duration := time.Since(event.StartTime)
if event.Err != nil {
logger.Error("SQL Query Failed [%s]: %s. Error: %v", duration, query, event.Err)
} else {
logger.Debug("SQL Query Success [%s]: %s", duration, query)
}
}
// BunAdapter adapts Bun to work with our Database interface // BunAdapter adapts Bun to work with our Database interface
// This demonstrates how the abstraction works with different ORMs // This demonstrates how the abstraction works with different ORMs
type BunAdapter struct { type BunAdapter struct {
@@ -26,6 +45,20 @@ func NewBunAdapter(db *bun.DB) *BunAdapter {
return &BunAdapter{db: db} return &BunAdapter{db: db}
} }
// EnableQueryDebug enables query debugging which logs all SQL queries including preloads
// This is useful for debugging preload queries that may be failing
func (b *BunAdapter) EnableQueryDebug() {
b.db.AddQueryHook(&QueryDebugHook{})
logger.Info("Bun query debug mode enabled - all SQL queries will be logged")
}
// DisableQueryDebug removes all query hooks
func (b *BunAdapter) DisableQueryDebug() {
// Create a new DB without hooks
// Note: Bun doesn't have a RemoveQueryHook, so we'd need to track hooks manually
logger.Info("To disable query debug, recreate the BunAdapter without adding the hook")
}
func (b *BunAdapter) NewSelect() common.SelectQuery { func (b *BunAdapter) NewSelect() common.SelectQuery {
return &BunSelectQuery{ return &BunSelectQuery{
query: b.db.NewSelect(), query: b.db.NewSelect(),
@@ -107,6 +140,8 @@ type BunSelectQuery struct {
tableName string // Just the table name, without schema tableName string // Just the table name, without schema
tableAlias string tableAlias string
deferredPreloads []deferredPreload // Preloads to execute as separate queries deferredPreloads []deferredPreload // Preloads to execute as separate queries
inJoinContext bool // Track if we're in a JOIN relation context
joinTableAlias string // Alias to use for JOIN conditions
} }
// deferredPreload represents a preload that will be executed as a separate query // deferredPreload represents a preload that will be executed as a separate query
@@ -156,10 +191,147 @@ func (b *BunSelectQuery) ColumnExpr(query string, args ...interface{}) common.Se
} }
func (b *BunSelectQuery) Where(query string, args ...interface{}) common.SelectQuery { func (b *BunSelectQuery) Where(query string, args ...interface{}) common.SelectQuery {
// If we're in a JOIN context, add table prefix to unqualified columns
if b.inJoinContext && b.joinTableAlias != "" {
query = addTablePrefix(query, b.joinTableAlias)
} else if b.tableAlias != "" && b.tableName != "" {
// If we have a table alias defined, check if the query references a different alias
// This can happen in preloads where the user expects a certain alias but Bun generates another
query = normalizeTableAlias(query, b.tableAlias, b.tableName)
}
b.query = b.query.Where(query, args...) b.query = b.query.Where(query, args...)
return b return b
} }
// addTablePrefix adds a table prefix to unqualified column references
// This is used in JOIN contexts where conditions must reference the joined table
func addTablePrefix(query, tableAlias string) string {
if tableAlias == "" || query == "" {
return query
}
// Split on spaces and parentheses to find column references
parts := strings.FieldsFunc(query, func(r rune) bool {
return r == ' ' || r == '(' || r == ')' || r == ','
})
modified := query
for _, part := range parts {
// Check if this looks like an unqualified column reference
// (no dot, and likely a column name before an operator)
if !strings.Contains(part, ".") {
// Extract potential column name (before = or other operators)
for _, op := range []string{"=", "!=", "<>", ">", ">=", "<", "<=", " LIKE ", " IN ", " IS "} {
if strings.Contains(part, op) {
colName := strings.Split(part, op)[0]
colName = strings.TrimSpace(colName)
if colName != "" && !isOperatorOrKeyword(colName) {
// Add table prefix
prefixed := tableAlias + "." + colName + strings.TrimPrefix(part, colName)
modified = strings.ReplaceAll(modified, part, prefixed)
logger.Debug("Adding table prefix '%s' to column '%s' in JOIN condition", tableAlias, colName)
}
break
}
}
}
}
return modified
}
// isOperatorOrKeyword checks if a string is likely an operator or SQL keyword
func isOperatorOrKeyword(s string) bool {
s = strings.ToUpper(strings.TrimSpace(s))
keywords := []string{"AND", "OR", "NOT", "IN", "IS", "NULL", "TRUE", "FALSE", "LIKE", "BETWEEN"}
for _, kw := range keywords {
if s == kw {
return true
}
}
return false
}
// isAcronymMatch checks if prefix is an acronym of tableName
// For example, "apil" matches "apiproviderlink" because each letter appears in sequence
func isAcronymMatch(prefix, tableName string) bool {
if len(prefix) == 0 || len(tableName) == 0 {
return false
}
prefixIdx := 0
for i := 0; i < len(tableName) && prefixIdx < len(prefix); i++ {
if tableName[i] == prefix[prefixIdx] {
prefixIdx++
}
}
// All characters of prefix were found in sequence in tableName
return prefixIdx == len(prefix)
}
// normalizeTableAlias replaces table alias prefixes in SQL conditions
// This handles cases where a user references a table alias that doesn't match
// what Bun generates (common in preload contexts)
func normalizeTableAlias(query, expectedAlias, tableName string) string {
// Pattern: <word>.<column> where <word> might be an incorrect alias
// We'll look for patterns like "APIL.column" and either:
// 1. Remove the alias prefix if it's clearly meant for this table
// 2. Leave it alone if it might be referring to another table (JOIN/preload)
// Split on spaces and parentheses to find qualified references
parts := strings.FieldsFunc(query, func(r rune) bool {
return r == ' ' || r == '(' || r == ')' || r == ','
})
modified := query
for _, part := range parts {
// Check if this looks like a qualified column reference
if dotIndex := strings.Index(part, "."); dotIndex > 0 {
prefix := part[:dotIndex]
column := part[dotIndex+1:]
// Check if the prefix matches our expected alias or table name (case-insensitive)
if strings.EqualFold(prefix, expectedAlias) ||
strings.EqualFold(prefix, tableName) ||
strings.EqualFold(prefix, strings.ToLower(tableName)) {
// Prefix matches current table, it's safe but redundant - leave it
continue
}
// Check if the prefix could plausibly be an alias/acronym for this table
// Only strip if we're confident it's meant for this table
// For example: "APIL" could be an acronym for "apiproviderlink"
prefixLower := strings.ToLower(prefix)
tableNameLower := strings.ToLower(tableName)
// Check if prefix is a substring of table name
isSubstring := strings.Contains(tableNameLower, prefixLower) && len(prefixLower) > 2
// Check if prefix is an acronym of table name
// e.g., "APIL" matches "ApiProviderLink" (A-p-I-providerL-ink)
isAcronym := false
if !isSubstring && len(prefixLower) > 2 {
isAcronym = isAcronymMatch(prefixLower, tableNameLower)
}
if isSubstring || isAcronym {
// This looks like it could be an alias for this table - strip it
logger.Debug("Stripping plausible alias '%s' from WHERE condition, keeping just '%s'", prefix, column)
// Replace the qualified reference with just the column name
modified = strings.ReplaceAll(modified, part, column)
} else {
// Prefix doesn't match the current table at all
// It's likely referring to a different table (JOIN/preload)
// DON'T strip it - leave the qualified reference as-is
logger.Debug("Keeping qualified reference '%s' - prefix '%s' doesn't match current table '%s'", part, prefix, tableName)
}
}
}
return modified
}
func (b *BunSelectQuery) WhereOr(query string, args ...interface{}) common.SelectQuery { func (b *BunSelectQuery) WhereOr(query string, args ...interface{}) common.SelectQuery {
b.query = b.query.WhereOr(query, args...) b.query = b.query.WhereOr(query, args...)
return b return b
@@ -288,6 +460,27 @@ func (b *BunSelectQuery) Preload(relation string, conditions ...interface{}) com
// } // }
func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery { func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
// Auto-detect relationship type and choose optimal loading strategy
// Get the model from the query if available
model := b.query.GetModel()
if model != nil && model.Value() != nil {
relType := reflection.GetRelationType(model.Value(), relation)
// Log the detected relationship type
logger.Debug("PreloadRelation '%s' detected as: %s", relation, relType)
// If this is a belongs-to or has-one relation, use JOIN for better performance
if relType.ShouldUseJoin() {
logger.Info("Using JOIN strategy for %s relation '%s'", relType, relation)
return b.JoinRelation(relation, apply...)
}
// For has-many, many-to-many, or unknown: use separate query (safer default)
if relType == reflection.RelationHasMany || relType == reflection.RelationManyToMany {
logger.Debug("Using separate query for %s relation '%s'", relType, relation)
}
}
// Check if this relation chain would create problematic long aliases // Check if this relation chain would create problematic long aliases
relationParts := strings.Split(relation, ".") relationParts := strings.Split(relation, ".")
aliasChain := strings.ToLower(strings.Join(relationParts, "__")) aliasChain := strings.ToLower(strings.Join(relationParts, "__"))
@@ -350,6 +543,28 @@ func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.S
db: b.db, db: b.db,
} }
// Try to extract table name and alias from the preload model
if model := sq.GetModel(); model != nil && model.Value() != nil {
modelValue := model.Value()
// Extract table name if model implements TableNameProvider
if provider, ok := modelValue.(common.TableNameProvider); ok {
fullTableName := provider.TableName()
wrapper.schema, wrapper.tableName = parseTableName(fullTableName)
}
// Extract table alias if model implements TableAliasProvider
if provider, ok := modelValue.(common.TableAliasProvider); ok {
wrapper.tableAlias = provider.TableAlias()
// Apply the alias to the Bun query so conditions can reference it
if wrapper.tableAlias != "" {
// Note: Bun's Relation() already sets up the table, but we can add
// the alias explicitly if needed
logger.Debug("Preload relation '%s' using table alias: %s", relation, wrapper.tableAlias)
}
}
}
// Start with the interface value (not pointer) // Start with the interface value (not pointer)
current := common.SelectQuery(wrapper) current := common.SelectQuery(wrapper)
@@ -372,6 +587,36 @@ func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.S
return b return b
} }
func (b *BunSelectQuery) JoinRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
// JoinRelation uses a LEFT JOIN instead of a separate query
// This is more efficient for many-to-one or one-to-one relationships
logger.Debug("JoinRelation '%s' - Using JOIN strategy with automatic WHERE prefix addition", relation)
// Wrap the apply functions to automatically add table prefix to WHERE conditions
wrappedApply := make([]func(common.SelectQuery) common.SelectQuery, 0, len(apply))
for _, fn := range apply {
if fn != nil {
wrappedFn := func(originalFn func(common.SelectQuery) common.SelectQuery) func(common.SelectQuery) common.SelectQuery {
return func(q common.SelectQuery) common.SelectQuery {
// Create a special wrapper that adds prefixes to WHERE conditions
if bunQuery, ok := q.(*BunSelectQuery); ok {
// Mark this query as being in JOIN context
bunQuery.inJoinContext = true
bunQuery.joinTableAlias = strings.ToLower(relation)
}
return originalFn(q)
}
}(fn)
wrappedApply = append(wrappedApply, wrappedFn)
}
}
// Use PreloadRelation with the wrapped functions
// Bun's Relation() will use JOIN for belongs-to and has-one relations
return b.PreloadRelation(relation, wrappedApply...)
}
func (b *BunSelectQuery) Order(order string) common.SelectQuery { func (b *BunSelectQuery) Order(order string) common.SelectQuery {
b.query = b.query.Order(order) b.query = b.query.Order(order)
return b return b
@@ -410,6 +655,9 @@ func (b *BunSelectQuery) Scan(ctx context.Context, dest interface{}) (err error)
// Execute the main query first // Execute the main query first
err = b.query.Scan(ctx, dest) err = b.query.Scan(ctx, dest)
if err != nil { if err != nil {
// Log SQL string for debugging
sqlStr := b.query.String()
logger.Error("BunSelectQuery.Scan failed. SQL: %s. Error: %v", sqlStr, err)
return err return err
} }
@@ -438,6 +686,9 @@ func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) {
// Execute the main query first // Execute the main query first
err = b.query.Scan(ctx) err = b.query.Scan(ctx)
if err != nil { if err != nil {
// Log SQL string for debugging
sqlStr := b.query.String()
logger.Error("BunSelectQuery.ScanModel failed. SQL: %s. Error: %v", sqlStr, err)
return err return err
} }
@@ -573,15 +824,25 @@ func (b *BunSelectQuery) Count(ctx context.Context) (count int, err error) {
// If Model() was set, use bun's native Count() which works properly // If Model() was set, use bun's native Count() which works properly
if b.hasModel { if b.hasModel {
count, err := b.query.Count(ctx) count, err := b.query.Count(ctx)
if err != nil {
// Log SQL string for debugging
sqlStr := b.query.String()
logger.Error("BunSelectQuery.Count failed. SQL: %s. Error: %v", sqlStr, err)
}
return count, err return count, err
} }
// Otherwise, wrap as subquery to avoid "Model(nil)" error // Otherwise, wrap as subquery to avoid "Model(nil)" error
// This is needed when only Table() is set without a model // This is needed when only Table() is set without a model
err = b.db.NewSelect(). countQuery := b.db.NewSelect().
TableExpr("(?) AS subquery", b.query). TableExpr("(?) AS subquery", b.query).
ColumnExpr("COUNT(*)"). ColumnExpr("COUNT(*)")
Scan(ctx, &count) err = countQuery.Scan(ctx, &count)
if err != nil {
// Log SQL string for debugging
sqlStr := countQuery.String()
logger.Error("BunSelectQuery.Count (subquery) failed. SQL: %s. Error: %v", sqlStr, err)
}
return count, err return count, err
} }
@@ -592,7 +853,13 @@ func (b *BunSelectQuery) Exists(ctx context.Context) (exists bool, err error) {
exists = false exists = false
} }
}() }()
return b.query.Exists(ctx) exists, err = b.query.Exists(ctx)
if err != nil {
// Log SQL string for debugging
sqlStr := b.query.String()
logger.Error("BunSelectQuery.Exists failed. SQL: %s. Error: %v", sqlStr, err)
}
return exists, err
} }
// BunInsertQuery implements InsertQuery for Bun // BunInsertQuery implements InsertQuery for Bun
@@ -729,6 +996,11 @@ func (b *BunUpdateQuery) Exec(ctx context.Context) (res common.Result, err error
} }
}() }()
result, err := b.query.Exec(ctx) result, err := b.query.Exec(ctx)
if err != nil {
// Log SQL string for debugging
sqlStr := b.query.String()
logger.Error("BunUpdateQuery.Exec failed. SQL: %s. Error: %v", sqlStr, err)
}
return &BunResult{result: result}, err return &BunResult{result: result}, err
} }
@@ -759,6 +1031,11 @@ func (b *BunDeleteQuery) Exec(ctx context.Context) (res common.Result, err error
} }
}() }()
result, err := b.query.Exec(ctx) result, err := b.query.Exec(ctx)
if err != nil {
// Log SQL string for debugging
sqlStr := b.query.String()
logger.Error("BunDeleteQuery.Exec failed. SQL: %s. Error: %v", sqlStr, err)
}
return &BunResult{result: result}, err return &BunResult{result: result}, err
} }

View File

@@ -23,6 +23,22 @@ func NewGormAdapter(db *gorm.DB) *GormAdapter {
return &GormAdapter{db: db} return &GormAdapter{db: db}
} }
// EnableQueryDebug enables query debugging which logs all SQL queries including preloads
// This is useful for debugging preload queries that may be failing
func (g *GormAdapter) EnableQueryDebug() *GormAdapter {
g.db = g.db.Debug()
logger.Info("GORM query debug mode enabled - all SQL queries will be logged")
return g
}
// DisableQueryDebug disables query debugging
func (g *GormAdapter) DisableQueryDebug() *GormAdapter {
// GORM's Debug() creates a new session, so we need to get the base DB
// This is a simplified implementation
logger.Info("GORM debug mode - create a new adapter without Debug() to disable")
return g
}
func (g *GormAdapter) NewSelect() common.SelectQuery { func (g *GormAdapter) NewSelect() common.SelectQuery {
return &GormSelectQuery{db: g.db} return &GormSelectQuery{db: g.db}
} }
@@ -88,10 +104,12 @@ func (g *GormAdapter) RunInTransaction(ctx context.Context, fn func(common.Datab
// GormSelectQuery implements SelectQuery for GORM // GormSelectQuery implements SelectQuery for GORM
type GormSelectQuery struct { type GormSelectQuery struct {
db *gorm.DB db *gorm.DB
schema string // Separated schema name schema string // Separated schema name
tableName string // Just the table name, without schema tableName string // Just the table name, without schema
tableAlias string tableAlias string
inJoinContext bool // Track if we're in a JOIN relation context
joinTableAlias string // Alias to use for JOIN conditions
} }
func (g *GormSelectQuery) Model(model interface{}) common.SelectQuery { func (g *GormSelectQuery) Model(model interface{}) common.SelectQuery {
@@ -135,10 +153,61 @@ func (g *GormSelectQuery) ColumnExpr(query string, args ...interface{}) common.S
} }
func (g *GormSelectQuery) Where(query string, args ...interface{}) common.SelectQuery { func (g *GormSelectQuery) Where(query string, args ...interface{}) common.SelectQuery {
// If we're in a JOIN context, add table prefix to unqualified columns
if g.inJoinContext && g.joinTableAlias != "" {
query = addTablePrefixGorm(query, g.joinTableAlias)
}
g.db = g.db.Where(query, args...) g.db = g.db.Where(query, args...)
return g return g
} }
// addTablePrefixGorm adds a table prefix to unqualified column references (GORM version)
func addTablePrefixGorm(query, tableAlias string) string {
if tableAlias == "" || query == "" {
return query
}
// Split on spaces and parentheses to find column references
parts := strings.FieldsFunc(query, func(r rune) bool {
return r == ' ' || r == '(' || r == ')' || r == ','
})
modified := query
for _, part := range parts {
// Check if this looks like an unqualified column reference
if !strings.Contains(part, ".") {
// Extract potential column name (before = or other operators)
for _, op := range []string{"=", "!=", "<>", ">", ">=", "<", "<=", " LIKE ", " IN ", " IS "} {
if strings.Contains(part, op) {
colName := strings.Split(part, op)[0]
colName = strings.TrimSpace(colName)
if colName != "" && !isOperatorOrKeywordGorm(colName) {
// Add table prefix
prefixed := tableAlias + "." + colName + strings.TrimPrefix(part, colName)
modified = strings.ReplaceAll(modified, part, prefixed)
logger.Debug("Adding table prefix '%s' to column '%s' in JOIN condition", tableAlias, colName)
}
break
}
}
}
}
return modified
}
// isOperatorOrKeywordGorm checks if a string is likely an operator or SQL keyword (GORM version)
func isOperatorOrKeywordGorm(s string) bool {
s = strings.ToUpper(strings.TrimSpace(s))
keywords := []string{"AND", "OR", "NOT", "IN", "IS", "NULL", "TRUE", "FALSE", "LIKE", "BETWEEN"}
for _, kw := range keywords {
if s == kw {
return true
}
}
return false
}
func (g *GormSelectQuery) WhereOr(query string, args ...interface{}) common.SelectQuery { func (g *GormSelectQuery) WhereOr(query string, args ...interface{}) common.SelectQuery {
g.db = g.db.Or(query, args...) g.db = g.db.Or(query, args...)
return g return g
@@ -222,6 +291,27 @@ func (g *GormSelectQuery) Preload(relation string, conditions ...interface{}) co
} }
func (g *GormSelectQuery) PreloadRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery { func (g *GormSelectQuery) PreloadRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
// Auto-detect relationship type and choose optimal loading strategy
// Get the model from GORM's statement if available
if g.db.Statement != nil && g.db.Statement.Model != nil {
relType := reflection.GetRelationType(g.db.Statement.Model, relation)
// Log the detected relationship type
logger.Debug("PreloadRelation '%s' detected as: %s", relation, relType)
// If this is a belongs-to or has-one relation, use JOIN for better performance
if relType.ShouldUseJoin() {
logger.Info("Using JOIN strategy for %s relation '%s'", relType, relation)
return g.JoinRelation(relation, apply...)
}
// For has-many, many-to-many, or unknown: use separate query (safer default)
if relType == reflection.RelationHasMany || relType == reflection.RelationManyToMany {
logger.Debug("Using separate query for %s relation '%s'", relType, relation)
}
}
// Use GORM's Preload (separate query strategy)
g.db = g.db.Preload(relation, func(db *gorm.DB) *gorm.DB { g.db = g.db.Preload(relation, func(db *gorm.DB) *gorm.DB {
if len(apply) == 0 { if len(apply) == 0 {
return db return db
@@ -251,6 +341,42 @@ func (g *GormSelectQuery) PreloadRelation(relation string, apply ...func(common.
return g return g
} }
func (g *GormSelectQuery) JoinRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
// JoinRelation uses a JOIN instead of a separate preload query
// This is more efficient for many-to-one or one-to-one relationships
// as it avoids additional round trips to the database
// GORM's Joins() method forces a JOIN for the preload
logger.Debug("JoinRelation '%s' - Using GORM Joins() with automatic WHERE prefix addition", relation)
g.db = g.db.Joins(relation, func(db *gorm.DB) *gorm.DB {
if len(apply) == 0 {
return db
}
wrapper := &GormSelectQuery{
db: db,
inJoinContext: true, // Mark as JOIN context
joinTableAlias: strings.ToLower(relation), // Use relation name as alias
}
current := common.SelectQuery(wrapper)
for _, fn := range apply {
if fn != nil {
current = fn(current)
}
}
if finalGorm, ok := current.(*GormSelectQuery); ok {
return finalGorm.db
}
return db
})
return g
}
func (g *GormSelectQuery) Order(order string) common.SelectQuery { func (g *GormSelectQuery) Order(order string) common.SelectQuery {
g.db = g.db.Order(order) g.db = g.db.Order(order)
return g return g
@@ -282,7 +408,15 @@ func (g *GormSelectQuery) Scan(ctx context.Context, dest interface{}) (err error
err = logger.HandlePanic("GormSelectQuery.Scan", r) err = logger.HandlePanic("GormSelectQuery.Scan", r)
} }
}() }()
return g.db.WithContext(ctx).Find(dest).Error err = g.db.WithContext(ctx).Find(dest).Error
if err != nil {
// Log SQL string for debugging
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
return tx.Find(dest)
})
logger.Error("GormSelectQuery.Scan failed. SQL: %s. Error: %v", sqlStr, err)
}
return err
} }
func (g *GormSelectQuery) ScanModel(ctx context.Context) (err error) { func (g *GormSelectQuery) ScanModel(ctx context.Context) (err error) {
@@ -294,7 +428,15 @@ func (g *GormSelectQuery) ScanModel(ctx context.Context) (err error) {
if g.db.Statement.Model == nil { if g.db.Statement.Model == nil {
return fmt.Errorf("ScanModel requires Model() to be set before scanning") return fmt.Errorf("ScanModel requires Model() to be set before scanning")
} }
return g.db.WithContext(ctx).Find(g.db.Statement.Model).Error err = g.db.WithContext(ctx).Find(g.db.Statement.Model).Error
if err != nil {
// Log SQL string for debugging
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
return tx.Find(g.db.Statement.Model)
})
logger.Error("GormSelectQuery.ScanModel failed. SQL: %s. Error: %v", sqlStr, err)
}
return err
} }
func (g *GormSelectQuery) Count(ctx context.Context) (count int, err error) { func (g *GormSelectQuery) Count(ctx context.Context) (count int, err error) {
@@ -306,6 +448,13 @@ func (g *GormSelectQuery) Count(ctx context.Context) (count int, err error) {
}() }()
var count64 int64 var count64 int64
err = g.db.WithContext(ctx).Count(&count64).Error err = g.db.WithContext(ctx).Count(&count64).Error
if err != nil {
// Log SQL string for debugging
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
return tx.Count(&count64)
})
logger.Error("GormSelectQuery.Count failed. SQL: %s. Error: %v", sqlStr, err)
}
return int(count64), err return int(count64), err
} }
@@ -318,6 +467,13 @@ func (g *GormSelectQuery) Exists(ctx context.Context) (exists bool, err error) {
}() }()
var count int64 var count int64
err = g.db.WithContext(ctx).Limit(1).Count(&count).Error err = g.db.WithContext(ctx).Limit(1).Count(&count).Error
if err != nil {
// Log SQL string for debugging
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
return tx.Limit(1).Count(&count)
})
logger.Error("GormSelectQuery.Exists failed. SQL: %s. Error: %v", sqlStr, err)
}
return count > 0, err return count > 0, err
} }
@@ -456,6 +612,13 @@ func (g *GormUpdateQuery) Exec(ctx context.Context) (res common.Result, err erro
} }
}() }()
result := g.db.WithContext(ctx).Updates(g.updates) result := g.db.WithContext(ctx).Updates(g.updates)
if result.Error != nil {
// Log SQL string for debugging
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
return tx.Updates(g.updates)
})
logger.Error("GormUpdateQuery.Exec failed. SQL: %s. Error: %v", sqlStr, result.Error)
}
return &GormResult{result: result}, result.Error return &GormResult{result: result}, result.Error
} }
@@ -488,6 +651,13 @@ func (g *GormDeleteQuery) Exec(ctx context.Context) (res common.Result, err erro
} }
}() }()
result := g.db.WithContext(ctx).Delete(g.model) result := g.db.WithContext(ctx).Delete(g.model)
if result.Error != nil {
// Log SQL string for debugging
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
return tx.Delete(g.model)
})
logger.Error("GormDeleteQuery.Exec failed. SQL: %s. Error: %v", sqlStr, result.Error)
}
return &GormResult{result: result}, result.Error return &GormResult{result: result}, result.Error
} }

View File

@@ -38,6 +38,7 @@ type SelectQuery interface {
LeftJoin(query string, args ...interface{}) SelectQuery LeftJoin(query string, args ...interface{}) SelectQuery
Preload(relation string, conditions ...interface{}) SelectQuery Preload(relation string, conditions ...interface{}) SelectQuery
PreloadRelation(relation string, apply ...func(SelectQuery) SelectQuery) SelectQuery PreloadRelation(relation string, apply ...func(SelectQuery) SelectQuery) SelectQuery
JoinRelation(relation string, apply ...func(SelectQuery) SelectQuery) SelectQuery
Order(order string) SelectQuery Order(order string) SelectQuery
Limit(n int) SelectQuery Limit(n int) SelectQuery
Offset(n int) SelectQuery Offset(n int) SelectQuery

View File

@@ -9,81 +9,40 @@ import (
"github.com/bitechdev/ResolveSpec/pkg/reflection" "github.com/bitechdev/ResolveSpec/pkg/reflection"
) )
// ValidateAndFixPreloadWhere validates that the WHERE clause for a preload contains // ValidateAndFixPreloadWhere validates and normalizes WHERE clauses for preloads
// the relation prefix (alias). If not present, it attempts to add it to column references. //
// Returns the fixed WHERE clause and an error if it cannot be safely fixed. // NOTE: For preload queries, table aliases from the parent query are not valid since
// the preload executes as a separate query with its own table alias. This function
// now simply validates basic syntax without requiring or adding prefixes.
// The actual alias normalization happens in the database adapter layer.
//
// Returns the WHERE clause and an error if it contains obviously invalid syntax.
func ValidateAndFixPreloadWhere(where string, relationName string) (string, error) { func ValidateAndFixPreloadWhere(where string, relationName string) (string, error) {
if where == "" { if where == "" {
return where, nil return where, nil
} }
// Check if the relation name is already present in the WHERE clause where = strings.TrimSpace(where)
lowerWhere := strings.ToLower(where)
lowerRelation := strings.ToLower(relationName)
// Check for patterns like "relation.", "relation ", or just "relation" followed by a dot // Just do basic validation - don't require or add prefixes
if strings.Contains(lowerWhere, lowerRelation+".") || // The database adapter will handle alias normalization
strings.Contains(lowerWhere, "`"+lowerRelation+"`.") ||
strings.Contains(lowerWhere, "\""+lowerRelation+"\".") { // Check if the WHERE clause contains any qualified column references
// Relation prefix is already present // If it does, log a debug message but don't fail - let the adapter handle it
if strings.Contains(where, ".") {
logger.Debug("Preload WHERE clause for '%s' contains qualified column references: '%s'. "+
"Note: In preload context, table aliases from parent query are not available. "+
"The database adapter will normalize aliases automatically.", relationName, where)
}
// Validate that it's not empty or just whitespace
if where == "" {
return where, nil return where, nil
} }
// If the WHERE clause is complex (contains OR, parentheses, subqueries, etc.), // Return the WHERE clause as-is
// we can't safely auto-fix it - require explicit prefix // The BunSelectQuery.Where() method will handle alias normalization via normalizeTableAlias()
if strings.Contains(lowerWhere, " or ") || return where, nil
strings.Contains(where, "(") ||
strings.Contains(where, ")") {
return "", fmt.Errorf("preload WHERE condition must reference the relation '%s' (e.g., '%s.column_name'). Complex WHERE clauses with OR/parentheses must explicitly use the relation prefix", relationName, relationName)
}
// Try to add the relation prefix to simple column references
// This handles basic cases like "column = value" or "column = value AND other_column = value"
// Split by AND to handle multiple conditions (case-insensitive)
originalConditions := strings.Split(where, " AND ")
// If uppercase split didn't work, try lowercase
if len(originalConditions) == 1 {
originalConditions = strings.Split(where, " and ")
}
fixedConditions := make([]string, 0, len(originalConditions))
for _, cond := range originalConditions {
cond = strings.TrimSpace(cond)
if cond == "" {
continue
}
// Check if this condition already has a table prefix (contains a dot)
if strings.Contains(cond, ".") {
fixedConditions = append(fixedConditions, cond)
continue
}
// Check if this is a SQL expression/literal that shouldn't be prefixed
lowerCond := strings.ToLower(strings.TrimSpace(cond))
if IsSQLExpression(lowerCond) {
// Don't prefix SQL expressions like "true", "false", "1=1", etc.
fixedConditions = append(fixedConditions, cond)
continue
}
// Extract the column name (first identifier before operator)
columnName := ExtractColumnName(cond)
if columnName == "" {
// Can't identify column name, require explicit prefix
return "", fmt.Errorf("preload WHERE condition must reference the relation '%s' (e.g., '%s.column_name'). Cannot auto-fix condition: %s", relationName, relationName, cond)
}
// Add relation prefix to the column name only
fixedCond := strings.Replace(cond, columnName, relationName+"."+columnName, 1)
fixedConditions = append(fixedConditions, fixedCond)
}
fixedWhere := strings.Join(fixedConditions, " AND ")
logger.Debug("Auto-fixed preload WHERE clause: '%s' -> '%s'", where, fixedWhere)
return fixedWhere, nil
} }
// IsSQLExpression checks if a condition is a SQL expression that shouldn't be prefixed // IsSQLExpression checks if a condition is a SQL expression that shouldn't be prefixed
@@ -120,23 +79,69 @@ func IsTrivialCondition(cond string) bool {
return false return false
} }
// SanitizeWhereClause removes trivial conditions and optionally prefixes table/relation names to columns // validateWhereClauseSecurity checks for dangerous SQL statements in WHERE clauses
// Returns an error if any dangerous keywords are found
func validateWhereClauseSecurity(where string) error {
if where == "" {
return nil
}
lowerWhere := strings.ToLower(where)
// List of dangerous SQL keywords that should never appear in WHERE clauses
dangerousKeywords := []string{
"delete ", "delete\t", "delete\n", "delete;",
"update ", "update\t", "update\n", "update;",
"truncate ", "truncate\t", "truncate\n", "truncate;",
"drop ", "drop\t", "drop\n", "drop;",
"alter ", "alter\t", "alter\n", "alter;",
"create ", "create\t", "create\n", "create;",
"insert ", "insert\t", "insert\n", "insert;",
"grant ", "grant\t", "grant\n", "grant;",
"revoke ", "revoke\t", "revoke\n", "revoke;",
"exec ", "exec\t", "exec\n", "exec;",
"execute ", "execute\t", "execute\n", "execute;",
";delete", ";update", ";truncate", ";drop", ";alter", ";create", ";insert",
}
for _, keyword := range dangerousKeywords {
if strings.Contains(lowerWhere, keyword) {
logger.Error("Dangerous SQL keyword detected in WHERE clause: %s", strings.TrimSpace(keyword))
return fmt.Errorf("dangerous SQL keyword detected in WHERE clause: %s", strings.TrimSpace(keyword))
}
}
return nil
}
// SanitizeWhereClause removes trivial conditions and fixes incorrect table prefixes
// This function should be used everywhere a WHERE statement is sent to ensure clean, efficient SQL // This function should be used everywhere a WHERE statement is sent to ensure clean, efficient SQL
// //
// Parameters: // Parameters:
// - where: The WHERE clause string to sanitize // - where: The WHERE clause string to sanitize
// - tableName: Optional table/relation name to prefix to column references (empty string to skip prefixing) // - tableName: The correct table/relation name to use when fixing incorrect prefixes
// - options: Optional RequestOptions containing preload relations that should be allowed as valid prefixes
// //
// Returns: // Returns:
// - The sanitized WHERE clause with trivial conditions removed and columns optionally prefixed // - The sanitized WHERE clause with trivial conditions removed and incorrect prefixes fixed
// - An empty string if all conditions were trivial or the input was empty // - An empty string if all conditions were trivial or the input was empty
func SanitizeWhereClause(where string, tableName string) string { //
// Note: This function will NOT add prefixes to unprefixed columns. It will only fix
// incorrect prefixes (e.g., wrong_table.column -> correct_table.column), unless the
// prefix matches a preloaded relation name, in which case it's left unchanged.
func SanitizeWhereClause(where string, tableName string, options ...*RequestOptions) string {
if where == "" { if where == "" {
return "" return ""
} }
where = strings.TrimSpace(where) where = strings.TrimSpace(where)
// Validate that the WHERE clause doesn't contain dangerous SQL statements
if err := validateWhereClauseSecurity(where); err != nil {
logger.Debug("Security validation failed for WHERE clause: %v", err)
return ""
}
// Strip outer parentheses and re-trim // Strip outer parentheses and re-trim
where = stripOuterParentheses(where) where = stripOuterParentheses(where)
@@ -146,6 +151,22 @@ func SanitizeWhereClause(where string, tableName string) string {
validColumns = getValidColumnsForTable(tableName) validColumns = getValidColumnsForTable(tableName)
} }
// Build a set of allowed table prefixes (main table + preloaded relations)
allowedPrefixes := make(map[string]bool)
if tableName != "" {
allowedPrefixes[tableName] = true
}
// Add preload relation names as allowed prefixes
if len(options) > 0 && options[0] != nil {
for pi := range options[0].Preload {
if options[0].Preload[pi].Relation != "" {
allowedPrefixes[options[0].Preload[pi].Relation] = true
logger.Debug("Added preload relation '%s' as allowed table prefix", options[0].Preload[pi].Relation)
}
}
}
// Split by AND to handle multiple conditions // Split by AND to handle multiple conditions
conditions := splitByAND(where) conditions := splitByAND(where)
@@ -166,22 +187,23 @@ func SanitizeWhereClause(where string, tableName string) string {
continue continue
} }
// If tableName is provided and the condition doesn't already have a table prefix, // If tableName is provided and the condition HAS a table prefix, check if it's correct
// attempt to add it if tableName != "" && hasTablePrefix(condToCheck) {
if tableName != "" && !hasTablePrefix(condToCheck) { // Extract the current prefix and column name
// Check if this is a SQL expression/literal that shouldn't be prefixed currentPrefix, columnName := extractTableAndColumn(condToCheck)
if !IsSQLExpression(strings.ToLower(condToCheck)) {
// Extract the column name and prefix it if currentPrefix != "" && columnName != "" {
columnName := ExtractColumnName(condToCheck) // Check if the prefix is allowed (main table or preload relation)
if columnName != "" { if !allowedPrefixes[currentPrefix] {
// Only prefix if this is a valid column in the model // Prefix is not in the allowed list - only fix if it's a valid column in the main table
// If we don't have model info (validColumns is nil), prefix anyway for backward compatibility
if validColumns == nil || isValidColumn(columnName, validColumns) { if validColumns == nil || isValidColumn(columnName, validColumns) {
// Replace in the original condition (without stripped parens) // Replace the incorrect prefix with the correct main table name
cond = strings.Replace(cond, columnName, tableName+"."+columnName, 1) oldRef := currentPrefix + "." + columnName
logger.Debug("Prefixed column in condition: '%s'", cond) newRef := tableName + "." + columnName
cond = strings.Replace(cond, oldRef, newRef, 1)
logger.Debug("Fixed incorrect table prefix in condition: '%s' -> '%s'", oldRef, newRef)
} else { } else {
logger.Debug("Skipping prefix for '%s' - not a valid column in model", columnName) logger.Debug("Skipping prefix fix for '%s.%s' - not a valid column in main table (might be preload relation)", currentPrefix, columnName)
} }
} }
} }
@@ -241,19 +263,57 @@ func stripOuterParentheses(s string) string {
} }
// splitByAND splits a WHERE clause by AND operators (case-insensitive) // splitByAND splits a WHERE clause by AND operators (case-insensitive)
// This is a simple split that doesn't handle nested parentheses or complex expressions // This is parenthesis-aware and won't split on AND operators inside subqueries
func splitByAND(where string) []string { func splitByAND(where string) []string {
// First try uppercase AND conditions := []string{}
conditions := strings.Split(where, " AND ") currentCondition := strings.Builder{}
depth := 0 // Track parenthesis depth
i := 0
// If we didn't split on uppercase, try lowercase for i < len(where) {
if len(conditions) == 1 { ch := where[i]
conditions = strings.Split(where, " and ")
// Track parenthesis depth
if ch == '(' {
depth++
currentCondition.WriteByte(ch)
i++
continue
} else if ch == ')' {
depth--
currentCondition.WriteByte(ch)
i++
continue
}
// Only look for AND operators at depth 0 (not inside parentheses)
if depth == 0 {
// Check if we're at an AND operator (case-insensitive)
// We need at least " AND " (5 chars) or " and " (5 chars)
if i+5 <= len(where) {
substring := where[i : i+5]
lowerSubstring := strings.ToLower(substring)
if lowerSubstring == " and " {
// Found an AND operator at the top level
// Add the current condition to the list
conditions = append(conditions, currentCondition.String())
currentCondition.Reset()
// Skip past the AND operator
i += 5
continue
}
}
}
// Not an AND operator or we're inside parentheses, just add the character
currentCondition.WriteByte(ch)
i++
} }
// If we still didn't split, try mixed case // Add the last condition
if len(conditions) == 1 { if currentCondition.Len() > 0 {
conditions = strings.Split(where, " And ") conditions = append(conditions, currentCondition.String())
} }
return conditions return conditions
@@ -330,6 +390,108 @@ func getValidColumnsForTable(tableName string) map[string]bool {
return columnMap return columnMap
} }
// extractTableAndColumn extracts the table prefix and column name from a qualified reference
// For example: "users.status = 'active'" returns ("users", "status")
// Returns empty strings if no table prefix is found
// This function is parenthesis-aware and will only look for operators outside of subqueries
func extractTableAndColumn(cond string) (table string, column string) {
// Common SQL operators to find the column reference
operators := []string{" = ", " != ", " <> ", " > ", " >= ", " < ", " <= ", " LIKE ", " like ", " IN ", " in ", " IS ", " is "}
var columnRef string
// Find the column reference (left side of the operator)
// We need to find the first operator that appears OUTSIDE of parentheses
minIdx := -1
for _, op := range operators {
idx := findOperatorOutsideParentheses(cond, op)
if idx > 0 && (minIdx == -1 || idx < minIdx) {
minIdx = idx
}
}
if minIdx > 0 {
columnRef = strings.TrimSpace(cond[:minIdx])
}
// If no operator found, the whole condition might be the column reference
if columnRef == "" {
parts := strings.Fields(cond)
if len(parts) > 0 {
columnRef = parts[0]
}
}
if columnRef == "" {
return "", ""
}
// Remove any quotes
columnRef = strings.Trim(columnRef, "`\"'")
// Check if it contains a dot (qualified reference)
if dotIdx := strings.LastIndex(columnRef, "."); dotIdx > 0 {
table = columnRef[:dotIdx]
column = columnRef[dotIdx+1:]
// Remove quotes from table and column if present
table = strings.Trim(table, "`\"'")
column = strings.Trim(column, "`\"'")
return table, column
}
return "", ""
}
// findOperatorOutsideParentheses finds the first occurrence of an operator outside of parentheses
// Returns the index of the operator, or -1 if not found or only found inside parentheses
func findOperatorOutsideParentheses(s string, operator string) int {
depth := 0
inSingleQuote := false
inDoubleQuote := false
for i := 0; i < len(s); i++ {
ch := s[i]
// Track quote state (operators inside quotes should be ignored)
if ch == '\'' && !inDoubleQuote {
inSingleQuote = !inSingleQuote
continue
}
if ch == '"' && !inSingleQuote {
inDoubleQuote = !inDoubleQuote
continue
}
// Skip if we're inside quotes
if inSingleQuote || inDoubleQuote {
continue
}
// Track parenthesis depth
switch ch {
case '(':
depth++
case ')':
depth--
}
// Only look for the operator when we're outside parentheses (depth == 0)
if depth == 0 {
// Check if the operator starts at this position
if i+len(operator) <= len(s) {
if s[i:i+len(operator)] == operator {
return i
}
}
}
}
return -1
}
// isValidColumn checks if a column name exists in the valid columns map // isValidColumn checks if a column name exists in the valid columns map
// Handles case-insensitive comparison // Handles case-insensitive comparison
func isValidColumn(columnName string, validColumns map[string]bool) bool { func isValidColumn(columnName string, validColumns map[string]bool) bool {

View File

@@ -1,6 +1,7 @@
package common package common
import ( import (
"strings"
"testing" "testing"
"github.com/bitechdev/ResolveSpec/pkg/modelregistry" "github.com/bitechdev/ResolveSpec/pkg/modelregistry"
@@ -32,29 +33,41 @@ func TestSanitizeWhereClause(t *testing.T) {
expected: "", expected: "",
}, },
{ {
name: "valid condition with parentheses", name: "valid condition with parentheses - no prefix added",
where: "(status = 'active')", where: "(status = 'active')",
tableName: "users", tableName: "users",
expected: "users.status = 'active'", expected: "status = 'active'",
}, },
{ {
name: "mixed trivial and valid conditions", name: "mixed trivial and valid conditions - no prefix added",
where: "true AND status = 'active' AND 1=1", where: "true AND status = 'active' AND 1=1",
tableName: "users", tableName: "users",
expected: "users.status = 'active'", expected: "status = 'active'",
}, },
{ {
name: "condition already with table prefix", name: "condition with correct table prefix - unchanged",
where: "users.status = 'active'", where: "users.status = 'active'",
tableName: "users", tableName: "users",
expected: "users.status = 'active'", expected: "users.status = 'active'",
}, },
{ {
name: "multiple valid conditions", name: "condition with incorrect table prefix - fixed",
where: "status = 'active' AND age > 18", where: "wrong_table.status = 'active'",
tableName: "users",
expected: "users.status = 'active'",
},
{
name: "multiple conditions with incorrect prefix - fixed",
where: "wrong_table.status = 'active' AND wrong_table.age > 18",
tableName: "users", tableName: "users",
expected: "users.status = 'active' AND users.age > 18", expected: "users.status = 'active' AND users.age > 18",
}, },
{
name: "multiple valid conditions without prefix - no prefix added",
where: "status = 'active' AND age > 18",
tableName: "users",
expected: "status = 'active' AND age > 18",
},
{ {
name: "no table name provided", name: "no table name provided",
where: "status = 'active'", where: "status = 'active'",
@@ -67,6 +80,60 @@ func TestSanitizeWhereClause(t *testing.T) {
tableName: "users", tableName: "users",
expected: "", expected: "",
}, },
{
name: "mixed correct and incorrect prefixes",
where: "users.status = 'active' AND wrong_table.age > 18",
tableName: "users",
expected: "users.status = 'active' AND users.age > 18",
},
{
name: "mixed case AND operators",
where: "status = 'active' AND age > 18 and name = 'John'",
tableName: "users",
expected: "status = 'active' AND age > 18 AND name = 'John'",
},
{
name: "subquery with ORDER BY and LIMIT - allowed",
where: "id IN (SELECT id FROM users WHERE status = 'active' ORDER BY created_at DESC LIMIT 10)",
tableName: "users",
expected: "id IN (SELECT id FROM users WHERE status = 'active' ORDER BY created_at DESC LIMIT 10)",
},
{
name: "dangerous DELETE keyword - blocked",
where: "status = 'active'; DELETE FROM users",
tableName: "users",
expected: "",
},
{
name: "dangerous UPDATE keyword - blocked",
where: "1=1; UPDATE users SET admin = true",
tableName: "users",
expected: "",
},
{
name: "dangerous TRUNCATE keyword - blocked",
where: "status = 'active' OR TRUNCATE TABLE users",
tableName: "users",
expected: "",
},
{
name: "dangerous DROP keyword - blocked",
where: "status = 'active'; DROP TABLE users",
tableName: "users",
expected: "",
},
{
name: "subquery with table alias should not be modified",
where: "apiprovider.rid_apiprovider in (select l.rid_apiprovider from core.apiproviderlink l where l.rid_hub = 2576)",
tableName: "apiprovider",
expected: "apiprovider.rid_apiprovider in (select l.rid_apiprovider from core.apiproviderlink l where l.rid_hub = 2576)",
},
{
name: "complex subquery with AND and multiple operators",
where: "apiprovider.type in ('softphone') AND (apiprovider.rid_apiprovider in (select l.rid_apiprovider from core.apiproviderlink l where l.rid_hub = 2576))",
tableName: "apiprovider",
expected: "apiprovider.type in ('softphone') AND (apiprovider.rid_apiprovider in (select l.rid_apiprovider from core.apiproviderlink l where l.rid_hub = 2576))",
},
} }
for _, tt := range tests { for _, tt := range tests {
@@ -120,6 +187,11 @@ func TestStripOuterParentheses(t *testing.T) {
input: " ( true ) ", input: " ( true ) ",
expected: "true", expected: "true",
}, },
{
name: "complex sub query",
input: "(a = 1 AND b = 2 or c = 3 and (select s from generate_series(1,10) s where s < 10 and s > 0 offset 2 limit 1) = 3)",
expected: "a = 1 AND b = 2 or c = 3 and (select s from generate_series(1,10) s where s < 10 and s > 0 offset 2 limit 1) = 3",
},
} }
for _, tt := range tests { for _, tt := range tests {
@@ -159,6 +231,158 @@ func TestIsTrivialCondition(t *testing.T) {
} }
} }
func TestExtractTableAndColumn(t *testing.T) {
tests := []struct {
name string
input string
expectedTable string
expectedCol string
}{
{
name: "qualified column with equals",
input: "users.status = 'active'",
expectedTable: "users",
expectedCol: "status",
},
{
name: "qualified column with greater than",
input: "users.age > 18",
expectedTable: "users",
expectedCol: "age",
},
{
name: "qualified column with LIKE",
input: "users.name LIKE '%john%'",
expectedTable: "users",
expectedCol: "name",
},
{
name: "qualified column with IN",
input: "users.status IN ('active', 'pending')",
expectedTable: "users",
expectedCol: "status",
},
{
name: "unqualified column",
input: "status = 'active'",
expectedTable: "",
expectedCol: "",
},
{
name: "qualified with backticks",
input: "`users`.`status` = 'active'",
expectedTable: "users",
expectedCol: "status",
},
{
name: "schema.table.column reference",
input: "public.users.status = 'active'",
expectedTable: "public.users",
expectedCol: "status",
},
{
name: "empty string",
input: "",
expectedTable: "",
expectedCol: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
table, col := extractTableAndColumn(tt.input)
if table != tt.expectedTable || col != tt.expectedCol {
t.Errorf("extractTableAndColumn(%q) = (%q, %q); want (%q, %q)",
tt.input, table, col, tt.expectedTable, tt.expectedCol)
}
})
}
}
func TestSanitizeWhereClauseWithPreloads(t *testing.T) {
tests := []struct {
name string
where string
tableName string
options *RequestOptions
expected string
}{
{
name: "preload relation prefix is preserved",
where: "Department.name = 'Engineering'",
tableName: "users",
options: &RequestOptions{
Preload: []PreloadOption{
{Relation: "Department"},
},
},
expected: "Department.name = 'Engineering'",
},
{
name: "multiple preload relations - all preserved",
where: "Department.name = 'Engineering' AND Manager.status = 'active'",
tableName: "users",
options: &RequestOptions{
Preload: []PreloadOption{
{Relation: "Department"},
{Relation: "Manager"},
},
},
expected: "Department.name = 'Engineering' AND Manager.status = 'active'",
},
{
name: "mix of main table and preload relation",
where: "users.status = 'active' AND Department.name = 'Engineering'",
tableName: "users",
options: &RequestOptions{
Preload: []PreloadOption{
{Relation: "Department"},
},
},
expected: "users.status = 'active' AND Department.name = 'Engineering'",
},
{
name: "incorrect prefix fixed when not a preload relation",
where: "wrong_table.status = 'active' AND Department.name = 'Engineering'",
tableName: "users",
options: &RequestOptions{
Preload: []PreloadOption{
{Relation: "Department"},
},
},
expected: "users.status = 'active' AND Department.name = 'Engineering'",
},
{
name: "no options provided - works as before",
where: "wrong_table.status = 'active'",
tableName: "users",
options: nil,
expected: "users.status = 'active'",
},
{
name: "empty preload list - works as before",
where: "wrong_table.status = 'active'",
tableName: "users",
options: &RequestOptions{Preload: []PreloadOption{}},
expected: "users.status = 'active'",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var result string
if tt.options != nil {
result = SanitizeWhereClause(tt.where, tt.tableName, tt.options)
} else {
result = SanitizeWhereClause(tt.where, tt.tableName)
}
if result != tt.expected {
t.Errorf("SanitizeWhereClause(%q, %q, options) = %q; want %q", tt.where, tt.tableName, result, tt.expected)
}
})
}
}
// Test model for model-aware sanitization tests // Test model for model-aware sanitization tests
type MasterTask struct { type MasterTask struct {
ID int `bun:"id,pk"` ID int `bun:"id,pk"`
@@ -167,6 +391,131 @@ type MasterTask struct {
UserID int `bun:"user_id"` UserID int `bun:"user_id"`
} }
func TestSplitByAND(t *testing.T) {
tests := []struct {
name string
input string
expected []string
}{
{
name: "uppercase AND",
input: "status = 'active' AND age > 18",
expected: []string{"status = 'active'", "age > 18"},
},
{
name: "lowercase and",
input: "status = 'active' and age > 18",
expected: []string{"status = 'active'", "age > 18"},
},
{
name: "mixed case AND",
input: "status = 'active' AND age > 18 and name = 'John'",
expected: []string{"status = 'active'", "age > 18", "name = 'John'"},
},
{
name: "single condition",
input: "status = 'active'",
expected: []string{"status = 'active'"},
},
{
name: "multiple uppercase AND",
input: "a = 1 AND b = 2 AND c = 3",
expected: []string{"a = 1", "b = 2", "c = 3"},
},
{
name: "multiple case subquery",
input: "a = 1 AND b = 2 AND c = 3 and (select s from generate_series(1,10) s where s < 10 and s > 0 offset 2 limit 1) = 3",
expected: []string{"a = 1", "b = 2", "c = 3", "(select s from generate_series(1,10) s where s < 10 and s > 0 offset 2 limit 1) = 3"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := splitByAND(tt.input)
if len(result) != len(tt.expected) {
t.Errorf("splitByAND(%q) returned %d conditions; want %d", tt.input, len(result), len(tt.expected))
return
}
for i := range result {
if strings.TrimSpace(result[i]) != strings.TrimSpace(tt.expected[i]) {
t.Errorf("splitByAND(%q)[%d] = %q; want %q", tt.input, i, result[i], tt.expected[i])
}
}
})
}
}
func TestValidateWhereClauseSecurity(t *testing.T) {
tests := []struct {
name string
input string
expectError bool
}{
{
name: "safe WHERE clause",
input: "status = 'active' AND age > 18",
expectError: false,
},
{
name: "safe subquery",
input: "id IN (SELECT id FROM users WHERE status = 'active' ORDER BY created_at DESC LIMIT 10)",
expectError: false,
},
{
name: "DELETE keyword",
input: "status = 'active'; DELETE FROM users",
expectError: true,
},
{
name: "UPDATE keyword",
input: "1=1; UPDATE users SET admin = true",
expectError: true,
},
{
name: "TRUNCATE keyword",
input: "status = 'active' OR TRUNCATE TABLE users",
expectError: true,
},
{
name: "DROP keyword",
input: "status = 'active'; DROP TABLE users",
expectError: true,
},
{
name: "INSERT keyword",
input: "status = 'active'; INSERT INTO users (name) VALUES ('hacker')",
expectError: true,
},
{
name: "ALTER keyword",
input: "1=1; ALTER TABLE users ADD COLUMN is_admin BOOLEAN",
expectError: true,
},
{
name: "CREATE keyword",
input: "1=1; CREATE TABLE malicious (id INT)",
expectError: true,
},
{
name: "empty clause",
input: "",
expectError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validateWhereClauseSecurity(tt.input)
if tt.expectError && err == nil {
t.Errorf("validateWhereClauseSecurity(%q) expected error but got none", tt.input)
}
if !tt.expectError && err != nil {
t.Errorf("validateWhereClauseSecurity(%q) unexpected error: %v", tt.input, err)
}
})
}
}
func TestSanitizeWhereClauseWithModel(t *testing.T) { func TestSanitizeWhereClauseWithModel(t *testing.T) {
// Register the test model // Register the test model
err := modelregistry.RegisterModel(MasterTask{}, "mastertask") err := modelregistry.RegisterModel(MasterTask{}, "mastertask")
@@ -182,34 +531,52 @@ func TestSanitizeWhereClauseWithModel(t *testing.T) {
expected string expected string
}{ }{
{ {
name: "valid column gets prefixed", name: "valid column without prefix - no prefix added",
where: "status = 'active'", where: "status = 'active'",
tableName: "mastertask", tableName: "mastertask",
expected: "status = 'active'",
},
{
name: "multiple valid columns without prefix - no prefix added",
where: "status = 'active' AND user_id = 123",
tableName: "mastertask",
expected: "status = 'active' AND user_id = 123",
},
{
name: "incorrect table prefix on valid column - fixed",
where: "wrong_table.status = 'active'",
tableName: "mastertask",
expected: "mastertask.status = 'active'", expected: "mastertask.status = 'active'",
}, },
{ {
name: "multiple valid columns get prefixed", name: "incorrect prefix on invalid column - not fixed",
where: "status = 'active' AND user_id = 123", where: "wrong_table.invalid_column = 'value'",
tableName: "mastertask", tableName: "mastertask",
expected: "mastertask.status = 'active' AND mastertask.user_id = 123", expected: "wrong_table.invalid_column = 'value'",
},
{
name: "invalid column does not get prefixed",
where: "invalid_column = 'value'",
tableName: "mastertask",
expected: "invalid_column = 'value'",
}, },
{ {
name: "mix of valid and trivial conditions", name: "mix of valid and trivial conditions",
where: "true AND status = 'active' AND 1=1", where: "true AND status = 'active' AND 1=1",
tableName: "mastertask", tableName: "mastertask",
expected: "status = 'active'",
},
{
name: "parentheses with valid column - no prefix added",
where: "(status = 'active')",
tableName: "mastertask",
expected: "status = 'active'",
},
{
name: "correct prefix - unchanged",
where: "mastertask.status = 'active'",
tableName: "mastertask",
expected: "mastertask.status = 'active'", expected: "mastertask.status = 'active'",
}, },
{ {
name: "parentheses with valid column", name: "multiple conditions with mixed prefixes",
where: "(status = 'active')", where: "mastertask.status = 'active' AND wrong_table.user_id = 123",
tableName: "mastertask", tableName: "mastertask",
expected: "mastertask.status = 'active'", expected: "mastertask.status = 'active' AND mastertask.user_id = 123",
}, },
} }

View File

@@ -71,15 +71,14 @@ func (n *SqlNull[T]) Scan(value any) error {
// Fallback: parse from string/bytes. // Fallback: parse from string/bytes.
switch v := value.(type) { switch v := value.(type) {
case string: case string:
return n.fromString(v) return n.FromString(v)
case []byte: case []byte:
return n.fromString(string(v)) return n.FromString(string(v))
default: default:
return n.fromString(fmt.Sprintf("%v", value)) return n.FromString(fmt.Sprintf("%v", value))
} }
} }
func (n *SqlNull[T]) FromString(s string) error {
func (n *SqlNull[T]) fromString(s string) error {
s = strings.TrimSpace(s) s = strings.TrimSpace(s)
n.Valid = false n.Valid = false
n.Val = *new(T) n.Val = *new(T)
@@ -90,19 +89,14 @@ func (n *SqlNull[T]) fromString(s string) error {
var zero T var zero T
switch any(zero).(type) { switch any(zero).(type) {
case int, int8, int16, int32, int64: case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
if i, err := strconv.ParseInt(s, 10, 64); err == nil { if i, err := strconv.ParseInt(s, 10, 64); err == nil {
n.Val = any(int64(i)).(T) // Cast to T (e.g., int16) reflect.ValueOf(&n.Val).Elem().SetInt(i)
n.Valid = true
}
case uint, uint8, uint16, uint32, uint64:
if u, err := strconv.ParseUint(s, 10, 64); err == nil {
n.Val = any(u).(T)
n.Valid = true n.Valid = true
} }
case float32, float64: case float32, float64:
if f, err := strconv.ParseFloat(s, 64); err == nil { if f, err := strconv.ParseFloat(s, 64); err == nil {
n.Val = any(f).(T) reflect.ValueOf(&n.Val).Elem().SetFloat(f)
n.Valid = true n.Valid = true
} }
case bool: case bool:
@@ -124,7 +118,6 @@ func (n *SqlNull[T]) fromString(s string) error {
n.Val = any(s).(T) n.Val = any(s).(T)
n.Valid = true n.Valid = true
} }
return nil return nil
} }
@@ -163,7 +156,7 @@ func (n *SqlNull[T]) UnmarshalJSON(b []byte) error {
// Fallback: unmarshal as string and parse. // Fallback: unmarshal as string and parse.
var s string var s string
if err := json.Unmarshal(b, &s); err == nil { if err := json.Unmarshal(b, &s); err == nil {
return n.fromString(s) return n.FromString(s)
} }
return fmt.Errorf("cannot unmarshal %s into SqlNull[%T]", b, n.Val) return fmt.Errorf("cannot unmarshal %s into SqlNull[%T]", b, n.Val)
@@ -517,6 +510,33 @@ func TryIfInt64(v any, def int64) int64 {
} }
// Constructor helpers - clean and fast value creation // Constructor helpers - clean and fast value creation
func Null[T any](v T, valid bool) SqlNull[T] {
return SqlNull[T]{Val: v, Valid: valid}
}
func NewSql[T any](value any) SqlNull[T] {
n := SqlNull[T]{}
if value == nil {
return n
}
// Fast path: exact match
if v, ok := value.(T); ok {
n.Val = v
n.Valid = true
return n
}
// Try from another SqlNull
if sn, ok := value.(SqlNull[T]); ok {
return sn
}
// Convert via string
_ = n.FromString(fmt.Sprintf("%v", value))
return n
}
func NewSqlInt16(v int16) SqlInt16 { func NewSqlInt16(v int16) SqlInt16 {
return SqlInt16{Val: v, Valid: true} return SqlInt16{Val: v, Valid: true}

View File

@@ -16,11 +16,11 @@ func TestNewSqlInt16(t *testing.T) {
input interface{} input interface{}
expected SqlInt16 expected SqlInt16
}{ }{
{"int", 42, NewSqlInt16(42)}, {"int", 42, Null(int16(42), true)},
{"int32", int32(100), NewSqlInt16(100)}, {"int32", int32(100), NewSqlInt16(100)},
{"int64", int64(200), NewSqlInt16(200)}, {"int64", int64(200), NewSqlInt16(200)},
{"string", "123", NewSqlInt16(123)}, {"string", "123", NewSqlInt16(123)},
{"nil", nil, NewSqlInt16(0)}, {"nil", nil, Null(int16(0), false)},
} }
for _, tt := range tests { for _, tt := range tests {
@@ -42,9 +42,9 @@ func TestNewSqlInt16_Value(t *testing.T) {
input SqlInt16 input SqlInt16
expected driver.Value expected driver.Value
}{ }{
{"zero", NewSqlInt16(0), nil}, {"zero", Null(int16(0), false), nil},
{"positive", NewSqlInt16(42), int64(42)}, {"positive", NewSqlInt16(42), int16(42)},
{"negative", NewSqlInt16(-10), int64(-10)}, {"negative", NewSqlInt16(-10), int16(-10)},
} }
for _, tt := range tests { for _, tt := range tests {
@@ -486,7 +486,7 @@ func TestSqlUUID_Value(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("Value failed: %v", err) t.Fatalf("Value failed: %v", err)
} }
if val != testUUID.String() { if val != testUUID {
t.Errorf("expected %s, got %s", testUUID.String(), val) t.Errorf("expected %s, got %s", testUUID.String(), val)
} }

View File

@@ -4,13 +4,14 @@ import "time"
// Config represents the complete application configuration // Config represents the complete application configuration
type Config struct { type Config struct {
Server ServerConfig `mapstructure:"server"` Server ServerConfig `mapstructure:"server"`
Tracing TracingConfig `mapstructure:"tracing"` Tracing TracingConfig `mapstructure:"tracing"`
Cache CacheConfig `mapstructure:"cache"` Cache CacheConfig `mapstructure:"cache"`
Logger LoggerConfig `mapstructure:"logger"` Logger LoggerConfig `mapstructure:"logger"`
Middleware MiddlewareConfig `mapstructure:"middleware"` ErrorTracking ErrorTrackingConfig `mapstructure:"error_tracking"`
CORS CORSConfig `mapstructure:"cors"` Middleware MiddlewareConfig `mapstructure:"middleware"`
Database DatabaseConfig `mapstructure:"database"` CORS CORSConfig `mapstructure:"cors"`
Database DatabaseConfig `mapstructure:"database"`
} }
// ServerConfig holds server-related configuration // ServerConfig holds server-related configuration
@@ -78,3 +79,15 @@ type CORSConfig struct {
type DatabaseConfig struct { type DatabaseConfig struct {
URL string `mapstructure:"url"` URL string `mapstructure:"url"`
} }
// ErrorTrackingConfig holds error tracking configuration
type ErrorTrackingConfig struct {
Enabled bool `mapstructure:"enabled"`
Provider string `mapstructure:"provider"` // sentry, noop
DSN string `mapstructure:"dsn"` // Sentry DSN
Environment string `mapstructure:"environment"` // e.g., production, staging, development
Release string `mapstructure:"release"` // Application version/release
Debug bool `mapstructure:"debug"` // Enable debug mode
SampleRate float64 `mapstructure:"sample_rate"` // Error sample rate (0.0-1.0)
TracesSampleRate float64 `mapstructure:"traces_sample_rate"` // Traces sample rate (0.0-1.0)
}

150
pkg/errortracking/README.md Normal file
View File

@@ -0,0 +1,150 @@
# Error Tracking
This package provides error tracking integration for ResolveSpec, with built-in support for Sentry.
## Features
- **Provider Interface**: Flexible design supporting multiple error tracking backends
- **Sentry Integration**: Full-featured Sentry support with automatic error, warning, and panic tracking
- **Automatic Logger Integration**: All `logger.Error()` and `logger.Warn()` calls are automatically sent to the error tracker
- **Panic Tracking**: Automatic panic capture with stack traces
- **NoOp Provider**: Zero-overhead when error tracking is disabled
## Configuration
Add error tracking configuration to your config file:
```yaml
error_tracking:
enabled: true
provider: "sentry" # Currently supports: "sentry" or "noop"
dsn: "https://your-sentry-dsn@sentry.io/project-id"
environment: "production" # e.g., production, staging, development
release: "v1.0.0" # Your application version
debug: false
sample_rate: 1.0 # Error sample rate (0.0-1.0)
traces_sample_rate: 0.1 # Traces sample rate (0.0-1.0)
```
## Usage
### Initialization
Initialize error tracking in your application startup:
```go
package main
import (
"github.com/bitechdev/ResolveSpec/pkg/config"
"github.com/bitechdev/ResolveSpec/pkg/errortracking"
"github.com/bitechdev/ResolveSpec/pkg/logger"
)
func main() {
// Load your configuration
cfg := config.Config{
ErrorTracking: config.ErrorTrackingConfig{
Enabled: true,
Provider: "sentry",
DSN: "https://your-sentry-dsn@sentry.io/project-id",
Environment: "production",
Release: "v1.0.0",
SampleRate: 1.0,
},
}
// Initialize logger
logger.Init(false)
// Initialize error tracking
provider, err := errortracking.NewProviderFromConfig(cfg.ErrorTracking)
if err != nil {
logger.Error("Failed to initialize error tracking: %v", err)
} else {
logger.InitErrorTracking(provider)
}
// Your application code...
// Cleanup on shutdown
defer logger.CloseErrorTracking()
}
```
### Automatic Tracking
Once initialized, all logger errors and warnings are automatically sent to the error tracker:
```go
// This will be logged AND sent to Sentry
logger.Error("Database connection failed: %v", err)
// This will also be logged AND sent to Sentry
logger.Warn("Cache miss for key: %s", key)
```
### Panic Tracking
Panics are automatically captured when using the logger's panic handlers:
```go
// Using CatchPanic
defer logger.CatchPanic("MyFunction")
// Using CatchPanicCallback
defer logger.CatchPanicCallback("MyFunction", func(err any) {
// Custom cleanup
})
// Using HandlePanic
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("MyMethod", r)
}
}()
```
### Manual Tracking
You can also use the provider directly for custom error tracking:
```go
import (
"context"
"github.com/bitechdev/ResolveSpec/pkg/errortracking"
"github.com/bitechdev/ResolveSpec/pkg/logger"
)
func someFunction() {
tracker := logger.GetErrorTracker()
if tracker != nil {
// Capture an error
tracker.CaptureError(context.Background(), err, errortracking.SeverityError, map[string]interface{}{
"user_id": userID,
"request_id": requestID,
})
// Capture a message
tracker.CaptureMessage(context.Background(), "Important event occurred", errortracking.SeverityInfo, map[string]interface{}{
"event_type": "user_signup",
})
// Capture a panic
tracker.CapturePanic(context.Background(), recovered, stackTrace, map[string]interface{}{
"context": "background_job",
})
}
}
```
## Severity Levels
The package supports the following severity levels:
- `SeverityError`: For errors that should be tracked and investigated
- `SeverityWarning`: For warnings that may indicate potential issues
- `SeverityInfo`: For informational messages
- `SeverityDebug`: For debug-level information
```

View File

@@ -0,0 +1,67 @@
package errortracking
import (
"context"
"errors"
"testing"
)
func TestNoOpProvider(t *testing.T) {
provider := NewNoOpProvider()
// Test that all methods can be called without panicking
t.Run("CaptureError", func(t *testing.T) {
provider.CaptureError(context.Background(), errors.New("test error"), SeverityError, nil)
})
t.Run("CaptureMessage", func(t *testing.T) {
provider.CaptureMessage(context.Background(), "test message", SeverityWarning, nil)
})
t.Run("CapturePanic", func(t *testing.T) {
provider.CapturePanic(context.Background(), "panic!", []byte("stack trace"), nil)
})
t.Run("Flush", func(t *testing.T) {
result := provider.Flush(5)
if !result {
t.Error("Expected Flush to return true")
}
})
t.Run("Close", func(t *testing.T) {
err := provider.Close()
if err != nil {
t.Errorf("Expected Close to return nil, got %v", err)
}
})
}
func TestSeverityLevels(t *testing.T) {
tests := []struct {
name string
severity Severity
expected string
}{
{"Error", SeverityError, "error"},
{"Warning", SeverityWarning, "warning"},
{"Info", SeverityInfo, "info"},
{"Debug", SeverityDebug, "debug"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if string(tt.severity) != tt.expected {
t.Errorf("Expected %s, got %s", tt.expected, string(tt.severity))
}
})
}
}
func TestProviderInterface(t *testing.T) {
// Test that NoOpProvider implements Provider interface
var _ Provider = (*NoOpProvider)(nil)
// Test that SentryProvider implements Provider interface
var _ Provider = (*SentryProvider)(nil)
}

View File

@@ -0,0 +1,33 @@
package errortracking
import (
"fmt"
"github.com/bitechdev/ResolveSpec/pkg/config"
)
// NewProviderFromConfig creates an error tracking provider based on the configuration
func NewProviderFromConfig(cfg config.ErrorTrackingConfig) (Provider, error) {
if !cfg.Enabled {
return NewNoOpProvider(), nil
}
switch cfg.Provider {
case "sentry":
if cfg.DSN == "" {
return nil, fmt.Errorf("sentry DSN is required when error tracking is enabled")
}
return NewSentryProvider(SentryConfig{
DSN: cfg.DSN,
Environment: cfg.Environment,
Release: cfg.Release,
Debug: cfg.Debug,
SampleRate: cfg.SampleRate,
TracesSampleRate: cfg.TracesSampleRate,
})
case "noop", "":
return NewNoOpProvider(), nil
default:
return nil, fmt.Errorf("unknown error tracking provider: %s", cfg.Provider)
}
}

View File

@@ -0,0 +1,33 @@
package errortracking
import (
"context"
)
// Severity represents the severity level of an error
type Severity string
const (
SeverityError Severity = "error"
SeverityWarning Severity = "warning"
SeverityInfo Severity = "info"
SeverityDebug Severity = "debug"
)
// Provider defines the interface for error tracking providers
type Provider interface {
// CaptureError captures an error with the given severity and additional context
CaptureError(ctx context.Context, err error, severity Severity, extra map[string]interface{})
// CaptureMessage captures a message with the given severity and additional context
CaptureMessage(ctx context.Context, message string, severity Severity, extra map[string]interface{})
// CapturePanic captures a panic with stack trace
CapturePanic(ctx context.Context, recovered interface{}, stackTrace []byte, extra map[string]interface{})
// Flush waits for all events to be sent (useful for graceful shutdown)
Flush(timeout int) bool
// Close closes the provider and releases resources
Close() error
}

37
pkg/errortracking/noop.go Normal file
View File

@@ -0,0 +1,37 @@
package errortracking
import "context"
// NoOpProvider is a no-op implementation of the Provider interface
// Used when error tracking is disabled
type NoOpProvider struct{}
// NewNoOpProvider creates a new NoOp provider
func NewNoOpProvider() *NoOpProvider {
return &NoOpProvider{}
}
// CaptureError does nothing
func (n *NoOpProvider) CaptureError(ctx context.Context, err error, severity Severity, extra map[string]interface{}) {
// No-op
}
// CaptureMessage does nothing
func (n *NoOpProvider) CaptureMessage(ctx context.Context, message string, severity Severity, extra map[string]interface{}) {
// No-op
}
// CapturePanic does nothing
func (n *NoOpProvider) CapturePanic(ctx context.Context, recovered interface{}, stackTrace []byte, extra map[string]interface{}) {
// No-op
}
// Flush does nothing and returns true
func (n *NoOpProvider) Flush(timeout int) bool {
return true
}
// Close does nothing
func (n *NoOpProvider) Close() error {
return nil
}

154
pkg/errortracking/sentry.go Normal file
View File

@@ -0,0 +1,154 @@
package errortracking
import (
"context"
"fmt"
"time"
"github.com/getsentry/sentry-go"
)
// SentryProvider implements the Provider interface using Sentry
type SentryProvider struct {
hub *sentry.Hub
}
// SentryConfig holds the configuration for Sentry
type SentryConfig struct {
DSN string
Environment string
Release string
Debug bool
SampleRate float64
TracesSampleRate float64
}
// NewSentryProvider creates a new Sentry provider
func NewSentryProvider(config SentryConfig) (*SentryProvider, error) {
err := sentry.Init(sentry.ClientOptions{
Dsn: config.DSN,
Environment: config.Environment,
Release: config.Release,
Debug: config.Debug,
AttachStacktrace: true,
SampleRate: config.SampleRate,
TracesSampleRate: config.TracesSampleRate,
})
if err != nil {
return nil, fmt.Errorf("failed to initialize Sentry: %w", err)
}
return &SentryProvider{
hub: sentry.CurrentHub(),
}, nil
}
// CaptureError captures an error with the given severity and additional context
func (s *SentryProvider) CaptureError(ctx context.Context, err error, severity Severity, extra map[string]interface{}) {
if err == nil {
return
}
hub := sentry.GetHubFromContext(ctx)
if hub == nil {
hub = s.hub
}
event := sentry.NewEvent()
event.Level = s.convertSeverity(severity)
event.Message = err.Error()
event.Exception = []sentry.Exception{
{
Value: err.Error(),
Type: fmt.Sprintf("%T", err),
Stacktrace: sentry.ExtractStacktrace(err),
},
}
if extra != nil {
event.Extra = extra
}
hub.CaptureEvent(event)
}
// CaptureMessage captures a message with the given severity and additional context
func (s *SentryProvider) CaptureMessage(ctx context.Context, message string, severity Severity, extra map[string]interface{}) {
if message == "" {
return
}
hub := sentry.GetHubFromContext(ctx)
if hub == nil {
hub = s.hub
}
event := sentry.NewEvent()
event.Level = s.convertSeverity(severity)
event.Message = message
if extra != nil {
event.Extra = extra
}
hub.CaptureEvent(event)
}
// CapturePanic captures a panic with stack trace
func (s *SentryProvider) CapturePanic(ctx context.Context, recovered interface{}, stackTrace []byte, extra map[string]interface{}) {
if recovered == nil {
return
}
hub := sentry.GetHubFromContext(ctx)
if hub == nil {
hub = s.hub
}
event := sentry.NewEvent()
event.Level = sentry.LevelError
event.Message = fmt.Sprintf("Panic: %v", recovered)
event.Exception = []sentry.Exception{
{
Value: fmt.Sprintf("%v", recovered),
Type: "panic",
},
}
if extra != nil {
event.Extra = extra
}
if stackTrace != nil {
event.Extra["stack_trace"] = string(stackTrace)
}
hub.CaptureEvent(event)
}
// Flush waits for all events to be sent (useful for graceful shutdown)
func (s *SentryProvider) Flush(timeout int) bool {
return sentry.Flush(time.Duration(timeout) * time.Second)
}
// Close closes the provider and releases resources
func (s *SentryProvider) Close() error {
sentry.Flush(2 * time.Second)
return nil
}
// convertSeverity converts our Severity to Sentry's Level
func (s *SentryProvider) convertSeverity(severity Severity) sentry.Level {
switch severity {
case SeverityError:
return sentry.LevelError
case SeverityWarning:
return sentry.LevelWarning
case SeverityInfo:
return sentry.LevelInfo
case SeverityDebug:
return sentry.LevelDebug
default:
return sentry.LevelError
}
}

View File

@@ -16,8 +16,8 @@ import (
// MockDatabase implements common.Database interface for testing // MockDatabase implements common.Database interface for testing
type MockDatabase struct { type MockDatabase struct {
QueryFunc func(ctx context.Context, dest interface{}, query string, args ...interface{}) error QueryFunc func(ctx context.Context, dest interface{}, query string, args ...interface{}) error
ExecFunc func(ctx context.Context, query string, args ...interface{}) (common.Result, error) ExecFunc func(ctx context.Context, query string, args ...interface{}) (common.Result, error)
RunInTransactionFunc func(ctx context.Context, fn func(common.Database) error) error RunInTransactionFunc func(ctx context.Context, fn func(common.Database) error) error
} }
@@ -161,9 +161,9 @@ func TestExtractInputVariables(t *testing.T) {
handler := NewHandler(&MockDatabase{}) handler := NewHandler(&MockDatabase{})
tests := []struct { tests := []struct {
name string name string
sqlQuery string sqlQuery string
expectedVars []string expectedVars []string
}{ }{
{ {
name: "No variables", name: "No variables",
@@ -340,9 +340,9 @@ func TestSqlQryWhere(t *testing.T) {
// TestGetIPAddress tests IP address extraction // TestGetIPAddress tests IP address extraction
func TestGetIPAddress(t *testing.T) { func TestGetIPAddress(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
setupReq func() *http.Request setupReq func() *http.Request
expected string expected string
}{ }{
{ {
name: "X-Forwarded-For header", name: "X-Forwarded-For header",
@@ -782,9 +782,10 @@ func TestReplaceMetaVariables(t *testing.T) {
handler := NewHandler(&MockDatabase{}) handler := NewHandler(&MockDatabase{})
userCtx := &security.UserContext{ userCtx := &security.UserContext{
UserID: 123, UserID: 123,
UserName: "testuser", UserName: "testuser",
SessionID: "456", SessionID: "ABC456",
SessionRID: 456,
} }
metainfo := map[string]interface{}{ metainfo := map[string]interface{}{
@@ -821,6 +822,12 @@ func TestReplaceMetaVariables(t *testing.T) {
expectedCheck: func(result string) bool { expectedCheck: func(result string) bool {
return strings.Contains(result, "456") return strings.Contains(result, "456")
}, },
}, {
name: "Replace [id_session]",
sqlQuery: "SELECT * FROM sessions WHERE session_id = [id_session]",
expectedCheck: func(result string) bool {
return strings.Contains(result, "ABC456")
},
}, },
} }

View File

@@ -1,15 +1,19 @@
package logger package logger
import ( import (
"context"
"fmt" "fmt"
"log" "log"
"os" "os"
"runtime/debug" "runtime/debug"
"go.uber.org/zap" "go.uber.org/zap"
errortracking "github.com/bitechdev/ResolveSpec/pkg/errortracking"
) )
var Logger *zap.SugaredLogger var Logger *zap.SugaredLogger
var errorTracker errortracking.Provider
func Init(dev bool) { func Init(dev bool) {
@@ -49,6 +53,28 @@ func UpdateLogger(config *zap.Config) {
Info("ResolveSpec Logger initialized") Info("ResolveSpec Logger initialized")
} }
// InitErrorTracking initializes the error tracking provider
func InitErrorTracking(provider errortracking.Provider) {
errorTracker = provider
if errorTracker != nil {
Info("Error tracking initialized")
}
}
// GetErrorTracker returns the current error tracking provider
func GetErrorTracker() errortracking.Provider {
return errorTracker
}
// CloseErrorTracking flushes and closes the error tracking provider
func CloseErrorTracking() error {
if errorTracker != nil {
errorTracker.Flush(5)
return errorTracker.Close()
}
return nil
}
func Info(template string, args ...interface{}) { func Info(template string, args ...interface{}) {
if Logger == nil { if Logger == nil {
log.Printf(template, args...) log.Printf(template, args...)
@@ -58,19 +84,35 @@ func Info(template string, args ...interface{}) {
} }
func Warn(template string, args ...interface{}) { func Warn(template string, args ...interface{}) {
message := fmt.Sprintf(template, args...)
if Logger == nil { if Logger == nil {
log.Printf(template, args...) log.Printf("%s", message)
return } else {
Logger.Warnw(message, "process_id", os.Getpid())
}
// Send to error tracker
if errorTracker != nil {
errorTracker.CaptureMessage(context.Background(), message, errortracking.SeverityWarning, map[string]interface{}{
"process_id": os.Getpid(),
})
} }
Logger.Warnw(fmt.Sprintf(template, args...), "process_id", os.Getpid())
} }
func Error(template string, args ...interface{}) { func Error(template string, args ...interface{}) {
message := fmt.Sprintf(template, args...)
if Logger == nil { if Logger == nil {
log.Printf(template, args...) log.Printf("%s", message)
return } else {
Logger.Errorw(message, "process_id", os.Getpid())
}
// Send to error tracker
if errorTracker != nil {
errorTracker.CaptureMessage(context.Background(), message, errortracking.SeverityError, map[string]interface{}{
"process_id": os.Getpid(),
})
} }
Logger.Errorw(fmt.Sprintf(template, args...), "process_id", os.Getpid())
} }
func Debug(template string, args ...interface{}) { func Debug(template string, args ...interface{}) {
@@ -84,7 +126,7 @@ func Debug(template string, args ...interface{}) {
// CatchPanic - Handle panic // CatchPanic - Handle panic
func CatchPanicCallback(location string, cb func(err any)) { func CatchPanicCallback(location string, cb func(err any)) {
if err := recover(); err != nil { if err := recover(); err != nil {
// callstack := debug.Stack() callstack := debug.Stack()
if Logger != nil { if Logger != nil {
Error("Panic in %s : %v", location, err) Error("Panic in %s : %v", location, err)
@@ -93,14 +135,13 @@ func CatchPanicCallback(location string, cb func(err any)) {
debug.PrintStack() debug.PrintStack()
} }
// push to sentry // Send to error tracker
// hub := sentry.CurrentHub() if errorTracker != nil {
// if hub != nil { errorTracker.CapturePanic(context.Background(), err, callstack, map[string]interface{}{
// evtID := hub.Recover(err) "location": location,
// if evtID != nil { "process_id": os.Getpid(),
// sentry.Flush(time.Second * 2) })
// } }
// }
if cb != nil { if cb != nil {
cb(err) cb(err)
@@ -125,5 +166,14 @@ func CatchPanic(location string) {
func HandlePanic(methodName string, r any) error { func HandlePanic(methodName string, r any) error {
stack := debug.Stack() stack := debug.Stack()
Error("Panic in %s: %v\nStack trace:\n%s", methodName, r, string(stack)) Error("Panic in %s: %v\nStack trace:\n%s", methodName, r, string(stack))
// Send to error tracker
if errorTracker != nil {
errorTracker.CapturePanic(context.Background(), r, stack, map[string]interface{}{
"method": methodName,
"process_id": os.Getpid(),
})
}
return fmt.Errorf("panic in %s: %v", methodName, r) return fmt.Errorf("panic in %s: %v", methodName, r)
} }

View File

@@ -0,0 +1,331 @@
package reflection
import (
"reflect"
"testing"
)
// Test models for GetModelColumnDetail
type TestModelForColumnDetail struct {
ID int `gorm:"column:rid_test;primaryKey;type:bigserial;not null" json:"id"`
Name string `gorm:"column:name;type:varchar(255);not null" json:"name"`
Email string `gorm:"column:email;type:varchar(255);unique;nullable" json:"email"`
Description string `gorm:"column:description;type:text;null" json:"description"`
ForeignKey int `gorm:"foreignKey:parent_id" json:"foreign_key"`
}
type EmbeddedBase struct {
ID int `gorm:"column:rid_base;primaryKey;identity" json:"id"`
CreatedAt string `gorm:"column:created_at;type:timestamp" json:"created_at"`
}
type ModelWithEmbeddedForDetail struct {
EmbeddedBase
Title string `gorm:"column:title;type:varchar(100);not null" json:"title"`
Content string `gorm:"column:content;type:text" json:"content"`
}
// Model with nil embedded pointer
type ModelWithNilEmbedded struct {
ID int `gorm:"column:id;primaryKey" json:"id"`
*EmbeddedBase
Name string `gorm:"column:name" json:"name"`
}
func TestGetModelColumnDetail(t *testing.T) {
t.Run("simple struct", func(t *testing.T) {
model := TestModelForColumnDetail{
ID: 1,
Name: "Test",
Email: "test@example.com",
Description: "Test description",
ForeignKey: 100,
}
details := GetModelColumnDetail(reflect.ValueOf(model))
if len(details) != 5 {
t.Errorf("Expected 5 fields, got %d", len(details))
}
// Check ID field
found := false
for _, detail := range details {
if detail.Name == "ID" {
found = true
if detail.SQLName != "rid_test" {
t.Errorf("Expected SQLName 'rid_test', got '%s'", detail.SQLName)
}
// Note: primaryKey (without underscore) is not detected as primary_key
// The function looks for "identity" or "primary_key" (with underscore)
if detail.SQLDataType != "bigserial" {
t.Errorf("Expected SQLDataType 'bigserial', got '%s'", detail.SQLDataType)
}
if detail.Nullable {
t.Errorf("Expected Nullable false, got true")
}
}
}
if !found {
t.Errorf("ID field not found in details")
}
})
t.Run("struct with embedded fields", func(t *testing.T) {
model := ModelWithEmbeddedForDetail{
EmbeddedBase: EmbeddedBase{
ID: 1,
CreatedAt: "2024-01-01",
},
Title: "Test Title",
Content: "Test Content",
}
details := GetModelColumnDetail(reflect.ValueOf(model))
// Should have 4 fields: ID, CreatedAt from embedded, Title, Content from main
if len(details) != 4 {
t.Errorf("Expected 4 fields, got %d", len(details))
}
// Check that embedded field is included
foundID := false
foundCreatedAt := false
for _, detail := range details {
if detail.Name == "ID" {
foundID = true
if detail.SQLKey != "primary_key" {
t.Errorf("Expected SQLKey 'primary_key' for embedded ID, got '%s'", detail.SQLKey)
}
}
if detail.Name == "CreatedAt" {
foundCreatedAt = true
}
}
if !foundID {
t.Errorf("Embedded ID field not found")
}
if !foundCreatedAt {
t.Errorf("Embedded CreatedAt field not found")
}
})
t.Run("nil embedded pointer is skipped", func(t *testing.T) {
model := ModelWithNilEmbedded{
ID: 1,
Name: "Test",
EmbeddedBase: nil, // nil embedded pointer
}
details := GetModelColumnDetail(reflect.ValueOf(model))
// Should have 2 fields: ID and Name (embedded is nil, so skipped)
if len(details) != 2 {
t.Errorf("Expected 2 fields (nil embedded skipped), got %d", len(details))
}
})
t.Run("pointer to struct", func(t *testing.T) {
model := &TestModelForColumnDetail{
ID: 1,
Name: "Test",
}
details := GetModelColumnDetail(reflect.ValueOf(model))
if len(details) != 5 {
t.Errorf("Expected 5 fields, got %d", len(details))
}
})
t.Run("invalid value", func(t *testing.T) {
var invalid reflect.Value
details := GetModelColumnDetail(invalid)
if len(details) != 0 {
t.Errorf("Expected 0 fields for invalid value, got %d", len(details))
}
})
t.Run("non-struct type", func(t *testing.T) {
details := GetModelColumnDetail(reflect.ValueOf(123))
if len(details) != 0 {
t.Errorf("Expected 0 fields for non-struct, got %d", len(details))
}
})
t.Run("nullable and not null detection", func(t *testing.T) {
model := TestModelForColumnDetail{}
details := GetModelColumnDetail(reflect.ValueOf(model))
for _, detail := range details {
switch detail.Name {
case "ID":
if detail.Nullable {
t.Errorf("ID should not be nullable (has 'not null')")
}
case "Name":
if detail.Nullable {
t.Errorf("Name should not be nullable (has 'not null')")
}
case "Email":
if !detail.Nullable {
t.Errorf("Email should be nullable (has 'nullable')")
}
case "Description":
if !detail.Nullable {
t.Errorf("Description should be nullable (has 'null')")
}
}
}
})
t.Run("unique and uniqueindex detection", func(t *testing.T) {
type UniqueTestModel struct {
ID int `gorm:"column:id;primary_key"`
Username string `gorm:"column:username;unique"`
Email string `gorm:"column:email;uniqueindex"`
}
model := UniqueTestModel{}
details := GetModelColumnDetail(reflect.ValueOf(model))
for _, detail := range details {
switch detail.Name {
case "ID":
if detail.SQLKey != "primary_key" {
t.Errorf("ID should have SQLKey 'primary_key', got '%s'", detail.SQLKey)
}
case "Username":
if detail.SQLKey != "unique" {
t.Errorf("Username should have SQLKey 'unique', got '%s'", detail.SQLKey)
}
case "Email":
// The function checks for "unique" first, so uniqueindex is also detected as "unique"
// This is expected behavior based on the code logic
if detail.SQLKey != "unique" {
t.Errorf("Email should have SQLKey 'unique' (uniqueindex contains 'unique'), got '%s'", detail.SQLKey)
}
}
}
})
t.Run("foreign key detection", func(t *testing.T) {
// Note: The foreignkey extraction in generic_model.go has a bug where
// it requires ik > 0, so foreignkey at the start won't extract the value
type FKTestModel struct {
ParentID int `gorm:"column:parent_id;foreignkey:rid_parent;association_foreignkey:id_atevent"`
}
model := FKTestModel{}
details := GetModelColumnDetail(reflect.ValueOf(model))
if len(details) == 0 {
t.Fatal("Expected at least 1 field")
}
detail := details[0]
if detail.SQLKey != "foreign_key" {
t.Errorf("Expected SQLKey 'foreign_key', got '%s'", detail.SQLKey)
}
// Due to the bug in the code (requires ik > 0), the SQLName will be extracted
// when foreignkey is not at the beginning of the string
if detail.SQLName != "rid_parent" {
t.Errorf("Expected SQLName 'rid_parent', got '%s'", detail.SQLName)
}
})
}
func TestFnFindKeyVal(t *testing.T) {
tests := []struct {
name string
src string
key string
expected string
}{
{
name: "find column",
src: "column:user_id;primaryKey;type:bigint",
key: "column:",
expected: "user_id",
},
{
name: "find type",
src: "column:name;type:varchar(255);not null",
key: "type:",
expected: "varchar(255)",
},
{
name: "key not found",
src: "primaryKey;autoIncrement",
key: "column:",
expected: "",
},
{
name: "key at end without semicolon",
src: "primaryKey;column:id",
key: "column:",
expected: "id",
},
{
name: "case insensitive search",
src: "Column:user_id;primaryKey",
key: "column:",
expected: "user_id",
},
{
name: "empty src",
src: "",
key: "column:",
expected: "",
},
{
name: "multiple occurrences (returns first)",
src: "column:first;column:second",
key: "column:",
expected: "first",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := fnFindKeyVal(tt.src, tt.key)
if result != tt.expected {
t.Errorf("fnFindKeyVal(%q, %q) = %q, want %q", tt.src, tt.key, result, tt.expected)
}
})
}
}
func TestGetModelColumnDetail_FieldValue(t *testing.T) {
model := TestModelForColumnDetail{
ID: 123,
Name: "TestName",
Email: "test@example.com",
}
details := GetModelColumnDetail(reflect.ValueOf(model))
for _, detail := range details {
if !detail.FieldValue.IsValid() {
t.Errorf("Field %s has invalid FieldValue", detail.Name)
}
// Check that FieldValue matches the actual value
switch detail.Name {
case "ID":
if detail.FieldValue.Int() != 123 {
t.Errorf("Expected ID FieldValue 123, got %v", detail.FieldValue.Int())
}
case "Name":
if detail.FieldValue.String() != "TestName" {
t.Errorf("Expected Name FieldValue 'TestName', got %v", detail.FieldValue.String())
}
case "Email":
if detail.FieldValue.String() != "test@example.com" {
t.Errorf("Expected Email FieldValue 'test@example.com', got %v", detail.FieldValue.String())
}
}
}
}

View File

@@ -750,6 +750,118 @@ func ConvertToNumericType(value string, kind reflect.Kind) (interface{}, error)
return nil, fmt.Errorf("unsupported numeric type: %v", kind) return nil, fmt.Errorf("unsupported numeric type: %v", kind)
} }
// RelationType represents the type of database relationship
type RelationType string
const (
RelationHasMany RelationType = "has-many" // 1:N - use separate query
RelationBelongsTo RelationType = "belongs-to" // N:1 - use JOIN
RelationHasOne RelationType = "has-one" // 1:1 - use JOIN
RelationManyToMany RelationType = "many-to-many" // M:N - use separate query
RelationUnknown RelationType = "unknown"
)
// ShouldUseJoin returns true if the relation type should use a JOIN instead of separate query
func (rt RelationType) ShouldUseJoin() bool {
return rt == RelationBelongsTo || rt == RelationHasOne
}
// GetRelationType inspects the model's struct tags to determine the relationship type
// It checks both Bun and GORM tags to identify the relationship cardinality
func GetRelationType(model interface{}, fieldName string) RelationType {
if model == nil || fieldName == "" {
return RelationUnknown
}
modelType := reflect.TypeOf(model)
if modelType == nil {
return RelationUnknown
}
if modelType.Kind() == reflect.Ptr {
modelType = modelType.Elem()
}
if modelType == nil || modelType.Kind() != reflect.Struct {
return RelationUnknown
}
// Find the field
for i := 0; i < modelType.NumField(); i++ {
field := modelType.Field(i)
// Check if field name matches (case-insensitive)
if !strings.EqualFold(field.Name, fieldName) {
continue
}
// Check Bun tags first
bunTag := field.Tag.Get("bun")
if bunTag != "" && strings.Contains(bunTag, "rel:") {
// Parse bun relation tag: rel:has-many, rel:belongs-to, rel:has-one, rel:many-to-many
parts := strings.Split(bunTag, ",")
for _, part := range parts {
part = strings.TrimSpace(part)
if strings.HasPrefix(part, "rel:") {
relType := strings.TrimPrefix(part, "rel:")
switch relType {
case "has-many":
return RelationHasMany
case "belongs-to":
return RelationBelongsTo
case "has-one":
return RelationHasOne
case "many-to-many", "m2m":
return RelationManyToMany
}
}
}
}
// Check GORM tags
gormTag := field.Tag.Get("gorm")
if gormTag != "" {
// GORM uses different patterns:
// - foreignKey: usually indicates belongs-to or has-one
// - many2many: indicates many-to-many
// - Field type (slice vs pointer) helps determine cardinality
if strings.Contains(gormTag, "many2many:") {
return RelationManyToMany
}
// Check field type for cardinality hints
fieldType := field.Type
if fieldType.Kind() == reflect.Slice {
// Slice indicates has-many or many-to-many
return RelationHasMany
}
if fieldType.Kind() == reflect.Ptr {
// Pointer to single struct usually indicates belongs-to or has-one
// Check if it has foreignKey (belongs-to) or references (has-one)
if strings.Contains(gormTag, "foreignKey:") {
return RelationBelongsTo
}
return RelationHasOne
}
}
// Fall back to field type inference
fieldType := field.Type
if fieldType.Kind() == reflect.Slice {
// Slice of structs → has-many
return RelationHasMany
}
if fieldType.Kind() == reflect.Ptr || fieldType.Kind() == reflect.Struct {
// Single struct → belongs-to (default assumption for safety)
// Using belongs-to as default ensures we use JOIN, which is safer
return RelationBelongsTo
}
}
return RelationUnknown
}
// GetRelationModel gets the model type for a relation field // GetRelationModel gets the model type for a relation field
// It searches for the field by name in the following order (case-insensitive): // It searches for the field by name in the following order (case-insensitive):
// 1. Actual field name // 1. Actual field name

File diff suppressed because it is too large Load Diff

View File

@@ -316,7 +316,7 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
// Apply cursor filter to query // Apply cursor filter to query
if cursorFilter != "" { if cursorFilter != "" {
logger.Debug("Applying cursor filter: %s", cursorFilter) logger.Debug("Applying cursor filter: %s", cursorFilter)
sanitizedCursor := common.SanitizeWhereClause(cursorFilter, reflection.ExtractTableNameOnly(tableName)) sanitizedCursor := common.SanitizeWhereClause(cursorFilter, reflection.ExtractTableNameOnly(tableName), &options)
if sanitizedCursor != "" { if sanitizedCursor != "" {
query = query.Where(sanitizedCursor) query = query.Where(sanitizedCursor)
} }
@@ -1351,7 +1351,9 @@ func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, pre
} }
if len(preload.Where) > 0 { if len(preload.Where) > 0 {
sanitizedWhere := common.SanitizeWhereClause(preload.Where, reflection.ExtractTableNameOnly(preload.Relation)) // Build RequestOptions with all preloads to allow references to sibling relations
preloadOpts := &common.RequestOptions{Preload: preloads}
sanitizedWhere := common.SanitizeWhereClause(preload.Where, reflection.ExtractTableNameOnly(preload.Relation), preloadOpts)
if len(sanitizedWhere) > 0 { if len(sanitizedWhere) > 0 {
sq = sq.Where(sanitizedWhere) sq = sq.Where(sanitizedWhere)
} }

View File

@@ -450,7 +450,7 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
} }
// Apply the preload with recursive support // Apply the preload with recursive support
query = h.applyPreloadWithRecursion(query, preload, model, 0) query = h.applyPreloadWithRecursion(query, preload, options.Preload, model, 0)
} }
// Apply DISTINCT if requested // Apply DISTINCT if requested
@@ -480,8 +480,8 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
// Apply custom SQL WHERE clause (AND condition) // Apply custom SQL WHERE clause (AND condition)
if options.CustomSQLWhere != "" { if options.CustomSQLWhere != "" {
logger.Debug("Applying custom SQL WHERE: %s", options.CustomSQLWhere) logger.Debug("Applying custom SQL WHERE: %s", options.CustomSQLWhere)
// Sanitize without auto-prefixing since custom SQL may reference multiple tables // Sanitize and allow preload table prefixes since custom SQL may reference multiple tables
sanitizedWhere := common.SanitizeWhereClause(options.CustomSQLWhere, reflection.ExtractTableNameOnly(tableName)) sanitizedWhere := common.SanitizeWhereClause(options.CustomSQLWhere, reflection.ExtractTableNameOnly(tableName), &options.RequestOptions)
if sanitizedWhere != "" { if sanitizedWhere != "" {
query = query.Where(sanitizedWhere) query = query.Where(sanitizedWhere)
} }
@@ -490,8 +490,8 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
// Apply custom SQL WHERE clause (OR condition) // Apply custom SQL WHERE clause (OR condition)
if options.CustomSQLOr != "" { if options.CustomSQLOr != "" {
logger.Debug("Applying custom SQL OR: %s", options.CustomSQLOr) logger.Debug("Applying custom SQL OR: %s", options.CustomSQLOr)
// Sanitize without auto-prefixing since custom SQL may reference multiple tables // Sanitize and allow preload table prefixes since custom SQL may reference multiple tables
sanitizedOr := common.SanitizeWhereClause(options.CustomSQLOr, reflection.ExtractTableNameOnly(tableName)) sanitizedOr := common.SanitizeWhereClause(options.CustomSQLOr, reflection.ExtractTableNameOnly(tableName), &options.RequestOptions)
if sanitizedOr != "" { if sanitizedOr != "" {
query = query.WhereOr(sanitizedOr) query = query.WhereOr(sanitizedOr)
} }
@@ -625,7 +625,7 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
// Apply cursor filter to query // Apply cursor filter to query
if cursorFilter != "" { if cursorFilter != "" {
logger.Debug("Applying cursor filter: %s", cursorFilter) logger.Debug("Applying cursor filter: %s", cursorFilter)
sanitizedCursor := common.SanitizeWhereClause(cursorFilter, reflection.ExtractTableNameOnly(tableName)) sanitizedCursor := common.SanitizeWhereClause(cursorFilter, reflection.ExtractTableNameOnly(tableName), &options.RequestOptions)
if sanitizedCursor != "" { if sanitizedCursor != "" {
query = query.Where(sanitizedCursor) query = query.Where(sanitizedCursor)
} }
@@ -703,7 +703,7 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
} }
// applyPreloadWithRecursion applies a preload with support for ComputedQL and recursive preloading // applyPreloadWithRecursion applies a preload with support for ComputedQL and recursive preloading
func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload common.PreloadOption, model interface{}, depth int) common.SelectQuery { func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload common.PreloadOption, allPreloads []common.PreloadOption, model interface{}, depth int) common.SelectQuery {
// Log relationship keys if they're specified (from XFiles) // Log relationship keys if they're specified (from XFiles)
if preload.RelatedKey != "" || preload.ForeignKey != "" || preload.PrimaryKey != "" { if preload.RelatedKey != "" || preload.ForeignKey != "" || preload.PrimaryKey != "" {
logger.Debug("Preload %s has relationship keys - PK: %s, RelatedKey: %s, ForeignKey: %s", logger.Debug("Preload %s has relationship keys - PK: %s, RelatedKey: %s, ForeignKey: %s",
@@ -799,7 +799,9 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co
// Apply WHERE clause // Apply WHERE clause
if len(preload.Where) > 0 { if len(preload.Where) > 0 {
sanitizedWhere := common.SanitizeWhereClause(preload.Where, reflection.ExtractTableNameOnly(preload.Relation)) // Build RequestOptions with all preloads to allow references to sibling relations
preloadOpts := &common.RequestOptions{Preload: allPreloads}
sanitizedWhere := common.SanitizeWhereClause(preload.Where, reflection.ExtractTableNameOnly(preload.Relation), preloadOpts)
if len(sanitizedWhere) > 0 { if len(sanitizedWhere) > 0 {
sq = sq.Where(sanitizedWhere) sq = sq.Where(sanitizedWhere)
} }
@@ -832,7 +834,7 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co
recursivePreload.Relation = preload.Relation + "." + lastRelationName recursivePreload.Relation = preload.Relation + "." + lastRelationName
// Recursively apply preload until we reach depth 5 // Recursively apply preload until we reach depth 5
query = h.applyPreloadWithRecursion(query, recursivePreload, model, depth+1) query = h.applyPreloadWithRecursion(query, recursivePreload, allPreloads, model, depth+1)
} }
return query return query

View File

@@ -11,6 +11,7 @@ import (
"time" "time"
"github.com/bitechdev/ResolveSpec/pkg/cache" "github.com/bitechdev/ResolveSpec/pkg/cache"
"github.com/bitechdev/ResolveSpec/pkg/logger"
) )
// Production-Ready Authenticators // Production-Ready Authenticators
@@ -169,69 +170,98 @@ func (a *DatabaseAuthenticator) Authenticate(r *http.Request) (*UserContext, err
// Extract session token from header or cookie // Extract session token from header or cookie
sessionToken := r.Header.Get("Authorization") sessionToken := r.Header.Get("Authorization")
reference := "authenticate" reference := "authenticate"
var tokens []string
if sessionToken == "" { if sessionToken == "" {
// Try cookie // Try cookie
cookie, err := r.Cookie("session_token") cookie, err := r.Cookie("session_token")
if err == nil { if err == nil {
sessionToken = cookie.Value tokens = []string{cookie.Value}
reference = "cookie" reference = "cookie"
} }
} else { } else {
// Remove "Bearer " prefix if present // Parse Authorization header which may contain multiple comma-separated tokens
sessionToken = strings.TrimPrefix(sessionToken, "Bearer ") // Format: "Token abc, Token def" or "Bearer abc" or just "abc"
// Remove "Token " prefix if present rawTokens := strings.Split(sessionToken, ",")
sessionToken = strings.TrimPrefix(sessionToken, "Token ") for _, token := range rawTokens {
token = strings.TrimSpace(token)
// Remove "Bearer " prefix if present
token = strings.TrimPrefix(token, "Bearer ")
// Remove "Token " prefix if present
token = strings.TrimPrefix(token, "Token ")
token = strings.TrimSpace(token)
if token != "" {
tokens = append(tokens, token)
}
}
} }
if sessionToken == "" { if len(tokens) == 0 {
return nil, fmt.Errorf("session token required") return nil, fmt.Errorf("session token required")
} }
// Build cache key // Log warning if multiple tokens are provided
cacheKey := fmt.Sprintf("auth:session:%s", sessionToken) if len(tokens) > 1 {
logger.Warn("Multiple authentication tokens provided in Authorization header (%d tokens). This is unusual and may indicate a misconfigured client. Header: %s", len(tokens), sessionToken)
// Use cache.GetOrSet to get from cache or load from database
var userCtx UserContext
err := a.cache.GetOrSet(r.Context(), cacheKey, &userCtx, a.cacheTTL, func() (interface{}, error) {
// This function is called only if cache miss
var success bool
var errorMsg sql.NullString
var userJSON sql.NullString
query := `SELECT p_success, p_error, p_user::text FROM resolvespec_session($1, $2)`
err := a.db.QueryRowContext(r.Context(), query, sessionToken, reference).Scan(&success, &errorMsg, &userJSON)
if err != nil {
return nil, fmt.Errorf("session query failed: %w", err)
}
if !success {
if errorMsg.Valid {
return nil, fmt.Errorf("%s", errorMsg.String)
}
return nil, fmt.Errorf("invalid or expired session")
}
if !userJSON.Valid {
return nil, fmt.Errorf("no user data in session")
}
// Parse UserContext
var user UserContext
if err := json.Unmarshal([]byte(userJSON.String), &user); err != nil {
return nil, fmt.Errorf("failed to parse user context: %w", err)
}
return &user, nil
})
if err != nil {
return nil, err
} }
// Update last activity timestamp asynchronously // Try each token until one succeeds
go a.updateSessionActivity(r.Context(), sessionToken, &userCtx) var lastErr error
for _, token := range tokens {
// Build cache key
cacheKey := fmt.Sprintf("auth:session:%s", token)
return &userCtx, nil // Use cache.GetOrSet to get from cache or load from database
var userCtx UserContext
err := a.cache.GetOrSet(r.Context(), cacheKey, &userCtx, a.cacheTTL, func() (any, error) {
// This function is called only if cache miss
var success bool
var errorMsg sql.NullString
var userJSON sql.NullString
query := `SELECT p_success, p_error, p_user::text FROM resolvespec_session($1, $2)`
err := a.db.QueryRowContext(r.Context(), query, token, reference).Scan(&success, &errorMsg, &userJSON)
if err != nil {
return nil, fmt.Errorf("session query failed: %w", err)
}
if !success {
if errorMsg.Valid {
return nil, fmt.Errorf("%s", errorMsg.String)
}
return nil, fmt.Errorf("invalid or expired session")
}
if !userJSON.Valid {
return nil, fmt.Errorf("no user data in session")
}
// Parse UserContext
var user UserContext
if err := json.Unmarshal([]byte(userJSON.String), &user); err != nil {
return nil, fmt.Errorf("failed to parse user context: %w", err)
}
return &user, nil
})
if err != nil {
lastErr = err
continue // Try next token
}
// Authentication succeeded with this token
// Update last activity timestamp asynchronously
go a.updateSessionActivity(r.Context(), token, &userCtx)
return &userCtx, nil
}
// All tokens failed
if lastErr != nil {
return nil, lastErr
}
return nil, fmt.Errorf("authentication failed for all provided tokens")
} }
// ClearCache removes a specific token from the cache or clears all cache if token is empty // ClearCache removes a specific token from the cache or clears all cache if token is empty

View File

@@ -545,6 +545,96 @@ func TestDatabaseAuthenticator(t *testing.T) {
t.Fatal("expected error when token is missing") t.Fatal("expected error when token is missing")
} }
}) })
t.Run("authenticate with multiple comma-separated tokens", func(t *testing.T) {
req := httptest.NewRequest("GET", "/test", nil)
req.Header.Set("Authorization", "Token invalid-token, Token valid-token-123")
// First token fails
rows1 := sqlmock.NewRows([]string{"p_success", "p_error", "p_user"}).
AddRow(false, "Invalid token", nil)
mock.ExpectQuery(`SELECT p_success, p_error, p_user::text FROM resolvespec_session`).
WithArgs("invalid-token", "authenticate").
WillReturnRows(rows1)
// Second token succeeds
rows2 := sqlmock.NewRows([]string{"p_success", "p_error", "p_user"}).
AddRow(true, nil, `{"user_id":3,"user_name":"multitoken","session_id":"valid-token-123"}`)
mock.ExpectQuery(`SELECT p_success, p_error, p_user::text FROM resolvespec_session`).
WithArgs("valid-token-123", "authenticate").
WillReturnRows(rows2)
userCtx, err := auth.Authenticate(req)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if userCtx.UserID != 3 {
t.Errorf("expected UserID 3, got %d", userCtx.UserID)
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("unfulfilled expectations: %v", err)
}
})
t.Run("authenticate with duplicate tokens", func(t *testing.T) {
req := httptest.NewRequest("GET", "/test", nil)
req.Header.Set("Authorization", "Token 968CA5AE-4F83-4D55-A3C6-51AE4410E03A, Token 968CA5AE-4F83-4D55-A3C6-51AE4410E03A")
// First token succeeds
rows := sqlmock.NewRows([]string{"p_success", "p_error", "p_user"}).
AddRow(true, nil, `{"user_id":4,"user_name":"duplicateuser","session_id":"968CA5AE-4F83-4D55-A3C6-51AE4410E03A"}`)
mock.ExpectQuery(`SELECT p_success, p_error, p_user::text FROM resolvespec_session`).
WithArgs("968CA5AE-4F83-4D55-A3C6-51AE4410E03A", "authenticate").
WillReturnRows(rows)
userCtx, err := auth.Authenticate(req)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if userCtx.UserID != 4 {
t.Errorf("expected UserID 4, got %d", userCtx.UserID)
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("unfulfilled expectations: %v", err)
}
})
t.Run("authenticate with all tokens failing", func(t *testing.T) {
req := httptest.NewRequest("GET", "/test", nil)
req.Header.Set("Authorization", "Token bad-token-1, Token bad-token-2")
// First token fails
rows1 := sqlmock.NewRows([]string{"p_success", "p_error", "p_user"}).
AddRow(false, "Invalid token", nil)
mock.ExpectQuery(`SELECT p_success, p_error, p_user::text FROM resolvespec_session`).
WithArgs("bad-token-1", "authenticate").
WillReturnRows(rows1)
// Second token also fails
rows2 := sqlmock.NewRows([]string{"p_success", "p_error", "p_user"}).
AddRow(false, "Invalid token", nil)
mock.ExpectQuery(`SELECT p_success, p_error, p_user::text FROM resolvespec_session`).
WithArgs("bad-token-2", "authenticate").
WillReturnRows(rows2)
_, err := auth.Authenticate(req)
if err == nil {
t.Fatal("expected error when all tokens fail")
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("unfulfilled expectations: %v", err)
}
})
} }
// Test DatabaseAuthenticator RefreshToken // Test DatabaseAuthenticator RefreshToken