mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2025-12-31 08:44:25 +00:00
Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
dc3b621380 | ||
|
|
a4dd2a7086 | ||
|
|
3ec2e5f15a |
@@ -71,15 +71,14 @@ func (n *SqlNull[T]) Scan(value any) error {
|
||||
// Fallback: parse from string/bytes.
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
return n.fromString(v)
|
||||
return n.FromString(v)
|
||||
case []byte:
|
||||
return n.fromString(string(v))
|
||||
return n.FromString(string(v))
|
||||
default:
|
||||
return n.fromString(fmt.Sprintf("%v", value))
|
||||
return n.FromString(fmt.Sprintf("%v", value))
|
||||
}
|
||||
}
|
||||
|
||||
func (n *SqlNull[T]) fromString(s string) error {
|
||||
func (n *SqlNull[T]) FromString(s string) error {
|
||||
s = strings.TrimSpace(s)
|
||||
n.Valid = false
|
||||
n.Val = *new(T)
|
||||
@@ -90,19 +89,14 @@ func (n *SqlNull[T]) fromString(s string) error {
|
||||
|
||||
var zero T
|
||||
switch any(zero).(type) {
|
||||
case int, int8, int16, int32, int64:
|
||||
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
|
||||
if i, err := strconv.ParseInt(s, 10, 64); err == nil {
|
||||
n.Val = any(int64(i)).(T) // Cast to T (e.g., int16)
|
||||
n.Valid = true
|
||||
}
|
||||
case uint, uint8, uint16, uint32, uint64:
|
||||
if u, err := strconv.ParseUint(s, 10, 64); err == nil {
|
||||
n.Val = any(u).(T)
|
||||
reflect.ValueOf(&n.Val).Elem().SetInt(i)
|
||||
n.Valid = true
|
||||
}
|
||||
case float32, float64:
|
||||
if f, err := strconv.ParseFloat(s, 64); err == nil {
|
||||
n.Val = any(f).(T)
|
||||
reflect.ValueOf(&n.Val).Elem().SetFloat(f)
|
||||
n.Valid = true
|
||||
}
|
||||
case bool:
|
||||
@@ -124,7 +118,6 @@ func (n *SqlNull[T]) fromString(s string) error {
|
||||
n.Val = any(s).(T)
|
||||
n.Valid = true
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -163,7 +156,7 @@ func (n *SqlNull[T]) UnmarshalJSON(b []byte) error {
|
||||
// Fallback: unmarshal as string and parse.
|
||||
var s string
|
||||
if err := json.Unmarshal(b, &s); err == nil {
|
||||
return n.fromString(s)
|
||||
return n.FromString(s)
|
||||
}
|
||||
|
||||
return fmt.Errorf("cannot unmarshal %s into SqlNull[%T]", b, n.Val)
|
||||
@@ -517,6 +510,33 @@ func TryIfInt64(v any, def int64) int64 {
|
||||
}
|
||||
|
||||
// Constructor helpers - clean and fast value creation
|
||||
func Null[T any](v T, valid bool) SqlNull[T] {
|
||||
return SqlNull[T]{Val: v, Valid: valid}
|
||||
}
|
||||
|
||||
func NewSql[T any](value any) SqlNull[T] {
|
||||
n := SqlNull[T]{}
|
||||
|
||||
if value == nil {
|
||||
return n
|
||||
}
|
||||
|
||||
// Fast path: exact match
|
||||
if v, ok := value.(T); ok {
|
||||
n.Val = v
|
||||
n.Valid = true
|
||||
return n
|
||||
}
|
||||
|
||||
// Try from another SqlNull
|
||||
if sn, ok := value.(SqlNull[T]); ok {
|
||||
return sn
|
||||
}
|
||||
|
||||
// Convert via string
|
||||
_ = n.FromString(fmt.Sprintf("%v", value))
|
||||
return n
|
||||
}
|
||||
|
||||
func NewSqlInt16(v int16) SqlInt16 {
|
||||
return SqlInt16{Val: v, Valid: true}
|
||||
|
||||
@@ -16,11 +16,11 @@ func TestNewSqlInt16(t *testing.T) {
|
||||
input interface{}
|
||||
expected SqlInt16
|
||||
}{
|
||||
{"int", 42, NewSqlInt16(42)},
|
||||
{"int", 42, Null(int16(42), true)},
|
||||
{"int32", int32(100), NewSqlInt16(100)},
|
||||
{"int64", int64(200), NewSqlInt16(200)},
|
||||
{"string", "123", NewSqlInt16(123)},
|
||||
{"nil", nil, NewSqlInt16(0)},
|
||||
{"nil", nil, Null(int16(0), false)},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -42,9 +42,9 @@ func TestNewSqlInt16_Value(t *testing.T) {
|
||||
input SqlInt16
|
||||
expected driver.Value
|
||||
}{
|
||||
{"zero", NewSqlInt16(0), nil},
|
||||
{"positive", NewSqlInt16(42), int64(42)},
|
||||
{"negative", NewSqlInt16(-10), int64(-10)},
|
||||
{"zero", Null(int16(0), false), nil},
|
||||
{"positive", NewSqlInt16(42), int16(42)},
|
||||
{"negative", NewSqlInt16(-10), int16(-10)},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -486,7 +486,7 @@ func TestSqlUUID_Value(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("Value failed: %v", err)
|
||||
}
|
||||
if val != testUUID.String() {
|
||||
if val != testUUID {
|
||||
t.Errorf("expected %s, got %s", testUUID.String(), val)
|
||||
}
|
||||
|
||||
|
||||
@@ -16,8 +16,8 @@ import (
|
||||
|
||||
// MockDatabase implements common.Database interface for testing
|
||||
type MockDatabase struct {
|
||||
QueryFunc func(ctx context.Context, dest interface{}, query string, args ...interface{}) error
|
||||
ExecFunc func(ctx context.Context, query string, args ...interface{}) (common.Result, error)
|
||||
QueryFunc func(ctx context.Context, dest interface{}, query string, args ...interface{}) error
|
||||
ExecFunc func(ctx context.Context, query string, args ...interface{}) (common.Result, error)
|
||||
RunInTransactionFunc func(ctx context.Context, fn func(common.Database) error) error
|
||||
}
|
||||
|
||||
@@ -161,9 +161,9 @@ func TestExtractInputVariables(t *testing.T) {
|
||||
handler := NewHandler(&MockDatabase{})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
sqlQuery string
|
||||
expectedVars []string
|
||||
name string
|
||||
sqlQuery string
|
||||
expectedVars []string
|
||||
}{
|
||||
{
|
||||
name: "No variables",
|
||||
@@ -340,9 +340,9 @@ func TestSqlQryWhere(t *testing.T) {
|
||||
// TestGetIPAddress tests IP address extraction
|
||||
func TestGetIPAddress(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupReq func() *http.Request
|
||||
expected string
|
||||
name string
|
||||
setupReq func() *http.Request
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "X-Forwarded-For header",
|
||||
@@ -782,9 +782,10 @@ func TestReplaceMetaVariables(t *testing.T) {
|
||||
handler := NewHandler(&MockDatabase{})
|
||||
|
||||
userCtx := &security.UserContext{
|
||||
UserID: 123,
|
||||
UserName: "testuser",
|
||||
SessionID: "456",
|
||||
UserID: 123,
|
||||
UserName: "testuser",
|
||||
SessionID: "ABC456",
|
||||
SessionRID: 456,
|
||||
}
|
||||
|
||||
metainfo := map[string]interface{}{
|
||||
@@ -821,6 +822,12 @@ func TestReplaceMetaVariables(t *testing.T) {
|
||||
expectedCheck: func(result string) bool {
|
||||
return strings.Contains(result, "456")
|
||||
},
|
||||
}, {
|
||||
name: "Replace [id_session]",
|
||||
sqlQuery: "SELECT * FROM sessions WHERE session_id = [id_session]",
|
||||
expectedCheck: func(result string) bool {
|
||||
return strings.Contains(result, "ABC456")
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user