Compare commits

...

12 Commits

Author SHA1 Message Date
Hein
cb54ec5e27 Better responses for updates and inserts 2025-11-20 09:57:24 +02:00
Hein
7d6a9025f5 Fixed hardcoded id 2025-11-20 09:40:11 +02:00
Hein
35089f511f correctly handle structs with embedded fields 2025-11-20 09:28:37 +02:00
Hein
66b6a0d835 Better registry handling
Some checks are pending
Tests / Run Tests (1.23.x) (push) Waiting to run
Tests / Run Tests (1.24.x) (push) Waiting to run
Tests / Lint Code (push) Waiting to run
Tests / Build (push) Waiting to run
2025-11-19 18:29:24 +02:00
Hein
456c165814 Fixed models being icorrectly set and added SetDefaultRegistry 2025-11-19 18:22:56 +02:00
Hein
850d7b546c Added modelregistry.AddRegistry 2025-11-19 18:18:18 +02:00
Hein
a44ef90d7c Fixes on getRelationshipInfo, ShouldUseNestedProcessor 2025-11-19 18:03:25 +02:00
Hein
8b7db5b31a reflection-based column validation for UpdateQuery 2025-11-19 17:41:15 +02:00
Hein
14daea3b05 Fixes for CUD operations 2025-11-19 15:08:04 +02:00
Hein
35f23b6d9e Recursive crud fix 2025-11-19 14:32:20 +02:00
Hein
53a4e67f70 Specifically call update if a ID was given. 2025-11-19 14:24:39 +02:00
Hein
1289c3af88 Fixed handling post routes as well for the restheadspec
Some checks are pending
Tests / Run Tests (1.23.x) (push) Waiting to run
Tests / Run Tests (1.24.x) (push) Waiting to run
Tests / Lint Code (push) Waiting to run
Tests / Build (push) Waiting to run
2025-11-19 14:04:56 +02:00
12 changed files with 1783 additions and 281 deletions

View File

@@ -9,6 +9,8 @@ import (
"github.com/uptrace/bun" "github.com/uptrace/bun"
"github.com/bitechdev/ResolveSpec/pkg/common" "github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
"github.com/bitechdev/ResolveSpec/pkg/reflection"
) )
// BunAdapter adapts Bun to work with our Database interface // BunAdapter adapts Bun to work with our Database interface
@@ -353,25 +355,50 @@ func (b *BunInsertQuery) Exec(ctx context.Context) (common.Result, error) {
// BunUpdateQuery implements UpdateQuery for Bun // BunUpdateQuery implements UpdateQuery for Bun
type BunUpdateQuery struct { type BunUpdateQuery struct {
query *bun.UpdateQuery query *bun.UpdateQuery
model interface{}
} }
func (b *BunUpdateQuery) Model(model interface{}) common.UpdateQuery { func (b *BunUpdateQuery) Model(model interface{}) common.UpdateQuery {
b.query = b.query.Model(model) b.query = b.query.Model(model)
b.model = model
return b return b
} }
func (b *BunUpdateQuery) Table(table string) common.UpdateQuery { func (b *BunUpdateQuery) Table(table string) common.UpdateQuery {
b.query = b.query.Table(table) b.query = b.query.Table(table)
if b.model == nil {
// Try to get table name from table string if model is not set
model, err := modelregistry.GetModelByName(table)
if err == nil {
b.model = model
}
}
return b return b
} }
func (b *BunUpdateQuery) Set(column string, value interface{}) common.UpdateQuery { func (b *BunUpdateQuery) Set(column string, value interface{}) common.UpdateQuery {
// Validate column is writable if model is set
if b.model != nil && !reflection.IsColumnWritable(b.model, column) {
// Skip scan-only columns
return b
}
b.query = b.query.Set(column+" = ?", value) b.query = b.query.Set(column+" = ?", value)
return b return b
} }
func (b *BunUpdateQuery) SetMap(values map[string]interface{}) common.UpdateQuery { func (b *BunUpdateQuery) SetMap(values map[string]interface{}) common.UpdateQuery {
pkName := reflection.GetPrimaryKeyName(b.model)
for column, value := range values { for column, value := range values {
// Validate column is writable if model is set
if b.model != nil && !reflection.IsColumnWritable(b.model, column) {
// Skip scan-only columns
continue
}
if pkName != "" && column == pkName {
// Skip primary key updates
continue
}
b.query = b.query.Set(column+" = ?", value) b.query = b.query.Set(column+" = ?", value)
} }
return b return b

View File

@@ -8,6 +8,8 @@ import (
"gorm.io/gorm" "gorm.io/gorm"
"github.com/bitechdev/ResolveSpec/pkg/common" "github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
"github.com/bitechdev/ResolveSpec/pkg/reflection"
) )
// GormAdapter adapts GORM to work with our Database interface // GormAdapter adapts GORM to work with our Database interface
@@ -97,6 +99,7 @@ func (g *GormSelectQuery) Table(table string) common.SelectQuery {
g.db = g.db.Table(table) g.db = g.db.Table(table)
// Check if the table name contains schema (e.g., "schema.table") // Check if the table name contains schema (e.g., "schema.table")
g.schema, g.tableName = parseTableName(table) g.schema, g.tableName = parseTableName(table)
return g return g
} }
@@ -339,10 +342,23 @@ func (g *GormUpdateQuery) Model(model interface{}) common.UpdateQuery {
func (g *GormUpdateQuery) Table(table string) common.UpdateQuery { func (g *GormUpdateQuery) Table(table string) common.UpdateQuery {
g.db = g.db.Table(table) g.db = g.db.Table(table)
if g.model == nil {
// Try to get table name from table string if model is not set
model, err := modelregistry.GetModelByName(table)
if err == nil {
g.model = model
}
}
return g return g
} }
func (g *GormUpdateQuery) Set(column string, value interface{}) common.UpdateQuery { func (g *GormUpdateQuery) Set(column string, value interface{}) common.UpdateQuery {
// Validate column is writable if model is set
if g.model != nil && !reflection.IsColumnWritable(g.model, column) {
// Skip read-only columns
return g
}
if g.updates == nil { if g.updates == nil {
g.updates = make(map[string]interface{}) g.updates = make(map[string]interface{})
} }
@@ -353,7 +369,25 @@ func (g *GormUpdateQuery) Set(column string, value interface{}) common.UpdateQue
} }
func (g *GormUpdateQuery) SetMap(values map[string]interface{}) common.UpdateQuery { func (g *GormUpdateQuery) SetMap(values map[string]interface{}) common.UpdateQuery {
g.updates = values
// Filter out read-only columns if model is set
if g.model != nil {
pkName := reflection.GetPrimaryKeyName(g.model)
filteredValues := make(map[string]interface{})
for column, value := range values {
if pkName != "" && column == pkName {
// Skip primary key updates
continue
}
if reflection.IsColumnWritable(g.model, column) {
filteredValues[column] = value
}
}
g.updates = filteredValues
} else {
g.updates = values
}
return g return g
} }

View File

@@ -0,0 +1,161 @@
package database
import (
"testing"
"github.com/bitechdev/ResolveSpec/pkg/reflection"
)
// Test models for bun
type BunTestModel struct {
ID int `bun:"id,pk"`
Name string `bun:"name"`
Email string `bun:"email"`
ComputedCol string `bun:"computed_col,scanonly"`
}
// Test models for gorm
type GormTestModel struct {
ID int `gorm:"column:id;primaryKey"`
Name string `gorm:"column:name"`
Email string `gorm:"column:email"`
ReadOnlyCol string `gorm:"column:readonly_col;->"`
NoWriteCol string `gorm:"column:nowrite_col;<-:false"`
}
func TestIsColumnWritable_Bun(t *testing.T) {
model := &BunTestModel{}
tests := []struct {
name string
columnName string
expected bool
}{
{
name: "writable column - id",
columnName: "id",
expected: true,
},
{
name: "writable column - name",
columnName: "name",
expected: true,
},
{
name: "writable column - email",
columnName: "email",
expected: true,
},
{
name: "scanonly column should not be writable",
columnName: "computed_col",
expected: false,
},
{
name: "non-existent column should be writable (dynamic)",
columnName: "nonexistent",
expected: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := reflection.IsColumnWritable(model, tt.columnName)
if result != tt.expected {
t.Errorf("IsColumnWritable(%q) = %v, want %v", tt.columnName, result, tt.expected)
}
})
}
}
func TestIsColumnWritable_Gorm(t *testing.T) {
model := &GormTestModel{}
tests := []struct {
name string
columnName string
expected bool
}{
{
name: "writable column - id",
columnName: "id",
expected: true,
},
{
name: "writable column - name",
columnName: "name",
expected: true,
},
{
name: "writable column - email",
columnName: "email",
expected: true,
},
{
name: "read-only column with -> should not be writable",
columnName: "readonly_col",
expected: false,
},
{
name: "column with <-:false should not be writable",
columnName: "nowrite_col",
expected: false,
},
{
name: "non-existent column should be writable (dynamic)",
columnName: "nonexistent",
expected: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := reflection.IsColumnWritable(model, tt.columnName)
if result != tt.expected {
t.Errorf("IsColumnWritable(%q) = %v, want %v", tt.columnName, result, tt.expected)
}
})
}
}
func TestBunUpdateQuery_SetMap_FiltersScanOnly(t *testing.T) {
// Note: This is a unit test for the validation logic only.
// We can't fully test the bun query without a database connection,
// but we've verified the validation logic in TestIsColumnWritable_Bun
t.Skip("Skipping integration test - validation logic tested in TestIsColumnWritable_Bun")
}
func TestGormUpdateQuery_SetMap_FiltersReadOnly(t *testing.T) {
model := &GormTestModel{}
query := &GormUpdateQuery{
model: model,
}
// SetMap should filter out read-only columns
values := map[string]interface{}{
"name": "John",
"email": "john@example.com",
"readonly_col": "should_be_filtered",
"nowrite_col": "should_also_be_filtered",
}
query.SetMap(values)
// Check that the updates map only contains writable columns
if updates, ok := query.updates.(map[string]interface{}); ok {
if _, exists := updates["readonly_col"]; exists {
t.Error("readonly_col should have been filtered out")
}
if _, exists := updates["nowrite_col"]; exists {
t.Error("nowrite_col should have been filtered out")
}
if _, exists := updates["name"]; !exists {
t.Error("name should be in updates")
}
if _, exists := updates["email"]; !exists {
t.Error("email should be in updates")
}
} else {
t.Error("updates should be a map[string]interface{}")
}
}

