From 0ac207d80fd3a696ce61ac2b815703878501610b Mon Sep 17 00:00:00 2001 From: Hein Date: Tue, 13 Jan 2026 11:33:45 +0200 Subject: [PATCH] fix: better update handling --- pkg/reflection/model_utils.go | 38 +- pkg/reflection/spectypes_integration_test.go | 364 +++++++++++++++++++ pkg/resolvespec/handler.go | 254 +++++++++---- pkg/restheadspec/handler.go | 53 +-- 4 files changed, 590 insertions(+), 119 deletions(-) create mode 100644 pkg/reflection/spectypes_integration_test.go diff --git a/pkg/reflection/model_utils.go b/pkg/reflection/model_utils.go index 150ae9d..b07d4d2 100644 --- a/pkg/reflection/model_utils.go +++ b/pkg/reflection/model_utils.go @@ -948,29 +948,35 @@ func MapToStruct(dataMap map[string]interface{}, target interface{}) error { // Build list of possible column names for this field var columnNames []string - // 1. Bun tag - if bunTag := field.Tag.Get("bun"); bunTag != "" && bunTag != "-" { - if colName := ExtractColumnFromBunTag(bunTag); colName != "" { - columnNames = append(columnNames, colName) - } - } - - // 2. Gorm tag - if gormTag := field.Tag.Get("gorm"); gormTag != "" && gormTag != "-" { - if colName := ExtractColumnFromGormTag(gormTag); colName != "" { - columnNames = append(columnNames, colName) - } - } - - // 3. JSON tag + // 1. JSON tag (primary - most common) + jsonFound := false if jsonTag := field.Tag.Get("json"); jsonTag != "" && jsonTag != "-" { parts := strings.Split(jsonTag, ",") if len(parts) > 0 && parts[0] != "" { columnNames = append(columnNames, parts[0]) + jsonFound = true } } - // 4. Field name variations + // 2. Bun tag (fallback if no JSON tag) + if !jsonFound { + if bunTag := field.Tag.Get("bun"); bunTag != "" && bunTag != "-" { + if colName := ExtractColumnFromBunTag(bunTag); colName != "" { + columnNames = append(columnNames, colName) + } + } + } + + // 3. Gorm tag (fallback if no JSON tag) + if !jsonFound { + if gormTag := field.Tag.Get("gorm"); gormTag != "" && gormTag != "-" { + if colName := ExtractColumnFromGormTag(gormTag); colName != "" { + columnNames = append(columnNames, colName) + } + } + } + + // 4. Field name variations (last resort) columnNames = append(columnNames, field.Name) columnNames = append(columnNames, strings.ToLower(field.Name)) // columnNames = append(columnNames, ToSnakeCase(field.Name)) diff --git a/pkg/reflection/spectypes_integration_test.go b/pkg/reflection/spectypes_integration_test.go new file mode 100644 index 0000000..34bfb15 --- /dev/null +++ b/pkg/reflection/spectypes_integration_test.go @@ -0,0 +1,364 @@ +package reflection + +import ( + "testing" + "time" + + "github.com/bitechdev/ResolveSpec/pkg/spectypes" + "github.com/google/uuid" +) + +// TestModel contains all spectypes custom types +type TestModel struct { + ID int64 `bun:"id,pk" json:"id"` + Name spectypes.SqlString `bun:"name" json:"name"` + Age spectypes.SqlInt64 `bun:"age" json:"age"` + Score spectypes.SqlFloat64 `bun:"score" json:"score"` + Active spectypes.SqlBool `bun:"active" json:"active"` + UUID spectypes.SqlUUID `bun:"uuid" json:"uuid"` + CreatedAt spectypes.SqlTimeStamp `bun:"created_at" json:"created_at"` + BirthDate spectypes.SqlDate `bun:"birth_date" json:"birth_date"` + StartTime spectypes.SqlTime `bun:"start_time" json:"start_time"` + Metadata spectypes.SqlJSONB `bun:"metadata" json:"metadata"` + Count16 spectypes.SqlInt16 `bun:"count16" json:"count16"` + Count32 spectypes.SqlInt32 `bun:"count32" json:"count32"` +} + +// TestMapToStruct_AllSpectypes verifies that MapToStruct can convert all spectypes correctly +func TestMapToStruct_AllSpectypes(t *testing.T) { + testUUID := uuid.New() + testTime := time.Now() + + tests := []struct { + name string + dataMap map[string]interface{} + validator func(*testing.T, *TestModel) + }{ + { + name: "SqlString from string", + dataMap: map[string]interface{}{ + "name": "John Doe", + }, + validator: func(t *testing.T, m *TestModel) { + if !m.Name.Valid || m.Name.String() != "John Doe" { + t.Errorf("expected name='John Doe', got valid=%v, value=%s", m.Name.Valid, m.Name.String()) + } + }, + }, + { + name: "SqlInt64 from int64", + dataMap: map[string]interface{}{ + "age": int64(42), + }, + validator: func(t *testing.T, m *TestModel) { + if !m.Age.Valid || m.Age.Int64() != 42 { + t.Errorf("expected age=42, got valid=%v, value=%d", m.Age.Valid, m.Age.Int64()) + } + }, + }, + { + name: "SqlInt64 from string", + dataMap: map[string]interface{}{ + "age": "99", + }, + validator: func(t *testing.T, m *TestModel) { + if !m.Age.Valid || m.Age.Int64() != 99 { + t.Errorf("expected age=99, got valid=%v, value=%d", m.Age.Valid, m.Age.Int64()) + } + }, + }, + { + name: "SqlFloat64 from float64", + dataMap: map[string]interface{}{ + "score": float64(98.5), + }, + validator: func(t *testing.T, m *TestModel) { + if !m.Score.Valid || m.Score.Float64() != 98.5 { + t.Errorf("expected score=98.5, got valid=%v, value=%f", m.Score.Valid, m.Score.Float64()) + } + }, + }, + { + name: "SqlBool from bool", + dataMap: map[string]interface{}{ + "active": true, + }, + validator: func(t *testing.T, m *TestModel) { + if !m.Active.Valid || !m.Active.Bool() { + t.Errorf("expected active=true, got valid=%v, value=%v", m.Active.Valid, m.Active.Bool()) + } + }, + }, + { + name: "SqlUUID from string", + dataMap: map[string]interface{}{ + "uuid": testUUID.String(), + }, + validator: func(t *testing.T, m *TestModel) { + if !m.UUID.Valid || m.UUID.UUID() != testUUID { + t.Errorf("expected uuid=%s, got valid=%v, value=%s", testUUID.String(), m.UUID.Valid, m.UUID.UUID().String()) + } + }, + }, + { + name: "SqlTimeStamp from time.Time", + dataMap: map[string]interface{}{ + "created_at": testTime, + }, + validator: func(t *testing.T, m *TestModel) { + if !m.CreatedAt.Valid { + t.Errorf("expected created_at to be valid") + } + // Check if times are close enough (within a second) + diff := m.CreatedAt.Time().Sub(testTime) + if diff < -time.Second || diff > time.Second { + t.Errorf("time difference too large: %v", diff) + } + }, + }, + { + name: "SqlTimeStamp from string", + dataMap: map[string]interface{}{ + "created_at": "2024-01-15T10:30:00", + }, + validator: func(t *testing.T, m *TestModel) { + if !m.CreatedAt.Valid { + t.Errorf("expected created_at to be valid") + } + expected := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC) + if m.CreatedAt.Time().Year() != expected.Year() || + m.CreatedAt.Time().Month() != expected.Month() || + m.CreatedAt.Time().Day() != expected.Day() { + t.Errorf("expected date 2024-01-15, got %v", m.CreatedAt.Time()) + } + }, + }, + { + name: "SqlDate from string", + dataMap: map[string]interface{}{ + "birth_date": "2000-05-20", + }, + validator: func(t *testing.T, m *TestModel) { + if !m.BirthDate.Valid { + t.Errorf("expected birth_date to be valid") + } + expected := "2000-05-20" + if m.BirthDate.String() != expected { + t.Errorf("expected date=%s, got %s", expected, m.BirthDate.String()) + } + }, + }, + { + name: "SqlTime from string", + dataMap: map[string]interface{}{ + "start_time": "14:30:00", + }, + validator: func(t *testing.T, m *TestModel) { + if !m.StartTime.Valid { + t.Errorf("expected start_time to be valid") + } + if m.StartTime.String() != "14:30:00" { + t.Errorf("expected time=14:30:00, got %s", m.StartTime.String()) + } + }, + }, + { + name: "SqlJSONB from map", + dataMap: map[string]interface{}{ + "metadata": map[string]interface{}{ + "key1": "value1", + "key2": 123, + }, + }, + validator: func(t *testing.T, m *TestModel) { + if len(m.Metadata) == 0 { + t.Errorf("expected metadata to have data") + } + asMap, err := m.Metadata.AsMap() + if err != nil { + t.Fatalf("failed to convert metadata to map: %v", err) + } + if asMap["key1"] != "value1" { + t.Errorf("expected key1=value1, got %v", asMap["key1"]) + } + }, + }, + { + name: "SqlJSONB from string", + dataMap: map[string]interface{}{ + "metadata": `{"test":"data"}`, + }, + validator: func(t *testing.T, m *TestModel) { + if len(m.Metadata) == 0 { + t.Errorf("expected metadata to have data") + } + asMap, err := m.Metadata.AsMap() + if err != nil { + t.Fatalf("failed to convert metadata to map: %v", err) + } + if asMap["test"] != "data" { + t.Errorf("expected test=data, got %v", asMap["test"]) + } + }, + }, + { + name: "SqlJSONB from []byte", + dataMap: map[string]interface{}{ + "metadata": []byte(`{"byte":"array"}`), + }, + validator: func(t *testing.T, m *TestModel) { + if len(m.Metadata) == 0 { + t.Errorf("expected metadata to have data") + } + if string(m.Metadata) != `{"byte":"array"}` { + t.Errorf("expected {\"byte\":\"array\"}, got %s", string(m.Metadata)) + } + }, + }, + { + name: "SqlInt16 from int16", + dataMap: map[string]interface{}{ + "count16": int16(100), + }, + validator: func(t *testing.T, m *TestModel) { + if !m.Count16.Valid || m.Count16.Int64() != 100 { + t.Errorf("expected count16=100, got valid=%v, value=%d", m.Count16.Valid, m.Count16.Int64()) + } + }, + }, + { + name: "SqlInt32 from int32", + dataMap: map[string]interface{}{ + "count32": int32(5000), + }, + validator: func(t *testing.T, m *TestModel) { + if !m.Count32.Valid || m.Count32.Int64() != 5000 { + t.Errorf("expected count32=5000, got valid=%v, value=%d", m.Count32.Valid, m.Count32.Int64()) + } + }, + }, + { + name: "nil values create invalid nulls", + dataMap: map[string]interface{}{ + "name": nil, + "age": nil, + "active": nil, + "created_at": nil, + }, + validator: func(t *testing.T, m *TestModel) { + if m.Name.Valid { + t.Error("expected name to be invalid for nil value") + } + if m.Age.Valid { + t.Error("expected age to be invalid for nil value") + } + if m.Active.Valid { + t.Error("expected active to be invalid for nil value") + } + if m.CreatedAt.Valid { + t.Error("expected created_at to be invalid for nil value") + } + }, + }, + { + name: "all types together", + dataMap: map[string]interface{}{ + "id": int64(1), + "name": "Test User", + "age": int64(30), + "score": float64(95.7), + "active": true, + "uuid": testUUID.String(), + "created_at": "2024-01-15T10:30:00", + "birth_date": "1994-06-15", + "start_time": "09:00:00", + "metadata": map[string]interface{}{"role": "admin"}, + "count16": int16(50), + "count32": int32(1000), + }, + validator: func(t *testing.T, m *TestModel) { + if m.ID != 1 { + t.Errorf("expected id=1, got %d", m.ID) + } + if !m.Name.Valid || m.Name.String() != "Test User" { + t.Errorf("expected name='Test User', got valid=%v, value=%s", m.Name.Valid, m.Name.String()) + } + if !m.Age.Valid || m.Age.Int64() != 30 { + t.Errorf("expected age=30, got valid=%v, value=%d", m.Age.Valid, m.Age.Int64()) + } + if !m.Score.Valid || m.Score.Float64() != 95.7 { + t.Errorf("expected score=95.7, got valid=%v, value=%f", m.Score.Valid, m.Score.Float64()) + } + if !m.Active.Valid || !m.Active.Bool() { + t.Errorf("expected active=true, got valid=%v, value=%v", m.Active.Valid, m.Active.Bool()) + } + if !m.UUID.Valid { + t.Error("expected uuid to be valid") + } + if !m.CreatedAt.Valid { + t.Error("expected created_at to be valid") + } + if !m.BirthDate.Valid || m.BirthDate.String() != "1994-06-15" { + t.Errorf("expected birth_date=1994-06-15, got valid=%v, value=%s", m.BirthDate.Valid, m.BirthDate.String()) + } + if !m.StartTime.Valid || m.StartTime.String() != "09:00:00" { + t.Errorf("expected start_time=09:00:00, got valid=%v, value=%s", m.StartTime.Valid, m.StartTime.String()) + } + if len(m.Metadata) == 0 { + t.Error("expected metadata to have data") + } + if !m.Count16.Valid || m.Count16.Int64() != 50 { + t.Errorf("expected count16=50, got valid=%v, value=%d", m.Count16.Valid, m.Count16.Int64()) + } + if !m.Count32.Valid || m.Count32.Int64() != 1000 { + t.Errorf("expected count32=1000, got valid=%v, value=%d", m.Count32.Valid, m.Count32.Int64()) + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + model := &TestModel{} + if err := MapToStruct(tt.dataMap, model); err != nil { + t.Fatalf("MapToStruct failed: %v", err) + } + tt.validator(t, model) + }) + } +} + +// TestMapToStruct_PartialUpdate tests that partial updates preserve unset fields +func TestMapToStruct_PartialUpdate(t *testing.T) { + // Create initial model with some values + initial := &TestModel{ + ID: 1, + Name: spectypes.NewSqlString("Original Name"), + Age: spectypes.NewSqlInt64(25), + } + + // Update only the age field + partialUpdate := map[string]interface{}{ + "age": int64(30), + } + + // Apply partial update + if err := MapToStruct(partialUpdate, initial); err != nil { + t.Fatalf("MapToStruct failed: %v", err) + } + + // Verify age was updated + if !initial.Age.Valid || initial.Age.Int64() != 30 { + t.Errorf("expected age=30, got valid=%v, value=%d", initial.Age.Valid, initial.Age.Int64()) + } + + // Verify name was preserved (not overwritten with zero value) + if !initial.Name.Valid || initial.Name.String() != "Original Name" { + t.Errorf("expected name='Original Name' to be preserved, got valid=%v, value=%s", initial.Name.Valid, initial.Name.String()) + } + + // Verify ID was preserved + if initial.ID != 1 { + t.Errorf("expected id=1 to be preserved, got %d", initial.ID) + } +} diff --git a/pkg/resolvespec/handler.go b/pkg/resolvespec/handler.go index bf082e7..a8d32d0 100644 --- a/pkg/resolvespec/handler.go +++ b/pkg/resolvespec/handler.go @@ -701,97 +701,130 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url // Get the primary key name pkName := reflection.GetPrimaryKeyName(model) - // First, read the existing record from the database - existingRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface() - selectQuery := h.db.NewSelect().Model(existingRecord) + // Wrap in transaction to ensure BeforeUpdate hook is inside transaction + err := h.db.RunInTransaction(ctx, func(tx common.Database) error { + // First, read the existing record from the database + existingRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface() + selectQuery := tx.NewSelect().Model(existingRecord).Column("*") - // Apply conditions to select - if urlID != "" { - logger.Debug("Updating by URL ID: %s", urlID) - selectQuery = selectQuery.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), urlID) - } else if reqID != nil { - switch id := reqID.(type) { - case string: - logger.Debug("Updating by request ID: %s", id) - selectQuery = selectQuery.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id) - case []string: - if len(id) > 0 { - logger.Debug("Updating by multiple IDs: %v", id) - selectQuery = selectQuery.Where(fmt.Sprintf("%s IN (?)", common.QuoteIdent(pkName)), id) + // Apply conditions to select + if urlID != "" { + logger.Debug("Updating by URL ID: %s", urlID) + selectQuery = selectQuery.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), urlID) + } else if reqID != nil { + switch id := reqID.(type) { + case string: + logger.Debug("Updating by request ID: %s", id) + selectQuery = selectQuery.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id) + case []string: + if len(id) > 0 { + logger.Debug("Updating by multiple IDs: %v", id) + selectQuery = selectQuery.Where(fmt.Sprintf("%s IN (?)", common.QuoteIdent(pkName)), id) + } } } - } - if err := selectQuery.ScanModel(ctx); err != nil { - if err == sql.ErrNoRows { - logger.Warn("No records found to update") - h.sendError(w, http.StatusNotFound, "not_found", "No records found to update", nil) - return - } - logger.Error("Error fetching existing record: %v", err) - h.sendError(w, http.StatusInternalServerError, "update_error", "Error fetching existing record", err) - return - } - - // Convert existing record to map - existingMap := make(map[string]interface{}) - jsonData, err := json.Marshal(existingRecord) - if err != nil { - logger.Error("Error marshaling existing record: %v", err) - h.sendError(w, http.StatusInternalServerError, "update_error", "Error processing existing record", err) - return - } - if err := json.Unmarshal(jsonData, &existingMap); err != nil { - logger.Error("Error unmarshaling existing record: %v", err) - h.sendError(w, http.StatusInternalServerError, "update_error", "Error processing existing record", err) - return - } - - // Merge only non-null and non-empty values from the incoming request into the existing record - for key, newValue := range updates { - // Skip if the value is nil - if newValue == nil { - continue + if err := selectQuery.ScanModel(ctx); err != nil { + if err == sql.ErrNoRows { + return fmt.Errorf("no records found to update") + } + return fmt.Errorf("error fetching existing record: %w", err) } - // Skip if the value is an empty string - if strVal, ok := newValue.(string); ok && strVal == "" { - continue + // Convert existing record to map + existingMap := make(map[string]interface{}) + jsonData, err := json.Marshal(existingRecord) + if err != nil { + return fmt.Errorf("error marshaling existing record: %w", err) + } + if err := json.Unmarshal(jsonData, &existingMap); err != nil { + return fmt.Errorf("error unmarshaling existing record: %w", err) } - // Update the existing map with the new value - existingMap[key] = newValue - } - - // Build update query with merged data - query := h.db.NewUpdate().Table(tableName).SetMap(existingMap) - - // Apply conditions - if urlID != "" { - query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), urlID) - } else if reqID != nil { - switch id := reqID.(type) { - case string: - query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id) - case []string: - query = query.Where(fmt.Sprintf("%s IN (?)", common.QuoteIdent(pkName)), id) + // Execute BeforeUpdate hooks inside transaction + hookCtx := &HookContext{ + Context: ctx, + Handler: h, + Schema: schema, + Entity: entity, + Model: model, + Options: options, + ID: urlID, + Data: updates, + Writer: w, + Tx: tx, } - } - result, err := query.Exec(ctx) + if err := h.hooks.Execute(BeforeUpdate, hookCtx); err != nil { + return fmt.Errorf("BeforeUpdate hook failed: %w", err) + } + + // Use potentially modified data from hook context + if modifiedData, ok := hookCtx.Data.(map[string]interface{}); ok { + updates = modifiedData + } + + // Merge only non-null and non-empty values from the incoming request into the existing record + for key, newValue := range updates { + // Skip if the value is nil + if newValue == nil { + continue + } + + // Skip if the value is an empty string + if strVal, ok := newValue.(string); ok && strVal == "" { + continue + } + + // Update the existing map with the new value + existingMap[key] = newValue + } + + // Build update query with merged data + query := tx.NewUpdate().Table(tableName).SetMap(existingMap) + + // Apply conditions + if urlID != "" { + query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), urlID) + } else if reqID != nil { + switch id := reqID.(type) { + case string: + query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id) + case []string: + query = query.Where(fmt.Sprintf("%s IN (?)", common.QuoteIdent(pkName)), id) + } + } + + result, err := query.Exec(ctx) + if err != nil { + return fmt.Errorf("error updating record(s): %w", err) + } + + if result.RowsAffected() == 0 { + return fmt.Errorf("no records found to update") + } + + // Execute AfterUpdate hooks inside transaction + hookCtx.Result = updates + hookCtx.Error = nil + if err := h.hooks.Execute(AfterUpdate, hookCtx); err != nil { + return fmt.Errorf("AfterUpdate hook failed: %w", err) + } + + return nil + }) + if err != nil { logger.Error("Update error: %v", err) - h.sendError(w, http.StatusInternalServerError, "update_error", "Error updating record(s)", err) + if err.Error() == "no records found to update" { + h.sendError(w, http.StatusNotFound, "not_found", "No records found to update", err) + } else { + h.sendError(w, http.StatusInternalServerError, "update_error", "Error updating record(s)", err) + } return } - if result.RowsAffected() == 0 { - logger.Warn("No records found to update") - h.sendError(w, http.StatusNotFound, "not_found", "No records found to update", nil) - return - } - - logger.Info("Successfully updated %d records", result.RowsAffected()) + logger.Info("Successfully updated record(s)") // Invalidate cache for this table cacheTags := buildCacheTags(schema, tableName) if err := invalidateCacheForTags(ctx, cacheTags); err != nil { @@ -849,9 +882,11 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url err := h.db.RunInTransaction(ctx, func(tx common.Database) error { for _, item := range updates { if itemID, ok := item["id"]; ok { + itemIDStr := fmt.Sprintf("%v", itemID) + // First, read the existing record existingRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface() - selectQuery := tx.NewSelect().Model(existingRecord).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), itemID) + selectQuery := tx.NewSelect().Model(existingRecord).Column("*").Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), itemID) if err := selectQuery.ScanModel(ctx); err != nil { if err == sql.ErrNoRows { continue // Skip if record not found @@ -869,6 +904,29 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url return fmt.Errorf("failed to unmarshal existing record: %w", err) } + // Execute BeforeUpdate hooks inside transaction + hookCtx := &HookContext{ + Context: ctx, + Handler: h, + Schema: schema, + Entity: entity, + Model: model, + Options: options, + ID: itemIDStr, + Data: item, + Writer: w, + Tx: tx, + } + + if err := h.hooks.Execute(BeforeUpdate, hookCtx); err != nil { + return fmt.Errorf("BeforeUpdate hook failed for ID %v: %w", itemID, err) + } + + // Use potentially modified data from hook context + if modifiedData, ok := hookCtx.Data.(map[string]interface{}); ok { + item = modifiedData + } + // Merge only non-null and non-empty values for key, newValue := range item { if newValue == nil { @@ -884,6 +942,13 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url if _, err := txQuery.Exec(ctx); err != nil { return err } + + // Execute AfterUpdate hooks inside transaction + hookCtx.Result = item + hookCtx.Error = nil + if err := h.hooks.Execute(AfterUpdate, hookCtx); err != nil { + return fmt.Errorf("AfterUpdate hook failed for ID %v: %w", itemID, err) + } } } return nil @@ -957,9 +1022,11 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url for _, item := range updates { if itemMap, ok := item.(map[string]interface{}); ok { if itemID, ok := itemMap["id"]; ok { + itemIDStr := fmt.Sprintf("%v", itemID) + // First, read the existing record existingRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface() - selectQuery := tx.NewSelect().Model(existingRecord).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), itemID) + selectQuery := tx.NewSelect().Model(existingRecord).Column("*").Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), itemID) if err := selectQuery.ScanModel(ctx); err != nil { if err == sql.ErrNoRows { continue // Skip if record not found @@ -977,6 +1044,29 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url return fmt.Errorf("failed to unmarshal existing record: %w", err) } + // Execute BeforeUpdate hooks inside transaction + hookCtx := &HookContext{ + Context: ctx, + Handler: h, + Schema: schema, + Entity: entity, + Model: model, + Options: options, + ID: itemIDStr, + Data: itemMap, + Writer: w, + Tx: tx, + } + + if err := h.hooks.Execute(BeforeUpdate, hookCtx); err != nil { + return fmt.Errorf("BeforeUpdate hook failed for ID %v: %w", itemID, err) + } + + // Use potentially modified data from hook context + if modifiedData, ok := hookCtx.Data.(map[string]interface{}); ok { + itemMap = modifiedData + } + // Merge only non-null and non-empty values for key, newValue := range itemMap { if newValue == nil { @@ -992,6 +1082,14 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url if _, err := txQuery.Exec(ctx); err != nil { return err } + + // Execute AfterUpdate hooks inside transaction + hookCtx.Result = itemMap + hookCtx.Error = nil + if err := h.hooks.Execute(AfterUpdate, hookCtx); err != nil { + return fmt.Errorf("AfterUpdate hook failed for ID %v: %w", itemID, err) + } + list = append(list, item) } } diff --git a/pkg/restheadspec/handler.go b/pkg/restheadspec/handler.go index 2f20ce1..5118073 100644 --- a/pkg/restheadspec/handler.go +++ b/pkg/restheadspec/handler.go @@ -1110,30 +1110,6 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id logger.Info("Updating record in %s.%s", schema, entity) - // Execute BeforeUpdate hooks - hookCtx := &HookContext{ - Context: ctx, - Handler: h, - Schema: schema, - Entity: entity, - TableName: tableName, - Tx: h.db, - Model: model, - Options: options, - ID: id, - Data: data, - Writer: w, - } - - if err := h.hooks.Execute(BeforeUpdate, hookCtx); err != nil { - logger.Error("BeforeUpdate hook failed: %v", err) - h.sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err) - return - } - - // Use potentially modified data from hook context - data = hookCtx.Data - // Convert data to map dataMap, ok := data.(map[string]interface{}) if !ok { @@ -1167,6 +1143,9 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id // Variable to store the updated record var updatedRecord interface{} + // Declare hook context to be used inside and outside transaction + var hookCtx *HookContext + // Process nested relations if present err := h.db.RunInTransaction(ctx, func(tx common.Database) error { // Create temporary nested processor with transaction @@ -1174,7 +1153,7 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id // First, read the existing record from the database existingRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface() - selectQuery := tx.NewSelect().Model(existingRecord).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), targetID) + selectQuery := tx.NewSelect().Model(existingRecord).Column("*").Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), targetID) if err := selectQuery.ScanModel(ctx); err != nil { if err == sql.ErrNoRows { return fmt.Errorf("record not found with ID: %v", targetID) @@ -1204,6 +1183,30 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id nestedRelations = relations } + // Execute BeforeUpdate hooks inside transaction + hookCtx = &HookContext{ + Context: ctx, + Handler: h, + Schema: schema, + Entity: entity, + TableName: tableName, + Tx: tx, + Model: model, + Options: options, + ID: id, + Data: dataMap, + Writer: w, + } + + if err := h.hooks.Execute(BeforeUpdate, hookCtx); err != nil { + return fmt.Errorf("BeforeUpdate hook failed: %w", err) + } + + // Use potentially modified data from hook context + if modifiedData, ok := hookCtx.Data.(map[string]interface{}); ok { + dataMap = modifiedData + } + // Merge only non-null and non-empty values from the incoming request into the existing record for key, newValue := range dataMap { // Skip if the value is nil