diff --git a/pkg/reflection/helpers.go b/pkg/reflection/helpers.go index cc6787f..155f30c 100644 --- a/pkg/reflection/helpers.go +++ b/pkg/reflection/helpers.go @@ -47,3 +47,20 @@ func ExtractTableNameOnly(fullName string) string { return fullName[startIndex:] } + +// GetPointerElement returns the element type if the provided reflect.Type is a pointer. +// If the type is a slice of pointers, it returns the element type of the pointer within the slice. +// If neither condition is met, it returns the original type. +func GetPointerElement(v reflect.Type) reflect.Type { + if v.Kind() == reflect.Ptr { + return v.Elem() + } + if v.Kind() == reflect.Slice && v.Elem().Kind() == reflect.Ptr { + subElem := v.Elem() + if subElem.Elem().Kind() == reflect.Ptr { + return subElem.Elem().Elem() + } + return v.Elem() + } + return v +} diff --git a/pkg/resolvespec/handler.go b/pkg/resolvespec/handler.go index d57826c..4a1aea8 100644 --- a/pkg/resolvespec/handler.go +++ b/pkg/resolvespec/handler.go @@ -702,7 +702,7 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url pkName := reflection.GetPrimaryKeyName(model) // First, read the existing record from the database - existingRecord := reflect.New(reflect.TypeOf(model).Elem()).Interface() + existingRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface() selectQuery := h.db.NewSelect().Model(existingRecord) // Apply conditions to select @@ -850,7 +850,7 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url for _, item := range updates { if itemID, ok := item["id"]; ok { // First, read the existing record - existingRecord := reflect.New(reflect.TypeOf(model).Elem()).Interface() + existingRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface() selectQuery := tx.NewSelect().Model(existingRecord).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), itemID) if err := selectQuery.ScanModel(ctx); err != nil { if err == sql.ErrNoRows { @@ -958,7 +958,7 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url if itemMap, ok := item.(map[string]interface{}); ok { if itemID, ok := itemMap["id"]; ok { // First, read the existing record - existingRecord := reflect.New(reflect.TypeOf(model).Elem()).Interface() + existingRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface() selectQuery := tx.NewSelect().Model(existingRecord).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), itemID) if err := selectQuery.ScanModel(ctx); err != nil { if err == sql.ErrNoRows { diff --git a/pkg/restheadspec/handler.go b/pkg/restheadspec/handler.go index 4d567ae..3a47fad 100644 --- a/pkg/restheadspec/handler.go +++ b/pkg/restheadspec/handler.go @@ -1240,7 +1240,7 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id txNestedProcessor := common.NewNestedCUDProcessor(tx, h.registry, h) // First, read the existing record from the database - existingRecord := reflect.New(reflect.TypeOf(model).Elem()).Interface() + existingRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface() selectQuery := tx.NewSelect().Model(existingRecord).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), targetID) if err := selectQuery.ScanModel(ctx); err != nil { if err == sql.ErrNoRows { @@ -1294,9 +1294,7 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id // Populate model instance from dataMap to preserve custom types (like SqlJSONB) // Get the type of the model, handling both pointer and non-pointer types modelType := reflect.TypeOf(model) - if modelType.Kind() == reflect.Ptr { - modelType = modelType.Elem() - } + modelType = reflection.GetPointerElement(modelType) modelInstance := reflect.New(modelType).Interface() if err := reflection.MapToStruct(dataMap, modelInstance); err != nil { return fmt.Errorf("failed to populate model from data: %w", err) @@ -1600,9 +1598,7 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id // First, fetch the record that will be deleted modelType := reflect.TypeOf(model) - if modelType.Kind() == reflect.Ptr { - modelType = modelType.Elem() - } + modelType = reflection.GetPointerElement(modelType) recordToDelete := reflect.New(modelType).Interface() selectQuery := h.db.NewSelect().Model(recordToDelete).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id)