From c52afe282579f0a4f87afc15a95576d6570b2a68 Mon Sep 17 00:00:00 2001 From: Hein Date: Tue, 9 Dec 2025 13:14:22 +0200 Subject: [PATCH] Updated sql types --- pkg/common/sql_types.go | 962 ++++++++++++++--------------------- pkg/common/sql_types_test.go | 72 +-- 2 files changed, 411 insertions(+), 623 deletions(-) diff --git a/pkg/common/sql_types.go b/pkg/common/sql_types.go index d07e966..cef9a0d 100644 --- a/pkg/common/sql_types.go +++ b/pkg/common/sql_types.go @@ -1,3 +1,4 @@ +// Package common provides nullable SQL types with automatic casting and conversion methods. package common import ( @@ -5,6 +6,7 @@ import ( "database/sql/driver" "encoding/json" "fmt" + "reflect" "strconv" "strings" "time" @@ -12,9 +14,11 @@ import ( "github.com/google/uuid" ) +// tryParseDT attempts to parse a string into a time.Time using various formats. func tryParseDT(str string) (time.Time, error) { var lasterror error - tryFormats := []string{time.RFC3339, + tryFormats := []string{ + time.RFC3339, "2006-01-02T15:04:05.000-0700", "2006-01-02T15:04:05.000", "06-01-02T15:04:05.000", @@ -25,599 +29,431 @@ func tryParseDT(str string) (time.Time, error) { "2006-01-02", "15:04:05.000", "15:04:05", - "15:04"} - + "15:04", + } for _, f := range tryFormats { tx, err := time.Parse(f, str) if err == nil { return tx, nil - } else { - lasterror = err } + lasterror = err } - - return time.Now(), lasterror + return time.Time{}, lasterror // Return zero time on failure } +// ToJSONDT formats a time.Time to RFC3339 string. func ToJSONDT(dt time.Time) string { return dt.Format(time.RFC3339) } -// SqlInt16 - A Int16 that supports SQL string -type SqlInt16 int16 +// SqlNull is a generic nullable type that behaves like sql.NullXXX with auto-casting. +type SqlNull[T any] struct { + Val T + Valid bool +} -// Scan - -func (n *SqlInt16) Scan(value interface{}) error { +// Scan implements sql.Scanner. +func (n *SqlNull[T]) Scan(value any) error { if value == nil { - *n = 0 + n.Valid = false + n.Val = *new(T) return nil } + + // Try standard sql.Null[T] first. + var sqlNull sql.Null[T] + if err := sqlNull.Scan(value); err == nil { + n.Val = sqlNull.V + n.Valid = sqlNull.Valid + return nil + } + + // Fallback: parse from string/bytes. switch v := value.(type) { - case int: - *n = SqlInt16(v) - case int32: - *n = SqlInt16(v) - case int64: - *n = SqlInt16(v) + case string: + return n.fromString(v) + case []byte: + return n.fromString(string(v)) default: - i, _ := strconv.ParseInt(fmt.Sprintf("%v", v), 10, 64) - *n = SqlInt16(i) + return n.fromString(fmt.Sprintf("%v", value)) } - - return nil } -// Value - -func (n SqlInt16) Value() (driver.Value, error) { - if n == 0 { - return nil, nil - } - return int64(n), nil -} - -// String - Override String format of ZNullInt32 -func (n SqlInt16) String() string { - tmstr := fmt.Sprintf("%d", n) - return tmstr -} - -// UnmarshalJSON - Overre JidSON format of ZNullInt32 -func (n *SqlInt16) UnmarshalJSON(b []byte) error { - - s := strings.Trim(strings.Trim(string(b), " "), "\"") - - n64, err := strconv.ParseInt(s, 10, 64) - if err == nil { - *n = SqlInt16(n64) - } - - return nil -} - -// MarshalJSON - Override JSON format of time -func (n SqlInt16) MarshalJSON() ([]byte, error) { - return []byte(fmt.Sprintf("%d", n)), nil -} - -// SqlInt32 - A int32 that supports SQL string -type SqlInt32 int32 - -// Scan - -func (n *SqlInt32) Scan(value interface{}) error { - if value == nil { - *n = 0 - return nil - } - switch v := value.(type) { - case int: - *n = SqlInt32(v) - case int32: - *n = SqlInt32(v) - case int64: - *n = SqlInt32(v) - default: - i, _ := strconv.ParseInt(fmt.Sprintf("%v", v), 10, 64) - *n = SqlInt32(i) - } - - return nil -} - -// Value - -func (n SqlInt32) Value() (driver.Value, error) { - if n == 0 { - return nil, nil - } - return int64(n), nil -} - -// String - Override String format of ZNullInt32 -func (n SqlInt32) String() string { - tmstr := fmt.Sprintf("%d", n) - return tmstr -} - -// UnmarshalJSON - Overre JidSON format of ZNullInt32 -func (n *SqlInt32) UnmarshalJSON(b []byte) error { - - s := strings.Trim(strings.Trim(string(b), " "), "\"") - - n64, err := strconv.ParseInt(s, 10, 64) - if err == nil { - *n = SqlInt32(n64) - } - - return nil -} - -// MarshalJSON - Override JSON format of time -func (n SqlInt32) MarshalJSON() ([]byte, error) { - return []byte(fmt.Sprintf("%d", n)), nil -} - -// SqlInt64 - A int64 that supports SQL string -type SqlInt64 int64 - -// Scan - -func (n *SqlInt64) Scan(value interface{}) error { - if value == nil { - *n = 0 - return nil - } - switch v := value.(type) { - case int: - *n = SqlInt64(v) - case int32: - *n = SqlInt64(v) - case uint32: - *n = SqlInt64(v) - case int64: - *n = SqlInt64(v) - case uint64: - *n = SqlInt64(v) - default: - i, _ := strconv.ParseInt(fmt.Sprintf("%v", v), 10, 64) - *n = SqlInt64(i) - } - - return nil -} - -// Value - -func (n SqlInt64) Value() (driver.Value, error) { - if n == 0 { - return nil, nil - } - return int64(n), nil -} - -// String - Override String format of ZNullInt32 -func (n SqlInt64) String() string { - tmstr := fmt.Sprintf("%d", n) - return tmstr -} - -// UnmarshalJSON - Overre JidSON format of ZNullInt32 -func (n *SqlInt64) UnmarshalJSON(b []byte) error { - - s := strings.Trim(strings.Trim(string(b), " "), "\"") - - n64, err := strconv.ParseInt(s, 10, 64) - if err == nil { - *n = SqlInt64(n64) - } - - return nil -} - -// MarshalJSON - Override JSON format of time -func (n SqlInt64) MarshalJSON() ([]byte, error) { - return []byte(fmt.Sprintf("%d", n)), nil -} - -// SqlTimeStamp - Implementation of SqlTimeStamp with some interfaces. -type SqlTimeStamp time.Time - -// MarshalJSON - Override JSON format of time -func (t SqlTimeStamp) MarshalJSON() ([]byte, error) { - if time.Time(t).IsZero() { - return []byte("null"), nil - } - if time.Time(t).Before(time.Date(0001, 1, 1, 0, 0, 0, 0, time.UTC)) { - return []byte("null"), nil - } - tmstr := time.Time(t).Format("2006-01-02T15:04:05") - if tmstr == "0001-01-01T00:00:00" { - return []byte("null"), nil - } - return []byte(fmt.Sprintf("\"%s\"", tmstr)), nil -} - -// UnmarshalJSON - Override JSON format of time -func (t *SqlTimeStamp) UnmarshalJSON(b []byte) error { - var err error - - if b == nil { - - return nil - } - s := strings.Trim(strings.Trim(string(b), " "), "\"") - if s == "null" || s == "" || s == "0" || - s == "0001-01-01T00:00:00" || s == "0001-01-01" { +func (n *SqlNull[T]) fromString(s string) error { + s = strings.TrimSpace(s) + n.Valid = false + n.Val = *new(T) + if s == "" || strings.EqualFold(s, "null") { return nil } - tx, err := tryParseDT(s) - if err != nil { - return err - } - - *t = SqlTimeStamp(tx) - return err -} - -// Value - SQL Value of custom date -func (t SqlTimeStamp) Value() (driver.Value, error) { - if t.GetTime().IsZero() || t.GetTime().Before(time.Date(0002, 1, 1, 0, 0, 0, 0, time.UTC)) { - return nil, nil - } - tmstr := time.Time(t).Format("2006-01-02T15:04:05") - if tmstr <= "0001-01-01" || tmstr == "" { - empty := time.Time{} - return empty, nil - } - - return tmstr, nil -} - -// Scan - Scan custom date from sql -func (t *SqlTimeStamp) Scan(value interface{}) error { - tm, ok := value.(time.Time) - if ok { - *t = SqlTimeStamp(tm) - return nil - } - - str, ok := value.(string) - if ok { - tx, err := tryParseDT(str) - if err != nil { - return err + var zero T + switch any(zero).(type) { + case int, int8, int16, int32, int64: + if i, err := strconv.ParseInt(s, 10, 64); err == nil { + n.Val = any(int64(i)).(T) // Cast to T (e.g., int16) + n.Valid = true } - *t = SqlTimeStamp(tx) - } - - return nil -} - -// String - Override String format of time -func (t SqlTimeStamp) String() string { - return time.Time(t).Format("2006-01-02T15:04:05") -} - -// GetTime - Returns Time -func (t SqlTimeStamp) GetTime() time.Time { - return time.Time(t) -} - -// SetTime - Returns Time -func (t *SqlTimeStamp) SetTime(pTime time.Time) { - *t = SqlTimeStamp(pTime) -} - -// Format - Formats the time -func (t SqlTimeStamp) Format(layout string) string { - return time.Time(t).Format(layout) -} - -func SqlTimeStampNow() SqlTimeStamp { - tx := time.Now() - - return SqlTimeStamp(tx) -} - -// SqlFloat64 - SQL Int -type SqlFloat64 sql.NullFloat64 - -// Scan - -func (n *SqlFloat64) Scan(value interface{}) error { - newval := sql.NullFloat64{Float64: 0, Valid: false} - if value == nil { - newval.Valid = false - *n = SqlFloat64(newval) - return nil - } - switch v := value.(type) { - case int: - newval.Float64 = float64(v) - newval.Valid = true - case float64: - newval.Float64 = float64(v) - newval.Valid = true - case float32: - newval.Float64 = float64(v) - newval.Valid = true - case int64: - newval.Float64 = float64(v) - newval.Valid = true - case int32: - newval.Float64 = float64(v) - newval.Valid = true - case uint16: - newval.Float64 = float64(v) - newval.Valid = true - case uint64: - newval.Float64 = float64(v) - newval.Valid = true - case uint32: - newval.Float64 = float64(v) - newval.Valid = true - default: - i, err := strconv.ParseInt(fmt.Sprintf("%v", v), 10, 64) - newval.Float64 = float64(i) - if err == nil { - newval.Valid = false + case uint, uint8, uint16, uint32, uint64: + if u, err := strconv.ParseUint(s, 10, 64); err == nil { + n.Val = any(u).(T) + n.Valid = true } + case float32, float64: + if f, err := strconv.ParseFloat(s, 64); err == nil { + n.Val = any(f).(T) + n.Valid = true + } + case bool: + if b, err := strconv.ParseBool(s); err == nil { + n.Val = any(b).(T) + n.Valid = true + } + case time.Time: + if t, err := tryParseDT(s); err == nil && !t.IsZero() { + n.Val = any(t).(T) + n.Valid = true + } + case uuid.UUID: + if u, err := uuid.Parse(s); err == nil { + n.Val = any(u).(T) + n.Valid = true + } + case string: + n.Val = any(s).(T) + n.Valid = true } - *n = SqlFloat64(newval) return nil } -// Value - -func (n SqlFloat64) Value() (driver.Value, error) { +// Value implements driver.Valuer. +func (n SqlNull[T]) Value() (driver.Value, error) { if !n.Valid { return nil, nil } - return float64(n.Float64), nil + return any(n.Val), nil } -// String - -func (n SqlFloat64) String() string { +// MarshalJSON implements json.Marshaler. +func (n SqlNull[T]) MarshalJSON() ([]byte, error) { + if !n.Valid { + return []byte("null"), nil + } + return json.Marshal(n.Val) +} + +// UnmarshalJSON implements json.Unmarshaler. +func (n *SqlNull[T]) UnmarshalJSON(b []byte) error { + if len(b) == 0 || string(b) == "null" || strings.TrimSpace(string(b)) == "" { + n.Valid = false + n.Val = *new(T) + return nil + } + + // Try direct unmarshal. + var val T + if err := json.Unmarshal(b, &val); err == nil { + n.Val = val + n.Valid = true + return nil + } + + // Fallback: unmarshal as string and parse. + var s string + if err := json.Unmarshal(b, &s); err == nil { + return n.fromString(s) + } + + return fmt.Errorf("cannot unmarshal %s into SqlNull[%T]", b, n.Val) +} + +// String implements fmt.Stringer. +func (n SqlNull[T]) String() string { if !n.Valid { return "" } - tmstr := fmt.Sprintf("%f", n.Float64) - return tmstr + return fmt.Sprintf("%v", n.Val) } -// UnmarshalJSON - -func (n *SqlFloat64) UnmarshalJSON(b []byte) error { - - s := strings.Trim(strings.Trim(string(b), " "), "\"") - invalid := (s == "null" || s == "" || len(s) < 2) || (strings.Contains(s, "{") || strings.Contains(s, "[")) - if invalid { - return nil +// Int64 converts to int64 or 0 if invalid. +func (n SqlNull[T]) Int64() int64 { + if !n.Valid { + return 0 } + v := reflect.ValueOf(any(n.Val)) + switch v.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return v.Int() + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return int64(v.Uint()) + case reflect.Float32, reflect.Float64: + return int64(v.Float()) + case reflect.String: + i, _ := strconv.ParseInt(v.String(), 10, 64) + return i + case reflect.Bool: + if v.Bool() { + return 1 + } + return 0 + } + return 0 +} - nval, err := strconv.ParseInt(s, 10, 64) - if err != nil { +// Float64 converts to float64 or 0.0 if invalid. +func (n SqlNull[T]) Float64() float64 { + if !n.Valid { + return 0.0 + } + v := reflect.ValueOf(any(n.Val)) + switch v.Kind() { + case reflect.Float32, reflect.Float64: + return v.Float() + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return float64(v.Int()) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return float64(v.Uint()) + case reflect.String: + f, _ := strconv.ParseFloat(v.String(), 64) + return f + } + return 0.0 +} + +// Bool converts to bool or false if invalid. +func (n SqlNull[T]) Bool() bool { + if !n.Valid { + return false + } + v := reflect.ValueOf(any(n.Val)) + if v.Kind() == reflect.Bool { + return v.Bool() + } + s := strings.ToLower(strings.TrimSpace(fmt.Sprint(n.Val))) + return s == "true" || s == "t" || s == "1" || s == "yes" || s == "on" +} + +// Time converts to time.Time or zero if invalid. +func (n SqlNull[T]) Time() time.Time { + if !n.Valid { + return time.Time{} + } + if t, ok := any(n.Val).(time.Time); ok { + return t + } + return time.Time{} +} + +// UUID converts to uuid.UUID or Nil if invalid. +func (n SqlNull[T]) UUID() uuid.UUID { + if !n.Valid { + return uuid.Nil + } + if u, ok := any(n.Val).(uuid.UUID); ok { + return u + } + return uuid.Nil +} + +// Type aliases for common types. +type ( + SqlInt16 = SqlNull[int16] + SqlInt32 = SqlNull[int32] + SqlInt64 = SqlNull[int64] + SqlFloat64 = SqlNull[float64] + SqlBool = SqlNull[bool] + SqlString = SqlNull[string] + SqlUUID = SqlNull[uuid.UUID] +) + +// SqlTimeStamp - Timestamp with custom formatting (YYYY-MM-DDTHH:MM:SS). +type SqlTimeStamp struct{ SqlNull[time.Time] } + +func (t SqlTimeStamp) MarshalJSON() ([]byte, error) { + if !t.Valid || t.Val.IsZero() || t.Val.Before(time.Date(0002, 1, 1, 0, 0, 0, 0, time.UTC)) { + return []byte("null"), nil + } + return []byte(fmt.Sprintf(`"%s"`, t.Val.Format("2006-01-02T15:04:05"))), nil +} + +func (t *SqlTimeStamp) UnmarshalJSON(b []byte) error { + if err := t.SqlNull.UnmarshalJSON(b); err != nil { return err } - - *n = SqlFloat64(sql.NullFloat64{Valid: true, Float64: float64(nval)}) - + if t.Valid && (t.Val.IsZero() || t.Val.Format("2006-01-02T15:04:05") == "0001-01-01T00:00:00") { + t.Valid = false + } return nil } -// MarshalJSON - Override JSON format of time -func (n SqlFloat64) MarshalJSON() ([]byte, error) { - if !n.Valid { - return []byte("null"), nil - } - return []byte(fmt.Sprintf("%f", n.Float64)), nil -} - -// SqlDate - Implementation of SqlTime with some interfaces. -type SqlDate time.Time - -// UnmarshalJSON - Override JSON format of time -func (t *SqlDate) UnmarshalJSON(b []byte) error { - var err error - - s := strings.Trim(strings.Trim(string(b), " "), "\"") - if s == "null" || s == "" || s == "0" || - strings.HasPrefix(s, "0001-01-01T00:00:00") || - s == "0001-01-01" { - return nil - } - - tx, err := tryParseDT(s) - if err != nil { - return err - } - *t = SqlDate(tx) - return err -} - -// MarshalJSON - Override JSON format of time -func (t SqlDate) MarshalJSON() ([]byte, error) { - tmstr := time.Time(t).Format("2006-01-02") // time.RFC3339 - if strings.HasPrefix(tmstr, "0001-01-01") { - return []byte("null"), nil - } - return []byte(fmt.Sprintf("\"%s\"", tmstr)), nil -} - -// Value - SQL Value of custom date -func (t SqlDate) Value() (driver.Value, error) { - var s time.Time - tmstr := time.Time(t).Format("2006-01-02") - if strings.HasPrefix(tmstr, "0001-01-01") || tmstr <= "0001-01-01" { +func (t SqlTimeStamp) Value() (driver.Value, error) { + if !t.Valid || t.Val.IsZero() || t.Val.Before(time.Date(0002, 1, 1, 0, 0, 0, 0, time.UTC)) { return nil, nil } - s = time.Time(t) - - return s.Format("2006-01-02"), nil + return t.Val.Format("2006-01-02T15:04:05"), nil } -// Scan - Scan custom date from sql -func (t *SqlDate) Scan(value interface{}) error { - tm, ok := value.(time.Time) - if ok { - *t = SqlDate(tm) - return nil +func SqlTimeStampNow() SqlTimeStamp { + return SqlTimeStamp{SqlNull: SqlNull[time.Time]{Val: time.Now(), Valid: true}} +} + +// SqlDate - Date only (YYYY-MM-DD). +type SqlDate struct{ SqlNull[time.Time] } + +func (d SqlDate) MarshalJSON() ([]byte, error) { + if !d.Valid || d.Val.IsZero() { + return []byte("null"), nil } + s := d.Val.Format("2006-01-02") + if strings.HasPrefix(s, "0001-01-01") { + return []byte("null"), nil + } + return []byte(fmt.Sprintf(`"%s"`, s)), nil +} - str, ok := value.(string) - if ok { - tx, err := tryParseDT(str) - if err != nil { - return err - } - - *t = SqlDate(tx) +func (d *SqlDate) UnmarshalJSON(b []byte) error { + if err := d.SqlNull.UnmarshalJSON(b); err != nil { return err } - + if d.Valid && d.Val.Format("2006-01-02") <= "0001-01-01" { + d.Valid = false + } return nil } -// Int64 - Override date format in unix epoch -func (t SqlDate) Int64() int64 { - return time.Time(t).Unix() +func (d SqlDate) Value() (driver.Value, error) { + if !d.Valid || d.Val.IsZero() { + return nil, nil + } + s := d.Val.Format("2006-01-02") + if s <= "0001-01-01" { + return nil, nil + } + return s, nil } -// String - Override String format of time -func (t SqlDate) String() string { - tmstr := time.Time(t).Format("2006-01-02") // time.RFC3339 - if strings.HasPrefix(tmstr, "0001-01-01") || strings.HasPrefix(tmstr, "1800-12-31") { - return "0" +func (d SqlDate) String() string { + if !d.Valid { + return "" } - return tmstr + s := d.Val.Format("2006-01-02") + if strings.HasPrefix(s, "0001-01-01") || strings.HasPrefix(s, "1800-12-31") { + return "" + } + return s } func SqlDateNow() SqlDate { - tx := time.Now() - return SqlDate(tx) + return SqlDate{SqlNull: SqlNull[time.Time]{Val: time.Now(), Valid: true}} } -// ////////////////////// SqlTime ///////////////////////// -// SqlTime - Implementation of SqlTime with some interfaces. -type SqlTime time.Time +// SqlTime - Time only (HH:MM:SS). +type SqlTime struct{ SqlNull[time.Time] } -// Int64 - Override Time format in unix epoch -func (t SqlTime) Int64() int64 { - return time.Time(t).Unix() +func (t SqlTime) MarshalJSON() ([]byte, error) { + if !t.Valid || t.Val.IsZero() { + return []byte("null"), nil + } + s := t.Val.Format("15:04:05") + if s == "00:00:00" { + return []byte("null"), nil + } + return []byte(fmt.Sprintf(`"%s"`, s)), nil } -// String - Override String format of time -func (t SqlTime) String() string { - return time.Time(t).Format("15:04:05") -} - -// UnmarshalJSON - Override JSON format of time func (t *SqlTime) UnmarshalJSON(b []byte) error { - var err error - s := strings.Trim(strings.Trim(string(b), " "), "\"") - if s == "null" || s == "" || s == "0" || - s == "0001-01-01T00:00:00" || s == "00:00:00" { - *t = SqlTime{} - return nil - } - - tx, err := tryParseDT(s) - *t = SqlTime(tx) - - return err -} - -// Format - Format Function -func (t SqlTime) Format(form string) string { - tmstr := time.Time(t).Format(form) - return tmstr -} - -// Scan - Scan custom date from sql -func (t *SqlTime) Scan(value interface{}) error { - tm, ok := value.(time.Time) - if ok { - *t = SqlTime(tm) - return nil - } - - str, ok := value.(string) - if ok { - tx, err := tryParseDT(str) - *t = SqlTime(tx) + if err := t.SqlNull.UnmarshalJSON(b); err != nil { return err } - + if t.Valid && t.Val.Format("15:04:05") == "00:00:00" { + t.Valid = false + } return nil } -// Value - SQL Value of custom date func (t SqlTime) Value() (driver.Value, error) { - - s := time.Time(t) - st := s.Format("15:04:05") - - return st, nil + if !t.Valid || t.Val.IsZero() { + return nil, nil + } + return t.Val.Format("15:04:05"), nil } -// MarshalJSON - Override JSON format of time -func (t SqlTime) MarshalJSON() ([]byte, error) { - tmstr := time.Time(t).Format("15:04:05") - if tmstr == "0001-01-01T00:00:00" || tmstr == "00:00:00" { - return []byte("null"), nil +func (t SqlTime) String() string { + if !t.Valid { + return "" } - return []byte(fmt.Sprintf("\"%s\"", tmstr)), nil + return t.Val.Format("15:04:05") } func SqlTimeNow() SqlTime { - tx := time.Now() - return SqlTime(tx) + return SqlTime{SqlNull: SqlNull[time.Time]{Val: time.Now(), Valid: true}} } -// SqlJSONB - Nullable JSONB String +// SqlJSONB - Nullable JSONB as []byte. type SqlJSONB []byte -// Scan - Implements sql.Scanner for reading JSONB from database -func (n *SqlJSONB) Scan(value interface{}) error { +// Scan implements sql.Scanner. +func (n *SqlJSONB) Scan(value any) error { if value == nil { *n = nil return nil } - switch v := value.(type) { case string: - *n = SqlJSONB([]byte(v)) + *n = []byte(v) case []byte: - *n = SqlJSONB(v) + *n = v default: - // For other types, marshal to JSON dat, err := json.Marshal(value) if err != nil { return fmt.Errorf("failed to marshal value to JSON: %v", err) } - *n = SqlJSONB(dat) + *n = dat } - return nil } -// Value - Implements driver.Valuer for writing JSONB to database +// Value implements driver.Valuer. func (n SqlJSONB) Value() (driver.Value, error) { if len(n) == 0 { return nil, nil } - - // Validate that it's valid JSON before returning - var js interface{} + var js any if err := json.Unmarshal(n, &js); err != nil { return nil, fmt.Errorf("invalid JSON: %v", err) } - - // Return as string for PostgreSQL JSONB/JSON columns return string(n), nil } +// MarshalJSON implements json.Marshaler. +func (n SqlJSONB) MarshalJSON() ([]byte, error) { + if len(n) == 0 { + return []byte("null"), nil + } + var obj any + if err := json.Unmarshal(n, &obj); err != nil { + return []byte("null"), nil + } + return n, nil +} + +// UnmarshalJSON implements json.Unmarshaler. +func (n *SqlJSONB) UnmarshalJSON(b []byte) error { + s := strings.TrimSpace(string(b)) + if s == "null" || s == "" || (!strings.HasPrefix(s, "{") && !strings.HasPrefix(s, "[")) { + *n = nil + return nil + } + *n = b + return nil +} + func (n SqlJSONB) AsMap() (map[string]any, error) { if len(n) == 0 { return nil, nil } - // Validate that it's valid JSON before returning js := make(map[string]any) if err := json.Unmarshal(n, &js); err != nil { return nil, fmt.Errorf("invalid JSON: %v", err) @@ -629,7 +465,6 @@ func (n SqlJSONB) AsSlice() ([]any, error) { if len(n) == 0 { return nil, nil } - // Validate that it's valid JSON before returning js := make([]any, 0) if err := json.Unmarshal(n, &js); err != nil { return nil, fmt.Errorf("invalid JSON: %v", err) @@ -637,119 +472,31 @@ func (n SqlJSONB) AsSlice() ([]any, error) { return js, nil } -// UnmarshalJSON - Override JSON -func (n *SqlJSONB) UnmarshalJSON(b []byte) error { - - s := strings.Trim(strings.Trim(string(b), " "), "\"") - invalid := (s == "null" || s == "" || len(s) < 2) || (!strings.Contains(s, "{") && !strings.Contains(s, "[")) - if invalid { - return nil - } - - *n = []byte(s) - - return nil -} - -// MarshalJSON - Override JSON format of time -func (n SqlJSONB) MarshalJSON() ([]byte, error) { - if n == nil { - return []byte("null"), nil - } - var obj interface{} - err := json.Unmarshal(n, &obj) - if err != nil { - // fmt.Printf("Invalid JSON %v", err) - return []byte("null"), nil - } - - // dat, err := json.MarshalIndent(obj, " ", " ") - // if err != nil { - // return nil, fmt.Errorf("failed to convert to JSON: %v", err) - // } - dat := n - - return dat, nil -} - -// SqlUUID - Nullable UUID String -type SqlUUID sql.NullString - -// Scan - -func (n *SqlUUID) Scan(value interface{}) error { - str := sql.NullString{String: "", Valid: false} - if value == nil { - *n = SqlUUID(str) - return nil - } - switch v := value.(type) { - case string: - uuid, err := uuid.Parse(v) - if err == nil { - str.String = uuid.String() - str.Valid = true - *n = SqlUUID(str) - } - case []uint8: - uuid, err := uuid.ParseBytes(v) - if err == nil { - str.String = uuid.String() - str.Valid = true - *n = SqlUUID(str) - } - default: - uuid, err := uuid.Parse(fmt.Sprintf("%v", v)) - if err == nil { - str.String = uuid.String() - str.Valid = true - *n = SqlUUID(str) - } - } - - return nil -} - -// Value - -func (n SqlUUID) Value() (driver.Value, error) { - if !n.Valid { - return nil, nil - } - return n.String, nil -} - -// UnmarshalJSON - Override JSON -func (n *SqlUUID) UnmarshalJSON(b []byte) error { - - s := strings.Trim(strings.Trim(string(b), " "), "\"") - invalid := (s == "null" || s == "" || len(s) < 30) - if invalid { - return nil - } - *n = SqlUUID(sql.NullString{String: s, Valid: !invalid}) - - return nil -} - -// MarshalJSON - Override JSON format of time -func (n SqlUUID) MarshalJSON() ([]byte, error) { - if !n.Valid { - return []byte("null"), nil - } - return []byte(fmt.Sprintf("\"%s\"", n.String)), nil -} - -// TryIfInt64 - Wrapper function to quickly try and cast text to int +// TryIfInt64 tries to parse any value to int64 with default. func TryIfInt64(v any, def int64) int64 { - str := "" switch val := v.(type) { case string: - str = val + i, err := strconv.ParseInt(val, 10, 64) + if err != nil { + return def + } + return i case int: return int64(val) + case int8: + return int64(val) + case int16: + return int64(val) case int32: return int64(val) case int64: return val + case uint: + return int64(val) + case uint8: + return int64(val) + case uint16: + return int64(val) case uint32: return int64(val) case uint64: @@ -759,13 +506,54 @@ func TryIfInt64(v any, def int64) int64 { case float64: return int64(val) case []byte: - str = string(val) + i, err := strconv.ParseInt(string(val), 10, 64) + if err != nil { + return def + } + return i default: - str = fmt.Sprintf("%d", def) - } - val, err := strconv.ParseInt(str, 10, 64) - if err != nil { return def } - return val +} + +// Constructor helpers - clean and fast value creation + +func NewSqlInt16(v int16) SqlInt16 { + return SqlInt16{Val: v, Valid: true} +} + +func NewSqlInt32(v int32) SqlInt32 { + return SqlInt32{Val: v, Valid: true} +} + +func NewSqlInt64(v int64) SqlInt64 { + return SqlInt64{Val: v, Valid: true} +} + +func NewSqlFloat64(v float64) SqlFloat64 { + return SqlFloat64{Val: v, Valid: true} +} + +func NewSqlBool(v bool) SqlBool { + return SqlBool{Val: v, Valid: true} +} + +func NewSqlString(v string) SqlString { + return SqlString{Val: v, Valid: true} +} + +func NewSqlUUID(v uuid.UUID) SqlUUID { + return SqlUUID{Val: v, Valid: true} +} + +func NewSqlTimeStamp(v time.Time) SqlTimeStamp { + return SqlTimeStamp{SqlNull: SqlNull[time.Time]{Val: v, Valid: true}} +} + +func NewSqlDate(v time.Time) SqlDate { + return SqlDate{SqlNull: SqlNull[time.Time]{Val: v, Valid: true}} +} + +func NewSqlTime(v time.Time) SqlTime { + return SqlTime{SqlNull: SqlNull[time.Time]{Val: v, Valid: true}} } diff --git a/pkg/common/sql_types_test.go b/pkg/common/sql_types_test.go index f3ec642..0117e52 100644 --- a/pkg/common/sql_types_test.go +++ b/pkg/common/sql_types_test.go @@ -9,18 +9,18 @@ import ( "github.com/google/uuid" ) -// TestSqlInt16 tests SqlInt16 type -func TestSqlInt16(t *testing.T) { +// TestNewSqlInt16 tests NewSqlInt16 type +func TestNewSqlInt16(t *testing.T) { tests := []struct { name string input interface{} expected SqlInt16 }{ - {"int", 42, SqlInt16(42)}, - {"int32", int32(100), SqlInt16(100)}, - {"int64", int64(200), SqlInt16(200)}, - {"string", "123", SqlInt16(123)}, - {"nil", nil, SqlInt16(0)}, + {"int", 42, NewSqlInt16(42)}, + {"int32", int32(100), NewSqlInt16(100)}, + {"int64", int64(200), NewSqlInt16(200)}, + {"string", "123", NewSqlInt16(123)}, + {"nil", nil, NewSqlInt16(0)}, } for _, tt := range tests { @@ -36,15 +36,15 @@ func TestSqlInt16(t *testing.T) { } } -func TestSqlInt16_Value(t *testing.T) { +func TestNewSqlInt16_Value(t *testing.T) { tests := []struct { name string input SqlInt16 expected driver.Value }{ - {"zero", SqlInt16(0), nil}, - {"positive", SqlInt16(42), int64(42)}, - {"negative", SqlInt16(-10), int64(-10)}, + {"zero", NewSqlInt16(0), nil}, + {"positive", NewSqlInt16(42), int64(42)}, + {"negative", NewSqlInt16(-10), int64(-10)}, } for _, tt := range tests { @@ -60,8 +60,8 @@ func TestSqlInt16_Value(t *testing.T) { } } -func TestSqlInt16_JSON(t *testing.T) { - n := SqlInt16(42) +func TestNewSqlInt16_JSON(t *testing.T) { + n := NewSqlInt16(42) // Marshal data, err := json.Marshal(n) @@ -78,24 +78,24 @@ func TestSqlInt16_JSON(t *testing.T) { if err := json.Unmarshal([]byte("123"), &n2); err != nil { t.Fatalf("Unmarshal failed: %v", err) } - if n2 != 123 { - t.Errorf("expected 123, got %d", n2) + if n2.Int64() != 123 { + t.Errorf("expected 123, got %d", n2.Int64()) } } -// TestSqlInt64 tests SqlInt64 type -func TestSqlInt64(t *testing.T) { +// TestNewSqlInt64 tests NewSqlInt64 type +func TestNewSqlInt64(t *testing.T) { tests := []struct { name string input interface{} expected SqlInt64 }{ - {"int", 42, SqlInt64(42)}, - {"int32", int32(100), SqlInt64(100)}, - {"int64", int64(9223372036854775807), SqlInt64(9223372036854775807)}, - {"uint32", uint32(100), SqlInt64(100)}, - {"uint64", uint64(200), SqlInt64(200)}, - {"nil", nil, SqlInt64(0)}, + {"int", 42, NewSqlInt64(42)}, + {"int32", int32(100), NewSqlInt64(100)}, + {"int64", int64(9223372036854775807), NewSqlInt64(9223372036854775807)}, + {"uint32", uint32(100), NewSqlInt64(100)}, + {"uint64", uint64(200), NewSqlInt64(200)}, + {"nil", nil, SqlInt64{}}, } for _, tt := range tests { @@ -135,8 +135,8 @@ func TestSqlFloat64(t *testing.T) { if n.Valid != tt.valid { t.Errorf("expected valid=%v, got valid=%v", tt.valid, n.Valid) } - if tt.valid && n.Float64 != tt.expected { - t.Errorf("expected %v, got %v", tt.expected, n.Float64) + if tt.valid && n.Float64() != tt.expected { + t.Errorf("expected %v, got %v", tt.expected, n.Float64()) } }) } @@ -162,7 +162,7 @@ func TestSqlTimeStamp(t *testing.T) { if err := ts.Scan(tt.input); err != nil { t.Fatalf("Scan failed: %v", err) } - if ts.GetTime().IsZero() { + if ts.Time().IsZero() { t.Error("expected non-zero time") } }) @@ -171,7 +171,7 @@ func TestSqlTimeStamp(t *testing.T) { func TestSqlTimeStamp_JSON(t *testing.T) { now := time.Date(2024, 1, 15, 10, 30, 45, 0, time.UTC) - ts := SqlTimeStamp(now) + ts := NewSqlTimeStamp(now) // Marshal data, err := json.Marshal(ts) @@ -188,8 +188,8 @@ func TestSqlTimeStamp_JSON(t *testing.T) { if err := json.Unmarshal([]byte(`"2024-01-15T10:30:45"`), &ts2); err != nil { t.Fatalf("Unmarshal failed: %v", err) } - if ts2.GetTime().Year() != 2024 { - t.Errorf("expected year 2024, got %d", ts2.GetTime().Year()) + if ts2.Time().Year() != 2024 { + t.Errorf("expected year 2024, got %d", ts2.Time().Year()) } // Test null @@ -226,7 +226,7 @@ func TestSqlDate(t *testing.T) { } func TestSqlDate_JSON(t *testing.T) { - date := SqlDate(time.Date(2024, 1, 15, 0, 0, 0, 0, time.UTC)) + date := NewSqlDate(time.Date(2024, 1, 15, 0, 0, 0, 0, time.UTC)) // Marshal data, err := json.Marshal(date) @@ -471,8 +471,8 @@ func TestSqlUUID_Scan(t *testing.T) { if u.Valid != tt.valid { t.Errorf("expected valid=%v, got valid=%v", tt.valid, u.Valid) } - if tt.valid && u.String != tt.expected { - t.Errorf("expected %s, got %s", tt.expected, u.String) + if tt.valid && u.String() != tt.expected { + t.Errorf("expected %s, got %s", tt.expected, u.String()) } }) } @@ -480,7 +480,7 @@ func TestSqlUUID_Scan(t *testing.T) { func TestSqlUUID_Value(t *testing.T) { testUUID := uuid.New() - u := SqlUUID{String: testUUID.String(), Valid: true} + u := NewSqlUUID(testUUID) val, err := u.Value() if err != nil { @@ -503,7 +503,7 @@ func TestSqlUUID_Value(t *testing.T) { func TestSqlUUID_JSON(t *testing.T) { testUUID := uuid.New() - u := SqlUUID{String: testUUID.String(), Valid: true} + u := NewSqlUUID(testUUID) // Marshal data, err := json.Marshal(u) @@ -520,8 +520,8 @@ func TestSqlUUID_JSON(t *testing.T) { if err := json.Unmarshal([]byte(`"`+testUUID.String()+`"`), &u2); err != nil { t.Fatalf("Unmarshal failed: %v", err) } - if u2.String != testUUID.String() { - t.Errorf("expected %s, got %s", testUUID.String(), u2.String) + if u2.String() != testUUID.String() { + t.Errorf("expected %s, got %s", testUUID.String(), u2.String()) } // Test null