mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2025-12-30 08:14:25 +00:00
Compare commits
7 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f0962ea1ec | ||
|
|
8fcb065b42 | ||
|
|
dc3b621380 | ||
|
|
a4dd2a7086 | ||
|
|
3ec2e5f15a | ||
|
|
c52afe2825 | ||
|
|
76e98d02c3 |
@@ -6,6 +6,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/uptrace/bun"
|
"github.com/uptrace/bun"
|
||||||
|
|
||||||
@@ -15,6 +16,24 @@ import (
|
|||||||
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// QueryDebugHook is a Bun query hook that logs all SQL queries including preloads
|
||||||
|
type QueryDebugHook struct{}
|
||||||
|
|
||||||
|
func (h *QueryDebugHook) BeforeQuery(ctx context.Context, event *bun.QueryEvent) context.Context {
|
||||||
|
return ctx
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *QueryDebugHook) AfterQuery(ctx context.Context, event *bun.QueryEvent) {
|
||||||
|
query := event.Query
|
||||||
|
duration := time.Since(event.StartTime)
|
||||||
|
|
||||||
|
if event.Err != nil {
|
||||||
|
logger.Error("SQL Query Failed [%s]: %s. Error: %v", duration, query, event.Err)
|
||||||
|
} else {
|
||||||
|
logger.Debug("SQL Query Success [%s]: %s", duration, query)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// BunAdapter adapts Bun to work with our Database interface
|
// BunAdapter adapts Bun to work with our Database interface
|
||||||
// This demonstrates how the abstraction works with different ORMs
|
// This demonstrates how the abstraction works with different ORMs
|
||||||
type BunAdapter struct {
|
type BunAdapter struct {
|
||||||
@@ -26,6 +45,20 @@ func NewBunAdapter(db *bun.DB) *BunAdapter {
|
|||||||
return &BunAdapter{db: db}
|
return &BunAdapter{db: db}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// EnableQueryDebug enables query debugging which logs all SQL queries including preloads
|
||||||
|
// This is useful for debugging preload queries that may be failing
|
||||||
|
func (b *BunAdapter) EnableQueryDebug() {
|
||||||
|
b.db.AddQueryHook(&QueryDebugHook{})
|
||||||
|
logger.Info("Bun query debug mode enabled - all SQL queries will be logged")
|
||||||
|
}
|
||||||
|
|
||||||
|
// DisableQueryDebug removes all query hooks
|
||||||
|
func (b *BunAdapter) DisableQueryDebug() {
|
||||||
|
// Create a new DB without hooks
|
||||||
|
// Note: Bun doesn't have a RemoveQueryHook, so we'd need to track hooks manually
|
||||||
|
logger.Info("To disable query debug, recreate the BunAdapter without adding the hook")
|
||||||
|
}
|
||||||
|
|
||||||
func (b *BunAdapter) NewSelect() common.SelectQuery {
|
func (b *BunAdapter) NewSelect() common.SelectQuery {
|
||||||
return &BunSelectQuery{
|
return &BunSelectQuery{
|
||||||
query: b.db.NewSelect(),
|
query: b.db.NewSelect(),
|
||||||
@@ -410,6 +443,9 @@ func (b *BunSelectQuery) Scan(ctx context.Context, dest interface{}) (err error)
|
|||||||
// Execute the main query first
|
// Execute the main query first
|
||||||
err = b.query.Scan(ctx, dest)
|
err = b.query.Scan(ctx, dest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
// Log SQL string for debugging
|
||||||
|
sqlStr := b.query.String()
|
||||||
|
logger.Error("BunSelectQuery.Scan failed. SQL: %s. Error: %v", sqlStr, err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -438,6 +474,9 @@ func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) {
|
|||||||
// Execute the main query first
|
// Execute the main query first
|
||||||
err = b.query.Scan(ctx)
|
err = b.query.Scan(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
// Log SQL string for debugging
|
||||||
|
sqlStr := b.query.String()
|
||||||
|
logger.Error("BunSelectQuery.ScanModel failed. SQL: %s. Error: %v", sqlStr, err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -573,15 +612,25 @@ func (b *BunSelectQuery) Count(ctx context.Context) (count int, err error) {
|
|||||||
// If Model() was set, use bun's native Count() which works properly
|
// If Model() was set, use bun's native Count() which works properly
|
||||||
if b.hasModel {
|
if b.hasModel {
|
||||||
count, err := b.query.Count(ctx)
|
count, err := b.query.Count(ctx)
|
||||||
|
if err != nil {
|
||||||
|
// Log SQL string for debugging
|
||||||
|
sqlStr := b.query.String()
|
||||||
|
logger.Error("BunSelectQuery.Count failed. SQL: %s. Error: %v", sqlStr, err)
|
||||||
|
}
|
||||||
return count, err
|
return count, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Otherwise, wrap as subquery to avoid "Model(nil)" error
|
// Otherwise, wrap as subquery to avoid "Model(nil)" error
|
||||||
// This is needed when only Table() is set without a model
|
// This is needed when only Table() is set without a model
|
||||||
err = b.db.NewSelect().
|
countQuery := b.db.NewSelect().
|
||||||
TableExpr("(?) AS subquery", b.query).
|
TableExpr("(?) AS subquery", b.query).
|
||||||
ColumnExpr("COUNT(*)").
|
ColumnExpr("COUNT(*)")
|
||||||
Scan(ctx, &count)
|
err = countQuery.Scan(ctx, &count)
|
||||||
|
if err != nil {
|
||||||
|
// Log SQL string for debugging
|
||||||
|
sqlStr := countQuery.String()
|
||||||
|
logger.Error("BunSelectQuery.Count (subquery) failed. SQL: %s. Error: %v", sqlStr, err)
|
||||||
|
}
|
||||||
return count, err
|
return count, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -592,7 +641,13 @@ func (b *BunSelectQuery) Exists(ctx context.Context) (exists bool, err error) {
|
|||||||
exists = false
|
exists = false
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
return b.query.Exists(ctx)
|
exists, err = b.query.Exists(ctx)
|
||||||
|
if err != nil {
|
||||||
|
// Log SQL string for debugging
|
||||||
|
sqlStr := b.query.String()
|
||||||
|
logger.Error("BunSelectQuery.Exists failed. SQL: %s. Error: %v", sqlStr, err)
|
||||||
|
}
|
||||||
|
return exists, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// BunInsertQuery implements InsertQuery for Bun
|
// BunInsertQuery implements InsertQuery for Bun
|
||||||
@@ -729,6 +784,11 @@ func (b *BunUpdateQuery) Exec(ctx context.Context) (res common.Result, err error
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
result, err := b.query.Exec(ctx)
|
result, err := b.query.Exec(ctx)
|
||||||
|
if err != nil {
|
||||||
|
// Log SQL string for debugging
|
||||||
|
sqlStr := b.query.String()
|
||||||
|
logger.Error("BunUpdateQuery.Exec failed. SQL: %s. Error: %v", sqlStr, err)
|
||||||
|
}
|
||||||
return &BunResult{result: result}, err
|
return &BunResult{result: result}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -759,6 +819,11 @@ func (b *BunDeleteQuery) Exec(ctx context.Context) (res common.Result, err error
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
result, err := b.query.Exec(ctx)
|
result, err := b.query.Exec(ctx)
|
||||||
|
if err != nil {
|
||||||
|
// Log SQL string for debugging
|
||||||
|
sqlStr := b.query.String()
|
||||||
|
logger.Error("BunDeleteQuery.Exec failed. SQL: %s. Error: %v", sqlStr, err)
|
||||||
|
}
|
||||||
return &BunResult{result: result}, err
|
return &BunResult{result: result}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -23,6 +23,22 @@ func NewGormAdapter(db *gorm.DB) *GormAdapter {
|
|||||||
return &GormAdapter{db: db}
|
return &GormAdapter{db: db}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// EnableQueryDebug enables query debugging which logs all SQL queries including preloads
|
||||||
|
// This is useful for debugging preload queries that may be failing
|
||||||
|
func (g *GormAdapter) EnableQueryDebug() *GormAdapter {
|
||||||
|
g.db = g.db.Debug()
|
||||||
|
logger.Info("GORM query debug mode enabled - all SQL queries will be logged")
|
||||||
|
return g
|
||||||
|
}
|
||||||
|
|
||||||
|
// DisableQueryDebug disables query debugging
|
||||||
|
func (g *GormAdapter) DisableQueryDebug() *GormAdapter {
|
||||||
|
// GORM's Debug() creates a new session, so we need to get the base DB
|
||||||
|
// This is a simplified implementation
|
||||||
|
logger.Info("GORM debug mode - create a new adapter without Debug() to disable")
|
||||||
|
return g
|
||||||
|
}
|
||||||
|
|
||||||
func (g *GormAdapter) NewSelect() common.SelectQuery {
|
func (g *GormAdapter) NewSelect() common.SelectQuery {
|
||||||
return &GormSelectQuery{db: g.db}
|
return &GormSelectQuery{db: g.db}
|
||||||
}
|
}
|
||||||
@@ -282,7 +298,15 @@ func (g *GormSelectQuery) Scan(ctx context.Context, dest interface{}) (err error
|
|||||||
err = logger.HandlePanic("GormSelectQuery.Scan", r)
|
err = logger.HandlePanic("GormSelectQuery.Scan", r)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
return g.db.WithContext(ctx).Find(dest).Error
|
err = g.db.WithContext(ctx).Find(dest).Error
|
||||||
|
if err != nil {
|
||||||
|
// Log SQL string for debugging
|
||||||
|
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
||||||
|
return tx.Find(dest)
|
||||||
|
})
|
||||||
|
logger.Error("GormSelectQuery.Scan failed. SQL: %s. Error: %v", sqlStr, err)
|
||||||
|
}
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GormSelectQuery) ScanModel(ctx context.Context) (err error) {
|
func (g *GormSelectQuery) ScanModel(ctx context.Context) (err error) {
|
||||||
@@ -294,7 +318,15 @@ func (g *GormSelectQuery) ScanModel(ctx context.Context) (err error) {
|
|||||||
if g.db.Statement.Model == nil {
|
if g.db.Statement.Model == nil {
|
||||||
return fmt.Errorf("ScanModel requires Model() to be set before scanning")
|
return fmt.Errorf("ScanModel requires Model() to be set before scanning")
|
||||||
}
|
}
|
||||||
return g.db.WithContext(ctx).Find(g.db.Statement.Model).Error
|
err = g.db.WithContext(ctx).Find(g.db.Statement.Model).Error
|
||||||
|
if err != nil {
|
||||||
|
// Log SQL string for debugging
|
||||||
|
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
||||||
|
return tx.Find(g.db.Statement.Model)
|
||||||
|
})
|
||||||
|
logger.Error("GormSelectQuery.ScanModel failed. SQL: %s. Error: %v", sqlStr, err)
|
||||||
|
}
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GormSelectQuery) Count(ctx context.Context) (count int, err error) {
|
func (g *GormSelectQuery) Count(ctx context.Context) (count int, err error) {
|
||||||
@@ -306,6 +338,13 @@ func (g *GormSelectQuery) Count(ctx context.Context) (count int, err error) {
|
|||||||
}()
|
}()
|
||||||
var count64 int64
|
var count64 int64
|
||||||
err = g.db.WithContext(ctx).Count(&count64).Error
|
err = g.db.WithContext(ctx).Count(&count64).Error
|
||||||
|
if err != nil {
|
||||||
|
// Log SQL string for debugging
|
||||||
|
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
||||||
|
return tx.Count(&count64)
|
||||||
|
})
|
||||||
|
logger.Error("GormSelectQuery.Count failed. SQL: %s. Error: %v", sqlStr, err)
|
||||||
|
}
|
||||||
return int(count64), err
|
return int(count64), err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -318,6 +357,13 @@ func (g *GormSelectQuery) Exists(ctx context.Context) (exists bool, err error) {
|
|||||||
}()
|
}()
|
||||||
var count int64
|
var count int64
|
||||||
err = g.db.WithContext(ctx).Limit(1).Count(&count).Error
|
err = g.db.WithContext(ctx).Limit(1).Count(&count).Error
|
||||||
|
if err != nil {
|
||||||
|
// Log SQL string for debugging
|
||||||
|
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
||||||
|
return tx.Limit(1).Count(&count)
|
||||||
|
})
|
||||||
|
logger.Error("GormSelectQuery.Exists failed. SQL: %s. Error: %v", sqlStr, err)
|
||||||
|
}
|
||||||
return count > 0, err
|
return count > 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -456,6 +502,13 @@ func (g *GormUpdateQuery) Exec(ctx context.Context) (res common.Result, err erro
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
result := g.db.WithContext(ctx).Updates(g.updates)
|
result := g.db.WithContext(ctx).Updates(g.updates)
|
||||||
|
if result.Error != nil {
|
||||||
|
// Log SQL string for debugging
|
||||||
|
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
||||||
|
return tx.Updates(g.updates)
|
||||||
|
})
|
||||||
|
logger.Error("GormUpdateQuery.Exec failed. SQL: %s. Error: %v", sqlStr, result.Error)
|
||||||
|
}
|
||||||
return &GormResult{result: result}, result.Error
|
return &GormResult{result: result}, result.Error
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -488,6 +541,13 @@ func (g *GormDeleteQuery) Exec(ctx context.Context) (res common.Result, err erro
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
result := g.db.WithContext(ctx).Delete(g.model)
|
result := g.db.WithContext(ctx).Delete(g.model)
|
||||||
|
if result.Error != nil {
|
||||||
|
// Log SQL string for debugging
|
||||||
|
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
||||||
|
return tx.Delete(g.model)
|
||||||
|
})
|
||||||
|
logger.Error("GormDeleteQuery.Exec failed. SQL: %s. Error: %v", sqlStr, result.Error)
|
||||||
|
}
|
||||||
return &GormResult{result: result}, result.Error
|
return &GormResult{result: result}, result.Error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -9,18 +9,18 @@ import (
|
|||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TestSqlInt16 tests SqlInt16 type
|
// TestNewSqlInt16 tests NewSqlInt16 type
|
||||||
func TestSqlInt16(t *testing.T) {
|
func TestNewSqlInt16(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
input interface{}
|
input interface{}
|
||||||
expected SqlInt16
|
expected SqlInt16
|
||||||
}{
|
}{
|
||||||
{"int", 42, SqlInt16(42)},
|
{"int", 42, Null(int16(42), true)},
|
||||||
{"int32", int32(100), SqlInt16(100)},
|
{"int32", int32(100), NewSqlInt16(100)},
|
||||||
{"int64", int64(200), SqlInt16(200)},
|
{"int64", int64(200), NewSqlInt16(200)},
|
||||||
{"string", "123", SqlInt16(123)},
|
{"string", "123", NewSqlInt16(123)},
|
||||||
{"nil", nil, SqlInt16(0)},
|
{"nil", nil, Null(int16(0), false)},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
@@ -36,15 +36,15 @@ func TestSqlInt16(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSqlInt16_Value(t *testing.T) {
|
func TestNewSqlInt16_Value(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
input SqlInt16
|
input SqlInt16
|
||||||
expected driver.Value
|
expected driver.Value
|
||||||
}{
|
}{
|
||||||
{"zero", SqlInt16(0), nil},
|
{"zero", Null(int16(0), false), nil},
|
||||||
{"positive", SqlInt16(42), int64(42)},
|
{"positive", NewSqlInt16(42), int16(42)},
|
||||||
{"negative", SqlInt16(-10), int64(-10)},
|
{"negative", NewSqlInt16(-10), int16(-10)},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
@@ -60,8 +60,8 @@ func TestSqlInt16_Value(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSqlInt16_JSON(t *testing.T) {
|
func TestNewSqlInt16_JSON(t *testing.T) {
|
||||||
n := SqlInt16(42)
|
n := NewSqlInt16(42)
|
||||||
|
|
||||||
// Marshal
|
// Marshal
|
||||||
data, err := json.Marshal(n)
|
data, err := json.Marshal(n)
|
||||||
@@ -78,24 +78,24 @@ func TestSqlInt16_JSON(t *testing.T) {
|
|||||||
if err := json.Unmarshal([]byte("123"), &n2); err != nil {
|
if err := json.Unmarshal([]byte("123"), &n2); err != nil {
|
||||||
t.Fatalf("Unmarshal failed: %v", err)
|
t.Fatalf("Unmarshal failed: %v", err)
|
||||||
}
|
}
|
||||||
if n2 != 123 {
|
if n2.Int64() != 123 {
|
||||||
t.Errorf("expected 123, got %d", n2)
|
t.Errorf("expected 123, got %d", n2.Int64())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestSqlInt64 tests SqlInt64 type
|
// TestNewSqlInt64 tests NewSqlInt64 type
|
||||||
func TestSqlInt64(t *testing.T) {
|
func TestNewSqlInt64(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
input interface{}
|
input interface{}
|
||||||
expected SqlInt64
|
expected SqlInt64
|
||||||
}{
|
}{
|
||||||
{"int", 42, SqlInt64(42)},
|
{"int", 42, NewSqlInt64(42)},
|
||||||
{"int32", int32(100), SqlInt64(100)},
|
{"int32", int32(100), NewSqlInt64(100)},
|
||||||
{"int64", int64(9223372036854775807), SqlInt64(9223372036854775807)},
|
{"int64", int64(9223372036854775807), NewSqlInt64(9223372036854775807)},
|
||||||
{"uint32", uint32(100), SqlInt64(100)},
|
{"uint32", uint32(100), NewSqlInt64(100)},
|
||||||
{"uint64", uint64(200), SqlInt64(200)},
|
{"uint64", uint64(200), NewSqlInt64(200)},
|
||||||
{"nil", nil, SqlInt64(0)},
|
{"nil", nil, SqlInt64{}},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
@@ -135,8 +135,8 @@ func TestSqlFloat64(t *testing.T) {
|
|||||||
if n.Valid != tt.valid {
|
if n.Valid != tt.valid {
|
||||||
t.Errorf("expected valid=%v, got valid=%v", tt.valid, n.Valid)
|
t.Errorf("expected valid=%v, got valid=%v", tt.valid, n.Valid)
|
||||||
}
|
}
|
||||||
if tt.valid && n.Float64 != tt.expected {
|
if tt.valid && n.Float64() != tt.expected {
|
||||||
t.Errorf("expected %v, got %v", tt.expected, n.Float64)
|
t.Errorf("expected %v, got %v", tt.expected, n.Float64())
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -162,7 +162,7 @@ func TestSqlTimeStamp(t *testing.T) {
|
|||||||
if err := ts.Scan(tt.input); err != nil {
|
if err := ts.Scan(tt.input); err != nil {
|
||||||
t.Fatalf("Scan failed: %v", err)
|
t.Fatalf("Scan failed: %v", err)
|
||||||
}
|
}
|
||||||
if ts.GetTime().IsZero() {
|
if ts.Time().IsZero() {
|
||||||
t.Error("expected non-zero time")
|
t.Error("expected non-zero time")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@@ -171,7 +171,7 @@ func TestSqlTimeStamp(t *testing.T) {
|
|||||||
|
|
||||||
func TestSqlTimeStamp_JSON(t *testing.T) {
|
func TestSqlTimeStamp_JSON(t *testing.T) {
|
||||||
now := time.Date(2024, 1, 15, 10, 30, 45, 0, time.UTC)
|
now := time.Date(2024, 1, 15, 10, 30, 45, 0, time.UTC)
|
||||||
ts := SqlTimeStamp(now)
|
ts := NewSqlTimeStamp(now)
|
||||||
|
|
||||||
// Marshal
|
// Marshal
|
||||||
data, err := json.Marshal(ts)
|
data, err := json.Marshal(ts)
|
||||||
@@ -188,8 +188,8 @@ func TestSqlTimeStamp_JSON(t *testing.T) {
|
|||||||
if err := json.Unmarshal([]byte(`"2024-01-15T10:30:45"`), &ts2); err != nil {
|
if err := json.Unmarshal([]byte(`"2024-01-15T10:30:45"`), &ts2); err != nil {
|
||||||
t.Fatalf("Unmarshal failed: %v", err)
|
t.Fatalf("Unmarshal failed: %v", err)
|
||||||
}
|
}
|
||||||
if ts2.GetTime().Year() != 2024 {
|
if ts2.Time().Year() != 2024 {
|
||||||
t.Errorf("expected year 2024, got %d", ts2.GetTime().Year())
|
t.Errorf("expected year 2024, got %d", ts2.Time().Year())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test null
|
// Test null
|
||||||
@@ -226,7 +226,7 @@ func TestSqlDate(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestSqlDate_JSON(t *testing.T) {
|
func TestSqlDate_JSON(t *testing.T) {
|
||||||
date := SqlDate(time.Date(2024, 1, 15, 0, 0, 0, 0, time.UTC))
|
date := NewSqlDate(time.Date(2024, 1, 15, 0, 0, 0, 0, time.UTC))
|
||||||
|
|
||||||
// Marshal
|
// Marshal
|
||||||
data, err := json.Marshal(date)
|
data, err := json.Marshal(date)
|
||||||
@@ -471,8 +471,8 @@ func TestSqlUUID_Scan(t *testing.T) {
|
|||||||
if u.Valid != tt.valid {
|
if u.Valid != tt.valid {
|
||||||
t.Errorf("expected valid=%v, got valid=%v", tt.valid, u.Valid)
|
t.Errorf("expected valid=%v, got valid=%v", tt.valid, u.Valid)
|
||||||
}
|
}
|
||||||
if tt.valid && u.String != tt.expected {
|
if tt.valid && u.String() != tt.expected {
|
||||||
t.Errorf("expected %s, got %s", tt.expected, u.String)
|
t.Errorf("expected %s, got %s", tt.expected, u.String())
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -480,13 +480,13 @@ func TestSqlUUID_Scan(t *testing.T) {
|
|||||||
|
|
||||||
func TestSqlUUID_Value(t *testing.T) {
|
func TestSqlUUID_Value(t *testing.T) {
|
||||||
testUUID := uuid.New()
|
testUUID := uuid.New()
|
||||||
u := SqlUUID{String: testUUID.String(), Valid: true}
|
u := NewSqlUUID(testUUID)
|
||||||
|
|
||||||
val, err := u.Value()
|
val, err := u.Value()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Value failed: %v", err)
|
t.Fatalf("Value failed: %v", err)
|
||||||
}
|
}
|
||||||
if val != testUUID.String() {
|
if val != testUUID {
|
||||||
t.Errorf("expected %s, got %s", testUUID.String(), val)
|
t.Errorf("expected %s, got %s", testUUID.String(), val)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -503,7 +503,7 @@ func TestSqlUUID_Value(t *testing.T) {
|
|||||||
|
|
||||||
func TestSqlUUID_JSON(t *testing.T) {
|
func TestSqlUUID_JSON(t *testing.T) {
|
||||||
testUUID := uuid.New()
|
testUUID := uuid.New()
|
||||||
u := SqlUUID{String: testUUID.String(), Valid: true}
|
u := NewSqlUUID(testUUID)
|
||||||
|
|
||||||
// Marshal
|
// Marshal
|
||||||
data, err := json.Marshal(u)
|
data, err := json.Marshal(u)
|
||||||
@@ -520,8 +520,8 @@ func TestSqlUUID_JSON(t *testing.T) {
|
|||||||
if err := json.Unmarshal([]byte(`"`+testUUID.String()+`"`), &u2); err != nil {
|
if err := json.Unmarshal([]byte(`"`+testUUID.String()+`"`), &u2); err != nil {
|
||||||
t.Fatalf("Unmarshal failed: %v", err)
|
t.Fatalf("Unmarshal failed: %v", err)
|
||||||
}
|
}
|
||||||
if u2.String != testUUID.String() {
|
if u2.String() != testUUID.String() {
|
||||||
t.Errorf("expected %s, got %s", testUUID.String(), u2.String)
|
t.Errorf("expected %s, got %s", testUUID.String(), u2.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test null
|
// Test null
|
||||||
|
|||||||
@@ -16,8 +16,8 @@ import (
|
|||||||
|
|
||||||
// MockDatabase implements common.Database interface for testing
|
// MockDatabase implements common.Database interface for testing
|
||||||
type MockDatabase struct {
|
type MockDatabase struct {
|
||||||
QueryFunc func(ctx context.Context, dest interface{}, query string, args ...interface{}) error
|
QueryFunc func(ctx context.Context, dest interface{}, query string, args ...interface{}) error
|
||||||
ExecFunc func(ctx context.Context, query string, args ...interface{}) (common.Result, error)
|
ExecFunc func(ctx context.Context, query string, args ...interface{}) (common.Result, error)
|
||||||
RunInTransactionFunc func(ctx context.Context, fn func(common.Database) error) error
|
RunInTransactionFunc func(ctx context.Context, fn func(common.Database) error) error
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -161,9 +161,9 @@ func TestExtractInputVariables(t *testing.T) {
|
|||||||
handler := NewHandler(&MockDatabase{})
|
handler := NewHandler(&MockDatabase{})
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
sqlQuery string
|
sqlQuery string
|
||||||
expectedVars []string
|
expectedVars []string
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "No variables",
|
name: "No variables",
|
||||||
@@ -340,9 +340,9 @@ func TestSqlQryWhere(t *testing.T) {
|
|||||||
// TestGetIPAddress tests IP address extraction
|
// TestGetIPAddress tests IP address extraction
|
||||||
func TestGetIPAddress(t *testing.T) {
|
func TestGetIPAddress(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
setupReq func() *http.Request
|
setupReq func() *http.Request
|
||||||
expected string
|
expected string
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "X-Forwarded-For header",
|
name: "X-Forwarded-For header",
|
||||||
@@ -782,9 +782,10 @@ func TestReplaceMetaVariables(t *testing.T) {
|
|||||||
handler := NewHandler(&MockDatabase{})
|
handler := NewHandler(&MockDatabase{})
|
||||||
|
|
||||||
userCtx := &security.UserContext{
|
userCtx := &security.UserContext{
|
||||||
UserID: 123,
|
UserID: 123,
|
||||||
UserName: "testuser",
|
UserName: "testuser",
|
||||||
SessionID: "456",
|
SessionID: "ABC456",
|
||||||
|
SessionRID: 456,
|
||||||
}
|
}
|
||||||
|
|
||||||
metainfo := map[string]interface{}{
|
metainfo := map[string]interface{}{
|
||||||
@@ -821,6 +822,12 @@ func TestReplaceMetaVariables(t *testing.T) {
|
|||||||
expectedCheck: func(result string) bool {
|
expectedCheck: func(result string) bool {
|
||||||
return strings.Contains(result, "456")
|
return strings.Contains(result, "456")
|
||||||
},
|
},
|
||||||
|
}, {
|
||||||
|
name: "Replace [id_session]",
|
||||||
|
sqlQuery: "SELECT * FROM sessions WHERE session_id = [id_session]",
|
||||||
|
expectedCheck: func(result string) bool {
|
||||||
|
return strings.Contains(result, "ABC456")
|
||||||
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -28,6 +28,10 @@ func NewModelRegistry() *DefaultModelRegistry {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GetDefaultRegistry() *DefaultModelRegistry {
|
||||||
|
return defaultRegistry
|
||||||
|
}
|
||||||
|
|
||||||
func SetDefaultRegistry(registry *DefaultModelRegistry) {
|
func SetDefaultRegistry(registry *DefaultModelRegistry) {
|
||||||
registriesMutex.Lock()
|
registriesMutex.Lock()
|
||||||
defer registriesMutex.Unlock()
|
defer registriesMutex.Unlock()
|
||||||
|
|||||||
Reference in New Issue
Block a user