mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2025-12-31 00:34:25 +00:00
Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
07b09e2025 | ||
|
|
3d5334002d | ||
|
|
640582d508 | ||
|
|
b0b3ae662b |
@@ -319,6 +319,10 @@ func (b *BunInsertQuery) Model(model interface{}) common.InsertQuery {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (b *BunInsertQuery) Table(table string) common.InsertQuery {
|
func (b *BunInsertQuery) Table(table string) common.InsertQuery {
|
||||||
|
if b.hasModel {
|
||||||
|
// If model is set, do not override table name
|
||||||
|
return b
|
||||||
|
}
|
||||||
b.query = b.query.Table(table)
|
b.query = b.query.Table(table)
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|||||||
774
pkg/common/sql_types.go
Normal file
774
pkg/common/sql_types.go
Normal file
@@ -0,0 +1,774 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"database/sql/driver"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
func tryParseDT(str string) (time.Time, error) {
|
||||||
|
var lasterror error
|
||||||
|
tryFormats := []string{time.RFC3339,
|
||||||
|
"2006-01-02T15:04:05.000-0700",
|
||||||
|
"2006-01-02T15:04:05.000",
|
||||||
|
"06-01-02T15:04:05.000",
|
||||||
|
"2006-01-02T15:04:05",
|
||||||
|
"2006-01-02 15:04:05",
|
||||||
|
"02/01/2006",
|
||||||
|
"02-01-2006",
|
||||||
|
"2006-01-02",
|
||||||
|
"15:04:05.000",
|
||||||
|
"15:04:05",
|
||||||
|
"15:04"}
|
||||||
|
|
||||||
|
for _, f := range tryFormats {
|
||||||
|
tx, err := time.Parse(f, str)
|
||||||
|
if err == nil {
|
||||||
|
return tx, nil
|
||||||
|
} else {
|
||||||
|
lasterror = err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return time.Now(), lasterror
|
||||||
|
}
|
||||||
|
|
||||||
|
func ToJSONDT(dt time.Time) string {
|
||||||
|
return dt.Format(time.RFC3339)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SqlInt16 - A Int16 that supports SQL string
|
||||||
|
type SqlInt16 int16
|
||||||
|
|
||||||
|
// Scan -
|
||||||
|
func (n *SqlInt16) Scan(value interface{}) error {
|
||||||
|
if value == nil {
|
||||||
|
*n = 0
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
switch v := value.(type) {
|
||||||
|
case int:
|
||||||
|
*n = SqlInt16(v)
|
||||||
|
case int32:
|
||||||
|
*n = SqlInt16(v)
|
||||||
|
case int64:
|
||||||
|
*n = SqlInt16(v)
|
||||||
|
default:
|
||||||
|
i, _ := strconv.ParseInt(fmt.Sprintf("%v", v), 10, 64)
|
||||||
|
*n = SqlInt16(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Value -
|
||||||
|
func (n SqlInt16) Value() (driver.Value, error) {
|
||||||
|
if n == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return int64(n), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// String - Override String format of ZNullInt32
|
||||||
|
func (n SqlInt16) String() string {
|
||||||
|
tmstr := fmt.Sprintf("%d", n)
|
||||||
|
return tmstr
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON - Overre JidSON format of ZNullInt32
|
||||||
|
func (n *SqlInt16) UnmarshalJSON(b []byte) error {
|
||||||
|
|
||||||
|
s := strings.Trim(strings.Trim(string(b), " "), "\"")
|
||||||
|
|
||||||
|
n64, err := strconv.ParseInt(s, 10, 64)
|
||||||
|
if err == nil {
|
||||||
|
*n = SqlInt16(n64)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON - Override JSON format of time
|
||||||
|
func (n SqlInt16) MarshalJSON() ([]byte, error) {
|
||||||
|
return []byte(fmt.Sprintf("%d", n)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SqlInt32 - A int32 that supports SQL string
|
||||||
|
type SqlInt32 int32
|
||||||
|
|
||||||
|
// Scan -
|
||||||
|
func (n *SqlInt32) Scan(value interface{}) error {
|
||||||
|
if value == nil {
|
||||||
|
*n = 0
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
switch v := value.(type) {
|
||||||
|
case int:
|
||||||
|
*n = SqlInt32(v)
|
||||||
|
case int32:
|
||||||
|
*n = SqlInt32(v)
|
||||||
|
case int64:
|
||||||
|
*n = SqlInt32(v)
|
||||||
|
default:
|
||||||
|
i, _ := strconv.ParseInt(fmt.Sprintf("%v", v), 10, 64)
|
||||||
|
*n = SqlInt32(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Value -
|
||||||
|
func (n SqlInt32) Value() (driver.Value, error) {
|
||||||
|
if n == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return int64(n), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// String - Override String format of ZNullInt32
|
||||||
|
func (n SqlInt32) String() string {
|
||||||
|
tmstr := fmt.Sprintf("%d", n)
|
||||||
|
return tmstr
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON - Overre JidSON format of ZNullInt32
|
||||||
|
func (n *SqlInt32) UnmarshalJSON(b []byte) error {
|
||||||
|
|
||||||
|
s := strings.Trim(strings.Trim(string(b), " "), "\"")
|
||||||
|
|
||||||
|
n64, err := strconv.ParseInt(s, 10, 64)
|
||||||
|
if err == nil {
|
||||||
|
*n = SqlInt32(n64)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON - Override JSON format of time
|
||||||
|
func (n SqlInt32) MarshalJSON() ([]byte, error) {
|
||||||
|
return []byte(fmt.Sprintf("%d", n)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SqlInt64 - A int64 that supports SQL string
|
||||||
|
type SqlInt64 int64
|
||||||
|
|
||||||
|
// Scan -
|
||||||
|
func (n *SqlInt64) Scan(value interface{}) error {
|
||||||
|
if value == nil {
|
||||||
|
*n = 0
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
switch v := value.(type) {
|
||||||
|
case int:
|
||||||
|
*n = SqlInt64(v)
|
||||||
|
case int32:
|
||||||
|
*n = SqlInt64(v)
|
||||||
|
case uint32:
|
||||||
|
*n = SqlInt64(v)
|
||||||
|
case int64:
|
||||||
|
*n = SqlInt64(v)
|
||||||
|
case uint64:
|
||||||
|
*n = SqlInt64(v)
|
||||||
|
default:
|
||||||
|
i, _ := strconv.ParseInt(fmt.Sprintf("%v", v), 10, 64)
|
||||||
|
*n = SqlInt64(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Value -
|
||||||
|
func (n SqlInt64) Value() (driver.Value, error) {
|
||||||
|
if n == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return int64(n), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// String - Override String format of ZNullInt32
|
||||||
|
func (n SqlInt64) String() string {
|
||||||
|
tmstr := fmt.Sprintf("%d", n)
|
||||||
|
return tmstr
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON - Overre JidSON format of ZNullInt32
|
||||||
|
func (n *SqlInt64) UnmarshalJSON(b []byte) error {
|
||||||
|
|
||||||
|
s := strings.Trim(strings.Trim(string(b), " "), "\"")
|
||||||
|
|
||||||
|
n64, err := strconv.ParseInt(s, 10, 64)
|
||||||
|
if err == nil {
|
||||||
|
*n = SqlInt64(n64)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON - Override JSON format of time
|
||||||
|
func (n SqlInt64) MarshalJSON() ([]byte, error) {
|
||||||
|
return []byte(fmt.Sprintf("%d", n)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SqlTimeStamp - Implementation of SqlTimeStamp with some interfaces.
|
||||||
|
type SqlTimeStamp time.Time
|
||||||
|
|
||||||
|
// MarshalJSON - Override JSON format of time
|
||||||
|
func (t SqlTimeStamp) MarshalJSON() ([]byte, error) {
|
||||||
|
if time.Time(t).IsZero() {
|
||||||
|
return []byte("null"), nil
|
||||||
|
}
|
||||||
|
if time.Time(t).Before(time.Date(0001, 1, 1, 0, 0, 0, 0, time.UTC)) {
|
||||||
|
return []byte("null"), nil
|
||||||
|
}
|
||||||
|
tmstr := time.Time(t).Format("2006-01-02T15:04:05")
|
||||||
|
if tmstr == "0001-01-01T00:00:00" {
|
||||||
|
return []byte("null"), nil
|
||||||
|
}
|
||||||
|
return []byte(fmt.Sprintf("\"%s\"", tmstr)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON - Override JSON format of time
|
||||||
|
func (t *SqlTimeStamp) UnmarshalJSON(b []byte) error {
|
||||||
|
var err error
|
||||||
|
|
||||||
|
if b == nil {
|
||||||
|
t = &SqlTimeStamp{}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
s := strings.Trim(strings.Trim(string(b), " "), "\"")
|
||||||
|
if s == "null" || s == "" || s == "0" ||
|
||||||
|
s == "0001-01-01T00:00:00" || s == "0001-01-01" {
|
||||||
|
t = &SqlTimeStamp{}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
tx, err := tryParseDT(s)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
*t = SqlTimeStamp(tx)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Value - SQL Value of custom date
|
||||||
|
func (t SqlTimeStamp) Value() (driver.Value, error) {
|
||||||
|
if t.GetTime().IsZero() || t.GetTime().Before(time.Date(0002, 1, 1, 0, 0, 0, 0, time.UTC)) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
tmstr := time.Time(t).Format("2006-01-02T15:04:05")
|
||||||
|
if tmstr <= "0001-01-01" || tmstr == "" {
|
||||||
|
empty := time.Time{}
|
||||||
|
return empty, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return tmstr, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scan - Scan custom date from sql
|
||||||
|
func (t *SqlTimeStamp) Scan(value interface{}) error {
|
||||||
|
tm, ok := value.(time.Time)
|
||||||
|
if ok {
|
||||||
|
*t = SqlTimeStamp(tm)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
str, ok := value.(string)
|
||||||
|
if ok {
|
||||||
|
tx, err := tryParseDT(str)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
*t = SqlTimeStamp(tx)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// String - Override String format of time
|
||||||
|
func (t SqlTimeStamp) String() string {
|
||||||
|
return fmt.Sprintf("%s", time.Time(t).Format("2006-01-02T15:04:05"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetTime - Returns Time
|
||||||
|
func (t SqlTimeStamp) GetTime() time.Time {
|
||||||
|
return time.Time(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetTime - Returns Time
|
||||||
|
func (t *SqlTimeStamp) SetTime(pTime time.Time) {
|
||||||
|
*t = SqlTimeStamp(pTime)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Format - Formats the time
|
||||||
|
func (t SqlTimeStamp) Format(layout string) string {
|
||||||
|
return fmt.Sprintf("%s", time.Time(t).Format(layout))
|
||||||
|
}
|
||||||
|
|
||||||
|
func SqlTimeStampNow() SqlTimeStamp {
|
||||||
|
tx := time.Now()
|
||||||
|
|
||||||
|
return SqlTimeStamp(tx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SqlFloat64 - SQL Int
|
||||||
|
type SqlFloat64 sql.NullFloat64
|
||||||
|
|
||||||
|
// Scan -
|
||||||
|
func (n *SqlFloat64) Scan(value interface{}) error {
|
||||||
|
newval := sql.NullFloat64{Float64: 0, Valid: false}
|
||||||
|
if value == nil {
|
||||||
|
newval.Valid = false
|
||||||
|
*n = SqlFloat64(newval)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
switch v := value.(type) {
|
||||||
|
case int:
|
||||||
|
newval.Float64 = float64(v)
|
||||||
|
newval.Valid = true
|
||||||
|
case float64:
|
||||||
|
newval.Float64 = float64(v)
|
||||||
|
newval.Valid = true
|
||||||
|
case float32:
|
||||||
|
newval.Float64 = float64(v)
|
||||||
|
newval.Valid = true
|
||||||
|
case int64:
|
||||||
|
newval.Float64 = float64(v)
|
||||||
|
newval.Valid = true
|
||||||
|
case int32:
|
||||||
|
newval.Float64 = float64(v)
|
||||||
|
newval.Valid = true
|
||||||
|
case uint16:
|
||||||
|
newval.Float64 = float64(v)
|
||||||
|
newval.Valid = true
|
||||||
|
case uint64:
|
||||||
|
newval.Float64 = float64(v)
|
||||||
|
newval.Valid = true
|
||||||
|
case uint32:
|
||||||
|
newval.Float64 = float64(v)
|
||||||
|
newval.Valid = true
|
||||||
|
default:
|
||||||
|
i, err := strconv.ParseInt(fmt.Sprintf("%v", v), 10, 64)
|
||||||
|
newval.Float64 = float64(i)
|
||||||
|
if err == nil {
|
||||||
|
newval.Valid = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
*n = SqlFloat64(newval)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Value -
|
||||||
|
func (n SqlFloat64) Value() (driver.Value, error) {
|
||||||
|
if !n.Valid {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return float64(n.Float64), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// String -
|
||||||
|
func (n SqlFloat64) String() string {
|
||||||
|
if !n.Valid {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
tmstr := fmt.Sprintf("%f", n.Float64)
|
||||||
|
return tmstr
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON -
|
||||||
|
func (n *SqlFloat64) UnmarshalJSON(b []byte) error {
|
||||||
|
|
||||||
|
s := strings.Trim(strings.Trim(string(b), " "), "\"")
|
||||||
|
invalid := (s == "null" || s == "" || len(s) < 2) || (strings.Contains(s, "{") || strings.Contains(s, "["))
|
||||||
|
if invalid {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
nval, err := strconv.ParseInt(s, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
*n = SqlFloat64(sql.NullFloat64{Valid: true, Float64: float64(nval)})
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON - Override JSON format of time
|
||||||
|
func (n SqlFloat64) MarshalJSON() ([]byte, error) {
|
||||||
|
if !n.Valid {
|
||||||
|
return []byte("null"), nil
|
||||||
|
}
|
||||||
|
return []byte(fmt.Sprintf("%f", n.Float64)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SqlDate - Implementation of SqlTime with some interfaces.
|
||||||
|
type SqlDate time.Time
|
||||||
|
|
||||||
|
// UnmarshalJSON - Override JSON format of time
|
||||||
|
func (t *SqlDate) UnmarshalJSON(b []byte) error {
|
||||||
|
var err error
|
||||||
|
|
||||||
|
s := strings.Trim(strings.Trim(string(b), " "), "\"")
|
||||||
|
if s == "null" || s == "" || s == "0" ||
|
||||||
|
strings.HasPrefix(s, "0001-01-01T00:00:00") ||
|
||||||
|
s == "0001-01-01" {
|
||||||
|
t = &SqlDate{}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
tx, err := tryParseDT(s)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
*t = SqlDate(tx)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON - Override JSON format of time
|
||||||
|
func (t SqlDate) MarshalJSON() ([]byte, error) {
|
||||||
|
tmstr := time.Time(t).Format("2006-01-02") //time.RFC3339
|
||||||
|
if strings.HasPrefix(tmstr, "0001-01-01") {
|
||||||
|
return []byte("null"), nil
|
||||||
|
}
|
||||||
|
return []byte(fmt.Sprintf("\"%s\"", tmstr)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Value - SQL Value of custom date
|
||||||
|
func (t SqlDate) Value() (driver.Value, error) {
|
||||||
|
var s time.Time
|
||||||
|
tmstr := time.Time(t).Format("2006-01-02")
|
||||||
|
if strings.HasPrefix(tmstr, "0001-01-01") || tmstr <= "0001-01-01" {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
s = time.Time(t)
|
||||||
|
|
||||||
|
return s.Format("2006-01-02"), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scan - Scan custom date from sql
|
||||||
|
func (t *SqlDate) Scan(value interface{}) error {
|
||||||
|
tm, ok := value.(time.Time)
|
||||||
|
if ok {
|
||||||
|
*t = SqlDate(tm)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
str, ok := value.(string)
|
||||||
|
if ok {
|
||||||
|
tx, err := tryParseDT(str)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
*t = SqlDate(tx)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Int64 - Override date format in unix epoch
|
||||||
|
func (t SqlDate) Int64() int64 {
|
||||||
|
return time.Time(t).Unix()
|
||||||
|
}
|
||||||
|
|
||||||
|
// String - Override String format of time
|
||||||
|
func (t SqlDate) String() string {
|
||||||
|
tmstr := time.Time(t).Format("2006-01-02") //time.RFC3339
|
||||||
|
if strings.HasPrefix(tmstr, "0001-01-01") || strings.HasPrefix(tmstr, "1800-12-31") {
|
||||||
|
return "0"
|
||||||
|
}
|
||||||
|
return tmstr
|
||||||
|
}
|
||||||
|
|
||||||
|
func SqlDateNow() SqlDate {
|
||||||
|
tx := time.Now()
|
||||||
|
return SqlDate(tx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ////////////////////// SqlTime /////////////////////////
|
||||||
|
// SqlTime - Implementation of SqlTime with some interfaces.
|
||||||
|
type SqlTime time.Time
|
||||||
|
|
||||||
|
// Int64 - Override Time format in unix epoch
|
||||||
|
func (t SqlTime) Int64() int64 {
|
||||||
|
return time.Time(t).Unix()
|
||||||
|
}
|
||||||
|
|
||||||
|
// String - Override String format of time
|
||||||
|
func (t SqlTime) String() string {
|
||||||
|
return time.Time(t).Format("15:04:05")
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON - Override JSON format of time
|
||||||
|
func (t *SqlTime) UnmarshalJSON(b []byte) error {
|
||||||
|
var err error
|
||||||
|
s := strings.Trim(strings.Trim(string(b), " "), "\"")
|
||||||
|
if s == "null" || s == "" || s == "0" ||
|
||||||
|
s == "0001-01-01T00:00:00" || s == "00:00:00" {
|
||||||
|
*t = SqlTime{}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
tx := time.Time{}
|
||||||
|
tx, err = tryParseDT(s)
|
||||||
|
*t = SqlTime(tx)
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Format - Format Function
|
||||||
|
func (t SqlTime) Format(form string) string {
|
||||||
|
tmstr := time.Time(t).Format(form)
|
||||||
|
return tmstr
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scan - Scan custom date from sql
|
||||||
|
func (t *SqlTime) Scan(value interface{}) error {
|
||||||
|
tm, ok := value.(time.Time)
|
||||||
|
if ok {
|
||||||
|
*t = SqlTime(tm)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
str, ok := value.(string)
|
||||||
|
if ok {
|
||||||
|
tx, err := tryParseDT(str)
|
||||||
|
*t = SqlTime(tx)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Value - SQL Value of custom date
|
||||||
|
func (t SqlTime) Value() (driver.Value, error) {
|
||||||
|
|
||||||
|
s := time.Time(t)
|
||||||
|
st := s.Format("15:04:05")
|
||||||
|
|
||||||
|
return st, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON - Override JSON format of time
|
||||||
|
func (t SqlTime) MarshalJSON() ([]byte, error) {
|
||||||
|
tmstr := time.Time(t).Format("15:04:05")
|
||||||
|
if tmstr == "0001-01-01T00:00:00" || tmstr == "00:00:00" {
|
||||||
|
return []byte("null"), nil
|
||||||
|
}
|
||||||
|
return []byte(fmt.Sprintf("\"%s\"", tmstr)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func SqlTimeNow() SqlTime {
|
||||||
|
tx := time.Now()
|
||||||
|
return SqlTime(tx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SqlJSONB - Nullable JSONB String
|
||||||
|
type SqlJSONB []byte
|
||||||
|
|
||||||
|
// Scan - Implements sql.Scanner for reading JSONB from database
|
||||||
|
func (n *SqlJSONB) Scan(value interface{}) error {
|
||||||
|
if value == nil {
|
||||||
|
*n = nil
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
switch v := value.(type) {
|
||||||
|
case string:
|
||||||
|
*n = SqlJSONB([]byte(v))
|
||||||
|
case []byte:
|
||||||
|
*n = SqlJSONB(v)
|
||||||
|
default:
|
||||||
|
// For other types, marshal to JSON
|
||||||
|
dat, err := json.Marshal(value)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to marshal value to JSON: %v", err)
|
||||||
|
}
|
||||||
|
*n = SqlJSONB(dat)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Value - Implements driver.Valuer for writing JSONB to database
|
||||||
|
func (n SqlJSONB) Value() (driver.Value, error) {
|
||||||
|
if len(n) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate that it's valid JSON before returning
|
||||||
|
var js interface{}
|
||||||
|
if err := json.Unmarshal(n, &js); err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid JSON: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return as string for PostgreSQL JSONB/JSON columns
|
||||||
|
return string(n), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n SqlJSONB) AsMap() (map[string]any, error) {
|
||||||
|
if len(n) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
// Validate that it's valid JSON before returning
|
||||||
|
js := make(map[string]any)
|
||||||
|
if err := json.Unmarshal(n, &js); err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid JSON: %v", err)
|
||||||
|
}
|
||||||
|
return js, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n SqlJSONB) AsSlice() ([]any, error) {
|
||||||
|
if len(n) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
// Validate that it's valid JSON before returning
|
||||||
|
js := make([]any, 0)
|
||||||
|
if err := json.Unmarshal(n, &js); err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid JSON: %v", err)
|
||||||
|
}
|
||||||
|
return js, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON - Override JSON
|
||||||
|
func (n *SqlJSONB) UnmarshalJSON(b []byte) error {
|
||||||
|
|
||||||
|
s := strings.Trim(strings.Trim(string(b), " "), "\"")
|
||||||
|
invalid := (s == "null" || s == "" || len(s) < 2) || !(strings.Contains(s, "{") || strings.Contains(s, "["))
|
||||||
|
if invalid {
|
||||||
|
s = ""
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
*n = []byte(s)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON - Override JSON format of time
|
||||||
|
func (n SqlJSONB) MarshalJSON() ([]byte, error) {
|
||||||
|
if n == nil {
|
||||||
|
return []byte("null"), nil
|
||||||
|
}
|
||||||
|
var obj interface{}
|
||||||
|
err := json.Unmarshal(n, &obj)
|
||||||
|
if err != nil {
|
||||||
|
//fmt.Printf("Invalid JSON %v", err)
|
||||||
|
return []byte("null"), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// dat, err := json.MarshalIndent(obj, " ", " ")
|
||||||
|
// if err != nil {
|
||||||
|
// return nil, fmt.Errorf("failed to convert to JSON: %v", err)
|
||||||
|
// }
|
||||||
|
dat := n
|
||||||
|
|
||||||
|
return dat, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SqlUUID - Nullable UUID String
|
||||||
|
type SqlUUID sql.NullString
|
||||||
|
|
||||||
|
// Scan -
|
||||||
|
func (n *SqlUUID) Scan(value interface{}) error {
|
||||||
|
str := sql.NullString{String: "", Valid: false}
|
||||||
|
if value == nil {
|
||||||
|
*n = SqlUUID(str)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
switch v := value.(type) {
|
||||||
|
case string:
|
||||||
|
uuid, err := uuid.Parse(v)
|
||||||
|
if err == nil {
|
||||||
|
str.String = uuid.String()
|
||||||
|
str.Valid = true
|
||||||
|
*n = SqlUUID(str)
|
||||||
|
}
|
||||||
|
case []uint8:
|
||||||
|
uuid, err := uuid.ParseBytes(v)
|
||||||
|
if err == nil {
|
||||||
|
str.String = uuid.String()
|
||||||
|
str.Valid = true
|
||||||
|
*n = SqlUUID(str)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
uuid, err := uuid.Parse(fmt.Sprintf("%v", v))
|
||||||
|
if err == nil {
|
||||||
|
str.String = uuid.String()
|
||||||
|
str.Valid = true
|
||||||
|
*n = SqlUUID(str)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Value -
|
||||||
|
func (n SqlUUID) Value() (driver.Value, error) {
|
||||||
|
if !n.Valid {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return n.String, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON - Override JSON
|
||||||
|
func (n *SqlUUID) UnmarshalJSON(b []byte) error {
|
||||||
|
|
||||||
|
s := strings.Trim(strings.Trim(string(b), " "), "\"")
|
||||||
|
invalid := (s == "null" || s == "" || len(s) < 30)
|
||||||
|
if invalid {
|
||||||
|
s = ""
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
*n = SqlUUID(sql.NullString{String: s, Valid: !invalid})
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON - Override JSON format of time
|
||||||
|
func (n SqlUUID) MarshalJSON() ([]byte, error) {
|
||||||
|
if !n.Valid {
|
||||||
|
return []byte("null"), nil
|
||||||
|
}
|
||||||
|
return []byte(fmt.Sprintf("\"%s\"", n.String)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TryIfInt64 - Wrapper function to quickly try and cast text to int
|
||||||
|
func TryIfInt64(v any, def int64) int64 {
|
||||||
|
str := ""
|
||||||
|
switch val := v.(type) {
|
||||||
|
case string:
|
||||||
|
str = val
|
||||||
|
case int:
|
||||||
|
return int64(val)
|
||||||
|
case int32:
|
||||||
|
return int64(val)
|
||||||
|
case int64:
|
||||||
|
return val
|
||||||
|
case uint32:
|
||||||
|
return int64(val)
|
||||||
|
case uint64:
|
||||||
|
return int64(val)
|
||||||
|
case float32:
|
||||||
|
return int64(val)
|
||||||
|
case float64:
|
||||||
|
return int64(val)
|
||||||
|
case []byte:
|
||||||
|
str = string(val)
|
||||||
|
default:
|
||||||
|
str = fmt.Sprintf("%d", def)
|
||||||
|
}
|
||||||
|
val, err := strconv.ParseInt(str, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return def
|
||||||
|
}
|
||||||
|
return val
|
||||||
|
}
|
||||||
566
pkg/common/sql_types_test.go
Normal file
566
pkg/common/sql_types_test.go
Normal file
@@ -0,0 +1,566 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql/driver"
|
||||||
|
"encoding/json"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestSqlInt16 tests SqlInt16 type
|
||||||
|
func TestSqlInt16(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input interface{}
|
||||||
|
expected SqlInt16
|
||||||
|
}{
|
||||||
|
{"int", 42, SqlInt16(42)},
|
||||||
|
{"int32", int32(100), SqlInt16(100)},
|
||||||
|
{"int64", int64(200), SqlInt16(200)},
|
||||||
|
{"string", "123", SqlInt16(123)},
|
||||||
|
{"nil", nil, SqlInt16(0)},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
var n SqlInt16
|
||||||
|
if err := n.Scan(tt.input); err != nil {
|
||||||
|
t.Fatalf("Scan failed: %v", err)
|
||||||
|
}
|
||||||
|
if n != tt.expected {
|
||||||
|
t.Errorf("expected %v, got %v", tt.expected, n)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlInt16_Value(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input SqlInt16
|
||||||
|
expected driver.Value
|
||||||
|
}{
|
||||||
|
{"zero", SqlInt16(0), nil},
|
||||||
|
{"positive", SqlInt16(42), int64(42)},
|
||||||
|
{"negative", SqlInt16(-10), int64(-10)},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
val, err := tt.input.Value()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Value failed: %v", err)
|
||||||
|
}
|
||||||
|
if val != tt.expected {
|
||||||
|
t.Errorf("expected %v, got %v", tt.expected, val)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlInt16_JSON(t *testing.T) {
|
||||||
|
n := SqlInt16(42)
|
||||||
|
|
||||||
|
// Marshal
|
||||||
|
data, err := json.Marshal(n)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Marshal failed: %v", err)
|
||||||
|
}
|
||||||
|
expected := "42"
|
||||||
|
if string(data) != expected {
|
||||||
|
t.Errorf("expected %s, got %s", expected, string(data))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unmarshal
|
||||||
|
var n2 SqlInt16
|
||||||
|
if err := json.Unmarshal([]byte("123"), &n2); err != nil {
|
||||||
|
t.Fatalf("Unmarshal failed: %v", err)
|
||||||
|
}
|
||||||
|
if n2 != 123 {
|
||||||
|
t.Errorf("expected 123, got %d", n2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSqlInt64 tests SqlInt64 type
|
||||||
|
func TestSqlInt64(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input interface{}
|
||||||
|
expected SqlInt64
|
||||||
|
}{
|
||||||
|
{"int", 42, SqlInt64(42)},
|
||||||
|
{"int32", int32(100), SqlInt64(100)},
|
||||||
|
{"int64", int64(9223372036854775807), SqlInt64(9223372036854775807)},
|
||||||
|
{"uint32", uint32(100), SqlInt64(100)},
|
||||||
|
{"uint64", uint64(200), SqlInt64(200)},
|
||||||
|
{"nil", nil, SqlInt64(0)},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
var n SqlInt64
|
||||||
|
if err := n.Scan(tt.input); err != nil {
|
||||||
|
t.Fatalf("Scan failed: %v", err)
|
||||||
|
}
|
||||||
|
if n != tt.expected {
|
||||||
|
t.Errorf("expected %v, got %v", tt.expected, n)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSqlFloat64 tests SqlFloat64 type
|
||||||
|
func TestSqlFloat64(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input interface{}
|
||||||
|
expected float64
|
||||||
|
valid bool
|
||||||
|
}{
|
||||||
|
{"float64", float64(3.14), 3.14, true},
|
||||||
|
{"float32", float32(2.5), 2.5, true},
|
||||||
|
{"int", 42, 42.0, true},
|
||||||
|
{"int64", int64(100), 100.0, true},
|
||||||
|
{"nil", nil, 0, false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
var n SqlFloat64
|
||||||
|
if err := n.Scan(tt.input); err != nil {
|
||||||
|
t.Fatalf("Scan failed: %v", err)
|
||||||
|
}
|
||||||
|
if n.Valid != tt.valid {
|
||||||
|
t.Errorf("expected valid=%v, got valid=%v", tt.valid, n.Valid)
|
||||||
|
}
|
||||||
|
if tt.valid && n.Float64 != tt.expected {
|
||||||
|
t.Errorf("expected %v, got %v", tt.expected, n.Float64)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSqlTimeStamp tests SqlTimeStamp type
|
||||||
|
func TestSqlTimeStamp(t *testing.T) {
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input interface{}
|
||||||
|
}{
|
||||||
|
{"time.Time", now},
|
||||||
|
{"string RFC3339", now.Format(time.RFC3339)},
|
||||||
|
{"string date", "2024-01-15"},
|
||||||
|
{"string datetime", "2024-01-15T10:30:00"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
var ts SqlTimeStamp
|
||||||
|
if err := ts.Scan(tt.input); err != nil {
|
||||||
|
t.Fatalf("Scan failed: %v", err)
|
||||||
|
}
|
||||||
|
if ts.GetTime().IsZero() {
|
||||||
|
t.Error("expected non-zero time")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlTimeStamp_JSON(t *testing.T) {
|
||||||
|
now := time.Date(2024, 1, 15, 10, 30, 45, 0, time.UTC)
|
||||||
|
ts := SqlTimeStamp(now)
|
||||||
|
|
||||||
|
// Marshal
|
||||||
|
data, err := json.Marshal(ts)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Marshal failed: %v", err)
|
||||||
|
}
|
||||||
|
expected := `"2024-01-15T10:30:45"`
|
||||||
|
if string(data) != expected {
|
||||||
|
t.Errorf("expected %s, got %s", expected, string(data))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unmarshal
|
||||||
|
var ts2 SqlTimeStamp
|
||||||
|
if err := json.Unmarshal([]byte(`"2024-01-15T10:30:45"`), &ts2); err != nil {
|
||||||
|
t.Fatalf("Unmarshal failed: %v", err)
|
||||||
|
}
|
||||||
|
if ts2.GetTime().Year() != 2024 {
|
||||||
|
t.Errorf("expected year 2024, got %d", ts2.GetTime().Year())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test null
|
||||||
|
var ts3 SqlTimeStamp
|
||||||
|
if err := json.Unmarshal([]byte("null"), &ts3); err != nil {
|
||||||
|
t.Fatalf("Unmarshal null failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSqlDate tests SqlDate type
|
||||||
|
func TestSqlDate(t *testing.T) {
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input interface{}
|
||||||
|
}{
|
||||||
|
{"time.Time", now},
|
||||||
|
{"string date", "2024-01-15"},
|
||||||
|
{"string UK format", "15/01/2024"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
var d SqlDate
|
||||||
|
if err := d.Scan(tt.input); err != nil {
|
||||||
|
t.Fatalf("Scan failed: %v", err)
|
||||||
|
}
|
||||||
|
if d.String() == "0" {
|
||||||
|
t.Error("expected non-zero date")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlDate_JSON(t *testing.T) {
|
||||||
|
date := SqlDate(time.Date(2024, 1, 15, 0, 0, 0, 0, time.UTC))
|
||||||
|
|
||||||
|
// Marshal
|
||||||
|
data, err := json.Marshal(date)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Marshal failed: %v", err)
|
||||||
|
}
|
||||||
|
expected := `"2024-01-15"`
|
||||||
|
if string(data) != expected {
|
||||||
|
t.Errorf("expected %s, got %s", expected, string(data))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unmarshal
|
||||||
|
var d2 SqlDate
|
||||||
|
if err := json.Unmarshal([]byte(`"2024-01-15"`), &d2); err != nil {
|
||||||
|
t.Fatalf("Unmarshal failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSqlTime tests SqlTime type
|
||||||
|
func TestSqlTime(t *testing.T) {
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input interface{}
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{"time.Time", now, now.Format("15:04:05")},
|
||||||
|
{"string time", "10:30:45", "10:30:45"},
|
||||||
|
{"string short time", "10:30", "10:30:00"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
var tm SqlTime
|
||||||
|
if err := tm.Scan(tt.input); err != nil {
|
||||||
|
t.Fatalf("Scan failed: %v", err)
|
||||||
|
}
|
||||||
|
if tm.String() != tt.expected {
|
||||||
|
t.Errorf("expected %s, got %s", tt.expected, tm.String())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSqlJSONB tests SqlJSONB type
|
||||||
|
func TestSqlJSONB_Scan(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input interface{}
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{"string JSON object", `{"key":"value"}`, `{"key":"value"}`},
|
||||||
|
{"string JSON array", `[1,2,3]`, `[1,2,3]`},
|
||||||
|
{"bytes", []byte(`{"test":true}`), `{"test":true}`},
|
||||||
|
{"nil", nil, ""},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
var j SqlJSONB
|
||||||
|
if err := j.Scan(tt.input); err != nil {
|
||||||
|
t.Fatalf("Scan failed: %v", err)
|
||||||
|
}
|
||||||
|
if tt.expected == "" && j == nil {
|
||||||
|
return // nil case
|
||||||
|
}
|
||||||
|
if string(j) != tt.expected {
|
||||||
|
t.Errorf("expected %s, got %s", tt.expected, string(j))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlJSONB_Value(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input SqlJSONB
|
||||||
|
expected string
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"valid object", SqlJSONB(`{"key":"value"}`), `{"key":"value"}`, false},
|
||||||
|
{"valid array", SqlJSONB(`[1,2,3]`), `[1,2,3]`, false},
|
||||||
|
{"empty", SqlJSONB{}, "", false},
|
||||||
|
{"nil", nil, "", false},
|
||||||
|
{"invalid JSON", SqlJSONB(`{invalid`), "", true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
val, err := tt.input.Value()
|
||||||
|
if tt.wantErr {
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error, got nil")
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Value failed: %v", err)
|
||||||
|
}
|
||||||
|
if tt.expected == "" && val == nil {
|
||||||
|
return // nil case
|
||||||
|
}
|
||||||
|
if val.(string) != tt.expected {
|
||||||
|
t.Errorf("expected %s, got %s", tt.expected, val)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlJSONB_JSON(t *testing.T) {
|
||||||
|
// Marshal
|
||||||
|
j := SqlJSONB(`{"name":"test","count":42}`)
|
||||||
|
data, err := json.Marshal(j)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Marshal failed: %v", err)
|
||||||
|
}
|
||||||
|
var result map[string]interface{}
|
||||||
|
if err := json.Unmarshal(data, &result); err != nil {
|
||||||
|
t.Fatalf("Unmarshal result failed: %v", err)
|
||||||
|
}
|
||||||
|
if result["name"] != "test" {
|
||||||
|
t.Errorf("expected name=test, got %v", result["name"])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unmarshal
|
||||||
|
var j2 SqlJSONB
|
||||||
|
if err := json.Unmarshal([]byte(`{"key":"value"}`), &j2); err != nil {
|
||||||
|
t.Fatalf("Unmarshal failed: %v", err)
|
||||||
|
}
|
||||||
|
if string(j2) != `{"key":"value"}` {
|
||||||
|
t.Errorf("expected {\"key\":\"value\"}, got %s", string(j2))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test null
|
||||||
|
var j3 SqlJSONB
|
||||||
|
if err := json.Unmarshal([]byte("null"), &j3); err != nil {
|
||||||
|
t.Fatalf("Unmarshal null failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlJSONB_AsMap(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input SqlJSONB
|
||||||
|
wantErr bool
|
||||||
|
wantNil bool
|
||||||
|
}{
|
||||||
|
{"valid object", SqlJSONB(`{"name":"test","age":30}`), false, false},
|
||||||
|
{"empty", SqlJSONB{}, false, true},
|
||||||
|
{"nil", nil, false, true},
|
||||||
|
{"invalid JSON", SqlJSONB(`{invalid`), true, false},
|
||||||
|
{"array not object", SqlJSONB(`[1,2,3]`), true, false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
m, err := tt.input.AsMap()
|
||||||
|
if tt.wantErr {
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error, got nil")
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("AsMap failed: %v", err)
|
||||||
|
}
|
||||||
|
if tt.wantNil {
|
||||||
|
if m != nil {
|
||||||
|
t.Errorf("expected nil, got %v", m)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if m == nil {
|
||||||
|
t.Error("expected non-nil map")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlJSONB_AsSlice(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input SqlJSONB
|
||||||
|
wantErr bool
|
||||||
|
wantNil bool
|
||||||
|
}{
|
||||||
|
{"valid array", SqlJSONB(`[1,2,3]`), false, false},
|
||||||
|
{"empty", SqlJSONB{}, false, true},
|
||||||
|
{"nil", nil, false, true},
|
||||||
|
{"invalid JSON", SqlJSONB(`[invalid`), true, false},
|
||||||
|
{"object not array", SqlJSONB(`{"key":"value"}`), true, false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
s, err := tt.input.AsSlice()
|
||||||
|
if tt.wantErr {
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error, got nil")
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("AsSlice failed: %v", err)
|
||||||
|
}
|
||||||
|
if tt.wantNil {
|
||||||
|
if s != nil {
|
||||||
|
t.Errorf("expected nil, got %v", s)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if s == nil {
|
||||||
|
t.Error("expected non-nil slice")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSqlUUID tests SqlUUID type
|
||||||
|
func TestSqlUUID_Scan(t *testing.T) {
|
||||||
|
testUUID := uuid.New()
|
||||||
|
testUUIDStr := testUUID.String()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input interface{}
|
||||||
|
expected string
|
||||||
|
valid bool
|
||||||
|
}{
|
||||||
|
{"string UUID", testUUIDStr, testUUIDStr, true},
|
||||||
|
{"bytes UUID", []byte(testUUIDStr), testUUIDStr, true},
|
||||||
|
{"nil", nil, "", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
var u SqlUUID
|
||||||
|
if err := u.Scan(tt.input); err != nil {
|
||||||
|
t.Fatalf("Scan failed: %v", err)
|
||||||
|
}
|
||||||
|
if u.Valid != tt.valid {
|
||||||
|
t.Errorf("expected valid=%v, got valid=%v", tt.valid, u.Valid)
|
||||||
|
}
|
||||||
|
if tt.valid && u.String != tt.expected {
|
||||||
|
t.Errorf("expected %s, got %s", tt.expected, u.String)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlUUID_Value(t *testing.T) {
|
||||||
|
testUUID := uuid.New()
|
||||||
|
u := SqlUUID{String: testUUID.String(), Valid: true}
|
||||||
|
|
||||||
|
val, err := u.Value()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Value failed: %v", err)
|
||||||
|
}
|
||||||
|
if val != testUUID.String() {
|
||||||
|
t.Errorf("expected %s, got %s", testUUID.String(), val)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test invalid UUID
|
||||||
|
u2 := SqlUUID{Valid: false}
|
||||||
|
val2, err := u2.Value()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Value failed: %v", err)
|
||||||
|
}
|
||||||
|
if val2 != nil {
|
||||||
|
t.Errorf("expected nil, got %v", val2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlUUID_JSON(t *testing.T) {
|
||||||
|
testUUID := uuid.New()
|
||||||
|
u := SqlUUID{String: testUUID.String(), Valid: true}
|
||||||
|
|
||||||
|
// Marshal
|
||||||
|
data, err := json.Marshal(u)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Marshal failed: %v", err)
|
||||||
|
}
|
||||||
|
expected := `"` + testUUID.String() + `"`
|
||||||
|
if string(data) != expected {
|
||||||
|
t.Errorf("expected %s, got %s", expected, string(data))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unmarshal
|
||||||
|
var u2 SqlUUID
|
||||||
|
if err := json.Unmarshal([]byte(`"`+testUUID.String()+`"`), &u2); err != nil {
|
||||||
|
t.Fatalf("Unmarshal failed: %v", err)
|
||||||
|
}
|
||||||
|
if u2.String != testUUID.String() {
|
||||||
|
t.Errorf("expected %s, got %s", testUUID.String(), u2.String)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test null
|
||||||
|
var u3 SqlUUID
|
||||||
|
if err := json.Unmarshal([]byte("null"), &u3); err != nil {
|
||||||
|
t.Fatalf("Unmarshal null failed: %v", err)
|
||||||
|
}
|
||||||
|
if u3.Valid {
|
||||||
|
t.Error("expected invalid UUID")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestTryIfInt64 tests the TryIfInt64 helper function
|
||||||
|
func TestTryIfInt64(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input interface{}
|
||||||
|
def int64
|
||||||
|
expected int64
|
||||||
|
}{
|
||||||
|
{"string valid", "123", 0, 123},
|
||||||
|
{"string invalid", "abc", 99, 99},
|
||||||
|
{"int", 42, 0, 42},
|
||||||
|
{"int32", int32(100), 0, 100},
|
||||||
|
{"int64", int64(200), 0, 200},
|
||||||
|
{"uint32", uint32(50), 0, 50},
|
||||||
|
{"uint64", uint64(75), 0, 75},
|
||||||
|
{"float32", float32(3.14), 0, 3},
|
||||||
|
{"float64", float64(2.71), 0, 2},
|
||||||
|
{"bytes", []byte("456"), 0, 456},
|
||||||
|
{"unknown type", struct{}{}, 999, 999},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := TryIfInt64(tt.input, tt.def)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("expected %d, got %d", tt.expected, result)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -92,9 +92,27 @@ func (v *ColumnValidator) getColumnName(field reflect.StructField) string {
|
|||||||
return strings.ToLower(field.Name)
|
return strings.ToLower(field.Name)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// extractSourceColumn extracts the base column name from PostgreSQL JSON operators
|
||||||
|
// Examples:
|
||||||
|
// - "columna->>'val'" returns "columna"
|
||||||
|
// - "columna->'key'" returns "columna"
|
||||||
|
// - "columna" returns "columna"
|
||||||
|
// - "table.columna->>'val'" returns "table.columna"
|
||||||
|
func extractSourceColumn(colName string) string {
|
||||||
|
// Check for PostgreSQL JSON operators: -> and ->>
|
||||||
|
if idx := strings.Index(colName, "->>"); idx != -1 {
|
||||||
|
return strings.TrimSpace(colName[:idx])
|
||||||
|
}
|
||||||
|
if idx := strings.Index(colName, "->"); idx != -1 {
|
||||||
|
return strings.TrimSpace(colName[:idx])
|
||||||
|
}
|
||||||
|
return colName
|
||||||
|
}
|
||||||
|
|
||||||
// ValidateColumn validates a single column name
|
// ValidateColumn validates a single column name
|
||||||
// Returns nil if valid, error if invalid
|
// Returns nil if valid, error if invalid
|
||||||
// Columns prefixed with "cql" (case insensitive) are always valid
|
// Columns prefixed with "cql" (case insensitive) are always valid
|
||||||
|
// Handles PostgreSQL JSON operators (-> and ->>)
|
||||||
func (v *ColumnValidator) ValidateColumn(column string) error {
|
func (v *ColumnValidator) ValidateColumn(column string) error {
|
||||||
// Allow empty columns
|
// Allow empty columns
|
||||||
if column == "" {
|
if column == "" {
|
||||||
@@ -106,8 +124,11 @@ func (v *ColumnValidator) ValidateColumn(column string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Extract source column name (remove JSON operators like ->> or ->)
|
||||||
|
sourceColumn := extractSourceColumn(column)
|
||||||
|
|
||||||
// Check if column exists in model
|
// Check if column exists in model
|
||||||
if _, exists := v.validColumns[strings.ToLower(column)]; !exists {
|
if _, exists := v.validColumns[strings.ToLower(sourceColumn)]; !exists {
|
||||||
return fmt.Errorf("invalid column '%s': column does not exist in model", column)
|
return fmt.Errorf("invalid column '%s': column does not exist in model", column)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
124
pkg/common/validation_json_test.go
Normal file
124
pkg/common/validation_json_test.go
Normal file
@@ -0,0 +1,124 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestExtractSourceColumn(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "simple column name",
|
||||||
|
input: "columna",
|
||||||
|
expected: "columna",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "column with ->> operator",
|
||||||
|
input: "columna->>'val'",
|
||||||
|
expected: "columna",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "column with -> operator",
|
||||||
|
input: "columna->'key'",
|
||||||
|
expected: "columna",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "column with table prefix and ->> operator",
|
||||||
|
input: "table.columna->>'val'",
|
||||||
|
expected: "table.columna",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "column with table prefix and -> operator",
|
||||||
|
input: "table.columna->'key'",
|
||||||
|
expected: "table.columna",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "complex JSON path with ->>",
|
||||||
|
input: "data->>'nested'->>'value'",
|
||||||
|
expected: "data",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "column with spaces before operator",
|
||||||
|
input: "columna ->>'val'",
|
||||||
|
expected: "columna",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
result := extractSourceColumn(tc.input)
|
||||||
|
if result != tc.expected {
|
||||||
|
t.Errorf("extractSourceColumn(%q) = %q; want %q", tc.input, result, tc.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateColumnWithJSONOperators(t *testing.T) {
|
||||||
|
// Create a test model
|
||||||
|
type TestModel struct {
|
||||||
|
ID int `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Data string `json:"data"` // JSON column
|
||||||
|
Metadata string `json:"metadata"`
|
||||||
|
}
|
||||||
|
|
||||||
|
validator := NewColumnValidator(TestModel{})
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
column string
|
||||||
|
shouldErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "simple valid column",
|
||||||
|
column: "name",
|
||||||
|
shouldErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid column with ->> operator",
|
||||||
|
column: "data->>'field'",
|
||||||
|
shouldErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid column with -> operator",
|
||||||
|
column: "metadata->'key'",
|
||||||
|
shouldErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid column",
|
||||||
|
column: "invalid_column",
|
||||||
|
shouldErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid column with ->> operator",
|
||||||
|
column: "invalid_column->>'field'",
|
||||||
|
shouldErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "cql prefixed column (always valid)",
|
||||||
|
column: "cql_computed",
|
||||||
|
shouldErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty column",
|
||||||
|
column: "",
|
||||||
|
shouldErr: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
err := validator.ValidateColumn(tc.column)
|
||||||
|
if tc.shouldErr && err == nil {
|
||||||
|
t.Errorf("ValidateColumn(%q) expected error, got nil", tc.column)
|
||||||
|
}
|
||||||
|
if !tc.shouldErr && err != nil {
|
||||||
|
t.Errorf("ValidateColumn(%q) expected no error, got %v", tc.column, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -663,7 +663,14 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create insert query
|
// Create insert query
|
||||||
query := tx.NewInsert().Model(modelValue).Table(tableName).Returning("*")
|
query := tx.NewInsert().Model(modelValue)
|
||||||
|
|
||||||
|
// Only set Table() if the model doesn't provide a table name via TableNameProvider
|
||||||
|
if provider, ok := modelValue.(common.TableNameProvider); !ok || provider.TableName() == "" {
|
||||||
|
query = query.Table(tableName)
|
||||||
|
}
|
||||||
|
|
||||||
|
query = query.Returning("*")
|
||||||
|
|
||||||
// Execute BeforeScan hooks - pass query chain so hooks can modify it
|
// Execute BeforeScan hooks - pass query chain so hooks can modify it
|
||||||
itemHookCtx := &HookContext{
|
itemHookCtx := &HookContext{
|
||||||
@@ -1640,10 +1647,13 @@ func (h *Handler) sendResponseWithOptions(w common.ResponseWriter, data interfac
|
|||||||
data = h.normalizeResultArray(data)
|
data = h.normalizeResultArray(data)
|
||||||
}
|
}
|
||||||
|
|
||||||
response := common.Response{
|
response := data
|
||||||
Success: true,
|
if response == nil {
|
||||||
Data: data,
|
response = common.Response{
|
||||||
Metadata: metadata,
|
Success: true,
|
||||||
|
Data: data,
|
||||||
|
Metadata: metadata,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
if err := w.WriteJSON(response); err != nil {
|
if err := w.WriteJSON(response); err != nil {
|
||||||
|
|||||||
@@ -480,12 +480,32 @@ func (h *Handler) parseCommaSeparated(value string) []string {
|
|||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// extractSourceColumn extracts the base column name from PostgreSQL JSON operators
|
||||||
|
// Examples:
|
||||||
|
// - "columna->>'val'" returns "columna"
|
||||||
|
// - "columna->'key'" returns "columna"
|
||||||
|
// - "columna" returns "columna"
|
||||||
|
// - "table.columna->>'val'" returns "table.columna"
|
||||||
|
func extractSourceColumn(colName string) string {
|
||||||
|
// Check for PostgreSQL JSON operators: -> and ->>
|
||||||
|
if idx := strings.Index(colName, "->>"); idx != -1 {
|
||||||
|
return strings.TrimSpace(colName[:idx])
|
||||||
|
}
|
||||||
|
if idx := strings.Index(colName, "->"); idx != -1 {
|
||||||
|
return strings.TrimSpace(colName[:idx])
|
||||||
|
}
|
||||||
|
return colName
|
||||||
|
}
|
||||||
|
|
||||||
// getColumnTypeFromModel uses reflection to determine the Go type of a column in a model
|
// getColumnTypeFromModel uses reflection to determine the Go type of a column in a model
|
||||||
func (h *Handler) getColumnTypeFromModel(model interface{}, colName string) reflect.Kind {
|
func (h *Handler) getColumnTypeFromModel(model interface{}, colName string) reflect.Kind {
|
||||||
if model == nil {
|
if model == nil {
|
||||||
return reflect.Invalid
|
return reflect.Invalid
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Extract the source column name (remove JSON operators like ->> or ->)
|
||||||
|
sourceColName := extractSourceColumn(colName)
|
||||||
|
|
||||||
modelType := reflect.TypeOf(model)
|
modelType := reflect.TypeOf(model)
|
||||||
// Dereference pointer if needed
|
// Dereference pointer if needed
|
||||||
if modelType.Kind() == reflect.Ptr {
|
if modelType.Kind() == reflect.Ptr {
|
||||||
@@ -506,19 +526,19 @@ func (h *Handler) getColumnTypeFromModel(model interface{}, colName string) refl
|
|||||||
if jsonTag != "" {
|
if jsonTag != "" {
|
||||||
// Parse JSON tag (format: "name,omitempty")
|
// Parse JSON tag (format: "name,omitempty")
|
||||||
parts := strings.Split(jsonTag, ",")
|
parts := strings.Split(jsonTag, ",")
|
||||||
if parts[0] == colName {
|
if parts[0] == sourceColName {
|
||||||
return field.Type.Kind()
|
return field.Type.Kind()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check field name (case-insensitive)
|
// Check field name (case-insensitive)
|
||||||
if strings.EqualFold(field.Name, colName) {
|
if strings.EqualFold(field.Name, sourceColName) {
|
||||||
return field.Type.Kind()
|
return field.Type.Kind()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check snake_case conversion
|
// Check snake_case conversion
|
||||||
snakeCaseName := toSnakeCase(field.Name)
|
snakeCaseName := toSnakeCase(field.Name)
|
||||||
if snakeCaseName == colName {
|
if snakeCaseName == sourceColName {
|
||||||
return field.Type.Kind()
|
return field.Type.Kind()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user