View File

@@ -111,6 +111,9 @@ func (p *NestedCUDProcessor) ProcessNestedCUD(
// Inject parent IDs for foreign key resolution // Inject parent IDs for foreign key resolution
p.injectForeignKeys(regularData, modelType, parentIDs) p.injectForeignKeys(regularData, modelType, parentIDs)
// Get the primary key name for this model
pkName := reflection.GetPrimaryKeyName(model)
// Process based on operation // Process based on operation
switch strings.ToLower(operation) { switch strings.ToLower(operation) {
case "insert", "create": case "insert", "create":
@@ -128,30 +131,30 @@ func (p *NestedCUDProcessor) ProcessNestedCUD(
} }
case "update": case "update":
rows, err := p.processUpdate(ctx, regularData, tableName, data["id"]) rows, err := p.processUpdate(ctx, regularData, tableName, data[pkName])
if err != nil { if err != nil {
return nil, fmt.Errorf("update failed: %w", err) return nil, fmt.Errorf("update failed: %w", err)
} }
result.ID = data["id"] result.ID = data[pkName]
result.AffectedRows = rows result.AffectedRows = rows
result.Data = regularData result.Data = regularData
// Process child relations for update // Process child relations for update
if err := p.processChildRelations(ctx, "update", data["id"], relationFields, result.RelationData, modelType); err != nil { if err := p.processChildRelations(ctx, "update", data[pkName], relationFields, result.RelationData, modelType); err != nil {
return nil, fmt.Errorf("failed to process child relations: %w", err) return nil, fmt.Errorf("failed to process child relations: %w", err)
} }
case "delete": case "delete":
// Process child relations first (for referential integrity) // Process child relations first (for referential integrity)
if err := p.processChildRelations(ctx, "delete", data["id"], relationFields, result.RelationData, modelType); err != nil { if err := p.processChildRelations(ctx, "delete", data[pkName], relationFields, result.RelationData, modelType); err != nil {
return nil, fmt.Errorf("failed to process child relations before delete: %w", err) return nil, fmt.Errorf("failed to process child relations before delete: %w", err)
} }
rows, err := p.processDelete(ctx, tableName, data["id"]) rows, err := p.processDelete(ctx, tableName, data[pkName])
if err != nil { if err != nil {
return nil, fmt.Errorf("delete failed: %w", err) return nil, fmt.Errorf("delete failed: %w", err)
} }
result.ID = data["id"] result.ID = data[pkName]
result.AffectedRows = rows result.AffectedRows = rows
result.Data = regularData result.Data = regularData
@@ -378,8 +381,16 @@ func (p *NestedCUDProcessor) getTableNameForModel(model interface{}, defaultName
} }
// ShouldUseNestedProcessor determines if we should use nested CUD processing // ShouldUseNestedProcessor determines if we should use nested CUD processing
// It checks if the data contains nested relations or a _request field // It recursively checks if the data contains:
// 1. A _request field at any level, OR
// 2. Nested relations that themselves contain further nested relations or _request fields
// This ensures nested processing is only used when there are deeply nested operations
func ShouldUseNestedProcessor(data map[string]interface{}, model interface{}, relationshipHelper RelationshipInfoProvider) bool { func ShouldUseNestedProcessor(data map[string]interface{}, model interface{}, relationshipHelper RelationshipInfoProvider) bool {
return shouldUseNestedProcessorDepth(data, model, relationshipHelper, 0)
}
// shouldUseNestedProcessorDepth is the internal recursive implementation with depth tracking
func shouldUseNestedProcessorDepth(data map[string]interface{}, model interface{}, relationshipHelper RelationshipInfoProvider, depth int) bool {
// Check for _request field // Check for _request field
if _, hasCRUDRequest := data["_request"]; hasCRUDRequest { if _, hasCRUDRequest := data["_request"]; hasCRUDRequest {
return true return true
@@ -406,10 +417,34 @@ func ShouldUseNestedProcessor(data map[string]interface{}, model interface{}, re
relInfo := relationshipHelper.GetRelationshipInfo(modelType, key) relInfo := relationshipHelper.GetRelationshipInfo(modelType, key)
if relInfo != nil { if relInfo != nil {
// Check if the value is actually nested data (object or array) // Check if the value is actually nested data (object or array)
switch value.(type) { switch v := value.(type) {
case map[string]interface{}, []interface{}, []map[string]interface{}: case map[string]interface{}, []interface{}, []map[string]interface{}:
logger.Debug("Found nested relation field: %s", key) // If we're already at a nested level (depth > 0) and found a relation,
return true // that means we have multi-level nesting, so return true
if depth > 0 {
return true
}
// At depth 0, recurse to check if the nested data has further nesting
switch typedValue := v.(type) {
case map[string]interface{}:
if shouldUseNestedProcessorDepth(typedValue, relInfo.RelatedModel, relationshipHelper, depth+1) {
return true
}
case []interface{}:
for _, item := range typedValue {
if itemMap, ok := item.(map[string]interface{}); ok {
if shouldUseNestedProcessorDepth(itemMap, relInfo.RelatedModel, relationshipHelper, depth+1) {
return true
}
}
}
case []map[string]interface{}:
for _, itemMap := range typedValue {
if shouldUseNestedProcessorDepth(itemMap, relInfo.RelatedModel, relationshipHelper, depth+1) {
return true
}
}
}
} }
} }
} }

View File

@@ -17,6 +17,10 @@ var defaultRegistry = &DefaultModelRegistry{
models: make(map[string]interface{}), models: make(map[string]interface{}),
} }
// Global list of registries (searched in order)
var registries = []*DefaultModelRegistry{defaultRegistry}
var registriesMutex sync.RWMutex
// NewModelRegistry creates a new model registry // NewModelRegistry creates a new model registry
func NewModelRegistry() *DefaultModelRegistry { func NewModelRegistry() *DefaultModelRegistry {
return &DefaultModelRegistry{ return &DefaultModelRegistry{
@@ -24,6 +28,34 @@ func NewModelRegistry() *DefaultModelRegistry {
} }
} }
func SetDefaultRegistry(registry *DefaultModelRegistry) {
registriesMutex.Lock()
foundAt := -1
for idx, r := range registries {
if r == defaultRegistry {
foundAt = idx
break
}
}
defaultRegistry = registry
if foundAt >= 0 {
registries[foundAt] = registry
} else {
registries = append([]*DefaultModelRegistry{registry}, registries...)
}
defer registriesMutex.Unlock()
}
// AddRegistry adds a registry to the global list of registries
// Registries are searched in the order they were added
func AddRegistry(registry *DefaultModelRegistry) {
registriesMutex.Lock()
defer registriesMutex.Unlock()
registries = append(registries, registry)
}
func (r *DefaultModelRegistry) RegisterModel(name string, model interface{}) error { func (r *DefaultModelRegistry) RegisterModel(name string, model interface{}) error {
r.mutex.Lock() r.mutex.Lock()
defer r.mutex.Unlock() defer r.mutex.Unlock()
@@ -107,9 +139,19 @@ func RegisterModel(model interface{}, name string) error {
return defaultRegistry.RegisterModel(name, model) return defaultRegistry.RegisterModel(name, model)
} }
// GetModelByName retrieves a model from the default global registry by name // GetModelByName retrieves a model by searching through all registries in order
// Returns the first match found
func GetModelByName(name string) (interface{}, error) { func GetModelByName(name string) (interface{}, error) {
return defaultRegistry.GetModel(name) registriesMutex.RLock()
defer registriesMutex.RUnlock()
for _, registry := range registries {
if model, err := registry.GetModel(name); err == nil {
return model, nil
}
}
return nil, fmt.Errorf("model %s not found in any registry", name)
} }
// IterateModels iterates over all models in the default global registry // IterateModels iterates over all models in the default global registry
@@ -122,14 +164,26 @@ func IterateModels(fn func(name string, model interface{})) {
} }
} }
// GetModels returns a list of all models in the default global registry // GetModels returns a list of all models from all registries
// Models are collected in registry order, with duplicates included
func GetModels() []interface{} { func GetModels() []interface{} {
defaultRegistry.mutex.RLock() registriesMutex.RLock()
defer defaultRegistry.mutex.RUnlock() defer registriesMutex.RUnlock()
models := make([]interface{}, 0, len(defaultRegistry.models)) var models []interface{}
for _, model := range defaultRegistry.models { seen := make(map[string]bool)
models = append(models, model)
for _, registry := range registries {
registry.mutex.RLock()
for name, model := range registry.models {
// Only add the first occurrence of each model name
if !seen[name] {
models = append(models, model)
seen[name] = true
}
}
registry.mutex.RUnlock()
} }
return models return models
} }

View File

@@ -18,6 +18,7 @@ type ModelFieldDetail struct {
} }
// GetModelColumnDetail - Get a list of columns in the SQL declaration of the model // GetModelColumnDetail - Get a list of columns in the SQL declaration of the model
// This function recursively processes embedded structs to include their fields
func GetModelColumnDetail(record reflect.Value) []ModelFieldDetail { func GetModelColumnDetail(record reflect.Value) []ModelFieldDetail {
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
@@ -37,14 +38,43 @@ func GetModelColumnDetail(record reflect.Value) []ModelFieldDetail {
if record.Kind() != reflect.Struct { if record.Kind() != reflect.Struct {
return lst return lst
} }
collectFieldDetails(record, &lst)
return lst
}
// collectFieldDetails recursively collects field details from a struct value and its embedded fields
func collectFieldDetails(record reflect.Value, lst *[]ModelFieldDetail) {
modeltype := record.Type() modeltype := record.Type()
for i := 0; i < modeltype.NumField(); i++ { for i := 0; i < modeltype.NumField(); i++ {
fieldtype := modeltype.Field(i) fieldtype := modeltype.Field(i)
fieldValue := record.Field(i)
// Check if this is an embedded struct
if fieldtype.Anonymous {
// Unwrap pointer type if necessary
embeddedValue := fieldValue
if fieldValue.Kind() == reflect.Pointer {
if fieldValue.IsNil() {
// Skip nil embedded pointers
continue
}
embeddedValue = fieldValue.Elem()
}
// Recursively process embedded struct
if embeddedValue.Kind() == reflect.Struct {
collectFieldDetails(embeddedValue, lst)
continue
}
}
gormdetail := fieldtype.Tag.Get("gorm") gormdetail := fieldtype.Tag.Get("gorm")
gormdetail = strings.Trim(gormdetail, " ") gormdetail = strings.Trim(gormdetail, " ")
fielddetail := ModelFieldDetail{} fielddetail := ModelFieldDetail{}
fielddetail.FieldValue = record.Field(i) fielddetail.FieldValue = fieldValue
fielddetail.Name = fieldtype.Name fielddetail.Name = fieldtype.Name
fielddetail.DataType = fieldtype.Type.Name() fielddetail.DataType = fieldtype.Type.Name()
fielddetail.SQLName = fnFindKeyVal(gormdetail, "column:") fielddetail.SQLName = fnFindKeyVal(gormdetail, "column:")
@@ -80,10 +110,8 @@ func GetModelColumnDetail(record reflect.Value) []ModelFieldDetail {
} }
// ";foreignkey:rid_parent;association_foreignkey:id_atevent;save_associations:false;association_autocreate:false;" // ";foreignkey:rid_parent;association_foreignkey:id_atevent;save_associations:false;association_autocreate:false;"
lst = append(lst, fielddetail) *lst = append(*lst, fielddetail)
} }
return lst
} }
func fnFindKeyVal(src, key string) string { func fnFindKeyVal(src, key string) string {

View File

@@ -45,8 +45,104 @@ func GetPrimaryKeyName(model any) string {
return "" return ""
} }
// GetPrimaryKeyValue extracts the primary key value from a model instance
// Returns the value of the primary key field
func GetPrimaryKeyValue(model any) any {
if model == nil || reflect.TypeOf(model) == nil {
return nil
}
val := reflect.ValueOf(model)
if val.Kind() == reflect.Pointer {
val = val.Elem()
}
if val.Kind() != reflect.Struct {
return nil
}
// Try Bun tag first
if pkValue := findPrimaryKeyValue(val, "bun"); pkValue != nil {
return pkValue
}
// Fall back to GORM tag
if pkValue := findPrimaryKeyValue(val, "gorm"); pkValue != nil {
return pkValue
}
// Last resort: look for field named "ID" or "Id"
if pkValue := findFieldByName(val, "id"); pkValue != nil {
return pkValue
}
return nil
}
// findPrimaryKeyValue recursively searches for a primary key field in the struct
func findPrimaryKeyValue(val reflect.Value, ormType string) any {
typ := val.Type()
for i := 0; i < typ.NumField(); i++ {
field := typ.Field(i)
fieldValue := val.Field(i)
// Check if this is an embedded struct
if field.Anonymous && field.Type.Kind() == reflect.Struct {
// Recursively search in embedded struct
if pkValue := findPrimaryKeyValue(fieldValue, ormType); pkValue != nil {
return pkValue
}
continue
}
// Check for primary key tag
switch ormType {
case "bun":
bunTag := field.Tag.Get("bun")
if strings.Contains(bunTag, "pk") && fieldValue.CanInterface() {
return fieldValue.Interface()
}
case "gorm":
gormTag := field.Tag.Get("gorm")
if strings.Contains(gormTag, "primaryKey") && fieldValue.CanInterface() {
return fieldValue.Interface()
}
}
}
return nil
}
// findFieldByName recursively searches for a field by name in the struct
func findFieldByName(val reflect.Value, name string) any {
typ := val.Type()
for i := 0; i < typ.NumField(); i++ {
field := typ.Field(i)
fieldValue := val.Field(i)
// Check if this is an embedded struct
if field.Anonymous && field.Type.Kind() == reflect.Struct {
// Recursively search in embedded struct
if result := findFieldByName(fieldValue, name); result != nil {
return result
}
continue
}
// Check if field name matches
if strings.ToLower(field.Name) == name && fieldValue.CanInterface() {
return fieldValue.Interface()
}
}
return nil
}
// GetModelColumns extracts all column names from a model using reflection // GetModelColumns extracts all column names from a model using reflection
// It checks bun tags first, then gorm tags, then json tags, and finally falls back to lowercase field names // It checks bun tags first, then gorm tags, then json tags, and finally falls back to lowercase field names
// This function recursively processes embedded structs to include their fields
func GetModelColumns(model any) []string { func GetModelColumns(model any) []string {
var columns []string var columns []string
@@ -62,18 +158,38 @@ func GetModelColumns(model any) []string {
return columns return columns
} }
for i := 0; i < modelType.NumField(); i++ { collectColumnsFromType(modelType, &columns)
field := modelType.Field(i)
return columns
}
// collectColumnsFromType recursively collects column names from a struct type and its embedded fields
func collectColumnsFromType(typ reflect.Type, columns *[]string) {
for i := 0; i < typ.NumField(); i++ {
field := typ.Field(i)
// Check if this is an embedded struct
if field.Anonymous {
// Unwrap pointer type if necessary
fieldType := field.Type
if fieldType.Kind() == reflect.Pointer {
fieldType = fieldType.Elem()
}
// Recursively process embedded struct
if fieldType.Kind() == reflect.Struct {
collectColumnsFromType(fieldType, columns)
continue
}
}
// Get column name using the same logic as primary key extraction // Get column name using the same logic as primary key extraction
columnName := getColumnNameFromField(field) columnName := getColumnNameFromField(field)
if columnName != "" { if columnName != "" {
columns = append(columns, columnName) *columns = append(*columns, columnName)
} }
} }
return columns
} }
// getColumnNameFromField extracts the column name from a struct field // getColumnNameFromField extracts the column name from a struct field
@@ -110,6 +226,7 @@ func getColumnNameFromField(field reflect.StructField) string {
} }
// getPrimaryKeyFromReflection uses reflection to find the primary key field // getPrimaryKeyFromReflection uses reflection to find the primary key field
// This function recursively searches embedded structs
func getPrimaryKeyFromReflection(model any, ormType string) string { func getPrimaryKeyFromReflection(model any, ormType string) string {
val := reflect.ValueOf(model) val := reflect.ValueOf(model)
if val.Kind() == reflect.Pointer { if val.Kind() == reflect.Pointer {
@@ -121,9 +238,31 @@ func getPrimaryKeyFromReflection(model any, ormType string) string {
} }
typ := val.Type() typ := val.Type()
return findPrimaryKeyNameFromType(typ, ormType)
}
// findPrimaryKeyNameFromType recursively searches for the primary key field name in a struct type
func findPrimaryKeyNameFromType(typ reflect.Type, ormType string) string {
for i := 0; i < typ.NumField(); i++ { for i := 0; i < typ.NumField(); i++ {
field := typ.Field(i) field := typ.Field(i)
// Check if this is an embedded struct
if field.Anonymous {
// Unwrap pointer type if necessary
fieldType := field.Type
if fieldType.Kind() == reflect.Pointer {
fieldType = fieldType.Elem()
}
// Recursively search in embedded struct
if fieldType.Kind() == reflect.Struct {
if pkName := findPrimaryKeyNameFromType(fieldType, ormType); pkName != "" {
return pkName
}
}
continue
}
switch ormType { switch ormType {
case "gorm": case "gorm":
// Check for gorm tag with primaryKey // Check for gorm tag with primaryKey
@@ -175,8 +314,129 @@ func ExtractColumnFromGormTag(tag string) string {
// Example: ",pk" -> "" (will fall back to json tag) // Example: ",pk" -> "" (will fall back to json tag)
func ExtractColumnFromBunTag(tag string) string { func ExtractColumnFromBunTag(tag string) string {
parts := strings.Split(tag, ",") parts := strings.Split(tag, ",")
if strings.HasPrefix(strings.ToLower(tag), "table:") || strings.HasPrefix(strings.ToLower(tag), "rel:") || strings.HasPrefix(strings.ToLower(tag), "join:") {
return ""
}
if len(parts) > 0 && parts[0] != "" { if len(parts) > 0 && parts[0] != "" {
return parts[0] return parts[0]
} }
return "" return ""
} }
// IsColumnWritable checks if a column can be written to in the database
// For bun: returns false if the field has "scanonly" tag
// For gorm: returns false if the field has "<-:false" or "->" (read-only) tag
// This function recursively searches embedded structs
func IsColumnWritable(model any, columnName string) bool {
modelType := reflect.TypeOf(model)
// Unwrap pointers to get to the base struct type
for modelType != nil && modelType.Kind() == reflect.Pointer {
modelType = modelType.Elem()
}
// Validate that we have a struct type
if modelType == nil || modelType.Kind() != reflect.Struct {
return false
}
found, writable := isColumnWritableInType(modelType, columnName)
if found {
return writable
}
// Column not found in model, allow it (might be a dynamic column)
return true
}
// isColumnWritableInType recursively searches for a column and checks if it's writable
// Returns (found, writable) where found indicates if the column was found
func isColumnWritableInType(typ reflect.Type, columnName string) (bool, bool) {
for i := 0; i < typ.NumField(); i++ {
field := typ.Field(i)
// Check if this is an embedded struct
if field.Anonymous {
// Unwrap pointer type if necessary
fieldType := field.Type
if fieldType.Kind() == reflect.Pointer {
fieldType = fieldType.Elem()
}
// Recursively search in embedded struct
if fieldType.Kind() == reflect.Struct {
if found, writable := isColumnWritableInType(fieldType, columnName); found {
return true, writable
}
}
continue
}
// Check if this field matches the column name
fieldColumnName := getColumnNameFromField(field)
if fieldColumnName != columnName {
continue
}
// Found the field, now check if it's writable
// Check bun tag for scanonly
bunTag := field.Tag.Get("bun")
if bunTag != "" {
if isBunFieldScanOnly(bunTag) {
return true, false
}
}
// Check gorm tag for write restrictions
gormTag := field.Tag.Get("gorm")
if gormTag != "" {
if isGormFieldReadOnly(gormTag) {
return true, false
}
}
// Column is writable
return true, true
}
// Column not found
return false, false
}
// isBunFieldScanOnly checks if a bun tag indicates the field is scan-only
// Example: "column_name,scanonly" -> true
func isBunFieldScanOnly(tag string) bool {
parts := strings.Split(tag, ",")
for _, part := range parts {
if strings.TrimSpace(part) == "scanonly" {
return true
}
}
return false
}
// isGormFieldReadOnly checks if a gorm tag indicates the field is read-only
// Examples:
// - "<-:false" -> true (no writes allowed)
// - "->" -> true (read-only, common pattern)
// - "column:name;->" -> true
// - "<-:create" -> false (writes allowed on create)
func isGormFieldReadOnly(tag string) bool {
parts := strings.Split(tag, ";")
for _, part := range parts {
part = strings.TrimSpace(part)
// Check for read-only marker
if part == "->" {
return true
}
// Check for write restrictions
if value, found := strings.CutPrefix(part, "<-:"); found {
if value == "false" {
return true
}
}
}
return false
}

View File

@@ -231,3 +231,246 @@ func TestGetModelColumns(t *testing.T) {
}) })
} }
} }
// Test models with embedded structs
type BaseModel struct {
ID int `bun:"rid_base,pk" json:"id"`
CreatedAt string `bun:"created_at" json:"created_at"`
}
type AdhocBuffer struct {
CQL1 string `json:"cql1,omitempty" gorm:"->" bun:",scanonly"`
CQL2 string `json:"cql2,omitempty" gorm:"->" bun:",scanonly"`
RowNumber int64 `json:"_rownumber,omitempty" gorm:"-" bun:",scanonly"`
}
type ModelWithEmbedded struct {
BaseModel
Name string `bun:"name" json:"name"`
Description string `bun:"description" json:"description"`
AdhocBuffer
}
type GormBaseModel struct {
ID int `gorm:"column:rid_base;primaryKey" json:"id"`
CreatedAt string `gorm:"column:created_at" json:"created_at"`
}
type GormAdhocBuffer struct {
CQL1 string `json:"cql1,omitempty" gorm:"column:cql1;->" bun:",scanonly"`
CQL2 string `json:"cql2,omitempty" gorm:"column:cql2;->" bun:",scanonly"`
RowNumber int64 `json:"_rownumber,omitempty" gorm:"-" bun:",scanonly"`
}
type GormModelWithEmbedded struct {
GormBaseModel
Name string `gorm:"column:name" json:"name"`
Description string `gorm:"column:description" json:"description"`
GormAdhocBuffer
}
func TestGetPrimaryKeyNameWithEmbedded(t *testing.T) {
tests := []struct {
name string
model any
expected string
}{
{
name: "Bun model with embedded base",
model: ModelWithEmbedded{},
expected: "rid_base",
},
{
name: "Bun model with embedded base (pointer)",
model: &ModelWithEmbedded{},
expected: "rid_base",
},
{
name: "GORM model with embedded base",
model: GormModelWithEmbedded{},
expected: "rid_base",
},
{
name: "GORM model with embedded base (pointer)",
model: &GormModelWithEmbedded{},
expected: "rid_base",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GetPrimaryKeyName(tt.model)
if result != tt.expected {
t.Errorf("GetPrimaryKeyName() = %v, want %v", result, tt.expected)
}
})
}
}
func TestGetPrimaryKeyValueWithEmbedded(t *testing.T) {
bunModel := ModelWithEmbedded{
BaseModel: BaseModel{
ID: 123,
CreatedAt: "2024-01-01",
},
Name: "Test",
Description: "Test Description",
}
gormModel := GormModelWithEmbedded{
GormBaseModel: GormBaseModel{
ID: 456,
CreatedAt: "2024-01-02",
},
Name: "GORM Test",
Description: "GORM Test Description",
}
tests := []struct {
name string
model any
expected any
}{
{
name: "Bun model with embedded base",
model: bunModel,
expected: 123,
},
{
name: "Bun model with embedded base (pointer)",
model: &bunModel,
expected: 123,
},
{
name: "GORM model with embedded base",
model: gormModel,
expected: 456,
},
{
name: "GORM model with embedded base (pointer)",
model: &gormModel,
expected: 456,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GetPrimaryKeyValue(tt.model)
if result != tt.expected {
t.Errorf("GetPrimaryKeyValue() = %v, want %v", result, tt.expected)
}
})
}
}
func TestGetModelColumnsWithEmbedded(t *testing.T) {
tests := []struct {
name string
model any
expected []string
}{
{
name: "Bun model with embedded structs",
model: ModelWithEmbedded{},
expected: []string{"rid_base", "created_at", "name", "description", "cql1", "cql2", "_rownumber"},
},
{
name: "Bun model with embedded structs (pointer)",
model: &ModelWithEmbedded{},
expected: []string{"rid_base", "created_at", "name", "description", "cql1", "cql2", "_rownumber"},
},
{
name: "GORM model with embedded structs",
model: GormModelWithEmbedded{},
expected: []string{"rid_base", "created_at", "name", "description", "cql1", "cql2", "_rownumber"},
},
{
name: "GORM model with embedded structs (pointer)",
model: &GormModelWithEmbedded{},
expected: []string{"rid_base", "created_at", "name", "description", "cql1", "cql2", "_rownumber"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GetModelColumns(tt.model)
if len(result) != len(tt.expected) {
t.Errorf("GetModelColumns() returned %d columns, want %d. Got: %v", len(result), len(tt.expected), result)
return
}
for i, col := range result {
if col != tt.expected[i] {
t.Errorf("GetModelColumns()[%d] = %v, want %v", i, col, tt.expected[i])
}
}
})
}
}
func TestIsColumnWritableWithEmbedded(t *testing.T) {
tests := []struct {
name string
model any
columnName string
expected bool
}{
{
name: "Bun model - writable column in main struct",
model: ModelWithEmbedded{},
columnName: "name",
expected: true,
},
{
name: "Bun model - writable column in embedded base",
model: ModelWithEmbedded{},
columnName: "rid_base",
expected: true,
},
{
name: "Bun model - scanonly column in embedded adhoc buffer",
model: ModelWithEmbedded{},
columnName: "cql1",
expected: false,
},
{
name: "Bun model - scanonly column _rownumber",
model: ModelWithEmbedded{},
columnName: "_rownumber",
expected: false,
},
{
name: "GORM model - writable column in main struct",
model: GormModelWithEmbedded{},
columnName: "name",
expected: true,
},
{
name: "GORM model - writable column in embedded base",
model: GormModelWithEmbedded{},
columnName: "rid_base",
expected: true,
},
{
name: "GORM model - readonly column in embedded adhoc buffer",
model: GormModelWithEmbedded{},
columnName: "cql1",
expected: false,
},
{
name: "GORM model - readonly column _rownumber",
model: GormModelWithEmbedded{},
columnName: "_rownumber",
expected: false, // bun:",scanonly" marks it as read-only, takes precedence over gorm:"-"
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := IsColumnWritable(tt.model, tt.columnName)
if result != tt.expected {
t.Errorf("IsColumnWritable(%s) = %v, want %v", tt.columnName, result, tt.expected)
}
})
}
}

