diff --git a/pkg/reflection/model_utils.go b/pkg/reflection/model_utils.go index 2a25085..2fda683 100644 --- a/pkg/reflection/model_utils.go +++ b/pkg/reflection/model_utils.go @@ -1,6 +1,7 @@ package reflection import ( + "encoding/json" "fmt" "reflect" "strconv" @@ -897,6 +898,319 @@ func GetRelationModel(model interface{}, fieldName string) interface{} { return currentModel } +// MapToStruct populates a struct from a map while preserving custom types +// It uses reflection to set struct fields based on map keys, matching by: +// 1. Bun tag column name +// 2. Gorm tag column name +// 3. JSON tag name +// 4. Field name (case-insensitive) +// This preserves custom types that implement driver.Valuer like SqlJSONB +func MapToStruct(dataMap map[string]interface{}, target interface{}) error { + if dataMap == nil || target == nil { + return fmt.Errorf("dataMap and target cannot be nil") + } + + targetValue := reflect.ValueOf(target) + if targetValue.Kind() != reflect.Ptr { + return fmt.Errorf("target must be a pointer to a struct") + } + + targetValue = targetValue.Elem() + if targetValue.Kind() != reflect.Struct { + return fmt.Errorf("target must be a pointer to a struct") + } + + targetType := targetValue.Type() + + // Create a map of column names to field indices for faster lookup + columnToField := make(map[string]int) + for i := 0; i < targetType.NumField(); i++ { + field := targetType.Field(i) + + // Skip unexported fields + if !field.IsExported() { + continue + } + + // Build list of possible column names for this field + var columnNames []string + + // 1. Bun tag + if bunTag := field.Tag.Get("bun"); bunTag != "" && bunTag != "-" { + if colName := ExtractColumnFromBunTag(bunTag); colName != "" { + columnNames = append(columnNames, colName) + } + } + + // 2. Gorm tag + if gormTag := field.Tag.Get("gorm"); gormTag != "" && gormTag != "-" { + if colName := ExtractColumnFromGormTag(gormTag); colName != "" { + columnNames = append(columnNames, colName) + } + } + + // 3. JSON tag + if jsonTag := field.Tag.Get("json"); jsonTag != "" && jsonTag != "-" { + parts := strings.Split(jsonTag, ",") + if len(parts) > 0 && parts[0] != "" { + columnNames = append(columnNames, parts[0]) + } + } + + // 4. Field name variations + columnNames = append(columnNames, field.Name) + columnNames = append(columnNames, strings.ToLower(field.Name)) + columnNames = append(columnNames, ToSnakeCase(field.Name)) + + // Map all column name variations to this field index + for _, colName := range columnNames { + columnToField[strings.ToLower(colName)] = i + } + } + + // Iterate through the map and set struct fields + for key, value := range dataMap { + // Find the field index for this key + fieldIndex, found := columnToField[strings.ToLower(key)] + if !found { + // Skip keys that don't map to any field + continue + } + + field := targetValue.Field(fieldIndex) + if !field.CanSet() { + continue + } + + // Set the value, preserving custom types + if err := setFieldValue(field, value); err != nil { + return fmt.Errorf("failed to set field %s: %w", targetType.Field(fieldIndex).Name, err) + } + } + + return nil +} + +// setFieldValue sets a reflect.Value from an interface{} value, handling type conversions +func setFieldValue(field reflect.Value, value interface{}) error { + if value == nil { + // Set zero value for nil + field.Set(reflect.Zero(field.Type())) + return nil + } + + valueReflect := reflect.ValueOf(value) + + // If types match exactly, just set it + if valueReflect.Type().AssignableTo(field.Type()) { + field.Set(valueReflect) + return nil + } + + // Handle pointer fields + if field.Kind() == reflect.Ptr { + if valueReflect.Kind() != reflect.Ptr { + // Create a new pointer and set its value + newPtr := reflect.New(field.Type().Elem()) + if err := setFieldValue(newPtr.Elem(), value); err != nil { + return err + } + field.Set(newPtr) + return nil + } + } + + // Handle conversions for basic types + switch field.Kind() { + case reflect.String: + if str, ok := value.(string); ok { + field.SetString(str) + return nil + } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + if num, ok := convertToInt64(value); ok { + if field.OverflowInt(num) { + return fmt.Errorf("integer overflow") + } + field.SetInt(num) + return nil + } + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + if num, ok := convertToUint64(value); ok { + if field.OverflowUint(num) { + return fmt.Errorf("unsigned integer overflow") + } + field.SetUint(num) + return nil + } + case reflect.Float32, reflect.Float64: + if num, ok := convertToFloat64(value); ok { + if field.OverflowFloat(num) { + return fmt.Errorf("float overflow") + } + field.SetFloat(num) + return nil + } + case reflect.Bool: + if b, ok := value.(bool); ok { + field.SetBool(b) + return nil + } + case reflect.Slice: + // Handle []byte specially (for types like SqlJSONB) + if field.Type().Elem().Kind() == reflect.Uint8 { + switch v := value.(type) { + case []byte: + field.SetBytes(v) + return nil + case string: + field.SetBytes([]byte(v)) + return nil + case map[string]interface{}, []interface{}: + // Marshal complex types to JSON for SqlJSONB fields + jsonBytes, err := json.Marshal(v) + if err != nil { + return fmt.Errorf("failed to marshal value to JSON: %w", err) + } + field.SetBytes(jsonBytes) + return nil + } + } + } + + // Handle struct types (like SqlTimeStamp, SqlDate, SqlTime which wrap SqlNull[time.Time]) + if field.Kind() == reflect.Struct { + // Try to find a "Val" field (for SqlNull types) and set it + valField := field.FieldByName("Val") + if valField.IsValid() && valField.CanSet() { + // Also set Valid field to true + validField := field.FieldByName("Valid") + if validField.IsValid() && validField.CanSet() && validField.Kind() == reflect.Bool { + // Set the Val field + if err := setFieldValue(valField, value); err != nil { + return err + } + // Set Valid to true + validField.SetBool(true) + return nil + } + } + } + + // If we can convert the type, do it + if valueReflect.Type().ConvertibleTo(field.Type()) { + field.Set(valueReflect.Convert(field.Type())) + return nil + } + + return fmt.Errorf("cannot convert %v to %v", valueReflect.Type(), field.Type()) +} + +// convertToInt64 attempts to convert various types to int64 +func convertToInt64(value interface{}) (int64, bool) { + switch v := value.(type) { + case int: + return int64(v), true + case int8: + return int64(v), true + case int16: + return int64(v), true + case int32: + return int64(v), true + case int64: + return v, true + case uint: + return int64(v), true + case uint8: + return int64(v), true + case uint16: + return int64(v), true + case uint32: + return int64(v), true + case uint64: + return int64(v), true + case float32: + return int64(v), true + case float64: + return int64(v), true + case string: + if num, err := strconv.ParseInt(v, 10, 64); err == nil { + return num, true + } + } + return 0, false +} + +// convertToUint64 attempts to convert various types to uint64 +func convertToUint64(value interface{}) (uint64, bool) { + switch v := value.(type) { + case int: + return uint64(v), true + case int8: + return uint64(v), true + case int16: + return uint64(v), true + case int32: + return uint64(v), true + case int64: + return uint64(v), true + case uint: + return uint64(v), true + case uint8: + return uint64(v), true + case uint16: + return uint64(v), true + case uint32: + return uint64(v), true + case uint64: + return v, true + case float32: + return uint64(v), true + case float64: + return uint64(v), true + case string: + if num, err := strconv.ParseUint(v, 10, 64); err == nil { + return num, true + } + } + return 0, false +} + +// convertToFloat64 attempts to convert various types to float64 +func convertToFloat64(value interface{}) (float64, bool) { + switch v := value.(type) { + case int: + return float64(v), true + case int8: + return float64(v), true + case int16: + return float64(v), true + case int32: + return float64(v), true + case int64: + return float64(v), true + case uint: + return float64(v), true + case uint8: + return float64(v), true + case uint16: + return float64(v), true + case uint32: + return float64(v), true + case uint64: + return float64(v), true + case float32: + return float64(v), true + case float64: + return v, true + case string: + if num, err := strconv.ParseFloat(v, 64); err == nil { + return num, true + } + } + return 0, false +} + // getRelationModelSingleLevel gets the model type for a single level field (non-recursive) // This is a helper function used by GetRelationModel to handle one level at a time func getRelationModelSingleLevel(model interface{}, fieldName string) interface{} { diff --git a/pkg/reflection/model_utils_sqltypes_test.go b/pkg/reflection/model_utils_sqltypes_test.go new file mode 100644 index 0000000..030c6d4 --- /dev/null +++ b/pkg/reflection/model_utils_sqltypes_test.go @@ -0,0 +1,266 @@ +package reflection_test + +import ( + "testing" + "time" + + "github.com/bitechdev/ResolveSpec/pkg/common" + "github.com/bitechdev/ResolveSpec/pkg/reflection" +) + +func TestMapToStruct_SqlJSONB_PreservesDriverValuer(t *testing.T) { + // Test that SqlJSONB type preserves driver.Valuer interface + type TestModel struct { + ID int64 `bun:"id,pk" json:"id"` + Meta common.SqlJSONB `bun:"meta" json:"meta"` + } + + dataMap := map[string]interface{}{ + "id": int64(123), + "meta": map[string]interface{}{ + "key": "value", + "num": 42, + }, + } + + var result TestModel + err := reflection.MapToStruct(dataMap, &result) + if err != nil { + t.Fatalf("MapToStruct() error = %v", err) + } + + // Verify the field was set + if result.ID != 123 { + t.Errorf("ID = %v, want 123", result.ID) + } + + // Verify SqlJSONB was populated + if len(result.Meta) == 0 { + t.Error("Meta is empty, want non-empty") + } + + // Most importantly: verify driver.Valuer interface works + value, err := result.Meta.Value() + if err != nil { + t.Errorf("Meta.Value() error = %v, want nil", err) + } + + // Value should return a string representation of the JSON + if value == nil { + t.Error("Meta.Value() returned nil, want non-nil") + } + + // Check it's a valid JSON string + if str, ok := value.(string); ok { + if len(str) == 0 { + t.Error("Meta.Value() returned empty string, want valid JSON") + } + t.Logf("SqlJSONB.Value() returned: %s", str) + } else { + t.Errorf("Meta.Value() returned type %T, want string", value) + } +} + +func TestMapToStruct_SqlJSONB_FromBytes(t *testing.T) { + // Test that SqlJSONB can be set from []byte directly + type TestModel struct { + ID int64 `bun:"id,pk" json:"id"` + Meta common.SqlJSONB `bun:"meta" json:"meta"` + } + + jsonBytes := []byte(`{"direct":"bytes"}`) + dataMap := map[string]interface{}{ + "id": int64(456), + "meta": jsonBytes, + } + + var result TestModel + err := reflection.MapToStruct(dataMap, &result) + if err != nil { + t.Fatalf("MapToStruct() error = %v", err) + } + + if result.ID != 456 { + t.Errorf("ID = %v, want 456", result.ID) + } + + if string(result.Meta) != string(jsonBytes) { + t.Errorf("Meta = %s, want %s", string(result.Meta), string(jsonBytes)) + } + + // Verify driver.Valuer works + value, err := result.Meta.Value() + if err != nil { + t.Errorf("Meta.Value() error = %v", err) + } + if value == nil { + t.Error("Meta.Value() returned nil") + } +} + +func TestMapToStruct_AllSqlTypes(t *testing.T) { + // Test model with all SQL custom types + type TestModel struct { + ID int64 `bun:"id,pk" json:"id"` + Name string `bun:"name" json:"name"` + CreatedAt common.SqlTimeStamp `bun:"created_at" json:"created_at"` + BirthDate common.SqlDate `bun:"birth_date" json:"birth_date"` + LoginTime common.SqlTime `bun:"login_time" json:"login_time"` + Meta common.SqlJSONB `bun:"meta" json:"meta"` + Tags common.SqlJSONB `bun:"tags" json:"tags"` + } + + now := time.Now() + birthDate := time.Date(1990, 1, 15, 0, 0, 0, 0, time.UTC) + loginTime := time.Date(0, 1, 1, 14, 30, 0, 0, time.UTC) + + dataMap := map[string]interface{}{ + "id": int64(100), + "name": "Test User", + "created_at": now, + "birth_date": birthDate, + "login_time": loginTime, + "meta": map[string]interface{}{ + "role": "admin", + "active": true, + }, + "tags": []interface{}{"golang", "testing", "sql"}, + } + + var result TestModel + err := reflection.MapToStruct(dataMap, &result) + if err != nil { + t.Fatalf("MapToStruct() error = %v", err) + } + + // Verify basic fields + if result.ID != 100 { + t.Errorf("ID = %v, want 100", result.ID) + } + if result.Name != "Test User" { + t.Errorf("Name = %v, want 'Test User'", result.Name) + } + + // Verify SqlTimeStamp + if !result.CreatedAt.Valid { + t.Error("CreatedAt.Valid = false, want true") + } + if !result.CreatedAt.Val.Equal(now) { + t.Errorf("CreatedAt.Val = %v, want %v", result.CreatedAt.Val, now) + } + + // Verify driver.Valuer for SqlTimeStamp + tsValue, err := result.CreatedAt.Value() + if err != nil { + t.Errorf("CreatedAt.Value() error = %v", err) + } + if tsValue == nil { + t.Error("CreatedAt.Value() returned nil") + } + + // Verify SqlDate + if !result.BirthDate.Valid { + t.Error("BirthDate.Valid = false, want true") + } + if !result.BirthDate.Val.Equal(birthDate) { + t.Errorf("BirthDate.Val = %v, want %v", result.BirthDate.Val, birthDate) + } + + // Verify driver.Valuer for SqlDate + dateValue, err := result.BirthDate.Value() + if err != nil { + t.Errorf("BirthDate.Value() error = %v", err) + } + if dateValue == nil { + t.Error("BirthDate.Value() returned nil") + } + + // Verify SqlTime + if !result.LoginTime.Valid { + t.Error("LoginTime.Valid = false, want true") + } + + // Verify driver.Valuer for SqlTime + timeValue, err := result.LoginTime.Value() + if err != nil { + t.Errorf("LoginTime.Value() error = %v", err) + } + if timeValue == nil { + t.Error("LoginTime.Value() returned nil") + } + + // Verify SqlJSONB for Meta + if len(result.Meta) == 0 { + t.Error("Meta is empty") + } + metaValue, err := result.Meta.Value() + if err != nil { + t.Errorf("Meta.Value() error = %v", err) + } + if metaValue == nil { + t.Error("Meta.Value() returned nil") + } + + // Verify SqlJSONB for Tags + if len(result.Tags) == 0 { + t.Error("Tags is empty") + } + tagsValue, err := result.Tags.Value() + if err != nil { + t.Errorf("Tags.Value() error = %v", err) + } + if tagsValue == nil { + t.Error("Tags.Value() returned nil") + } + + t.Logf("All SQL types successfully preserved driver.Valuer interface:") + t.Logf(" - SqlTimeStamp: %v", tsValue) + t.Logf(" - SqlDate: %v", dateValue) + t.Logf(" - SqlTime: %v", timeValue) + t.Logf(" - SqlJSONB (Meta): %v", metaValue) + t.Logf(" - SqlJSONB (Tags): %v", tagsValue) +} + +func TestMapToStruct_SqlNull_NilValues(t *testing.T) { + // Test that SqlNull types handle nil values correctly + type TestModel struct { + ID int64 `bun:"id,pk" json:"id"` + UpdatedAt common.SqlTimeStamp `bun:"updated_at" json:"updated_at"` + DeletedAt common.SqlTimeStamp `bun:"deleted_at" json:"deleted_at"` + } + + now := time.Now() + dataMap := map[string]interface{}{ + "id": int64(200), + "updated_at": now, + "deleted_at": nil, // Explicitly nil + } + + var result TestModel + err := reflection.MapToStruct(dataMap, &result) + if err != nil { + t.Fatalf("MapToStruct() error = %v", err) + } + + // UpdatedAt should be valid + if !result.UpdatedAt.Valid { + t.Error("UpdatedAt.Valid = false, want true") + } + if !result.UpdatedAt.Val.Equal(now) { + t.Errorf("UpdatedAt.Val = %v, want %v", result.UpdatedAt.Val, now) + } + + // DeletedAt should be invalid (null) + if result.DeletedAt.Valid { + t.Error("DeletedAt.Valid = true, want false (null)") + } + + // Verify driver.Valuer for null SqlTimeStamp + deletedValue, err := result.DeletedAt.Value() + if err != nil { + t.Errorf("DeletedAt.Value() error = %v", err) + } + if deletedValue != nil { + t.Errorf("DeletedAt.Value() = %v, want nil", deletedValue) + } +} diff --git a/pkg/reflection/model_utils_test.go b/pkg/reflection/model_utils_test.go index 8aa3db6..41b6529 100644 --- a/pkg/reflection/model_utils_test.go +++ b/pkg/reflection/model_utils_test.go @@ -1687,3 +1687,201 @@ func TestGetRelationModel_WithTags(t *testing.T) { }) } } + +func TestMapToStruct(t *testing.T) { + // Test model with various field types + type TestModel struct { + ID int64 `bun:"id,pk" json:"id"` + Name string `bun:"name" json:"name"` + Age int `bun:"age" json:"age"` + Active bool `bun:"active" json:"active"` + Score float64 `bun:"score" json:"score"` + Data []byte `bun:"data" json:"data"` + MetaJSON []byte `bun:"meta_json" json:"meta_json"` + } + + tests := []struct { + name string + dataMap map[string]interface{} + expected TestModel + wantErr bool + }{ + { + name: "Basic types conversion", + dataMap: map[string]interface{}{ + "id": int64(123), + "name": "Test User", + "age": 30, + "active": true, + "score": 95.5, + }, + expected: TestModel{ + ID: 123, + Name: "Test User", + Age: 30, + Active: true, + Score: 95.5, + }, + wantErr: false, + }, + { + name: "Byte slice (SqlJSONB-like) from []byte", + dataMap: map[string]interface{}{ + "id": int64(456), + "name": "JSON Test", + "data": []byte(`{"key":"value"}`), + }, + expected: TestModel{ + ID: 456, + Name: "JSON Test", + Data: []byte(`{"key":"value"}`), + }, + wantErr: false, + }, + { + name: "Byte slice from string", + dataMap: map[string]interface{}{ + "id": int64(789), + "data": "string data", + }, + expected: TestModel{ + ID: 789, + Data: []byte("string data"), + }, + wantErr: false, + }, + { + name: "Byte slice from map (JSON marshal)", + dataMap: map[string]interface{}{ + "id": int64(999), + "meta_json": map[string]interface{}{ + "field1": "value1", + "field2": 42, + }, + }, + expected: TestModel{ + ID: 999, + MetaJSON: []byte(`{"field1":"value1","field2":42}`), + }, + wantErr: false, + }, + { + name: "Byte slice from slice (JSON marshal)", + dataMap: map[string]interface{}{ + "id": int64(111), + "meta_json": []interface{}{"item1", "item2", 3}, + }, + expected: TestModel{ + ID: 111, + MetaJSON: []byte(`["item1","item2",3]`), + }, + wantErr: false, + }, + { + name: "Field matching by bun tag", + dataMap: map[string]interface{}{ + "id": int64(222), + "name": "Tagged Field", + }, + expected: TestModel{ + ID: 222, + Name: "Tagged Field", + }, + wantErr: false, + }, + { + name: "Nil values", + dataMap: map[string]interface{}{ + "id": int64(333), + "data": nil, + }, + expected: TestModel{ + ID: 333, + Data: nil, + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var result TestModel + err := MapToStruct(tt.dataMap, &result) + + if (err != nil) != tt.wantErr { + t.Errorf("MapToStruct() error = %v, wantErr %v", err, tt.wantErr) + return + } + + // Compare fields individually for better error messages + if result.ID != tt.expected.ID { + t.Errorf("ID = %v, want %v", result.ID, tt.expected.ID) + } + if result.Name != tt.expected.Name { + t.Errorf("Name = %v, want %v", result.Name, tt.expected.Name) + } + if result.Age != tt.expected.Age { + t.Errorf("Age = %v, want %v", result.Age, tt.expected.Age) + } + if result.Active != tt.expected.Active { + t.Errorf("Active = %v, want %v", result.Active, tt.expected.Active) + } + if result.Score != tt.expected.Score { + t.Errorf("Score = %v, want %v", result.Score, tt.expected.Score) + } + + // For byte slices, compare as strings for JSON data + if tt.expected.Data != nil { + if string(result.Data) != string(tt.expected.Data) { + t.Errorf("Data = %s, want %s", string(result.Data), string(tt.expected.Data)) + } + } + if tt.expected.MetaJSON != nil { + if string(result.MetaJSON) != string(tt.expected.MetaJSON) { + t.Errorf("MetaJSON = %s, want %s", string(result.MetaJSON), string(tt.expected.MetaJSON)) + } + } + }) + } +} + +func TestMapToStruct_Errors(t *testing.T) { + type TestModel struct { + ID int `bun:"id" json:"id"` + } + + tests := []struct { + name string + dataMap map[string]interface{} + target interface{} + wantErr bool + }{ + { + name: "Nil dataMap", + dataMap: nil, + target: &TestModel{}, + wantErr: true, + }, + { + name: "Nil target", + dataMap: map[string]interface{}{"id": 1}, + target: nil, + wantErr: true, + }, + { + name: "Non-pointer target", + dataMap: map[string]interface{}{"id": 1}, + target: TestModel{}, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := MapToStruct(tt.dataMap, tt.target) + if (err != nil) != tt.wantErr { + t.Errorf("MapToStruct() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} diff --git a/pkg/restheadspec/handler.go b/pkg/restheadspec/handler.go index acb1833..ce6bb64 100644 --- a/pkg/restheadspec/handler.go +++ b/pkg/restheadspec/handler.go @@ -1220,8 +1220,14 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id // Ensure ID is in the data map for the update dataMap[pkName] = targetID - // Create update query - query := tx.NewUpdate().Table(tableName).SetMap(dataMap) + // Populate model instance from dataMap to preserve custom types (like SqlJSONB) + modelInstance := reflect.New(reflect.TypeOf(model).Elem()).Interface() + if err := reflection.MapToStruct(dataMap, modelInstance); err != nil { + return fmt.Errorf("failed to populate model from data: %w", err) + } + + // Create update query using Model() to preserve custom types and driver.Valuer interfaces + query := tx.NewUpdate().Model(modelInstance).Table(tableName) query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), targetID) // Execute BeforeScan hooks - pass query chain so hooks can modify it