From 7d6a9025f55b0ea7257e2362302a45915f11ab84 Mon Sep 17 00:00:00 2001 From: Hein Date: Thu, 20 Nov 2025 09:40:11 +0200 Subject: [PATCH] Fixed hardcoded id --- pkg/common/adapters/database/bun.go | 5 +++++ pkg/common/adapters/database/gorm.go | 7 +++++++ pkg/common/recursive_crud.go | 15 +++++++++------ pkg/restheadspec/handler.go | 15 ++++++++++----- 4 files changed, 31 insertions(+), 11 deletions(-) diff --git a/pkg/common/adapters/database/bun.go b/pkg/common/adapters/database/bun.go index edc6fad..7a5cc50 100644 --- a/pkg/common/adapters/database/bun.go +++ b/pkg/common/adapters/database/bun.go @@ -388,12 +388,17 @@ func (b *BunUpdateQuery) Set(column string, value interface{}) common.UpdateQuer } func (b *BunUpdateQuery) SetMap(values map[string]interface{}) common.UpdateQuery { + pkName := reflection.GetPrimaryKeyName(b.model) for column, value := range values { // Validate column is writable if model is set if b.model != nil && !reflection.IsColumnWritable(b.model, column) { // Skip scan-only columns continue } + if pkName != "" && column == pkName { + // Skip primary key updates + continue + } b.query = b.query.Set(column+" = ?", value) } return b diff --git a/pkg/common/adapters/database/gorm.go b/pkg/common/adapters/database/gorm.go index fb0d74d..666d9cb 100644 --- a/pkg/common/adapters/database/gorm.go +++ b/pkg/common/adapters/database/gorm.go @@ -369,13 +369,20 @@ func (g *GormUpdateQuery) Set(column string, value interface{}) common.UpdateQue } func (g *GormUpdateQuery) SetMap(values map[string]interface{}) common.UpdateQuery { + // Filter out read-only columns if model is set if g.model != nil { + pkName := reflection.GetPrimaryKeyName(g.model) filteredValues := make(map[string]interface{}) for column, value := range values { + if pkName != "" && column == pkName { + // Skip primary key updates + continue + } if reflection.IsColumnWritable(g.model, column) { filteredValues[column] = value } + } g.updates = filteredValues } else { diff --git a/pkg/common/recursive_crud.go b/pkg/common/recursive_crud.go index 688da6f..f7f06a7 100644 --- a/pkg/common/recursive_crud.go +++ b/pkg/common/recursive_crud.go @@ -111,6 +111,9 @@ func (p *NestedCUDProcessor) ProcessNestedCUD( // Inject parent IDs for foreign key resolution p.injectForeignKeys(regularData, modelType, parentIDs) + // Get the primary key name for this model + pkName := reflection.GetPrimaryKeyName(model) + // Process based on operation switch strings.ToLower(operation) { case "insert", "create": @@ -128,30 +131,30 @@ func (p *NestedCUDProcessor) ProcessNestedCUD( } case "update": - rows, err := p.processUpdate(ctx, regularData, tableName, data["id"]) + rows, err := p.processUpdate(ctx, regularData, tableName, data[pkName]) if err != nil { return nil, fmt.Errorf("update failed: %w", err) } - result.ID = data["id"] + result.ID = data[pkName] result.AffectedRows = rows result.Data = regularData // Process child relations for update - if err := p.processChildRelations(ctx, "update", data["id"], relationFields, result.RelationData, modelType); err != nil { + if err := p.processChildRelations(ctx, "update", data[pkName], relationFields, result.RelationData, modelType); err != nil { return nil, fmt.Errorf("failed to process child relations: %w", err) } case "delete": // Process child relations first (for referential integrity) - if err := p.processChildRelations(ctx, "delete", data["id"], relationFields, result.RelationData, modelType); err != nil { + if err := p.processChildRelations(ctx, "delete", data[pkName], relationFields, result.RelationData, modelType); err != nil { return nil, fmt.Errorf("failed to process child relations before delete: %w", err) } - rows, err := p.processDelete(ctx, tableName, data["id"]) + rows, err := p.processDelete(ctx, tableName, data[pkName]) if err != nil { return nil, fmt.Errorf("delete failed: %w", err) } - result.ID = data["id"] + result.ID = data[pkName] result.AffectedRows = rows result.Data = regularData diff --git a/pkg/restheadspec/handler.go b/pkg/restheadspec/handler.go index d7ac3cd..ecf226b 100644 --- a/pkg/restheadspec/handler.go +++ b/pkg/restheadspec/handler.go @@ -807,12 +807,14 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id nestedRelations = relations } + // Get the primary key name for the model + pkName := reflection.GetPrimaryKeyName(model) + // Ensure ID is in the data map for the update - dataMap["id"] = targetID + dataMap[pkName] = targetID // Create update query query := tx.NewUpdate().Table(tableName).SetMap(dataMap) - pkName := reflection.GetPrimaryKeyName(model) query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), targetID) // Execute BeforeScan hooks - pass query chain so hooks can modify it @@ -936,6 +938,7 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id // Array of IDs or objects with ID field logger.Info("Batch delete with %d items ([]interface{})", len(v)) deletedCount := 0 + pkName := reflection.GetPrimaryKeyName(model) err := h.db.RunInTransaction(ctx, func(tx common.Database) error { for _, item := range v { var itemID interface{} @@ -945,7 +948,7 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id case string: itemID = v case map[string]interface{}: - itemID = v["id"] + itemID = v[pkName] default: itemID = item } @@ -1002,9 +1005,10 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id // Array of objects with id field logger.Info("Batch delete with %d items ([]map[string]interface{})", len(v)) deletedCount := 0 + pkName := reflection.GetPrimaryKeyName(model) err := h.db.RunInTransaction(ctx, func(tx common.Database) error { for _, item := range v { - if itemID, ok := item["id"]; ok && itemID != nil { + if itemID, ok := item[pkName]; ok && itemID != nil { itemIDStr := fmt.Sprintf("%v", itemID) // Execute hooks for each item @@ -1052,7 +1056,8 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id case map[string]interface{}: // Single object with id field - if itemID, ok := v["id"]; ok && itemID != nil { + pkName := reflection.GetPrimaryKeyName(model) + if itemID, ok := v[pkName]; ok && itemID != nil { id = fmt.Sprintf("%v", itemID) } }