diff --git a/pkg/reflection/generic_model.go b/pkg/reflection/generic_model.go index d5e08f3..cb33107 100644 --- a/pkg/reflection/generic_model.go +++ b/pkg/reflection/generic_model.go @@ -18,6 +18,7 @@ type ModelFieldDetail struct { } // GetModelColumnDetail - Get a list of columns in the SQL declaration of the model +// This function recursively processes embedded structs to include their fields func GetModelColumnDetail(record reflect.Value) []ModelFieldDetail { defer func() { if r := recover(); r != nil { @@ -37,14 +38,43 @@ func GetModelColumnDetail(record reflect.Value) []ModelFieldDetail { if record.Kind() != reflect.Struct { return lst } + + collectFieldDetails(record, &lst) + + return lst +} + +// collectFieldDetails recursively collects field details from a struct value and its embedded fields +func collectFieldDetails(record reflect.Value, lst *[]ModelFieldDetail) { modeltype := record.Type() for i := 0; i < modeltype.NumField(); i++ { fieldtype := modeltype.Field(i) + fieldValue := record.Field(i) + + // Check if this is an embedded struct + if fieldtype.Anonymous { + // Unwrap pointer type if necessary + embeddedValue := fieldValue + if fieldValue.Kind() == reflect.Pointer { + if fieldValue.IsNil() { + // Skip nil embedded pointers + continue + } + embeddedValue = fieldValue.Elem() + } + + // Recursively process embedded struct + if embeddedValue.Kind() == reflect.Struct { + collectFieldDetails(embeddedValue, lst) + continue + } + } + gormdetail := fieldtype.Tag.Get("gorm") gormdetail = strings.Trim(gormdetail, " ") fielddetail := ModelFieldDetail{} - fielddetail.FieldValue = record.Field(i) + fielddetail.FieldValue = fieldValue fielddetail.Name = fieldtype.Name fielddetail.DataType = fieldtype.Type.Name() fielddetail.SQLName = fnFindKeyVal(gormdetail, "column:") @@ -80,10 +110,8 @@ func GetModelColumnDetail(record reflect.Value) []ModelFieldDetail { } // ";foreignkey:rid_parent;association_foreignkey:id_atevent;save_associations:false;association_autocreate:false;" - lst = append(lst, fielddetail) - + *lst = append(*lst, fielddetail) } - return lst } func fnFindKeyVal(src, key string) string { diff --git a/pkg/reflection/model_utils.go b/pkg/reflection/model_utils.go index f28d164..ffea02a 100644 --- a/pkg/reflection/model_utils.go +++ b/pkg/reflection/model_utils.go @@ -47,7 +47,7 @@ func GetPrimaryKeyName(model any) string { // GetPrimaryKeyValue extracts the primary key value from a model instance // Returns the value of the primary key field -func GetPrimaryKeyValue(model any) interface{} { +func GetPrimaryKeyValue(model any) any { if model == nil || reflect.TypeOf(model) == nil { return nil } @@ -61,38 +61,51 @@ func GetPrimaryKeyValue(model any) interface{} { return nil } - typ := val.Type() - // Try Bun tag first - for i := 0; i < typ.NumField(); i++ { - field := typ.Field(i) - bunTag := field.Tag.Get("bun") - if strings.Contains(bunTag, "pk") { - fieldValue := val.Field(i) - if fieldValue.CanInterface() { - return fieldValue.Interface() - } - } + if pkValue := findPrimaryKeyValue(val, "bun"); pkValue != nil { + return pkValue } // Fall back to GORM tag - for i := 0; i < typ.NumField(); i++ { - field := typ.Field(i) - gormTag := field.Tag.Get("gorm") - if strings.Contains(gormTag, "primaryKey") { - fieldValue := val.Field(i) - if fieldValue.CanInterface() { - return fieldValue.Interface() - } - } + if pkValue := findPrimaryKeyValue(val, "gorm"); pkValue != nil { + return pkValue } // Last resort: look for field named "ID" or "Id" + if pkValue := findFieldByName(val, "id"); pkValue != nil { + return pkValue + } + + return nil +} + +// findPrimaryKeyValue recursively searches for a primary key field in the struct +func findPrimaryKeyValue(val reflect.Value, ormType string) any { + typ := val.Type() + for i := 0; i < typ.NumField(); i++ { field := typ.Field(i) - if strings.ToLower(field.Name) == "id" { - fieldValue := val.Field(i) - if fieldValue.CanInterface() { + fieldValue := val.Field(i) + + // Check if this is an embedded struct + if field.Anonymous && field.Type.Kind() == reflect.Struct { + // Recursively search in embedded struct + if pkValue := findPrimaryKeyValue(fieldValue, ormType); pkValue != nil { + return pkValue + } + continue + } + + // Check for primary key tag + switch ormType { + case "bun": + bunTag := field.Tag.Get("bun") + if strings.Contains(bunTag, "pk") && fieldValue.CanInterface() { + return fieldValue.Interface() + } + case "gorm": + gormTag := field.Tag.Get("gorm") + if strings.Contains(gormTag, "primaryKey") && fieldValue.CanInterface() { return fieldValue.Interface() } } @@ -101,8 +114,35 @@ func GetPrimaryKeyValue(model any) interface{} { return nil } +// findFieldByName recursively searches for a field by name in the struct +func findFieldByName(val reflect.Value, name string) any { + typ := val.Type() + + for i := 0; i < typ.NumField(); i++ { + field := typ.Field(i) + fieldValue := val.Field(i) + + // Check if this is an embedded struct + if field.Anonymous && field.Type.Kind() == reflect.Struct { + // Recursively search in embedded struct + if result := findFieldByName(fieldValue, name); result != nil { + return result + } + continue + } + + // Check if field name matches + if strings.ToLower(field.Name) == name && fieldValue.CanInterface() { + return fieldValue.Interface() + } + } + + return nil +} + // GetModelColumns extracts all column names from a model using reflection // It checks bun tags first, then gorm tags, then json tags, and finally falls back to lowercase field names +// This function recursively processes embedded structs to include their fields func GetModelColumns(model any) []string { var columns []string @@ -118,18 +158,38 @@ func GetModelColumns(model any) []string { return columns } - for i := 0; i < modelType.NumField(); i++ { - field := modelType.Field(i) + collectColumnsFromType(modelType, &columns) + + return columns +} + +// collectColumnsFromType recursively collects column names from a struct type and its embedded fields +func collectColumnsFromType(typ reflect.Type, columns *[]string) { + for i := 0; i < typ.NumField(); i++ { + field := typ.Field(i) + + // Check if this is an embedded struct + if field.Anonymous { + // Unwrap pointer type if necessary + fieldType := field.Type + if fieldType.Kind() == reflect.Pointer { + fieldType = fieldType.Elem() + } + + // Recursively process embedded struct + if fieldType.Kind() == reflect.Struct { + collectColumnsFromType(fieldType, columns) + continue + } + } // Get column name using the same logic as primary key extraction columnName := getColumnNameFromField(field) if columnName != "" { - columns = append(columns, columnName) + *columns = append(*columns, columnName) } } - - return columns } // getColumnNameFromField extracts the column name from a struct field @@ -166,6 +226,7 @@ func getColumnNameFromField(field reflect.StructField) string { } // getPrimaryKeyFromReflection uses reflection to find the primary key field +// This function recursively searches embedded structs func getPrimaryKeyFromReflection(model any, ormType string) string { val := reflect.ValueOf(model) if val.Kind() == reflect.Pointer { @@ -177,9 +238,31 @@ func getPrimaryKeyFromReflection(model any, ormType string) string { } typ := val.Type() + return findPrimaryKeyNameFromType(typ, ormType) +} + +// findPrimaryKeyNameFromType recursively searches for the primary key field name in a struct type +func findPrimaryKeyNameFromType(typ reflect.Type, ormType string) string { for i := 0; i < typ.NumField(); i++ { field := typ.Field(i) + // Check if this is an embedded struct + if field.Anonymous { + // Unwrap pointer type if necessary + fieldType := field.Type + if fieldType.Kind() == reflect.Pointer { + fieldType = fieldType.Elem() + } + + // Recursively search in embedded struct + if fieldType.Kind() == reflect.Struct { + if pkName := findPrimaryKeyNameFromType(fieldType, ormType); pkName != "" { + return pkName + } + } + continue + } + switch ormType { case "gorm": // Check for gorm tag with primaryKey @@ -231,6 +314,9 @@ func ExtractColumnFromGormTag(tag string) string { // Example: ",pk" -> "" (will fall back to json tag) func ExtractColumnFromBunTag(tag string) string { parts := strings.Split(tag, ",") + if strings.HasPrefix(strings.ToLower(tag), "table:") || strings.HasPrefix(strings.ToLower(tag), "rel:") || strings.HasPrefix(strings.ToLower(tag), "join:") { + return "" + } if len(parts) > 0 && parts[0] != "" { return parts[0] } @@ -240,6 +326,7 @@ func ExtractColumnFromBunTag(tag string) string { // IsColumnWritable checks if a column can be written to in the database // For bun: returns false if the field has "scanonly" tag // For gorm: returns false if the field has "<-:false" or "->" (read-only) tag +// This function recursively searches embedded structs func IsColumnWritable(model any, columnName string) bool { modelType := reflect.TypeOf(model) @@ -253,8 +340,37 @@ func IsColumnWritable(model any, columnName string) bool { return false } - for i := 0; i < modelType.NumField(); i++ { - field := modelType.Field(i) + found, writable := isColumnWritableInType(modelType, columnName) + if found { + return writable + } + + // Column not found in model, allow it (might be a dynamic column) + return true +} + +// isColumnWritableInType recursively searches for a column and checks if it's writable +// Returns (found, writable) where found indicates if the column was found +func isColumnWritableInType(typ reflect.Type, columnName string) (bool, bool) { + for i := 0; i < typ.NumField(); i++ { + field := typ.Field(i) + + // Check if this is an embedded struct + if field.Anonymous { + // Unwrap pointer type if necessary + fieldType := field.Type + if fieldType.Kind() == reflect.Pointer { + fieldType = fieldType.Elem() + } + + // Recursively search in embedded struct + if fieldType.Kind() == reflect.Struct { + if found, writable := isColumnWritableInType(fieldType, columnName); found { + return true, writable + } + } + continue + } // Check if this field matches the column name fieldColumnName := getColumnNameFromField(field) @@ -262,11 +378,12 @@ func IsColumnWritable(model any, columnName string) bool { continue } + // Found the field, now check if it's writable // Check bun tag for scanonly bunTag := field.Tag.Get("bun") if bunTag != "" { if isBunFieldScanOnly(bunTag) { - return false + return true, false } } @@ -274,16 +391,16 @@ func IsColumnWritable(model any, columnName string) bool { gormTag := field.Tag.Get("gorm") if gormTag != "" { if isGormFieldReadOnly(gormTag) { - return false + return true, false } } // Column is writable - return true + return true, true } - // Column not found in model, allow it (might be a dynamic column) - return true + // Column not found + return false, false } // isBunFieldScanOnly checks if a bun tag indicates the field is scan-only diff --git a/pkg/reflection/model_utils_test.go b/pkg/reflection/model_utils_test.go index dd2f020..2e581a7 100644 --- a/pkg/reflection/model_utils_test.go +++ b/pkg/reflection/model_utils_test.go @@ -231,3 +231,246 @@ func TestGetModelColumns(t *testing.T) { }) } } + +// Test models with embedded structs + +type BaseModel struct { + ID int `bun:"rid_base,pk" json:"id"` + CreatedAt string `bun:"created_at" json:"created_at"` +} + +type AdhocBuffer struct { + CQL1 string `json:"cql1,omitempty" gorm:"->" bun:",scanonly"` + CQL2 string `json:"cql2,omitempty" gorm:"->" bun:",scanonly"` + RowNumber int64 `json:"_rownumber,omitempty" gorm:"-" bun:",scanonly"` +} + +type ModelWithEmbedded struct { + BaseModel + Name string `bun:"name" json:"name"` + Description string `bun:"description" json:"description"` + AdhocBuffer +} + +type GormBaseModel struct { + ID int `gorm:"column:rid_base;primaryKey" json:"id"` + CreatedAt string `gorm:"column:created_at" json:"created_at"` +} + +type GormAdhocBuffer struct { + CQL1 string `json:"cql1,omitempty" gorm:"column:cql1;->" bun:",scanonly"` + CQL2 string `json:"cql2,omitempty" gorm:"column:cql2;->" bun:",scanonly"` + RowNumber int64 `json:"_rownumber,omitempty" gorm:"-" bun:",scanonly"` +} + +type GormModelWithEmbedded struct { + GormBaseModel + Name string `gorm:"column:name" json:"name"` + Description string `gorm:"column:description" json:"description"` + GormAdhocBuffer +} + +func TestGetPrimaryKeyNameWithEmbedded(t *testing.T) { + tests := []struct { + name string + model any + expected string + }{ + { + name: "Bun model with embedded base", + model: ModelWithEmbedded{}, + expected: "rid_base", + }, + { + name: "Bun model with embedded base (pointer)", + model: &ModelWithEmbedded{}, + expected: "rid_base", + }, + { + name: "GORM model with embedded base", + model: GormModelWithEmbedded{}, + expected: "rid_base", + }, + { + name: "GORM model with embedded base (pointer)", + model: &GormModelWithEmbedded{}, + expected: "rid_base", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := GetPrimaryKeyName(tt.model) + if result != tt.expected { + t.Errorf("GetPrimaryKeyName() = %v, want %v", result, tt.expected) + } + }) + } +} + +func TestGetPrimaryKeyValueWithEmbedded(t *testing.T) { + bunModel := ModelWithEmbedded{ + BaseModel: BaseModel{ + ID: 123, + CreatedAt: "2024-01-01", + }, + Name: "Test", + Description: "Test Description", + } + + gormModel := GormModelWithEmbedded{ + GormBaseModel: GormBaseModel{ + ID: 456, + CreatedAt: "2024-01-02", + }, + Name: "GORM Test", + Description: "GORM Test Description", + } + + tests := []struct { + name string + model any + expected any + }{ + { + name: "Bun model with embedded base", + model: bunModel, + expected: 123, + }, + { + name: "Bun model with embedded base (pointer)", + model: &bunModel, + expected: 123, + }, + { + name: "GORM model with embedded base", + model: gormModel, + expected: 456, + }, + { + name: "GORM model with embedded base (pointer)", + model: &gormModel, + expected: 456, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := GetPrimaryKeyValue(tt.model) + if result != tt.expected { + t.Errorf("GetPrimaryKeyValue() = %v, want %v", result, tt.expected) + } + }) + } +} + +func TestGetModelColumnsWithEmbedded(t *testing.T) { + tests := []struct { + name string + model any + expected []string + }{ + { + name: "Bun model with embedded structs", + model: ModelWithEmbedded{}, + expected: []string{"rid_base", "created_at", "name", "description", "cql1", "cql2", "_rownumber"}, + }, + { + name: "Bun model with embedded structs (pointer)", + model: &ModelWithEmbedded{}, + expected: []string{"rid_base", "created_at", "name", "description", "cql1", "cql2", "_rownumber"}, + }, + { + name: "GORM model with embedded structs", + model: GormModelWithEmbedded{}, + expected: []string{"rid_base", "created_at", "name", "description", "cql1", "cql2", "_rownumber"}, + }, + { + name: "GORM model with embedded structs (pointer)", + model: &GormModelWithEmbedded{}, + expected: []string{"rid_base", "created_at", "name", "description", "cql1", "cql2", "_rownumber"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := GetModelColumns(tt.model) + if len(result) != len(tt.expected) { + t.Errorf("GetModelColumns() returned %d columns, want %d. Got: %v", len(result), len(tt.expected), result) + return + } + for i, col := range result { + if col != tt.expected[i] { + t.Errorf("GetModelColumns()[%d] = %v, want %v", i, col, tt.expected[i]) + } + } + }) + } +} + +func TestIsColumnWritableWithEmbedded(t *testing.T) { + tests := []struct { + name string + model any + columnName string + expected bool + }{ + { + name: "Bun model - writable column in main struct", + model: ModelWithEmbedded{}, + columnName: "name", + expected: true, + }, + { + name: "Bun model - writable column in embedded base", + model: ModelWithEmbedded{}, + columnName: "rid_base", + expected: true, + }, + { + name: "Bun model - scanonly column in embedded adhoc buffer", + model: ModelWithEmbedded{}, + columnName: "cql1", + expected: false, + }, + { + name: "Bun model - scanonly column _rownumber", + model: ModelWithEmbedded{}, + columnName: "_rownumber", + expected: false, + }, + { + name: "GORM model - writable column in main struct", + model: GormModelWithEmbedded{}, + columnName: "name", + expected: true, + }, + { + name: "GORM model - writable column in embedded base", + model: GormModelWithEmbedded{}, + columnName: "rid_base", + expected: true, + }, + { + name: "GORM model - readonly column in embedded adhoc buffer", + model: GormModelWithEmbedded{}, + columnName: "cql1", + expected: false, + }, + { + name: "GORM model - readonly column _rownumber", + model: GormModelWithEmbedded{}, + columnName: "_rownumber", + expected: false, // bun:",scanonly" marks it as read-only, takes precedence over gorm:"-" + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := IsColumnWritable(tt.model, tt.columnName) + if result != tt.expected { + t.Errorf("IsColumnWritable(%s) = %v, want %v", tt.columnName, result, tt.expected) + } + }) + } +} diff --git a/pkg/restheadspec/rownumber_test.go b/pkg/restheadspec/rownumber_test.go index 5424eec..c53b995 100644 --- a/pkg/restheadspec/rownumber_test.go +++ b/pkg/restheadspec/rownumber_test.go @@ -10,7 +10,7 @@ import ( type TestModel struct { ID int64 `json:"id" bun:"id,pk"` Name string `json:"name" bun:"name"` - RowNumber int64 `json:"_rownumber,omitempty" gorm:"-" bun:"-"` + RowNumber int64 `json:"_rownumber,omitempty" gorm:"-" bun:",scanonly"` } func TestSetRowNumbersOnRecords(t *testing.T) {