diff --git a/pkg/common/recursive_crud.go b/pkg/common/recursive_crud.go index 6e047fb..f6261c3 100644 --- a/pkg/common/recursive_crud.go +++ b/pkg/common/recursive_crud.go @@ -74,6 +74,7 @@ func (p *NestedCUDProcessor) ProcessNestedCUD( } if modelType == nil || modelType.Kind() != reflect.Struct { + logger.Error("Invalid model type: operation=%s, table=%s, modelType=%v, expected struct", operation, tableName, modelType) return nil, fmt.Errorf("model must be a struct type, got %v", modelType) } @@ -103,44 +104,64 @@ func (p *NestedCUDProcessor) ProcessNestedCUD( // Get the primary key name for this model pkName := reflection.GetPrimaryKeyName(model) + // Check if we have any data to process (besides _request) + hasData := len(regularData) > 0 + // Process based on operation switch strings.ToLower(operation) { case "insert", "create": - id, err := p.processInsert(ctx, regularData, tableName) - if err != nil { - return nil, fmt.Errorf("insert failed: %w", err) - } - result.ID = id - result.AffectedRows = 1 - result.Data = regularData + // Only perform insert if we have data to insert + if hasData { + id, err := p.processInsert(ctx, regularData, tableName) + if err != nil { + logger.Error("Insert failed for table=%s, data=%+v, error=%v", tableName, regularData, err) + return nil, fmt.Errorf("insert failed: %w", err) + } + result.ID = id + result.AffectedRows = 1 + result.Data = regularData - // Process child relations after parent insert (to get parent ID) - if err := p.processChildRelations(ctx, "insert", id, relationFields, result.RelationData, modelType); err != nil { - return nil, fmt.Errorf("failed to process child relations: %w", err) + // Process child relations after parent insert (to get parent ID) + if err := p.processChildRelations(ctx, "insert", id, relationFields, result.RelationData, modelType, parentIDs); err != nil { + logger.Error("Failed to process child relations after insert: table=%s, parentID=%v, relations=%+v, error=%v", tableName, id, relationFields, err) + return nil, fmt.Errorf("failed to process child relations: %w", err) + } + } else { + logger.Debug("Skipping insert for %s - no data columns besides _request", tableName) } case "update": - rows, err := p.processUpdate(ctx, regularData, tableName, data[pkName]) - if err != nil { - return nil, fmt.Errorf("update failed: %w", err) - } - result.ID = data[pkName] - result.AffectedRows = rows - result.Data = regularData + // Only perform update if we have data to update + if hasData { + rows, err := p.processUpdate(ctx, regularData, tableName, data[pkName]) + if err != nil { + logger.Error("Update failed for table=%s, id=%v, data=%+v, error=%v", tableName, data[pkName], regularData, err) + return nil, fmt.Errorf("update failed: %w", err) + } + result.ID = data[pkName] + result.AffectedRows = rows + result.Data = regularData - // Process child relations for update - if err := p.processChildRelations(ctx, "update", data[pkName], relationFields, result.RelationData, modelType); err != nil { - return nil, fmt.Errorf("failed to process child relations: %w", err) + // Process child relations for update + if err := p.processChildRelations(ctx, "update", data[pkName], relationFields, result.RelationData, modelType, parentIDs); err != nil { + logger.Error("Failed to process child relations after update: table=%s, parentID=%v, relations=%+v, error=%v", tableName, data[pkName], relationFields, err) + return nil, fmt.Errorf("failed to process child relations: %w", err) + } + } else { + logger.Debug("Skipping update for %s - no data columns besides _request", tableName) + result.ID = data[pkName] } case "delete": // Process child relations first (for referential integrity) - if err := p.processChildRelations(ctx, "delete", data[pkName], relationFields, result.RelationData, modelType); err != nil { + if err := p.processChildRelations(ctx, "delete", data[pkName], relationFields, result.RelationData, modelType, parentIDs); err != nil { + logger.Error("Failed to process child relations before delete: table=%s, id=%v, relations=%+v, error=%v", tableName, data[pkName], relationFields, err) return nil, fmt.Errorf("failed to process child relations before delete: %w", err) } rows, err := p.processDelete(ctx, tableName, data[pkName]) if err != nil { + logger.Error("Delete failed for table=%s, id=%v, error=%v", tableName, data[pkName], err) return nil, fmt.Errorf("delete failed: %w", err) } result.ID = data[pkName] @@ -148,6 +169,7 @@ func (p *NestedCUDProcessor) ProcessNestedCUD( result.Data = regularData default: + logger.Error("Unsupported operation: %s for table=%s", operation, tableName) return nil, fmt.Errorf("unsupported operation: %s", operation) } @@ -213,6 +235,7 @@ func (p *NestedCUDProcessor) processInsert( result, err := query.Exec(ctx) if err != nil { + logger.Error("Insert execution failed: table=%s, data=%+v, error=%v", tableName, data, err) return nil, fmt.Errorf("insert exec failed: %w", err) } @@ -236,6 +259,7 @@ func (p *NestedCUDProcessor) processUpdate( id interface{}, ) (int64, error) { if id == nil { + logger.Error("Update requires an ID: table=%s, data=%+v", tableName, data) return 0, fmt.Errorf("update requires an ID") } @@ -245,6 +269,7 @@ func (p *NestedCUDProcessor) processUpdate( result, err := query.Exec(ctx) if err != nil { + logger.Error("Update execution failed: table=%s, id=%v, data=%+v, error=%v", tableName, id, data, err) return 0, fmt.Errorf("update exec failed: %w", err) } @@ -256,6 +281,7 @@ func (p *NestedCUDProcessor) processUpdate( // processDelete handles delete operation func (p *NestedCUDProcessor) processDelete(ctx context.Context, tableName string, id interface{}) (int64, error) { if id == nil { + logger.Error("Delete requires an ID: table=%s", tableName) return 0, fmt.Errorf("delete requires an ID") } @@ -265,6 +291,7 @@ func (p *NestedCUDProcessor) processDelete(ctx context.Context, tableName string result, err := query.Exec(ctx) if err != nil { + logger.Error("Delete execution failed: table=%s, id=%v, error=%v", tableName, id, err) return 0, fmt.Errorf("delete exec failed: %w", err) } @@ -281,6 +308,7 @@ func (p *NestedCUDProcessor) processChildRelations( relationFields map[string]*RelationshipInfo, relationData map[string]interface{}, parentModelType reflect.Type, + incomingParentIDs map[string]interface{}, // IDs from all ancestors ) error { for relationName, relInfo := range relationFields { relationValue, exists := relationData[relationName] @@ -293,7 +321,7 @@ func (p *NestedCUDProcessor) processChildRelations( // Get the related model field, found := parentModelType.FieldByName(relInfo.FieldName) if !found { - logger.Warn("Field %s not found in model", relInfo.FieldName) + logger.Error("Field %s not found in model type %v for relation %s", relInfo.FieldName, parentModelType, relationName) continue } @@ -313,20 +341,77 @@ func (p *NestedCUDProcessor) processChildRelations( relatedTableName := p.getTableNameForModel(relatedModel, relInfo.JSONName) // Prepare parent IDs for foreign key injection + // Start by copying all incoming parent IDs (from ancestors) parentIDs := make(map[string]interface{}) - if relInfo.ForeignKey != "" { + for k, v := range incomingParentIDs { + parentIDs[k] = v + } + logger.Debug("Inherited %d parent IDs from ancestors: %+v", len(incomingParentIDs), incomingParentIDs) + + // Add the current parent's primary key to the parentIDs map + // This ensures nested children have access to all ancestor IDs + if parentID != nil && parentModelType != nil { + // Get the parent model's primary key field name + parentPKFieldName := reflection.GetPrimaryKeyName(parentModelType) + if parentPKFieldName != "" { + // Get the JSON name for the primary key field + parentPKJSONName := reflection.GetJSONNameForField(parentModelType, parentPKFieldName) + baseName := "" + if len(parentPKJSONName) > 1 { + baseName = parentPKJSONName + } else { + // Add parent's PK to the map using the base model name + baseName = strings.TrimSuffix(parentPKFieldName, "ID") + baseName = strings.TrimSuffix(strings.ToLower(baseName), "_id") + if baseName == "" { + baseName = "parent" + } + } + + parentIDs[baseName] = parentID + logger.Debug("Added current parent PK to parentIDs map: %s=%v (from field %s)", baseName, parentID, parentPKFieldName) + } + } + + // Also add the foreign key reference if specified + if relInfo.ForeignKey != "" && parentID != nil { // Extract the base name from foreign key (e.g., "DepartmentID" -> "Department") baseName := strings.TrimSuffix(relInfo.ForeignKey, "ID") baseName = strings.TrimSuffix(strings.ToLower(baseName), "_id") - parentIDs[baseName] = parentID + // Only add if different from what we already added + if _, exists := parentIDs[baseName]; !exists { + parentIDs[baseName] = parentID + logger.Debug("Added foreign key to parentIDs map: %s=%v (from FK %s)", baseName, parentID, relInfo.ForeignKey) + } + } + + logger.Debug("Final parentIDs map for relation %s: %+v", relationName, parentIDs) + + // Determine which field name to use for setting parent ID in child data + // Priority: Use foreign key field name if specified + var foreignKeyFieldName string + if relInfo.ForeignKey != "" { + // Get the JSON name for the foreign key field in the child model + foreignKeyFieldName = reflection.GetJSONNameForField(relatedModelType, relInfo.ForeignKey) + if foreignKeyFieldName == "" { + // Fallback to lowercase field name + foreignKeyFieldName = strings.ToLower(relInfo.ForeignKey) + } + logger.Debug("Using foreign key field for direct assignment: %s (from FK %s)", foreignKeyFieldName, relInfo.ForeignKey) } // Process based on relation type and data structure switch v := relationValue.(type) { case map[string]interface{}: - // Single related object + // Single related object - directly set foreign key if specified + if parentID != nil && foreignKeyFieldName != "" { + v[foreignKeyFieldName] = parentID + logger.Debug("Set foreign key in single relation: %s=%v", foreignKeyFieldName, parentID) + } _, err := p.ProcessNestedCUD(ctx, operation, v, relatedModel, parentIDs, relatedTableName) if err != nil { + logger.Error("Failed to process single relation: name=%s, table=%s, operation=%s, parentID=%v, data=%+v, error=%v", + relationName, relatedTableName, operation, parentID, v, err) return fmt.Errorf("failed to process relation %s: %w", relationName, err) } @@ -334,24 +419,40 @@ func (p *NestedCUDProcessor) processChildRelations( // Multiple related objects for i, item := range v { if itemMap, ok := item.(map[string]interface{}); ok { + // Directly set foreign key if specified + if parentID != nil && foreignKeyFieldName != "" { + itemMap[foreignKeyFieldName] = parentID + logger.Debug("Set foreign key in relation array[%d]: %s=%v", i, foreignKeyFieldName, parentID) + } _, err := p.ProcessNestedCUD(ctx, operation, itemMap, relatedModel, parentIDs, relatedTableName) if err != nil { + logger.Error("Failed to process relation array item: name=%s[%d], table=%s, operation=%s, parentID=%v, data=%+v, error=%v", + relationName, i, relatedTableName, operation, parentID, itemMap, err) return fmt.Errorf("failed to process relation %s[%d]: %w", relationName, i, err) } + } else { + logger.Warn("Relation array item is not a map: name=%s[%d], type=%T", relationName, i, item) } } case []map[string]interface{}: // Multiple related objects (typed slice) for i, itemMap := range v { + // Directly set foreign key if specified + if parentID != nil && foreignKeyFieldName != "" { + itemMap[foreignKeyFieldName] = parentID + logger.Debug("Set foreign key in relation typed array[%d]: %s=%v", i, foreignKeyFieldName, parentID) + } _, err := p.ProcessNestedCUD(ctx, operation, itemMap, relatedModel, parentIDs, relatedTableName) if err != nil { + logger.Error("Failed to process relation typed array item: name=%s[%d], table=%s, operation=%s, parentID=%v, data=%+v, error=%v", + relationName, i, relatedTableName, operation, parentID, itemMap, err) return fmt.Errorf("failed to process relation %s[%d]: %w", relationName, i, err) } } default: - logger.Warn("Unsupported relation data type for %s: %T", relationName, relationValue) + logger.Error("Unsupported relation data type: name=%s, type=%T, value=%+v", relationName, relationValue, relationValue) } } diff --git a/pkg/common/recursive_crud_test.go b/pkg/common/recursive_crud_test.go new file mode 100644 index 0000000..9bda8bb --- /dev/null +++ b/pkg/common/recursive_crud_test.go @@ -0,0 +1,720 @@ +package common + +import ( + "context" + "reflect" + "testing" + + "github.com/bitechdev/ResolveSpec/pkg/reflection" +) + +// Mock Database for testing +type mockDatabase struct { + insertCalls []map[string]interface{} + updateCalls []map[string]interface{} + deleteCalls []interface{} + lastID int64 +} + +func newMockDatabase() *mockDatabase { + return &mockDatabase{ + insertCalls: make([]map[string]interface{}, 0), + updateCalls: make([]map[string]interface{}, 0), + deleteCalls: make([]interface{}, 0), + lastID: 1, + } +} + +func (m *mockDatabase) NewSelect() SelectQuery { return &mockSelectQuery{} } +func (m *mockDatabase) NewInsert() InsertQuery { return &mockInsertQuery{db: m} } +func (m *mockDatabase) NewUpdate() UpdateQuery { return &mockUpdateQuery{db: m} } +func (m *mockDatabase) NewDelete() DeleteQuery { return &mockDeleteQuery{db: m} } +func (m *mockDatabase) RunInTransaction(ctx context.Context, fn func(Database) error) error { + return fn(m) +} +func (m *mockDatabase) Exec(ctx context.Context, query string, args ...interface{}) (Result, error) { + return &mockResult{rowsAffected: 1}, nil +} +func (m *mockDatabase) Query(ctx context.Context, dest interface{}, query string, args ...interface{}) error { + return nil +} +func (m *mockDatabase) BeginTx(ctx context.Context) (Database, error) { + return m, nil +} +func (m *mockDatabase) CommitTx(ctx context.Context) error { + return nil +} +func (m *mockDatabase) RollbackTx(ctx context.Context) error { + return nil +} +func (m *mockDatabase) GetUnderlyingDB() interface{} { + return nil +} + +// Mock SelectQuery +type mockSelectQuery struct{} + +func (m *mockSelectQuery) Model(model interface{}) SelectQuery { return m } +func (m *mockSelectQuery) Table(name string) SelectQuery { return m } +func (m *mockSelectQuery) Column(columns ...string) SelectQuery { return m } +func (m *mockSelectQuery) ColumnExpr(query string, args ...interface{}) SelectQuery { return m } +func (m *mockSelectQuery) Where(condition string, args ...interface{}) SelectQuery { return m } +func (m *mockSelectQuery) WhereOr(query string, args ...interface{}) SelectQuery { return m } +func (m *mockSelectQuery) Join(query string, args ...interface{}) SelectQuery { return m } +func (m *mockSelectQuery) LeftJoin(query string, args ...interface{}) SelectQuery { return m } +func (m *mockSelectQuery) Preload(relation string, conditions ...interface{}) SelectQuery { return m } +func (m *mockSelectQuery) PreloadRelation(relation string, apply ...func(SelectQuery) SelectQuery) SelectQuery { return m } +func (m *mockSelectQuery) JoinRelation(relation string, apply ...func(SelectQuery) SelectQuery) SelectQuery { return m } +func (m *mockSelectQuery) Order(order string) SelectQuery { return m } +func (m *mockSelectQuery) OrderExpr(order string, args ...interface{}) SelectQuery { return m } +func (m *mockSelectQuery) Limit(n int) SelectQuery { return m } +func (m *mockSelectQuery) Offset(n int) SelectQuery { return m } +func (m *mockSelectQuery) Group(group string) SelectQuery { return m } +func (m *mockSelectQuery) Having(condition string, args ...interface{}) SelectQuery { return m } +func (m *mockSelectQuery) Scan(ctx context.Context, dest interface{}) error { return nil } +func (m *mockSelectQuery) ScanModel(ctx context.Context) error { return nil } +func (m *mockSelectQuery) Count(ctx context.Context) (int, error) { return 0, nil } +func (m *mockSelectQuery) Exists(ctx context.Context) (bool, error) { return false, nil } + +// Mock InsertQuery +type mockInsertQuery struct { + db *mockDatabase + table string + values map[string]interface{} +} + +func (m *mockInsertQuery) Model(model interface{}) InsertQuery { return m } +func (m *mockInsertQuery) Table(name string) InsertQuery { + m.table = name + return m +} +func (m *mockInsertQuery) Value(column string, value interface{}) InsertQuery { + if m.values == nil { + m.values = make(map[string]interface{}) + } + m.values[column] = value + return m +} +func (m *mockInsertQuery) OnConflict(action string) InsertQuery { return m } +func (m *mockInsertQuery) Returning(columns ...string) InsertQuery { return m } +func (m *mockInsertQuery) Exec(ctx context.Context) (Result, error) { + // Record the insert call + m.db.insertCalls = append(m.db.insertCalls, m.values) + m.db.lastID++ + return &mockResult{lastID: m.db.lastID, rowsAffected: 1}, nil +} + +// Mock UpdateQuery +type mockUpdateQuery struct { + db *mockDatabase + table string + setValues map[string]interface{} +} + +func (m *mockUpdateQuery) Model(model interface{}) UpdateQuery { return m } +func (m *mockUpdateQuery) Table(name string) UpdateQuery { + m.table = name + return m +} +func (m *mockUpdateQuery) Set(column string, value interface{}) UpdateQuery { return m } +func (m *mockUpdateQuery) SetMap(values map[string]interface{}) UpdateQuery { + m.setValues = values + return m +} +func (m *mockUpdateQuery) Where(condition string, args ...interface{}) UpdateQuery { return m } +func (m *mockUpdateQuery) Returning(columns ...string) UpdateQuery { return m } +func (m *mockUpdateQuery) Exec(ctx context.Context) (Result, error) { + // Record the update call + m.db.updateCalls = append(m.db.updateCalls, m.setValues) + return &mockResult{rowsAffected: 1}, nil +} + +// Mock DeleteQuery +type mockDeleteQuery struct { + db *mockDatabase + table string +} + +func (m *mockDeleteQuery) Model(model interface{}) DeleteQuery { return m } +func (m *mockDeleteQuery) Table(name string) DeleteQuery { + m.table = name + return m +} +func (m *mockDeleteQuery) Where(condition string, args ...interface{}) DeleteQuery { return m } +func (m *mockDeleteQuery) Exec(ctx context.Context) (Result, error) { + // Record the delete call + m.db.deleteCalls = append(m.db.deleteCalls, m.table) + return &mockResult{rowsAffected: 1}, nil +} + +// Mock Result +type mockResult struct { + lastID int64 + rowsAffected int64 +} + +func (m *mockResult) LastInsertId() (int64, error) { return m.lastID, nil } +func (m *mockResult) RowsAffected() int64 { return m.rowsAffected } + +// Mock ModelRegistry +type mockModelRegistry struct{} + +func (m *mockModelRegistry) GetModel(name string) (interface{}, error) { return nil, nil } +func (m *mockModelRegistry) GetModelByEntity(schema, entity string) (interface{}, error) { return nil, nil } +func (m *mockModelRegistry) RegisterModel(name string, model interface{}) error { return nil } +func (m *mockModelRegistry) GetAllModels() map[string]interface{} { return make(map[string]interface{}) } + +// Mock RelationshipInfoProvider +type mockRelationshipProvider struct { + relationships map[string]*RelationshipInfo +} + +func newMockRelationshipProvider() *mockRelationshipProvider { + return &mockRelationshipProvider{ + relationships: make(map[string]*RelationshipInfo), + } +} + +func (m *mockRelationshipProvider) GetRelationshipInfo(modelType reflect.Type, relationName string) *RelationshipInfo { + key := modelType.Name() + "." + relationName + return m.relationships[key] +} + +func (m *mockRelationshipProvider) RegisterRelation(modelTypeName, relationName string, info *RelationshipInfo) { + key := modelTypeName + "." + relationName + m.relationships[key] = info +} + +// Test Models +type Department struct { + ID int64 `json:"id" bun:"id,pk"` + Name string `json:"name"` + Employees []*Employee `json:"employees,omitempty"` +} + +func (d Department) TableName() string { return "departments" } +func (d Department) GetIDName() string { return "ID" } + +type Employee struct { + ID int64 `json:"id" bun:"id,pk"` + Name string `json:"name"` + DepartmentID int64 `json:"department_id"` + Tasks []*Task `json:"tasks,omitempty"` +} + +func (e Employee) TableName() string { return "employees" } +func (e Employee) GetIDName() string { return "ID" } + +type Task struct { + ID int64 `json:"id" bun:"id,pk"` + Title string `json:"title"` + EmployeeID int64 `json:"employee_id"` + Comments []*Comment `json:"comments,omitempty"` +} + +func (t Task) TableName() string { return "tasks" } +func (t Task) GetIDName() string { return "ID" } + +type Comment struct { + ID int64 `json:"id" bun:"id,pk"` + Text string `json:"text"` + TaskID int64 `json:"task_id"` +} + +func (c Comment) TableName() string { return "comments" } +func (c Comment) GetIDName() string { return "ID" } + +// Test Cases + +func TestProcessNestedCUD_SingleLevelInsert(t *testing.T) { + db := newMockDatabase() + registry := &mockModelRegistry{} + relProvider := newMockRelationshipProvider() + + // Register Department -> Employees relationship + relProvider.RegisterRelation("Department", "employees", &RelationshipInfo{ + FieldName: "Employees", + JSONName: "employees", + RelationType: "has_many", + ForeignKey: "DepartmentID", + RelatedModel: Employee{}, + }) + + processor := NewNestedCUDProcessor(db, registry, relProvider) + + data := map[string]interface{}{ + "name": "Engineering", + "employees": []interface{}{ + map[string]interface{}{ + "name": "John Doe", + }, + map[string]interface{}{ + "name": "Jane Smith", + }, + }, + } + + result, err := processor.ProcessNestedCUD( + context.Background(), + "insert", + data, + Department{}, + nil, + "departments", + ) + + if err != nil { + t.Fatalf("ProcessNestedCUD failed: %v", err) + } + + if result.ID == nil { + t.Error("Expected result.ID to be set") + } + + // Verify department was inserted + if len(db.insertCalls) != 3 { + t.Errorf("Expected 3 insert calls (1 dept + 2 employees), got %d", len(db.insertCalls)) + } + + // Verify first insert is department + if db.insertCalls[0]["name"] != "Engineering" { + t.Errorf("Expected department name 'Engineering', got %v", db.insertCalls[0]["name"]) + } + + // Verify employees were inserted with foreign key + if db.insertCalls[1]["department_id"] == nil { + t.Error("Expected employee to have department_id set") + } + if db.insertCalls[2]["department_id"] == nil { + t.Error("Expected employee to have department_id set") + } +} + +func TestProcessNestedCUD_MultiLevelInsert(t *testing.T) { + db := newMockDatabase() + registry := &mockModelRegistry{} + relProvider := newMockRelationshipProvider() + + // Register relationships + relProvider.RegisterRelation("Department", "employees", &RelationshipInfo{ + FieldName: "Employees", + JSONName: "employees", + RelationType: "has_many", + ForeignKey: "DepartmentID", + RelatedModel: Employee{}, + }) + + relProvider.RegisterRelation("Employee", "tasks", &RelationshipInfo{ + FieldName: "Tasks", + JSONName: "tasks", + RelationType: "has_many", + ForeignKey: "EmployeeID", + RelatedModel: Task{}, + }) + + processor := NewNestedCUDProcessor(db, registry, relProvider) + + data := map[string]interface{}{ + "name": "Engineering", + "employees": []interface{}{ + map[string]interface{}{ + "name": "John Doe", + "tasks": []interface{}{ + map[string]interface{}{ + "title": "Task 1", + }, + map[string]interface{}{ + "title": "Task 2", + }, + }, + }, + }, + } + + result, err := processor.ProcessNestedCUD( + context.Background(), + "insert", + data, + Department{}, + nil, + "departments", + ) + + if err != nil { + t.Fatalf("ProcessNestedCUD failed: %v", err) + } + + if result.ID == nil { + t.Error("Expected result.ID to be set") + } + + // Verify: 1 dept + 1 employee + 2 tasks = 4 inserts + if len(db.insertCalls) != 4 { + t.Errorf("Expected 4 insert calls, got %d", len(db.insertCalls)) + } + + // Verify department + if db.insertCalls[0]["name"] != "Engineering" { + t.Errorf("Expected department name 'Engineering', got %v", db.insertCalls[0]["name"]) + } + + // Verify employee has department_id + if db.insertCalls[1]["department_id"] == nil { + t.Error("Expected employee to have department_id set") + } + + // Verify tasks have employee_id + if db.insertCalls[2]["employee_id"] == nil { + t.Error("Expected task to have employee_id set") + } + if db.insertCalls[3]["employee_id"] == nil { + t.Error("Expected task to have employee_id set") + } +} + +func TestProcessNestedCUD_RequestFieldOverride(t *testing.T) { + db := newMockDatabase() + registry := &mockModelRegistry{} + relProvider := newMockRelationshipProvider() + + relProvider.RegisterRelation("Department", "employees", &RelationshipInfo{ + FieldName: "Employees", + JSONName: "employees", + RelationType: "has_many", + ForeignKey: "DepartmentID", + RelatedModel: Employee{}, + }) + + processor := NewNestedCUDProcessor(db, registry, relProvider) + + data := map[string]interface{}{ + "name": "Engineering", + "employees": []interface{}{ + map[string]interface{}{ + "_request": "update", + "ID": int64(10), // Use capital ID to match struct field + "name": "John Updated", + }, + }, + } + + _, err := processor.ProcessNestedCUD( + context.Background(), + "insert", + data, + Department{}, + nil, + "departments", + ) + + if err != nil { + t.Fatalf("ProcessNestedCUD failed: %v", err) + } + + // Verify department was inserted (1 insert) + // Employee should be updated (1 update) + if len(db.insertCalls) != 1 { + t.Errorf("Expected 1 insert call for department, got %d", len(db.insertCalls)) + } + + if len(db.updateCalls) != 1 { + t.Errorf("Expected 1 update call for employee, got %d", len(db.updateCalls)) + } + + // Verify update data + if db.updateCalls[0]["name"] != "John Updated" { + t.Errorf("Expected employee name 'John Updated', got %v", db.updateCalls[0]["name"]) + } +} + +func TestProcessNestedCUD_SkipInsertWhenOnlyRequestField(t *testing.T) { + db := newMockDatabase() + registry := &mockModelRegistry{} + relProvider := newMockRelationshipProvider() + + relProvider.RegisterRelation("Department", "employees", &RelationshipInfo{ + FieldName: "Employees", + JSONName: "employees", + RelationType: "has_many", + ForeignKey: "DepartmentID", + RelatedModel: Employee{}, + }) + + processor := NewNestedCUDProcessor(db, registry, relProvider) + + // Data with only _request field for nested employee + data := map[string]interface{}{ + "name": "Engineering", + "employees": []interface{}{ + map[string]interface{}{ + "_request": "insert", + // No other fields besides _request + // Note: Foreign key will be injected, so employee WILL be inserted + }, + }, + } + + _, err := processor.ProcessNestedCUD( + context.Background(), + "insert", + data, + Department{}, + nil, + "departments", + ) + + if err != nil { + t.Fatalf("ProcessNestedCUD failed: %v", err) + } + + // Department + Employee (with injected FK) = 2 inserts + if len(db.insertCalls) != 2 { + t.Errorf("Expected 2 insert calls (department + employee with FK), got %d", len(db.insertCalls)) + } + + if db.insertCalls[0]["name"] != "Engineering" { + t.Errorf("Expected department name 'Engineering', got %v", db.insertCalls[0]["name"]) + } + + // Verify employee has foreign key + if db.insertCalls[1]["department_id"] == nil { + t.Error("Expected employee to have department_id injected") + } +} + +func TestProcessNestedCUD_Update(t *testing.T) { + db := newMockDatabase() + registry := &mockModelRegistry{} + relProvider := newMockRelationshipProvider() + + relProvider.RegisterRelation("Department", "employees", &RelationshipInfo{ + FieldName: "Employees", + JSONName: "employees", + RelationType: "has_many", + ForeignKey: "DepartmentID", + RelatedModel: Employee{}, + }) + + processor := NewNestedCUDProcessor(db, registry, relProvider) + + data := map[string]interface{}{ + "ID": int64(1), // Use capital ID to match struct field + "name": "Engineering Updated", + "employees": []interface{}{ + map[string]interface{}{ + "_request": "insert", + "name": "New Employee", + }, + }, + } + + result, err := processor.ProcessNestedCUD( + context.Background(), + "update", + data, + Department{}, + nil, + "departments", + ) + + if err != nil { + t.Fatalf("ProcessNestedCUD failed: %v", err) + } + + if result.ID != int64(1) { + t.Errorf("Expected result.ID to be 1, got %v", result.ID) + } + + // Verify department was updated + if len(db.updateCalls) != 1 { + t.Errorf("Expected 1 update call, got %d", len(db.updateCalls)) + } + + // Verify new employee was inserted + if len(db.insertCalls) != 1 { + t.Errorf("Expected 1 insert call for new employee, got %d", len(db.insertCalls)) + } +} + +func TestProcessNestedCUD_Delete(t *testing.T) { + db := newMockDatabase() + registry := &mockModelRegistry{} + relProvider := newMockRelationshipProvider() + + relProvider.RegisterRelation("Department", "employees", &RelationshipInfo{ + FieldName: "Employees", + JSONName: "employees", + RelationType: "has_many", + ForeignKey: "DepartmentID", + RelatedModel: Employee{}, + }) + + processor := NewNestedCUDProcessor(db, registry, relProvider) + + data := map[string]interface{}{ + "ID": int64(1), // Use capital ID to match struct field + "employees": []interface{}{ + map[string]interface{}{ + "_request": "delete", + "ID": int64(10), // Use capital ID + }, + map[string]interface{}{ + "_request": "delete", + "ID": int64(11), // Use capital ID + }, + }, + } + + _, err := processor.ProcessNestedCUD( + context.Background(), + "delete", + data, + Department{}, + nil, + "departments", + ) + + if err != nil { + t.Fatalf("ProcessNestedCUD failed: %v", err) + } + + // Verify employees were deleted first, then department + // 2 employees + 1 department = 3 deletes + if len(db.deleteCalls) != 3 { + t.Errorf("Expected 3 delete calls, got %d", len(db.deleteCalls)) + } +} + +func TestProcessNestedCUD_ParentIDPropagation(t *testing.T) { + db := newMockDatabase() + registry := &mockModelRegistry{} + relProvider := newMockRelationshipProvider() + + // Register 3-level relationships + relProvider.RegisterRelation("Department", "employees", &RelationshipInfo{ + FieldName: "Employees", + JSONName: "employees", + RelationType: "has_many", + ForeignKey: "DepartmentID", + RelatedModel: Employee{}, + }) + + relProvider.RegisterRelation("Employee", "tasks", &RelationshipInfo{ + FieldName: "Tasks", + JSONName: "tasks", + RelationType: "has_many", + ForeignKey: "EmployeeID", + RelatedModel: Task{}, + }) + + relProvider.RegisterRelation("Task", "comments", &RelationshipInfo{ + FieldName: "Comments", + JSONName: "comments", + RelationType: "has_many", + ForeignKey: "TaskID", + RelatedModel: Comment{}, + }) + + processor := NewNestedCUDProcessor(db, registry, relProvider) + + data := map[string]interface{}{ + "name": "Engineering", + "employees": []interface{}{ + map[string]interface{}{ + "name": "John", + "tasks": []interface{}{ + map[string]interface{}{ + "title": "Task 1", + "comments": []interface{}{ + map[string]interface{}{ + "text": "Great work!", + }, + }, + }, + }, + }, + }, + } + + _, err := processor.ProcessNestedCUD( + context.Background(), + "insert", + data, + Department{}, + nil, + "departments", + ) + + if err != nil { + t.Fatalf("ProcessNestedCUD failed: %v", err) + } + + // Verify: 1 dept + 1 employee + 1 task + 1 comment = 4 inserts + if len(db.insertCalls) != 4 { + t.Errorf("Expected 4 insert calls, got %d", len(db.insertCalls)) + } + + // Verify department + if db.insertCalls[0]["name"] != "Engineering" { + t.Error("Expected department to be inserted first") + } + + // Verify employee has department_id + if db.insertCalls[1]["department_id"] == nil { + t.Error("Expected employee to have department_id") + } + + // Verify task has employee_id + if db.insertCalls[2]["employee_id"] == nil { + t.Error("Expected task to have employee_id") + } + + // Verify comment has task_id + if db.insertCalls[3]["task_id"] == nil { + t.Error("Expected comment to have task_id") + } +} + +func TestInjectForeignKeys(t *testing.T) { + db := newMockDatabase() + registry := &mockModelRegistry{} + relProvider := newMockRelationshipProvider() + + processor := NewNestedCUDProcessor(db, registry, relProvider) + + data := map[string]interface{}{ + "name": "John", + } + + parentIDs := map[string]interface{}{ + "department": int64(5), + } + + modelType := reflect.TypeOf(Employee{}) + + processor.injectForeignKeys(data, modelType, parentIDs) + + // Should inject department_id based on the "department" key in parentIDs + if data["department_id"] == nil { + t.Error("Expected department_id to be injected") + } + + if data["department_id"] != int64(5) { + t.Errorf("Expected department_id to be 5, got %v", data["department_id"]) + } +} + +func TestGetPrimaryKeyName(t *testing.T) { + dept := Department{} + pkName := reflection.GetPrimaryKeyName(dept) + + if pkName != "ID" { + t.Errorf("Expected primary key name 'ID', got '%s'", pkName) + } + + // Test with pointer + pkName2 := reflection.GetPrimaryKeyName(&dept) + if pkName2 != "ID" { + t.Errorf("Expected primary key name 'ID' from pointer, got '%s'", pkName2) + } +} diff --git a/pkg/reflection/helpers.go b/pkg/reflection/helpers.go index 155f30c..2ae1a88 100644 --- a/pkg/reflection/helpers.go +++ b/pkg/reflection/helpers.go @@ -1,6 +1,9 @@ package reflection -import "reflect" +import ( + "reflect" + "strings" +) func Len(v any) int { val := reflect.ValueOf(v) @@ -64,3 +67,41 @@ func GetPointerElement(v reflect.Type) reflect.Type { } return v } + +// GetJSONNameForField gets the JSON tag name for a struct field. +// Returns the JSON field name from the json struct tag, or an empty string if not found. +// Handles the "json" tag format: "name", "name,omitempty", etc. +func GetJSONNameForField(modelType reflect.Type, fieldName string) string { + if modelType == nil { + return "" + } + + // Handle pointer types + if modelType.Kind() == reflect.Ptr { + modelType = modelType.Elem() + } + + if modelType.Kind() != reflect.Struct { + return "" + } + + // Find the field + field, found := modelType.FieldByName(fieldName) + if !found { + return "" + } + + // Get the JSON tag + jsonTag := field.Tag.Get("json") + if jsonTag == "" { + return "" + } + + // Parse the tag (format: "name,omitempty" or just "name") + parts := strings.Split(jsonTag, ",") + if len(parts) > 0 && parts[0] != "" && parts[0] != "-" { + return parts[0] + } + + return "" +} diff --git a/pkg/restheadspec/handler.go b/pkg/restheadspec/handler.go index 5118073..c4c2a3d 100644 --- a/pkg/restheadspec/handler.go +++ b/pkg/restheadspec/handler.go @@ -1794,10 +1794,36 @@ func (h *Handler) processChildRelationsForField( parentIDs[baseName] = parentID } + // Determine which field name to use for setting parent ID in child data + // Priority: Use foreign key field name if specified, otherwise use parent's PK name + var foreignKeyFieldName string + if relInfo.ForeignKey != "" { + // Get the JSON name for the foreign key field in the child model + foreignKeyFieldName = reflection.GetJSONNameForField(relatedModelType, relInfo.ForeignKey) + if foreignKeyFieldName == "" { + // Fallback to lowercase field name + foreignKeyFieldName = strings.ToLower(relInfo.ForeignKey) + } + } else { + // Fallback: use parent's primary key name + parentPKName := reflection.GetPrimaryKeyName(parentModelType) + foreignKeyFieldName = reflection.GetJSONNameForField(parentModelType, parentPKName) + if foreignKeyFieldName == "" { + foreignKeyFieldName = strings.ToLower(parentPKName) + } + } + + logger.Debug("Setting parent ID in child data: foreignKeyField=%s, parentID=%v, relForeignKey=%s", + foreignKeyFieldName, parentID, relInfo.ForeignKey) + // Process based on relation type and data structure switch v := relationValue.(type) { case map[string]interface{}: - // Single related object + // Single related object - add parent ID to foreign key field + if parentID != nil && foreignKeyFieldName != "" { + v[foreignKeyFieldName] = parentID + logger.Debug("Set foreign key in single relation: %s=%v", foreignKeyFieldName, parentID) + } _, err := processor.ProcessNestedCUD(ctx, operation, v, relatedModel, parentIDs, relatedTableName) if err != nil { return fmt.Errorf("failed to process single relation: %w", err) @@ -1807,6 +1833,11 @@ func (h *Handler) processChildRelationsForField( // Multiple related objects for i, item := range v { if itemMap, ok := item.(map[string]interface{}); ok { + // Add parent ID to foreign key field + if parentID != nil && foreignKeyFieldName != "" { + itemMap[foreignKeyFieldName] = parentID + logger.Debug("Set foreign key in relation array[%d]: %s=%v", i, foreignKeyFieldName, parentID) + } _, err := processor.ProcessNestedCUD(ctx, operation, itemMap, relatedModel, parentIDs, relatedTableName) if err != nil { return fmt.Errorf("failed to process relation item %d: %w", i, err) @@ -1817,6 +1848,11 @@ func (h *Handler) processChildRelationsForField( case []map[string]interface{}: // Multiple related objects (typed slice) for i, itemMap := range v { + // Add parent ID to foreign key field + if parentID != nil && foreignKeyFieldName != "" { + itemMap[foreignKeyFieldName] = parentID + logger.Debug("Set foreign key in relation typed array[%d]: %s=%v", i, foreignKeyFieldName, parentID) + } _, err := processor.ProcessNestedCUD(ctx, operation, itemMap, relatedModel, parentIDs, relatedTableName) if err != nil { return fmt.Errorf("failed to process relation item %d: %w", i, err)