correctly handle structs with embedded fields

This commit is contained in:
Hein 2025-11-20 09:28:37 +02:00
parent 66b6a0d835
commit 35089f511f
4 changed files with 429 additions and 41 deletions

View File

@ -18,6 +18,7 @@ type ModelFieldDetail struct {
} }
// GetModelColumnDetail - Get a list of columns in the SQL declaration of the model // GetModelColumnDetail - Get a list of columns in the SQL declaration of the model
// This function recursively processes embedded structs to include their fields
func GetModelColumnDetail(record reflect.Value) []ModelFieldDetail { func GetModelColumnDetail(record reflect.Value) []ModelFieldDetail {
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
@ -37,14 +38,43 @@ func GetModelColumnDetail(record reflect.Value) []ModelFieldDetail {
if record.Kind() != reflect.Struct { if record.Kind() != reflect.Struct {
return lst return lst
} }
collectFieldDetails(record, &lst)
return lst
}
// collectFieldDetails recursively collects field details from a struct value and its embedded fields
func collectFieldDetails(record reflect.Value, lst *[]ModelFieldDetail) {
modeltype := record.Type() modeltype := record.Type()
for i := 0; i < modeltype.NumField(); i++ { for i := 0; i < modeltype.NumField(); i++ {
fieldtype := modeltype.Field(i) fieldtype := modeltype.Field(i)
fieldValue := record.Field(i)
// Check if this is an embedded struct
if fieldtype.Anonymous {
// Unwrap pointer type if necessary
embeddedValue := fieldValue
if fieldValue.Kind() == reflect.Pointer {
if fieldValue.IsNil() {
// Skip nil embedded pointers
continue
}
embeddedValue = fieldValue.Elem()
}
// Recursively process embedded struct
if embeddedValue.Kind() == reflect.Struct {
collectFieldDetails(embeddedValue, lst)
continue
}
}
gormdetail := fieldtype.Tag.Get("gorm") gormdetail := fieldtype.Tag.Get("gorm")
gormdetail = strings.Trim(gormdetail, " ") gormdetail = strings.Trim(gormdetail, " ")
fielddetail := ModelFieldDetail{} fielddetail := ModelFieldDetail{}
fielddetail.FieldValue = record.Field(i) fielddetail.FieldValue = fieldValue
fielddetail.Name = fieldtype.Name fielddetail.Name = fieldtype.Name
fielddetail.DataType = fieldtype.Type.Name() fielddetail.DataType = fieldtype.Type.Name()
fielddetail.SQLName = fnFindKeyVal(gormdetail, "column:") fielddetail.SQLName = fnFindKeyVal(gormdetail, "column:")
@ -80,10 +110,8 @@ func GetModelColumnDetail(record reflect.Value) []ModelFieldDetail {
} }
// ";foreignkey:rid_parent;association_foreignkey:id_atevent;save_associations:false;association_autocreate:false;" // ";foreignkey:rid_parent;association_foreignkey:id_atevent;save_associations:false;association_autocreate:false;"
lst = append(lst, fielddetail) *lst = append(*lst, fielddetail)
} }
return lst
} }
func fnFindKeyVal(src, key string) string { func fnFindKeyVal(src, key string) string {

View File

@ -47,7 +47,7 @@ func GetPrimaryKeyName(model any) string {
// GetPrimaryKeyValue extracts the primary key value from a model instance // GetPrimaryKeyValue extracts the primary key value from a model instance
// Returns the value of the primary key field // Returns the value of the primary key field
func GetPrimaryKeyValue(model any) interface{} { func GetPrimaryKeyValue(model any) any {
if model == nil || reflect.TypeOf(model) == nil { if model == nil || reflect.TypeOf(model) == nil {
return nil return nil
} }
@ -61,38 +61,51 @@ func GetPrimaryKeyValue(model any) interface{} {
return nil return nil
} }
typ := val.Type()
// Try Bun tag first // Try Bun tag first
for i := 0; i < typ.NumField(); i++ { if pkValue := findPrimaryKeyValue(val, "bun"); pkValue != nil {
field := typ.Field(i) return pkValue
bunTag := field.Tag.Get("bun")
if strings.Contains(bunTag, "pk") {
fieldValue := val.Field(i)
if fieldValue.CanInterface() {
return fieldValue.Interface()
}
}
} }
// Fall back to GORM tag // Fall back to GORM tag
for i := 0; i < typ.NumField(); i++ { if pkValue := findPrimaryKeyValue(val, "gorm"); pkValue != nil {
field := typ.Field(i) return pkValue
gormTag := field.Tag.Get("gorm")
if strings.Contains(gormTag, "primaryKey") {
fieldValue := val.Field(i)
if fieldValue.CanInterface() {
return fieldValue.Interface()
}
}
} }
// Last resort: look for field named "ID" or "Id" // Last resort: look for field named "ID" or "Id"
if pkValue := findFieldByName(val, "id"); pkValue != nil {
return pkValue
}
return nil
}
// findPrimaryKeyValue recursively searches for a primary key field in the struct
func findPrimaryKeyValue(val reflect.Value, ormType string) any {
typ := val.Type()
for i := 0; i < typ.NumField(); i++ { for i := 0; i < typ.NumField(); i++ {
field := typ.Field(i) field := typ.Field(i)
if strings.ToLower(field.Name) == "id" { fieldValue := val.Field(i)
fieldValue := val.Field(i)
if fieldValue.CanInterface() { // Check if this is an embedded struct
if field.Anonymous && field.Type.Kind() == reflect.Struct {
// Recursively search in embedded struct
if pkValue := findPrimaryKeyValue(fieldValue, ormType); pkValue != nil {
return pkValue
}
continue
}
// Check for primary key tag
switch ormType {
case "bun":
bunTag := field.Tag.Get("bun")
if strings.Contains(bunTag, "pk") && fieldValue.CanInterface() {
return fieldValue.Interface()
}
case "gorm":
gormTag := field.Tag.Get("gorm")
if strings.Contains(gormTag, "primaryKey") && fieldValue.CanInterface() {
return fieldValue.Interface() return fieldValue.Interface()
} }
} }
@ -101,8 +114,35 @@ func GetPrimaryKeyValue(model any) interface{} {
return nil return nil
} }
// findFieldByName recursively searches for a field by name in the struct
func findFieldByName(val reflect.Value, name string) any {
typ := val.Type()
for i := 0; i < typ.NumField(); i++ {
field := typ.Field(i)
fieldValue := val.Field(i)
// Check if this is an embedded struct
if field.Anonymous && field.Type.Kind() == reflect.Struct {
// Recursively search in embedded struct
if result := findFieldByName(fieldValue, name); result != nil {
return result
}
continue
}
// Check if field name matches
if strings.ToLower(field.Name) == name && fieldValue.CanInterface() {
return fieldValue.Interface()
}
}
return nil
}
// GetModelColumns extracts all column names from a model using reflection // GetModelColumns extracts all column names from a model using reflection
// It checks bun tags first, then gorm tags, then json tags, and finally falls back to lowercase field names // It checks bun tags first, then gorm tags, then json tags, and finally falls back to lowercase field names
// This function recursively processes embedded structs to include their fields
func GetModelColumns(model any) []string { func GetModelColumns(model any) []string {
var columns []string var columns []string
@ -118,18 +158,38 @@ func GetModelColumns(model any) []string {
return columns return columns
} }
for i := 0; i < modelType.NumField(); i++ { collectColumnsFromType(modelType, &columns)
field := modelType.Field(i)
return columns
}
// collectColumnsFromType recursively collects column names from a struct type and its embedded fields
func collectColumnsFromType(typ reflect.Type, columns *[]string) {
for i := 0; i < typ.NumField(); i++ {
field := typ.Field(i)
// Check if this is an embedded struct
if field.Anonymous {
// Unwrap pointer type if necessary
fieldType := field.Type
if fieldType.Kind() == reflect.Pointer {
fieldType = fieldType.Elem()
}
// Recursively process embedded struct
if fieldType.Kind() == reflect.Struct {
collectColumnsFromType(fieldType, columns)
continue
}
}
// Get column name using the same logic as primary key extraction // Get column name using the same logic as primary key extraction
columnName := getColumnNameFromField(field) columnName := getColumnNameFromField(field)
if columnName != "" { if columnName != "" {
columns = append(columns, columnName) *columns = append(*columns, columnName)
} }
} }
return columns
} }
// getColumnNameFromField extracts the column name from a struct field // getColumnNameFromField extracts the column name from a struct field
@ -166,6 +226,7 @@ func getColumnNameFromField(field reflect.StructField) string {
} }
// getPrimaryKeyFromReflection uses reflection to find the primary key field // getPrimaryKeyFromReflection uses reflection to find the primary key field
// This function recursively searches embedded structs
func getPrimaryKeyFromReflection(model any, ormType string) string { func getPrimaryKeyFromReflection(model any, ormType string) string {
val := reflect.ValueOf(model) val := reflect.ValueOf(model)
if val.Kind() == reflect.Pointer { if val.Kind() == reflect.Pointer {
@ -177,9 +238,31 @@ func getPrimaryKeyFromReflection(model any, ormType string) string {
} }
typ := val.Type() typ := val.Type()
return findPrimaryKeyNameFromType(typ, ormType)
}
// findPrimaryKeyNameFromType recursively searches for the primary key field name in a struct type
func findPrimaryKeyNameFromType(typ reflect.Type, ormType string) string {
for i := 0; i < typ.NumField(); i++ { for i := 0; i < typ.NumField(); i++ {
field := typ.Field(i) field := typ.Field(i)
// Check if this is an embedded struct
if field.Anonymous {
// Unwrap pointer type if necessary
fieldType := field.Type
if fieldType.Kind() == reflect.Pointer {
fieldType = fieldType.Elem()
}
// Recursively search in embedded struct
if fieldType.Kind() == reflect.Struct {
if pkName := findPrimaryKeyNameFromType(fieldType, ormType); pkName != "" {
return pkName
}
}
continue
}
switch ormType { switch ormType {
case "gorm": case "gorm":
// Check for gorm tag with primaryKey // Check for gorm tag with primaryKey
@ -231,6 +314,9 @@ func ExtractColumnFromGormTag(tag string) string {
// Example: ",pk" -> "" (will fall back to json tag) // Example: ",pk" -> "" (will fall back to json tag)
func ExtractColumnFromBunTag(tag string) string { func ExtractColumnFromBunTag(tag string) string {
parts := strings.Split(tag, ",") parts := strings.Split(tag, ",")
if strings.HasPrefix(strings.ToLower(tag), "table:") || strings.HasPrefix(strings.ToLower(tag), "rel:") || strings.HasPrefix(strings.ToLower(tag), "join:") {
return ""
}
if len(parts) > 0 && parts[0] != "" { if len(parts) > 0 && parts[0] != "" {
return parts[0] return parts[0]
} }
@ -240,6 +326,7 @@ func ExtractColumnFromBunTag(tag string) string {
// IsColumnWritable checks if a column can be written to in the database // IsColumnWritable checks if a column can be written to in the database
// For bun: returns false if the field has "scanonly" tag // For bun: returns false if the field has "scanonly" tag
// For gorm: returns false if the field has "<-:false" or "->" (read-only) tag // For gorm: returns false if the field has "<-:false" or "->" (read-only) tag
// This function recursively searches embedded structs
func IsColumnWritable(model any, columnName string) bool { func IsColumnWritable(model any, columnName string) bool {
modelType := reflect.TypeOf(model) modelType := reflect.TypeOf(model)
@ -253,8 +340,37 @@ func IsColumnWritable(model any, columnName string) bool {
return false return false
} }
for i := 0; i < modelType.NumField(); i++ { found, writable := isColumnWritableInType(modelType, columnName)
field := modelType.Field(i) if found {
return writable
}
// Column not found in model, allow it (might be a dynamic column)
return true
}
// isColumnWritableInType recursively searches for a column and checks if it's writable
// Returns (found, writable) where found indicates if the column was found
func isColumnWritableInType(typ reflect.Type, columnName string) (bool, bool) {
for i := 0; i < typ.NumField(); i++ {
field := typ.Field(i)
// Check if this is an embedded struct
if field.Anonymous {
// Unwrap pointer type if necessary
fieldType := field.Type
if fieldType.Kind() == reflect.Pointer {
fieldType = fieldType.Elem()
}
// Recursively search in embedded struct
if fieldType.Kind() == reflect.Struct {
if found, writable := isColumnWritableInType(fieldType, columnName); found {
return true, writable
}
}
continue
}
// Check if this field matches the column name // Check if this field matches the column name
fieldColumnName := getColumnNameFromField(field) fieldColumnName := getColumnNameFromField(field)
@ -262,11 +378,12 @@ func IsColumnWritable(model any, columnName string) bool {
continue continue
} }
// Found the field, now check if it's writable
// Check bun tag for scanonly // Check bun tag for scanonly
bunTag := field.Tag.Get("bun") bunTag := field.Tag.Get("bun")
if bunTag != "" { if bunTag != "" {
if isBunFieldScanOnly(bunTag) { if isBunFieldScanOnly(bunTag) {
return false return true, false
} }
} }
@ -274,16 +391,16 @@ func IsColumnWritable(model any, columnName string) bool {
gormTag := field.Tag.Get("gorm") gormTag := field.Tag.Get("gorm")
if gormTag != "" { if gormTag != "" {
if isGormFieldReadOnly(gormTag) { if isGormFieldReadOnly(gormTag) {
return false return true, false
} }
} }
// Column is writable // Column is writable
return true return true, true
} }
// Column not found in model, allow it (might be a dynamic column) // Column not found
return true return false, false
} }
// isBunFieldScanOnly checks if a bun tag indicates the field is scan-only // isBunFieldScanOnly checks if a bun tag indicates the field is scan-only

View File

@ -231,3 +231,246 @@ func TestGetModelColumns(t *testing.T) {
}) })
} }
} }
// Test models with embedded structs
type BaseModel struct {
ID int `bun:"rid_base,pk" json:"id"`
CreatedAt string `bun:"created_at" json:"created_at"`
}
type AdhocBuffer struct {
CQL1 string `json:"cql1,omitempty" gorm:"->" bun:",scanonly"`
CQL2 string `json:"cql2,omitempty" gorm:"->" bun:",scanonly"`
RowNumber int64 `json:"_rownumber,omitempty" gorm:"-" bun:",scanonly"`
}
type ModelWithEmbedded struct {
BaseModel
Name string `bun:"name" json:"name"`
Description string `bun:"description" json:"description"`
AdhocBuffer
}
type GormBaseModel struct {
ID int `gorm:"column:rid_base;primaryKey" json:"id"`
CreatedAt string `gorm:"column:created_at" json:"created_at"`
}
type GormAdhocBuffer struct {
CQL1 string `json:"cql1,omitempty" gorm:"column:cql1;->" bun:",scanonly"`
CQL2 string `json:"cql2,omitempty" gorm:"column:cql2;->" bun:",scanonly"`
RowNumber int64 `json:"_rownumber,omitempty" gorm:"-" bun:",scanonly"`
}
type GormModelWithEmbedded struct {
GormBaseModel
Name string `gorm:"column:name" json:"name"`
Description string `gorm:"column:description" json:"description"`
GormAdhocBuffer
}
func TestGetPrimaryKeyNameWithEmbedded(t *testing.T) {
tests := []struct {
name string
model any
expected string
}{
{
name: "Bun model with embedded base",
model: ModelWithEmbedded{},
expected: "rid_base",
},
{
name: "Bun model with embedded base (pointer)",
model: &ModelWithEmbedded{},
expected: "rid_base",
},
{
name: "GORM model with embedded base",
model: GormModelWithEmbedded{},
expected: "rid_base",
},
{
name: "GORM model with embedded base (pointer)",
model: &GormModelWithEmbedded{},
expected: "rid_base",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GetPrimaryKeyName(tt.model)
if result != tt.expected {
t.Errorf("GetPrimaryKeyName() = %v, want %v", result, tt.expected)
}
})
}
}
func TestGetPrimaryKeyValueWithEmbedded(t *testing.T) {
bunModel := ModelWithEmbedded{
BaseModel: BaseModel{
ID: 123,
CreatedAt: "2024-01-01",
},
Name: "Test",
Description: "Test Description",
}
gormModel := GormModelWithEmbedded{
GormBaseModel: GormBaseModel{
ID: 456,
CreatedAt: "2024-01-02",
},
Name: "GORM Test",
Description: "GORM Test Description",
}
tests := []struct {
name string
model any
expected any
}{
{
name: "Bun model with embedded base",
model: bunModel,
expected: 123,
},
{
name: "Bun model with embedded base (pointer)",
model: &bunModel,
expected: 123,
},
{
name: "GORM model with embedded base",
model: gormModel,
expected: 456,
},
{
name: "GORM model with embedded base (pointer)",
model: &gormModel,
expected: 456,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GetPrimaryKeyValue(tt.model)
if result != tt.expected {
t.Errorf("GetPrimaryKeyValue() = %v, want %v", result, tt.expected)
}
})
}
}
func TestGetModelColumnsWithEmbedded(t *testing.T) {
tests := []struct {
name string
model any
expected []string
}{
{
name: "Bun model with embedded structs",
model: ModelWithEmbedded{},
expected: []string{"rid_base", "created_at", "name", "description", "cql1", "cql2", "_rownumber"},
},
{
name: "Bun model with embedded structs (pointer)",
model: &ModelWithEmbedded{},
expected: []string{"rid_base", "created_at", "name", "description", "cql1", "cql2", "_rownumber"},
},
{
name: "GORM model with embedded structs",
model: GormModelWithEmbedded{},
expected: []string{"rid_base", "created_at", "name", "description", "cql1", "cql2", "_rownumber"},
},
{
name: "GORM model with embedded structs (pointer)",
model: &GormModelWithEmbedded{},
expected: []string{"rid_base", "created_at", "name", "description", "cql1", "cql2", "_rownumber"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GetModelColumns(tt.model)
if len(result) != len(tt.expected) {
t.Errorf("GetModelColumns() returned %d columns, want %d. Got: %v", len(result), len(tt.expected), result)
return
}
for i, col := range result {
if col != tt.expected[i] {
t.Errorf("GetModelColumns()[%d] = %v, want %v", i, col, tt.expected[i])
}
}
})
}
}
func TestIsColumnWritableWithEmbedded(t *testing.T) {
tests := []struct {
name string
model any
columnName string
expected bool
}{
{
name: "Bun model - writable column in main struct",
model: ModelWithEmbedded{},
columnName: "name",
expected: true,
},
{
name: "Bun model - writable column in embedded base",
model: ModelWithEmbedded{},
columnName: "rid_base",
expected: true,
},
{
name: "Bun model - scanonly column in embedded adhoc buffer",
model: ModelWithEmbedded{},
columnName: "cql1",
expected: false,
},
{
name: "Bun model - scanonly column _rownumber",
model: ModelWithEmbedded{},
columnName: "_rownumber",
expected: false,
},
{
name: "GORM model - writable column in main struct",
model: GormModelWithEmbedded{},
columnName: "name",
expected: true,
},
{
name: "GORM model - writable column in embedded base",
model: GormModelWithEmbedded{},
columnName: "rid_base",
expected: true,
},
{
name: "GORM model - readonly column in embedded adhoc buffer",
model: GormModelWithEmbedded{},
columnName: "cql1",
expected: false,
},
{
name: "GORM model - readonly column _rownumber",
model: GormModelWithEmbedded{},
columnName: "_rownumber",
expected: false, // bun:",scanonly" marks it as read-only, takes precedence over gorm:"-"
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := IsColumnWritable(tt.model, tt.columnName)
if result != tt.expected {
t.Errorf("IsColumnWritable(%s) = %v, want %v", tt.columnName, result, tt.expected)
}
})
}
}

View File

@ -10,7 +10,7 @@ import (
type TestModel struct { type TestModel struct {
ID int64 `json:"id" bun:"id,pk"` ID int64 `json:"id" bun:"id,pk"`
Name string `json:"name" bun:"name"` Name string `json:"name" bun:"name"`
RowNumber int64 `json:"_rownumber,omitempty" gorm:"-" bun:"-"` RowNumber int64 `json:"_rownumber,omitempty" gorm:"-" bun:",scanonly"`
} }
func TestSetRowNumbersOnRecords(t *testing.T) { func TestSetRowNumbersOnRecords(t *testing.T) {