View File

@@ -7,6 +7,7 @@ import (
"net/http" "net/http"
"reflect" "reflect"
"runtime/debug" "runtime/debug"
"strconv"
"strings" "strings"
"github.com/bitechdev/ResolveSpec/pkg/common" "github.com/bitechdev/ResolveSpec/pkg/common"
@@ -133,9 +134,15 @@ func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[s
h.sendError(w, http.StatusBadRequest, "invalid_request", "Invalid request body", err) h.sendError(w, http.StatusBadRequest, "invalid_request", "Invalid request body", err)
return return
} }
h.handleCreate(ctx, w, data, options) validId, _ := strconv.ParseInt(id, 10, 64)
if validId > 0 {
h.handleUpdate(ctx, w, id, nil, data, options)
} else {
h.handleCreate(ctx, w, data, options)
}
case "PUT", "PATCH": case "PUT", "PATCH":
// Update operation // Update operation
body, err := r.Body() body, err := r.Body()
if err != nil { if err != nil {
logger.Error("Failed to read request body: %v", err) logger.Error("Failed to read request body: %v", err)
@@ -577,22 +584,6 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
logger.Info("Creating record in %s.%s", schema, entity) logger.Info("Creating record in %s.%s", schema, entity)
// Check if data is a single map with nested relations
if dataMap, ok := data.(map[string]interface{}); ok {
if h.shouldUseNestedProcessor(dataMap, model) {
logger.Info("Using nested CUD processor for create operation")
result, err := h.nestedProcessor.ProcessNestedCUD(ctx, "insert", dataMap, model, make(map[string]interface{}), tableName)
if err != nil {
logger.Error("Error in nested create: %v", err)
h.sendError(w, http.StatusInternalServerError, "create_error", "Error creating record with nested data", err)
return
}
logger.Info("Successfully created record with nested data, ID: %v", result.ID)
h.sendResponseWithOptions(w, result.Data, nil, &options)
return
}
}
// Execute BeforeCreate hooks // Execute BeforeCreate hooks
hookCtx := &HookContext{ hookCtx := &HookContext{
Context: ctx, Context: ctx,
@@ -615,172 +606,135 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
// Use potentially modified data from hook context // Use potentially modified data from hook context
data = hookCtx.Data data = hookCtx.Data
// Handle batch creation // Normalize data to slice for unified processing
dataValue := reflect.ValueOf(data) dataSlice := h.normalizeToSlice(data)
if dataValue.Kind() == reflect.Slice || dataValue.Kind() == reflect.Array { logger.Debug("Processing %d item(s) for creation", len(dataSlice))
logger.Debug("Batch creation detected, count: %d", dataValue.Len())
// Check if any item needs nested processing // Store original data maps for merging later
hasNestedData := false originalDataMaps := make([]map[string]interface{}, 0, len(dataSlice))
for i := 0; i < dataValue.Len(); i++ {
item := dataValue.Index(i).Interface()
if itemMap, ok := item.(map[string]interface{}); ok {
if h.shouldUseNestedProcessor(itemMap, model) {
hasNestedData = true
break
}
}
}
if hasNestedData { // Process all items in a transaction
logger.Info("Using nested CUD processor for batch create with nested data") results := make([]interface{}, 0, len(dataSlice))
results := make([]interface{}, 0, dataValue.Len()) err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
err := h.db.RunInTransaction(ctx, func(tx common.Database) error { // Create temporary nested processor with transaction
// Temporarily swap the database to use transaction txNestedProcessor := common.NewNestedCUDProcessor(tx, h.registry, h)
originalDB := h.nestedProcessor
h.nestedProcessor = common.NewNestedCUDProcessor(tx, h.registry, h)
defer func() {
h.nestedProcessor = originalDB
}()
for i := 0; i < dataValue.Len(); i++ { for i, item := range dataSlice {
item := dataValue.Index(i).Interface() itemMap, ok := item.(map[string]interface{})
if itemMap, ok := item.(map[string]interface{}); ok { if !ok {
result, err := h.nestedProcessor.ProcessNestedCUD(ctx, "insert", itemMap, model, make(map[string]interface{}), tableName) // Convert to map if needed
if err != nil {
return fmt.Errorf("failed to process item: %w", err)
}
results = append(results, result.Data)
}
}
return nil
})
if err != nil {
logger.Error("Error creating records with nested data: %v", err)
h.sendError(w, http.StatusInternalServerError, "create_error", "Error creating records with nested data", err)
return
}
// Execute AfterCreate hooks
hookCtx.Result = map[string]interface{}{"created": len(results), "data": results}
hookCtx.Error = nil
if err := h.hooks.Execute(AfterCreate, hookCtx); err != nil {
logger.Error("AfterCreate hook failed: %v", err)
h.sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err)
return
}
logger.Info("Successfully created %d records with nested data", len(results))
h.sendResponseWithOptions(w, results, nil, &options)
return
}
// Standard batch insert without nested relations
// Use transaction for batch insert
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
for i := 0; i < dataValue.Len(); i++ {
item := dataValue.Index(i).Interface()
// Convert item to model type - create a pointer to the model
modelValue := reflect.New(reflect.TypeOf(model)).Interface()
jsonData, err := json.Marshal(item) jsonData, err := json.Marshal(item)
if err != nil { if err != nil {
return fmt.Errorf("failed to marshal item: %w", err) return fmt.Errorf("failed to marshal item %d: %w", i, err)
} }
if err := json.Unmarshal(jsonData, modelValue); err != nil { itemMap = make(map[string]interface{})
return fmt.Errorf("failed to unmarshal item: %w", err) if err := json.Unmarshal(jsonData, &itemMap); err != nil {
} return fmt.Errorf("failed to unmarshal item %d: %w", i, err)
query := tx.NewInsert().Model(modelValue).Table(tableName)
// Execute BeforeScan hooks - pass query chain so hooks can modify it
batchHookCtx := &HookContext{
Context: ctx,
Handler: h,
Schema: schema,
Entity: entity,
TableName: tableName,
Model: model,
Options: options,
Data: modelValue,
Writer: w,
Query: query,
}
if err := h.hooks.Execute(BeforeScan, batchHookCtx); err != nil {
return fmt.Errorf("BeforeScan hook failed: %w", err)
}
// Use potentially modified query from hook context
if modifiedQuery, ok := batchHookCtx.Query.(common.InsertQuery); ok {
query = modifiedQuery
}
if _, err := query.Exec(ctx); err != nil {
return fmt.Errorf("failed to insert record: %w", err)
} }
} }
return nil
})
if err != nil { // Store a copy of the original data map for merging later
logger.Error("Error creating records: %v", err) originalMap := make(map[string]interface{})
h.sendError(w, http.StatusInternalServerError, "create_error", "Error creating records", err) for k, v := range itemMap {
return originalMap[k] = v
}
originalDataMaps = append(originalDataMaps, originalMap)
// Extract nested relations if present (but don't process them yet)
var nestedRelations map[string]interface{}
if h.shouldUseNestedProcessor(itemMap, model) {
logger.Debug("Extracting nested relations for item %d", i)
cleanedData, relations, err := h.extractNestedRelations(itemMap, model)
if err != nil {
return fmt.Errorf("failed to extract nested relations for item %d: %w", i, err)
}
itemMap = cleanedData
nestedRelations = relations
}
// Convert item to model type - create a pointer to the model
modelValue := reflect.New(reflect.TypeOf(model)).Interface()
jsonData, err := json.Marshal(itemMap)
if err != nil {
return fmt.Errorf("failed to marshal item %d: %w", i, err)
}
if err := json.Unmarshal(jsonData, modelValue); err != nil {
return fmt.Errorf("failed to unmarshal item %d: %w", i, err)
}
// Create insert query
query := tx.NewInsert().Model(modelValue).Table(tableName).Returning("*")
// Execute BeforeScan hooks - pass query chain so hooks can modify it
itemHookCtx := &HookContext{
Context: ctx,
Handler: h,
Schema: schema,
Entity: entity,
TableName: tableName,
Model: model,
Options: options,
Data: modelValue,
Writer: w,
Query: query,
}
if err := h.hooks.Execute(BeforeScan, itemHookCtx); err != nil {
return fmt.Errorf("BeforeScan hook failed for item %d: %w", i, err)
}
// Use potentially modified query from hook context
if modifiedQuery, ok := itemHookCtx.Query.(common.InsertQuery); ok {
query = modifiedQuery
}
// Execute insert and get the ID
if _, err := query.Exec(ctx); err != nil {
return fmt.Errorf("failed to insert item %d: %w", i, err)
}
// Get the inserted ID
insertedID := reflection.GetPrimaryKeyValue(modelValue)
// Now process nested relations with the parent ID
if len(nestedRelations) > 0 {
logger.Debug("Processing nested relations for item %d with parent ID: %v", i, insertedID)
if err := h.processChildRelationsWithParentID(ctx, txNestedProcessor, "insert", nestedRelations, model, insertedID); err != nil {
return fmt.Errorf("failed to process nested relations for item %d: %w", i, err)
}
}
results = append(results, modelValue)
} }
return nil
})
// Execute AfterCreate hooks for batch creation
hookCtx.Result = map[string]interface{}{"created": dataValue.Len()}
hookCtx.Error = nil
if err := h.hooks.Execute(AfterCreate, hookCtx); err != nil {
logger.Error("AfterCreate hook failed: %v", err)
h.sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err)
return
}
h.sendResponse(w, map[string]interface{}{"created": dataValue.Len()}, nil)
return
}
// Single record creation - create a pointer to the model
modelValue := reflect.New(reflect.TypeOf(model)).Interface()
jsonData, err := json.Marshal(data)
if err != nil { if err != nil {
logger.Error("Error marshaling data: %v", err) logger.Error("Error creating records: %v", err)
h.sendError(w, http.StatusBadRequest, "invalid_data", "Invalid data format", err) h.sendError(w, http.StatusInternalServerError, "create_error", "Error creating records", err)
return
}
if err := json.Unmarshal(jsonData, modelValue); err != nil {
logger.Error("Error unmarshaling data: %v", err)
h.sendError(w, http.StatusBadRequest, "invalid_data", "Invalid data format", err)
return return
} }
query := h.db.NewInsert().Model(modelValue).Table(tableName) // Merge created records with original request data
// This preserves extra keys from the request
// Execute BeforeScan hooks - pass query chain so hooks can modify it mergedResults := make([]interface{}, 0, len(results))
hookCtx.Data = modelValue for i, result := range results {
hookCtx.Query = query if i < len(originalDataMaps) {
if err := h.hooks.Execute(BeforeScan, hookCtx); err != nil { merged := h.mergeRecordWithRequest(result, originalDataMaps[i])
logger.Error("BeforeScan hook failed: %v", err) mergedResults = append(mergedResults, merged)
h.sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err) } else {
return mergedResults = append(mergedResults, result)
}
} }
// Use potentially modified query from hook context // Execute AfterCreate hooks
if modifiedQuery, ok := hookCtx.Query.(common.InsertQuery); ok { var responseData interface{}
query = modifiedQuery if len(mergedResults) == 1 {
responseData = mergedResults[0]
hookCtx.Result = mergedResults[0]
} else {
responseData = mergedResults
hookCtx.Result = map[string]interface{}{"created": len(mergedResults), "data": mergedResults}
} }
if _, err := query.Exec(ctx); err != nil {
logger.Error("Error creating record: %v", err)
h.sendError(w, http.StatusInternalServerError, "create_error", "Error creating record", err)
return
}
// Execute AfterCreate hooks for single record creation
hookCtx.Result = modelValue
hookCtx.Error = nil hookCtx.Error = nil
if err := h.hooks.Execute(AfterCreate, hookCtx); err != nil { if err := h.hooks.Execute(AfterCreate, hookCtx); err != nil {
@@ -789,7 +743,8 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
return return
} }
h.sendResponseWithOptions(w, modelValue, nil, &options) logger.Info("Successfully created %d record(s)", len(mergedResults))
h.sendResponseWithOptions(w, responseData, nil, &options)
} }
func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id string, idPtr *int64, data interface{}, options ExtendedRequestOptions) { func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id string, idPtr *int64, data interface{}, options ExtendedRequestOptions) {
@@ -807,46 +762,6 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id
logger.Info("Updating record in %s.%s", schema, entity) logger.Info("Updating record in %s.%s", schema, entity)
// Convert data to map first for nested processor check
dataMap, ok := data.(map[string]interface{})
if !ok {
jsonData, err := json.Marshal(data)
if err != nil {
logger.Error("Error marshaling data: %v", err)
h.sendError(w, http.StatusBadRequest, "invalid_data", "Invalid data format", err)
return
}
if err := json.Unmarshal(jsonData, &dataMap); err != nil {
logger.Error("Error unmarshaling data: %v", err)
h.sendError(w, http.StatusBadRequest, "invalid_data", "Invalid data format", err)
return
}
}
// Check if we should use nested processing
if h.shouldUseNestedProcessor(dataMap, model) {
logger.Info("Using nested CUD processor for update operation")
// Ensure ID is in the data map
var targetID interface{}
if id != "" {
targetID = id
} else if idPtr != nil {
targetID = *idPtr
}
if targetID != nil {
dataMap["id"] = targetID
}
result, err := h.nestedProcessor.ProcessNestedCUD(ctx, "update", dataMap, model, make(map[string]interface{}), tableName)
if err != nil {
logger.Error("Error in nested update: %v", err)
h.sendError(w, http.StatusInternalServerError, "update_error", "Error updating record with nested data", err)
return
}
logger.Info("Successfully updated record with nested data, rows: %d", result.AffectedRows)
h.sendResponseWithOptions(w, result.Data, nil, &options)
return
}
// Execute BeforeUpdate hooks // Execute BeforeUpdate hooks
hookCtx := &HookContext{ hookCtx := &HookContext{
Context: ctx, Context: ctx,
@@ -870,8 +785,8 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id
// Use potentially modified data from hook context // Use potentially modified data from hook context
data = hookCtx.Data data = hookCtx.Data
// Convert data to map (again if modified by hooks) // Convert data to map
dataMap, ok = data.(map[string]interface{}) dataMap, ok := data.(map[string]interface{})
if !ok { if !ok {
jsonData, err := json.Marshal(data) jsonData, err := json.Marshal(data)
if err != nil { if err != nil {
@@ -886,53 +801,108 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id
} }
} }
query := h.db.NewUpdate().Table(tableName).SetMap(dataMap) // Determine target ID
pkName := reflection.GetPrimaryKeyName(model) var targetID interface{}
// Apply ID filter if id != "" {
switch { targetID = id
case id != "": } else if idPtr != nil {
query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id) targetID = *idPtr
case idPtr != nil: } else {
query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), *idPtr)
default:
h.sendError(w, http.StatusBadRequest, "missing_id", "ID is required for update", nil) h.sendError(w, http.StatusBadRequest, "missing_id", "ID is required for update", nil)
return return
} }
// Execute BeforeScan hooks - pass query chain so hooks can modify it // Get the primary key name for the model
hookCtx.Query = query pkName := reflection.GetPrimaryKeyName(model)
if err := h.hooks.Execute(BeforeScan, hookCtx); err != nil {
logger.Error("BeforeScan hook failed: %v", err)
h.sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err)
return
}
// Use potentially modified query from hook context // Variable to store the updated record
if modifiedQuery, ok := hookCtx.Query.(common.UpdateQuery); ok { var updatedRecord interface{}
query = modifiedQuery
} // Process nested relations if present
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
// Create temporary nested processor with transaction
txNestedProcessor := common.NewNestedCUDProcessor(tx, h.registry, h)
// Extract nested relations if present (but don't process them yet)
var nestedRelations map[string]interface{}
if h.shouldUseNestedProcessor(dataMap, model) {
logger.Debug("Extracting nested relations for update")
cleanedData, relations, err := h.extractNestedRelations(dataMap, model)
if err != nil {
return fmt.Errorf("failed to extract nested relations: %w", err)
}
dataMap = cleanedData
nestedRelations = relations
}
// Ensure ID is in the data map for the update
dataMap[pkName] = targetID
// Create update query
query := tx.NewUpdate().Table(tableName).SetMap(dataMap)
query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), targetID)
// Execute BeforeScan hooks - pass query chain so hooks can modify it
hookCtx.Query = query
if err := h.hooks.Execute(BeforeScan, hookCtx); err != nil {
return fmt.Errorf("BeforeScan hook failed: %w", err)
}
// Use potentially modified query from hook context
if modifiedQuery, ok := hookCtx.Query.(common.UpdateQuery); ok {
query = modifiedQuery
}
// Execute update
result, err := query.Exec(ctx)
if err != nil {
return fmt.Errorf("failed to update record: %w", err)
}
// Now process nested relations with the parent ID
if len(nestedRelations) > 0 {
logger.Debug("Processing nested relations for update with parent ID: %v", targetID)
if err := h.processChildRelationsWithParentID(ctx, txNestedProcessor, "update", nestedRelations, model, targetID); err != nil {
return fmt.Errorf("failed to process nested relations: %w", err)
}
}
// Fetch the updated record to return the new values
modelValue := reflect.New(reflect.TypeOf(model)).Interface()
selectQuery := tx.NewSelect().Model(modelValue).Table(tableName).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), targetID)
if err := selectQuery.ScanModel(ctx); err != nil {
return fmt.Errorf("failed to fetch updated record: %w", err)
}
updatedRecord = modelValue
// Store result for hooks
hookCtx.Result = updatedRecord
_ = result // Keep result variable for potential future use
return nil
})
result, err := query.Exec(ctx)
if err != nil { if err != nil {
logger.Error("Error updating record: %v", err) logger.Error("Error updating record: %v", err)
h.sendError(w, http.StatusInternalServerError, "update_error", "Error updating record", err) h.sendError(w, http.StatusInternalServerError, "update_error", "Error updating record", err)
return return
} }
// Execute AfterUpdate hooks // Merge the updated record with the original request data
responseData := map[string]interface{}{ // This preserves extra keys from the request and updates values from the database
"updated": result.RowsAffected(), mergedData := h.mergeRecordWithRequest(updatedRecord, dataMap)
}
hookCtx.Result = responseData
hookCtx.Error = nil
// Execute AfterUpdate hooks
hookCtx.Result = mergedData
hookCtx.Error = nil
if err := h.hooks.Execute(AfterUpdate, hookCtx); err != nil { if err := h.hooks.Execute(AfterUpdate, hookCtx); err != nil {
logger.Error("AfterUpdate hook failed: %v", err) logger.Error("AfterUpdate hook failed: %v", err)
h.sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err) h.sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err)
return return
} }
h.sendResponseWithOptions(w, responseData, nil, &options) logger.Info("Successfully updated record with ID: %v", targetID)
h.sendResponseWithOptions(w, mergedData, nil, &options)
} }
func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id string, data interface{}) { func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id string, data interface{}) {
@@ -1006,6 +976,7 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
// Array of IDs or objects with ID field // Array of IDs or objects with ID field
logger.Info("Batch delete with %d items ([]interface{})", len(v)) logger.Info("Batch delete with %d items ([]interface{})", len(v))
deletedCount := 0 deletedCount := 0
pkName := reflection.GetPrimaryKeyName(model)
err := h.db.RunInTransaction(ctx, func(tx common.Database) error { err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
for _, item := range v { for _, item := range v {
var itemID interface{} var itemID interface{}
@@ -1015,7 +986,7 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
case string: case string:
itemID = v itemID = v
case map[string]interface{}: case map[string]interface{}:
itemID = v["id"] itemID = v[pkName]
default: default:
itemID = item itemID = item
} }
@@ -1072,9 +1043,10 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
// Array of objects with id field // Array of objects with id field
logger.Info("Batch delete with %d items ([]map[string]interface{})", len(v)) logger.Info("Batch delete with %d items ([]map[string]interface{})", len(v))
deletedCount := 0 deletedCount := 0
pkName := reflection.GetPrimaryKeyName(model)
err := h.db.RunInTransaction(ctx, func(tx common.Database) error { err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
for _, item := range v { for _, item := range v {
if itemID, ok := item["id"]; ok && itemID != nil { if itemID, ok := item[pkName]; ok && itemID != nil {
itemIDStr := fmt.Sprintf("%v", itemID) itemIDStr := fmt.Sprintf("%v", itemID)
// Execute hooks for each item // Execute hooks for each item
@@ -1122,7 +1094,8 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
case map[string]interface{}: case map[string]interface{}:
// Single object with id field // Single object with id field
if itemID, ok := v["id"]; ok && itemID != nil { pkName := reflection.GetPrimaryKeyName(model)
if itemID, ok := v[pkName]; ok && itemID != nil {
id = fmt.Sprintf("%v", itemID) id = fmt.Sprintf("%v", itemID)
} }
} }
@@ -1192,6 +1165,229 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
h.sendResponse(w, responseData, nil) h.sendResponse(w, responseData, nil)
} }
// mergeRecordWithRequest merges a database record with the original request data
// This preserves extra keys from the request that aren't in the database model
// and updates values from the database (e.g., from SQL triggers or defaults)
func (h *Handler) mergeRecordWithRequest(dbRecord interface{}, requestData map[string]interface{}) map[string]interface{} {
// Convert the database record to a map
dbMap := make(map[string]interface{})
// Marshal and unmarshal to convert struct to map
jsonData, err := json.Marshal(dbRecord)
if err != nil {
logger.Warn("Failed to marshal database record for merging: %v", err)
return requestData
}
if err := json.Unmarshal(jsonData, &dbMap); err != nil {
logger.Warn("Failed to unmarshal database record for merging: %v", err)
return requestData
}
// Start with the request data (preserves extra keys)
result := make(map[string]interface{})
for k, v := range requestData {
result[k] = v
}
// Update with values from database (overwrites with DB values, including trigger changes)
for k, v := range dbMap {
result[k] = v
}
return result
}
// normalizeToSlice converts data to a slice. Single items become a 1-item slice.
func (h *Handler) normalizeToSlice(data interface{}) []interface{} {
if data == nil {
return []interface{}{}
}
dataValue := reflect.ValueOf(data)
if dataValue.Kind() == reflect.Slice || dataValue.Kind() == reflect.Array {
result := make([]interface{}, dataValue.Len())
for i := 0; i < dataValue.Len(); i++ {
result[i] = dataValue.Index(i).Interface()
}
return result
}
// Single item - return as 1-item slice
return []interface{}{data}
}
// extractNestedRelations extracts nested relations from data, returning cleaned data and relations
// This does NOT process the relations, just separates them for later processing
func (h *Handler) extractNestedRelations(
data map[string]interface{},
model interface{},
) (map[string]interface{}, map[string]interface{}, error) {
// Get model type for reflection
modelType := reflect.TypeOf(model)
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
modelType = modelType.Elem()
}
if modelType == nil || modelType.Kind() != reflect.Struct {
return data, nil, fmt.Errorf("model must be a struct type, got %v", modelType)
}
// Separate relation fields from regular fields
cleanedData := make(map[string]interface{})
relations := make(map[string]interface{})
for key, value := range data {
// Skip _request field
if key == "_request" {
continue
}
// Check if this field is a relation
relInfo := h.GetRelationshipInfo(modelType, key)
if relInfo != nil {
logger.Debug("Found nested relation field: %s (type: %s)", key, relInfo.RelationType)
relations[key] = value
} else {
cleanedData[key] = value
}
}
return cleanedData, relations, nil
}
// processChildRelationsWithParentID processes nested relations with a parent ID
func (h *Handler) processChildRelationsWithParentID(
ctx context.Context,
processor *common.NestedCUDProcessor,
operation string,
relations map[string]interface{},
parentModel interface{},
parentID interface{},
) error {
// Get model type for reflection
modelType := reflect.TypeOf(parentModel)
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
modelType = modelType.Elem()
}
if modelType == nil || modelType.Kind() != reflect.Struct {
return fmt.Errorf("model must be a struct type, got %v", modelType)
}
// Process each relation
for relationName, relationValue := range relations {
if relationValue == nil {
continue
}
// Get relationship info
relInfo := h.GetRelationshipInfo(modelType, relationName)
if relInfo == nil {
logger.Warn("No relationship info found for %s, skipping", relationName)
continue
}
// Process this relation with parent ID
if err := h.processChildRelationsForField(ctx, processor, operation, relationName, relationValue, relInfo, modelType, parentID); err != nil {
return fmt.Errorf("failed to process relation %s: %w", relationName, err)
}
}
return nil
}
// processChildRelationsForField processes a single nested relation field
func (h *Handler) processChildRelationsForField(
ctx context.Context,
processor *common.NestedCUDProcessor,
operation string,
relationName string,
relationValue interface{},
relInfo *common.RelationshipInfo,
parentModelType reflect.Type,
parentID interface{},
) error {
if relationValue == nil {
return nil
}
// Get the related model
field, found := parentModelType.FieldByName(relInfo.FieldName)
if !found {
return fmt.Errorf("field %s not found in model", relInfo.FieldName)
}
// Get the model type for the relation
relatedModelType := field.Type
if relatedModelType.Kind() == reflect.Slice {
relatedModelType = relatedModelType.Elem()
}
if relatedModelType.Kind() == reflect.Ptr {
relatedModelType = relatedModelType.Elem()
}
// Create an instance of the related model
relatedModel := reflect.New(relatedModelType).Elem().Interface()
// Get table name for related model
relatedTableName := h.getTableNameForRelatedModel(relatedModel, relInfo.JSONName)
// Prepare parent IDs for foreign key injection
parentIDs := make(map[string]interface{})
if relInfo.ForeignKey != "" && parentID != nil {
baseName := strings.TrimSuffix(relInfo.ForeignKey, "ID")
baseName = strings.TrimSuffix(strings.ToLower(baseName), "_id")
parentIDs[baseName] = parentID
}
// Process based on relation type and data structure
switch v := relationValue.(type) {
case map[string]interface{}:
// Single related object
_, err := processor.ProcessNestedCUD(ctx, operation, v, relatedModel, parentIDs, relatedTableName)
if err != nil {
return fmt.Errorf("failed to process single relation: %w", err)
}
case []interface{}:
// Multiple related objects
for i, item := range v {
if itemMap, ok := item.(map[string]interface{}); ok {
_, err := processor.ProcessNestedCUD(ctx, operation, itemMap, relatedModel, parentIDs, relatedTableName)
if err != nil {
return fmt.Errorf("failed to process relation item %d: %w", i, err)
}
}
}
case []map[string]interface{}:
// Multiple related objects (typed slice)
for i, itemMap := range v {
_, err := processor.ProcessNestedCUD(ctx, operation, itemMap, relatedModel, parentIDs, relatedTableName)
if err != nil {
return fmt.Errorf("failed to process relation item %d: %w", i, err)
}
}
default:
return fmt.Errorf("unsupported relation data type: %T", relationValue)
}
return nil
}
// getTableNameForRelatedModel gets the table name for a related model
func (h *Handler) getTableNameForRelatedModel(model interface{}, defaultName string) string {
if provider, ok := model.(common.TableNameProvider); ok {
tableName := provider.TableName()
if tableName != "" {
return tableName
}
}
return defaultName
}
// qualifyColumnName ensures column name is fully qualified with table name if not already // qualifyColumnName ensures column name is fully qualified with table name if not already
func (h *Handler) qualifyColumnName(columnName, fullTableName string) string { func (h *Handler) qualifyColumnName(columnName, fullTableName string) string {
// Check if column already has a table/schema prefix (contains a dot) // Check if column already has a table/schema prefix (contains a dot)
@@ -1538,22 +1734,22 @@ func (h *Handler) cleanJSON(data interface{}) interface{} {
} }
func (h *Handler) sendError(w common.ResponseWriter, statusCode int, code, message string, err error) { func (h *Handler) sendError(w common.ResponseWriter, statusCode int, code, message string, err error) {
var details string var errorMsg string
if err != nil { if err != nil {
details = err.Error() errorMsg = err.Error()
} else if message != "" {
errorMsg = message
} else {
errorMsg = code
} }
response := common.Response{ response := map[string]interface{}{
Success: false, "_error": errorMsg,
Error: &common.APIError{ "_retval": 1,
Code: code,
Message: message,
Details: details,
},
} }
w.WriteHeader(statusCode) w.WriteHeader(statusCode)
if err := w.WriteJSON(response); err != nil { if jsonErr := w.WriteJSON(response); jsonErr != nil {
logger.Error("Failed to write JSON error response: %v", err) logger.Error("Failed to write JSON error response: %v", jsonErr)
} }
} }
@@ -1784,7 +1980,8 @@ func filterExtendedOptions(validator *common.ColumnValidator, options ExtendedRe
} }
// shouldUseNestedProcessor determines if we should use nested CUD processing // shouldUseNestedProcessor determines if we should use nested CUD processing
// It checks if the data contains nested relations or a _request field // It recursively checks if the data contains deeply nested relations or _request fields
// Simple one-level relations without further nesting don't require the nested processor
func (h *Handler) shouldUseNestedProcessor(data map[string]interface{}, model interface{}) bool { func (h *Handler) shouldUseNestedProcessor(data map[string]interface{}, model interface{}) bool {
return common.ShouldUseNestedProcessor(data, model, h) return common.ShouldUseNestedProcessor(data, model, h)
} }
@@ -1846,12 +2043,40 @@ func (h *Handler) getRelationshipInfo(modelType reflect.Type, relationName strin
// Determine if it's belongsTo or hasMany/hasOne // Determine if it's belongsTo or hasMany/hasOne
if field.Type.Kind() == reflect.Slice { if field.Type.Kind() == reflect.Slice {
info.relationType = "hasMany" info.relationType = "hasMany"
// Get the element type for slice
elemType := field.Type.Elem()
if elemType.Kind() == reflect.Ptr {
elemType = elemType.Elem()
}
if elemType.Kind() == reflect.Struct {
info.relatedModel = reflect.New(elemType).Elem().Interface()
}
} else if field.Type.Kind() == reflect.Ptr || field.Type.Kind() == reflect.Struct { } else if field.Type.Kind() == reflect.Ptr || field.Type.Kind() == reflect.Struct {
info.relationType = "belongsTo" info.relationType = "belongsTo"
elemType := field.Type
if elemType.Kind() == reflect.Ptr {
elemType = elemType.Elem()
}
if elemType.Kind() == reflect.Struct {
info.relatedModel = reflect.New(elemType).Elem().Interface()
}
} }
} else if strings.Contains(gormTag, "many2many") { } else if strings.Contains(gormTag, "many2many") {
info.relationType = "many2many" info.relationType = "many2many"
info.joinTable = h.extractTagValue(gormTag, "many2many") info.joinTable = h.extractTagValue(gormTag, "many2many")
// Get the element type for many2many (always slice)
if field.Type.Kind() == reflect.Slice {
elemType := field.Type.Elem()
if elemType.Kind() == reflect.Ptr {
elemType = elemType.Elem()
}
if elemType.Kind() == reflect.Struct {
info.relatedModel = reflect.New(elemType).Elem().Interface()
}
}
} else {
// Field has no GORM relationship tags, so it's not a relation
return nil
} }
return info return info

View File

@@ -0,0 +1,423 @@
package restheadspec
import (
"fmt"
"reflect"
"testing"
)
// Test models for nested CRUD operations
type TestUser struct {
ID int64 `json:"id" bun:"id,pk,autoincrement"`
Name string `json:"name"`
Posts []TestPost `json:"posts" gorm:"foreignKey:UserID"`
}
type TestPost struct {
ID int64 `json:"id" bun:"id,pk,autoincrement"`
UserID int64 `json:"user_id"`
Title string `json:"title"`
Comments []TestComment `json:"comments" gorm:"foreignKey:PostID"`
}
type TestComment struct {
ID int64 `json:"id" bun:"id,pk,autoincrement"`
PostID int64 `json:"post_id"`
Content string `json:"content"`
}
func (TestUser) TableName() string { return "users" }
func (TestPost) TableName() string { return "posts" }
func (TestComment) TableName() string { return "comments" }
// Test extractNestedRelations function
func TestExtractNestedRelations(t *testing.T) {
// Create handler
registry := &mockRegistry{
models: map[string]interface{}{
"users": TestUser{},
"posts": TestPost{},
"comments": TestComment{},
},
}
handler := NewHandler(nil, registry)
tests := []struct {
name string
data map[string]interface{}
model interface{}
expectedCleanCount int
expectedRelCount int
}{
{
name: "User with posts",
data: map[string]interface{}{
"name": "John Doe",
"posts": []map[string]interface{}{
{"title": "Post 1"},
},
},
model: TestUser{},
expectedCleanCount: 1, // name
expectedRelCount: 1, // posts
},
{
name: "Post with comments",
data: map[string]interface{}{
"title": "Test Post",
"comments": []map[string]interface{}{
{"content": "Comment 1"},
{"content": "Comment 2"},
},
},
model: TestPost{},
expectedCleanCount: 1, // title
expectedRelCount: 1, // comments
},
{
name: "User with nested posts and comments",
data: map[string]interface{}{
"name": "Jane Doe",
"posts": []map[string]interface{}{
{
"title": "Post 1",
"comments": []map[string]interface{}{
{"content": "Comment 1"},
},
},
},
},
model: TestUser{},
expectedCleanCount: 1, // name
expectedRelCount: 1, // posts (which contains nested comments)
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cleanedData, relations, err := handler.extractNestedRelations(tt.data, tt.model)
if err != nil {
t.Errorf("extractNestedRelations() error = %v", err)
return
}
if len(cleanedData) != tt.expectedCleanCount {
t.Errorf("Expected %d cleaned fields, got %d: %+v", tt.expectedCleanCount, len(cleanedData), cleanedData)
}
if len(relations) != tt.expectedRelCount {
t.Errorf("Expected %d relation fields, got %d: %+v", tt.expectedRelCount, len(relations), relations)
}
t.Logf("Cleaned data: %+v", cleanedData)
t.Logf("Relations: %+v", relations)
})
}
}
// Test shouldUseNestedProcessor function
func TestShouldUseNestedProcessor(t *testing.T) {
registry := &mockRegistry{
models: map[string]interface{}{
"users": TestUser{},
"posts": TestPost{},
},
}
handler := NewHandler(nil, registry)
tests := []struct {
name string
data map[string]interface{}
model interface{}
expected bool
}{
{
name: "Data with simple nested posts (no further nesting)",
data: map[string]interface{}{
"name": "John",
"posts": []map[string]interface{}{
{"title": "Post 1"},
},
},
model: TestUser{},
expected: false, // Simple one-level nesting doesn't require nested processor
},
{
name: "Data with deeply nested relations",
data: map[string]interface{}{
"name": "John",
"posts": []map[string]interface{}{
{
"title": "Post 1",
"comments": []map[string]interface{}{
{"content": "Comment 1"},
},
},
},
},
model: TestUser{},
expected: true, // Multi-level nesting requires nested processor
},
{
name: "Data without nested relations",
data: map[string]interface{}{
"name": "John",
},
model: TestUser{},
expected: false,
},
{
name: "Data with _request field",
data: map[string]interface{}{
"_request": "insert",
"name": "John",
},
model: TestUser{},
expected: true,
},
{
name: "Nested data with _request field",
data: map[string]interface{}{
"name": "John",
"posts": []map[string]interface{}{
{
"_request": "insert",
"title": "Post 1",
},
},
},
model: TestUser{},
expected: true, // _request at nested level requires nested processor
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := handler.shouldUseNestedProcessor(tt.data, tt.model)
if result != tt.expected {
t.Errorf("shouldUseNestedProcessor() = %v, expected %v", result, tt.expected)
}
})
}
}
// Test normalizeToSlice function
func TestNormalizeToSlice(t *testing.T) {
registry := &mockRegistry{}
handler := NewHandler(nil, registry)
tests := []struct {
name string
input interface{}
expected int // expected slice length
}{
{
name: "Single object",
input: map[string]interface{}{"name": "John"},
expected: 1,
},
{
name: "Slice of objects",
input: []map[string]interface{}{
{"name": "John"},
{"name": "Jane"},
},
expected: 2,
},
{
name: "Array of interfaces",
input: []interface{}{
map[string]interface{}{"name": "John"},
map[string]interface{}{"name": "Jane"},
map[string]interface{}{"name": "Bob"},
},
expected: 3,
},
{
name: "Nil input",
input: nil,
expected: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := handler.normalizeToSlice(tt.input)
if len(result) != tt.expected {
t.Errorf("normalizeToSlice() returned slice of length %d, expected %d", len(result), tt.expected)
}
})
}
}
// Test GetRelationshipInfo function
func TestGetRelationshipInfo(t *testing.T) {
registry := &mockRegistry{}
handler := NewHandler(nil, registry)
tests := []struct {
name string
modelType reflect.Type
relationName string
expectNil bool
}{
{
name: "User posts relation",
modelType: reflect.TypeOf(TestUser{}),
relationName: "posts",
expectNil: false,
},
{
name: "Post comments relation",
modelType: reflect.TypeOf(TestPost{}),
relationName: "comments",
expectNil: false,
},
{
name: "Non-existent relation",
modelType: reflect.TypeOf(TestUser{}),
relationName: "nonexistent",
expectNil: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := handler.GetRelationshipInfo(tt.modelType, tt.relationName)
if tt.expectNil && result != nil {
t.Errorf("Expected nil, got %+v", result)
}
if !tt.expectNil && result == nil {
t.Errorf("Expected non-nil relationship info")
}
if result != nil {
t.Logf("Relationship info: FieldName=%s, JSONName=%s, RelationType=%s, ForeignKey=%s",
result.FieldName, result.JSONName, result.RelationType, result.ForeignKey)
}
})
}
}
// Mock registry for testing
type mockRegistry struct {
models map[string]interface{}
}
func (m *mockRegistry) Register(name string, model interface{}) {
m.RegisterModel(name, model)
}
func (m *mockRegistry) RegisterModel(name string, model interface{}) error {
if m.models == nil {
m.models = make(map[string]interface{})
}
m.models[name] = model
return nil
}
func (m *mockRegistry) GetModelByEntity(schema, entity string) (interface{}, error) {
if model, ok := m.models[entity]; ok {
return model, nil
}
return nil, fmt.Errorf("model not found: %s", entity)
}
func (m *mockRegistry) GetModelByName(name string) (interface{}, error) {
if model, ok := m.models[name]; ok {
return model, nil
}
return nil, fmt.Errorf("model not found: %s", name)
}
func (m *mockRegistry) GetModel(name string) (interface{}, error) {
return m.GetModelByName(name)
}
func (m *mockRegistry) HasModel(schema, entity string) bool {
_, ok := m.models[entity]
return ok
}
func (m *mockRegistry) ListModels() []string {
models := make([]string, 0, len(m.models))
for name := range m.models {
models = append(models, name)
}
return models
}
func (m *mockRegistry) GetAllModels() map[string]interface{} {
return m.models
}
// TestMultiLevelRelationExtraction tests extracting deeply nested relations
func TestMultiLevelRelationExtraction(t *testing.T) {
registry := &mockRegistry{
models: map[string]interface{}{
"users": TestUser{},
"posts": TestPost{},
"comments": TestComment{},
},
}
handler := NewHandler(nil, registry)
// Test data with 3 levels: User -> Posts -> Comments
testData := map[string]interface{}{
"name": "John Doe",
"posts": []map[string]interface{}{
{
"title": "First Post",
"comments": []map[string]interface{}{
{"content": "Great post!"},
{"content": "Thanks for sharing!"},
},
},
{
"title": "Second Post",
"comments": []map[string]interface{}{
{"content": "Interesting read"},
},
},
},
}
// Extract relations from user
cleanedData, relations, err := handler.extractNestedRelations(testData, TestUser{})
if err != nil {
t.Fatalf("Failed to extract relations: %v", err)
}
// Verify user data is cleaned
if len(cleanedData) != 1 || cleanedData["name"] != "John Doe" {
t.Errorf("Expected cleaned data to contain only name, got: %+v", cleanedData)
}
// Verify posts relation was extracted
if len(relations) != 1 {
t.Errorf("Expected 1 relation (posts), got %d", len(relations))
}
posts, ok := relations["posts"]
if !ok {
t.Fatal("Expected posts relation to be extracted")
}
// Verify posts is a slice with 2 items
postsSlice, ok := posts.([]map[string]interface{})
if !ok {
t.Fatalf("Expected posts to be []map[string]interface{}, got %T", posts)
}
if len(postsSlice) != 2 {
t.Errorf("Expected 2 posts, got %d", len(postsSlice))
}
// Verify first post has comments
if _, hasComments := postsSlice[0]["comments"]; !hasComments {
t.Error("Expected first post to have comments")
}
t.Logf("Successfully extracted multi-level nested relations")
t.Logf("Cleaned data: %+v", cleanedData)
t.Logf("Relations: %d posts with nested comments", len(postsSlice))
}

