diff --git a/pkg/spectypes/sql_types.go b/pkg/spectypes/sql_types.go index 43afc93..1d008d1 100644 --- a/pkg/spectypes/sql_types.go +++ b/pkg/spectypes/sql_types.go @@ -4,6 +4,7 @@ package spectypes import ( "database/sql" "database/sql/driver" + "encoding/base64" "encoding/json" "fmt" "reflect" @@ -60,7 +61,34 @@ func (n *SqlNull[T]) Scan(value any) error { return nil } - // Try standard sql.Null[T] first. + // Check if T is []byte, and decode base64 if applicable + // Do this BEFORE trying sql.Null to ensure base64 is handled + var zero T + switch any(zero).(type) { + case []byte: + // For []byte types, try to decode from base64 + var strVal string + switch v := value.(type) { + case string: + strVal = v + case []byte: + strVal = string(v) + default: + strVal = fmt.Sprintf("%v", value) + } + // Try base64 decode + if decoded, err := base64.StdEncoding.DecodeString(strVal); err == nil { + n.Val = any(decoded).(T) + n.Valid = true + return nil + } + // Fallback to raw bytes + n.Val = any([]byte(strVal)).(T) + n.Valid = true + return nil + } + + // Try standard sql.Null[T] for other types. var sqlNull sql.Null[T] if err := sqlNull.Scan(value); err == nil { n.Val = sqlNull.V @@ -122,6 +150,9 @@ func (n *SqlNull[T]) FromString(s string) error { n.Val = any(u).(T) n.Valid = true } + case []byte: + n.Val = any([]byte(s)).(T) + n.Valid = true case string: n.Val = any(s).(T) n.Valid = true @@ -149,6 +180,15 @@ func (n SqlNull[T]) MarshalJSON() ([]byte, error) { if !n.Valid { return []byte("null"), nil } + + // Check if T is []byte, and encode to base64 + switch v := any(n.Val).(type) { + case []byte: + // Encode []byte as base64 + encoded := base64.StdEncoding.EncodeToString(v) + return json.Marshal(encoded) + } + return json.Marshal(n.Val) } @@ -160,8 +200,26 @@ func (n *SqlNull[T]) UnmarshalJSON(b []byte) error { return nil } - // Try direct unmarshal. + // Check if T is []byte, and decode from base64 var val T + switch any(val).(type) { + case []byte: + // Unmarshal as string first (JSON representation) + var s string + if err := json.Unmarshal(b, &s); err == nil { + // Decode from base64 + if decoded, err := base64.StdEncoding.DecodeString(s); err == nil { + n.Val = any(decoded).(T) + n.Valid = true + return nil + } + // Fallback to raw string as bytes + n.Val = any([]byte(s)).(T) + n.Valid = true + return nil + } + } + if err := json.Unmarshal(b, &val); err == nil { n.Val = val n.Valid = true @@ -271,13 +329,14 @@ func (n SqlNull[T]) UUID() uuid.UUID { // Type aliases for common types. type ( - SqlInt16 = SqlNull[int16] - SqlInt32 = SqlNull[int32] - SqlInt64 = SqlNull[int64] - SqlFloat64 = SqlNull[float64] - SqlBool = SqlNull[bool] - SqlString = SqlNull[string] - SqlUUID = SqlNull[uuid.UUID] + SqlInt16 = SqlNull[int16] + SqlInt32 = SqlNull[int32] + SqlInt64 = SqlNull[int64] + SqlFloat64 = SqlNull[float64] + SqlBool = SqlNull[bool] + SqlString = SqlNull[string] + SqlByteArray = SqlNull[[]byte] + SqlUUID = SqlNull[uuid.UUID] ) // SqlTimeStamp - Timestamp with custom formatting (YYYY-MM-DDTHH:MM:SS). @@ -581,6 +640,10 @@ func NewSqlString(v string) SqlString { return SqlString{Val: v, Valid: true} } +func NewSqlByteArray(v []byte) SqlByteArray { + return SqlByteArray{Val: v, Valid: true} +} + func NewSqlUUID(v uuid.UUID) SqlUUID { return SqlUUID{Val: v, Valid: true} } diff --git a/pkg/spectypes/sql_types_test.go b/pkg/spectypes/sql_types_test.go index 57e7614..7e743c3 100644 --- a/pkg/spectypes/sql_types_test.go +++ b/pkg/spectypes/sql_types_test.go @@ -565,3 +565,394 @@ func TestTryIfInt64(t *testing.T) { }) } } + +// TestSqlString tests SqlString without base64 (plain text) +func TestSqlString_Scan(t *testing.T) { + tests := []struct { + name string + input interface{} + expected string + valid bool + }{ + { + name: "plain string", + input: "hello world", + expected: "hello world", + valid: true, + }, + { + name: "plain text", + input: "plain text", + expected: "plain text", + valid: true, + }, + { + name: "bytes as string", + input: []byte("raw bytes"), + expected: "raw bytes", + valid: true, + }, + { + name: "nil value", + input: nil, + expected: "", + valid: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var s SqlString + if err := s.Scan(tt.input); err != nil { + t.Fatalf("Scan failed: %v", err) + } + if s.Valid != tt.valid { + t.Errorf("expected valid=%v, got valid=%v", tt.valid, s.Valid) + } + if tt.valid && s.String() != tt.expected { + t.Errorf("expected %q, got %q", tt.expected, s.String()) + } + }) + } +} + +func TestSqlString_JSON(t *testing.T) { + tests := []struct { + name string + inputValue string + expectedJSON string + expectedDecode string + }{ + { + name: "simple string", + inputValue: "hello world", + expectedJSON: `"hello world"`, // plain text, not base64 + expectedDecode: "hello world", + }, + { + name: "special characters", + inputValue: "test@#$%", + expectedJSON: `"test@#$%"`, // plain text, not base64 + expectedDecode: "test@#$%", + }, + { + name: "unicode string", + inputValue: "Hello 世界", + expectedJSON: `"Hello 世界"`, // plain text, not base64 + expectedDecode: "Hello 世界", + }, + { + name: "empty string", + inputValue: "", + expectedJSON: `""`, + expectedDecode: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Test MarshalJSON + s := NewSqlString(tt.inputValue) + data, err := json.Marshal(s) + if err != nil { + t.Fatalf("Marshal failed: %v", err) + } + if string(data) != tt.expectedJSON { + t.Errorf("Marshal: expected %s, got %s", tt.expectedJSON, string(data)) + } + + // Test UnmarshalJSON + var s2 SqlString + if err := json.Unmarshal(data, &s2); err != nil { + t.Fatalf("Unmarshal failed: %v", err) + } + if !s2.Valid { + t.Error("expected valid=true after unmarshal") + } + if s2.String() != tt.expectedDecode { + t.Errorf("Unmarshal: expected %q, got %q", tt.expectedDecode, s2.String()) + } + }) + } +} + +func TestSqlString_JSON_Null(t *testing.T) { + // Test null handling + var s SqlString + if err := json.Unmarshal([]byte("null"), &s); err != nil { + t.Fatalf("Unmarshal null failed: %v", err) + } + if s.Valid { + t.Error("expected invalid after unmarshaling null") + } + + // Test marshal null + data, err := json.Marshal(s) + if err != nil { + t.Fatalf("Marshal failed: %v", err) + } + if string(data) != "null" { + t.Errorf("expected null, got %s", string(data)) + } +} + +// TestSqlByteArray_Base64 tests SqlByteArray with base64 encoding/decoding +func TestSqlByteArray_Base64_Scan(t *testing.T) { + tests := []struct { + name string + input interface{} + expected []byte + valid bool + }{ + { + name: "base64 encoded bytes from SQL", + input: "aGVsbG8gd29ybGQ=", // "hello world" in base64 + expected: []byte("hello world"), + valid: true, + }, + { + name: "plain bytes fallback", + input: "plain text", + expected: []byte("plain text"), + valid: true, + }, + { + name: "bytes base64 encoded", + input: []byte("SGVsbG8gR29waGVy"), // "Hello Gopher" in base64 + expected: []byte("Hello Gopher"), + valid: true, + }, + { + name: "bytes plain fallback", + input: []byte("raw bytes"), + expected: []byte("raw bytes"), + valid: true, + }, + { + name: "binary data", + input: "AQIDBA==", // []byte{1, 2, 3, 4} in base64 + expected: []byte{1, 2, 3, 4}, + valid: true, + }, + { + name: "nil value", + input: nil, + expected: nil, + valid: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var b SqlByteArray + if err := b.Scan(tt.input); err != nil { + t.Fatalf("Scan failed: %v", err) + } + if b.Valid != tt.valid { + t.Errorf("expected valid=%v, got valid=%v", tt.valid, b.Valid) + } + if tt.valid { + if string(b.Val) != string(tt.expected) { + t.Errorf("expected %q, got %q", tt.expected, b.Val) + } + } + }) + } +} + +func TestSqlByteArray_Base64_JSON(t *testing.T) { + tests := []struct { + name string + inputValue []byte + expectedJSON string + expectedDecode []byte + }{ + { + name: "text bytes", + inputValue: []byte("hello world"), + expectedJSON: `"aGVsbG8gd29ybGQ="`, // base64 encoded + expectedDecode: []byte("hello world"), + }, + { + name: "binary data", + inputValue: []byte{0x01, 0x02, 0x03, 0x04, 0xFF}, + expectedJSON: `"AQIDBP8="`, // base64 encoded + expectedDecode: []byte{0x01, 0x02, 0x03, 0x04, 0xFF}, + }, + { + name: "empty bytes", + inputValue: []byte{}, + expectedJSON: `""`, // base64 of empty bytes + expectedDecode: []byte{}, + }, + { + name: "unicode bytes", + inputValue: []byte("Hello 世界"), + expectedJSON: `"SGVsbG8g5LiW55WM"`, // base64 encoded + expectedDecode: []byte("Hello 世界"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Test MarshalJSON + b := NewSqlByteArray(tt.inputValue) + data, err := json.Marshal(b) + if err != nil { + t.Fatalf("Marshal failed: %v", err) + } + if string(data) != tt.expectedJSON { + t.Errorf("Marshal: expected %s, got %s", tt.expectedJSON, string(data)) + } + + // Test UnmarshalJSON + var b2 SqlByteArray + if err := json.Unmarshal(data, &b2); err != nil { + t.Fatalf("Unmarshal failed: %v", err) + } + if !b2.Valid { + t.Error("expected valid=true after unmarshal") + } + if string(b2.Val) != string(tt.expectedDecode) { + t.Errorf("Unmarshal: expected %v, got %v", tt.expectedDecode, b2.Val) + } + }) + } +} + +func TestSqlByteArray_Base64_JSON_Null(t *testing.T) { + // Test null handling + var b SqlByteArray + if err := json.Unmarshal([]byte("null"), &b); err != nil { + t.Fatalf("Unmarshal null failed: %v", err) + } + if b.Valid { + t.Error("expected invalid after unmarshaling null") + } + + // Test marshal null + data, err := json.Marshal(b) + if err != nil { + t.Fatalf("Marshal failed: %v", err) + } + if string(data) != "null" { + t.Errorf("expected null, got %s", string(data)) + } +} + +func TestSqlByteArray_Value(t *testing.T) { + tests := []struct { + name string + input SqlByteArray + expected interface{} + }{ + { + name: "valid bytes", + input: NewSqlByteArray([]byte("test data")), + expected: []byte("test data"), + }, + { + name: "empty bytes", + input: NewSqlByteArray([]byte{}), + expected: []byte{}, + }, + { + name: "invalid", + input: SqlByteArray{Valid: false}, + expected: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + val, err := tt.input.Value() + if err != nil { + t.Fatalf("Value failed: %v", err) + } + if tt.expected == nil && val != nil { + t.Errorf("expected nil, got %v", val) + } + if tt.expected != nil && val == nil { + t.Errorf("expected %v, got nil", tt.expected) + } + if tt.expected != nil && val != nil { + if string(val.([]byte)) != string(tt.expected.([]byte)) { + t.Errorf("expected %v, got %v", tt.expected, val) + } + } + }) + } +} + +// TestSqlString_RoundTrip tests complete round-trip: Go -> JSON -> Go -> SQL -> Go +func TestSqlString_RoundTrip(t *testing.T) { + original := "Test String with Special Chars: @#$%^&*()" + + // Go -> JSON + s1 := NewSqlString(original) + jsonData, err := json.Marshal(s1) + if err != nil { + t.Fatalf("Marshal failed: %v", err) + } + + // JSON -> Go + var s2 SqlString + if err := json.Unmarshal(jsonData, &s2); err != nil { + t.Fatalf("Unmarshal failed: %v", err) + } + + // Go -> SQL (Value) + _, err = s2.Value() + if err != nil { + t.Fatalf("Value failed: %v", err) + } + + // SQL -> Go (Scan plain text) + var s3 SqlString + // Simulate SQL driver returning plain text value + if err := s3.Scan(original); err != nil { + t.Fatalf("Scan failed: %v", err) + } + + // Verify round-trip + if s3.String() != original { + t.Errorf("Round-trip failed: expected %q, got %q", original, s3.String()) + } +} + +// TestSqlByteArray_Base64_RoundTrip tests complete round-trip: Go -> JSON -> Go -> SQL -> Go +func TestSqlByteArray_Base64_RoundTrip(t *testing.T) { + original := []byte{0x48, 0x65, 0x6C, 0x6C, 0x6F, 0x20, 0xFF, 0xFE} // "Hello " + binary data + + // Go -> JSON + b1 := NewSqlByteArray(original) + jsonData, err := json.Marshal(b1) + if err != nil { + t.Fatalf("Marshal failed: %v", err) + } + + // JSON -> Go + var b2 SqlByteArray + if err := json.Unmarshal(jsonData, &b2); err != nil { + t.Fatalf("Unmarshal failed: %v", err) + } + + // Go -> SQL (Value) + _, err = b2.Value() + if err != nil { + t.Fatalf("Value failed: %v", err) + } + + // SQL -> Go (Scan with base64) + var b3 SqlByteArray + // Simulate SQL driver returning base64 encoded value + if err := b3.Scan("SGVsbG8g//4="); err != nil { + t.Fatalf("Scan failed: %v", err) + } + + // Verify round-trip + if string(b3.Val) != string(original) { + t.Errorf("Round-trip failed: expected %v, got %v", original, b3.Val) + } +} +