Compare commits

...

5 Commits

Author SHA1 Message Date
Hein
8fcb065b42 Better Query Debugging 2025-12-09 14:31:26 +02:00
Hein
dc3b621380 Fixed test for session id changes 2025-12-09 14:07:00 +02:00
Hein
a4dd2a7086 exposed types FromString 2025-12-09 14:03:55 +02:00
Hein
3ec2e5f15a Proper handling of fromString in the types 2025-12-09 13:55:51 +02:00
Hein
c52afe2825 Updated sql types 2025-12-09 13:14:22 +02:00
5 changed files with 534 additions and 643 deletions

View File

@@ -410,6 +410,9 @@ func (b *BunSelectQuery) Scan(ctx context.Context, dest interface{}) (err error)
// Execute the main query first
err = b.query.Scan(ctx, dest)
if err != nil {
// Log SQL string for debugging
sqlStr := b.query.String()
logger.Error("BunSelectQuery.Scan failed. SQL: %s. Error: %v", sqlStr, err)
return err
}
@@ -438,6 +441,9 @@ func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) {
// Execute the main query first
err = b.query.Scan(ctx)
if err != nil {
// Log SQL string for debugging
sqlStr := b.query.String()
logger.Error("BunSelectQuery.ScanModel failed. SQL: %s. Error: %v", sqlStr, err)
return err
}
@@ -573,15 +579,25 @@ func (b *BunSelectQuery) Count(ctx context.Context) (count int, err error) {
// If Model() was set, use bun's native Count() which works properly
if b.hasModel {
count, err := b.query.Count(ctx)
if err != nil {
// Log SQL string for debugging
sqlStr := b.query.String()
logger.Error("BunSelectQuery.Count failed. SQL: %s. Error: %v", sqlStr, err)
}
return count, err
}
// Otherwise, wrap as subquery to avoid "Model(nil)" error
// This is needed when only Table() is set without a model
err = b.db.NewSelect().
countQuery := b.db.NewSelect().
TableExpr("(?) AS subquery", b.query).
ColumnExpr("COUNT(*)").
Scan(ctx, &count)
ColumnExpr("COUNT(*)")
err = countQuery.Scan(ctx, &count)
if err != nil {
// Log SQL string for debugging
sqlStr := countQuery.String()
logger.Error("BunSelectQuery.Count (subquery) failed. SQL: %s. Error: %v", sqlStr, err)
}
return count, err
}
@@ -592,7 +608,13 @@ func (b *BunSelectQuery) Exists(ctx context.Context) (exists bool, err error) {
exists = false
}
}()
return b.query.Exists(ctx)
exists, err = b.query.Exists(ctx)
if err != nil {
// Log SQL string for debugging
sqlStr := b.query.String()
logger.Error("BunSelectQuery.Exists failed. SQL: %s. Error: %v", sqlStr, err)
}
return exists, err
}
// BunInsertQuery implements InsertQuery for Bun
@@ -729,6 +751,11 @@ func (b *BunUpdateQuery) Exec(ctx context.Context) (res common.Result, err error
}
}()
result, err := b.query.Exec(ctx)
if err != nil {
// Log SQL string for debugging
sqlStr := b.query.String()
logger.Error("BunUpdateQuery.Exec failed. SQL: %s. Error: %v", sqlStr, err)
}
return &BunResult{result: result}, err
}
@@ -759,6 +786,11 @@ func (b *BunDeleteQuery) Exec(ctx context.Context) (res common.Result, err error
}
}()
result, err := b.query.Exec(ctx)
if err != nil {
// Log SQL string for debugging
sqlStr := b.query.String()
logger.Error("BunDeleteQuery.Exec failed. SQL: %s. Error: %v", sqlStr, err)
}
return &BunResult{result: result}, err
}

View File

