mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2026-01-05 03:14:25 +00:00
[breaking] Another breaking change datatypes -> spectypes
This commit is contained in:
579
pkg/spectypes/sql_types.go
Normal file
579
pkg/spectypes/sql_types.go
Normal file
@@ -0,0 +1,579 @@
|
||||
// Package spectypes provides nullable SQL types with automatic casting and conversion methods.
|
||||
package spectypes
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// tryParseDT attempts to parse a string into a time.Time using various formats.
|
||||
func tryParseDT(str string) (time.Time, error) {
|
||||
var lasterror error
|
||||
tryFormats := []string{
|
||||
time.RFC3339,
|
||||
"2006-01-02T15:04:05.000-0700",
|
||||
"2006-01-02T15:04:05.000",
|
||||
"06-01-02T15:04:05.000",
|
||||
"2006-01-02T15:04:05",
|
||||
"2006-01-02 15:04:05",
|
||||
"02/01/2006",
|
||||
"02-01-2006",
|
||||
"2006-01-02",
|
||||
"15:04:05.000",
|
||||
"15:04:05",
|
||||
"15:04",
|
||||
}
|
||||
for _, f := range tryFormats {
|
||||
tx, err := time.Parse(f, str)
|
||||
if err == nil {
|
||||
return tx, nil
|
||||
}
|
||||
lasterror = err
|
||||
}
|
||||
return time.Time{}, lasterror // Return zero time on failure
|
||||
}
|
||||
|
||||
// ToJSONDT formats a time.Time to RFC3339 string.
|
||||
func ToJSONDT(dt time.Time) string {
|
||||
return dt.Format(time.RFC3339)
|
||||
}
|
||||
|
||||
// SqlNull is a generic nullable type that behaves like sql.NullXXX with auto-casting.
|
||||
type SqlNull[T any] struct {
|
||||
Val T
|
||||
Valid bool
|
||||
}
|
||||
|
||||
// Scan implements sql.Scanner.
|
||||
func (n *SqlNull[T]) Scan(value any) error {
|
||||
if value == nil {
|
||||
n.Valid = false
|
||||
n.Val = *new(T)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Try standard sql.Null[T] first.
|
||||
var sqlNull sql.Null[T]
|
||||
if err := sqlNull.Scan(value); err == nil {
|
||||
n.Val = sqlNull.V
|
||||
n.Valid = sqlNull.Valid
|
||||
return nil
|
||||
}
|
||||
|
||||
// Fallback: parse from string/bytes.
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
return n.FromString(v)
|
||||
case []byte:
|
||||
return n.FromString(string(v))
|
||||
default:
|
||||
return n.FromString(fmt.Sprintf("%v", value))
|
||||
}
|
||||
}
|
||||
func (n *SqlNull[T]) FromString(s string) error {
|
||||
s = strings.TrimSpace(s)
|
||||
n.Valid = false
|
||||
n.Val = *new(T)
|
||||
|
||||
if s == "" || strings.EqualFold(s, "null") {
|
||||
return nil
|
||||
}
|
||||
|
||||
var zero T
|
||||
switch any(zero).(type) {
|
||||
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
|
||||
if i, err := strconv.ParseInt(s, 10, 64); err == nil {
|
||||
reflect.ValueOf(&n.Val).Elem().SetInt(i)
|
||||
n.Valid = true
|
||||
}
|
||||
case float32, float64:
|
||||
if f, err := strconv.ParseFloat(s, 64); err == nil {
|
||||
reflect.ValueOf(&n.Val).Elem().SetFloat(f)
|
||||
n.Valid = true
|
||||
}
|
||||
case bool:
|
||||
if b, err := strconv.ParseBool(s); err == nil {
|
||||
n.Val = any(b).(T)
|
||||
n.Valid = true
|
||||
}
|
||||
case time.Time:
|
||||
if t, err := tryParseDT(s); err == nil && !t.IsZero() {
|
||||
n.Val = any(t).(T)
|
||||
n.Valid = true
|
||||
}
|
||||
case uuid.UUID:
|
||||
if u, err := uuid.Parse(s); err == nil {
|
||||
n.Val = any(u).(T)
|
||||
n.Valid = true
|
||||
}
|
||||
case string:
|
||||
n.Val = any(s).(T)
|
||||
n.Valid = true
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value implements driver.Valuer.
|
||||
func (n SqlNull[T]) Value() (driver.Value, error) {
|
||||
if !n.Valid {
|
||||
return nil, nil
|
||||
}
|
||||
return any(n.Val), nil
|
||||
}
|
||||
|
||||
// MarshalJSON implements json.Marshaler.
|
||||
func (n SqlNull[T]) MarshalJSON() ([]byte, error) {
|
||||
if !n.Valid {
|
||||
return []byte("null"), nil
|
||||
}
|
||||
return json.Marshal(n.Val)
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements json.Unmarshaler.
|
||||
func (n *SqlNull[T]) UnmarshalJSON(b []byte) error {
|
||||
if len(b) == 0 || string(b) == "null" || strings.TrimSpace(string(b)) == "" {
|
||||
n.Valid = false
|
||||
n.Val = *new(T)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Try direct unmarshal.
|
||||
var val T
|
||||
if err := json.Unmarshal(b, &val); err == nil {
|
||||
n.Val = val
|
||||
n.Valid = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// Fallback: unmarshal as string and parse.
|
||||
var s string
|
||||
if err := json.Unmarshal(b, &s); err == nil {
|
||||
return n.FromString(s)
|
||||
}
|
||||
|
||||
return fmt.Errorf("cannot unmarshal %s into SqlNull[%T]", b, n.Val)
|
||||
}
|
||||
|
||||
// String implements fmt.Stringer.
|
||||
func (n SqlNull[T]) String() string {
|
||||
if !n.Valid {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("%v", n.Val)
|
||||
}
|
||||
|
||||
// Int64 converts to int64 or 0 if invalid.
|
||||
func (n SqlNull[T]) Int64() int64 {
|
||||
if !n.Valid {
|
||||
return 0
|
||||
}
|
||||
v := reflect.ValueOf(any(n.Val))
|
||||
switch v.Kind() {
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
return v.Int()
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
return int64(v.Uint())
|
||||
case reflect.Float32, reflect.Float64:
|
||||
return int64(v.Float())
|
||||
case reflect.String:
|
||||
i, _ := strconv.ParseInt(v.String(), 10, 64)
|
||||
return i
|
||||
case reflect.Bool:
|
||||
if v.Bool() {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// Float64 converts to float64 or 0.0 if invalid.
|
||||
func (n SqlNull[T]) Float64() float64 {
|
||||
if !n.Valid {
|
||||
return 0.0
|
||||
}
|
||||
v := reflect.ValueOf(any(n.Val))
|
||||
switch v.Kind() {
|
||||
case reflect.Float32, reflect.Float64:
|
||||
return v.Float()
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
return float64(v.Int())
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
return float64(v.Uint())
|
||||
case reflect.String:
|
||||
f, _ := strconv.ParseFloat(v.String(), 64)
|
||||
return f
|
||||
}
|
||||
return 0.0
|
||||
}
|
||||
|
||||
// Bool converts to bool or false if invalid.
|
||||
func (n SqlNull[T]) Bool() bool {
|
||||
if !n.Valid {
|
||||
return false
|
||||
}
|
||||
v := reflect.ValueOf(any(n.Val))
|
||||
if v.Kind() == reflect.Bool {
|
||||
return v.Bool()
|
||||
}
|
||||
s := strings.ToLower(strings.TrimSpace(fmt.Sprint(n.Val)))
|
||||
return s == "true" || s == "t" || s == "1" || s == "yes" || s == "on"
|
||||
}
|
||||
|
||||
// Time converts to time.Time or zero if invalid.
|
||||
func (n SqlNull[T]) Time() time.Time {
|
||||
if !n.Valid {
|
||||
return time.Time{}
|
||||
}
|
||||
if t, ok := any(n.Val).(time.Time); ok {
|
||||
return t
|
||||
}
|
||||
return time.Time{}
|
||||
}
|
||||
|
||||
// UUID converts to uuid.UUID or Nil if invalid.
|
||||
func (n SqlNull[T]) UUID() uuid.UUID {
|
||||
if !n.Valid {
|
||||
return uuid.Nil
|
||||
}
|
||||
if u, ok := any(n.Val).(uuid.UUID); ok {
|
||||
return u
|
||||
}
|
||||
return uuid.Nil
|
||||
}
|
||||
|
||||
// Type aliases for common types.
|
||||
type (
|
||||
SqlInt16 = SqlNull[int16]
|
||||
SqlInt32 = SqlNull[int32]
|
||||
SqlInt64 = SqlNull[int64]
|
||||
SqlFloat64 = SqlNull[float64]
|
||||
SqlBool = SqlNull[bool]
|
||||
SqlString = SqlNull[string]
|
||||
SqlUUID = SqlNull[uuid.UUID]
|
||||
)
|
||||
|
||||
// SqlTimeStamp - Timestamp with custom formatting (YYYY-MM-DDTHH:MM:SS).
|
||||
type SqlTimeStamp struct{ SqlNull[time.Time] }
|
||||
|
||||
func (t SqlTimeStamp) MarshalJSON() ([]byte, error) {
|
||||
if !t.Valid || t.Val.IsZero() || t.Val.Before(time.Date(0002, 1, 1, 0, 0, 0, 0, time.UTC)) {
|
||||
return []byte("null"), nil
|
||||
}
|
||||
return []byte(fmt.Sprintf(`"%s"`, t.Val.Format("2006-01-02T15:04:05"))), nil
|
||||
}
|
||||
|
||||
func (t *SqlTimeStamp) UnmarshalJSON(b []byte) error {
|
||||
if err := t.SqlNull.UnmarshalJSON(b); err != nil {
|
||||
return err
|
||||
}
|
||||
if t.Valid && (t.Val.IsZero() || t.Val.Format("2006-01-02T15:04:05") == "0001-01-01T00:00:00") {
|
||||
t.Valid = false
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t SqlTimeStamp) Value() (driver.Value, error) {
|
||||
if !t.Valid || t.Val.IsZero() || t.Val.Before(time.Date(0002, 1, 1, 0, 0, 0, 0, time.UTC)) {
|
||||
return nil, nil
|
||||
}
|
||||
return t.Val.Format("2006-01-02T15:04:05"), nil
|
||||
}
|
||||
|
||||
func SqlTimeStampNow() SqlTimeStamp {
|
||||
return SqlTimeStamp{SqlNull: SqlNull[time.Time]{Val: time.Now(), Valid: true}}
|
||||
}
|
||||
|
||||
// SqlDate - Date only (YYYY-MM-DD).
|
||||
type SqlDate struct{ SqlNull[time.Time] }
|
||||
|
||||
func (d SqlDate) MarshalJSON() ([]byte, error) {
|
||||
if !d.Valid || d.Val.IsZero() {
|
||||
return []byte("null"), nil
|
||||
}
|
||||
s := d.Val.Format("2006-01-02")
|
||||
if strings.HasPrefix(s, "0001-01-01") {
|
||||
return []byte("null"), nil
|
||||
}
|
||||
return []byte(fmt.Sprintf(`"%s"`, s)), nil
|
||||
}
|
||||
|
||||
func (d *SqlDate) UnmarshalJSON(b []byte) error {
|
||||
if err := d.SqlNull.UnmarshalJSON(b); err != nil {
|
||||
return err
|
||||
}
|
||||
if d.Valid && d.Val.Format("2006-01-02") <= "0001-01-01" {
|
||||
d.Valid = false
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d SqlDate) Value() (driver.Value, error) {
|
||||
if !d.Valid || d.Val.IsZero() {
|
||||
return nil, nil
|
||||
}
|
||||
s := d.Val.Format("2006-01-02")
|
||||
if s <= "0001-01-01" {
|
||||
return nil, nil
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (d SqlDate) String() string {
|
||||
if !d.Valid {
|
||||
return ""
|
||||
}
|
||||
s := d.Val.Format("2006-01-02")
|
||||
if strings.HasPrefix(s, "0001-01-01") || strings.HasPrefix(s, "1800-12-31") {
|
||||
return ""
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func SqlDateNow() SqlDate {
|
||||
return SqlDate{SqlNull: SqlNull[time.Time]{Val: time.Now(), Valid: true}}
|
||||
}
|
||||
|
||||
// SqlTime - Time only (HH:MM:SS).
|
||||
type SqlTime struct{ SqlNull[time.Time] }
|
||||
|
||||
func (t SqlTime) MarshalJSON() ([]byte, error) {
|
||||
if !t.Valid || t.Val.IsZero() {
|
||||
return []byte("null"), nil
|
||||
}
|
||||
s := t.Val.Format("15:04:05")
|
||||
if s == "00:00:00" {
|
||||
return []byte("null"), nil
|
||||
}
|
||||
return []byte(fmt.Sprintf(`"%s"`, s)), nil
|
||||
}
|
||||
|
||||
func (t *SqlTime) UnmarshalJSON(b []byte) error {
|
||||
if err := t.SqlNull.UnmarshalJSON(b); err != nil {
|
||||
return err
|
||||
}
|
||||
if t.Valid && t.Val.Format("15:04:05") == "00:00:00" {
|
||||
t.Valid = false
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t SqlTime) Value() (driver.Value, error) {
|
||||
if !t.Valid || t.Val.IsZero() {
|
||||
return nil, nil
|
||||
}
|
||||
return t.Val.Format("15:04:05"), nil
|
||||
}
|
||||
|
||||
func (t SqlTime) String() string {
|
||||
if !t.Valid {
|
||||
return ""
|
||||
}
|
||||
return t.Val.Format("15:04:05")
|
||||
}
|
||||
|
||||
func SqlTimeNow() SqlTime {
|
||||
return SqlTime{SqlNull: SqlNull[time.Time]{Val: time.Now(), Valid: true}}
|
||||
}
|
||||
|
||||
// SqlJSONB - Nullable JSONB as []byte.
|
||||
type SqlJSONB []byte
|
||||
|
||||
// Scan implements sql.Scanner.
|
||||
func (n *SqlJSONB) Scan(value any) error {
|
||||
if value == nil {
|
||||
*n = nil
|
||||
return nil
|
||||
}
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
*n = []byte(v)
|
||||
case []byte:
|
||||
*n = v
|
||||
default:
|
||||
dat, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal value to JSON: %v", err)
|
||||
}
|
||||
*n = dat
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value implements driver.Valuer.
|
||||
func (n SqlJSONB) Value() (driver.Value, error) {
|
||||
if len(n) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
var js any
|
||||
if err := json.Unmarshal(n, &js); err != nil {
|
||||
return nil, fmt.Errorf("invalid JSON: %v", err)
|
||||
}
|
||||
return string(n), nil
|
||||
}
|
||||
|
||||
// MarshalJSON implements json.Marshaler.
|
||||
func (n SqlJSONB) MarshalJSON() ([]byte, error) {
|
||||
if len(n) == 0 {
|
||||
return []byte("null"), nil
|
||||
}
|
||||
var obj any
|
||||
if err := json.Unmarshal(n, &obj); err != nil {
|
||||
return []byte("null"), nil
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements json.Unmarshaler.
|
||||
func (n *SqlJSONB) UnmarshalJSON(b []byte) error {
|
||||
s := strings.TrimSpace(string(b))
|
||||
if s == "null" || s == "" || (!strings.HasPrefix(s, "{") && !strings.HasPrefix(s, "[")) {
|
||||
*n = nil
|
||||
return nil
|
||||
}
|
||||
*n = b
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n SqlJSONB) AsMap() (map[string]any, error) {
|
||||
if len(n) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
js := make(map[string]any)
|
||||
if err := json.Unmarshal(n, &js); err != nil {
|
||||
return nil, fmt.Errorf("invalid JSON: %v", err)
|
||||
}
|
||||
return js, nil
|
||||
}
|
||||
|
||||
func (n SqlJSONB) AsSlice() ([]any, error) {
|
||||
if len(n) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
js := make([]any, 0)
|
||||
if err := json.Unmarshal(n, &js); err != nil {
|
||||
return nil, fmt.Errorf("invalid JSON: %v", err)
|
||||
}
|
||||
return js, nil
|
||||
}
|
||||
|
||||
// TryIfInt64 tries to parse any value to int64 with default.
|
||||
func TryIfInt64(v any, def int64) int64 {
|
||||
switch val := v.(type) {
|
||||
case string:
|
||||
i, err := strconv.ParseInt(val, 10, 64)
|
||||
if err != nil {
|
||||
return def
|
||||
}
|
||||
return i
|
||||
case int:
|
||||
return int64(val)
|
||||
case int8:
|
||||
return int64(val)
|
||||
case int16:
|
||||
return int64(val)
|
||||
case int32:
|
||||
return int64(val)
|
||||
case int64:
|
||||
return val
|
||||
case uint:
|
||||
return int64(val)
|
||||
case uint8:
|
||||
return int64(val)
|
||||
case uint16:
|
||||
return int64(val)
|
||||
case uint32:
|
||||
return int64(val)
|
||||
case uint64:
|
||||
return int64(val)
|
||||
case float32:
|
||||
return int64(val)
|
||||
case float64:
|
||||
return int64(val)
|
||||
case []byte:
|
||||
i, err := strconv.ParseInt(string(val), 10, 64)
|
||||
if err != nil {
|
||||
return def
|
||||
}
|
||||
return i
|
||||
default:
|
||||
return def
|
||||
}
|
||||
}
|
||||
|
||||
// Constructor helpers - clean and fast value creation
|
||||
func Null[T any](v T, valid bool) SqlNull[T] {
|
||||
return SqlNull[T]{Val: v, Valid: valid}
|
||||
}
|
||||
|
||||
func NewSql[T any](value any) SqlNull[T] {
|
||||
n := SqlNull[T]{}
|
||||
|
||||
if value == nil {
|
||||
return n
|
||||
}
|
||||
|
||||
// Fast path: exact match
|
||||
if v, ok := value.(T); ok {
|
||||
n.Val = v
|
||||
n.Valid = true
|
||||
return n
|
||||
}
|
||||
|
||||
// Try from another SqlNull
|
||||
if sn, ok := value.(SqlNull[T]); ok {
|
||||
return sn
|
||||
}
|
||||
|
||||
// Convert via string
|
||||
_ = n.FromString(fmt.Sprintf("%v", value))
|
||||
return n
|
||||
}
|
||||
|
||||
func NewSqlInt16(v int16) SqlInt16 {
|
||||
return SqlInt16{Val: v, Valid: true}
|
||||
}
|
||||
|
||||
func NewSqlInt32(v int32) SqlInt32 {
|
||||
return SqlInt32{Val: v, Valid: true}
|
||||
}
|
||||
|
||||
func NewSqlInt64(v int64) SqlInt64 {
|
||||
return SqlInt64{Val: v, Valid: true}
|
||||
}
|
||||
|
||||
func NewSqlFloat64(v float64) SqlFloat64 {
|
||||
return SqlFloat64{Val: v, Valid: true}
|
||||
}
|
||||
|
||||
func NewSqlBool(v bool) SqlBool {
|
||||
return SqlBool{Val: v, Valid: true}
|
||||
}
|
||||
|
||||
func NewSqlString(v string) SqlString {
|
||||
return SqlString{Val: v, Valid: true}
|
||||
}
|
||||
|
||||
func NewSqlUUID(v uuid.UUID) SqlUUID {
|
||||
return SqlUUID{Val: v, Valid: true}
|
||||
}
|
||||
|
||||
func NewSqlTimeStamp(v time.Time) SqlTimeStamp {
|
||||
return SqlTimeStamp{SqlNull: SqlNull[time.Time]{Val: v, Valid: true}}
|
||||
}
|
||||
|
||||
func NewSqlDate(v time.Time) SqlDate {
|
||||
return SqlDate{SqlNull: SqlNull[time.Time]{Val: v, Valid: true}}
|
||||
}
|
||||
|
||||
func NewSqlTime(v time.Time) SqlTime {
|
||||
return SqlTime{SqlNull: SqlNull[time.Time]{Val: v, Valid: true}}
|
||||
}
|
||||
566
pkg/spectypes/sql_types_test.go
Normal file
566
pkg/spectypes/sql_types_test.go
Normal file
@@ -0,0 +1,566 @@
|
||||
package spectypes
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// TestNewSqlInt16 tests NewSqlInt16 type
|
||||
func TestNewSqlInt16(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
expected SqlInt16
|
||||
}{
|
||||
{"int", 42, Null(int16(42), true)},
|
||||
{"int32", int32(100), NewSqlInt16(100)},
|
||||
{"int64", int64(200), NewSqlInt16(200)},
|
||||
{"string", "123", NewSqlInt16(123)},
|
||||
{"nil", nil, Null(int16(0), false)},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var n SqlInt16
|
||||
if err := n.Scan(tt.input); err != nil {
|
||||
t.Fatalf("Scan failed: %v", err)
|
||||
}
|
||||
if n != tt.expected {
|
||||
t.Errorf("expected %v, got %v", tt.expected, n)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewSqlInt16_Value(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input SqlInt16
|
||||
expected driver.Value
|
||||
}{
|
||||
{"zero", Null(int16(0), false), nil},
|
||||
{"positive", NewSqlInt16(42), int16(42)},
|
||||
{"negative", NewSqlInt16(-10), int16(-10)},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
val, err := tt.input.Value()
|
||||
if err != nil {
|
||||
t.Fatalf("Value failed: %v", err)
|
||||
}
|
||||
if val != tt.expected {
|
||||
t.Errorf("expected %v, got %v", tt.expected, val)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewSqlInt16_JSON(t *testing.T) {
|
||||
n := NewSqlInt16(42)
|
||||
|
||||
// Marshal
|
||||
data, err := json.Marshal(n)
|
||||
if err != nil {
|
||||
t.Fatalf("Marshal failed: %v", err)
|
||||
}
|
||||
expected := "42"
|
||||
if string(data) != expected {
|
||||
t.Errorf("expected %s, got %s", expected, string(data))
|
||||
}
|
||||
|
||||
// Unmarshal
|
||||
var n2 SqlInt16
|
||||
if err := json.Unmarshal([]byte("123"), &n2); err != nil {
|
||||
t.Fatalf("Unmarshal failed: %v", err)
|
||||
}
|
||||
if n2.Int64() != 123 {
|
||||
t.Errorf("expected 123, got %d", n2.Int64())
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewSqlInt64 tests NewSqlInt64 type
|
||||
func TestNewSqlInt64(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
expected SqlInt64
|
||||
}{
|
||||
{"int", 42, NewSqlInt64(42)},
|
||||
{"int32", int32(100), NewSqlInt64(100)},
|
||||
{"int64", int64(9223372036854775807), NewSqlInt64(9223372036854775807)},
|
||||
{"uint32", uint32(100), NewSqlInt64(100)},
|
||||
{"uint64", uint64(200), NewSqlInt64(200)},
|
||||
{"nil", nil, SqlInt64{}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var n SqlInt64
|
||||
if err := n.Scan(tt.input); err != nil {
|
||||
t.Fatalf("Scan failed: %v", err)
|
||||
}
|
||||
if n != tt.expected {
|
||||
t.Errorf("expected %v, got %v", tt.expected, n)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSqlFloat64 tests SqlFloat64 type
|
||||
func TestSqlFloat64(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
expected float64
|
||||
valid bool
|
||||
}{
|
||||
{"float64", float64(3.14), 3.14, true},
|
||||
{"float32", float32(2.5), 2.5, true},
|
||||
{"int", 42, 42.0, true},
|
||||
{"int64", int64(100), 100.0, true},
|
||||
{"nil", nil, 0, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var n SqlFloat64
|
||||
if err := n.Scan(tt.input); err != nil {
|
||||
t.Fatalf("Scan failed: %v", err)
|
||||
}
|
||||
if n.Valid != tt.valid {
|
||||
t.Errorf("expected valid=%v, got valid=%v", tt.valid, n.Valid)
|
||||
}
|
||||
if tt.valid && n.Float64() != tt.expected {
|
||||
t.Errorf("expected %v, got %v", tt.expected, n.Float64())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSqlTimeStamp tests SqlTimeStamp type
|
||||
func TestSqlTimeStamp(t *testing.T) {
|
||||
now := time.Now()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
}{
|
||||
{"time.Time", now},
|
||||
{"string RFC3339", now.Format(time.RFC3339)},
|
||||
{"string date", "2024-01-15"},
|
||||
{"string datetime", "2024-01-15T10:30:00"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var ts SqlTimeStamp
|
||||
if err := ts.Scan(tt.input); err != nil {
|
||||
t.Fatalf("Scan failed: %v", err)
|
||||
}
|
||||
if ts.Time().IsZero() {
|
||||
t.Error("expected non-zero time")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlTimeStamp_JSON(t *testing.T) {
|
||||
now := time.Date(2024, 1, 15, 10, 30, 45, 0, time.UTC)
|
||||
ts := NewSqlTimeStamp(now)
|
||||
|
||||
// Marshal
|
||||
data, err := json.Marshal(ts)
|
||||
if err != nil {
|
||||
t.Fatalf("Marshal failed: %v", err)
|
||||
}
|
||||
expected := `"2024-01-15T10:30:45"`
|
||||
if string(data) != expected {
|
||||
t.Errorf("expected %s, got %s", expected, string(data))
|
||||
}
|
||||
|
||||
// Unmarshal
|
||||
var ts2 SqlTimeStamp
|
||||
if err := json.Unmarshal([]byte(`"2024-01-15T10:30:45"`), &ts2); err != nil {
|
||||
t.Fatalf("Unmarshal failed: %v", err)
|
||||
}
|
||||
if ts2.Time().Year() != 2024 {
|
||||
t.Errorf("expected year 2024, got %d", ts2.Time().Year())
|
||||
}
|
||||
|
||||
// Test null
|
||||
var ts3 SqlTimeStamp
|
||||
if err := json.Unmarshal([]byte("null"), &ts3); err != nil {
|
||||
t.Fatalf("Unmarshal null failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSqlDate tests SqlDate type
|
||||
func TestSqlDate(t *testing.T) {
|
||||
now := time.Now()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
}{
|
||||
{"time.Time", now},
|
||||
{"string date", "2024-01-15"},
|
||||
{"string UK format", "15/01/2024"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var d SqlDate
|
||||
if err := d.Scan(tt.input); err != nil {
|
||||
t.Fatalf("Scan failed: %v", err)
|
||||
}
|
||||
if d.String() == "0" {
|
||||
t.Error("expected non-zero date")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlDate_JSON(t *testing.T) {
|
||||
date := NewSqlDate(time.Date(2024, 1, 15, 0, 0, 0, 0, time.UTC))
|
||||
|
||||
// Marshal
|
||||
data, err := json.Marshal(date)
|
||||
if err != nil {
|
||||
t.Fatalf("Marshal failed: %v", err)
|
||||
}
|
||||
expected := `"2024-01-15"`
|
||||
if string(data) != expected {
|
||||
t.Errorf("expected %s, got %s", expected, string(data))
|
||||
}
|
||||
|
||||
// Unmarshal
|
||||
var d2 SqlDate
|
||||
if err := json.Unmarshal([]byte(`"2024-01-15"`), &d2); err != nil {
|
||||
t.Fatalf("Unmarshal failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSqlTime tests SqlTime type
|
||||
func TestSqlTime(t *testing.T) {
|
||||
now := time.Now()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
expected string
|
||||
}{
|
||||
{"time.Time", now, now.Format("15:04:05")},
|
||||
{"string time", "10:30:45", "10:30:45"},
|
||||
{"string short time", "10:30", "10:30:00"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var tm SqlTime
|
||||
if err := tm.Scan(tt.input); err != nil {
|
||||
t.Fatalf("Scan failed: %v", err)
|
||||
}
|
||||
if tm.String() != tt.expected {
|
||||
t.Errorf("expected %s, got %s", tt.expected, tm.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSqlJSONB tests SqlJSONB type
|
||||
func TestSqlJSONB_Scan(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
expected string
|
||||
}{
|
||||
{"string JSON object", `{"key":"value"}`, `{"key":"value"}`},
|
||||
{"string JSON array", `[1,2,3]`, `[1,2,3]`},
|
||||
{"bytes", []byte(`{"test":true}`), `{"test":true}`},
|
||||
{"nil", nil, ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var j SqlJSONB
|
||||
if err := j.Scan(tt.input); err != nil {
|
||||
t.Fatalf("Scan failed: %v", err)
|
||||
}
|
||||
if tt.expected == "" && j == nil {
|
||||
return // nil case
|
||||
}
|
||||
if string(j) != tt.expected {
|
||||
t.Errorf("expected %s, got %s", tt.expected, string(j))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlJSONB_Value(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input SqlJSONB
|
||||
expected string
|
||||
wantErr bool
|
||||
}{
|
||||
{"valid object", SqlJSONB(`{"key":"value"}`), `{"key":"value"}`, false},
|
||||
{"valid array", SqlJSONB(`[1,2,3]`), `[1,2,3]`, false},
|
||||
{"empty", SqlJSONB{}, "", false},
|
||||
{"nil", nil, "", false},
|
||||
{"invalid JSON", SqlJSONB(`{invalid`), "", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
val, err := tt.input.Value()
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Error("expected error, got nil")
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("Value failed: %v", err)
|
||||
}
|
||||
if tt.expected == "" && val == nil {
|
||||
return // nil case
|
||||
}
|
||||
if val.(string) != tt.expected {
|
||||
t.Errorf("expected %s, got %s", tt.expected, val)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlJSONB_JSON(t *testing.T) {
|
||||
// Marshal
|
||||
j := SqlJSONB(`{"name":"test","count":42}`)
|
||||
data, err := json.Marshal(j)
|
||||
if err != nil {
|
||||
t.Fatalf("Marshal failed: %v", err)
|
||||
}
|
||||
var result map[string]interface{}
|
||||
if err := json.Unmarshal(data, &result); err != nil {
|
||||
t.Fatalf("Unmarshal result failed: %v", err)
|
||||
}
|
||||
if result["name"] != "test" {
|
||||
t.Errorf("expected name=test, got %v", result["name"])
|
||||
}
|
||||
|
||||
// Unmarshal
|
||||
var j2 SqlJSONB
|
||||
if err := json.Unmarshal([]byte(`{"key":"value"}`), &j2); err != nil {
|
||||
t.Fatalf("Unmarshal failed: %v", err)
|
||||
}
|
||||
if string(j2) != `{"key":"value"}` {
|
||||
t.Errorf("expected {\"key\":\"value\"}, got %s", string(j2))
|
||||
}
|
||||
|
||||
// Test null
|
||||
var j3 SqlJSONB
|
||||
if err := json.Unmarshal([]byte("null"), &j3); err != nil {
|
||||
t.Fatalf("Unmarshal null failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlJSONB_AsMap(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input SqlJSONB
|
||||
wantErr bool
|
||||
wantNil bool
|
||||
}{
|
||||
{"valid object", SqlJSONB(`{"name":"test","age":30}`), false, false},
|
||||
{"empty", SqlJSONB{}, false, true},
|
||||
{"nil", nil, false, true},
|
||||
{"invalid JSON", SqlJSONB(`{invalid`), true, false},
|
||||
{"array not object", SqlJSONB(`[1,2,3]`), true, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
m, err := tt.input.AsMap()
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Error("expected error, got nil")
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("AsMap failed: %v", err)
|
||||
}
|
||||
if tt.wantNil {
|
||||
if m != nil {
|
||||
t.Errorf("expected nil, got %v", m)
|
||||
}
|
||||
return
|
||||
}
|
||||
if m == nil {
|
||||
t.Error("expected non-nil map")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlJSONB_AsSlice(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input SqlJSONB
|
||||
wantErr bool
|
||||
wantNil bool
|
||||
}{
|
||||
{"valid array", SqlJSONB(`[1,2,3]`), false, false},
|
||||
{"empty", SqlJSONB{}, false, true},
|
||||
{"nil", nil, false, true},
|
||||
{"invalid JSON", SqlJSONB(`[invalid`), true, false},
|
||||
{"object not array", SqlJSONB(`{"key":"value"}`), true, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s, err := tt.input.AsSlice()
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Error("expected error, got nil")
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("AsSlice failed: %v", err)
|
||||
}
|
||||
if tt.wantNil {
|
||||
if s != nil {
|
||||
t.Errorf("expected nil, got %v", s)
|
||||
}
|
||||
return
|
||||
}
|
||||
if s == nil {
|
||||
t.Error("expected non-nil slice")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSqlUUID tests SqlUUID type
|
||||
func TestSqlUUID_Scan(t *testing.T) {
|
||||
testUUID := uuid.New()
|
||||
testUUIDStr := testUUID.String()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
expected string
|
||||
valid bool
|
||||
}{
|
||||
{"string UUID", testUUIDStr, testUUIDStr, true},
|
||||
{"bytes UUID", []byte(testUUIDStr), testUUIDStr, true},
|
||||
{"nil", nil, "", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var u SqlUUID
|
||||
if err := u.Scan(tt.input); err != nil {
|
||||
t.Fatalf("Scan failed: %v", err)
|
||||
}
|
||||
if u.Valid != tt.valid {
|
||||
t.Errorf("expected valid=%v, got valid=%v", tt.valid, u.Valid)
|
||||
}
|
||||
if tt.valid && u.String() != tt.expected {
|
||||
t.Errorf("expected %s, got %s", tt.expected, u.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlUUID_Value(t *testing.T) {
|
||||
testUUID := uuid.New()
|
||||
u := NewSqlUUID(testUUID)
|
||||
|
||||
val, err := u.Value()
|
||||
if err != nil {
|
||||
t.Fatalf("Value failed: %v", err)
|
||||
}
|
||||
if val != testUUID {
|
||||
t.Errorf("expected %s, got %s", testUUID.String(), val)
|
||||
}
|
||||
|
||||
// Test invalid UUID
|
||||
u2 := SqlUUID{Valid: false}
|
||||
val2, err := u2.Value()
|
||||
if err != nil {
|
||||
t.Fatalf("Value failed: %v", err)
|
||||
}
|
||||
if val2 != nil {
|
||||
t.Errorf("expected nil, got %v", val2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlUUID_JSON(t *testing.T) {
|
||||
testUUID := uuid.New()
|
||||
u := NewSqlUUID(testUUID)
|
||||
|
||||
// Marshal
|
||||
data, err := json.Marshal(u)
|
||||
if err != nil {
|
||||
t.Fatalf("Marshal failed: %v", err)
|
||||
}
|
||||
expected := `"` + testUUID.String() + `"`
|
||||
if string(data) != expected {
|
||||
t.Errorf("expected %s, got %s", expected, string(data))
|
||||
}
|
||||
|
||||
// Unmarshal
|
||||
var u2 SqlUUID
|
||||
if err := json.Unmarshal([]byte(`"`+testUUID.String()+`"`), &u2); err != nil {
|
||||
t.Fatalf("Unmarshal failed: %v", err)
|
||||
}
|
||||
if u2.String() != testUUID.String() {
|
||||
t.Errorf("expected %s, got %s", testUUID.String(), u2.String())
|
||||
}
|
||||
|
||||
// Test null
|
||||
var u3 SqlUUID
|
||||
if err := json.Unmarshal([]byte("null"), &u3); err != nil {
|
||||
t.Fatalf("Unmarshal null failed: %v", err)
|
||||
}
|
||||
if u3.Valid {
|
||||
t.Error("expected invalid UUID")
|
||||
}
|
||||
}
|
||||
|
||||
// TestTryIfInt64 tests the TryIfInt64 helper function
|
||||
func TestTryIfInt64(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
def int64
|
||||
expected int64
|
||||
}{
|
||||
{"string valid", "123", 0, 123},
|
||||
{"string invalid", "abc", 99, 99},
|
||||
{"int", 42, 0, 42},
|
||||
{"int32", int32(100), 0, 100},
|
||||
{"int64", int64(200), 0, 200},
|
||||
{"uint32", uint32(50), 0, 50},
|
||||
{"uint64", uint64(75), 0, 75},
|
||||
{"float32", float32(3.14), 0, 3},
|
||||
{"float64", float64(2.71), 0, 2},
|
||||
{"bytes", []byte("456"), 0, 456},
|
||||
{"unknown type", struct{}{}, 999, 999},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := TryIfInt64(tt.input, tt.def)
|
||||
if result != tt.expected {
|
||||
t.Errorf("expected %d, got %d", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user