diff --git a/pkg/restheadspec/handler.go b/pkg/restheadspec/handler.go index ecf226b..604456b 100644 --- a/pkg/restheadspec/handler.go +++ b/pkg/restheadspec/handler.go @@ -610,6 +610,9 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat dataSlice := h.normalizeToSlice(data) logger.Debug("Processing %d item(s) for creation", len(dataSlice)) + // Store original data maps for merging later + originalDataMaps := make([]map[string]interface{}, 0, len(dataSlice)) + // Process all items in a transaction results := make([]interface{}, 0, len(dataSlice)) err := h.db.RunInTransaction(ctx, func(tx common.Database) error { @@ -630,6 +633,13 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat } } + // Store a copy of the original data map for merging later + originalMap := make(map[string]interface{}) + for k, v := range itemMap { + originalMap[k] = v + } + originalDataMaps = append(originalDataMaps, originalMap) + // Extract nested relations if present (but don't process them yet) var nestedRelations map[string]interface{} if h.shouldUseNestedProcessor(itemMap, model) { @@ -704,14 +714,26 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat return } + // Merge created records with original request data + // This preserves extra keys from the request + mergedResults := make([]interface{}, 0, len(results)) + for i, result := range results { + if i < len(originalDataMaps) { + merged := h.mergeRecordWithRequest(result, originalDataMaps[i]) + mergedResults = append(mergedResults, merged) + } else { + mergedResults = append(mergedResults, result) + } + } + // Execute AfterCreate hooks var responseData interface{} - if len(results) == 1 { - responseData = results[0] - hookCtx.Result = results[0] + if len(mergedResults) == 1 { + responseData = mergedResults[0] + hookCtx.Result = mergedResults[0] } else { - responseData = results - hookCtx.Result = map[string]interface{}{"created": len(results), "data": results} + responseData = mergedResults + hookCtx.Result = map[string]interface{}{"created": len(mergedResults), "data": mergedResults} } hookCtx.Error = nil @@ -721,7 +743,7 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat return } - logger.Info("Successfully created %d record(s)", len(results)) + logger.Info("Successfully created %d record(s)", len(mergedResults)) h.sendResponseWithOptions(w, responseData, nil, &options) } @@ -790,6 +812,12 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id return } + // Get the primary key name for the model + pkName := reflection.GetPrimaryKeyName(model) + + // Variable to store the updated record + var updatedRecord interface{} + // Process nested relations if present err := h.db.RunInTransaction(ctx, func(tx common.Database) error { // Create temporary nested processor with transaction @@ -807,9 +835,6 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id nestedRelations = relations } - // Get the primary key name for the model - pkName := reflection.GetPrimaryKeyName(model) - // Ensure ID is in the data map for the update dataMap[pkName] = targetID @@ -842,10 +867,18 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id } } - // Store result for hooks - hookCtx.Result = map[string]interface{}{ - "updated": result.RowsAffected(), + // Fetch the updated record to return the new values + modelValue := reflect.New(reflect.TypeOf(model)).Interface() + selectQuery := tx.NewSelect().Model(modelValue).Table(tableName).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), targetID) + if err := selectQuery.ScanModel(ctx); err != nil { + return fmt.Errorf("failed to fetch updated record: %w", err) } + + updatedRecord = modelValue + + // Store result for hooks + hookCtx.Result = updatedRecord + _ = result // Keep result variable for potential future use return nil }) @@ -855,7 +888,12 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id return } + // Merge the updated record with the original request data + // This preserves extra keys from the request and updates values from the database + mergedData := h.mergeRecordWithRequest(updatedRecord, dataMap) + // Execute AfterUpdate hooks + hookCtx.Result = mergedData hookCtx.Error = nil if err := h.hooks.Execute(AfterUpdate, hookCtx); err != nil { logger.Error("AfterUpdate hook failed: %v", err) @@ -864,7 +902,7 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id } logger.Info("Successfully updated record with ID: %v", targetID) - h.sendResponseWithOptions(w, hookCtx.Result, nil, &options) + h.sendResponseWithOptions(w, mergedData, nil, &options) } func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id string, data interface{}) { @@ -1127,6 +1165,39 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id h.sendResponse(w, responseData, nil) } +// mergeRecordWithRequest merges a database record with the original request data +// This preserves extra keys from the request that aren't in the database model +// and updates values from the database (e.g., from SQL triggers or defaults) +func (h *Handler) mergeRecordWithRequest(dbRecord interface{}, requestData map[string]interface{}) map[string]interface{} { + // Convert the database record to a map + dbMap := make(map[string]interface{}) + + // Marshal and unmarshal to convert struct to map + jsonData, err := json.Marshal(dbRecord) + if err != nil { + logger.Warn("Failed to marshal database record for merging: %v", err) + return requestData + } + + if err := json.Unmarshal(jsonData, &dbMap); err != nil { + logger.Warn("Failed to unmarshal database record for merging: %v", err) + return requestData + } + + // Start with the request data (preserves extra keys) + result := make(map[string]interface{}) + for k, v := range requestData { + result[k] = v + } + + // Update with values from database (overwrites with DB values, including trigger changes) + for k, v := range dbMap { + result[k] = v + } + + return result +} + // normalizeToSlice converts data to a slice. Single items become a 1-item slice. func (h *Handler) normalizeToSlice(data interface{}) []interface{} { if data == nil { @@ -1663,22 +1734,22 @@ func (h *Handler) cleanJSON(data interface{}) interface{} { } func (h *Handler) sendError(w common.ResponseWriter, statusCode int, code, message string, err error) { - var details string + var errorMsg string if err != nil { - details = err.Error() + errorMsg = err.Error() + } else if message != "" { + errorMsg = message + } else { + errorMsg = code } - response := common.Response{ - Success: false, - Error: &common.APIError{ - Code: code, - Message: message, - Details: details, - }, + response := map[string]interface{}{ + "_error": errorMsg, + "_retval": 1, } w.WriteHeader(statusCode) - if err := w.WriteJSON(response); err != nil { - logger.Error("Failed to write JSON error response: %v", err) + if jsonErr := w.WriteJSON(response); jsonErr != nil { + logger.Error("Failed to write JSON error response: %v", jsonErr) } }