@@ -282,7 +282,15 @@ func (g *GormSelectQuery) Scan(ctx context.Context, dest interface{}) (err error
err = logger.HandlePanic("GormSelectQuery.Scan", r)
}
}()
return g.db.WithContext(ctx).Find(dest).Error
err = g.db.WithContext(ctx).Find(dest).Error
if err != nil {
// Log SQL string for debugging
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
return tx.Find(dest)
})
logger.Error("GormSelectQuery.Scan failed. SQL: %s. Error: %v", sqlStr, err)
}
return err
}
func (g *GormSelectQuery) ScanModel(ctx context.Context) (err error) {
@@ -294,7 +302,15 @@ func (g *GormSelectQuery) ScanModel(ctx context.Context) (err error) {
if g.db.Statement.Model == nil {
return fmt.Errorf("ScanModel requires Model() to be set before scanning")
}
return g.db.WithContext(ctx).Find(g.db.Statement.Model).Error
err = g.db.WithContext(ctx).Find(g.db.Statement.Model).Error
if err != nil {
// Log SQL string for debugging
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
return tx.Find(g.db.Statement.Model)
})
logger.Error("GormSelectQuery.ScanModel failed. SQL: %s. Error: %v", sqlStr, err)
}
return err
}
func (g *GormSelectQuery) Count(ctx context.Context) (count int, err error) {
@@ -306,6 +322,13 @@ func (g *GormSelectQuery) Count(ctx context.Context) (count int, err error) {
}()
var count64 int64
err = g.db.WithContext(ctx).Count(&count64).Error
if err != nil {
// Log SQL string for debugging
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
return tx.Count(&count64)
})
logger.Error("GormSelectQuery.Count failed. SQL: %s. Error: %v", sqlStr, err)
}
return int(count64), err
}
@@ -318,6 +341,13 @@ func (g *GormSelectQuery) Exists(ctx context.Context) (exists bool, err error) {
}()
var count int64
err = g.db.WithContext(ctx).Limit(1).Count(&count).Error
if err != nil {
// Log SQL string for debugging
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
return tx.Limit(1).Count(&count)
})
logger.Error("GormSelectQuery.Exists failed. SQL: %s. Error: %v", sqlStr, err)
}
return count > 0, err
}
@@ -456,6 +486,13 @@ func (g *GormUpdateQuery) Exec(ctx context.Context) (res common.Result, err erro
}
}()
result := g.db.WithContext(ctx).Updates(g.updates)
if result.Error != nil {
// Log SQL string for debugging
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
return tx.Updates(g.updates)
})
logger.Error("GormUpdateQuery.Exec failed. SQL: %s. Error: %v", sqlStr, result.Error)
}
return &GormResult{result: result}, result.Error
}
@@ -488,6 +525,13 @@ func (g *GormDeleteQuery) Exec(ctx context.Context) (res common.Result, err erro
}
}()
result := g.db.WithContext(ctx).Delete(g.model)
if result.Error != nil {
// Log SQL string for debugging
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
return tx.Delete(g.model)
})
logger.Error("GormDeleteQuery.Exec failed. SQL: %s. Error: %v", sqlStr, result.Error)
}
return &GormResult{result: result}, result.Error
}

File diff suppressed because it is too large Load Diff

View File

