From 62a8e56f1b24b222f542c86c165bcb82ca107c65 Mon Sep 17 00:00:00 2001 From: Hein Date: Tue, 6 Jan 2026 10:45:23 +0200 Subject: [PATCH] =?UTF-8?q?feat(reflection):=20=E2=9C=A8=20add=20GetPointe?= =?UTF-8?q?rElement=20function=20for=20type=20handling?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Introduced GetPointerElement to simplify pointer type extraction. * Updated handleUpdate methods to utilize GetPointerElement for better clarity and maintainability. --- pkg/reflection/helpers.go | 17 +++++++++++++++++ pkg/resolvespec/handler.go | 6 +++--- pkg/restheadspec/handler.go | 10 +++------- 3 files changed, 23 insertions(+), 10 deletions(-) diff --git a/pkg/reflection/helpers.go b/pkg/reflection/helpers.go index cc6787f..155f30c 100644 --- a/pkg/reflection/helpers.go +++ b/pkg/reflection/helpers.go @@ -47,3 +47,20 @@ func ExtractTableNameOnly(fullName string) string { return fullName[startIndex:] } + +// GetPointerElement returns the element type if the provided reflect.Type is a pointer. +// If the type is a slice of pointers, it returns the element type of the pointer within the slice. +// If neither condition is met, it returns the original type. +func GetPointerElement(v reflect.Type) reflect.Type { + if v.Kind() == reflect.Ptr { + return v.Elem() + } + if v.Kind() == reflect.Slice && v.Elem().Kind() == reflect.Ptr { + subElem := v.Elem() + if subElem.Elem().Kind() == reflect.Ptr { + return subElem.Elem().Elem() + } + return v.Elem() + } + return v +} diff --git a/pkg/resolvespec/handler.go b/pkg/resolvespec/handler.go index d57826c..4a1aea8 100644 --- a/pkg/resolvespec/handler.go +++ b/pkg/resolvespec/handler.go @@ -702,7 +702,7 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url pkName := reflection.GetPrimaryKeyName(model) // First, read the existing record from the database - existingRecord := reflect.New(reflect.TypeOf(model).Elem()).Interface() + existingRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface() selectQuery := h.db.NewSelect().Model(existingRecord) // Apply conditions to select @@ -850,7 +850,7 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url for _, item := range updates { if itemID, ok := item["id"]; ok { // First, read the existing record - existingRecord := reflect.New(reflect.TypeOf(model).Elem()).Interface() + existingRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface() selectQuery := tx.NewSelect().Model(existingRecord).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), itemID) if err := selectQuery.ScanModel(ctx); err != nil { if err == sql.ErrNoRows { @@ -958,7 +958,7 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url if itemMap, ok := item.(map[string]interface{}); ok { if itemID, ok := itemMap["id"]; ok { // First, read the existing record - existingRecord := reflect.New(reflect.TypeOf(model).Elem()).Interface() + existingRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface() selectQuery := tx.NewSelect().Model(existingRecord).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), itemID) if err := selectQuery.ScanModel(ctx); err != nil { if err == sql.ErrNoRows { diff --git a/pkg/restheadspec/handler.go b/pkg/restheadspec/handler.go index 4d567ae..3a47fad 100644 --- a/pkg/restheadspec/handler.go +++ b/pkg/restheadspec/handler.go @@ -1240,7 +1240,7 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id txNestedProcessor := common.NewNestedCUDProcessor(tx, h.registry, h) // First, read the existing record from the database - existingRecord := reflect.New(reflect.TypeOf(model).Elem()).Interface() + existingRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface() selectQuery := tx.NewSelect().Model(existingRecord).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), targetID) if err := selectQuery.ScanModel(ctx); err != nil { if err == sql.ErrNoRows { @@ -1294,9 +1294,7 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id // Populate model instance from dataMap to preserve custom types (like SqlJSONB) // Get the type of the model, handling both pointer and non-pointer types modelType := reflect.TypeOf(model) - if modelType.Kind() == reflect.Ptr { - modelType = modelType.Elem() - } + modelType = reflection.GetPointerElement(modelType) modelInstance := reflect.New(modelType).Interface() if err := reflection.MapToStruct(dataMap, modelInstance); err != nil { return fmt.Errorf("failed to populate model from data: %w", err) @@ -1600,9 +1598,7 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id // First, fetch the record that will be deleted modelType := reflect.TypeOf(model) - if modelType.Kind() == reflect.Ptr { - modelType = modelType.Elem() - } + modelType = reflection.GetPointerElement(modelType) recordToDelete := reflect.New(modelType).Interface() selectQuery := h.db.NewSelect().Model(recordToDelete).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id)