[breaking] Another breaking change datatypes -> spectypes

This commit is contained in:
Hein
2025-12-18 11:49:35 +02:00
parent 77f39af2f9
commit 7c861c708e
3 changed files with 13 additions and 13 deletions

579
pkg/spectypes/sql_types.go Normal file
View File

@@ -0,0 +1,579 @@
// Package spectypes provides nullable SQL types with automatic casting and conversion methods.
package spectypes
import (
"database/sql"
"database/sql/driver"
"encoding/json"
"fmt"
"reflect"
"strconv"
"strings"
"time"
"github.com/google/uuid"
)
// tryParseDT attempts to parse a string into a time.Time using various formats.
func tryParseDT(str string) (time.Time, error) {
var lasterror error
tryFormats := []string{
time.RFC3339,
"2006-01-02T15:04:05.000-0700",
"2006-01-02T15:04:05.000",
"06-01-02T15:04:05.000",
"2006-01-02T15:04:05",
"2006-01-02 15:04:05",
"02/01/2006",
"02-01-2006",
"2006-01-02",
"15:04:05.000",
"15:04:05",
"15:04",
}
for _, f := range tryFormats {
tx, err := time.Parse(f, str)
if err == nil {
return tx, nil
}
lasterror = err
}
return time.Time{}, lasterror // Return zero time on failure
}
// ToJSONDT formats a time.Time to RFC3339 string.
func ToJSONDT(dt time.Time) string {
return dt.Format(time.RFC3339)
}
// SqlNull is a generic nullable type that behaves like sql.NullXXX with auto-casting.
type SqlNull[T any] struct {
Val T
Valid bool
}
// Scan implements sql.Scanner.
func (n *SqlNull[T]) Scan(value any) error {
if value == nil {
n.Valid = false
n.Val = *new(T)
return nil
}
// Try standard sql.Null[T] first.
var sqlNull sql.Null[T]
if err := sqlNull.Scan(value); err == nil {
n.Val = sqlNull.V
n.Valid = sqlNull.Valid
return nil
}
// Fallback: parse from string/bytes.
switch v := value.(type) {
case string:
return n.FromString(v)
case []byte:
return n.FromString(string(v))
default:
return n.FromString(fmt.Sprintf("%v", value))
}
}
func (n *SqlNull[T]) FromString(s string) error {
s = strings.TrimSpace(s)
n.Valid = false
n.Val = *new(T)
if s == "" || strings.EqualFold(s, "null") {
return nil
}
var zero T
switch any(zero).(type) {
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
if i, err := strconv.ParseInt(s, 10, 64); err == nil {
reflect.ValueOf(&n.Val).Elem().SetInt(i)
n.Valid = true
}
case float32, float64:
if f, err := strconv.ParseFloat(s, 64); err == nil {
reflect.ValueOf(&n.Val).Elem().SetFloat(f)
n.Valid = true
}
case bool:
if b, err := strconv.ParseBool(s); err == nil {
n.Val = any(b).(T)
n.Valid = true
}
case time.Time:
if t, err := tryParseDT(s); err == nil && !t.IsZero() {
n.Val = any(t).(T)
n.Valid = true
}
case uuid.UUID:
if u, err := uuid.Parse(s); err == nil {
n.Val = any(u).(T)
n.Valid = true
}
case string:
n.Val = any(s).(T)
n.Valid = true
}
return nil
}
// Value implements driver.Valuer.
func (n SqlNull[T]) Value() (driver.Value, error) {
if !n.Valid {
return nil, nil
}
return any(n.Val), nil
}
// MarshalJSON implements json.Marshaler.
func (n SqlNull[T]) MarshalJSON() ([]byte, error) {
if !n.Valid {
return []byte("null"), nil
}
return json.Marshal(n.Val)
}
// UnmarshalJSON implements json.Unmarshaler.
func (n *SqlNull[T]) UnmarshalJSON(b []byte) error {
if len(b) == 0 || string(b) == "null" || strings.TrimSpace(string(b)) == "" {
n.Valid = false
n.Val = *new(T)
return nil
}
// Try direct unmarshal.
var val T
if err := json.Unmarshal(b, &val); err == nil {
n.Val = val
n.Valid = true
return nil
}
// Fallback: unmarshal as string and parse.
var s string
if err := json.Unmarshal(b, &s); err == nil {
return n.FromString(s)
}
return fmt.Errorf("cannot unmarshal %s into SqlNull[%T]", b, n.Val)
}
// String implements fmt.Stringer.
func (n SqlNull[T]) String() string {
if !n.Valid {
return ""
}
return fmt.Sprintf("%v", n.Val)
}
// Int64 converts to int64 or 0 if invalid.
func (n SqlNull[T]) Int64() int64 {
if !n.Valid {
return 0
}
v := reflect.ValueOf(any(n.Val))
switch v.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return v.Int()
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return int64(v.Uint())
case reflect.Float32, reflect.Float64:
return int64(v.Float())
case reflect.String:
i, _ := strconv.ParseInt(v.String(), 10, 64)
return i
case reflect.Bool:
if v.Bool() {
return 1
}
return 0
}
return 0
}
// Float64 converts to float64 or 0.0 if invalid.
func (n SqlNull[T]) Float64() float64 {
if !n.Valid {
return 0.0
}
v := reflect.ValueOf(any(n.Val))
switch v.Kind() {
case reflect.Float32, reflect.Float64:
return v.Float()
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return float64(v.Int())
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return float64(v.Uint())
case reflect.String:
f, _ := strconv.ParseFloat(v.String(), 64)
return f
}
return 0.0
}
// Bool converts to bool or false if invalid.
func (n SqlNull[T]) Bool() bool {
if !n.Valid {
return false
}
v := reflect.ValueOf(any(n.Val))
if v.Kind() == reflect.Bool {
return v.Bool()
}
s := strings.ToLower(strings.TrimSpace(fmt.Sprint(n.Val)))
return s == "true" || s == "t" || s == "1" || s == "yes" || s == "on"
}
// Time converts to time.Time or zero if invalid.
func (n SqlNull[T]) Time() time.Time {
if !n.Valid {
return time.Time{}
}
if t, ok := any(n.Val).(time.Time); ok {
return t
}
return time.Time{}
}
// UUID converts to uuid.UUID or Nil if invalid.
func (n SqlNull[T]) UUID() uuid.UUID {
if !n.Valid {
return uuid.Nil
}
if u, ok := any(n.Val).(uuid.UUID); ok {
return u
}
return uuid.Nil
}
// Type aliases for common types.
type (
SqlInt16 = SqlNull[int16]
SqlInt32 = SqlNull[int32]
SqlInt64 = SqlNull[int64]
SqlFloat64 = SqlNull[float64]
SqlBool = SqlNull[bool]
SqlString = SqlNull[string]
SqlUUID = SqlNull[uuid.UUID]
)
// SqlTimeStamp - Timestamp with custom formatting (YYYY-MM-DDTHH:MM:SS).
type SqlTimeStamp struct{ SqlNull[time.Time] }
func (t SqlTimeStamp) MarshalJSON() ([]byte, error) {
if !t.Valid || t.Val.IsZero() || t.Val.Before(time.Date(0002, 1, 1, 0, 0, 0, 0, time.UTC)) {
return []byte("null"), nil
}
return []byte(fmt.Sprintf(`"%s"`, t.Val.Format("2006-01-02T15:04:05"))), nil
}
func (t *SqlTimeStamp) UnmarshalJSON(b []byte) error {
if err := t.SqlNull.UnmarshalJSON(b); err != nil {
return err
}
if t.Valid && (t.Val.IsZero() || t.Val.Format("2006-01-02T15:04:05") == "0001-01-01T00:00:00") {
t.Valid = false
}
return nil
}
func (t SqlTimeStamp) Value() (driver.Value, error) {
if !t.Valid || t.Val.IsZero() || t.Val.Before(time.Date(0002, 1, 1, 0, 0, 0, 0, time.UTC)) {
return nil, nil
}
return t.Val.Format("2006-01-02T15:04:05"), nil
}
func SqlTimeStampNow() SqlTimeStamp {
return SqlTimeStamp{SqlNull: SqlNull[time.Time]{Val: time.Now(), Valid: true}}
}
// SqlDate - Date only (YYYY-MM-DD).
type SqlDate struct{ SqlNull[time.Time] }
func (d SqlDate) MarshalJSON() ([]byte, error) {
if !d.Valid || d.Val.IsZero() {
return []byte("null"), nil
}
s := d.Val.Format("2006-01-02")
if strings.HasPrefix(s, "0001-01-01") {
return []byte("null"), nil
}
return []byte(fmt.Sprintf(`"%s"`, s)), nil
}
func (d *SqlDate) UnmarshalJSON(b []byte) error {
if err := d.SqlNull.UnmarshalJSON(b); err != nil {
return err
}
if d.Valid && d.Val.Format("2006-01-02") <= "0001-01-01" {
d.Valid = false
}
return nil
}
func (d SqlDate) Value() (driver.Value, error) {
if !d.Valid || d.Val.IsZero() {
return nil, nil
}
s := d.Val.Format("2006-01-02")
if s <= "0001-01-01" {
return nil, nil
}
return s, nil
}
func (d SqlDate) String() string {
if !d.Valid {
return ""
}
s := d.Val.Format("2006-01-02")
if strings.HasPrefix(s, "0001-01-01") || strings.HasPrefix(s, "1800-12-31") {
return ""
}
return s
}
func SqlDateNow() SqlDate {
return SqlDate{SqlNull: SqlNull[time.Time]{Val: time.Now(), Valid: true}}
}
// SqlTime - Time only (HH:MM:SS).
type SqlTime struct{ SqlNull[time.Time] }
func (t SqlTime) MarshalJSON() ([]byte, error) {
if !t.Valid || t.Val.IsZero() {
return []byte("null"), nil
}
s := t.Val.Format("15:04:05")
if s == "00:00:00" {
return []byte("null"), nil
}
return []byte(fmt.Sprintf(`"%s"`, s)), nil
}
func (t *SqlTime) UnmarshalJSON(b []byte) error {
if err := t.SqlNull.UnmarshalJSON(b); err != nil {
return err
}
if t.Valid && t.Val.Format("15:04:05") == "00:00:00" {
t.Valid = false
}
return nil
}
func (t SqlTime) Value() (driver.Value, error) {
if !t.Valid || t.Val.IsZero() {
return nil, nil
}
return t.Val.Format("15:04:05"), nil
}
func (t SqlTime) String() string {
if !t.Valid {
return ""
}
return t.Val.Format("15:04:05")
}
func SqlTimeNow() SqlTime {
return SqlTime{SqlNull: SqlNull[time.Time]{Val: time.Now(), Valid: true}}
}
// SqlJSONB - Nullable JSONB as []byte.
type SqlJSONB []byte
// Scan implements sql.Scanner.
func (n *SqlJSONB) Scan(value any) error {
if value == nil {
*n = nil
return nil
}
switch v := value.(type) {
case string:
*n = []byte(v)
case []byte:
*n = v
default:
dat, err := json.Marshal(value)
if err != nil {
return fmt.Errorf("failed to marshal value to JSON: %v", err)
}
*n = dat
}
return nil
}
// Value implements driver.Valuer.
func (n SqlJSONB) Value() (driver.Value, error) {
if len(n) == 0 {
return nil, nil
}
var js any
if err := json.Unmarshal(n, &js); err != nil {
return nil, fmt.Errorf("invalid JSON: %v", err)
}
return string(n), nil
}
// MarshalJSON implements json.Marshaler.
func (n SqlJSONB) MarshalJSON() ([]byte, error) {
if len(n) == 0 {
return []byte("null"), nil
}
var obj any
if err := json.Unmarshal(n, &obj); err != nil {
return []byte("null"), nil
}
return n, nil
}
// UnmarshalJSON implements json.Unmarshaler.
func (n *SqlJSONB) UnmarshalJSON(b []byte) error {
s := strings.TrimSpace(string(b))
if s == "null" || s == "" || (!strings.HasPrefix(s, "{") && !strings.HasPrefix(s, "[")) {
*n = nil
return nil
}
*n = b
return nil
}
func (n SqlJSONB) AsMap() (map[string]any, error) {
if len(n) == 0 {
return nil, nil
}
js := make(map[string]any)
if err := json.Unmarshal(n, &js); err != nil {
return nil, fmt.Errorf("invalid JSON: %v", err)
}
return js, nil
}
func (n SqlJSONB) AsSlice() ([]any, error) {
if len(n) == 0 {
return nil, nil
}
js := make([]any, 0)
if err := json.Unmarshal(n, &js); err != nil {
return nil, fmt.Errorf("invalid JSON: %v", err)
}
return js, nil
}
// TryIfInt64 tries to parse any value to int64 with default.
func TryIfInt64(v any, def int64) int64 {
switch val := v.(type) {
case string:
i, err := strconv.ParseInt(val, 10, 64)
if err != nil {
return def
}
return i
case int:
return int64(val)
case int8:
return int64(val)
case int16:
return int64(val)
case int32:
return int64(val)
case int64:
return val
case uint:
return int64(val)
case uint8:
return int64(val)
case uint16:
return int64(val)
case uint32:
return int64(val)
case uint64:
return int64(val)
case float32:
return int64(val)
case float64:
return int64(val)
case []byte:
i, err := strconv.ParseInt(string(val), 10, 64)
if err != nil {
return def
}
return i
default:
return def
}
}
// Constructor helpers - clean and fast value creation
func Null[T any](v T, valid bool) SqlNull[T] {
return SqlNull[T]{Val: v, Valid: valid}
}
func NewSql[T any](value any) SqlNull[T] {
n := SqlNull[T]{}
if value == nil {
return n
}
// Fast path: exact match
if v, ok := value.(T); ok {
n.Val = v
n.Valid = true
return n
}
// Try from another SqlNull
if sn, ok := value.(SqlNull[T]); ok {
return sn
}
// Convert via string
_ = n.FromString(fmt.Sprintf("%v", value))
return n
}
func NewSqlInt16(v int16) SqlInt16 {
return SqlInt16{Val: v, Valid: true}
}
func NewSqlInt32(v int32) SqlInt32 {
return SqlInt32{Val: v, Valid: true}
}
func NewSqlInt64(v int64) SqlInt64 {
return SqlInt64{Val: v, Valid: true}
}
func NewSqlFloat64(v float64) SqlFloat64 {
return SqlFloat64{Val: v, Valid: true}
}
func NewSqlBool(v bool) SqlBool {
return SqlBool{Val: v, Valid: true}
}
func NewSqlString(v string) SqlString {
return SqlString{Val: v, Valid: true}
}
func NewSqlUUID(v uuid.UUID) SqlUUID {
return SqlUUID{Val: v, Valid: true}
}
func NewSqlTimeStamp(v time.Time) SqlTimeStamp {
return SqlTimeStamp{SqlNull: SqlNull[time.Time]{Val: v, Valid: true}}
}
func NewSqlDate(v time.Time) SqlDate {
return SqlDate{SqlNull: SqlNull[time.Time]{Val: v, Valid: true}}
}
func NewSqlTime(v time.Time) SqlTime {
return SqlTime{SqlNull: SqlNull[time.Time]{Val: v, Valid: true}}
}

