Compare commits

...

3 Commits

Author SHA1 Message Date
Hein
dc3b621380 Fixed test for session id changes 2025-12-09 14:07:00 +02:00
Hein
a4dd2a7086 exposed types FromString 2025-12-09 14:03:55 +02:00
Hein
3ec2e5f15a Proper handling of fromString in the types 2025-12-09 13:55:51 +02:00
3 changed files with 59 additions and 32 deletions

View File

@@ -71,15 +71,14 @@ func (n *SqlNull[T]) Scan(value any) error {
// Fallback: parse from string/bytes. // Fallback: parse from string/bytes.
switch v := value.(type) { switch v := value.(type) {
case string: case string:
return n.fromString(v) return n.FromString(v)
case []byte: case []byte:
return n.fromString(string(v)) return n.FromString(string(v))
default: default:
return n.fromString(fmt.Sprintf("%v", value)) return n.FromString(fmt.Sprintf("%v", value))
} }
} }
func (n *SqlNull[T]) FromString(s string) error {
func (n *SqlNull[T]) fromString(s string) error {
s = strings.TrimSpace(s) s = strings.TrimSpace(s)
n.Valid = false n.Valid = false
n.Val = *new(T) n.Val = *new(T)
@@ -90,19 +89,14 @@ func (n *SqlNull[T]) fromString(s string) error {
var zero T var zero T
switch any(zero).(type) { switch any(zero).(type) {
case int, int8, int16, int32, int64: case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
if i, err := strconv.ParseInt(s, 10, 64); err == nil { if i, err := strconv.ParseInt(s, 10, 64); err == nil {
n.Val = any(int64(i)).(T) // Cast to T (e.g., int16) reflect.ValueOf(&n.Val).Elem().SetInt(i)
n.Valid = true
}
case uint, uint8, uint16, uint32, uint64:
if u, err := strconv.ParseUint(s, 10, 64); err == nil {
n.Val = any(u).(T)
n.Valid = true n.Valid = true
} }
case float32, float64: case float32, float64:
if f, err := strconv.ParseFloat(s, 64); err == nil { if f, err := strconv.ParseFloat(s, 64); err == nil {
n.Val = any(f).(T) reflect.ValueOf(&n.Val).Elem().SetFloat(f)
n.Valid = true n.Valid = true
} }
case bool: case bool:
@@ -124,7 +118,6 @@ func (n *SqlNull[T]) fromString(s string) error {
n.Val = any(s).(T) n.Val = any(s).(T)
n.Valid = true n.Valid = true
} }
return nil return nil
} }
@@ -163,7 +156,7 @@ func (n *SqlNull[T]) UnmarshalJSON(b []byte) error {
// Fallback: unmarshal as string and parse. // Fallback: unmarshal as string and parse.
var s string var s string
if err := json.Unmarshal(b, &s); err == nil { if err := json.Unmarshal(b, &s); err == nil {
return n.fromString(s) return n.FromString(s)
} }
return fmt.Errorf("cannot unmarshal %s into SqlNull[%T]", b, n.Val) return fmt.Errorf("cannot unmarshal %s into SqlNull[%T]", b, n.Val)
@@ -517,6 +510,33 @@ func TryIfInt64(v any, def int64) int64 {
} }
// Constructor helpers - clean and fast value creation // Constructor helpers - clean and fast value creation
func Null[T any](v T, valid bool) SqlNull[T] {
return SqlNull[T]{Val: v, Valid: valid}
}
func NewSql[T any](value any) SqlNull[T] {
n := SqlNull[T]{}
if value == nil {
return n
}
// Fast path: exact match
if v, ok := value.(T); ok {
n.Val = v
n.Valid = true
return n
}
// Try from another SqlNull
if sn, ok := value.(SqlNull[T]); ok {
return sn
}
// Convert via string
_ = n.FromString(fmt.Sprintf("%v", value))
return n
}
func NewSqlInt16(v int16) SqlInt16 { func NewSqlInt16(v int16) SqlInt16 {
return SqlInt16{Val: v, Valid: true} return SqlInt16{Val: v, Valid: true}

View File

@@ -16,11 +16,11 @@ func TestNewSqlInt16(t *testing.T) {
input interface{} input interface{}
expected SqlInt16 expected SqlInt16
}{ }{
{"int", 42, NewSqlInt16(42)}, {"int", 42, Null(int16(42), true)},
{"int32", int32(100), NewSqlInt16(100)}, {"int32", int32(100), NewSqlInt16(100)},
{"int64", int64(200), NewSqlInt16(200)}, {"int64", int64(200), NewSqlInt16(200)},
{"string", "123", NewSqlInt16(123)}, {"string", "123", NewSqlInt16(123)},
{"nil", nil, NewSqlInt16(0)}, {"nil", nil, Null(int16(0), false)},
} }
for _, tt := range tests { for _, tt := range tests {
@@ -42,9 +42,9 @@ func TestNewSqlInt16_Value(t *testing.T) {
input SqlInt16 input SqlInt16
expected driver.Value expected driver.Value
}{ }{
{"zero", NewSqlInt16(0), nil}, {"zero", Null(int16(0), false), nil},
{"positive", NewSqlInt16(42), int64(42)}, {"positive", NewSqlInt16(42), int16(42)},
{"negative", NewSqlInt16(-10), int64(-10)}, {"negative", NewSqlInt16(-10), int16(-10)},
} }
for _, tt := range tests { for _, tt := range tests {
@@ -486,7 +486,7 @@ func TestSqlUUID_Value(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("Value failed: %v", err) t.Fatalf("Value failed: %v", err)
} }
if val != testUUID.String() { if val != testUUID {
t.Errorf("expected %s, got %s", testUUID.String(), val) t.Errorf("expected %s, got %s", testUUID.String(), val)
} }

View File

@@ -16,8 +16,8 @@ import (
// MockDatabase implements common.Database interface for testing // MockDatabase implements common.Database interface for testing
type MockDatabase struct { type MockDatabase struct {
QueryFunc func(ctx context.Context, dest interface{}, query string, args ...interface{}) error QueryFunc func(ctx context.Context, dest interface{}, query string, args ...interface{}) error
ExecFunc func(ctx context.Context, query string, args ...interface{}) (common.Result, error) ExecFunc func(ctx context.Context, query string, args ...interface{}) (common.Result, error)
RunInTransactionFunc func(ctx context.Context, fn func(common.Database) error) error RunInTransactionFunc func(ctx context.Context, fn func(common.Database) error) error
} }
@@ -161,9 +161,9 @@ func TestExtractInputVariables(t *testing.T) {
handler := NewHandler(&MockDatabase{}) handler := NewHandler(&MockDatabase{})
tests := []struct { tests := []struct {
name string name string
sqlQuery string sqlQuery string
expectedVars []string expectedVars []string
}{ }{
{ {
name: "No variables", name: "No variables",
@@ -340,9 +340,9 @@ func TestSqlQryWhere(t *testing.T) {
// TestGetIPAddress tests IP address extraction // TestGetIPAddress tests IP address extraction
func TestGetIPAddress(t *testing.T) { func TestGetIPAddress(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
setupReq func() *http.Request setupReq func() *http.Request
expected string expected string
}{ }{
{ {
name: "X-Forwarded-For header", name: "X-Forwarded-For header",
@@ -782,9 +782,10 @@ func TestReplaceMetaVariables(t *testing.T) {
handler := NewHandler(&MockDatabase{}) handler := NewHandler(&MockDatabase{})
userCtx := &security.UserContext{ userCtx := &security.UserContext{
UserID: 123, UserID: 123,
UserName: "testuser", UserName: "testuser",
SessionID: "456", SessionID: "ABC456",
SessionRID: 456,
} }
metainfo := map[string]interface{}{ metainfo := map[string]interface{}{
@@ -821,6 +822,12 @@ func TestReplaceMetaVariables(t *testing.T) {
expectedCheck: func(result string) bool { expectedCheck: func(result string) bool {
return strings.Contains(result, "456") return strings.Contains(result, "456")
}, },
}, {
name: "Replace [id_session]",
sqlQuery: "SELECT * FROM sessions WHERE session_id = [id_session]",
expectedCheck: func(result string) bool {
return strings.Contains(result, "ABC456")
},
}, },
} }