View File

@@ -106,7 +106,7 @@ func SetupMuxRoutes(muxRouter *mux.Router, handler *Handler) {
reqAdapter := router.NewHTTPRequest(r) reqAdapter := router.NewHTTPRequest(r)
respAdapter := router.NewHTTPResponseWriter(w) respAdapter := router.NewHTTPResponseWriter(w)
handler.Handle(respAdapter, reqAdapter, vars) handler.Handle(respAdapter, reqAdapter, vars)
}).Methods("GET", "PUT", "PATCH", "DELETE") }).Methods("GET", "PUT", "PATCH", "DELETE", "POST")
// GET for metadata (using HandleGet) // GET for metadata (using HandleGet)
muxRouter.HandleFunc("/{schema}/{entity}/metadata", func(w http.ResponseWriter, r *http.Request) { muxRouter.HandleFunc("/{schema}/{entity}/metadata", func(w http.ResponseWriter, r *http.Request) {
@@ -189,6 +189,18 @@ func SetupBunRouterRoutes(bunRouter *router.StandardBunRouterAdapter, handler *H
return nil return nil
}) })
r.Handle("POST", "/:schema/:entity/:id", func(w http.ResponseWriter, req bunrouter.Request) error {
params := map[string]string{
"schema": req.Param("schema"),
"entity": req.Param("entity"),
"id": req.Param("id"),
}
reqAdapter := router.NewBunRouterRequest(req)
respAdapter := router.NewHTTPResponseWriter(w)
handler.Handle(respAdapter, reqAdapter, params)
return nil
})
r.Handle("PUT", "/:schema/:entity/:id", func(w http.ResponseWriter, req bunrouter.Request) error { r.Handle("PUT", "/:schema/:entity/:id", func(w http.ResponseWriter, req bunrouter.Request) error {
params := map[string]string{ params := map[string]string{
"schema": req.Param("schema"), "schema": req.Param("schema"),

View File

@@ -10,7 +10,7 @@ import (
type TestModel struct { type TestModel struct {
ID int64 `json:"id" bun:"id,pk"` ID int64 `json:"id" bun:"id,pk"`
Name string `json:"name" bun:"name"` Name string `json:"name" bun:"name"`
RowNumber int64 `json:"_rownumber,omitempty" gorm:"-" bun:"-"` RowNumber int64 `json:"_rownumber,omitempty" gorm:"-" bun:",scanonly"`
} }
func TestSetRowNumbersOnRecords(t *testing.T) { func TestSetRowNumbersOnRecords(t *testing.T) {