View File

@@ -0,0 +1,566 @@
package spectypes
import (
"database/sql/driver"
"encoding/json"
"testing"
"time"
"github.com/google/uuid"
)
// TestNewSqlInt16 tests NewSqlInt16 type
func TestNewSqlInt16(t *testing.T) {
tests := []struct {
name string
input interface{}
expected SqlInt16
}{
{"int", 42, Null(int16(42), true)},
{"int32", int32(100), NewSqlInt16(100)},
{"int64", int64(200), NewSqlInt16(200)},
{"string", "123", NewSqlInt16(123)},
{"nil", nil, Null(int16(0), false)},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var n SqlInt16
if err := n.Scan(tt.input); err != nil {
t.Fatalf("Scan failed: %v", err)
}
if n != tt.expected {
t.Errorf("expected %v, got %v", tt.expected, n)
}
})
}
}
func TestNewSqlInt16_Value(t *testing.T) {
tests := []struct {
name string
input SqlInt16
expected driver.Value
}{
{"zero", Null(int16(0), false), nil},
{"positive", NewSqlInt16(42), int16(42)},
{"negative", NewSqlInt16(-10), int16(-10)},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
val, err := tt.input.Value()
if err != nil {
t.Fatalf("Value failed: %v", err)
}
if val != tt.expected {
t.Errorf("expected %v, got %v", tt.expected, val)
}
})
}
}
func TestNewSqlInt16_JSON(t *testing.T) {
n := NewSqlInt16(42)
// Marshal
data, err := json.Marshal(n)
if err != nil {
t.Fatalf("Marshal failed: %v", err)
}
expected := "42"
if string(data) != expected {
t.Errorf("expected %s, got %s", expected, string(data))
}
// Unmarshal
var n2 SqlInt16
if err := json.Unmarshal([]byte("123"), &n2); err != nil {
t.Fatalf("Unmarshal failed: %v", err)
}
if n2.Int64() != 123 {
t.Errorf("expected 123, got %d", n2.Int64())
}
}
// TestNewSqlInt64 tests NewSqlInt64 type
func TestNewSqlInt64(t *testing.T) {
tests := []struct {
name string
input interface{}
expected SqlInt64
}{
{"int", 42, NewSqlInt64(42)},
{"int32", int32(100), NewSqlInt64(100)},
{"int64", int64(9223372036854775807), NewSqlInt64(9223372036854775807)},
{"uint32", uint32(100), NewSqlInt64(100)},
{"uint64", uint64(200), NewSqlInt64(200)},
{"nil", nil, SqlInt64{}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var n SqlInt64
if err := n.Scan(tt.input); err != nil {
t.Fatalf("Scan failed: %v", err)
}
if n != tt.expected {
t.Errorf("expected %v, got %v", tt.expected, n)
}
})
}
}
// TestSqlFloat64 tests SqlFloat64 type
func TestSqlFloat64(t *testing.T) {
tests := []struct {
name string
input interface{}
expected float64
valid bool
}{
{"float64", float64(3.14), 3.14, true},
{"float32", float32(2.5), 2.5, true},
{"int", 42, 42.0, true},
{"int64", int64(100), 100.0, true},
{"nil", nil, 0, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var n SqlFloat64
if err := n.Scan(tt.input); err != nil {
t.Fatalf("Scan failed: %v", err)
}
if n.Valid != tt.valid {
t.Errorf("expected valid=%v, got valid=%v", tt.valid, n.Valid)
}
if tt.valid && n.Float64() != tt.expected {
t.Errorf("expected %v, got %v", tt.expected, n.Float64())
}
})
}
}
// TestSqlTimeStamp tests SqlTimeStamp type
func TestSqlTimeStamp(t *testing.T) {
now := time.Now()
tests := []struct {
name string
input interface{}
}{
{"time.Time", now},
{"string RFC3339", now.Format(time.RFC3339)},
{"string date", "2024-01-15"},
{"string datetime", "2024-01-15T10:30:00"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var ts SqlTimeStamp
if err := ts.Scan(tt.input); err != nil {
t.Fatalf("Scan failed: %v", err)
}
if ts.Time().IsZero() {
t.Error("expected non-zero time")
}
})
}
}
func TestSqlTimeStamp_JSON(t *testing.T) {
now := time.Date(2024, 1, 15, 10, 30, 45, 0, time.UTC)
ts := NewSqlTimeStamp(now)
// Marshal
data, err := json.Marshal(ts)
if err != nil {
t.Fatalf("Marshal failed: %v", err)
}
expected := `"2024-01-15T10:30:45"`
if string(data) != expected {
t.Errorf("expected %s, got %s", expected, string(data))
}
// Unmarshal
var ts2 SqlTimeStamp
if err := json.Unmarshal([]byte(`"2024-01-15T10:30:45"`), &ts2); err != nil {
t.Fatalf("Unmarshal failed: %v", err)
}
if ts2.Time().Year() != 2024 {
t.Errorf("expected year 2024, got %d", ts2.Time().Year())
}
// Test null
var ts3 SqlTimeStamp
if err := json.Unmarshal([]byte("null"), &ts3); err != nil {
t.Fatalf("Unmarshal null failed: %v", err)
}
}
// TestSqlDate tests SqlDate type
func TestSqlDate(t *testing.T) {
now := time.Now()
tests := []struct {
name string
input interface{}
}{
{"time.Time", now},
{"string date", "2024-01-15"},
{"string UK format", "15/01/2024"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var d SqlDate
if err := d.Scan(tt.input); err != nil {
t.Fatalf("Scan failed: %v", err)
}
if d.String() == "0" {
t.Error("expected non-zero date")
}
})
}
}
func TestSqlDate_JSON(t *testing.T) {
date := NewSqlDate(time.Date(2024, 1, 15, 0, 0, 0, 0, time.UTC))
// Marshal
data, err := json.Marshal(date)
if err != nil {
t.Fatalf("Marshal failed: %v", err)
}
expected := `"2024-01-15"`
if string(data) != expected {
t.Errorf("expected %s, got %s", expected, string(data))
}
// Unmarshal
var d2 SqlDate
if err := json.Unmarshal([]byte(`"2024-01-15"`), &d2); err != nil {
t.Fatalf("Unmarshal failed: %v", err)
}
}
// TestSqlTime tests SqlTime type
func TestSqlTime(t *testing.T) {
now := time.Now()
tests := []struct {
name string
input interface{}
expected string
}{
{"time.Time", now, now.Format("15:04:05")},
{"string time", "10:30:45", "10:30:45"},
{"string short time", "10:30", "10:30:00"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var tm SqlTime
if err := tm.Scan(tt.input); err != nil {
t.Fatalf("Scan failed: %v", err)
}
if tm.String() != tt.expected {
t.Errorf("expected %s, got %s", tt.expected, tm.String())
}
})
}
}
// TestSqlJSONB tests SqlJSONB type
func TestSqlJSONB_Scan(t *testing.T) {
tests := []struct {
name string
input interface{}
expected string
}{
{"string JSON object", `{"key":"value"}`, `{"key":"value"}`},
{"string JSON array", `[1,2,3]`, `[1,2,3]`},
{"bytes", []byte(`{"test":true}`), `{"test":true}`},
{"nil", nil, ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var j SqlJSONB
if err := j.Scan(tt.input); err != nil {
t.Fatalf("Scan failed: %v", err)
}
if tt.expected == "" && j == nil {
return // nil case
}
if string(j) != tt.expected {
t.Errorf("expected %s, got %s", tt.expected, string(j))
}
})
}
}
func TestSqlJSONB_Value(t *testing.T) {
tests := []struct {
name string
input SqlJSONB
expected string
wantErr bool
}{
{"valid object", SqlJSONB(`{"key":"value"}`), `{"key":"value"}`, false},
{"valid array", SqlJSONB(`[1,2,3]`), `[1,2,3]`, false},
{"empty", SqlJSONB{}, "", false},
{"nil", nil, "", false},
{"invalid JSON", SqlJSONB(`{invalid`), "", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
val, err := tt.input.Value()
if tt.wantErr {
if err == nil {
t.Error("expected error, got nil")
}
return
}
if err != nil {
t.Fatalf("Value failed: %v", err)
}
if tt.expected == "" && val == nil {
return // nil case
}
if val.(string) != tt.expected {
t.Errorf("expected %s, got %s", tt.expected, val)
}
})
}
}
func TestSqlJSONB_JSON(t *testing.T) {
// Marshal
j := SqlJSONB(`{"name":"test","count":42}`)
data, err := json.Marshal(j)
if err != nil {
t.Fatalf("Marshal failed: %v", err)
}
var result map[string]interface{}
if err := json.Unmarshal(data, &result); err != nil {
t.Fatalf("Unmarshal result failed: %v", err)
}
if result["name"] != "test" {
t.Errorf("expected name=test, got %v", result["name"])
}
// Unmarshal
var j2 SqlJSONB
if err := json.Unmarshal([]byte(`{"key":"value"}`), &j2); err != nil {
t.Fatalf("Unmarshal failed: %v", err)
}
if string(j2) != `{"key":"value"}` {
t.Errorf("expected {\"key\":\"value\"}, got %s", string(j2))
}
// Test null
var j3 SqlJSONB
if err := json.Unmarshal([]byte("null"), &j3); err != nil {
t.Fatalf("Unmarshal null failed: %v", err)
}
}
func TestSqlJSONB_AsMap(t *testing.T) {
tests := []struct {
name string
input SqlJSONB
wantErr bool
wantNil bool
}{
{"valid object", SqlJSONB(`{"name":"test","age":30}`), false, false},
{"empty", SqlJSONB{}, false, true},
{"nil", nil, false, true},
{"invalid JSON", SqlJSONB(`{invalid`), true, false},
{"array not object", SqlJSONB(`[1,2,3]`), true, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
m, err := tt.input.AsMap()
if tt.wantErr {
if err == nil {
t.Error("expected error, got nil")
}
return
}
if err != nil {
t.Fatalf("AsMap failed: %v", err)
}
if tt.wantNil {
if m != nil {
t.Errorf("expected nil, got %v", m)
}
return
}
if m == nil {
t.Error("expected non-nil map")
}
})
}
}
func TestSqlJSONB_AsSlice(t *testing.T) {
tests := []struct {
name string
input SqlJSONB
wantErr bool
wantNil bool
}{
{"valid array", SqlJSONB(`[1,2,3]`), false, false},
{"empty", SqlJSONB{}, false, true},
{"nil", nil, false, true},
{"invalid JSON", SqlJSONB(`[invalid`), true, false},
{"object not array", SqlJSONB(`{"key":"value"}`), true, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s, err := tt.input.AsSlice()
if tt.wantErr {
if err == nil {
t.Error("expected error, got nil")
}
return
}
if err != nil {
t.Fatalf("AsSlice failed: %v", err)
}
if tt.wantNil {
if s != nil {
t.Errorf("expected nil, got %v", s)
}
return
}
if s == nil {
t.Error("expected non-nil slice")
}
})
}
}
// TestSqlUUID tests SqlUUID type
func TestSqlUUID_Scan(t *testing.T) {
testUUID := uuid.New()
testUUIDStr := testUUID.String()
tests := []struct {
name string
input interface{}
expected string
valid bool
}{
{"string UUID", testUUIDStr, testUUIDStr, true},
{"bytes UUID", []byte(testUUIDStr), testUUIDStr, true},
{"nil", nil, "", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var u SqlUUID
if err := u.Scan(tt.input); err != nil {
t.Fatalf("Scan failed: %v", err)
}
if u.Valid != tt.valid {
t.Errorf("expected valid=%v, got valid=%v", tt.valid, u.Valid)
}
if tt.valid && u.String() != tt.expected {
t.Errorf("expected %s, got %s", tt.expected, u.String())
}
})
}
}
func TestSqlUUID_Value(t *testing.T) {
testUUID := uuid.New()
u := NewSqlUUID(testUUID)
val, err := u.Value()
if err != nil {
t.Fatalf("Value failed: %v", err)
}
if val != testUUID {
t.Errorf("expected %s, got %s", testUUID.String(), val)
}
// Test invalid UUID
u2 := SqlUUID{Valid: false}
val2, err := u2.Value()
if err != nil {
t.Fatalf("Value failed: %v", err)
}
if val2 != nil {
t.Errorf("expected nil, got %v", val2)
}
}
func TestSqlUUID_JSON(t *testing.T) {
testUUID := uuid.New()
u := NewSqlUUID(testUUID)
// Marshal
data, err := json.Marshal(u)
if err != nil {
t.Fatalf("Marshal failed: %v", err)
}
expected := `"` + testUUID.String() + `"`
if string(data) != expected {
t.Errorf("expected %s, got %s", expected, string(data))
}
// Unmarshal
var u2 SqlUUID
if err := json.Unmarshal([]byte(`"`+testUUID.String()+`"`), &u2); err != nil {
t.Fatalf("Unmarshal failed: %v", err)
}
if u2.String() != testUUID.String() {
t.Errorf("expected %s, got %s", testUUID.String(), u2.String())
}
// Test null
var u3 SqlUUID
if err := json.Unmarshal([]byte("null"), &u3); err != nil {
t.Fatalf("Unmarshal null failed: %v", err)
}
if u3.Valid {
t.Error("expected invalid UUID")
}
}
// TestTryIfInt64 tests the TryIfInt64 helper function
func TestTryIfInt64(t *testing.T) {
tests := []struct {
name string
input interface{}
def int64
expected int64
}{
{"string valid", "123", 0, 123},
{"string invalid", "abc", 99, 99},
{"int", 42, 0, 42},
{"int32", int32(100), 0, 100},
{"int64", int64(200), 0, 200},
{"uint32", uint32(50), 0, 50},
{"uint64", uint64(75), 0, 75},
{"float32", float32(3.14), 0, 3},
{"float64", float64(2.71), 0, 2},
{"bytes", []byte("456"), 0, 456},
{"unknown type", struct{}{}, 999, 999},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := TryIfInt64(tt.input, tt.def)
if result != tt.expected {
t.Errorf("expected %d, got %d", tt.expected, result)
}
})
}
}