diff --git a/pkg/reflection/helpers.go b/pkg/reflection/helpers.go index 2ae1a88..9740c3e 100644 --- a/pkg/reflection/helpers.go +++ b/pkg/reflection/helpers.go @@ -76,9 +76,14 @@ func GetJSONNameForField(modelType reflect.Type, fieldName string) string { return "" } - // Handle pointer types - if modelType.Kind() == reflect.Ptr { - modelType = modelType.Elem() + // Unwrap pointer and slice indirections to reach the struct type + for { + switch modelType.Kind() { + case reflect.Ptr, reflect.Slice: + modelType = modelType.Elem() + continue + } + break } if modelType.Kind() != reflect.Struct { diff --git a/pkg/reflection/model_utils.go b/pkg/reflection/model_utils.go index 18f4034..e388bf1 100644 --- a/pkg/reflection/model_utils.go +++ b/pkg/reflection/model_utils.go @@ -541,9 +541,14 @@ func collectSQLColumnsFromType(typ reflect.Type, columns *[]string, scanOnlyEmbe func IsColumnWritable(model any, columnName string) bool { modelType := reflect.TypeOf(model) - // Unwrap pointers to get to the base struct type - for modelType != nil && modelType.Kind() == reflect.Pointer { - modelType = modelType.Elem() + // Unwrap pointers and slices to get to the base struct type + for modelType != nil { + switch modelType.Kind() { + case reflect.Ptr, reflect.Slice: + modelType = modelType.Elem() + continue + } + break } // Validate that we have a struct type @@ -878,8 +883,14 @@ func GetRelationType(model interface{}, fieldName string) RelationType { return RelationUnknown } - if modelType.Kind() == reflect.Ptr { - modelType = modelType.Elem() + // Unwrap pointer → slice → pointer chains to reach the underlying struct + for { + switch modelType.Kind() { + case reflect.Ptr, reflect.Slice: + modelType = modelType.Elem() + continue + } + break } if modelType == nil || modelType.Kind() != reflect.Struct { @@ -1472,9 +1483,14 @@ func convertToFloat64(value interface{}) (float64, bool) { func GetValidJSONFieldNames(modelType reflect.Type) map[string]bool { validFields := make(map[string]bool) - // Unwrap pointers to get to the base struct type - for modelType != nil && modelType.Kind() == reflect.Pointer { - modelType = modelType.Elem() + // Unwrap pointers and slices to get to the base struct type + for modelType != nil { + switch modelType.Kind() { + case reflect.Ptr, reflect.Slice: + modelType = modelType.Elem() + continue + } + break } if modelType == nil || modelType.Kind() != reflect.Struct { @@ -1535,8 +1551,13 @@ func getRelationModelSingleLevel(model interface{}, fieldName string) interface{ return nil } - if modelType.Kind() == reflect.Ptr { - modelType = modelType.Elem() + for { + switch modelType.Kind() { + case reflect.Ptr, reflect.Slice: + modelType = modelType.Elem() + continue + } + break } if modelType == nil || modelType.Kind() != reflect.Struct { @@ -1599,17 +1620,16 @@ func getRelationModelSingleLevel(model interface{}, fieldName string) interface{ return nil } - if targetType.Kind() == reflect.Slice { - targetType = targetType.Elem() - if targetType == nil { - return nil - } - } - if targetType.Kind() == reflect.Ptr { - targetType = targetType.Elem() - if targetType == nil { - return nil + for { + switch targetType.Kind() { + case reflect.Ptr, reflect.Slice: + targetType = targetType.Elem() + if targetType == nil { + return nil + } + continue } + break } if targetType.Kind() != reflect.Struct {