feat(writer): 🎉 Resolve field name collisions with methods
All checks were successful
CI / Test (1.24) (push) Successful in -27m21s
CI / Test (1.25) (push) Successful in -27m12s
CI / Build (push) Successful in -27m37s
CI / Lint (push) Successful in -27m26s
Release / Build and Release (push) Successful in -27m25s
Integration Tests / Integration Tests (push) Successful in -27m20s
All checks were successful
CI / Test (1.24) (push) Successful in -27m21s
CI / Test (1.25) (push) Successful in -27m12s
CI / Build (push) Successful in -27m37s
CI / Lint (push) Successful in -27m26s
Release / Build and Release (push) Successful in -27m25s
Integration Tests / Integration Tests (push) Successful in -27m20s
* Implement field name collision resolution in model generation. * Add tests to verify renaming of fields that conflict with generated method names. * Ensure primary key type safety in UpdateID method.
This commit is contained in:
@@ -149,6 +149,8 @@ func NewModelData(table *models.Table, schema string, typeMapper *TypeMapper) *M
|
|||||||
columns := sortColumns(table.Columns)
|
columns := sortColumns(table.Columns)
|
||||||
for _, col := range columns {
|
for _, col := range columns {
|
||||||
field := columnToField(col, table, typeMapper)
|
field := columnToField(col, table, typeMapper)
|
||||||
|
// Check for name collision with generated methods and rename if needed
|
||||||
|
field.Name = resolveFieldNameCollision(field.Name)
|
||||||
model.Fields = append(model.Fields, field)
|
model.Fields = append(model.Fields, field)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -195,6 +197,30 @@ func hasModelPrefix(name string) bool {
|
|||||||
return len(name) >= 5 && name[:5] == "Model"
|
return len(name) >= 5 && name[:5] == "Model"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// resolveFieldNameCollision checks if a field name conflicts with generated method names
|
||||||
|
// and adds an underscore suffix if there's a collision
|
||||||
|
func resolveFieldNameCollision(fieldName string) string {
|
||||||
|
// List of method names that are generated by the template
|
||||||
|
reservedNames := map[string]bool{
|
||||||
|
"TableName": true,
|
||||||
|
"TableNameOnly": true,
|
||||||
|
"SchemaName": true,
|
||||||
|
"GetID": true,
|
||||||
|
"GetIDStr": true,
|
||||||
|
"SetID": true,
|
||||||
|
"UpdateID": true,
|
||||||
|
"GetIDName": true,
|
||||||
|
"GetPrefix": true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if field name conflicts with a reserved method name
|
||||||
|
if reservedNames[fieldName] {
|
||||||
|
return fieldName + "_"
|
||||||
|
}
|
||||||
|
|
||||||
|
return fieldName
|
||||||
|
}
|
||||||
|
|
||||||
// sortColumns sorts columns by sequence, then by name
|
// sortColumns sorts columns by sequence, then by name
|
||||||
func sortColumns(columns map[string]*models.Column) []*models.Column {
|
func sortColumns(columns map[string]*models.Column) []*models.Column {
|
||||||
result := make([]*models.Column, 0, len(columns))
|
result := make([]*models.Column, 0, len(columns))
|
||||||
|
|||||||
@@ -481,6 +481,74 @@ func TestWriter_MultipleHasManyRelationships(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestWriter_FieldNameCollision(t *testing.T) {
|
||||||
|
// Test scenario: table with columns that would conflict with generated method names
|
||||||
|
table := models.InitTable("audit_table", "audit")
|
||||||
|
table.Columns["id_audit_table"] = &models.Column{
|
||||||
|
Name: "id_audit_table",
|
||||||
|
Type: "smallint",
|
||||||
|
NotNull: true,
|
||||||
|
IsPrimaryKey: true,
|
||||||
|
Sequence: 1,
|
||||||
|
}
|
||||||
|
table.Columns["table_name"] = &models.Column{
|
||||||
|
Name: "table_name",
|
||||||
|
Type: "varchar",
|
||||||
|
Length: 100,
|
||||||
|
NotNull: true,
|
||||||
|
Sequence: 2,
|
||||||
|
}
|
||||||
|
table.Columns["table_schema"] = &models.Column{
|
||||||
|
Name: "table_schema",
|
||||||
|
Type: "varchar",
|
||||||
|
Length: 100,
|
||||||
|
NotNull: true,
|
||||||
|
Sequence: 3,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create writer
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
opts := &writers.WriterOptions{
|
||||||
|
PackageName: "models",
|
||||||
|
OutputPath: filepath.Join(tmpDir, "test.go"),
|
||||||
|
}
|
||||||
|
|
||||||
|
writer := NewWriter(opts)
|
||||||
|
|
||||||
|
err := writer.WriteTable(table)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("WriteTable failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read the generated file
|
||||||
|
content, err := os.ReadFile(opts.OutputPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to read generated file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
generated := string(content)
|
||||||
|
|
||||||
|
// Verify that TableName field was renamed to TableName_ to avoid collision
|
||||||
|
if !strings.Contains(generated, "TableName_") {
|
||||||
|
t.Errorf("Expected field 'TableName_' (with underscore) but not found\nGenerated:\n%s", generated)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the struct tag still references the correct database column
|
||||||
|
if !strings.Contains(generated, `bun:"table_name,`) {
|
||||||
|
t.Errorf("Expected bun tag to reference 'table_name' column\nGenerated:\n%s", generated)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the TableName() method still exists and doesn't conflict
|
||||||
|
if !strings.Contains(generated, "func (m ModelAuditTable) TableName() string") {
|
||||||
|
t.Errorf("TableName() method should still be generated\nGenerated:\n%s", generated)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify NO field named just "TableName" (without underscore)
|
||||||
|
if strings.Contains(generated, "TableName resolvespec_common") || strings.Contains(generated, "TableName string") {
|
||||||
|
t.Errorf("Field 'TableName' without underscore should not exist (would conflict with method)\nGenerated:\n%s", generated)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestTypeMapper_SQLTypeToGoType_Bun(t *testing.T) {
|
func TestTypeMapper_SQLTypeToGoType_Bun(t *testing.T) {
|
||||||
mapper := NewTypeMapper()
|
mapper := NewTypeMapper()
|
||||||
|
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ type ModelData struct {
|
|||||||
Fields []*FieldData
|
Fields []*FieldData
|
||||||
Config *MethodConfig
|
Config *MethodConfig
|
||||||
PrimaryKeyField string // Name of the primary key field
|
PrimaryKeyField string // Name of the primary key field
|
||||||
|
PrimaryKeyType string // Go type of the primary key field
|
||||||
IDColumnName string // Name of the ID column in database
|
IDColumnName string // Name of the ID column in database
|
||||||
Prefix string // 3-letter prefix
|
Prefix string // 3-letter prefix
|
||||||
}
|
}
|
||||||
@@ -135,6 +136,7 @@ func NewModelData(table *models.Table, schema string, typeMapper *TypeMapper) *M
|
|||||||
// Sanitize column name to remove backticks
|
// Sanitize column name to remove backticks
|
||||||
safeName := writers.SanitizeStructTagValue(col.Name)
|
safeName := writers.SanitizeStructTagValue(col.Name)
|
||||||
model.PrimaryKeyField = SnakeCaseToPascalCase(safeName)
|
model.PrimaryKeyField = SnakeCaseToPascalCase(safeName)
|
||||||
|
model.PrimaryKeyType = typeMapper.SQLTypeToGoType(col.Type, col.NotNull)
|
||||||
model.IDColumnName = safeName
|
model.IDColumnName = safeName
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@@ -144,6 +146,8 @@ func NewModelData(table *models.Table, schema string, typeMapper *TypeMapper) *M
|
|||||||
columns := sortColumns(table.Columns)
|
columns := sortColumns(table.Columns)
|
||||||
for _, col := range columns {
|
for _, col := range columns {
|
||||||
field := columnToField(col, table, typeMapper)
|
field := columnToField(col, table, typeMapper)
|
||||||
|
// Check for name collision with generated methods and rename if needed
|
||||||
|
field.Name = resolveFieldNameCollision(field.Name)
|
||||||
model.Fields = append(model.Fields, field)
|
model.Fields = append(model.Fields, field)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -190,6 +194,30 @@ func hasModelPrefix(name string) bool {
|
|||||||
return len(name) >= 5 && name[:5] == "Model"
|
return len(name) >= 5 && name[:5] == "Model"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// resolveFieldNameCollision checks if a field name conflicts with generated method names
|
||||||
|
// and adds an underscore suffix if there's a collision
|
||||||
|
func resolveFieldNameCollision(fieldName string) string {
|
||||||
|
// List of method names that are generated by the template
|
||||||
|
reservedNames := map[string]bool{
|
||||||
|
"TableName": true,
|
||||||
|
"TableNameOnly": true,
|
||||||
|
"SchemaName": true,
|
||||||
|
"GetID": true,
|
||||||
|
"GetIDStr": true,
|
||||||
|
"SetID": true,
|
||||||
|
"UpdateID": true,
|
||||||
|
"GetIDName": true,
|
||||||
|
"GetPrefix": true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if field name conflicts with a reserved method name
|
||||||
|
if reservedNames[fieldName] {
|
||||||
|
return fieldName + "_"
|
||||||
|
}
|
||||||
|
|
||||||
|
return fieldName
|
||||||
|
}
|
||||||
|
|
||||||
// sortColumns sorts columns by sequence, then by name
|
// sortColumns sorts columns by sequence, then by name
|
||||||
func sortColumns(columns map[string]*models.Column) []*models.Column {
|
func sortColumns(columns map[string]*models.Column) []*models.Column {
|
||||||
result := make([]*models.Column, 0, len(columns))
|
result := make([]*models.Column, 0, len(columns))
|
||||||
|
|||||||
@@ -62,7 +62,7 @@ func (m {{.Name}}) SetID(newid int64) {
|
|||||||
{{if and .Config.GenerateUpdateID .PrimaryKeyField}}
|
{{if and .Config.GenerateUpdateID .PrimaryKeyField}}
|
||||||
// UpdateID updates the primary key value
|
// UpdateID updates the primary key value
|
||||||
func (m *{{.Name}}) UpdateID(newid int64) {
|
func (m *{{.Name}}) UpdateID(newid int64) {
|
||||||
m.{{.PrimaryKeyField}} = int32(newid)
|
m.{{.PrimaryKeyField}} = {{.PrimaryKeyType}}(newid)
|
||||||
}
|
}
|
||||||
{{end}}
|
{{end}}
|
||||||
{{if and .Config.GenerateGetIDName .IDColumnName}}
|
{{if and .Config.GenerateGetIDName .IDColumnName}}
|
||||||
|
|||||||
@@ -470,6 +470,134 @@ func TestWriter_MultipleHasManyRelationships(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestWriter_FieldNameCollision(t *testing.T) {
|
||||||
|
// Test scenario: table with columns that would conflict with generated method names
|
||||||
|
table := models.InitTable("audit_table", "audit")
|
||||||
|
table.Columns["id_audit_table"] = &models.Column{
|
||||||
|
Name: "id_audit_table",
|
||||||
|
Type: "smallint",
|
||||||
|
NotNull: true,
|
||||||
|
IsPrimaryKey: true,
|
||||||
|
Sequence: 1,
|
||||||
|
}
|
||||||
|
table.Columns["table_name"] = &models.Column{
|
||||||
|
Name: "table_name",
|
||||||
|
Type: "varchar",
|
||||||
|
Length: 100,
|
||||||
|
NotNull: true,
|
||||||
|
Sequence: 2,
|
||||||
|
}
|
||||||
|
table.Columns["table_schema"] = &models.Column{
|
||||||
|
Name: "table_schema",
|
||||||
|
Type: "varchar",
|
||||||
|
Length: 100,
|
||||||
|
NotNull: true,
|
||||||
|
Sequence: 3,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create writer
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
opts := &writers.WriterOptions{
|
||||||
|
PackageName: "models",
|
||||||
|
OutputPath: filepath.Join(tmpDir, "test.go"),
|
||||||
|
}
|
||||||
|
|
||||||
|
writer := NewWriter(opts)
|
||||||
|
|
||||||
|
err := writer.WriteTable(table)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("WriteTable failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read the generated file
|
||||||
|
content, err := os.ReadFile(opts.OutputPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to read generated file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
generated := string(content)
|
||||||
|
|
||||||
|
// Verify that TableName field was renamed to TableName_ to avoid collision
|
||||||
|
if !strings.Contains(generated, "TableName_") {
|
||||||
|
t.Errorf("Expected field 'TableName_' (with underscore) but not found\nGenerated:\n%s", generated)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the struct tag still references the correct database column
|
||||||
|
if !strings.Contains(generated, `gorm:"column:table_name;`) {
|
||||||
|
t.Errorf("Expected gorm tag to reference 'table_name' column\nGenerated:\n%s", generated)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the TableName() method still exists and doesn't conflict
|
||||||
|
if !strings.Contains(generated, "func (m ModelAuditTable) TableName() string") {
|
||||||
|
t.Errorf("TableName() method should still be generated\nGenerated:\n%s", generated)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify NO field named just "TableName" (without underscore)
|
||||||
|
if strings.Contains(generated, "TableName sql_types") || strings.Contains(generated, "TableName string") {
|
||||||
|
t.Errorf("Field 'TableName' without underscore should not exist (would conflict with method)\nGenerated:\n%s", generated)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWriter_UpdateIDTypeSafety(t *testing.T) {
|
||||||
|
// Test scenario: tables with different primary key types
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
pkType string
|
||||||
|
expectedPK string
|
||||||
|
castType string
|
||||||
|
}{
|
||||||
|
{"int32_pk", "int", "int32", "int32(newid)"},
|
||||||
|
{"int16_pk", "smallint", "int16", "int16(newid)"},
|
||||||
|
{"int64_pk", "bigint", "int64", "int64(newid)"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
table := models.InitTable("test_table", "public")
|
||||||
|
table.Columns["id"] = &models.Column{
|
||||||
|
Name: "id",
|
||||||
|
Type: tt.pkType,
|
||||||
|
NotNull: true,
|
||||||
|
IsPrimaryKey: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
opts := &writers.WriterOptions{
|
||||||
|
PackageName: "models",
|
||||||
|
OutputPath: filepath.Join(tmpDir, "test.go"),
|
||||||
|
}
|
||||||
|
|
||||||
|
writer := NewWriter(opts)
|
||||||
|
err := writer.WriteTable(table)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("WriteTable failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
content, err := os.ReadFile(opts.OutputPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to read generated file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
generated := string(content)
|
||||||
|
|
||||||
|
// Verify UpdateID method has correct type cast
|
||||||
|
if !strings.Contains(generated, tt.castType) {
|
||||||
|
t.Errorf("Expected UpdateID to cast to %s\nGenerated:\n%s", tt.castType, generated)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify no invalid int32(newid) for non-int32 types
|
||||||
|
if tt.expectedPK != "int32" && strings.Contains(generated, "int32(newid)") {
|
||||||
|
t.Errorf("UpdateID should not cast to int32 for %s type\nGenerated:\n%s", tt.pkType, generated)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify UpdateID parameter is int64 (for consistency)
|
||||||
|
if !strings.Contains(generated, "UpdateID(newid int64)") {
|
||||||
|
t.Errorf("UpdateID should accept int64 parameter\nGenerated:\n%s", generated)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestNameConverter_SnakeCaseToPascalCase(t *testing.T) {
|
func TestNameConverter_SnakeCaseToPascalCase(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
input string
|
input string
|
||||||
|
|||||||
Reference in New Issue
Block a user