mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2026-06-13 09:03:44 +00:00
Compare commits
8 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| a2799fa224 | |||
| 1419542650 | |||
| c120b49529 | |||
| 66348dac97 | |||
| a87cd18b1b | |||
| 29449c93d5 | |||
| 3b6e5c75be | |||
| 549ccb8468 |
@@ -174,7 +174,9 @@ func (h *HTTPResponseWriter) Write(data []byte) (int, error) {
|
||||
|
||||
func (h *HTTPResponseWriter) WriteJSON(data interface{}) error {
|
||||
h.SetHeader("Content-Type", "application/json")
|
||||
return json.NewEncoder(h.resp).Encode(data)
|
||||
enc := json.NewEncoder(h.resp)
|
||||
enc.SetEscapeHTML(false)
|
||||
return enc.Encode(data)
|
||||
}
|
||||
|
||||
// UnderlyingResponseWriter returns the underlying http.ResponseWriter
|
||||
|
||||
@@ -178,7 +178,9 @@ func (s *StandardResponseWriter) Write(data []byte) (int, error) {
|
||||
|
||||
func (s *StandardResponseWriter) WriteJSON(data interface{}) error {
|
||||
s.SetHeader("Content-Type", "application/json")
|
||||
return json.NewEncoder(s.w).Encode(data)
|
||||
enc := json.NewEncoder(s.w)
|
||||
enc.SetEscapeHTML(false)
|
||||
return enc.Encode(data)
|
||||
}
|
||||
|
||||
func (s *StandardResponseWriter) UnderlyingResponseWriter() http.ResponseWriter {
|
||||
|
||||
@@ -113,7 +113,7 @@ func (p *NestedCUDProcessor) ProcessNestedCUD(
|
||||
|
||||
// Process based on operation
|
||||
switch strings.ToLower(operation) {
|
||||
case "insert", "create":
|
||||
case "insert", "create", "add":
|
||||
// Only perform insert if we have data to insert
|
||||
if hasData {
|
||||
id, err := p.processInsert(ctx, regularData, tableName)
|
||||
@@ -141,7 +141,7 @@ func (p *NestedCUDProcessor) ProcessNestedCUD(
|
||||
logger.Debug("Skipping insert for %s - no data columns besides _request", tableName)
|
||||
}
|
||||
|
||||
case "update", "change":
|
||||
case "update", "change", "modify":
|
||||
// Only perform update if we have data to update
|
||||
if reflection.IsEmptyValue(data[pkName]) {
|
||||
logger.Warn("Skipping update for %s - no primary key", tableName)
|
||||
@@ -174,7 +174,7 @@ func (p *NestedCUDProcessor) ProcessNestedCUD(
|
||||
result.ID = data[pkName]
|
||||
}
|
||||
|
||||
case "delete":
|
||||
case "delete", "remove":
|
||||
if reflection.IsEmptyValue(data[pkName]) {
|
||||
logger.Warn("Skipping delete for %s - no primary key", tableName)
|
||||
return result, nil
|
||||
@@ -471,13 +471,17 @@ func (p *NestedCUDProcessor) processChildRelations(
|
||||
// Priority: Use foreign key field name if specified
|
||||
var foreignKeyFieldName string
|
||||
if relInfo.ForeignKey != "" {
|
||||
// Get the JSON name for the foreign key field in the child model
|
||||
foreignKeyFieldName = reflection.GetJSONNameForField(relatedModelType, relInfo.ForeignKey)
|
||||
if foreignKeyFieldName == "" {
|
||||
// Fallback to lowercase field name
|
||||
foreignKeyFieldName = strings.ToLower(relInfo.ForeignKey)
|
||||
// For has-many/has-one: join:parentCol=childCol
|
||||
// ForeignKey = parent side, References = child side (where we actually set the value)
|
||||
childField := relInfo.ForeignKey
|
||||
if (relInfo.RelationType == "hasMany" || relInfo.RelationType == "hasOne") && relInfo.References != "" {
|
||||
childField = relInfo.References
|
||||
}
|
||||
logger.Debug("Using foreign key field for direct assignment: %s (from FK %s)", foreignKeyFieldName, relInfo.ForeignKey)
|
||||
foreignKeyFieldName = reflection.GetJSONNameForField(relatedModelType, childField)
|
||||
if foreignKeyFieldName == "" {
|
||||
foreignKeyFieldName = strings.ToLower(childField)
|
||||
}
|
||||
logger.Debug("Using foreign key field for direct assignment: %s (from FK %s -> child %s)", foreignKeyFieldName, relInfo.ForeignKey, childField)
|
||||
}
|
||||
|
||||
// Get the primary key name for the child model to avoid overwriting it in recursive relationships
|
||||
|
||||
@@ -713,6 +713,220 @@ func TestInjectForeignKeys(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// Models for asymmetric join column tests (mirrors the bun has-many join:parentCol=childCol pattern).
|
||||
// ActionOption has-many ActionOptionLinks via join:rid_actionoption=rid_actionoption_child.
|
||||
// The child column ("rid_actionoption_child") differs from the parent column ("rid_actionoption").
|
||||
type ActionOption struct {
|
||||
RidActionoption int64 `json:"rid_actionoption" bun:"rid_actionoption,pk"`
|
||||
Label string `json:"label"`
|
||||
Links []*ActionOptionLink `json:"aol_rid_actionoption_child,omitempty"`
|
||||
}
|
||||
|
||||
func (a ActionOption) TableName() string { return "action_options" }
|
||||
func (a ActionOption) GetIDName() string { return "RidActionoption" }
|
||||
|
||||
type ActionOptionLink struct {
|
||||
RidActionoptionlink int64 `json:"rid_actionoptionlink" bun:"rid_actionoptionlink,pk"`
|
||||
RidActionoptionChild int64 `json:"rid_actionoption_child" bun:"rid_actionoption_child"`
|
||||
Label string `json:"label"`
|
||||
// Note: no field named "rid_actionoption" — that is the parent's column.
|
||||
}
|
||||
|
||||
func (a ActionOptionLink) TableName() string { return "action_option_links" }
|
||||
func (a ActionOptionLink) GetIDName() string { return "RidActionoptionlink" }
|
||||
|
||||
// TestProcessNestedCUD_AsymmetricJoinColumns verifies that for a has-many relation with
|
||||
// join:parentCol=childCol, the child rows are stamped with the child-side column (References),
|
||||
// not the parent-side column (ForeignKey).
|
||||
func TestProcessNestedCUD_AsymmetricJoinColumns(t *testing.T) {
|
||||
db := newMockDatabase()
|
||||
registry := &mockModelRegistry{}
|
||||
relProvider := newMockRelationshipProvider()
|
||||
|
||||
// Mirrors: bun:"rel:has-many,join:rid_actionoption=rid_actionoption_child"
|
||||
relProvider.RegisterRelation("ActionOption", "aol_rid_actionoption_child", &RelationshipInfo{
|
||||
FieldName: "Links",
|
||||
JSONName: "aol_rid_actionoption_child",
|
||||
RelationType: "hasMany",
|
||||
ForeignKey: "rid_actionoption", // parent-side column (left of join:)
|
||||
References: "rid_actionoption_child", // child-side column (right of join:)
|
||||
RelatedModel: ActionOptionLink{},
|
||||
})
|
||||
|
||||
processor := NewNestedCUDProcessor(db, registry, relProvider)
|
||||
|
||||
data := map[string]interface{}{
|
||||
"label": "option-a",
|
||||
"aol_rid_actionoption_child": []interface{}{
|
||||
map[string]interface{}{"label": "link-1"},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := processor.ProcessNestedCUD(
|
||||
context.Background(),
|
||||
"insert",
|
||||
data,
|
||||
ActionOption{},
|
||||
nil,
|
||||
"action_options",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("ProcessNestedCUD failed: %v", err)
|
||||
}
|
||||
|
||||
if len(db.insertCalls) < 2 {
|
||||
t.Fatalf("Expected at least 2 insert calls (parent + child), got %d", len(db.insertCalls))
|
||||
}
|
||||
|
||||
childInsert := db.insertCalls[1]
|
||||
|
||||
// The fix: child must receive "rid_actionoption_child", NOT "rid_actionoption".
|
||||
if childInsert["rid_actionoption_child"] == nil {
|
||||
t.Error("Expected child to have rid_actionoption_child set (child-side FK column)")
|
||||
}
|
||||
if childInsert["rid_actionoption"] != nil {
|
||||
t.Errorf("Child must not receive parent-side column rid_actionoption, got %v", childInsert["rid_actionoption"])
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcessNestedCUD_BelongsToUnchanged verifies that the fix does not regress belongsTo
|
||||
// relations, where ForeignKey is already the local (child) column.
|
||||
func TestProcessNestedCUD_BelongsToUnchanged(t *testing.T) {
|
||||
db := newMockDatabase()
|
||||
registry := &mockModelRegistry{}
|
||||
relProvider := newMockRelationshipProvider()
|
||||
|
||||
// For belongsTo, ForeignKey is the column on the child; References is on the parent.
|
||||
// The old and new code must behave identically here.
|
||||
relProvider.RegisterRelation("Employee", "department", &RelationshipInfo{
|
||||
FieldName: "Department",
|
||||
JSONName: "department",
|
||||
RelationType: "belongsTo",
|
||||
ForeignKey: "DepartmentID", // child's own column
|
||||
References: "ID", // parent's PK
|
||||
RelatedModel: Department{},
|
||||
})
|
||||
relProvider.RegisterRelation("Department", "employees", &RelationshipInfo{
|
||||
FieldName: "Employees",
|
||||
JSONName: "employees",
|
||||
RelationType: "has_many",
|
||||
ForeignKey: "DepartmentID",
|
||||
RelatedModel: Employee{},
|
||||
})
|
||||
|
||||
processor := NewNestedCUDProcessor(db, registry, relProvider)
|
||||
|
||||
data := map[string]interface{}{
|
||||
"name": "Engineering",
|
||||
"employees": []interface{}{
|
||||
map[string]interface{}{"name": "Alice"},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := processor.ProcessNestedCUD(
|
||||
context.Background(),
|
||||
"insert",
|
||||
data,
|
||||
Department{},
|
||||
nil,
|
||||
"departments",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("ProcessNestedCUD failed: %v", err)
|
||||
}
|
||||
|
||||
if len(db.insertCalls) < 2 {
|
||||
t.Fatalf("Expected at least 2 inserts, got %d", len(db.insertCalls))
|
||||
}
|
||||
|
||||
// Employees relation uses has_many (old-style) so it goes through the parentIDs injection path,
|
||||
// not the foreignKeyFieldName path. Just confirm no panic and employee is inserted.
|
||||
if db.insertCalls[0]["name"] != "Engineering" {
|
||||
t.Errorf("Expected department name 'Engineering', got %v", db.insertCalls[0]["name"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessNestedCUD_AddAlias(t *testing.T) {
|
||||
db := newMockDatabase()
|
||||
registry := &mockModelRegistry{}
|
||||
relProvider := newMockRelationshipProvider()
|
||||
|
||||
processor := NewNestedCUDProcessor(db, registry, relProvider)
|
||||
|
||||
data := map[string]interface{}{
|
||||
"_request": "add",
|
||||
"name": "New Department",
|
||||
}
|
||||
|
||||
result, err := processor.ProcessNestedCUD(context.Background(), "insert", data, Department{}, nil, "departments")
|
||||
if err != nil {
|
||||
t.Fatalf("ProcessNestedCUD with _request=add failed: %v", err)
|
||||
}
|
||||
if result.ID == nil {
|
||||
t.Error("Expected result.ID to be set after add")
|
||||
}
|
||||
if len(db.insertCalls) != 1 {
|
||||
t.Errorf("Expected 1 insert call, got %d", len(db.insertCalls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessNestedCUD_RemoveAlias(t *testing.T) {
|
||||
db := newMockDatabase()
|
||||
registry := &mockModelRegistry{}
|
||||
relProvider := newMockRelationshipProvider()
|
||||
|
||||
processor := NewNestedCUDProcessor(db, registry, relProvider)
|
||||
|
||||
data := map[string]interface{}{
|
||||
"_request": "remove",
|
||||
"ID": int64(42),
|
||||
}
|
||||
|
||||
_, err := processor.ProcessNestedCUD(context.Background(), "delete", data, Department{}, nil, "departments")
|
||||
if err != nil {
|
||||
t.Fatalf("ProcessNestedCUD with _request=remove failed: %v", err)
|
||||
}
|
||||
if len(db.deleteCalls) != 1 {
|
||||
t.Errorf("Expected 1 delete call, got %d", len(db.deleteCalls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessNestedCUD_NestedAddRemoveAliases(t *testing.T) {
|
||||
db := newMockDatabase()
|
||||
registry := &mockModelRegistry{}
|
||||
relProvider := newMockRelationshipProvider()
|
||||
|
||||
relProvider.RegisterRelation("Department", "employees", &RelationshipInfo{
|
||||
FieldName: "Employees",
|
||||
JSONName: "employees",
|
||||
RelationType: "has_many",
|
||||
ForeignKey: "DepartmentID",
|
||||
RelatedModel: Employee{},
|
||||
})
|
||||
|
||||
processor := NewNestedCUDProcessor(db, registry, relProvider)
|
||||
|
||||
data := map[string]interface{}{
|
||||
"ID": int64(1),
|
||||
"name": "Engineering",
|
||||
"employees": []interface{}{
|
||||
map[string]interface{}{"_request": "add", "name": "Alice"},
|
||||
map[string]interface{}{"_request": "remove", "ID": int64(5)},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := processor.ProcessNestedCUD(context.Background(), "update", data, Department{}, nil, "departments")
|
||||
if err != nil {
|
||||
t.Fatalf("ProcessNestedCUD with nested add/remove failed: %v", err)
|
||||
}
|
||||
if len(db.insertCalls) != 1 {
|
||||
t.Errorf("Expected 1 insert (add alias) for employee, got %d", len(db.insertCalls))
|
||||
}
|
||||
if len(db.deleteCalls) != 1 {
|
||||
t.Errorf("Expected 1 delete (remove alias) for employee, got %d", len(db.deleteCalls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetPrimaryKeyName(t *testing.T) {
|
||||
dept := Department{}
|
||||
pkName := reflection.GetPrimaryKeyName(dept)
|
||||
|
||||
@@ -614,6 +614,15 @@ func extractTableAndColumn(cond string) (table string, column string) {
|
||||
// Remove any quotes
|
||||
columnRef = strings.Trim(columnRef, "`\"'")
|
||||
|
||||
// If the left side is a parenthesized subquery (starts with '(' and contains SQL keywords),
|
||||
// don't attempt prefix extraction from inside it.
|
||||
if len(columnRef) > 0 && columnRef[0] == '(' {
|
||||
lowerRef := strings.ToLower(columnRef)
|
||||
if strings.Contains(lowerRef, "select ") || strings.Contains(lowerRef, " from ") || strings.Contains(lowerRef, " where ") {
|
||||
return "", ""
|
||||
}
|
||||
}
|
||||
|
||||
// Check if there's a function call (contains opening parenthesis)
|
||||
openParenIdx := strings.Index(columnRef, "(")
|
||||
|
||||
|
||||
@@ -781,6 +781,12 @@ func (h *Handler) create(hookCtx *HookContext) (interface{}, error) {
|
||||
return nil, fmt.Errorf("failed to create record: %w", err)
|
||||
}
|
||||
|
||||
// Re-fetch the created record to capture DB-generated defaults/triggers.
|
||||
if pkVal := reflection.GetPrimaryKeyValue(hookCtx.ModelPtr); pkVal != nil {
|
||||
hookCtx.ID = fmt.Sprintf("%v", pkVal)
|
||||
return h.readByID(hookCtx)
|
||||
}
|
||||
|
||||
return hookCtx.ModelPtr, nil
|
||||
}
|
||||
|
||||
|
||||
+100
-5
@@ -428,14 +428,36 @@ func (h *Handler) executeCreate(ctx context.Context, schema, entity string, data
|
||||
// Use potentially modified data
|
||||
data = hookCtx.Data
|
||||
|
||||
pkName := reflection.GetPrimaryKeyName(model)
|
||||
|
||||
switch v := data.(type) {
|
||||
case map[string]interface{}:
|
||||
query := h.db.NewInsert().Table(tableName)
|
||||
for key, value := range v {
|
||||
query = query.Value(key, value)
|
||||
}
|
||||
if _, err := query.Exec(ctx); err != nil {
|
||||
return nil, fmt.Errorf("create error: %w", err)
|
||||
if pkName != "" {
|
||||
var insertedID interface{}
|
||||
if err := query.Returning(pkName).Scan(ctx, &insertedID); err != nil {
|
||||
return nil, fmt.Errorf("create error: %w", err)
|
||||
}
|
||||
// Re-fetch after insert to capture DB-generated defaults/triggers.
|
||||
modelType := reflect.TypeOf(model)
|
||||
if modelType.Kind() == reflect.Ptr {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
fetchedRecord := reflect.New(modelType).Interface()
|
||||
if err := h.db.NewSelect().Model(fetchedRecord).
|
||||
Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), insertedID).
|
||||
ScanModel(ctx); err == nil {
|
||||
v = mergeWithInput(fetchedRecord, v)
|
||||
} else {
|
||||
logger.Warn("Failed to re-fetch created record with %s=%v: %v", pkName, insertedID, err)
|
||||
}
|
||||
} else {
|
||||
if _, err := query.Exec(ctx); err != nil {
|
||||
return nil, fmt.Errorf("create error: %w", err)
|
||||
}
|
||||
}
|
||||
hookCtx.Result = v
|
||||
if err := h.hooks.Execute(AfterCreate, hookCtx); err != nil {
|
||||
@@ -444,7 +466,12 @@ func (h *Handler) executeCreate(ctx context.Context, schema, entity string, data
|
||||
return v, nil
|
||||
|
||||
case []interface{}:
|
||||
results := make([]interface{}, 0, len(v))
|
||||
modelType := reflect.TypeOf(model)
|
||||
if modelType.Kind() == reflect.Ptr {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
originals := make([]map[string]interface{}, 0, len(v))
|
||||
insertedIDs := make([]interface{}, 0, len(v))
|
||||
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
||||
for _, item := range v {
|
||||
itemMap, ok := item.(map[string]interface{})
|
||||
@@ -455,16 +482,43 @@ func (h *Handler) executeCreate(ctx context.Context, schema, entity string, data
|
||||
for key, value := range itemMap {
|
||||
q = q.Value(key, value)
|
||||
}
|
||||
if _, err := q.Exec(ctx); err != nil {
|
||||
if pkName == "" {
|
||||
if _, err := q.Exec(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
originals = append(originals, itemMap)
|
||||
insertedIDs = append(insertedIDs, nil)
|
||||
continue
|
||||
}
|
||||
var returnedID interface{}
|
||||
if err := q.Returning(pkName).Scan(ctx, &returnedID); err != nil {
|
||||
return err
|
||||
}
|
||||
results = append(results, item)
|
||||
originals = append(originals, itemMap)
|
||||
insertedIDs = append(insertedIDs, returnedID)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("batch create error: %w", err)
|
||||
}
|
||||
// Re-fetch each record after transaction commits; fall back to input on failure.
|
||||
results := make([]interface{}, 0, len(insertedIDs))
|
||||
for i, pkVal := range insertedIDs {
|
||||
if pkVal == nil {
|
||||
results = append(results, originals[i])
|
||||
continue
|
||||
}
|
||||
fetchedRecord := reflect.New(modelType).Interface()
|
||||
if err := h.db.NewSelect().Model(fetchedRecord).
|
||||
Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), pkVal).
|
||||
ScanModel(ctx); err == nil {
|
||||
results = append(results, mergeWithInput(fetchedRecord, originals[i]))
|
||||
} else {
|
||||
logger.Warn("Failed to re-fetch created record with %s=%v: %v", pkName, pkVal, err)
|
||||
results = append(results, originals[i])
|
||||
}
|
||||
}
|
||||
hookCtx.Result = results
|
||||
if err := h.hooks.Execute(AfterCreate, hookCtx); err != nil {
|
||||
return nil, fmt.Errorf("AfterCreate hook failed: %w", err)
|
||||
@@ -584,6 +638,25 @@ func (h *Handler) executeUpdate(ctx context.Context, schema, entity, id string,
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Re-fetch the record after transaction commits to capture DB-generated changes.
|
||||
modelType := reflect.TypeOf(model)
|
||||
if modelType.Kind() == reflect.Ptr {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
fetchedRecord := reflect.New(modelType).Interface()
|
||||
if err := h.db.NewSelect().Model(fetchedRecord).
|
||||
Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id).
|
||||
ScanModel(ctx); err == nil {
|
||||
jsonData, marshalErr := json.Marshal(fetchedRecord)
|
||||
if marshalErr == nil {
|
||||
var fetchedMap map[string]interface{}
|
||||
if json.Unmarshal(jsonData, &fetchedMap) == nil {
|
||||
updateResult = fetchedMap
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return updateResult, nil
|
||||
}
|
||||
|
||||
@@ -749,6 +822,28 @@ func (h *Handler) buildFilterCondition(filter common.FilterOption) (condition st
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// mergeWithInput merges a database record with the original request data.
|
||||
// DB values take precedence (capturing triggers/defaults), while extra
|
||||
// input keys that have no DB column are preserved in the response.
|
||||
func mergeWithInput(dbRecord interface{}, input map[string]interface{}) map[string]interface{} {
|
||||
result := make(map[string]interface{}, len(input))
|
||||
for k, v := range input {
|
||||
result[k] = v
|
||||
}
|
||||
jsonData, err := json.Marshal(dbRecord)
|
||||
if err != nil {
|
||||
return result
|
||||
}
|
||||
var dbMap map[string]interface{}
|
||||
if err := json.Unmarshal(jsonData, &dbMap); err != nil {
|
||||
return result
|
||||
}
|
||||
for k, v := range dbMap {
|
||||
result[k] = v
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, preloads []common.PreloadOption) (common.SelectQuery, error) {
|
||||
for i := range preloads {
|
||||
preload := &preloads[i]
|
||||
|
||||
+187
-25
@@ -602,23 +602,44 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
|
||||
}
|
||||
|
||||
// Standard processing without nested relations
|
||||
pkName := reflection.GetPrimaryKeyName(model)
|
||||
query := h.db.NewInsert().Table(tableName)
|
||||
for key, value := range v {
|
||||
query = query.Value(key, common.ConvertSliceForBun(value))
|
||||
}
|
||||
result, err := query.Exec(ctx)
|
||||
if err != nil {
|
||||
logger.Error("Error creating record: %v", err)
|
||||
h.sendError(w, http.StatusInternalServerError, "create_error", "Error creating record", err)
|
||||
return
|
||||
var responseData interface{} = v
|
||||
if pkName == "" {
|
||||
// No PK on model — insert and return input as-is.
|
||||
result, err := query.Exec(ctx)
|
||||
if err != nil {
|
||||
logger.Error("Error creating record: %v", err)
|
||||
h.sendError(w, http.StatusInternalServerError, "create_error", "Error creating record", err)
|
||||
return
|
||||
}
|
||||
logger.Info("Successfully created record, rows affected: %d", result.RowsAffected())
|
||||
} else {
|
||||
var insertedID interface{}
|
||||
if err := query.Returning(pkName).Scan(ctx, &insertedID); err != nil {
|
||||
logger.Error("Error creating record: %v", err)
|
||||
h.sendError(w, http.StatusInternalServerError, "create_error", "Error creating record", err)
|
||||
return
|
||||
}
|
||||
logger.Info("Successfully created record with %s: %v", pkName, insertedID)
|
||||
fetchedRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface()
|
||||
if fetchErr := h.db.NewSelect().Model(fetchedRecord).
|
||||
Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), insertedID).
|
||||
ScanModel(ctx); fetchErr == nil {
|
||||
responseData = mergeWithInput(fetchedRecord, v)
|
||||
} else {
|
||||
logger.Warn("Failed to re-fetch created record with %s=%v: %v", pkName, insertedID, fetchErr)
|
||||
}
|
||||
}
|
||||
logger.Info("Successfully created record, rows affected: %d", result.RowsAffected())
|
||||
// Invalidate cache for this table
|
||||
cacheTags := buildCacheTags(schema, tableName)
|
||||
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
||||
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
||||
}
|
||||
h.sendResponse(w, v, nil)
|
||||
h.sendResponse(w, responseData, nil)
|
||||
|
||||
case []map[string]interface{}:
|
||||
// Check if any item needs nested processing
|
||||
@@ -666,15 +687,30 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
|
||||
}
|
||||
|
||||
// Standard batch insert without nested relations
|
||||
pkName := reflection.GetPrimaryKeyName(model)
|
||||
modelElemType := reflection.GetPointerElement(reflect.TypeOf(model))
|
||||
originals := make([]map[string]interface{}, 0, len(v))
|
||||
insertedIDs := make([]interface{}, 0, len(v))
|
||||
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
||||
for _, item := range v {
|
||||
txQuery := tx.NewInsert().Table(tableName)
|
||||
for key, value := range item {
|
||||
txQuery = txQuery.Value(key, common.ConvertSliceForBun(value))
|
||||
}
|
||||
if _, err := txQuery.Exec(ctx); err != nil {
|
||||
if pkName == "" {
|
||||
if _, err := txQuery.Exec(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
originals = append(originals, item)
|
||||
insertedIDs = append(insertedIDs, nil)
|
||||
continue
|
||||
}
|
||||
var returnedID interface{}
|
||||
if err := txQuery.Returning(pkName).Scan(ctx, &returnedID); err != nil {
|
||||
return err
|
||||
}
|
||||
originals = append(originals, item)
|
||||
insertedIDs = append(insertedIDs, returnedID)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
@@ -689,7 +725,24 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
|
||||
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
||||
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
||||
}
|
||||
h.sendResponse(w, v, nil)
|
||||
// Re-fetch each record after transaction commits; fall back to input on failure.
|
||||
responseItems := make([]interface{}, 0, len(insertedIDs))
|
||||
for i, pkVal := range insertedIDs {
|
||||
if pkVal == nil {
|
||||
responseItems = append(responseItems, originals[i])
|
||||
continue
|
||||
}
|
||||
fetchedRecord := reflect.New(modelElemType).Interface()
|
||||
if fetchErr := h.db.NewSelect().Model(fetchedRecord).
|
||||
Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), pkVal).
|
||||
ScanModel(ctx); fetchErr == nil {
|
||||
responseItems = append(responseItems, mergeWithInput(fetchedRecord, originals[i]))
|
||||
} else {
|
||||
logger.Warn("Failed to re-fetch created record with %s=%v: %v", pkName, pkVal, fetchErr)
|
||||
responseItems = append(responseItems, originals[i])
|
||||
}
|
||||
}
|
||||
h.sendResponse(w, responseItems, nil)
|
||||
|
||||
case []interface{}:
|
||||
// Handle []interface{} type from JSON unmarshaling
|
||||
@@ -742,19 +795,34 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
|
||||
}
|
||||
|
||||
// Standard batch insert without nested relations
|
||||
list := make([]interface{}, 0)
|
||||
pkName := reflection.GetPrimaryKeyName(model)
|
||||
modelElemType := reflection.GetPointerElement(reflect.TypeOf(model))
|
||||
originals := make([]map[string]interface{}, 0, len(v))
|
||||
insertedIDs := make([]interface{}, 0, len(v))
|
||||
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
||||
for _, item := range v {
|
||||
if itemMap, ok := item.(map[string]interface{}); ok {
|
||||
txQuery := tx.NewInsert().Table(tableName)
|
||||
for key, value := range itemMap {
|
||||
txQuery = txQuery.Value(key, common.ConvertSliceForBun(value))
|
||||
}
|
||||
itemMap, ok := item.(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
txQuery := tx.NewInsert().Table(tableName)
|
||||
for key, value := range itemMap {
|
||||
txQuery = txQuery.Value(key, common.ConvertSliceForBun(value))
|
||||
}
|
||||
if pkName == "" {
|
||||
if _, err := txQuery.Exec(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
list = append(list, item)
|
||||
originals = append(originals, itemMap)
|
||||
insertedIDs = append(insertedIDs, nil)
|
||||
continue
|
||||
}
|
||||
var returnedID interface{}
|
||||
if err := txQuery.Returning(pkName).Scan(ctx, &returnedID); err != nil {
|
||||
return err
|
||||
}
|
||||
originals = append(originals, itemMap)
|
||||
insertedIDs = append(insertedIDs, returnedID)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
@@ -769,7 +837,24 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
|
||||
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
||||
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
||||
}
|
||||
h.sendResponse(w, list, nil)
|
||||
// Re-fetch each record after transaction commits; fall back to input on failure.
|
||||
responseItems := make([]interface{}, 0, len(insertedIDs))
|
||||
for i, pkVal := range insertedIDs {
|
||||
if pkVal == nil {
|
||||
responseItems = append(responseItems, originals[i])
|
||||
continue
|
||||
}
|
||||
fetchedRecord := reflect.New(modelElemType).Interface()
|
||||
if fetchErr := h.db.NewSelect().Model(fetchedRecord).
|
||||
Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), pkVal).
|
||||
ScanModel(ctx); fetchErr == nil {
|
||||
responseItems = append(responseItems, mergeWithInput(fetchedRecord, originals[i]))
|
||||
} else {
|
||||
logger.Warn("Failed to re-fetch created record with %s=%v: %v", pkName, pkVal, fetchErr)
|
||||
responseItems = append(responseItems, originals[i])
|
||||
}
|
||||
}
|
||||
h.sendResponse(w, responseItems, nil)
|
||||
|
||||
default:
|
||||
logger.Error("Invalid data type for create operation: %T", data)
|
||||
@@ -836,7 +921,7 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url
|
||||
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
||||
// First, read the existing record from the database
|
||||
existingRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface()
|
||||
selectQuery := tx.NewSelect().Model(existingRecord).Column("*")
|
||||
selectQuery := tx.NewSelect().Model(existingRecord).Column(reflection.GetSQLModelColumns(model)...)
|
||||
|
||||
// Apply conditions to select
|
||||
if urlID != "" {
|
||||
@@ -955,13 +1040,34 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url
|
||||
return
|
||||
}
|
||||
|
||||
// Fetch the updated record after the transaction commits to capture any trigger changes
|
||||
updatedRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface()
|
||||
fetchQuery := h.db.NewSelect().Model(updatedRecord).Column(reflection.GetSQLModelColumns(model)...)
|
||||
if urlID != "" {
|
||||
fetchQuery = fetchQuery.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), urlID)
|
||||
} else if reqID != nil {
|
||||
switch id := reqID.(type) {
|
||||
case string:
|
||||
fetchQuery = fetchQuery.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id)
|
||||
case []string:
|
||||
if len(id) > 0 {
|
||||
fetchQuery = fetchQuery.Where(fmt.Sprintf("%s IN (?)", common.QuoteIdent(pkName)), id)
|
||||
}
|
||||
}
|
||||
}
|
||||
if err := fetchQuery.ScanModel(ctx); err != nil {
|
||||
logger.Error("Failed to fetch updated record: %v", err)
|
||||
h.sendError(w, http.StatusInternalServerError, "fetch_error", "Failed to fetch updated record", err)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Info("Successfully updated record(s)")
|
||||
// Invalidate cache for this table
|
||||
cacheTags := buildCacheTags(schema, tableName)
|
||||
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
||||
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
||||
}
|
||||
h.sendResponse(w, data, nil)
|
||||
h.sendResponse(w, updatedRecord, nil)
|
||||
|
||||
case []map[string]interface{}:
|
||||
// Batch update with array of objects
|
||||
@@ -1017,7 +1123,7 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url
|
||||
|
||||
// First, read the existing record
|
||||
existingRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface()
|
||||
selectQuery := tx.NewSelect().Model(existingRecord).Column("*").Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), itemID)
|
||||
selectQuery := tx.NewSelect().Model(existingRecord).Column(reflection.GetSQLModelColumns(model)...).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), itemID)
|
||||
if err := selectQuery.ScanModel(ctx); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
continue // Skip if record not found
|
||||
@@ -1089,13 +1195,29 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url
|
||||
h.sendError(w, http.StatusInternalServerError, "update_error", "Error updating records", err)
|
||||
return
|
||||
}
|
||||
logger.Info("Successfully updated %d records", len(updates))
|
||||
|
||||
// Fetch updated records after the transaction commits to capture any trigger changes
|
||||
fetchedUpdates := make([]interface{}, 0, len(updates))
|
||||
for _, item := range updates {
|
||||
if itemID, ok := item["id"]; ok && itemID != nil {
|
||||
fetchedRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface()
|
||||
fetchQuery := h.db.NewSelect().Model(fetchedRecord).Column(reflection.GetSQLModelColumns(model)...).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), itemID)
|
||||
if err := fetchQuery.ScanModel(ctx); err != nil {
|
||||
logger.Error("Failed to fetch updated record with ID %v: %v", itemID, err)
|
||||
h.sendError(w, http.StatusInternalServerError, "fetch_error", "Failed to fetch updated record", err)
|
||||
return
|
||||
}
|
||||
fetchedUpdates = append(fetchedUpdates, fetchedRecord)
|
||||
}
|
||||
}
|
||||
|
||||
logger.Info("Successfully updated %d records", len(fetchedUpdates))
|
||||
// Invalidate cache for this table
|
||||
cacheTags := buildCacheTags(schema, tableName)
|
||||
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
||||
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
||||
}
|
||||
h.sendResponse(w, updates, nil)
|
||||
h.sendResponse(w, fetchedUpdates, nil)
|
||||
|
||||
case []interface{}:
|
||||
// Batch update with []interface{}
|
||||
@@ -1157,7 +1279,7 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url
|
||||
|
||||
// First, read the existing record
|
||||
existingRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface()
|
||||
selectQuery := tx.NewSelect().Model(existingRecord).Column("*").Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), itemID)
|
||||
selectQuery := tx.NewSelect().Model(existingRecord).Column(reflection.GetSQLModelColumns(model)...).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), itemID)
|
||||
if err := selectQuery.ScanModel(ctx); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
continue // Skip if record not found
|
||||
@@ -1232,13 +1354,31 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url
|
||||
h.sendError(w, http.StatusInternalServerError, "update_error", "Error updating records", err)
|
||||
return
|
||||
}
|
||||
logger.Info("Successfully updated %d records", len(list))
|
||||
|
||||
// Fetch updated records after the transaction commits to capture any trigger changes
|
||||
fetchedList := make([]interface{}, 0, len(list))
|
||||
for _, item := range list {
|
||||
if itemMap, ok := item.(map[string]interface{}); ok {
|
||||
if itemID, ok := itemMap["id"]; ok && itemID != nil {
|
||||
fetchedRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface()
|
||||
fetchQuery := h.db.NewSelect().Model(fetchedRecord).Column(reflection.GetSQLModelColumns(model)...).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), itemID)
|
||||
if err := fetchQuery.ScanModel(ctx); err != nil {
|
||||
logger.Error("Failed to fetch updated record with ID %v: %v", itemID, err)
|
||||
h.sendError(w, http.StatusInternalServerError, "fetch_error", "Failed to fetch updated record", err)
|
||||
return
|
||||
}
|
||||
fetchedList = append(fetchedList, fetchedRecord)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
logger.Info("Successfully updated %d records", len(fetchedList))
|
||||
// Invalidate cache for this table
|
||||
cacheTags := buildCacheTags(schema, tableName)
|
||||
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
||||
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
||||
}
|
||||
h.sendResponse(w, list, nil)
|
||||
h.sendResponse(w, fetchedList, nil)
|
||||
|
||||
default:
|
||||
logger.Error("Invalid data type for update operation: %T", data)
|
||||
@@ -2067,3 +2207,25 @@ func (h *Handler) HandleOpenAPI(w common.ResponseWriter, r common.Request) {
|
||||
func (h *Handler) SetOpenAPIGenerator(generator func() (string, error)) {
|
||||
h.openAPIGenerator = generator
|
||||
}
|
||||
|
||||
// mergeWithInput merges a database record with the original request data.
|
||||
// DB values take precedence (capturing triggers/defaults), while extra
|
||||
// input keys that have no DB column are preserved in the response.
|
||||
func mergeWithInput(dbRecord interface{}, input map[string]interface{}) map[string]interface{} {
|
||||
result := make(map[string]interface{}, len(input))
|
||||
for k, v := range input {
|
||||
result[k] = v
|
||||
}
|
||||
jsonData, err := json.Marshal(dbRecord)
|
||||
if err != nil {
|
||||
return result
|
||||
}
|
||||
var dbMap map[string]interface{}
|
||||
if err := json.Unmarshal(jsonData, &dbMap); err != nil {
|
||||
return result
|
||||
}
|
||||
for k, v := range dbMap {
|
||||
result[k] = v
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -2011,11 +2011,15 @@ func (h *Handler) processChildRelationsForField(
|
||||
// Priority: Use foreign key field name if specified, otherwise use parent's PK name
|
||||
var foreignKeyFieldName string
|
||||
if relInfo.ForeignKey != "" {
|
||||
// Get the JSON name for the foreign key field in the child model
|
||||
foreignKeyFieldName = reflection.GetJSONNameForField(relatedModelType, relInfo.ForeignKey)
|
||||
// For has-many/has-one: join:parentCol=childCol
|
||||
// ForeignKey = parent side, References = child side (where we actually set the value)
|
||||
childField := relInfo.ForeignKey
|
||||
if (relInfo.RelationType == "hasMany" || relInfo.RelationType == "hasOne") && relInfo.References != "" {
|
||||
childField = relInfo.References
|
||||
}
|
||||
foreignKeyFieldName = reflection.GetJSONNameForField(relatedModelType, childField)
|
||||
if foreignKeyFieldName == "" {
|
||||
// Fallback to lowercase field name
|
||||
foreignKeyFieldName = strings.ToLower(relInfo.ForeignKey)
|
||||
foreignKeyFieldName = strings.ToLower(childField)
|
||||
}
|
||||
} else {
|
||||
// Fallback: use parent's primary key name
|
||||
@@ -2039,7 +2043,10 @@ func (h *Handler) processChildRelationsForField(
|
||||
// Process based on relation type and data structure
|
||||
switch v := relationValue.(type) {
|
||||
case map[string]interface{}:
|
||||
// Single related object - add parent ID to foreign key field
|
||||
if !isValidNestedRequest(v) {
|
||||
logger.Debug("Skipping single relation %s - missing or invalid _request value", relationName)
|
||||
return nil
|
||||
}
|
||||
// IMPORTANT: In recursive relationships, don't overwrite the primary key
|
||||
if parentID != nil && foreignKeyFieldName != "" && foreignKeyFieldName != childPKFieldName {
|
||||
v[foreignKeyFieldName] = parentID
|
||||
@@ -2056,7 +2063,10 @@ func (h *Handler) processChildRelationsForField(
|
||||
// Multiple related objects
|
||||
for i, item := range v {
|
||||
if itemMap, ok := item.(map[string]interface{}); ok {
|
||||
// Add parent ID to foreign key field
|
||||
if !isValidNestedRequest(itemMap) {
|
||||
logger.Debug("Skipping relation array[%d] %s - missing or invalid _request value", i, relationName)
|
||||
continue
|
||||
}
|
||||
// IMPORTANT: In recursive relationships, don't overwrite the primary key
|
||||
if parentID != nil && foreignKeyFieldName != "" && foreignKeyFieldName != childPKFieldName {
|
||||
itemMap[foreignKeyFieldName] = parentID
|
||||
@@ -2074,7 +2084,10 @@ func (h *Handler) processChildRelationsForField(
|
||||
case []map[string]interface{}:
|
||||
// Multiple related objects (typed slice)
|
||||
for i, itemMap := range v {
|
||||
// Add parent ID to foreign key field
|
||||
if !isValidNestedRequest(itemMap) {
|
||||
logger.Debug("Skipping relation typed array[%d] %s - missing or invalid _request value", i, relationName)
|
||||
continue
|
||||
}
|
||||
// IMPORTANT: In recursive relationships, don't overwrite the primary key
|
||||
if parentID != nil && foreignKeyFieldName != "" && foreignKeyFieldName != childPKFieldName {
|
||||
itemMap[foreignKeyFieldName] = parentID
|
||||
@@ -2095,6 +2108,24 @@ func (h *Handler) processChildRelationsForField(
|
||||
return nil
|
||||
}
|
||||
|
||||
// isValidNestedRequest returns true only when the item carries a _request key
|
||||
// whose value is one of the recognised mutation verbs.
|
||||
func isValidNestedRequest(item map[string]interface{}) bool {
|
||||
raw, ok := item["_request"]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
s, ok := raw.(string)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
switch strings.ToLower(strings.TrimSpace(s)) {
|
||||
case "insert", "add", "change", "update", "delete", "remove":
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// getTableNameForRelatedModel gets the table name for a related model.
|
||||
// If the model's TableName() is schema-qualified (e.g. "public.users") the
|
||||
// separator is adjusted for the active driver: underscore for SQLite, dot otherwise.
|
||||
|
||||
@@ -352,6 +352,45 @@ func (m *mockRegistry) GetAllModels() map[string]interface{} {
|
||||
return m.models
|
||||
}
|
||||
|
||||
// TestIsValidNestedRequest verifies that only the allowed _request verbs are accepted
|
||||
// and that items missing the key are rejected.
|
||||
func TestIsValidNestedRequest(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
item map[string]interface{}
|
||||
expected bool
|
||||
}{
|
||||
// Valid verbs
|
||||
{name: "insert", item: map[string]interface{}{"_request": "insert"}, expected: true},
|
||||
{name: "add", item: map[string]interface{}{"_request": "add"}, expected: true},
|
||||
{name: "update", item: map[string]interface{}{"_request": "update"}, expected: true},
|
||||
{name: "change", item: map[string]interface{}{"_request": "change"}, expected: true},
|
||||
{name: "delete", item: map[string]interface{}{"_request": "delete"}, expected: true},
|
||||
{name: "remove", item: map[string]interface{}{"_request": "remove"}, expected: true},
|
||||
// Case-insensitive
|
||||
{name: "INSERT uppercase", item: map[string]interface{}{"_request": "INSERT"}, expected: true},
|
||||
{name: "Remove mixed case", item: map[string]interface{}{"_request": "Remove"}, expected: true},
|
||||
// Whitespace trimmed
|
||||
{name: "insert with spaces", item: map[string]interface{}{"_request": " insert "}, expected: true},
|
||||
// Invalid / missing
|
||||
{name: "missing _request", item: map[string]interface{}{"name": "foo"}, expected: false},
|
||||
{name: "empty string", item: map[string]interface{}{"_request": ""}, expected: false},
|
||||
{name: "unknown verb", item: map[string]interface{}{"_request": "create"}, expected: false},
|
||||
{name: "unknown verb modify", item: map[string]interface{}{"_request": "modify"}, expected: false},
|
||||
{name: "non-string value", item: map[string]interface{}{"_request": 42}, expected: false},
|
||||
{name: "nil value", item: map[string]interface{}{"_request": nil}, expected: false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := isValidNestedRequest(tt.item)
|
||||
if got != tt.expected {
|
||||
t.Errorf("isValidNestedRequest(%v) = %v, want %v", tt.item, got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestMultiLevelRelationExtraction tests extracting deeply nested relations
|
||||
func TestMultiLevelRelationExtraction(t *testing.T) {
|
||||
registry := &mockRegistry{
|
||||
|
||||
@@ -671,6 +671,12 @@ func (h *Handler) create(hookCtx *HookContext) (interface{}, error) {
|
||||
return nil, fmt.Errorf("failed to create record: %w", err)
|
||||
}
|
||||
|
||||
// Re-fetch the created record to capture DB-generated defaults/triggers.
|
||||
if pkVal := reflection.GetPrimaryKeyValue(hookCtx.ModelPtr); pkVal != nil {
|
||||
hookCtx.ID = fmt.Sprintf("%v", pkVal)
|
||||
return h.readByID(hookCtx)
|
||||
}
|
||||
|
||||
return hookCtx.ModelPtr, nil
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user