From 3ec2e5f15a3304bc664101deb4093101be602283 Mon Sep 17 00:00:00 2001 From: Hein Date: Tue, 9 Dec 2025 13:55:51 +0200 Subject: [PATCH] Proper handling of fromString in the types --- pkg/common/sql_types.go | 16 ++++++---------- pkg/common/sql_types_test.go | 12 ++++++------ 2 files changed, 12 insertions(+), 16 deletions(-) diff --git a/pkg/common/sql_types.go b/pkg/common/sql_types.go index cef9a0d..ee31d91 100644 --- a/pkg/common/sql_types.go +++ b/pkg/common/sql_types.go @@ -78,7 +78,6 @@ func (n *SqlNull[T]) Scan(value any) error { return n.fromString(fmt.Sprintf("%v", value)) } } - func (n *SqlNull[T]) fromString(s string) error { s = strings.TrimSpace(s) n.Valid = false @@ -90,19 +89,14 @@ func (n *SqlNull[T]) fromString(s string) error { var zero T switch any(zero).(type) { - case int, int8, int16, int32, int64: + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: if i, err := strconv.ParseInt(s, 10, 64); err == nil { - n.Val = any(int64(i)).(T) // Cast to T (e.g., int16) - n.Valid = true - } - case uint, uint8, uint16, uint32, uint64: - if u, err := strconv.ParseUint(s, 10, 64); err == nil { - n.Val = any(u).(T) + reflect.ValueOf(&n.Val).Elem().SetInt(i) n.Valid = true } case float32, float64: if f, err := strconv.ParseFloat(s, 64); err == nil { - n.Val = any(f).(T) + reflect.ValueOf(&n.Val).Elem().SetFloat(f) n.Valid = true } case bool: @@ -124,7 +118,6 @@ func (n *SqlNull[T]) fromString(s string) error { n.Val = any(s).(T) n.Valid = true } - return nil } @@ -517,6 +510,9 @@ func TryIfInt64(v any, def int64) int64 { } // Constructor helpers - clean and fast value creation +func Null[T any](v T, valid bool) SqlNull[T] { + return SqlNull[T]{Val: v, Valid: valid} +} func NewSqlInt16(v int16) SqlInt16 { return SqlInt16{Val: v, Valid: true} diff --git a/pkg/common/sql_types_test.go b/pkg/common/sql_types_test.go index 0117e52..86d1c88 100644 --- a/pkg/common/sql_types_test.go +++ b/pkg/common/sql_types_test.go @@ -16,11 +16,11 @@ func TestNewSqlInt16(t *testing.T) { input interface{} expected SqlInt16 }{ - {"int", 42, NewSqlInt16(42)}, + {"int", 42, Null(int16(42), true)}, {"int32", int32(100), NewSqlInt16(100)}, {"int64", int64(200), NewSqlInt16(200)}, {"string", "123", NewSqlInt16(123)}, - {"nil", nil, NewSqlInt16(0)}, + {"nil", nil, Null(int16(0), false)}, } for _, tt := range tests { @@ -42,9 +42,9 @@ func TestNewSqlInt16_Value(t *testing.T) { input SqlInt16 expected driver.Value }{ - {"zero", NewSqlInt16(0), nil}, - {"positive", NewSqlInt16(42), int64(42)}, - {"negative", NewSqlInt16(-10), int64(-10)}, + {"zero", Null(int16(0), false), nil}, + {"positive", NewSqlInt16(42), int16(42)}, + {"negative", NewSqlInt16(-10), int16(-10)}, } for _, tt := range tests { @@ -486,7 +486,7 @@ func TestSqlUUID_Value(t *testing.T) { if err != nil { t.Fatalf("Value failed: %v", err) } - if val != testUUID.String() { + if val != testUUID { t.Errorf("expected %s, got %s", testUUID.String(), val) }