From 14daea3b056928edc08330ecb369232ffbb4ad05 Mon Sep 17 00:00:00 2001 From: Hein Date: Wed, 19 Nov 2025 15:08:04 +0200 Subject: [PATCH] Fixes for CUD operations --- pkg/reflection/model_utils.go | 56 +++ pkg/restheadspec/handler.go | 587 ++++++++++++++---------- pkg/restheadspec/handler_nested_test.go | 393 ++++++++++++++++ 3 files changed, 799 insertions(+), 237 deletions(-) create mode 100644 pkg/restheadspec/handler_nested_test.go diff --git a/pkg/reflection/model_utils.go b/pkg/reflection/model_utils.go index 0ca6a94..2363ca4 100644 --- a/pkg/reflection/model_utils.go +++ b/pkg/reflection/model_utils.go @@ -45,6 +45,62 @@ func GetPrimaryKeyName(model any) string { return "" } +// GetPrimaryKeyValue extracts the primary key value from a model instance +// Returns the value of the primary key field +func GetPrimaryKeyValue(model any) interface{} { + if model == nil || reflect.TypeOf(model) == nil { + return nil + } + + val := reflect.ValueOf(model) + if val.Kind() == reflect.Pointer { + val = val.Elem() + } + + if val.Kind() != reflect.Struct { + return nil + } + + typ := val.Type() + + // Try Bun tag first + for i := 0; i < typ.NumField(); i++ { + field := typ.Field(i) + bunTag := field.Tag.Get("bun") + if strings.Contains(bunTag, "pk") { + fieldValue := val.Field(i) + if fieldValue.CanInterface() { + return fieldValue.Interface() + } + } + } + + // Fall back to GORM tag + for i := 0; i < typ.NumField(); i++ { + field := typ.Field(i) + gormTag := field.Tag.Get("gorm") + if strings.Contains(gormTag, "primaryKey") { + fieldValue := val.Field(i) + if fieldValue.CanInterface() { + return fieldValue.Interface() + } + } + } + + // Last resort: look for field named "ID" or "Id" + for i := 0; i < typ.NumField(); i++ { + field := typ.Field(i) + if strings.ToLower(field.Name) == "id" { + fieldValue := val.Field(i) + if fieldValue.CanInterface() { + return fieldValue.Interface() + } + } + } + + return nil +} + // GetModelColumns extracts all column names from a model using reflection // It checks bun tags first, then gorm tags, then json tags, and finally falls back to lowercase field names func GetModelColumns(model any) []string { diff --git a/pkg/restheadspec/handler.go b/pkg/restheadspec/handler.go index a6d8fe4..32c6e90 100644 --- a/pkg/restheadspec/handler.go +++ b/pkg/restheadspec/handler.go @@ -584,22 +584,6 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat logger.Info("Creating record in %s.%s", schema, entity) - // Check if data is a single map with nested relations - if dataMap, ok := data.(map[string]interface{}); ok { - if h.shouldUseNestedProcessor(dataMap, model) { - logger.Info("Using nested CUD processor for create operation") - result, err := h.nestedProcessor.ProcessNestedCUD(ctx, "insert", dataMap, model, make(map[string]interface{}), tableName) - if err != nil { - logger.Error("Error in nested create: %v", err) - h.sendError(w, http.StatusInternalServerError, "create_error", "Error creating record with nested data", err) - return - } - logger.Info("Successfully created record with nested data, ID: %v", result.ID) - h.sendResponseWithOptions(w, result.Data, nil, &options) - return - } - } - // Execute BeforeCreate hooks hookCtx := &HookContext{ Context: ctx, @@ -622,172 +606,113 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat // Use potentially modified data from hook context data = hookCtx.Data - // Handle batch creation - dataValue := reflect.ValueOf(data) - if dataValue.Kind() == reflect.Slice || dataValue.Kind() == reflect.Array { - logger.Debug("Batch creation detected, count: %d", dataValue.Len()) + // Normalize data to slice for unified processing + dataSlice := h.normalizeToSlice(data) + logger.Debug("Processing %d item(s) for creation", len(dataSlice)) - // Check if any item needs nested processing - hasNestedData := false - for i := 0; i < dataValue.Len(); i++ { - item := dataValue.Index(i).Interface() - if itemMap, ok := item.(map[string]interface{}); ok { - if h.shouldUseNestedProcessor(itemMap, model) { - hasNestedData = true - break - } - } - } + // Process all items in a transaction + results := make([]interface{}, 0, len(dataSlice)) + err := h.db.RunInTransaction(ctx, func(tx common.Database) error { + // Create temporary nested processor with transaction + txNestedProcessor := common.NewNestedCUDProcessor(tx, h.registry, h) - if hasNestedData { - logger.Info("Using nested CUD processor for batch create with nested data") - results := make([]interface{}, 0, dataValue.Len()) - err := h.db.RunInTransaction(ctx, func(tx common.Database) error { - // Temporarily swap the database to use transaction - originalDB := h.nestedProcessor - h.nestedProcessor = common.NewNestedCUDProcessor(tx, h.registry, h) - defer func() { - h.nestedProcessor = originalDB - }() - - for i := 0; i < dataValue.Len(); i++ { - item := dataValue.Index(i).Interface() - if itemMap, ok := item.(map[string]interface{}); ok { - result, err := h.nestedProcessor.ProcessNestedCUD(ctx, "insert", itemMap, model, make(map[string]interface{}), tableName) - if err != nil { - return fmt.Errorf("failed to process item: %w", err) - } - results = append(results, result.Data) - } - } - return nil - }) - if err != nil { - logger.Error("Error creating records with nested data: %v", err) - h.sendError(w, http.StatusInternalServerError, "create_error", "Error creating records with nested data", err) - return - } - - // Execute AfterCreate hooks - hookCtx.Result = map[string]interface{}{"created": len(results), "data": results} - hookCtx.Error = nil - - if err := h.hooks.Execute(AfterCreate, hookCtx); err != nil { - logger.Error("AfterCreate hook failed: %v", err) - h.sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err) - return - } - - logger.Info("Successfully created %d records with nested data", len(results)) - h.sendResponseWithOptions(w, results, nil, &options) - return - } - - // Standard batch insert without nested relations - // Use transaction for batch insert - err := h.db.RunInTransaction(ctx, func(tx common.Database) error { - for i := 0; i < dataValue.Len(); i++ { - item := dataValue.Index(i).Interface() - - // Convert item to model type - create a pointer to the model - modelValue := reflect.New(reflect.TypeOf(model)).Interface() + for i, item := range dataSlice { + itemMap, ok := item.(map[string]interface{}) + if !ok { + // Convert to map if needed jsonData, err := json.Marshal(item) if err != nil { - return fmt.Errorf("failed to marshal item: %w", err) + return fmt.Errorf("failed to marshal item %d: %w", i, err) } - if err := json.Unmarshal(jsonData, modelValue); err != nil { - return fmt.Errorf("failed to unmarshal item: %w", err) - } - - query := tx.NewInsert().Model(modelValue).Table(tableName) - - // Execute BeforeScan hooks - pass query chain so hooks can modify it - batchHookCtx := &HookContext{ - Context: ctx, - Handler: h, - Schema: schema, - Entity: entity, - TableName: tableName, - Model: model, - Options: options, - Data: modelValue, - Writer: w, - Query: query, - } - if err := h.hooks.Execute(BeforeScan, batchHookCtx); err != nil { - return fmt.Errorf("BeforeScan hook failed: %w", err) - } - - // Use potentially modified query from hook context - if modifiedQuery, ok := batchHookCtx.Query.(common.InsertQuery); ok { - query = modifiedQuery - } - - if _, err := query.Exec(ctx); err != nil { - return fmt.Errorf("failed to insert record: %w", err) + itemMap = make(map[string]interface{}) + if err := json.Unmarshal(jsonData, &itemMap); err != nil { + return fmt.Errorf("failed to unmarshal item %d: %w", i, err) } } - return nil - }) - if err != nil { - logger.Error("Error creating records: %v", err) - h.sendError(w, http.StatusInternalServerError, "create_error", "Error creating records", err) - return + // Extract nested relations if present (but don't process them yet) + var nestedRelations map[string]interface{} + if h.shouldUseNestedProcessor(itemMap, model) { + logger.Debug("Extracting nested relations for item %d", i) + cleanedData, relations, err := h.extractNestedRelations(itemMap, model) + if err != nil { + return fmt.Errorf("failed to extract nested relations for item %d: %w", i, err) + } + itemMap = cleanedData + nestedRelations = relations + } + + // Convert item to model type - create a pointer to the model + modelValue := reflect.New(reflect.TypeOf(model)).Interface() + jsonData, err := json.Marshal(itemMap) + if err != nil { + return fmt.Errorf("failed to marshal item %d: %w", i, err) + } + if err := json.Unmarshal(jsonData, modelValue); err != nil { + return fmt.Errorf("failed to unmarshal item %d: %w", i, err) + } + + // Create insert query + query := tx.NewInsert().Model(modelValue).Table(tableName).Returning("*") + + // Execute BeforeScan hooks - pass query chain so hooks can modify it + itemHookCtx := &HookContext{ + Context: ctx, + Handler: h, + Schema: schema, + Entity: entity, + TableName: tableName, + Model: model, + Options: options, + Data: modelValue, + Writer: w, + Query: query, + } + if err := h.hooks.Execute(BeforeScan, itemHookCtx); err != nil { + return fmt.Errorf("BeforeScan hook failed for item %d: %w", i, err) + } + + // Use potentially modified query from hook context + if modifiedQuery, ok := itemHookCtx.Query.(common.InsertQuery); ok { + query = modifiedQuery + } + + // Execute insert and get the ID + if _, err := query.Exec(ctx); err != nil { + return fmt.Errorf("failed to insert item %d: %w", i, err) + } + + // Get the inserted ID + insertedID := reflection.GetPrimaryKeyValue(modelValue) + + // Now process nested relations with the parent ID + if len(nestedRelations) > 0 { + logger.Debug("Processing nested relations for item %d with parent ID: %v", i, insertedID) + if err := h.processChildRelationsWithParentID(ctx, txNestedProcessor, "insert", nestedRelations, model, insertedID); err != nil { + return fmt.Errorf("failed to process nested relations for item %d: %w", i, err) + } + } + + results = append(results, modelValue) } + return nil + }) - // Execute AfterCreate hooks for batch creation - hookCtx.Result = map[string]interface{}{"created": dataValue.Len()} - hookCtx.Error = nil - - if err := h.hooks.Execute(AfterCreate, hookCtx); err != nil { - logger.Error("AfterCreate hook failed: %v", err) - h.sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err) - return - } - - h.sendResponse(w, map[string]interface{}{"created": dataValue.Len()}, nil) - return - } - - // Single record creation - create a pointer to the model - modelValue := reflect.New(reflect.TypeOf(model)).Interface() - jsonData, err := json.Marshal(data) if err != nil { - logger.Error("Error marshaling data: %v", err) - h.sendError(w, http.StatusBadRequest, "invalid_data", "Invalid data format", err) - return - } - if err := json.Unmarshal(jsonData, modelValue); err != nil { - logger.Error("Error unmarshaling data: %v", err) - h.sendError(w, http.StatusBadRequest, "invalid_data", "Invalid data format", err) + logger.Error("Error creating records: %v", err) + h.sendError(w, http.StatusInternalServerError, "create_error", "Error creating records", err) return } - query := h.db.NewInsert().Model(modelValue).Table(tableName) - - // Execute BeforeScan hooks - pass query chain so hooks can modify it - hookCtx.Data = modelValue - hookCtx.Query = query - if err := h.hooks.Execute(BeforeScan, hookCtx); err != nil { - logger.Error("BeforeScan hook failed: %v", err) - h.sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err) - return + // Execute AfterCreate hooks + var responseData interface{} + if len(results) == 1 { + responseData = results[0] + hookCtx.Result = results[0] + } else { + responseData = results + hookCtx.Result = map[string]interface{}{"created": len(results), "data": results} } - - // Use potentially modified query from hook context - if modifiedQuery, ok := hookCtx.Query.(common.InsertQuery); ok { - query = modifiedQuery - } - - if _, err := query.Exec(ctx); err != nil { - logger.Error("Error creating record: %v", err) - h.sendError(w, http.StatusInternalServerError, "create_error", "Error creating record", err) - return - } - - // Execute AfterCreate hooks for single record creation - hookCtx.Result = modelValue hookCtx.Error = nil if err := h.hooks.Execute(AfterCreate, hookCtx); err != nil { @@ -796,7 +721,8 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat return } - h.sendResponseWithOptions(w, modelValue, nil, &options) + logger.Info("Successfully created %d record(s)", len(results)) + h.sendResponseWithOptions(w, responseData, nil, &options) } func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id string, idPtr *int64, data interface{}, options ExtendedRequestOptions) { @@ -814,46 +740,6 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id logger.Info("Updating record in %s.%s", schema, entity) - // Convert data to map first for nested processor check - dataMap, ok := data.(map[string]interface{}) - if !ok { - jsonData, err := json.Marshal(data) - if err != nil { - logger.Error("Error marshaling data: %v", err) - h.sendError(w, http.StatusBadRequest, "invalid_data", "Invalid data format", err) - return - } - if err := json.Unmarshal(jsonData, &dataMap); err != nil { - logger.Error("Error unmarshaling data: %v", err) - h.sendError(w, http.StatusBadRequest, "invalid_data", "Invalid data format", err) - return - } - } - - // Check if we should use nested processing - if h.shouldUseNestedProcessor(dataMap, model) { - logger.Info("Using nested CUD processor for update operation") - // Ensure ID is in the data map - var targetID interface{} - if id != "" { - targetID = id - } else if idPtr != nil { - targetID = *idPtr - } - if targetID != nil { - dataMap["id"] = targetID - } - result, err := h.nestedProcessor.ProcessNestedCUD(ctx, "update", dataMap, model, make(map[string]interface{}), tableName) - if err != nil { - logger.Error("Error in nested update: %v", err) - h.sendError(w, http.StatusInternalServerError, "update_error", "Error updating record with nested data", err) - return - } - logger.Info("Successfully updated record with nested data, rows: %d", result.AffectedRows) - h.sendResponseWithOptions(w, result.Data, nil, &options) - return - } - // Execute BeforeUpdate hooks hookCtx := &HookContext{ Context: ctx, @@ -877,8 +763,8 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id // Use potentially modified data from hook context data = hookCtx.Data - // Convert data to map (again if modified by hooks) - dataMap, ok = data.(map[string]interface{}) + // Convert data to map + dataMap, ok := data.(map[string]interface{}) if !ok { jsonData, err := json.Marshal(data) if err != nil { @@ -893,33 +779,74 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id } } - query := h.db.NewUpdate().Table(tableName).SetMap(dataMap) - pkName := reflection.GetPrimaryKeyName(model) - // Apply ID filter - switch { - case id != "": - query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id) - case idPtr != nil: - query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), *idPtr) - default: + // Determine target ID + var targetID interface{} + if id != "" { + targetID = id + } else if idPtr != nil { + targetID = *idPtr + } else { h.sendError(w, http.StatusBadRequest, "missing_id", "ID is required for update", nil) return } - // Execute BeforeScan hooks - pass query chain so hooks can modify it - hookCtx.Query = query - if err := h.hooks.Execute(BeforeScan, hookCtx); err != nil { - logger.Error("BeforeScan hook failed: %v", err) - h.sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err) - return - } + // Process nested relations if present + err := h.db.RunInTransaction(ctx, func(tx common.Database) error { + // Create temporary nested processor with transaction + txNestedProcessor := common.NewNestedCUDProcessor(tx, h.registry, h) - // Use potentially modified query from hook context - if modifiedQuery, ok := hookCtx.Query.(common.UpdateQuery); ok { - query = modifiedQuery - } + // Extract nested relations if present (but don't process them yet) + var nestedRelations map[string]interface{} + if h.shouldUseNestedProcessor(dataMap, model) { + logger.Debug("Extracting nested relations for update") + cleanedData, relations, err := h.extractNestedRelations(dataMap, model) + if err != nil { + return fmt.Errorf("failed to extract nested relations: %w", err) + } + dataMap = cleanedData + nestedRelations = relations + } + + // Ensure ID is in the data map for the update + dataMap["id"] = targetID + + // Create update query + query := tx.NewUpdate().Table(tableName).SetMap(dataMap) + pkName := reflection.GetPrimaryKeyName(model) + query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), targetID) + + // Execute BeforeScan hooks - pass query chain so hooks can modify it + hookCtx.Query = query + if err := h.hooks.Execute(BeforeScan, hookCtx); err != nil { + return fmt.Errorf("BeforeScan hook failed: %w", err) + } + + // Use potentially modified query from hook context + if modifiedQuery, ok := hookCtx.Query.(common.UpdateQuery); ok { + query = modifiedQuery + } + + // Execute update + result, err := query.Exec(ctx) + if err != nil { + return fmt.Errorf("failed to update record: %w", err) + } + + // Now process nested relations with the parent ID + if len(nestedRelations) > 0 { + logger.Debug("Processing nested relations for update with parent ID: %v", targetID) + if err := h.processChildRelationsWithParentID(ctx, txNestedProcessor, "update", nestedRelations, model, targetID); err != nil { + return fmt.Errorf("failed to process nested relations: %w", err) + } + } + + // Store result for hooks + hookCtx.Result = map[string]interface{}{ + "updated": result.RowsAffected(), + } + return nil + }) - result, err := query.Exec(ctx) if err != nil { logger.Error("Error updating record: %v", err) h.sendError(w, http.StatusInternalServerError, "update_error", "Error updating record", err) @@ -927,19 +854,15 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id } // Execute AfterUpdate hooks - responseData := map[string]interface{}{ - "updated": result.RowsAffected(), - } - hookCtx.Result = responseData hookCtx.Error = nil - if err := h.hooks.Execute(AfterUpdate, hookCtx); err != nil { logger.Error("AfterUpdate hook failed: %v", err) h.sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err) return } - h.sendResponseWithOptions(w, responseData, nil, &options) + logger.Info("Successfully updated record with ID: %v", targetID) + h.sendResponseWithOptions(w, hookCtx.Result, nil, &options) } func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id string, data interface{}) { @@ -1199,6 +1122,196 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id h.sendResponse(w, responseData, nil) } +// normalizeToSlice converts data to a slice. Single items become a 1-item slice. +func (h *Handler) normalizeToSlice(data interface{}) []interface{} { + if data == nil { + return []interface{}{} + } + + dataValue := reflect.ValueOf(data) + if dataValue.Kind() == reflect.Slice || dataValue.Kind() == reflect.Array { + result := make([]interface{}, dataValue.Len()) + for i := 0; i < dataValue.Len(); i++ { + result[i] = dataValue.Index(i).Interface() + } + return result + } + + // Single item - return as 1-item slice + return []interface{}{data} +} + +// extractNestedRelations extracts nested relations from data, returning cleaned data and relations +// This does NOT process the relations, just separates them for later processing +func (h *Handler) extractNestedRelations( + data map[string]interface{}, + model interface{}, +) (map[string]interface{}, map[string]interface{}, error) { + // Get model type for reflection + modelType := reflect.TypeOf(model) + for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) { + modelType = modelType.Elem() + } + + if modelType == nil || modelType.Kind() != reflect.Struct { + return data, nil, fmt.Errorf("model must be a struct type, got %v", modelType) + } + + // Separate relation fields from regular fields + cleanedData := make(map[string]interface{}) + relations := make(map[string]interface{}) + + for key, value := range data { + // Skip _request field + if key == "_request" { + continue + } + + // Check if this field is a relation + relInfo := h.GetRelationshipInfo(modelType, key) + if relInfo != nil { + logger.Debug("Found nested relation field: %s (type: %s)", key, relInfo.RelationType) + relations[key] = value + } else { + cleanedData[key] = value + } + } + + return cleanedData, relations, nil +} + +// processChildRelationsWithParentID processes nested relations with a parent ID +func (h *Handler) processChildRelationsWithParentID( + ctx context.Context, + processor *common.NestedCUDProcessor, + operation string, + relations map[string]interface{}, + parentModel interface{}, + parentID interface{}, +) error { + // Get model type for reflection + modelType := reflect.TypeOf(parentModel) + for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) { + modelType = modelType.Elem() + } + + if modelType == nil || modelType.Kind() != reflect.Struct { + return fmt.Errorf("model must be a struct type, got %v", modelType) + } + + // Process each relation + for relationName, relationValue := range relations { + if relationValue == nil { + continue + } + + // Get relationship info + relInfo := h.GetRelationshipInfo(modelType, relationName) + if relInfo == nil { + logger.Warn("No relationship info found for %s, skipping", relationName) + continue + } + + // Process this relation with parent ID + if err := h.processChildRelationsForField(ctx, processor, operation, relationName, relationValue, relInfo, modelType, parentID); err != nil { + return fmt.Errorf("failed to process relation %s: %w", relationName, err) + } + } + + return nil +} + +// processChildRelationsForField processes a single nested relation field +func (h *Handler) processChildRelationsForField( + ctx context.Context, + processor *common.NestedCUDProcessor, + operation string, + relationName string, + relationValue interface{}, + relInfo *common.RelationshipInfo, + parentModelType reflect.Type, + parentID interface{}, +) error { + if relationValue == nil { + return nil + } + + // Get the related model + field, found := parentModelType.FieldByName(relInfo.FieldName) + if !found { + return fmt.Errorf("field %s not found in model", relInfo.FieldName) + } + + // Get the model type for the relation + relatedModelType := field.Type + if relatedModelType.Kind() == reflect.Slice { + relatedModelType = relatedModelType.Elem() + } + if relatedModelType.Kind() == reflect.Ptr { + relatedModelType = relatedModelType.Elem() + } + + // Create an instance of the related model + relatedModel := reflect.New(relatedModelType).Elem().Interface() + + // Get table name for related model + relatedTableName := h.getTableNameForRelatedModel(relatedModel, relInfo.JSONName) + + // Prepare parent IDs for foreign key injection + parentIDs := make(map[string]interface{}) + if relInfo.ForeignKey != "" && parentID != nil { + baseName := strings.TrimSuffix(relInfo.ForeignKey, "ID") + baseName = strings.TrimSuffix(strings.ToLower(baseName), "_id") + parentIDs[baseName] = parentID + } + + // Process based on relation type and data structure + switch v := relationValue.(type) { + case map[string]interface{}: + // Single related object + _, err := processor.ProcessNestedCUD(ctx, operation, v, relatedModel, parentIDs, relatedTableName) + if err != nil { + return fmt.Errorf("failed to process single relation: %w", err) + } + + case []interface{}: + // Multiple related objects + for i, item := range v { + if itemMap, ok := item.(map[string]interface{}); ok { + _, err := processor.ProcessNestedCUD(ctx, operation, itemMap, relatedModel, parentIDs, relatedTableName) + if err != nil { + return fmt.Errorf("failed to process relation item %d: %w", i, err) + } + } + } + + case []map[string]interface{}: + // Multiple related objects (typed slice) + for i, itemMap := range v { + _, err := processor.ProcessNestedCUD(ctx, operation, itemMap, relatedModel, parentIDs, relatedTableName) + if err != nil { + return fmt.Errorf("failed to process relation item %d: %w", i, err) + } + } + + default: + return fmt.Errorf("unsupported relation data type: %T", relationValue) + } + + return nil +} + +// getTableNameForRelatedModel gets the table name for a related model +func (h *Handler) getTableNameForRelatedModel(model interface{}, defaultName string) string { + if provider, ok := model.(common.TableNameProvider); ok { + tableName := provider.TableName() + if tableName != "" { + return tableName + } + } + return defaultName +} + // qualifyColumnName ensures column name is fully qualified with table name if not already func (h *Handler) qualifyColumnName(columnName, fullTableName string) string { // Check if column already has a table/schema prefix (contains a dot) diff --git a/pkg/restheadspec/handler_nested_test.go b/pkg/restheadspec/handler_nested_test.go new file mode 100644 index 0000000..943ee3a --- /dev/null +++ b/pkg/restheadspec/handler_nested_test.go @@ -0,0 +1,393 @@ +package restheadspec + +import ( + "fmt" + "reflect" + "testing" +) + +// Test models for nested CRUD operations +type TestUser struct { + ID int64 `json:"id" bun:"id,pk,autoincrement"` + Name string `json:"name"` + Posts []TestPost `json:"posts" gorm:"foreignKey:UserID"` +} + +type TestPost struct { + ID int64 `json:"id" bun:"id,pk,autoincrement"` + UserID int64 `json:"user_id"` + Title string `json:"title"` + Comments []TestComment `json:"comments" gorm:"foreignKey:PostID"` +} + +type TestComment struct { + ID int64 `json:"id" bun:"id,pk,autoincrement"` + PostID int64 `json:"post_id"` + Content string `json:"content"` +} + +func (TestUser) TableName() string { return "users" } +func (TestPost) TableName() string { return "posts" } +func (TestComment) TableName() string { return "comments" } + +// Test extractNestedRelations function +func TestExtractNestedRelations(t *testing.T) { + // Create handler + registry := &mockRegistry{ + models: map[string]interface{}{ + "users": TestUser{}, + "posts": TestPost{}, + "comments": TestComment{}, + }, + } + handler := NewHandler(nil, registry) + + tests := []struct { + name string + data map[string]interface{} + model interface{} + expectedCleanCount int + expectedRelCount int + }{ + { + name: "User with posts", + data: map[string]interface{}{ + "name": "John Doe", + "posts": []map[string]interface{}{ + {"title": "Post 1"}, + }, + }, + model: TestUser{}, + expectedCleanCount: 1, // name + expectedRelCount: 1, // posts + }, + { + name: "Post with comments", + data: map[string]interface{}{ + "title": "Test Post", + "comments": []map[string]interface{}{ + {"content": "Comment 1"}, + {"content": "Comment 2"}, + }, + }, + model: TestPost{}, + expectedCleanCount: 1, // title + expectedRelCount: 1, // comments + }, + { + name: "User with nested posts and comments", + data: map[string]interface{}{ + "name": "Jane Doe", + "posts": []map[string]interface{}{ + { + "title": "Post 1", + "comments": []map[string]interface{}{ + {"content": "Comment 1"}, + }, + }, + }, + }, + model: TestUser{}, + expectedCleanCount: 1, // name + expectedRelCount: 1, // posts (which contains nested comments) + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cleanedData, relations, err := handler.extractNestedRelations(tt.data, tt.model) + if err != nil { + t.Errorf("extractNestedRelations() error = %v", err) + return + } + + if len(cleanedData) != tt.expectedCleanCount { + t.Errorf("Expected %d cleaned fields, got %d: %+v", tt.expectedCleanCount, len(cleanedData), cleanedData) + } + + if len(relations) != tt.expectedRelCount { + t.Errorf("Expected %d relation fields, got %d: %+v", tt.expectedRelCount, len(relations), relations) + } + + t.Logf("Cleaned data: %+v", cleanedData) + t.Logf("Relations: %+v", relations) + }) + } +} + +// Test shouldUseNestedProcessor function +func TestShouldUseNestedProcessor(t *testing.T) { + registry := &mockRegistry{ + models: map[string]interface{}{ + "users": TestUser{}, + "posts": TestPost{}, + }, + } + handler := NewHandler(nil, registry) + + tests := []struct { + name string + data map[string]interface{} + model interface{} + expected bool + }{ + { + name: "Data with nested posts", + data: map[string]interface{}{ + "name": "John", + "posts": []map[string]interface{}{ + {"title": "Post 1"}, + }, + }, + model: TestUser{}, + expected: true, + }, + { + name: "Data without nested relations", + data: map[string]interface{}{ + "name": "John", + }, + model: TestUser{}, + expected: false, + }, + { + name: "Data with _request field", + data: map[string]interface{}{ + "_request": "insert", + "name": "John", + }, + model: TestUser{}, + expected: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := handler.shouldUseNestedProcessor(tt.data, tt.model) + if result != tt.expected { + t.Errorf("shouldUseNestedProcessor() = %v, expected %v", result, tt.expected) + } + }) + } +} + +// Test normalizeToSlice function +func TestNormalizeToSlice(t *testing.T) { + registry := &mockRegistry{} + handler := NewHandler(nil, registry) + + tests := []struct { + name string + input interface{} + expected int // expected slice length + }{ + { + name: "Single object", + input: map[string]interface{}{"name": "John"}, + expected: 1, + }, + { + name: "Slice of objects", + input: []map[string]interface{}{ + {"name": "John"}, + {"name": "Jane"}, + }, + expected: 2, + }, + { + name: "Array of interfaces", + input: []interface{}{ + map[string]interface{}{"name": "John"}, + map[string]interface{}{"name": "Jane"}, + map[string]interface{}{"name": "Bob"}, + }, + expected: 3, + }, + { + name: "Nil input", + input: nil, + expected: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := handler.normalizeToSlice(tt.input) + if len(result) != tt.expected { + t.Errorf("normalizeToSlice() returned slice of length %d, expected %d", len(result), tt.expected) + } + }) + } +} + +// Test GetRelationshipInfo function +func TestGetRelationshipInfo(t *testing.T) { + registry := &mockRegistry{} + handler := NewHandler(nil, registry) + + tests := []struct { + name string + modelType reflect.Type + relationName string + expectNil bool + }{ + { + name: "User posts relation", + modelType: reflect.TypeOf(TestUser{}), + relationName: "posts", + expectNil: false, + }, + { + name: "Post comments relation", + modelType: reflect.TypeOf(TestPost{}), + relationName: "comments", + expectNil: false, + }, + { + name: "Non-existent relation", + modelType: reflect.TypeOf(TestUser{}), + relationName: "nonexistent", + expectNil: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := handler.GetRelationshipInfo(tt.modelType, tt.relationName) + if tt.expectNil && result != nil { + t.Errorf("Expected nil, got %+v", result) + } + if !tt.expectNil && result == nil { + t.Errorf("Expected non-nil relationship info") + } + if result != nil { + t.Logf("Relationship info: FieldName=%s, JSONName=%s, RelationType=%s, ForeignKey=%s", + result.FieldName, result.JSONName, result.RelationType, result.ForeignKey) + } + }) + } +} + +// Mock registry for testing +type mockRegistry struct { + models map[string]interface{} +} + +func (m *mockRegistry) Register(name string, model interface{}) { + m.RegisterModel(name, model) +} + +func (m *mockRegistry) RegisterModel(name string, model interface{}) error { + if m.models == nil { + m.models = make(map[string]interface{}) + } + m.models[name] = model + return nil +} + +func (m *mockRegistry) GetModelByEntity(schema, entity string) (interface{}, error) { + if model, ok := m.models[entity]; ok { + return model, nil + } + return nil, fmt.Errorf("model not found: %s", entity) +} + +func (m *mockRegistry) GetModelByName(name string) (interface{}, error) { + if model, ok := m.models[name]; ok { + return model, nil + } + return nil, fmt.Errorf("model not found: %s", name) +} + +func (m *mockRegistry) GetModel(name string) (interface{}, error) { + return m.GetModelByName(name) +} + +func (m *mockRegistry) HasModel(schema, entity string) bool { + _, ok := m.models[entity] + return ok +} + +func (m *mockRegistry) ListModels() []string { + models := make([]string, 0, len(m.models)) + for name := range m.models { + models = append(models, name) + } + return models +} + +func (m *mockRegistry) GetAllModels() map[string]interface{} { + return m.models +} + +// TestMultiLevelRelationExtraction tests extracting deeply nested relations +func TestMultiLevelRelationExtraction(t *testing.T) { + registry := &mockRegistry{ + models: map[string]interface{}{ + "users": TestUser{}, + "posts": TestPost{}, + "comments": TestComment{}, + }, + } + handler := NewHandler(nil, registry) + + // Test data with 3 levels: User -> Posts -> Comments + testData := map[string]interface{}{ + "name": "John Doe", + "posts": []map[string]interface{}{ + { + "title": "First Post", + "comments": []map[string]interface{}{ + {"content": "Great post!"}, + {"content": "Thanks for sharing!"}, + }, + }, + { + "title": "Second Post", + "comments": []map[string]interface{}{ + {"content": "Interesting read"}, + }, + }, + }, + } + + // Extract relations from user + cleanedData, relations, err := handler.extractNestedRelations(testData, TestUser{}) + if err != nil { + t.Fatalf("Failed to extract relations: %v", err) + } + + // Verify user data is cleaned + if len(cleanedData) != 1 || cleanedData["name"] != "John Doe" { + t.Errorf("Expected cleaned data to contain only name, got: %+v", cleanedData) + } + + // Verify posts relation was extracted + if len(relations) != 1 { + t.Errorf("Expected 1 relation (posts), got %d", len(relations)) + } + + posts, ok := relations["posts"] + if !ok { + t.Fatal("Expected posts relation to be extracted") + } + + // Verify posts is a slice with 2 items + postsSlice, ok := posts.([]map[string]interface{}) + if !ok { + t.Fatalf("Expected posts to be []map[string]interface{}, got %T", posts) + } + + if len(postsSlice) != 2 { + t.Errorf("Expected 2 posts, got %d", len(postsSlice)) + } + + // Verify first post has comments + if _, hasComments := postsSlice[0]["comments"]; !hasComments { + t.Error("Expected first post to have comments") + } + + t.Logf("Successfully extracted multi-level nested relations") + t.Logf("Cleaned data: %+v", cleanedData) + t.Logf("Relations: %d posts with nested comments", len(postsSlice)) +}