mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2025-12-06 14:26:22 +00:00
correctly handle structs with embedded fields
This commit is contained in:
parent
66b6a0d835
commit
35089f511f
@ -18,6 +18,7 @@ type ModelFieldDetail struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetModelColumnDetail - Get a list of columns in the SQL declaration of the model
|
// GetModelColumnDetail - Get a list of columns in the SQL declaration of the model
|
||||||
|
// This function recursively processes embedded structs to include their fields
|
||||||
func GetModelColumnDetail(record reflect.Value) []ModelFieldDetail {
|
func GetModelColumnDetail(record reflect.Value) []ModelFieldDetail {
|
||||||
defer func() {
|
defer func() {
|
||||||
if r := recover(); r != nil {
|
if r := recover(); r != nil {
|
||||||
@ -37,14 +38,43 @@ func GetModelColumnDetail(record reflect.Value) []ModelFieldDetail {
|
|||||||
if record.Kind() != reflect.Struct {
|
if record.Kind() != reflect.Struct {
|
||||||
return lst
|
return lst
|
||||||
}
|
}
|
||||||
|
|
||||||
|
collectFieldDetails(record, &lst)
|
||||||
|
|
||||||
|
return lst
|
||||||
|
}
|
||||||
|
|
||||||
|
// collectFieldDetails recursively collects field details from a struct value and its embedded fields
|
||||||
|
func collectFieldDetails(record reflect.Value, lst *[]ModelFieldDetail) {
|
||||||
modeltype := record.Type()
|
modeltype := record.Type()
|
||||||
|
|
||||||
for i := 0; i < modeltype.NumField(); i++ {
|
for i := 0; i < modeltype.NumField(); i++ {
|
||||||
fieldtype := modeltype.Field(i)
|
fieldtype := modeltype.Field(i)
|
||||||
|
fieldValue := record.Field(i)
|
||||||
|
|
||||||
|
// Check if this is an embedded struct
|
||||||
|
if fieldtype.Anonymous {
|
||||||
|
// Unwrap pointer type if necessary
|
||||||
|
embeddedValue := fieldValue
|
||||||
|
if fieldValue.Kind() == reflect.Pointer {
|
||||||
|
if fieldValue.IsNil() {
|
||||||
|
// Skip nil embedded pointers
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
embeddedValue = fieldValue.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Recursively process embedded struct
|
||||||
|
if embeddedValue.Kind() == reflect.Struct {
|
||||||
|
collectFieldDetails(embeddedValue, lst)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
gormdetail := fieldtype.Tag.Get("gorm")
|
gormdetail := fieldtype.Tag.Get("gorm")
|
||||||
gormdetail = strings.Trim(gormdetail, " ")
|
gormdetail = strings.Trim(gormdetail, " ")
|
||||||
fielddetail := ModelFieldDetail{}
|
fielddetail := ModelFieldDetail{}
|
||||||
fielddetail.FieldValue = record.Field(i)
|
fielddetail.FieldValue = fieldValue
|
||||||
fielddetail.Name = fieldtype.Name
|
fielddetail.Name = fieldtype.Name
|
||||||
fielddetail.DataType = fieldtype.Type.Name()
|
fielddetail.DataType = fieldtype.Type.Name()
|
||||||
fielddetail.SQLName = fnFindKeyVal(gormdetail, "column:")
|
fielddetail.SQLName = fnFindKeyVal(gormdetail, "column:")
|
||||||
@ -80,10 +110,8 @@ func GetModelColumnDetail(record reflect.Value) []ModelFieldDetail {
|
|||||||
}
|
}
|
||||||
// ";foreignkey:rid_parent;association_foreignkey:id_atevent;save_associations:false;association_autocreate:false;"
|
// ";foreignkey:rid_parent;association_foreignkey:id_atevent;save_associations:false;association_autocreate:false;"
|
||||||
|
|
||||||
lst = append(lst, fielddetail)
|
*lst = append(*lst, fielddetail)
|
||||||
|
|
||||||
}
|
}
|
||||||
return lst
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func fnFindKeyVal(src, key string) string {
|
func fnFindKeyVal(src, key string) string {
|
||||||
|
|||||||
@ -47,7 +47,7 @@ func GetPrimaryKeyName(model any) string {
|
|||||||
|
|
||||||
// GetPrimaryKeyValue extracts the primary key value from a model instance
|
// GetPrimaryKeyValue extracts the primary key value from a model instance
|
||||||
// Returns the value of the primary key field
|
// Returns the value of the primary key field
|
||||||
func GetPrimaryKeyValue(model any) interface{} {
|
func GetPrimaryKeyValue(model any) any {
|
||||||
if model == nil || reflect.TypeOf(model) == nil {
|
if model == nil || reflect.TypeOf(model) == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -61,38 +61,51 @@ func GetPrimaryKeyValue(model any) interface{} {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
typ := val.Type()
|
|
||||||
|
|
||||||
// Try Bun tag first
|
// Try Bun tag first
|
||||||
for i := 0; i < typ.NumField(); i++ {
|
if pkValue := findPrimaryKeyValue(val, "bun"); pkValue != nil {
|
||||||
field := typ.Field(i)
|
return pkValue
|
||||||
bunTag := field.Tag.Get("bun")
|
|
||||||
if strings.Contains(bunTag, "pk") {
|
|
||||||
fieldValue := val.Field(i)
|
|
||||||
if fieldValue.CanInterface() {
|
|
||||||
return fieldValue.Interface()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fall back to GORM tag
|
// Fall back to GORM tag
|
||||||
for i := 0; i < typ.NumField(); i++ {
|
if pkValue := findPrimaryKeyValue(val, "gorm"); pkValue != nil {
|
||||||
field := typ.Field(i)
|
return pkValue
|
||||||
gormTag := field.Tag.Get("gorm")
|
|
||||||
if strings.Contains(gormTag, "primaryKey") {
|
|
||||||
fieldValue := val.Field(i)
|
|
||||||
if fieldValue.CanInterface() {
|
|
||||||
return fieldValue.Interface()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Last resort: look for field named "ID" or "Id"
|
// Last resort: look for field named "ID" or "Id"
|
||||||
|
if pkValue := findFieldByName(val, "id"); pkValue != nil {
|
||||||
|
return pkValue
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// findPrimaryKeyValue recursively searches for a primary key field in the struct
|
||||||
|
func findPrimaryKeyValue(val reflect.Value, ormType string) any {
|
||||||
|
typ := val.Type()
|
||||||
|
|
||||||
for i := 0; i < typ.NumField(); i++ {
|
for i := 0; i < typ.NumField(); i++ {
|
||||||
field := typ.Field(i)
|
field := typ.Field(i)
|
||||||
if strings.ToLower(field.Name) == "id" {
|
|
||||||
fieldValue := val.Field(i)
|
fieldValue := val.Field(i)
|
||||||
if fieldValue.CanInterface() {
|
|
||||||
|
// Check if this is an embedded struct
|
||||||
|
if field.Anonymous && field.Type.Kind() == reflect.Struct {
|
||||||
|
// Recursively search in embedded struct
|
||||||
|
if pkValue := findPrimaryKeyValue(fieldValue, ormType); pkValue != nil {
|
||||||
|
return pkValue
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for primary key tag
|
||||||
|
switch ormType {
|
||||||
|
case "bun":
|
||||||
|
bunTag := field.Tag.Get("bun")
|
||||||
|
if strings.Contains(bunTag, "pk") && fieldValue.CanInterface() {
|
||||||
|
return fieldValue.Interface()
|
||||||
|
}
|
||||||
|
case "gorm":
|
||||||
|
gormTag := field.Tag.Get("gorm")
|
||||||
|
if strings.Contains(gormTag, "primaryKey") && fieldValue.CanInterface() {
|
||||||
return fieldValue.Interface()
|
return fieldValue.Interface()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -101,8 +114,35 @@ func GetPrimaryKeyValue(model any) interface{} {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// findFieldByName recursively searches for a field by name in the struct
|
||||||
|
func findFieldByName(val reflect.Value, name string) any {
|
||||||
|
typ := val.Type()
|
||||||
|
|
||||||
|
for i := 0; i < typ.NumField(); i++ {
|
||||||
|
field := typ.Field(i)
|
||||||
|
fieldValue := val.Field(i)
|
||||||
|
|
||||||
|
// Check if this is an embedded struct
|
||||||
|
if field.Anonymous && field.Type.Kind() == reflect.Struct {
|
||||||
|
// Recursively search in embedded struct
|
||||||
|
if result := findFieldByName(fieldValue, name); result != nil {
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if field name matches
|
||||||
|
if strings.ToLower(field.Name) == name && fieldValue.CanInterface() {
|
||||||
|
return fieldValue.Interface()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// GetModelColumns extracts all column names from a model using reflection
|
// GetModelColumns extracts all column names from a model using reflection
|
||||||
// It checks bun tags first, then gorm tags, then json tags, and finally falls back to lowercase field names
|
// It checks bun tags first, then gorm tags, then json tags, and finally falls back to lowercase field names
|
||||||
|
// This function recursively processes embedded structs to include their fields
|
||||||
func GetModelColumns(model any) []string {
|
func GetModelColumns(model any) []string {
|
||||||
var columns []string
|
var columns []string
|
||||||
|
|
||||||
@ -118,18 +158,38 @@ func GetModelColumns(model any) []string {
|
|||||||
return columns
|
return columns
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; i < modelType.NumField(); i++ {
|
collectColumnsFromType(modelType, &columns)
|
||||||
field := modelType.Field(i)
|
|
||||||
|
return columns
|
||||||
|
}
|
||||||
|
|
||||||
|
// collectColumnsFromType recursively collects column names from a struct type and its embedded fields
|
||||||
|
func collectColumnsFromType(typ reflect.Type, columns *[]string) {
|
||||||
|
for i := 0; i < typ.NumField(); i++ {
|
||||||
|
field := typ.Field(i)
|
||||||
|
|
||||||
|
// Check if this is an embedded struct
|
||||||
|
if field.Anonymous {
|
||||||
|
// Unwrap pointer type if necessary
|
||||||
|
fieldType := field.Type
|
||||||
|
if fieldType.Kind() == reflect.Pointer {
|
||||||
|
fieldType = fieldType.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Recursively process embedded struct
|
||||||
|
if fieldType.Kind() == reflect.Struct {
|
||||||
|
collectColumnsFromType(fieldType, columns)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Get column name using the same logic as primary key extraction
|
// Get column name using the same logic as primary key extraction
|
||||||
columnName := getColumnNameFromField(field)
|
columnName := getColumnNameFromField(field)
|
||||||
|
|
||||||
if columnName != "" {
|
if columnName != "" {
|
||||||
columns = append(columns, columnName)
|
*columns = append(*columns, columnName)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return columns
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// getColumnNameFromField extracts the column name from a struct field
|
// getColumnNameFromField extracts the column name from a struct field
|
||||||
@ -166,6 +226,7 @@ func getColumnNameFromField(field reflect.StructField) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// getPrimaryKeyFromReflection uses reflection to find the primary key field
|
// getPrimaryKeyFromReflection uses reflection to find the primary key field
|
||||||
|
// This function recursively searches embedded structs
|
||||||
func getPrimaryKeyFromReflection(model any, ormType string) string {
|
func getPrimaryKeyFromReflection(model any, ormType string) string {
|
||||||
val := reflect.ValueOf(model)
|
val := reflect.ValueOf(model)
|
||||||
if val.Kind() == reflect.Pointer {
|
if val.Kind() == reflect.Pointer {
|
||||||
@ -177,9 +238,31 @@ func getPrimaryKeyFromReflection(model any, ormType string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
typ := val.Type()
|
typ := val.Type()
|
||||||
|
return findPrimaryKeyNameFromType(typ, ormType)
|
||||||
|
}
|
||||||
|
|
||||||
|
// findPrimaryKeyNameFromType recursively searches for the primary key field name in a struct type
|
||||||
|
func findPrimaryKeyNameFromType(typ reflect.Type, ormType string) string {
|
||||||
for i := 0; i < typ.NumField(); i++ {
|
for i := 0; i < typ.NumField(); i++ {
|
||||||
field := typ.Field(i)
|
field := typ.Field(i)
|
||||||
|
|
||||||
|
// Check if this is an embedded struct
|
||||||
|
if field.Anonymous {
|
||||||
|
// Unwrap pointer type if necessary
|
||||||
|
fieldType := field.Type
|
||||||
|
if fieldType.Kind() == reflect.Pointer {
|
||||||
|
fieldType = fieldType.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Recursively search in embedded struct
|
||||||
|
if fieldType.Kind() == reflect.Struct {
|
||||||
|
if pkName := findPrimaryKeyNameFromType(fieldType, ormType); pkName != "" {
|
||||||
|
return pkName
|
||||||
|
}
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
switch ormType {
|
switch ormType {
|
||||||
case "gorm":
|
case "gorm":
|
||||||
// Check for gorm tag with primaryKey
|
// Check for gorm tag with primaryKey
|
||||||
@ -231,6 +314,9 @@ func ExtractColumnFromGormTag(tag string) string {
|
|||||||
// Example: ",pk" -> "" (will fall back to json tag)
|
// Example: ",pk" -> "" (will fall back to json tag)
|
||||||
func ExtractColumnFromBunTag(tag string) string {
|
func ExtractColumnFromBunTag(tag string) string {
|
||||||
parts := strings.Split(tag, ",")
|
parts := strings.Split(tag, ",")
|
||||||
|
if strings.HasPrefix(strings.ToLower(tag), "table:") || strings.HasPrefix(strings.ToLower(tag), "rel:") || strings.HasPrefix(strings.ToLower(tag), "join:") {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
if len(parts) > 0 && parts[0] != "" {
|
if len(parts) > 0 && parts[0] != "" {
|
||||||
return parts[0]
|
return parts[0]
|
||||||
}
|
}
|
||||||
@ -240,6 +326,7 @@ func ExtractColumnFromBunTag(tag string) string {
|
|||||||
// IsColumnWritable checks if a column can be written to in the database
|
// IsColumnWritable checks if a column can be written to in the database
|
||||||
// For bun: returns false if the field has "scanonly" tag
|
// For bun: returns false if the field has "scanonly" tag
|
||||||
// For gorm: returns false if the field has "<-:false" or "->" (read-only) tag
|
// For gorm: returns false if the field has "<-:false" or "->" (read-only) tag
|
||||||
|
// This function recursively searches embedded structs
|
||||||
func IsColumnWritable(model any, columnName string) bool {
|
func IsColumnWritable(model any, columnName string) bool {
|
||||||
modelType := reflect.TypeOf(model)
|
modelType := reflect.TypeOf(model)
|
||||||
|
|
||||||
@ -253,8 +340,37 @@ func IsColumnWritable(model any, columnName string) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; i < modelType.NumField(); i++ {
|
found, writable := isColumnWritableInType(modelType, columnName)
|
||||||
field := modelType.Field(i)
|
if found {
|
||||||
|
return writable
|
||||||
|
}
|
||||||
|
|
||||||
|
// Column not found in model, allow it (might be a dynamic column)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// isColumnWritableInType recursively searches for a column and checks if it's writable
|
||||||
|
// Returns (found, writable) where found indicates if the column was found
|
||||||
|
func isColumnWritableInType(typ reflect.Type, columnName string) (bool, bool) {
|
||||||
|
for i := 0; i < typ.NumField(); i++ {
|
||||||
|
field := typ.Field(i)
|
||||||
|
|
||||||
|
// Check if this is an embedded struct
|
||||||
|
if field.Anonymous {
|
||||||
|
// Unwrap pointer type if necessary
|
||||||
|
fieldType := field.Type
|
||||||
|
if fieldType.Kind() == reflect.Pointer {
|
||||||
|
fieldType = fieldType.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Recursively search in embedded struct
|
||||||
|
if fieldType.Kind() == reflect.Struct {
|
||||||
|
if found, writable := isColumnWritableInType(fieldType, columnName); found {
|
||||||
|
return true, writable
|
||||||
|
}
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
// Check if this field matches the column name
|
// Check if this field matches the column name
|
||||||
fieldColumnName := getColumnNameFromField(field)
|
fieldColumnName := getColumnNameFromField(field)
|
||||||
@ -262,11 +378,12 @@ func IsColumnWritable(model any, columnName string) bool {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Found the field, now check if it's writable
|
||||||
// Check bun tag for scanonly
|
// Check bun tag for scanonly
|
||||||
bunTag := field.Tag.Get("bun")
|
bunTag := field.Tag.Get("bun")
|
||||||
if bunTag != "" {
|
if bunTag != "" {
|
||||||
if isBunFieldScanOnly(bunTag) {
|
if isBunFieldScanOnly(bunTag) {
|
||||||
return false
|
return true, false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -274,16 +391,16 @@ func IsColumnWritable(model any, columnName string) bool {
|
|||||||
gormTag := field.Tag.Get("gorm")
|
gormTag := field.Tag.Get("gorm")
|
||||||
if gormTag != "" {
|
if gormTag != "" {
|
||||||
if isGormFieldReadOnly(gormTag) {
|
if isGormFieldReadOnly(gormTag) {
|
||||||
return false
|
return true, false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Column is writable
|
// Column is writable
|
||||||
return true
|
return true, true
|
||||||
}
|
}
|
||||||
|
|
||||||
// Column not found in model, allow it (might be a dynamic column)
|
// Column not found
|
||||||
return true
|
return false, false
|
||||||
}
|
}
|
||||||
|
|
||||||
// isBunFieldScanOnly checks if a bun tag indicates the field is scan-only
|
// isBunFieldScanOnly checks if a bun tag indicates the field is scan-only
|
||||||
|
|||||||
@ -231,3 +231,246 @@ func TestGetModelColumns(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test models with embedded structs
|
||||||
|
|
||||||
|
type BaseModel struct {
|
||||||
|
ID int `bun:"rid_base,pk" json:"id"`
|
||||||
|
CreatedAt string `bun:"created_at" json:"created_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type AdhocBuffer struct {
|
||||||
|
CQL1 string `json:"cql1,omitempty" gorm:"->" bun:",scanonly"`
|
||||||
|
CQL2 string `json:"cql2,omitempty" gorm:"->" bun:",scanonly"`
|
||||||
|
RowNumber int64 `json:"_rownumber,omitempty" gorm:"-" bun:",scanonly"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ModelWithEmbedded struct {
|
||||||
|
BaseModel
|
||||||
|
Name string `bun:"name" json:"name"`
|
||||||
|
Description string `bun:"description" json:"description"`
|
||||||
|
AdhocBuffer
|
||||||
|
}
|
||||||
|
|
||||||
|
type GormBaseModel struct {
|
||||||
|
ID int `gorm:"column:rid_base;primaryKey" json:"id"`
|
||||||
|
CreatedAt string `gorm:"column:created_at" json:"created_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GormAdhocBuffer struct {
|
||||||
|
CQL1 string `json:"cql1,omitempty" gorm:"column:cql1;->" bun:",scanonly"`
|
||||||
|
CQL2 string `json:"cql2,omitempty" gorm:"column:cql2;->" bun:",scanonly"`
|
||||||
|
RowNumber int64 `json:"_rownumber,omitempty" gorm:"-" bun:",scanonly"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GormModelWithEmbedded struct {
|
||||||
|
GormBaseModel
|
||||||
|
Name string `gorm:"column:name" json:"name"`
|
||||||
|
Description string `gorm:"column:description" json:"description"`
|
||||||
|
GormAdhocBuffer
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetPrimaryKeyNameWithEmbedded(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
model any
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Bun model with embedded base",
|
||||||
|
model: ModelWithEmbedded{},
|
||||||
|
expected: "rid_base",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Bun model with embedded base (pointer)",
|
||||||
|
model: &ModelWithEmbedded{},
|
||||||
|
expected: "rid_base",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "GORM model with embedded base",
|
||||||
|
model: GormModelWithEmbedded{},
|
||||||
|
expected: "rid_base",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "GORM model with embedded base (pointer)",
|
||||||
|
model: &GormModelWithEmbedded{},
|
||||||
|
expected: "rid_base",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := GetPrimaryKeyName(tt.model)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("GetPrimaryKeyName() = %v, want %v", result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetPrimaryKeyValueWithEmbedded(t *testing.T) {
|
||||||
|
bunModel := ModelWithEmbedded{
|
||||||
|
BaseModel: BaseModel{
|
||||||
|
ID: 123,
|
||||||
|
CreatedAt: "2024-01-01",
|
||||||
|
},
|
||||||
|
Name: "Test",
|
||||||
|
Description: "Test Description",
|
||||||
|
}
|
||||||
|
|
||||||
|
gormModel := GormModelWithEmbedded{
|
||||||
|
GormBaseModel: GormBaseModel{
|
||||||
|
ID: 456,
|
||||||
|
CreatedAt: "2024-01-02",
|
||||||
|
},
|
||||||
|
Name: "GORM Test",
|
||||||
|
Description: "GORM Test Description",
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
model any
|
||||||
|
expected any
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Bun model with embedded base",
|
||||||
|
model: bunModel,
|
||||||
|
expected: 123,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Bun model with embedded base (pointer)",
|
||||||
|
model: &bunModel,
|
||||||
|
expected: 123,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "GORM model with embedded base",
|
||||||
|
model: gormModel,
|
||||||
|
expected: 456,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "GORM model with embedded base (pointer)",
|
||||||
|
model: &gormModel,
|
||||||
|
expected: 456,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := GetPrimaryKeyValue(tt.model)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("GetPrimaryKeyValue() = %v, want %v", result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetModelColumnsWithEmbedded(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
model any
|
||||||
|
expected []string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Bun model with embedded structs",
|
||||||
|
model: ModelWithEmbedded{},
|
||||||
|
expected: []string{"rid_base", "created_at", "name", "description", "cql1", "cql2", "_rownumber"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Bun model with embedded structs (pointer)",
|
||||||
|
model: &ModelWithEmbedded{},
|
||||||
|
expected: []string{"rid_base", "created_at", "name", "description", "cql1", "cql2", "_rownumber"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "GORM model with embedded structs",
|
||||||
|
model: GormModelWithEmbedded{},
|
||||||
|
expected: []string{"rid_base", "created_at", "name", "description", "cql1", "cql2", "_rownumber"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "GORM model with embedded structs (pointer)",
|
||||||
|
model: &GormModelWithEmbedded{},
|
||||||
|
expected: []string{"rid_base", "created_at", "name", "description", "cql1", "cql2", "_rownumber"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := GetModelColumns(tt.model)
|
||||||
|
if len(result) != len(tt.expected) {
|
||||||
|
t.Errorf("GetModelColumns() returned %d columns, want %d. Got: %v", len(result), len(tt.expected), result)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for i, col := range result {
|
||||||
|
if col != tt.expected[i] {
|
||||||
|
t.Errorf("GetModelColumns()[%d] = %v, want %v", i, col, tt.expected[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsColumnWritableWithEmbedded(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
model any
|
||||||
|
columnName string
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Bun model - writable column in main struct",
|
||||||
|
model: ModelWithEmbedded{},
|
||||||
|
columnName: "name",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Bun model - writable column in embedded base",
|
||||||
|
model: ModelWithEmbedded{},
|
||||||
|
columnName: "rid_base",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Bun model - scanonly column in embedded adhoc buffer",
|
||||||
|
model: ModelWithEmbedded{},
|
||||||
|
columnName: "cql1",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Bun model - scanonly column _rownumber",
|
||||||
|
model: ModelWithEmbedded{},
|
||||||
|
columnName: "_rownumber",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "GORM model - writable column in main struct",
|
||||||
|
model: GormModelWithEmbedded{},
|
||||||
|
columnName: "name",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "GORM model - writable column in embedded base",
|
||||||
|
model: GormModelWithEmbedded{},
|
||||||
|
columnName: "rid_base",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "GORM model - readonly column in embedded adhoc buffer",
|
||||||
|
model: GormModelWithEmbedded{},
|
||||||
|
columnName: "cql1",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "GORM model - readonly column _rownumber",
|
||||||
|
model: GormModelWithEmbedded{},
|
||||||
|
columnName: "_rownumber",
|
||||||
|
expected: false, // bun:",scanonly" marks it as read-only, takes precedence over gorm:"-"
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := IsColumnWritable(tt.model, tt.columnName)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("IsColumnWritable(%s) = %v, want %v", tt.columnName, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@ -10,7 +10,7 @@ import (
|
|||||||
type TestModel struct {
|
type TestModel struct {
|
||||||
ID int64 `json:"id" bun:"id,pk"`
|
ID int64 `json:"id" bun:"id,pk"`
|
||||||
Name string `json:"name" bun:"name"`
|
Name string `json:"name" bun:"name"`
|
||||||
RowNumber int64 `json:"_rownumber,omitempty" gorm:"-" bun:"-"`
|
RowNumber int64 `json:"_rownumber,omitempty" gorm:"-" bun:",scanonly"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSetRowNumbersOnRecords(t *testing.T) {
|
func TestSetRowNumbersOnRecords(t *testing.T) {
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user