mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2025-12-06 14:26:22 +00:00
Fixes for CUD operations
This commit is contained in:
parent
35f23b6d9e
commit
14daea3b05
@ -45,6 +45,62 @@ func GetPrimaryKeyName(model any) string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetPrimaryKeyValue extracts the primary key value from a model instance
|
||||||
|
// Returns the value of the primary key field
|
||||||
|
func GetPrimaryKeyValue(model any) interface{} {
|
||||||
|
if model == nil || reflect.TypeOf(model) == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
val := reflect.ValueOf(model)
|
||||||
|
if val.Kind() == reflect.Pointer {
|
||||||
|
val = val.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
if val.Kind() != reflect.Struct {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
typ := val.Type()
|
||||||
|
|
||||||
|
// Try Bun tag first
|
||||||
|
for i := 0; i < typ.NumField(); i++ {
|
||||||
|
field := typ.Field(i)
|
||||||
|
bunTag := field.Tag.Get("bun")
|
||||||
|
if strings.Contains(bunTag, "pk") {
|
||||||
|
fieldValue := val.Field(i)
|
||||||
|
if fieldValue.CanInterface() {
|
||||||
|
return fieldValue.Interface()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fall back to GORM tag
|
||||||
|
for i := 0; i < typ.NumField(); i++ {
|
||||||
|
field := typ.Field(i)
|
||||||
|
gormTag := field.Tag.Get("gorm")
|
||||||
|
if strings.Contains(gormTag, "primaryKey") {
|
||||||
|
fieldValue := val.Field(i)
|
||||||
|
if fieldValue.CanInterface() {
|
||||||
|
return fieldValue.Interface()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Last resort: look for field named "ID" or "Id"
|
||||||
|
for i := 0; i < typ.NumField(); i++ {
|
||||||
|
field := typ.Field(i)
|
||||||
|
if strings.ToLower(field.Name) == "id" {
|
||||||
|
fieldValue := val.Field(i)
|
||||||
|
if fieldValue.CanInterface() {
|
||||||
|
return fieldValue.Interface()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// GetModelColumns extracts all column names from a model using reflection
|
// GetModelColumns extracts all column names from a model using reflection
|
||||||
// It checks bun tags first, then gorm tags, then json tags, and finally falls back to lowercase field names
|
// It checks bun tags first, then gorm tags, then json tags, and finally falls back to lowercase field names
|
||||||
func GetModelColumns(model any) []string {
|
func GetModelColumns(model any) []string {
|
||||||
|
|||||||
@ -584,22 +584,6 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
|
|||||||
|
|
||||||
logger.Info("Creating record in %s.%s", schema, entity)
|
logger.Info("Creating record in %s.%s", schema, entity)
|
||||||
|
|
||||||
// Check if data is a single map with nested relations
|
|
||||||
if dataMap, ok := data.(map[string]interface{}); ok {
|
|
||||||
if h.shouldUseNestedProcessor(dataMap, model) {
|
|
||||||
logger.Info("Using nested CUD processor for create operation")
|
|
||||||
result, err := h.nestedProcessor.ProcessNestedCUD(ctx, "insert", dataMap, model, make(map[string]interface{}), tableName)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("Error in nested create: %v", err)
|
|
||||||
h.sendError(w, http.StatusInternalServerError, "create_error", "Error creating record with nested data", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
logger.Info("Successfully created record with nested data, ID: %v", result.ID)
|
|
||||||
h.sendResponseWithOptions(w, result.Data, nil, &options)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Execute BeforeCreate hooks
|
// Execute BeforeCreate hooks
|
||||||
hookCtx := &HookContext{
|
hookCtx := &HookContext{
|
||||||
Context: ctx,
|
Context: ctx,
|
||||||
@ -622,172 +606,113 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
|
|||||||
// Use potentially modified data from hook context
|
// Use potentially modified data from hook context
|
||||||
data = hookCtx.Data
|
data = hookCtx.Data
|
||||||
|
|
||||||
// Handle batch creation
|
// Normalize data to slice for unified processing
|
||||||
dataValue := reflect.ValueOf(data)
|
dataSlice := h.normalizeToSlice(data)
|
||||||
if dataValue.Kind() == reflect.Slice || dataValue.Kind() == reflect.Array {
|
logger.Debug("Processing %d item(s) for creation", len(dataSlice))
|
||||||
logger.Debug("Batch creation detected, count: %d", dataValue.Len())
|
|
||||||
|
|
||||||
// Check if any item needs nested processing
|
// Process all items in a transaction
|
||||||
hasNestedData := false
|
results := make([]interface{}, 0, len(dataSlice))
|
||||||
for i := 0; i < dataValue.Len(); i++ {
|
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
||||||
item := dataValue.Index(i).Interface()
|
// Create temporary nested processor with transaction
|
||||||
if itemMap, ok := item.(map[string]interface{}); ok {
|
txNestedProcessor := common.NewNestedCUDProcessor(tx, h.registry, h)
|
||||||
if h.shouldUseNestedProcessor(itemMap, model) {
|
|
||||||
hasNestedData = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if hasNestedData {
|
for i, item := range dataSlice {
|
||||||
logger.Info("Using nested CUD processor for batch create with nested data")
|
itemMap, ok := item.(map[string]interface{})
|
||||||
results := make([]interface{}, 0, dataValue.Len())
|
if !ok {
|
||||||
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
// Convert to map if needed
|
||||||
// Temporarily swap the database to use transaction
|
|
||||||
originalDB := h.nestedProcessor
|
|
||||||
h.nestedProcessor = common.NewNestedCUDProcessor(tx, h.registry, h)
|
|
||||||
defer func() {
|
|
||||||
h.nestedProcessor = originalDB
|
|
||||||
}()
|
|
||||||
|
|
||||||
for i := 0; i < dataValue.Len(); i++ {
|
|
||||||
item := dataValue.Index(i).Interface()
|
|
||||||
if itemMap, ok := item.(map[string]interface{}); ok {
|
|
||||||
result, err := h.nestedProcessor.ProcessNestedCUD(ctx, "insert", itemMap, model, make(map[string]interface{}), tableName)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to process item: %w", err)
|
|
||||||
}
|
|
||||||
results = append(results, result.Data)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("Error creating records with nested data: %v", err)
|
|
||||||
h.sendError(w, http.StatusInternalServerError, "create_error", "Error creating records with nested data", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Execute AfterCreate hooks
|
|
||||||
hookCtx.Result = map[string]interface{}{"created": len(results), "data": results}
|
|
||||||
hookCtx.Error = nil
|
|
||||||
|
|
||||||
if err := h.hooks.Execute(AfterCreate, hookCtx); err != nil {
|
|
||||||
logger.Error("AfterCreate hook failed: %v", err)
|
|
||||||
h.sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Info("Successfully created %d records with nested data", len(results))
|
|
||||||
h.sendResponseWithOptions(w, results, nil, &options)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Standard batch insert without nested relations
|
|
||||||
// Use transaction for batch insert
|
|
||||||
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
|
||||||
for i := 0; i < dataValue.Len(); i++ {
|
|
||||||
item := dataValue.Index(i).Interface()
|
|
||||||
|
|
||||||
// Convert item to model type - create a pointer to the model
|
|
||||||
modelValue := reflect.New(reflect.TypeOf(model)).Interface()
|
|
||||||
jsonData, err := json.Marshal(item)
|
jsonData, err := json.Marshal(item)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to marshal item: %w", err)
|
return fmt.Errorf("failed to marshal item %d: %w", i, err)
|
||||||
}
|
}
|
||||||
if err := json.Unmarshal(jsonData, modelValue); err != nil {
|
itemMap = make(map[string]interface{})
|
||||||
return fmt.Errorf("failed to unmarshal item: %w", err)
|
if err := json.Unmarshal(jsonData, &itemMap); err != nil {
|
||||||
}
|
return fmt.Errorf("failed to unmarshal item %d: %w", i, err)
|
||||||
|
|
||||||
query := tx.NewInsert().Model(modelValue).Table(tableName)
|
|
||||||
|
|
||||||
// Execute BeforeScan hooks - pass query chain so hooks can modify it
|
|
||||||
batchHookCtx := &HookContext{
|
|
||||||
Context: ctx,
|
|
||||||
Handler: h,
|
|
||||||
Schema: schema,
|
|
||||||
Entity: entity,
|
|
||||||
TableName: tableName,
|
|
||||||
Model: model,
|
|
||||||
Options: options,
|
|
||||||
Data: modelValue,
|
|
||||||
Writer: w,
|
|
||||||
Query: query,
|
|
||||||
}
|
|
||||||
if err := h.hooks.Execute(BeforeScan, batchHookCtx); err != nil {
|
|
||||||
return fmt.Errorf("BeforeScan hook failed: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Use potentially modified query from hook context
|
|
||||||
if modifiedQuery, ok := batchHookCtx.Query.(common.InsertQuery); ok {
|
|
||||||
query = modifiedQuery
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := query.Exec(ctx); err != nil {
|
|
||||||
return fmt.Errorf("failed to insert record: %w", err)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
|
||||||
})
|
|
||||||
|
|
||||||
if err != nil {
|
// Extract nested relations if present (but don't process them yet)
|
||||||
logger.Error("Error creating records: %v", err)
|
var nestedRelations map[string]interface{}
|
||||||
h.sendError(w, http.StatusInternalServerError, "create_error", "Error creating records", err)
|
if h.shouldUseNestedProcessor(itemMap, model) {
|
||||||
return
|
logger.Debug("Extracting nested relations for item %d", i)
|
||||||
|
cleanedData, relations, err := h.extractNestedRelations(itemMap, model)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to extract nested relations for item %d: %w", i, err)
|
||||||
|
}
|
||||||
|
itemMap = cleanedData
|
||||||
|
nestedRelations = relations
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert item to model type - create a pointer to the model
|
||||||
|
modelValue := reflect.New(reflect.TypeOf(model)).Interface()
|
||||||
|
jsonData, err := json.Marshal(itemMap)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to marshal item %d: %w", i, err)
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(jsonData, modelValue); err != nil {
|
||||||
|
return fmt.Errorf("failed to unmarshal item %d: %w", i, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create insert query
|
||||||
|
query := tx.NewInsert().Model(modelValue).Table(tableName).Returning("*")
|
||||||
|
|
||||||
|
// Execute BeforeScan hooks - pass query chain so hooks can modify it
|
||||||
|
itemHookCtx := &HookContext{
|
||||||
|
Context: ctx,
|
||||||
|
Handler: h,
|
||||||
|
Schema: schema,
|
||||||
|
Entity: entity,
|
||||||
|
TableName: tableName,
|
||||||
|
Model: model,
|
||||||
|
Options: options,
|
||||||
|
Data: modelValue,
|
||||||
|
Writer: w,
|
||||||
|
Query: query,
|
||||||
|
}
|
||||||
|
if err := h.hooks.Execute(BeforeScan, itemHookCtx); err != nil {
|
||||||
|
return fmt.Errorf("BeforeScan hook failed for item %d: %w", i, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use potentially modified query from hook context
|
||||||
|
if modifiedQuery, ok := itemHookCtx.Query.(common.InsertQuery); ok {
|
||||||
|
query = modifiedQuery
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute insert and get the ID
|
||||||
|
if _, err := query.Exec(ctx); err != nil {
|
||||||
|
return fmt.Errorf("failed to insert item %d: %w", i, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the inserted ID
|
||||||
|
insertedID := reflection.GetPrimaryKeyValue(modelValue)
|
||||||
|
|
||||||
|
// Now process nested relations with the parent ID
|
||||||
|
if len(nestedRelations) > 0 {
|
||||||
|
logger.Debug("Processing nested relations for item %d with parent ID: %v", i, insertedID)
|
||||||
|
if err := h.processChildRelationsWithParentID(ctx, txNestedProcessor, "insert", nestedRelations, model, insertedID); err != nil {
|
||||||
|
return fmt.Errorf("failed to process nested relations for item %d: %w", i, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
results = append(results, modelValue)
|
||||||
}
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
// Execute AfterCreate hooks for batch creation
|
|
||||||
hookCtx.Result = map[string]interface{}{"created": dataValue.Len()}
|
|
||||||
hookCtx.Error = nil
|
|
||||||
|
|
||||||
if err := h.hooks.Execute(AfterCreate, hookCtx); err != nil {
|
|
||||||
logger.Error("AfterCreate hook failed: %v", err)
|
|
||||||
h.sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
h.sendResponse(w, map[string]interface{}{"created": dataValue.Len()}, nil)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Single record creation - create a pointer to the model
|
|
||||||
modelValue := reflect.New(reflect.TypeOf(model)).Interface()
|
|
||||||
jsonData, err := json.Marshal(data)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Error marshaling data: %v", err)
|
logger.Error("Error creating records: %v", err)
|
||||||
h.sendError(w, http.StatusBadRequest, "invalid_data", "Invalid data format", err)
|
h.sendError(w, http.StatusInternalServerError, "create_error", "Error creating records", err)
|
||||||
return
|
|
||||||
}
|
|
||||||
if err := json.Unmarshal(jsonData, modelValue); err != nil {
|
|
||||||
logger.Error("Error unmarshaling data: %v", err)
|
|
||||||
h.sendError(w, http.StatusBadRequest, "invalid_data", "Invalid data format", err)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
query := h.db.NewInsert().Model(modelValue).Table(tableName)
|
// Execute AfterCreate hooks
|
||||||
|
var responseData interface{}
|
||||||
// Execute BeforeScan hooks - pass query chain so hooks can modify it
|
if len(results) == 1 {
|
||||||
hookCtx.Data = modelValue
|
responseData = results[0]
|
||||||
hookCtx.Query = query
|
hookCtx.Result = results[0]
|
||||||
if err := h.hooks.Execute(BeforeScan, hookCtx); err != nil {
|
} else {
|
||||||
logger.Error("BeforeScan hook failed: %v", err)
|
responseData = results
|
||||||
h.sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err)
|
hookCtx.Result = map[string]interface{}{"created": len(results), "data": results}
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use potentially modified query from hook context
|
|
||||||
if modifiedQuery, ok := hookCtx.Query.(common.InsertQuery); ok {
|
|
||||||
query = modifiedQuery
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := query.Exec(ctx); err != nil {
|
|
||||||
logger.Error("Error creating record: %v", err)
|
|
||||||
h.sendError(w, http.StatusInternalServerError, "create_error", "Error creating record", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Execute AfterCreate hooks for single record creation
|
|
||||||
hookCtx.Result = modelValue
|
|
||||||
hookCtx.Error = nil
|
hookCtx.Error = nil
|
||||||
|
|
||||||
if err := h.hooks.Execute(AfterCreate, hookCtx); err != nil {
|
if err := h.hooks.Execute(AfterCreate, hookCtx); err != nil {
|
||||||
@ -796,7 +721,8 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
h.sendResponseWithOptions(w, modelValue, nil, &options)
|
logger.Info("Successfully created %d record(s)", len(results))
|
||||||
|
h.sendResponseWithOptions(w, responseData, nil, &options)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id string, idPtr *int64, data interface{}, options ExtendedRequestOptions) {
|
func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id string, idPtr *int64, data interface{}, options ExtendedRequestOptions) {
|
||||||
@ -814,46 +740,6 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id
|
|||||||
|
|
||||||
logger.Info("Updating record in %s.%s", schema, entity)
|
logger.Info("Updating record in %s.%s", schema, entity)
|
||||||
|
|
||||||
// Convert data to map first for nested processor check
|
|
||||||
dataMap, ok := data.(map[string]interface{})
|
|
||||||
if !ok {
|
|
||||||
jsonData, err := json.Marshal(data)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("Error marshaling data: %v", err)
|
|
||||||
h.sendError(w, http.StatusBadRequest, "invalid_data", "Invalid data format", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if err := json.Unmarshal(jsonData, &dataMap); err != nil {
|
|
||||||
logger.Error("Error unmarshaling data: %v", err)
|
|
||||||
h.sendError(w, http.StatusBadRequest, "invalid_data", "Invalid data format", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if we should use nested processing
|
|
||||||
if h.shouldUseNestedProcessor(dataMap, model) {
|
|
||||||
logger.Info("Using nested CUD processor for update operation")
|
|
||||||
// Ensure ID is in the data map
|
|
||||||
var targetID interface{}
|
|
||||||
if id != "" {
|
|
||||||
targetID = id
|
|
||||||
} else if idPtr != nil {
|
|
||||||
targetID = *idPtr
|
|
||||||
}
|
|
||||||
if targetID != nil {
|
|
||||||
dataMap["id"] = targetID
|
|
||||||
}
|
|
||||||
result, err := h.nestedProcessor.ProcessNestedCUD(ctx, "update", dataMap, model, make(map[string]interface{}), tableName)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("Error in nested update: %v", err)
|
|
||||||
h.sendError(w, http.StatusInternalServerError, "update_error", "Error updating record with nested data", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
logger.Info("Successfully updated record with nested data, rows: %d", result.AffectedRows)
|
|
||||||
h.sendResponseWithOptions(w, result.Data, nil, &options)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Execute BeforeUpdate hooks
|
// Execute BeforeUpdate hooks
|
||||||
hookCtx := &HookContext{
|
hookCtx := &HookContext{
|
||||||
Context: ctx,
|
Context: ctx,
|
||||||
@ -877,8 +763,8 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id
|
|||||||
// Use potentially modified data from hook context
|
// Use potentially modified data from hook context
|
||||||
data = hookCtx.Data
|
data = hookCtx.Data
|
||||||
|
|
||||||
// Convert data to map (again if modified by hooks)
|
// Convert data to map
|
||||||
dataMap, ok = data.(map[string]interface{})
|
dataMap, ok := data.(map[string]interface{})
|
||||||
if !ok {
|
if !ok {
|
||||||
jsonData, err := json.Marshal(data)
|
jsonData, err := json.Marshal(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -893,33 +779,74 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
query := h.db.NewUpdate().Table(tableName).SetMap(dataMap)
|
// Determine target ID
|
||||||
pkName := reflection.GetPrimaryKeyName(model)
|
var targetID interface{}
|
||||||
// Apply ID filter
|
if id != "" {
|
||||||
switch {
|
targetID = id
|
||||||
case id != "":
|
} else if idPtr != nil {
|
||||||
query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id)
|
targetID = *idPtr
|
||||||
case idPtr != nil:
|
} else {
|
||||||
query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), *idPtr)
|
|
||||||
default:
|
|
||||||
h.sendError(w, http.StatusBadRequest, "missing_id", "ID is required for update", nil)
|
h.sendError(w, http.StatusBadRequest, "missing_id", "ID is required for update", nil)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute BeforeScan hooks - pass query chain so hooks can modify it
|
// Process nested relations if present
|
||||||
hookCtx.Query = query
|
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
||||||
if err := h.hooks.Execute(BeforeScan, hookCtx); err != nil {
|
// Create temporary nested processor with transaction
|
||||||
logger.Error("BeforeScan hook failed: %v", err)
|
txNestedProcessor := common.NewNestedCUDProcessor(tx, h.registry, h)
|
||||||
h.sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Use potentially modified query from hook context
|
// Extract nested relations if present (but don't process them yet)
|
||||||
if modifiedQuery, ok := hookCtx.Query.(common.UpdateQuery); ok {
|
var nestedRelations map[string]interface{}
|
||||||
query = modifiedQuery
|
if h.shouldUseNestedProcessor(dataMap, model) {
|
||||||
}
|
logger.Debug("Extracting nested relations for update")
|
||||||
|
cleanedData, relations, err := h.extractNestedRelations(dataMap, model)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to extract nested relations: %w", err)
|
||||||
|
}
|
||||||
|
dataMap = cleanedData
|
||||||
|
nestedRelations = relations
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure ID is in the data map for the update
|
||||||
|
dataMap["id"] = targetID
|
||||||
|
|
||||||
|
// Create update query
|
||||||
|
query := tx.NewUpdate().Table(tableName).SetMap(dataMap)
|
||||||
|
pkName := reflection.GetPrimaryKeyName(model)
|
||||||
|
query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), targetID)
|
||||||
|
|
||||||
|
// Execute BeforeScan hooks - pass query chain so hooks can modify it
|
||||||
|
hookCtx.Query = query
|
||||||
|
if err := h.hooks.Execute(BeforeScan, hookCtx); err != nil {
|
||||||
|
return fmt.Errorf("BeforeScan hook failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use potentially modified query from hook context
|
||||||
|
if modifiedQuery, ok := hookCtx.Query.(common.UpdateQuery); ok {
|
||||||
|
query = modifiedQuery
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute update
|
||||||
|
result, err := query.Exec(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to update record: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now process nested relations with the parent ID
|
||||||
|
if len(nestedRelations) > 0 {
|
||||||
|
logger.Debug("Processing nested relations for update with parent ID: %v", targetID)
|
||||||
|
if err := h.processChildRelationsWithParentID(ctx, txNestedProcessor, "update", nestedRelations, model, targetID); err != nil {
|
||||||
|
return fmt.Errorf("failed to process nested relations: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store result for hooks
|
||||||
|
hookCtx.Result = map[string]interface{}{
|
||||||
|
"updated": result.RowsAffected(),
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
result, err := query.Exec(ctx)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Error updating record: %v", err)
|
logger.Error("Error updating record: %v", err)
|
||||||
h.sendError(w, http.StatusInternalServerError, "update_error", "Error updating record", err)
|
h.sendError(w, http.StatusInternalServerError, "update_error", "Error updating record", err)
|
||||||
@ -927,19 +854,15 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Execute AfterUpdate hooks
|
// Execute AfterUpdate hooks
|
||||||
responseData := map[string]interface{}{
|
|
||||||
"updated": result.RowsAffected(),
|
|
||||||
}
|
|
||||||
hookCtx.Result = responseData
|
|
||||||
hookCtx.Error = nil
|
hookCtx.Error = nil
|
||||||
|
|
||||||
if err := h.hooks.Execute(AfterUpdate, hookCtx); err != nil {
|
if err := h.hooks.Execute(AfterUpdate, hookCtx); err != nil {
|
||||||
logger.Error("AfterUpdate hook failed: %v", err)
|
logger.Error("AfterUpdate hook failed: %v", err)
|
||||||
h.sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err)
|
h.sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
h.sendResponseWithOptions(w, responseData, nil, &options)
|
logger.Info("Successfully updated record with ID: %v", targetID)
|
||||||
|
h.sendResponseWithOptions(w, hookCtx.Result, nil, &options)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id string, data interface{}) {
|
func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id string, data interface{}) {
|
||||||
@ -1199,6 +1122,196 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
|||||||
h.sendResponse(w, responseData, nil)
|
h.sendResponse(w, responseData, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// normalizeToSlice converts data to a slice. Single items become a 1-item slice.
|
||||||
|
func (h *Handler) normalizeToSlice(data interface{}) []interface{} {
|
||||||
|
if data == nil {
|
||||||
|
return []interface{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
dataValue := reflect.ValueOf(data)
|
||||||
|
if dataValue.Kind() == reflect.Slice || dataValue.Kind() == reflect.Array {
|
||||||
|
result := make([]interface{}, dataValue.Len())
|
||||||
|
for i := 0; i < dataValue.Len(); i++ {
|
||||||
|
result[i] = dataValue.Index(i).Interface()
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// Single item - return as 1-item slice
|
||||||
|
return []interface{}{data}
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractNestedRelations extracts nested relations from data, returning cleaned data and relations
|
||||||
|
// This does NOT process the relations, just separates them for later processing
|
||||||
|
func (h *Handler) extractNestedRelations(
|
||||||
|
data map[string]interface{},
|
||||||
|
model interface{},
|
||||||
|
) (map[string]interface{}, map[string]interface{}, error) {
|
||||||
|
// Get model type for reflection
|
||||||
|
modelType := reflect.TypeOf(model)
|
||||||
|
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||||
|
modelType = modelType.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||||
|
return data, nil, fmt.Errorf("model must be a struct type, got %v", modelType)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Separate relation fields from regular fields
|
||||||
|
cleanedData := make(map[string]interface{})
|
||||||
|
relations := make(map[string]interface{})
|
||||||
|
|
||||||
|
for key, value := range data {
|
||||||
|
// Skip _request field
|
||||||
|
if key == "_request" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if this field is a relation
|
||||||
|
relInfo := h.GetRelationshipInfo(modelType, key)
|
||||||
|
if relInfo != nil {
|
||||||
|
logger.Debug("Found nested relation field: %s (type: %s)", key, relInfo.RelationType)
|
||||||
|
relations[key] = value
|
||||||
|
} else {
|
||||||
|
cleanedData[key] = value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return cleanedData, relations, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// processChildRelationsWithParentID processes nested relations with a parent ID
|
||||||
|
func (h *Handler) processChildRelationsWithParentID(
|
||||||
|
ctx context.Context,
|
||||||
|
processor *common.NestedCUDProcessor,
|
||||||
|
operation string,
|
||||||
|
relations map[string]interface{},
|
||||||
|
parentModel interface{},
|
||||||
|
parentID interface{},
|
||||||
|
) error {
|
||||||
|
// Get model type for reflection
|
||||||
|
modelType := reflect.TypeOf(parentModel)
|
||||||
|
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||||
|
modelType = modelType.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||||
|
return fmt.Errorf("model must be a struct type, got %v", modelType)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process each relation
|
||||||
|
for relationName, relationValue := range relations {
|
||||||
|
if relationValue == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get relationship info
|
||||||
|
relInfo := h.GetRelationshipInfo(modelType, relationName)
|
||||||
|
if relInfo == nil {
|
||||||
|
logger.Warn("No relationship info found for %s, skipping", relationName)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process this relation with parent ID
|
||||||
|
if err := h.processChildRelationsForField(ctx, processor, operation, relationName, relationValue, relInfo, modelType, parentID); err != nil {
|
||||||
|
return fmt.Errorf("failed to process relation %s: %w", relationName, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// processChildRelationsForField processes a single nested relation field
|
||||||
|
func (h *Handler) processChildRelationsForField(
|
||||||
|
ctx context.Context,
|
||||||
|
processor *common.NestedCUDProcessor,
|
||||||
|
operation string,
|
||||||
|
relationName string,
|
||||||
|
relationValue interface{},
|
||||||
|
relInfo *common.RelationshipInfo,
|
||||||
|
parentModelType reflect.Type,
|
||||||
|
parentID interface{},
|
||||||
|
) error {
|
||||||
|
if relationValue == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the related model
|
||||||
|
field, found := parentModelType.FieldByName(relInfo.FieldName)
|
||||||
|
if !found {
|
||||||
|
return fmt.Errorf("field %s not found in model", relInfo.FieldName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the model type for the relation
|
||||||
|
relatedModelType := field.Type
|
||||||
|
if relatedModelType.Kind() == reflect.Slice {
|
||||||
|
relatedModelType = relatedModelType.Elem()
|
||||||
|
}
|
||||||
|
if relatedModelType.Kind() == reflect.Ptr {
|
||||||
|
relatedModelType = relatedModelType.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create an instance of the related model
|
||||||
|
relatedModel := reflect.New(relatedModelType).Elem().Interface()
|
||||||
|
|
||||||
|
// Get table name for related model
|
||||||
|
relatedTableName := h.getTableNameForRelatedModel(relatedModel, relInfo.JSONName)
|
||||||
|
|
||||||
|
// Prepare parent IDs for foreign key injection
|
||||||
|
parentIDs := make(map[string]interface{})
|
||||||
|
if relInfo.ForeignKey != "" && parentID != nil {
|
||||||
|
baseName := strings.TrimSuffix(relInfo.ForeignKey, "ID")
|
||||||
|
baseName = strings.TrimSuffix(strings.ToLower(baseName), "_id")
|
||||||
|
parentIDs[baseName] = parentID
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process based on relation type and data structure
|
||||||
|
switch v := relationValue.(type) {
|
||||||
|
case map[string]interface{}:
|
||||||
|
// Single related object
|
||||||
|
_, err := processor.ProcessNestedCUD(ctx, operation, v, relatedModel, parentIDs, relatedTableName)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to process single relation: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
case []interface{}:
|
||||||
|
// Multiple related objects
|
||||||
|
for i, item := range v {
|
||||||
|
if itemMap, ok := item.(map[string]interface{}); ok {
|
||||||
|
_, err := processor.ProcessNestedCUD(ctx, operation, itemMap, relatedModel, parentIDs, relatedTableName)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to process relation item %d: %w", i, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
case []map[string]interface{}:
|
||||||
|
// Multiple related objects (typed slice)
|
||||||
|
for i, itemMap := range v {
|
||||||
|
_, err := processor.ProcessNestedCUD(ctx, operation, itemMap, relatedModel, parentIDs, relatedTableName)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to process relation item %d: %w", i, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unsupported relation data type: %T", relationValue)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getTableNameForRelatedModel gets the table name for a related model
|
||||||
|
func (h *Handler) getTableNameForRelatedModel(model interface{}, defaultName string) string {
|
||||||
|
if provider, ok := model.(common.TableNameProvider); ok {
|
||||||
|
tableName := provider.TableName()
|
||||||
|
if tableName != "" {
|
||||||
|
return tableName
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return defaultName
|
||||||
|
}
|
||||||
|
|
||||||
// qualifyColumnName ensures column name is fully qualified with table name if not already
|
// qualifyColumnName ensures column name is fully qualified with table name if not already
|
||||||
func (h *Handler) qualifyColumnName(columnName, fullTableName string) string {
|
func (h *Handler) qualifyColumnName(columnName, fullTableName string) string {
|
||||||
// Check if column already has a table/schema prefix (contains a dot)
|
// Check if column already has a table/schema prefix (contains a dot)
|
||||||
|
|||||||
393
pkg/restheadspec/handler_nested_test.go
Normal file
393
pkg/restheadspec/handler_nested_test.go
Normal file
@ -0,0 +1,393 @@
|
|||||||
|
package restheadspec
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Test models for nested CRUD operations
|
||||||
|
type TestUser struct {
|
||||||
|
ID int64 `json:"id" bun:"id,pk,autoincrement"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Posts []TestPost `json:"posts" gorm:"foreignKey:UserID"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type TestPost struct {
|
||||||
|
ID int64 `json:"id" bun:"id,pk,autoincrement"`
|
||||||
|
UserID int64 `json:"user_id"`
|
||||||
|
Title string `json:"title"`
|
||||||
|
Comments []TestComment `json:"comments" gorm:"foreignKey:PostID"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type TestComment struct {
|
||||||
|
ID int64 `json:"id" bun:"id,pk,autoincrement"`
|
||||||
|
PostID int64 `json:"post_id"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (TestUser) TableName() string { return "users" }
|
||||||
|
func (TestPost) TableName() string { return "posts" }
|
||||||
|
func (TestComment) TableName() string { return "comments" }
|
||||||
|
|
||||||
|
// Test extractNestedRelations function
|
||||||
|
func TestExtractNestedRelations(t *testing.T) {
|
||||||
|
// Create handler
|
||||||
|
registry := &mockRegistry{
|
||||||
|
models: map[string]interface{}{
|
||||||
|
"users": TestUser{},
|
||||||
|
"posts": TestPost{},
|
||||||
|
"comments": TestComment{},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
handler := NewHandler(nil, registry)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
data map[string]interface{}
|
||||||
|
model interface{}
|
||||||
|
expectedCleanCount int
|
||||||
|
expectedRelCount int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "User with posts",
|
||||||
|
data: map[string]interface{}{
|
||||||
|
"name": "John Doe",
|
||||||
|
"posts": []map[string]interface{}{
|
||||||
|
{"title": "Post 1"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
model: TestUser{},
|
||||||
|
expectedCleanCount: 1, // name
|
||||||
|
expectedRelCount: 1, // posts
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Post with comments",
|
||||||
|
data: map[string]interface{}{
|
||||||
|
"title": "Test Post",
|
||||||
|
"comments": []map[string]interface{}{
|
||||||
|
{"content": "Comment 1"},
|
||||||
|
{"content": "Comment 2"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
model: TestPost{},
|
||||||
|
expectedCleanCount: 1, // title
|
||||||
|
expectedRelCount: 1, // comments
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "User with nested posts and comments",
|
||||||
|
data: map[string]interface{}{
|
||||||
|
"name": "Jane Doe",
|
||||||
|
"posts": []map[string]interface{}{
|
||||||
|
{
|
||||||
|
"title": "Post 1",
|
||||||
|
"comments": []map[string]interface{}{
|
||||||
|
{"content": "Comment 1"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
model: TestUser{},
|
||||||
|
expectedCleanCount: 1, // name
|
||||||
|
expectedRelCount: 1, // posts (which contains nested comments)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
cleanedData, relations, err := handler.extractNestedRelations(tt.data, tt.model)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("extractNestedRelations() error = %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(cleanedData) != tt.expectedCleanCount {
|
||||||
|
t.Errorf("Expected %d cleaned fields, got %d: %+v", tt.expectedCleanCount, len(cleanedData), cleanedData)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(relations) != tt.expectedRelCount {
|
||||||
|
t.Errorf("Expected %d relation fields, got %d: %+v", tt.expectedRelCount, len(relations), relations)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Logf("Cleaned data: %+v", cleanedData)
|
||||||
|
t.Logf("Relations: %+v", relations)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test shouldUseNestedProcessor function
|
||||||
|
func TestShouldUseNestedProcessor(t *testing.T) {
|
||||||
|
registry := &mockRegistry{
|
||||||
|
models: map[string]interface{}{
|
||||||
|
"users": TestUser{},
|
||||||
|
"posts": TestPost{},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
handler := NewHandler(nil, registry)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
data map[string]interface{}
|
||||||
|
model interface{}
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Data with nested posts",
|
||||||
|
data: map[string]interface{}{
|
||||||
|
"name": "John",
|
||||||
|
"posts": []map[string]interface{}{
|
||||||
|
{"title": "Post 1"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
model: TestUser{},
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Data without nested relations",
|
||||||
|
data: map[string]interface{}{
|
||||||
|
"name": "John",
|
||||||
|
},
|
||||||
|
model: TestUser{},
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Data with _request field",
|
||||||
|
data: map[string]interface{}{
|
||||||
|
"_request": "insert",
|
||||||
|
"name": "John",
|
||||||
|
},
|
||||||
|
model: TestUser{},
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := handler.shouldUseNestedProcessor(tt.data, tt.model)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("shouldUseNestedProcessor() = %v, expected %v", result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test normalizeToSlice function
|
||||||
|
func TestNormalizeToSlice(t *testing.T) {
|
||||||
|
registry := &mockRegistry{}
|
||||||
|
handler := NewHandler(nil, registry)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input interface{}
|
||||||
|
expected int // expected slice length
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Single object",
|
||||||
|
input: map[string]interface{}{"name": "John"},
|
||||||
|
expected: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Slice of objects",
|
||||||
|
input: []map[string]interface{}{
|
||||||
|
{"name": "John"},
|
||||||
|
{"name": "Jane"},
|
||||||
|
},
|
||||||
|
expected: 2,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Array of interfaces",
|
||||||
|
input: []interface{}{
|
||||||
|
map[string]interface{}{"name": "John"},
|
||||||
|
map[string]interface{}{"name": "Jane"},
|
||||||
|
map[string]interface{}{"name": "Bob"},
|
||||||
|
},
|
||||||
|
expected: 3,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Nil input",
|
||||||
|
input: nil,
|
||||||
|
expected: 0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := handler.normalizeToSlice(tt.input)
|
||||||
|
if len(result) != tt.expected {
|
||||||
|
t.Errorf("normalizeToSlice() returned slice of length %d, expected %d", len(result), tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test GetRelationshipInfo function
|
||||||
|
func TestGetRelationshipInfo(t *testing.T) {
|
||||||
|
registry := &mockRegistry{}
|
||||||
|
handler := NewHandler(nil, registry)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
modelType reflect.Type
|
||||||
|
relationName string
|
||||||
|
expectNil bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "User posts relation",
|
||||||
|
modelType: reflect.TypeOf(TestUser{}),
|
||||||
|
relationName: "posts",
|
||||||
|
expectNil: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Post comments relation",
|
||||||
|
modelType: reflect.TypeOf(TestPost{}),
|
||||||
|
relationName: "comments",
|
||||||
|
expectNil: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Non-existent relation",
|
||||||
|
modelType: reflect.TypeOf(TestUser{}),
|
||||||
|
relationName: "nonexistent",
|
||||||
|
expectNil: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := handler.GetRelationshipInfo(tt.modelType, tt.relationName)
|
||||||
|
if tt.expectNil && result != nil {
|
||||||
|
t.Errorf("Expected nil, got %+v", result)
|
||||||
|
}
|
||||||
|
if !tt.expectNil && result == nil {
|
||||||
|
t.Errorf("Expected non-nil relationship info")
|
||||||
|
}
|
||||||
|
if result != nil {
|
||||||
|
t.Logf("Relationship info: FieldName=%s, JSONName=%s, RelationType=%s, ForeignKey=%s",
|
||||||
|
result.FieldName, result.JSONName, result.RelationType, result.ForeignKey)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mock registry for testing
|
||||||
|
type mockRegistry struct {
|
||||||
|
models map[string]interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockRegistry) Register(name string, model interface{}) {
|
||||||
|
m.RegisterModel(name, model)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockRegistry) RegisterModel(name string, model interface{}) error {
|
||||||
|
if m.models == nil {
|
||||||
|
m.models = make(map[string]interface{})
|
||||||
|
}
|
||||||
|
m.models[name] = model
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockRegistry) GetModelByEntity(schema, entity string) (interface{}, error) {
|
||||||
|
if model, ok := m.models[entity]; ok {
|
||||||
|
return model, nil
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("model not found: %s", entity)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockRegistry) GetModelByName(name string) (interface{}, error) {
|
||||||
|
if model, ok := m.models[name]; ok {
|
||||||
|
return model, nil
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("model not found: %s", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockRegistry) GetModel(name string) (interface{}, error) {
|
||||||
|
return m.GetModelByName(name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockRegistry) HasModel(schema, entity string) bool {
|
||||||
|
_, ok := m.models[entity]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockRegistry) ListModels() []string {
|
||||||
|
models := make([]string, 0, len(m.models))
|
||||||
|
for name := range m.models {
|
||||||
|
models = append(models, name)
|
||||||
|
}
|
||||||
|
return models
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockRegistry) GetAllModels() map[string]interface{} {
|
||||||
|
return m.models
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestMultiLevelRelationExtraction tests extracting deeply nested relations
|
||||||
|
func TestMultiLevelRelationExtraction(t *testing.T) {
|
||||||
|
registry := &mockRegistry{
|
||||||
|
models: map[string]interface{}{
|
||||||
|
"users": TestUser{},
|
||||||
|
"posts": TestPost{},
|
||||||
|
"comments": TestComment{},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
handler := NewHandler(nil, registry)
|
||||||
|
|
||||||
|
// Test data with 3 levels: User -> Posts -> Comments
|
||||||
|
testData := map[string]interface{}{
|
||||||
|
"name": "John Doe",
|
||||||
|
"posts": []map[string]interface{}{
|
||||||
|
{
|
||||||
|
"title": "First Post",
|
||||||
|
"comments": []map[string]interface{}{
|
||||||
|
{"content": "Great post!"},
|
||||||
|
{"content": "Thanks for sharing!"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"title": "Second Post",
|
||||||
|
"comments": []map[string]interface{}{
|
||||||
|
{"content": "Interesting read"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract relations from user
|
||||||
|
cleanedData, relations, err := handler.extractNestedRelations(testData, TestUser{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to extract relations: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify user data is cleaned
|
||||||
|
if len(cleanedData) != 1 || cleanedData["name"] != "John Doe" {
|
||||||
|
t.Errorf("Expected cleaned data to contain only name, got: %+v", cleanedData)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify posts relation was extracted
|
||||||
|
if len(relations) != 1 {
|
||||||
|
t.Errorf("Expected 1 relation (posts), got %d", len(relations))
|
||||||
|
}
|
||||||
|
|
||||||
|
posts, ok := relations["posts"]
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("Expected posts relation to be extracted")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify posts is a slice with 2 items
|
||||||
|
postsSlice, ok := posts.([]map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("Expected posts to be []map[string]interface{}, got %T", posts)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(postsSlice) != 2 {
|
||||||
|
t.Errorf("Expected 2 posts, got %d", len(postsSlice))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify first post has comments
|
||||||
|
if _, hasComments := postsSlice[0]["comments"]; !hasComments {
|
||||||
|
t.Error("Expected first post to have comments")
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Logf("Successfully extracted multi-level nested relations")
|
||||||
|
t.Logf("Cleaned data: %+v", cleanedData)
|
||||||
|
t.Logf("Relations: %d posts with nested comments", len(postsSlice))
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue
Block a user