diff --git a/pkg/resolvespec/handler.go b/pkg/resolvespec/handler.go index b6a0b97..cd5aaeb 100644 --- a/pkg/resolvespec/handler.go +++ b/pkg/resolvespec/handler.go @@ -836,7 +836,7 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url err := h.db.RunInTransaction(ctx, func(tx common.Database) error { // First, read the existing record from the database existingRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface() - selectQuery := tx.NewSelect().Model(existingRecord).Column("*") + selectQuery := tx.NewSelect().Model(existingRecord).Column(reflection.GetSQLModelColumns(model)...) // Apply conditions to select if urlID != "" { @@ -955,13 +955,34 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url return } + // Fetch the updated record after the transaction commits to capture any trigger changes + updatedRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface() + fetchQuery := h.db.NewSelect().Model(updatedRecord).Column(reflection.GetSQLModelColumns(model)...) + if urlID != "" { + fetchQuery = fetchQuery.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), urlID) + } else if reqID != nil { + switch id := reqID.(type) { + case string: + fetchQuery = fetchQuery.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id) + case []string: + if len(id) > 0 { + fetchQuery = fetchQuery.Where(fmt.Sprintf("%s IN (?)", common.QuoteIdent(pkName)), id) + } + } + } + if err := fetchQuery.ScanModel(ctx); err != nil { + logger.Error("Failed to fetch updated record: %v", err) + h.sendError(w, http.StatusInternalServerError, "fetch_error", "Failed to fetch updated record", err) + return + } + logger.Info("Successfully updated record(s)") // Invalidate cache for this table cacheTags := buildCacheTags(schema, tableName) if err := invalidateCacheForTags(ctx, cacheTags); err != nil { logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err) } - h.sendResponse(w, data, nil) + h.sendResponse(w, updatedRecord, nil) case []map[string]interface{}: // Batch update with array of objects @@ -1017,7 +1038,7 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url // First, read the existing record existingRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface() - selectQuery := tx.NewSelect().Model(existingRecord).Column("*").Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), itemID) + selectQuery := tx.NewSelect().Model(existingRecord).Column(reflection.GetSQLModelColumns(model)...).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), itemID) if err := selectQuery.ScanModel(ctx); err != nil { if err == sql.ErrNoRows { continue // Skip if record not found @@ -1089,13 +1110,29 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url h.sendError(w, http.StatusInternalServerError, "update_error", "Error updating records", err) return } - logger.Info("Successfully updated %d records", len(updates)) + + // Fetch updated records after the transaction commits to capture any trigger changes + fetchedUpdates := make([]interface{}, 0, len(updates)) + for _, item := range updates { + if itemID, ok := item["id"]; ok && itemID != nil { + fetchedRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface() + fetchQuery := h.db.NewSelect().Model(fetchedRecord).Column(reflection.GetSQLModelColumns(model)...).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), itemID) + if err := fetchQuery.ScanModel(ctx); err != nil { + logger.Error("Failed to fetch updated record with ID %v: %v", itemID, err) + h.sendError(w, http.StatusInternalServerError, "fetch_error", "Failed to fetch updated record", err) + return + } + fetchedUpdates = append(fetchedUpdates, fetchedRecord) + } + } + + logger.Info("Successfully updated %d records", len(fetchedUpdates)) // Invalidate cache for this table cacheTags := buildCacheTags(schema, tableName) if err := invalidateCacheForTags(ctx, cacheTags); err != nil { logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err) } - h.sendResponse(w, updates, nil) + h.sendResponse(w, fetchedUpdates, nil) case []interface{}: // Batch update with []interface{} @@ -1157,7 +1194,7 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url // First, read the existing record existingRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface() - selectQuery := tx.NewSelect().Model(existingRecord).Column("*").Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), itemID) + selectQuery := tx.NewSelect().Model(existingRecord).Column(reflection.GetSQLModelColumns(model)...).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), itemID) if err := selectQuery.ScanModel(ctx); err != nil { if err == sql.ErrNoRows { continue // Skip if record not found @@ -1232,13 +1269,31 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url h.sendError(w, http.StatusInternalServerError, "update_error", "Error updating records", err) return } - logger.Info("Successfully updated %d records", len(list)) + + // Fetch updated records after the transaction commits to capture any trigger changes + fetchedList := make([]interface{}, 0, len(list)) + for _, item := range list { + if itemMap, ok := item.(map[string]interface{}); ok { + if itemID, ok := itemMap["id"]; ok && itemID != nil { + fetchedRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface() + fetchQuery := h.db.NewSelect().Model(fetchedRecord).Column(reflection.GetSQLModelColumns(model)...).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), itemID) + if err := fetchQuery.ScanModel(ctx); err != nil { + logger.Error("Failed to fetch updated record with ID %v: %v", itemID, err) + h.sendError(w, http.StatusInternalServerError, "fetch_error", "Failed to fetch updated record", err) + return + } + fetchedList = append(fetchedList, fetchedRecord) + } + } + } + + logger.Info("Successfully updated %d records", len(fetchedList)) // Invalidate cache for this table cacheTags := buildCacheTags(schema, tableName) if err := invalidateCacheForTags(ctx, cacheTags); err != nil { logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err) } - h.sendResponse(w, list, nil) + h.sendResponse(w, fetchedList, nil) default: logger.Error("Invalid data type for update operation: %T", data)