mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2025-12-30 08:14:25 +00:00
Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
07b09e2025 | ||
|
|
3d5334002d | ||
|
|
640582d508 |
@@ -319,6 +319,10 @@ func (b *BunInsertQuery) Model(model interface{}) common.InsertQuery {
|
||||
}
|
||||
|
||||
func (b *BunInsertQuery) Table(table string) common.InsertQuery {
|
||||
if b.hasModel {
|
||||
// If model is set, do not override table name
|
||||
return b
|
||||
}
|
||||
b.query = b.query.Table(table)
|
||||
return b
|
||||
}
|
||||
|
||||
@@ -8,6 +8,8 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func tryParseDT(str string) (time.Time, error) {
|
||||
@@ -671,3 +673,102 @@ func (n SqlJSONB) MarshalJSON() ([]byte, error) {
|
||||
|
||||
return dat, nil
|
||||
}
|
||||
|
||||
// SqlUUID - Nullable UUID String
|
||||
type SqlUUID sql.NullString
|
||||
|
||||
// Scan -
|
||||
func (n *SqlUUID) Scan(value interface{}) error {
|
||||
str := sql.NullString{String: "", Valid: false}
|
||||
if value == nil {
|
||||
*n = SqlUUID(str)
|
||||
return nil
|
||||
}
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
uuid, err := uuid.Parse(v)
|
||||
if err == nil {
|
||||
str.String = uuid.String()
|
||||
str.Valid = true
|
||||
*n = SqlUUID(str)
|
||||
}
|
||||
case []uint8:
|
||||
uuid, err := uuid.ParseBytes(v)
|
||||
if err == nil {
|
||||
str.String = uuid.String()
|
||||
str.Valid = true
|
||||
*n = SqlUUID(str)
|
||||
}
|
||||
default:
|
||||
uuid, err := uuid.Parse(fmt.Sprintf("%v", v))
|
||||
if err == nil {
|
||||
str.String = uuid.String()
|
||||
str.Valid = true
|
||||
*n = SqlUUID(str)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value -
|
||||
func (n SqlUUID) Value() (driver.Value, error) {
|
||||
if !n.Valid {
|
||||
return nil, nil
|
||||
}
|
||||
return n.String, nil
|
||||
}
|
||||
|
||||
// UnmarshalJSON - Override JSON
|
||||
func (n *SqlUUID) UnmarshalJSON(b []byte) error {
|
||||
|
||||
s := strings.Trim(strings.Trim(string(b), " "), "\"")
|
||||
invalid := (s == "null" || s == "" || len(s) < 30)
|
||||
if invalid {
|
||||
s = ""
|
||||
return nil
|
||||
}
|
||||
*n = SqlUUID(sql.NullString{String: s, Valid: !invalid})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarshalJSON - Override JSON format of time
|
||||
func (n SqlUUID) MarshalJSON() ([]byte, error) {
|
||||
if !n.Valid {
|
||||
return []byte("null"), nil
|
||||
}
|
||||
return []byte(fmt.Sprintf("\"%s\"", n.String)), nil
|
||||
}
|
||||
|
||||
// TryIfInt64 - Wrapper function to quickly try and cast text to int
|
||||
func TryIfInt64(v any, def int64) int64 {
|
||||
str := ""
|
||||
switch val := v.(type) {
|
||||
case string:
|
||||
str = val
|
||||
case int:
|
||||
return int64(val)
|
||||
case int32:
|
||||
return int64(val)
|
||||
case int64:
|
||||
return val
|
||||
case uint32:
|
||||
return int64(val)
|
||||
case uint64:
|
||||
return int64(val)
|
||||
case float32:
|
||||
return int64(val)
|
||||
case float64:
|
||||
return int64(val)
|
||||
case []byte:
|
||||
str = string(val)
|
||||
default:
|
||||
str = fmt.Sprintf("%d", def)
|
||||
}
|
||||
val, err := strconv.ParseInt(str, 10, 64)
|
||||
if err != nil {
|
||||
return def
|
||||
}
|
||||
return val
|
||||
}
|
||||
|
||||
566
pkg/common/sql_types_test.go
Normal file
566
pkg/common/sql_types_test.go
Normal file
@@ -0,0 +1,566 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// TestSqlInt16 tests SqlInt16 type
|
||||
func TestSqlInt16(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
expected SqlInt16
|
||||
}{
|
||||
{"int", 42, SqlInt16(42)},
|
||||
{"int32", int32(100), SqlInt16(100)},
|
||||
{"int64", int64(200), SqlInt16(200)},
|
||||
{"string", "123", SqlInt16(123)},
|
||||
{"nil", nil, SqlInt16(0)},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var n SqlInt16
|
||||
if err := n.Scan(tt.input); err != nil {
|
||||
t.Fatalf("Scan failed: %v", err)
|
||||
}
|
||||
if n != tt.expected {
|
||||
t.Errorf("expected %v, got %v", tt.expected, n)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlInt16_Value(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input SqlInt16
|
||||
expected driver.Value
|
||||
}{
|
||||
{"zero", SqlInt16(0), nil},
|
||||
{"positive", SqlInt16(42), int64(42)},
|
||||
{"negative", SqlInt16(-10), int64(-10)},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
val, err := tt.input.Value()
|
||||
if err != nil {
|
||||
t.Fatalf("Value failed: %v", err)
|
||||
}
|
||||
if val != tt.expected {
|
||||
t.Errorf("expected %v, got %v", tt.expected, val)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlInt16_JSON(t *testing.T) {
|
||||
n := SqlInt16(42)
|
||||
|
||||
// Marshal
|
||||
data, err := json.Marshal(n)
|
||||
if err != nil {
|
||||
t.Fatalf("Marshal failed: %v", err)
|
||||
}
|
||||
expected := "42"
|
||||
if string(data) != expected {
|
||||
t.Errorf("expected %s, got %s", expected, string(data))
|
||||
}
|
||||
|
||||
// Unmarshal
|
||||
var n2 SqlInt16
|
||||
if err := json.Unmarshal([]byte("123"), &n2); err != nil {
|
||||
t.Fatalf("Unmarshal failed: %v", err)
|
||||
}
|
||||
if n2 != 123 {
|
||||
t.Errorf("expected 123, got %d", n2)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSqlInt64 tests SqlInt64 type
|
||||
func TestSqlInt64(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
expected SqlInt64
|
||||
}{
|
||||
{"int", 42, SqlInt64(42)},
|
||||
{"int32", int32(100), SqlInt64(100)},
|
||||
{"int64", int64(9223372036854775807), SqlInt64(9223372036854775807)},
|
||||
{"uint32", uint32(100), SqlInt64(100)},
|
||||
{"uint64", uint64(200), SqlInt64(200)},
|
||||
{"nil", nil, SqlInt64(0)},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var n SqlInt64
|
||||
if err := n.Scan(tt.input); err != nil {
|
||||
t.Fatalf("Scan failed: %v", err)
|
||||
}
|
||||
if n != tt.expected {
|
||||
t.Errorf("expected %v, got %v", tt.expected, n)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSqlFloat64 tests SqlFloat64 type
|
||||
func TestSqlFloat64(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
expected float64
|
||||
valid bool
|
||||
}{
|
||||
{"float64", float64(3.14), 3.14, true},
|
||||
{"float32", float32(2.5), 2.5, true},
|
||||
{"int", 42, 42.0, true},
|
||||
{"int64", int64(100), 100.0, true},
|
||||
{"nil", nil, 0, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var n SqlFloat64
|
||||
if err := n.Scan(tt.input); err != nil {
|
||||
t.Fatalf("Scan failed: %v", err)
|
||||
}
|
||||
if n.Valid != tt.valid {
|
||||
t.Errorf("expected valid=%v, got valid=%v", tt.valid, n.Valid)
|
||||
}
|
||||
if tt.valid && n.Float64 != tt.expected {
|
||||
t.Errorf("expected %v, got %v", tt.expected, n.Float64)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSqlTimeStamp tests SqlTimeStamp type
|
||||
func TestSqlTimeStamp(t *testing.T) {
|
||||
now := time.Now()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
}{
|
||||
{"time.Time", now},
|
||||
{"string RFC3339", now.Format(time.RFC3339)},
|
||||
{"string date", "2024-01-15"},
|
||||
{"string datetime", "2024-01-15T10:30:00"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var ts SqlTimeStamp
|
||||
if err := ts.Scan(tt.input); err != nil {
|
||||
t.Fatalf("Scan failed: %v", err)
|
||||
}
|
||||
if ts.GetTime().IsZero() {
|
||||
t.Error("expected non-zero time")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlTimeStamp_JSON(t *testing.T) {
|
||||
now := time.Date(2024, 1, 15, 10, 30, 45, 0, time.UTC)
|
||||
ts := SqlTimeStamp(now)
|
||||
|
||||
// Marshal
|
||||
data, err := json.Marshal(ts)
|
||||
if err != nil {
|
||||
t.Fatalf("Marshal failed: %v", err)
|
||||
}
|
||||
expected := `"2024-01-15T10:30:45"`
|
||||
if string(data) != expected {
|
||||
t.Errorf("expected %s, got %s", expected, string(data))
|
||||
}
|
||||
|
||||
// Unmarshal
|
||||
var ts2 SqlTimeStamp
|
||||
if err := json.Unmarshal([]byte(`"2024-01-15T10:30:45"`), &ts2); err != nil {
|
||||
t.Fatalf("Unmarshal failed: %v", err)
|
||||
}
|
||||
if ts2.GetTime().Year() != 2024 {
|
||||
t.Errorf("expected year 2024, got %d", ts2.GetTime().Year())
|
||||
}
|
||||
|
||||
// Test null
|
||||
var ts3 SqlTimeStamp
|
||||
if err := json.Unmarshal([]byte("null"), &ts3); err != nil {
|
||||
t.Fatalf("Unmarshal null failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSqlDate tests SqlDate type
|
||||
func TestSqlDate(t *testing.T) {
|
||||
now := time.Now()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
}{
|
||||
{"time.Time", now},
|
||||
{"string date", "2024-01-15"},
|
||||
{"string UK format", "15/01/2024"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var d SqlDate
|
||||
if err := d.Scan(tt.input); err != nil {
|
||||
t.Fatalf("Scan failed: %v", err)
|
||||
}
|
||||
if d.String() == "0" {
|
||||
t.Error("expected non-zero date")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlDate_JSON(t *testing.T) {
|
||||
date := SqlDate(time.Date(2024, 1, 15, 0, 0, 0, 0, time.UTC))
|
||||
|
||||
// Marshal
|
||||
data, err := json.Marshal(date)
|
||||
if err != nil {
|
||||
t.Fatalf("Marshal failed: %v", err)
|
||||
}
|
||||
expected := `"2024-01-15"`
|
||||
if string(data) != expected {
|
||||
t.Errorf("expected %s, got %s", expected, string(data))
|
||||
}
|
||||
|
||||
// Unmarshal
|
||||
var d2 SqlDate
|
||||
if err := json.Unmarshal([]byte(`"2024-01-15"`), &d2); err != nil {
|
||||
t.Fatalf("Unmarshal failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSqlTime tests SqlTime type
|
||||
func TestSqlTime(t *testing.T) {
|
||||
now := time.Now()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
expected string
|
||||
}{
|
||||
{"time.Time", now, now.Format("15:04:05")},
|
||||
{"string time", "10:30:45", "10:30:45"},
|
||||
{"string short time", "10:30", "10:30:00"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var tm SqlTime
|
||||
if err := tm.Scan(tt.input); err != nil {
|
||||
t.Fatalf("Scan failed: %v", err)
|
||||
}
|
||||
if tm.String() != tt.expected {
|
||||
t.Errorf("expected %s, got %s", tt.expected, tm.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSqlJSONB tests SqlJSONB type
|
||||
func TestSqlJSONB_Scan(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
expected string
|
||||
}{
|
||||
{"string JSON object", `{"key":"value"}`, `{"key":"value"}`},
|
||||
{"string JSON array", `[1,2,3]`, `[1,2,3]`},
|
||||
{"bytes", []byte(`{"test":true}`), `{"test":true}`},
|
||||
{"nil", nil, ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var j SqlJSONB
|
||||
if err := j.Scan(tt.input); err != nil {
|
||||
t.Fatalf("Scan failed: %v", err)
|
||||
}
|
||||
if tt.expected == "" && j == nil {
|
||||
return // nil case
|
||||
}
|
||||
if string(j) != tt.expected {
|
||||
t.Errorf("expected %s, got %s", tt.expected, string(j))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlJSONB_Value(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input SqlJSONB
|
||||
expected string
|
||||
wantErr bool
|
||||
}{
|
||||
{"valid object", SqlJSONB(`{"key":"value"}`), `{"key":"value"}`, false},
|
||||
{"valid array", SqlJSONB(`[1,2,3]`), `[1,2,3]`, false},
|
||||
{"empty", SqlJSONB{}, "", false},
|
||||
{"nil", nil, "", false},
|
||||
{"invalid JSON", SqlJSONB(`{invalid`), "", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
val, err := tt.input.Value()
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Error("expected error, got nil")
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("Value failed: %v", err)
|
||||
}
|
||||
if tt.expected == "" && val == nil {
|
||||
return // nil case
|
||||
}
|
||||
if val.(string) != tt.expected {
|
||||
t.Errorf("expected %s, got %s", tt.expected, val)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlJSONB_JSON(t *testing.T) {
|
||||
// Marshal
|
||||
j := SqlJSONB(`{"name":"test","count":42}`)
|
||||
data, err := json.Marshal(j)
|
||||
if err != nil {
|
||||
t.Fatalf("Marshal failed: %v", err)
|
||||
}
|
||||
var result map[string]interface{}
|
||||
if err := json.Unmarshal(data, &result); err != nil {
|
||||
t.Fatalf("Unmarshal result failed: %v", err)
|
||||
}
|
||||
if result["name"] != "test" {
|
||||
t.Errorf("expected name=test, got %v", result["name"])
|
||||
}
|
||||
|
||||
// Unmarshal
|
||||
var j2 SqlJSONB
|
||||
if err := json.Unmarshal([]byte(`{"key":"value"}`), &j2); err != nil {
|
||||
t.Fatalf("Unmarshal failed: %v", err)
|
||||
}
|
||||
if string(j2) != `{"key":"value"}` {
|
||||
t.Errorf("expected {\"key\":\"value\"}, got %s", string(j2))
|
||||
}
|
||||
|
||||
// Test null
|
||||
var j3 SqlJSONB
|
||||
if err := json.Unmarshal([]byte("null"), &j3); err != nil {
|
||||
t.Fatalf("Unmarshal null failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlJSONB_AsMap(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input SqlJSONB
|
||||
wantErr bool
|
||||
wantNil bool
|
||||
}{
|
||||
{"valid object", SqlJSONB(`{"name":"test","age":30}`), false, false},
|
||||
{"empty", SqlJSONB{}, false, true},
|
||||
{"nil", nil, false, true},
|
||||
{"invalid JSON", SqlJSONB(`{invalid`), true, false},
|
||||
{"array not object", SqlJSONB(`[1,2,3]`), true, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
m, err := tt.input.AsMap()
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Error("expected error, got nil")
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("AsMap failed: %v", err)
|
||||
}
|
||||
if tt.wantNil {
|
||||
if m != nil {
|
||||
t.Errorf("expected nil, got %v", m)
|
||||
}
|
||||
return
|
||||
}
|
||||
if m == nil {
|
||||
t.Error("expected non-nil map")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlJSONB_AsSlice(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input SqlJSONB
|
||||
wantErr bool
|
||||
wantNil bool
|
||||
}{
|
||||
{"valid array", SqlJSONB(`[1,2,3]`), false, false},
|
||||
{"empty", SqlJSONB{}, false, true},
|
||||
{"nil", nil, false, true},
|
||||
{"invalid JSON", SqlJSONB(`[invalid`), true, false},
|
||||
{"object not array", SqlJSONB(`{"key":"value"}`), true, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s, err := tt.input.AsSlice()
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Error("expected error, got nil")
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("AsSlice failed: %v", err)
|
||||
}
|
||||
if tt.wantNil {
|
||||
if s != nil {
|
||||
t.Errorf("expected nil, got %v", s)
|
||||
}
|
||||
return
|
||||
}
|
||||
if s == nil {
|
||||
t.Error("expected non-nil slice")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSqlUUID tests SqlUUID type
|
||||
func TestSqlUUID_Scan(t *testing.T) {
|
||||
testUUID := uuid.New()
|
||||
testUUIDStr := testUUID.String()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
expected string
|
||||
valid bool
|
||||
}{
|
||||
{"string UUID", testUUIDStr, testUUIDStr, true},
|
||||
{"bytes UUID", []byte(testUUIDStr), testUUIDStr, true},
|
||||
{"nil", nil, "", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var u SqlUUID
|
||||
if err := u.Scan(tt.input); err != nil {
|
||||
t.Fatalf("Scan failed: %v", err)
|
||||
}
|
||||
if u.Valid != tt.valid {
|
||||
t.Errorf("expected valid=%v, got valid=%v", tt.valid, u.Valid)
|
||||
}
|
||||
if tt.valid && u.String != tt.expected {
|
||||
t.Errorf("expected %s, got %s", tt.expected, u.String)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlUUID_Value(t *testing.T) {
|
||||
testUUID := uuid.New()
|
||||
u := SqlUUID{String: testUUID.String(), Valid: true}
|
||||
|
||||
val, err := u.Value()
|
||||
if err != nil {
|
||||
t.Fatalf("Value failed: %v", err)
|
||||
}
|
||||
if val != testUUID.String() {
|
||||
t.Errorf("expected %s, got %s", testUUID.String(), val)
|
||||
}
|
||||
|
||||
// Test invalid UUID
|
||||
u2 := SqlUUID{Valid: false}
|
||||
val2, err := u2.Value()
|
||||
if err != nil {
|
||||
t.Fatalf("Value failed: %v", err)
|
||||
}
|
||||
if val2 != nil {
|
||||
t.Errorf("expected nil, got %v", val2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlUUID_JSON(t *testing.T) {
|
||||
testUUID := uuid.New()
|
||||
u := SqlUUID{String: testUUID.String(), Valid: true}
|
||||
|
||||
// Marshal
|
||||
data, err := json.Marshal(u)
|
||||
if err != nil {
|
||||
t.Fatalf("Marshal failed: %v", err)
|
||||
}
|
||||
expected := `"` + testUUID.String() + `"`
|
||||
if string(data) != expected {
|
||||
t.Errorf("expected %s, got %s", expected, string(data))
|
||||
}
|
||||
|
||||
// Unmarshal
|
||||
var u2 SqlUUID
|
||||
if err := json.Unmarshal([]byte(`"`+testUUID.String()+`"`), &u2); err != nil {
|
||||
t.Fatalf("Unmarshal failed: %v", err)
|
||||
}
|
||||
if u2.String != testUUID.String() {
|
||||
t.Errorf("expected %s, got %s", testUUID.String(), u2.String)
|
||||
}
|
||||
|
||||
// Test null
|
||||
var u3 SqlUUID
|
||||
if err := json.Unmarshal([]byte("null"), &u3); err != nil {
|
||||
t.Fatalf("Unmarshal null failed: %v", err)
|
||||
}
|
||||
if u3.Valid {
|
||||
t.Error("expected invalid UUID")
|
||||
}
|
||||
}
|
||||
|
||||
// TestTryIfInt64 tests the TryIfInt64 helper function
|
||||
func TestTryIfInt64(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
def int64
|
||||
expected int64
|
||||
}{
|
||||
{"string valid", "123", 0, 123},
|
||||
{"string invalid", "abc", 99, 99},
|
||||
{"int", 42, 0, 42},
|
||||
{"int32", int32(100), 0, 100},
|
||||
{"int64", int64(200), 0, 200},
|
||||
{"uint32", uint32(50), 0, 50},
|
||||
{"uint64", uint64(75), 0, 75},
|
||||
{"float32", float32(3.14), 0, 3},
|
||||
{"float64", float64(2.71), 0, 2},
|
||||
{"bytes", []byte("456"), 0, 456},
|
||||
{"unknown type", struct{}{}, 999, 999},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := TryIfInt64(tt.input, tt.def)
|
||||
if result != tt.expected {
|
||||
t.Errorf("expected %d, got %d", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -92,9 +92,27 @@ func (v *ColumnValidator) getColumnName(field reflect.StructField) string {
|
||||
return strings.ToLower(field.Name)
|
||||
}
|
||||
|
||||
// extractSourceColumn extracts the base column name from PostgreSQL JSON operators
|
||||
// Examples:
|
||||
// - "columna->>'val'" returns "columna"
|
||||
// - "columna->'key'" returns "columna"
|
||||
// - "columna" returns "columna"
|
||||
// - "table.columna->>'val'" returns "table.columna"
|
||||
func extractSourceColumn(colName string) string {
|
||||
// Check for PostgreSQL JSON operators: -> and ->>
|
||||
if idx := strings.Index(colName, "->>"); idx != -1 {
|
||||
return strings.TrimSpace(colName[:idx])
|
||||
}
|
||||
if idx := strings.Index(colName, "->"); idx != -1 {
|
||||
return strings.TrimSpace(colName[:idx])
|
||||
}
|
||||
return colName
|
||||
}
|
||||
|
||||
// ValidateColumn validates a single column name
|
||||
// Returns nil if valid, error if invalid
|
||||
// Columns prefixed with "cql" (case insensitive) are always valid
|
||||
// Handles PostgreSQL JSON operators (-> and ->>)
|
||||
func (v *ColumnValidator) ValidateColumn(column string) error {
|
||||
// Allow empty columns
|
||||
if column == "" {
|
||||
@@ -106,8 +124,11 @@ func (v *ColumnValidator) ValidateColumn(column string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Extract source column name (remove JSON operators like ->> or ->)
|
||||
sourceColumn := extractSourceColumn(column)
|
||||
|
||||
// Check if column exists in model
|
||||
if _, exists := v.validColumns[strings.ToLower(column)]; !exists {
|
||||
if _, exists := v.validColumns[strings.ToLower(sourceColumn)]; !exists {
|
||||
return fmt.Errorf("invalid column '%s': column does not exist in model", column)
|
||||
}
|
||||
|
||||
|
||||
124
pkg/common/validation_json_test.go
Normal file
124
pkg/common/validation_json_test.go
Normal file
@@ -0,0 +1,124 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestExtractSourceColumn(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "simple column name",
|
||||
input: "columna",
|
||||
expected: "columna",
|
||||
},
|
||||
{
|
||||
name: "column with ->> operator",
|
||||
input: "columna->>'val'",
|
||||
expected: "columna",
|
||||
},
|
||||
{
|
||||
name: "column with -> operator",
|
||||
input: "columna->'key'",
|
||||
expected: "columna",
|
||||
},
|
||||
{
|
||||
name: "column with table prefix and ->> operator",
|
||||
input: "table.columna->>'val'",
|
||||
expected: "table.columna",
|
||||
},
|
||||
{
|
||||
name: "column with table prefix and -> operator",
|
||||
input: "table.columna->'key'",
|
||||
expected: "table.columna",
|
||||
},
|
||||
{
|
||||
name: "complex JSON path with ->>",
|
||||
input: "data->>'nested'->>'value'",
|
||||
expected: "data",
|
||||
},
|
||||
{
|
||||
name: "column with spaces before operator",
|
||||
input: "columna ->>'val'",
|
||||
expected: "columna",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result := extractSourceColumn(tc.input)
|
||||
if result != tc.expected {
|
||||
t.Errorf("extractSourceColumn(%q) = %q; want %q", tc.input, result, tc.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateColumnWithJSONOperators(t *testing.T) {
|
||||
// Create a test model
|
||||
type TestModel struct {
|
||||
ID int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Data string `json:"data"` // JSON column
|
||||
Metadata string `json:"metadata"`
|
||||
}
|
||||
|
||||
validator := NewColumnValidator(TestModel{})
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
column string
|
||||
shouldErr bool
|
||||
}{
|
||||
{
|
||||
name: "simple valid column",
|
||||
column: "name",
|
||||
shouldErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid column with ->> operator",
|
||||
column: "data->>'field'",
|
||||
shouldErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid column with -> operator",
|
||||
column: "metadata->'key'",
|
||||
shouldErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid column",
|
||||
column: "invalid_column",
|
||||
shouldErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid column with ->> operator",
|
||||
column: "invalid_column->>'field'",
|
||||
shouldErr: true,
|
||||
},
|
||||
{
|
||||
name: "cql prefixed column (always valid)",
|
||||
column: "cql_computed",
|
||||
shouldErr: false,
|
||||
},
|
||||
{
|
||||
name: "empty column",
|
||||
column: "",
|
||||
shouldErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := validator.ValidateColumn(tc.column)
|
||||
if tc.shouldErr && err == nil {
|
||||
t.Errorf("ValidateColumn(%q) expected error, got nil", tc.column)
|
||||
}
|
||||
if !tc.shouldErr && err != nil {
|
||||
t.Errorf("ValidateColumn(%q) expected no error, got %v", tc.column, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -663,7 +663,14 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
|
||||
}
|
||||
|
||||
// Create insert query
|
||||
query := tx.NewInsert().Model(modelValue).Table(tableName).Returning("*")
|
||||
query := tx.NewInsert().Model(modelValue)
|
||||
|
||||
// Only set Table() if the model doesn't provide a table name via TableNameProvider
|
||||
if provider, ok := modelValue.(common.TableNameProvider); !ok || provider.TableName() == "" {
|
||||
query = query.Table(tableName)
|
||||
}
|
||||
|
||||
query = query.Returning("*")
|
||||
|
||||
// Execute BeforeScan hooks - pass query chain so hooks can modify it
|
||||
itemHookCtx := &HookContext{
|
||||
@@ -1640,10 +1647,13 @@ func (h *Handler) sendResponseWithOptions(w common.ResponseWriter, data interfac
|
||||
data = h.normalizeResultArray(data)
|
||||
}
|
||||
|
||||
response := common.Response{
|
||||
Success: true,
|
||||
Data: data,
|
||||
Metadata: metadata,
|
||||
response := data
|
||||
if response == nil {
|
||||
response = common.Response{
|
||||
Success: true,
|
||||
Data: data,
|
||||
Metadata: metadata,
|
||||
}
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if err := w.WriteJSON(response); err != nil {
|
||||
|
||||
@@ -480,12 +480,32 @@ func (h *Handler) parseCommaSeparated(value string) []string {
|
||||
return result
|
||||
}
|
||||
|
||||
// extractSourceColumn extracts the base column name from PostgreSQL JSON operators
|
||||
// Examples:
|
||||
// - "columna->>'val'" returns "columna"
|
||||
// - "columna->'key'" returns "columna"
|
||||
// - "columna" returns "columna"
|
||||
// - "table.columna->>'val'" returns "table.columna"
|
||||
func extractSourceColumn(colName string) string {
|
||||
// Check for PostgreSQL JSON operators: -> and ->>
|
||||
if idx := strings.Index(colName, "->>"); idx != -1 {
|
||||
return strings.TrimSpace(colName[:idx])
|
||||
}
|
||||
if idx := strings.Index(colName, "->"); idx != -1 {
|
||||
return strings.TrimSpace(colName[:idx])
|
||||
}
|
||||
return colName
|
||||
}
|
||||
|
||||
// getColumnTypeFromModel uses reflection to determine the Go type of a column in a model
|
||||
func (h *Handler) getColumnTypeFromModel(model interface{}, colName string) reflect.Kind {
|
||||
if model == nil {
|
||||
return reflect.Invalid
|
||||
}
|
||||
|
||||
// Extract the source column name (remove JSON operators like ->> or ->)
|
||||
sourceColName := extractSourceColumn(colName)
|
||||
|
||||
modelType := reflect.TypeOf(model)
|
||||
// Dereference pointer if needed
|
||||
if modelType.Kind() == reflect.Ptr {
|
||||
@@ -506,19 +526,19 @@ func (h *Handler) getColumnTypeFromModel(model interface{}, colName string) refl
|
||||
if jsonTag != "" {
|
||||
// Parse JSON tag (format: "name,omitempty")
|
||||
parts := strings.Split(jsonTag, ",")
|
||||
if parts[0] == colName {
|
||||
if parts[0] == sourceColName {
|
||||
return field.Type.Kind()
|
||||
}
|
||||
}
|
||||
|
||||
// Check field name (case-insensitive)
|
||||
if strings.EqualFold(field.Name, colName) {
|
||||
if strings.EqualFold(field.Name, sourceColName) {
|
||||
return field.Type.Kind()
|
||||
}
|
||||
|
||||
// Check snake_case conversion
|
||||
snakeCaseName := toSnakeCase(field.Name)
|
||||
if snakeCaseName == colName {
|
||||
if snakeCaseName == sourceColName {
|
||||
return field.Type.Kind()
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user