From c0c669bd3d425bf4f29d15676a418a68a0e5c417 Mon Sep 17 00:00:00 2001 From: Hein Date: Mon, 5 Jan 2026 12:31:01 +0200 Subject: [PATCH] feat(handler): enhance update logic to merge existing records with incoming data --- pkg/funcspec/function_api.go | 140 +++++++++++++++++----------------- pkg/funcspec/hooks.go | 5 ++ pkg/resolvespec/handler.go | 141 +++++++++++++++++++++++++++++++++-- pkg/restheadspec/handler.go | 41 +++++++++- 4 files changed, 248 insertions(+), 79 deletions(-) diff --git a/pkg/funcspec/function_api.go b/pkg/funcspec/function_api.go index cf8787c..6d6328d 100644 --- a/pkg/funcspec/function_api.go +++ b/pkg/funcspec/function_api.go @@ -123,27 +123,6 @@ func (h *Handler) SqlQueryList(sqlquery string, options SqlQueryOptions) HTTPFun ComplexAPI: complexAPI, } - // Execute BeforeQueryList hook - if err := h.hooks.Execute(BeforeQueryList, hookCtx); err != nil { - logger.Error("BeforeQueryList hook failed: %v", err) - sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err) - return - } - - // Check if hook aborted the operation - if hookCtx.Abort { - if hookCtx.AbortCode == 0 { - hookCtx.AbortCode = http.StatusBadRequest - } - sendError(w, hookCtx.AbortCode, "operation_aborted", hookCtx.AbortMessage, nil) - return - } - - // Use potentially modified SQL query and variables from hooks - sqlquery = hookCtx.SQLQuery - variables = hookCtx.Variables - // complexAPI = hookCtx.ComplexAPI - // Extract input variables from SQL query (placeholders like [variable]) sqlquery = h.extractInputVariables(sqlquery, &inputvars) @@ -203,6 +182,27 @@ func (h *Handler) SqlQueryList(sqlquery string, options SqlQueryOptions) HTTPFun // Execute query within transaction err := h.db.RunInTransaction(ctx, func(tx common.Database) error { + // Set transaction in hook context for hooks to use + hookCtx.Tx = tx + + // Execute BeforeQueryList hook (inside transaction) + if err := h.hooks.Execute(BeforeQueryList, hookCtx); err != nil { + logger.Error("BeforeQueryList hook failed: %v", err) + sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err) + return err + } + + // Check if hook aborted the operation + if hookCtx.Abort { + if hookCtx.AbortCode == 0 { + hookCtx.AbortCode = http.StatusBadRequest + } + sendError(w, hookCtx.AbortCode, "operation_aborted", hookCtx.AbortMessage, nil) + return fmt.Errorf("operation aborted: %s", hookCtx.AbortMessage) + } + + // Use potentially modified SQL query from hook + sqlquery = hookCtx.SQLQuery sqlqueryCnt := sqlquery // Parse sorting and pagination parameters @@ -286,6 +286,21 @@ func (h *Handler) SqlQueryList(sqlquery string, options SqlQueryOptions) HTTPFun } total = hookCtx.Total + // Execute AfterQueryList hook (inside transaction) + hookCtx.Result = dbobjlist + hookCtx.Total = total + hookCtx.Error = nil + if err := h.hooks.Execute(AfterQueryList, hookCtx); err != nil { + logger.Error("AfterQueryList hook failed: %v", err) + sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err) + return err + } + // Use potentially modified result from hook + if modifiedResult, ok := hookCtx.Result.([]map[string]interface{}); ok { + dbobjlist = modifiedResult + } + total = hookCtx.Total + return nil }) @@ -294,21 +309,6 @@ func (h *Handler) SqlQueryList(sqlquery string, options SqlQueryOptions) HTTPFun return } - // Execute AfterQueryList hook - hookCtx.Result = dbobjlist - hookCtx.Total = total - hookCtx.Error = err - if err := h.hooks.Execute(AfterQueryList, hookCtx); err != nil { - logger.Error("AfterQueryList hook failed: %v", err) - sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err) - return - } - // Use potentially modified result from hook - if modifiedResult, ok := hookCtx.Result.([]map[string]interface{}); ok { - dbobjlist = modifiedResult - } - total = hookCtx.Total - // Set response headers respOffset := 0 if offsetStr := r.URL.Query().Get("offset"); offsetStr != "" { @@ -459,26 +459,6 @@ func (h *Handler) SqlQuery(sqlquery string, options SqlQueryOptions) HTTPFuncTyp ComplexAPI: complexAPI, } - // Execute BeforeQuery hook - if err := h.hooks.Execute(BeforeQuery, hookCtx); err != nil { - logger.Error("BeforeQuery hook failed: %v", err) - sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err) - return - } - - // Check if hook aborted the operation - if hookCtx.Abort { - if hookCtx.AbortCode == 0 { - hookCtx.AbortCode = http.StatusBadRequest - } - sendError(w, hookCtx.AbortCode, "operation_aborted", hookCtx.AbortMessage, nil) - return - } - - // Use potentially modified SQL query and variables from hooks - sqlquery = hookCtx.SQLQuery - variables = hookCtx.Variables - // Extract input variables from SQL query sqlquery = h.extractInputVariables(sqlquery, &inputvars) @@ -554,6 +534,28 @@ func (h *Handler) SqlQuery(sqlquery string, options SqlQueryOptions) HTTPFuncTyp // Execute query within transaction err := h.db.RunInTransaction(ctx, func(tx common.Database) error { + // Set transaction in hook context for hooks to use + hookCtx.Tx = tx + + // Execute BeforeQuery hook (inside transaction) + if err := h.hooks.Execute(BeforeQuery, hookCtx); err != nil { + logger.Error("BeforeQuery hook failed: %v", err) + sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err) + return err + } + + // Check if hook aborted the operation + if hookCtx.Abort { + if hookCtx.AbortCode == 0 { + hookCtx.AbortCode = http.StatusBadRequest + } + sendError(w, hookCtx.AbortCode, "operation_aborted", hookCtx.AbortMessage, nil) + return fmt.Errorf("operation aborted: %s", hookCtx.AbortMessage) + } + + // Use potentially modified SQL query from hook + sqlquery = hookCtx.SQLQuery + // Execute BeforeSQLExec hook if err := h.hooks.Execute(BeforeSQLExec, hookCtx); err != nil { logger.Error("BeforeSQLExec hook failed: %v", err) @@ -586,6 +588,19 @@ func (h *Handler) SqlQuery(sqlquery string, options SqlQueryOptions) HTTPFuncTyp dbobj = modifiedResult } + // Execute AfterQuery hook (inside transaction) + hookCtx.Result = dbobj + hookCtx.Error = nil + if err := h.hooks.Execute(AfterQuery, hookCtx); err != nil { + logger.Error("AfterQuery hook failed: %v", err) + sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err) + return err + } + // Use potentially modified result from hook + if modifiedResult, ok := hookCtx.Result.(map[string]interface{}); ok { + dbobj = modifiedResult + } + return nil }) @@ -594,19 +609,6 @@ func (h *Handler) SqlQuery(sqlquery string, options SqlQueryOptions) HTTPFuncTyp return } - // Execute AfterQuery hook - hookCtx.Result = dbobj - hookCtx.Error = err - if err := h.hooks.Execute(AfterQuery, hookCtx); err != nil { - logger.Error("AfterQuery hook failed: %v", err) - sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err) - return - } - // Use potentially modified result from hook - if modifiedResult, ok := hookCtx.Result.(map[string]interface{}); ok { - dbobj = modifiedResult - } - // Execute BeforeResponse hook hookCtx.Result = dbobj if err := h.hooks.Execute(BeforeResponse, hookCtx); err != nil { diff --git a/pkg/funcspec/hooks.go b/pkg/funcspec/hooks.go index 26d46a3..d04d19e 100644 --- a/pkg/funcspec/hooks.go +++ b/pkg/funcspec/hooks.go @@ -5,6 +5,7 @@ import ( "fmt" "net/http" + "github.com/bitechdev/ResolveSpec/pkg/common" "github.com/bitechdev/ResolveSpec/pkg/logger" "github.com/bitechdev/ResolveSpec/pkg/security" ) @@ -46,6 +47,10 @@ type HookContext struct { // User context UserContext *security.UserContext + // Tx provides access to the database/transaction for executing additional SQL + // This allows hooks to run custom queries in addition to the main Query chain + Tx common.Database + // Pagination and filtering (for list queries) SortColumns string Limit int diff --git a/pkg/resolvespec/handler.go b/pkg/resolvespec/handler.go index c65f26d..d57826c 100644 --- a/pkg/resolvespec/handler.go +++ b/pkg/resolvespec/handler.go @@ -698,20 +698,83 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url } // Standard processing without nested relations - query := h.db.NewUpdate().Table(tableName).SetMap(updates) + // Get the primary key name + pkName := reflection.GetPrimaryKeyName(model) - // Apply conditions + // First, read the existing record from the database + existingRecord := reflect.New(reflect.TypeOf(model).Elem()).Interface() + selectQuery := h.db.NewSelect().Model(existingRecord) + + // Apply conditions to select if urlID != "" { logger.Debug("Updating by URL ID: %s", urlID) - query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(reflection.GetPrimaryKeyName(model))), urlID) + selectQuery = selectQuery.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), urlID) } else if reqID != nil { switch id := reqID.(type) { case string: logger.Debug("Updating by request ID: %s", id) - query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(reflection.GetPrimaryKeyName(model))), id) + selectQuery = selectQuery.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id) case []string: - logger.Debug("Updating by multiple IDs: %v", id) - query = query.Where(fmt.Sprintf("%s IN (?)", common.QuoteIdent(reflection.GetPrimaryKeyName(model))), id) + if len(id) > 0 { + logger.Debug("Updating by multiple IDs: %v", id) + selectQuery = selectQuery.Where(fmt.Sprintf("%s IN (?)", common.QuoteIdent(pkName)), id) + } + } + } + + if err := selectQuery.ScanModel(ctx); err != nil { + if err == sql.ErrNoRows { + logger.Warn("No records found to update") + h.sendError(w, http.StatusNotFound, "not_found", "No records found to update", nil) + return + } + logger.Error("Error fetching existing record: %v", err) + h.sendError(w, http.StatusInternalServerError, "update_error", "Error fetching existing record", err) + return + } + + // Convert existing record to map + existingMap := make(map[string]interface{}) + jsonData, err := json.Marshal(existingRecord) + if err != nil { + logger.Error("Error marshaling existing record: %v", err) + h.sendError(w, http.StatusInternalServerError, "update_error", "Error processing existing record", err) + return + } + if err := json.Unmarshal(jsonData, &existingMap); err != nil { + logger.Error("Error unmarshaling existing record: %v", err) + h.sendError(w, http.StatusInternalServerError, "update_error", "Error processing existing record", err) + return + } + + // Merge only non-null and non-empty values from the incoming request into the existing record + for key, newValue := range updates { + // Skip if the value is nil + if newValue == nil { + continue + } + + // Skip if the value is an empty string + if strVal, ok := newValue.(string); ok && strVal == "" { + continue + } + + // Update the existing map with the new value + existingMap[key] = newValue + } + + // Build update query with merged data + query := h.db.NewUpdate().Table(tableName).SetMap(existingMap) + + // Apply conditions + if urlID != "" { + query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), urlID) + } else if reqID != nil { + switch id := reqID.(type) { + case string: + query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id) + case []string: + query = query.Where(fmt.Sprintf("%s IN (?)", common.QuoteIdent(pkName)), id) } } @@ -782,11 +845,42 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url } // Standard batch update without nested relations + pkName := reflection.GetPrimaryKeyName(model) err := h.db.RunInTransaction(ctx, func(tx common.Database) error { for _, item := range updates { if itemID, ok := item["id"]; ok { + // First, read the existing record + existingRecord := reflect.New(reflect.TypeOf(model).Elem()).Interface() + selectQuery := tx.NewSelect().Model(existingRecord).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), itemID) + if err := selectQuery.ScanModel(ctx); err != nil { + if err == sql.ErrNoRows { + continue // Skip if record not found + } + return fmt.Errorf("failed to fetch existing record: %w", err) + } - txQuery := tx.NewUpdate().Table(tableName).SetMap(item).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(reflection.GetPrimaryKeyName(model))), itemID) + // Convert existing record to map + existingMap := make(map[string]interface{}) + jsonData, err := json.Marshal(existingRecord) + if err != nil { + return fmt.Errorf("failed to marshal existing record: %w", err) + } + if err := json.Unmarshal(jsonData, &existingMap); err != nil { + return fmt.Errorf("failed to unmarshal existing record: %w", err) + } + + // Merge only non-null and non-empty values + for key, newValue := range item { + if newValue == nil { + continue + } + if strVal, ok := newValue.(string); ok && strVal == "" { + continue + } + existingMap[key] = newValue + } + + txQuery := tx.NewUpdate().Table(tableName).SetMap(existingMap).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), itemID) if _, err := txQuery.Exec(ctx); err != nil { return err } @@ -857,13 +951,44 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url } // Standard batch update without nested relations + pkName := reflection.GetPrimaryKeyName(model) list := make([]interface{}, 0) err := h.db.RunInTransaction(ctx, func(tx common.Database) error { for _, item := range updates { if itemMap, ok := item.(map[string]interface{}); ok { if itemID, ok := itemMap["id"]; ok { + // First, read the existing record + existingRecord := reflect.New(reflect.TypeOf(model).Elem()).Interface() + selectQuery := tx.NewSelect().Model(existingRecord).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), itemID) + if err := selectQuery.ScanModel(ctx); err != nil { + if err == sql.ErrNoRows { + continue // Skip if record not found + } + return fmt.Errorf("failed to fetch existing record: %w", err) + } - txQuery := tx.NewUpdate().Table(tableName).SetMap(itemMap).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(reflection.GetPrimaryKeyName(model))), itemID) + // Convert existing record to map + existingMap := make(map[string]interface{}) + jsonData, err := json.Marshal(existingRecord) + if err != nil { + return fmt.Errorf("failed to marshal existing record: %w", err) + } + if err := json.Unmarshal(jsonData, &existingMap); err != nil { + return fmt.Errorf("failed to unmarshal existing record: %w", err) + } + + // Merge only non-null and non-empty values + for key, newValue := range itemMap { + if newValue == nil { + continue + } + if strVal, ok := newValue.(string); ok && strVal == "" { + continue + } + existingMap[key] = newValue + } + + txQuery := tx.NewUpdate().Table(tableName).SetMap(existingMap).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), itemID) if _, err := txQuery.Exec(ctx); err != nil { return err } diff --git a/pkg/restheadspec/handler.go b/pkg/restheadspec/handler.go index 5a71a20..4d567ae 100644 --- a/pkg/restheadspec/handler.go +++ b/pkg/restheadspec/handler.go @@ -1239,6 +1239,26 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id // Create temporary nested processor with transaction txNestedProcessor := common.NewNestedCUDProcessor(tx, h.registry, h) + // First, read the existing record from the database + existingRecord := reflect.New(reflect.TypeOf(model).Elem()).Interface() + selectQuery := tx.NewSelect().Model(existingRecord).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), targetID) + if err := selectQuery.ScanModel(ctx); err != nil { + if err == sql.ErrNoRows { + return fmt.Errorf("record not found with ID: %v", targetID) + } + return fmt.Errorf("failed to fetch existing record: %w", err) + } + + // Convert existing record to map + existingMap := make(map[string]interface{}) + jsonData, err := json.Marshal(existingRecord) + if err != nil { + return fmt.Errorf("failed to marshal existing record: %w", err) + } + if err := json.Unmarshal(jsonData, &existingMap); err != nil { + return fmt.Errorf("failed to unmarshal existing record: %w", err) + } + // Extract nested relations if present (but don't process them yet) var nestedRelations map[string]interface{} if h.shouldUseNestedProcessor(dataMap, model) { @@ -1251,8 +1271,25 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id nestedRelations = relations } + // Merge only non-null and non-empty values from the incoming request into the existing record + for key, newValue := range dataMap { + // Skip if the value is nil + if newValue == nil { + continue + } + + // Skip if the value is an empty string + if strVal, ok := newValue.(string); ok && strVal == "" { + continue + } + + // Update the existing map with the new value + existingMap[key] = newValue + } + // Ensure ID is in the data map for the update - dataMap[pkName] = targetID + existingMap[pkName] = targetID + dataMap = existingMap // Populate model instance from dataMap to preserve custom types (like SqlJSONB) // Get the type of the model, handling both pointer and non-pointer types @@ -1297,7 +1334,7 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id // Fetch the updated record to return the new values modelValue := reflect.New(reflect.TypeOf(model)).Interface() - selectQuery := tx.NewSelect().Model(modelValue).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), targetID) + selectQuery = tx.NewSelect().Model(modelValue).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), targetID) if err := selectQuery.ScanModel(ctx); err != nil { return fmt.Errorf("failed to fetch updated record: %w", err) }