@@ -9,18 +9,18 @@ import (
"github.com/google/uuid"
)
// TestSqlInt16 tests SqlInt16 type
func TestSqlInt16(t *testing.T) {
// TestNewSqlInt16 tests NewSqlInt16 type
func TestNewSqlInt16(t *testing.T) {
tests := []struct {
name string
input interface{}
expected SqlInt16
}{
{"int", 42, SqlInt16(42)},
{"int32", int32(100), SqlInt16(100)},
{"int64", int64(200), SqlInt16(200)},
{"string", "123", SqlInt16(123)},
{"nil", nil, SqlInt16(0)},
{"int", 42, Null(int16(42), true)},
{"int32", int32(100), NewSqlInt16(100)},
{"int64", int64(200), NewSqlInt16(200)},
{"string", "123", NewSqlInt16(123)},
{"nil", nil, Null(int16(0), false)},
}
for _, tt := range tests {
@@ -36,15 +36,15 @@ func TestSqlInt16(t *testing.T) {
}
}
func TestSqlInt16_Value(t *testing.T) {
func TestNewSqlInt16_Value(t *testing.T) {
tests := []struct {
name string
input SqlInt16
expected driver.Value
}{
{"zero", SqlInt16(0), nil},
{"positive", SqlInt16(42), int64(42)},
{"negative", SqlInt16(-10), int64(-10)},
{"zero", Null(int16(0), false), nil},
{"positive", NewSqlInt16(42), int16(42)},
{"negative", NewSqlInt16(-10), int16(-10)},
}
for _, tt := range tests {
@@ -60,8 +60,8 @@ func TestSqlInt16_Value(t *testing.T) {
}
}
func TestSqlInt16_JSON(t *testing.T) {
n := SqlInt16(42)
func TestNewSqlInt16_JSON(t *testing.T) {
n := NewSqlInt16(42)
// Marshal
data, err := json.Marshal(n)
@@ -78,24 +78,24 @@ func TestSqlInt16_JSON(t *testing.T) {
if err := json.Unmarshal([]byte("123"), &n2); err != nil {
t.Fatalf("Unmarshal failed: %v", err)
}
if n2 != 123 {
t.Errorf("expected 123, got %d", n2)
if n2.Int64() != 123 {
t.Errorf("expected 123, got %d", n2.Int64())
}
}
// TestSqlInt64 tests SqlInt64 type
func TestSqlInt64(t *testing.T) {
// TestNewSqlInt64 tests NewSqlInt64 type
func TestNewSqlInt64(t *testing.T) {
tests := []struct {
name string
input interface{}
expected SqlInt64
}{
{"int", 42, SqlInt64(42)},
{"int32", int32(100), SqlInt64(100)},
{"int64", int64(9223372036854775807), SqlInt64(9223372036854775807)},
{"uint32", uint32(100), SqlInt64(100)},
{"uint64", uint64(200), SqlInt64(200)},
{"nil", nil, SqlInt64(0)},
{"int", 42, NewSqlInt64(42)},
{"int32", int32(100), NewSqlInt64(100)},
{"int64", int64(9223372036854775807), NewSqlInt64(9223372036854775807)},
{"uint32", uint32(100), NewSqlInt64(100)},
{"uint64", uint64(200), NewSqlInt64(200)},
{"nil", nil, SqlInt64{}},
}
for _, tt := range tests {
@@ -135,8 +135,8 @@ func TestSqlFloat64(t *testing.T) {
if n.Valid != tt.valid {
t.Errorf("expected valid=%v, got valid=%v", tt.valid, n.Valid)
}
if tt.valid && n.Float64 != tt.expected {
t.Errorf("expected %v, got %v", tt.expected, n.Float64)
if tt.valid && n.Float64() != tt.expected {
t.Errorf("expected %v, got %v", tt.expected, n.Float64())
}
})
}
@@ -162,7 +162,7 @@ func TestSqlTimeStamp(t *testing.T) {
if err := ts.Scan(tt.input); err != nil {
t.Fatalf("Scan failed: %v", err)
}
if ts.GetTime().IsZero() {
if ts.Time().IsZero() {
t.Error("expected non-zero time")
}
})
@@ -171,7 +171,7 @@ func TestSqlTimeStamp(t *testing.T) {
func TestSqlTimeStamp_JSON(t *testing.T) {
now := time.Date(2024, 1, 15, 10, 30, 45, 0, time.UTC)
ts := SqlTimeStamp(now)
ts := NewSqlTimeStamp(now)
// Marshal
data, err := json.Marshal(ts)
@@ -188,8 +188,8 @@ func TestSqlTimeStamp_JSON(t *testing.T) {
if err := json.Unmarshal([]byte(`"2024-01-15T10:30:45"`), &ts2); err != nil {
t.Fatalf("Unmarshal failed: %v", err)
}
if ts2.GetTime().Year() != 2024 {
t.Errorf("expected year 2024, got %d", ts2.GetTime().Year())
if ts2.Time().Year() != 2024 {
t.Errorf("expected year 2024, got %d", ts2.Time().Year())
}
// Test null
@@ -226,7 +226,7 @@ func TestSqlDate(t *testing.T) {
}
func TestSqlDate_JSON(t *testing.T) {
date := SqlDate(time.Date(2024, 1, 15, 0, 0, 0, 0, time.UTC))
date := NewSqlDate(time.Date(2024, 1, 15, 0, 0, 0, 0, time.UTC))
// Marshal
data, err := json.Marshal(date)
@@ -471,8 +471,8 @@ func TestSqlUUID_Scan(t *testing.T) {
if u.Valid != tt.valid {
t.Errorf("expected valid=%v, got valid=%v", tt.valid, u.Valid)
}
if tt.valid && u.String != tt.expected {
t.Errorf("expected %s, got %s", tt.expected, u.String)
if tt.valid && u.String() != tt.expected {
t.Errorf("expected %s, got %s", tt.expected, u.String())
}
})
}
@@ -480,13 +480,13 @@ func TestSqlUUID_Scan(t *testing.T) {
func TestSqlUUID_Value(t *testing.T) {
testUUID := uuid.New()
u := SqlUUID{String: testUUID.String(), Valid: true}
u := NewSqlUUID(testUUID)
val, err := u.Value()
if err != nil {
t.Fatalf("Value failed: %v", err)
}
if val != testUUID.String() {
if val != testUUID {
t.Errorf("expected %s, got %s", testUUID.String(), val)
}
@@ -503,7 +503,7 @@ func TestSqlUUID_Value(t *testing.T) {
func TestSqlUUID_JSON(t *testing.T) {
testUUID := uuid.New()
u := SqlUUID{String: testUUID.String(), Valid: true}
u := NewSqlUUID(testUUID)
// Marshal
data, err := json.Marshal(u)
@@ -520,8 +520,8 @@ func TestSqlUUID_JSON(t *testing.T) {
if err := json.Unmarshal([]byte(`"`+testUUID.String()+`"`), &u2); err != nil {
t.Fatalf("Unmarshal failed: %v", err)
}
if u2.String != testUUID.String() {
t.Errorf("expected %s, got %s", testUUID.String(), u2.String)
if u2.String() != testUUID.String() {
t.Errorf("expected %s, got %s", testUUID.String(), u2.String())
}
// Test null

View File

@@ -16,8 +16,8 @@ import (
// MockDatabase implements common.Database interface for testing
type MockDatabase struct {
QueryFunc func(ctx context.Context, dest interface{}, query string, args ...interface{}) error
ExecFunc func(ctx context.Context, query string, args ...interface{}) (common.Result, error)
QueryFunc func(ctx context.Context, dest interface{}, query string, args ...interface{}) error
ExecFunc func(ctx context.Context, query string, args ...interface{}) (common.Result, error)
RunInTransactionFunc func(ctx context.Context, fn func(common.Database) error) error
}
@@ -161,9 +161,9 @@ func TestExtractInputVariables(t *testing.T) {
handler := NewHandler(&MockDatabase{})
tests := []struct {
name string
sqlQuery string
expectedVars []string
name string
sqlQuery string
expectedVars []string
}{
{
name: "No variables",
@@ -340,9 +340,9 @@ func TestSqlQryWhere(t *testing.T) {
// TestGetIPAddress tests IP address extraction
func TestGetIPAddress(t *testing.T) {
tests := []struct {
name string
setupReq func() *http.Request
expected string
name string
setupReq func() *http.Request
expected string
}{
{
name: "X-Forwarded-For header",
@@ -782,9 +782,10 @@ func TestReplaceMetaVariables(t *testing.T) {
handler := NewHandler(&MockDatabase{})
userCtx := &security.UserContext{
UserID: 123,
UserName: "testuser",
SessionID: "456",
UserID: 123,
UserName: "testuser",
SessionID: "ABC456",
SessionRID: 456,
}
metainfo := map[string]interface{}{
@@ -821,6 +822,12 @@ func TestReplaceMetaVariables(t *testing.T) {
expectedCheck: func(result string) bool {
return strings.Contains(result, "456")
},
}, {
name: "Replace [id_session]",
sqlQuery: "SELECT * FROM sessions WHERE session_id = [id_session]",
expectedCheck: func(result string) bool {
return strings.Contains(result, "ABC456")
},
},
}