diff --git a/pkg/mqttspec/handler.go b/pkg/mqttspec/handler.go index 48c3ed4..7cdf7ef 100644 --- a/pkg/mqttspec/handler.go +++ b/pkg/mqttspec/handler.go @@ -781,6 +781,12 @@ func (h *Handler) create(hookCtx *HookContext) (interface{}, error) { return nil, fmt.Errorf("failed to create record: %w", err) } + // Re-fetch the created record to capture DB-generated defaults/triggers. + if pkVal := reflection.GetPrimaryKeyValue(hookCtx.ModelPtr); pkVal != nil { + hookCtx.ID = fmt.Sprintf("%v", pkVal) + return h.readByID(hookCtx) + } + return hookCtx.ModelPtr, nil } diff --git a/pkg/resolvemcp/handler.go b/pkg/resolvemcp/handler.go index 5d3f53b..f2ccfdf 100644 --- a/pkg/resolvemcp/handler.go +++ b/pkg/resolvemcp/handler.go @@ -436,24 +436,27 @@ func (h *Handler) executeCreate(ctx context.Context, schema, entity string, data for key, value := range v { query = query.Value(key, value) } - if _, err := query.Exec(ctx); err != nil { - return nil, fmt.Errorf("create error: %w", err) - } - // Re-fetch after insert to capture DB-generated defaults/triggers. - if pkVal, ok := v[pkName]; ok && pkVal != nil { + if pkName != "" { + var insertedID interface{} + if err := query.Returning(pkName).Scan(ctx, &insertedID); err != nil { + return nil, fmt.Errorf("create error: %w", err) + } + // Re-fetch after insert to capture DB-generated defaults/triggers. modelType := reflect.TypeOf(model) if modelType.Kind() == reflect.Ptr { modelType = modelType.Elem() } fetchedRecord := reflect.New(modelType).Interface() if err := h.db.NewSelect().Model(fetchedRecord). - Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), pkVal). + Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), insertedID). ScanModel(ctx); err == nil { - jsonData, _ := json.Marshal(fetchedRecord) - var fetchedMap map[string]interface{} - if json.Unmarshal(jsonData, &fetchedMap) == nil { - v = fetchedMap - } + v = mergeWithInput(fetchedRecord, v) + } else { + logger.Warn("Failed to re-fetch created record with %s=%v: %v", pkName, insertedID, err) + } + } else { + if _, err := query.Exec(ctx); err != nil { + return nil, fmt.Errorf("create error: %w", err) } } hookCtx.Result = v @@ -463,7 +466,12 @@ func (h *Handler) executeCreate(ctx context.Context, schema, entity string, data return v, nil case []interface{}: - results := make([]interface{}, 0, len(v)) + modelType := reflect.TypeOf(model) + if modelType.Kind() == reflect.Ptr { + modelType = modelType.Elem() + } + originals := make([]map[string]interface{}, 0, len(v)) + insertedIDs := make([]interface{}, 0, len(v)) err := h.db.RunInTransaction(ctx, func(tx common.Database) error { for _, item := range v { itemMap, ok := item.(map[string]interface{}) @@ -474,16 +482,43 @@ func (h *Handler) executeCreate(ctx context.Context, schema, entity string, data for key, value := range itemMap { q = q.Value(key, value) } - if _, err := q.Exec(ctx); err != nil { + if pkName == "" { + if _, err := q.Exec(ctx); err != nil { + return err + } + originals = append(originals, itemMap) + insertedIDs = append(insertedIDs, nil) + continue + } + var returnedID interface{} + if err := q.Returning(pkName).Scan(ctx, &returnedID); err != nil { return err } - results = append(results, item) + originals = append(originals, itemMap) + insertedIDs = append(insertedIDs, returnedID) } return nil }) if err != nil { return nil, fmt.Errorf("batch create error: %w", err) } + // Re-fetch each record after transaction commits; fall back to input on failure. + results := make([]interface{}, 0, len(insertedIDs)) + for i, pkVal := range insertedIDs { + if pkVal == nil { + results = append(results, originals[i]) + continue + } + fetchedRecord := reflect.New(modelType).Interface() + if err := h.db.NewSelect().Model(fetchedRecord). + Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), pkVal). + ScanModel(ctx); err == nil { + results = append(results, mergeWithInput(fetchedRecord, originals[i])) + } else { + logger.Warn("Failed to re-fetch created record with %s=%v: %v", pkName, pkVal, err) + results = append(results, originals[i]) + } + } hookCtx.Result = results if err := h.hooks.Execute(AfterCreate, hookCtx); err != nil { return nil, fmt.Errorf("AfterCreate hook failed: %w", err) @@ -787,6 +822,28 @@ func (h *Handler) buildFilterCondition(filter common.FilterOption) (condition st return "", nil } +// mergeWithInput merges a database record with the original request data. +// DB values take precedence (capturing triggers/defaults), while extra +// input keys that have no DB column are preserved in the response. +func mergeWithInput(dbRecord interface{}, input map[string]interface{}) map[string]interface{} { + result := make(map[string]interface{}, len(input)) + for k, v := range input { + result[k] = v + } + jsonData, err := json.Marshal(dbRecord) + if err != nil { + return result + } + var dbMap map[string]interface{} + if err := json.Unmarshal(jsonData, &dbMap); err != nil { + return result + } + for k, v := range dbMap { + result[k] = v + } + return result +} + func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, preloads []common.PreloadOption) (common.SelectQuery, error) { for i := range preloads { preload := &preloads[i] diff --git a/pkg/resolvespec/handler.go b/pkg/resolvespec/handler.go index cd5aaeb..fd1c29b 100644 --- a/pkg/resolvespec/handler.go +++ b/pkg/resolvespec/handler.go @@ -602,23 +602,44 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat } // Standard processing without nested relations + pkName := reflection.GetPrimaryKeyName(model) query := h.db.NewInsert().Table(tableName) for key, value := range v { query = query.Value(key, common.ConvertSliceForBun(value)) } - result, err := query.Exec(ctx) - if err != nil { - logger.Error("Error creating record: %v", err) - h.sendError(w, http.StatusInternalServerError, "create_error", "Error creating record", err) - return + var responseData interface{} = v + if pkName == "" { + // No PK on model — insert and return input as-is. + result, err := query.Exec(ctx) + if err != nil { + logger.Error("Error creating record: %v", err) + h.sendError(w, http.StatusInternalServerError, "create_error", "Error creating record", err) + return + } + logger.Info("Successfully created record, rows affected: %d", result.RowsAffected()) + } else { + var insertedID interface{} + if err := query.Returning(pkName).Scan(ctx, &insertedID); err != nil { + logger.Error("Error creating record: %v", err) + h.sendError(w, http.StatusInternalServerError, "create_error", "Error creating record", err) + return + } + logger.Info("Successfully created record with %s: %v", pkName, insertedID) + fetchedRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface() + if fetchErr := h.db.NewSelect().Model(fetchedRecord). + Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), insertedID). + ScanModel(ctx); fetchErr == nil { + responseData = mergeWithInput(fetchedRecord, v) + } else { + logger.Warn("Failed to re-fetch created record with %s=%v: %v", pkName, insertedID, fetchErr) + } } - logger.Info("Successfully created record, rows affected: %d", result.RowsAffected()) // Invalidate cache for this table cacheTags := buildCacheTags(schema, tableName) if err := invalidateCacheForTags(ctx, cacheTags); err != nil { logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err) } - h.sendResponse(w, v, nil) + h.sendResponse(w, responseData, nil) case []map[string]interface{}: // Check if any item needs nested processing @@ -666,15 +687,30 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat } // Standard batch insert without nested relations + pkName := reflection.GetPrimaryKeyName(model) + modelElemType := reflection.GetPointerElement(reflect.TypeOf(model)) + originals := make([]map[string]interface{}, 0, len(v)) + insertedIDs := make([]interface{}, 0, len(v)) err := h.db.RunInTransaction(ctx, func(tx common.Database) error { for _, item := range v { txQuery := tx.NewInsert().Table(tableName) for key, value := range item { txQuery = txQuery.Value(key, common.ConvertSliceForBun(value)) } - if _, err := txQuery.Exec(ctx); err != nil { + if pkName == "" { + if _, err := txQuery.Exec(ctx); err != nil { + return err + } + originals = append(originals, item) + insertedIDs = append(insertedIDs, nil) + continue + } + var returnedID interface{} + if err := txQuery.Returning(pkName).Scan(ctx, &returnedID); err != nil { return err } + originals = append(originals, item) + insertedIDs = append(insertedIDs, returnedID) } return nil }) @@ -689,7 +725,24 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat if err := invalidateCacheForTags(ctx, cacheTags); err != nil { logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err) } - h.sendResponse(w, v, nil) + // Re-fetch each record after transaction commits; fall back to input on failure. + responseItems := make([]interface{}, 0, len(insertedIDs)) + for i, pkVal := range insertedIDs { + if pkVal == nil { + responseItems = append(responseItems, originals[i]) + continue + } + fetchedRecord := reflect.New(modelElemType).Interface() + if fetchErr := h.db.NewSelect().Model(fetchedRecord). + Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), pkVal). + ScanModel(ctx); fetchErr == nil { + responseItems = append(responseItems, mergeWithInput(fetchedRecord, originals[i])) + } else { + logger.Warn("Failed to re-fetch created record with %s=%v: %v", pkName, pkVal, fetchErr) + responseItems = append(responseItems, originals[i]) + } + } + h.sendResponse(w, responseItems, nil) case []interface{}: // Handle []interface{} type from JSON unmarshaling @@ -742,19 +795,34 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat } // Standard batch insert without nested relations - list := make([]interface{}, 0) + pkName := reflection.GetPrimaryKeyName(model) + modelElemType := reflection.GetPointerElement(reflect.TypeOf(model)) + originals := make([]map[string]interface{}, 0, len(v)) + insertedIDs := make([]interface{}, 0, len(v)) err := h.db.RunInTransaction(ctx, func(tx common.Database) error { for _, item := range v { - if itemMap, ok := item.(map[string]interface{}); ok { - txQuery := tx.NewInsert().Table(tableName) - for key, value := range itemMap { - txQuery = txQuery.Value(key, common.ConvertSliceForBun(value)) - } + itemMap, ok := item.(map[string]interface{}) + if !ok { + continue + } + txQuery := tx.NewInsert().Table(tableName) + for key, value := range itemMap { + txQuery = txQuery.Value(key, common.ConvertSliceForBun(value)) + } + if pkName == "" { if _, err := txQuery.Exec(ctx); err != nil { return err } - list = append(list, item) + originals = append(originals, itemMap) + insertedIDs = append(insertedIDs, nil) + continue } + var returnedID interface{} + if err := txQuery.Returning(pkName).Scan(ctx, &returnedID); err != nil { + return err + } + originals = append(originals, itemMap) + insertedIDs = append(insertedIDs, returnedID) } return nil }) @@ -769,7 +837,24 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat if err := invalidateCacheForTags(ctx, cacheTags); err != nil { logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err) } - h.sendResponse(w, list, nil) + // Re-fetch each record after transaction commits; fall back to input on failure. + responseItems := make([]interface{}, 0, len(insertedIDs)) + for i, pkVal := range insertedIDs { + if pkVal == nil { + responseItems = append(responseItems, originals[i]) + continue + } + fetchedRecord := reflect.New(modelElemType).Interface() + if fetchErr := h.db.NewSelect().Model(fetchedRecord). + Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), pkVal). + ScanModel(ctx); fetchErr == nil { + responseItems = append(responseItems, mergeWithInput(fetchedRecord, originals[i])) + } else { + logger.Warn("Failed to re-fetch created record with %s=%v: %v", pkName, pkVal, fetchErr) + responseItems = append(responseItems, originals[i]) + } + } + h.sendResponse(w, responseItems, nil) default: logger.Error("Invalid data type for create operation: %T", data) @@ -2122,3 +2207,25 @@ func (h *Handler) HandleOpenAPI(w common.ResponseWriter, r common.Request) { func (h *Handler) SetOpenAPIGenerator(generator func() (string, error)) { h.openAPIGenerator = generator } + +// mergeWithInput merges a database record with the original request data. +// DB values take precedence (capturing triggers/defaults), while extra +// input keys that have no DB column are preserved in the response. +func mergeWithInput(dbRecord interface{}, input map[string]interface{}) map[string]interface{} { + result := make(map[string]interface{}, len(input)) + for k, v := range input { + result[k] = v + } + jsonData, err := json.Marshal(dbRecord) + if err != nil { + return result + } + var dbMap map[string]interface{} + if err := json.Unmarshal(jsonData, &dbMap); err != nil { + return result + } + for k, v := range dbMap { + result[k] = v + } + return result +}