diff --git a/pkg/writers/bun/template_data.go b/pkg/writers/bun/template_data.go index f607aa7..ab3435f 100644 --- a/pkg/writers/bun/template_data.go +++ b/pkg/writers/bun/template_data.go @@ -149,6 +149,8 @@ func NewModelData(table *models.Table, schema string, typeMapper *TypeMapper) *M columns := sortColumns(table.Columns) for _, col := range columns { field := columnToField(col, table, typeMapper) + // Check for name collision with generated methods and rename if needed + field.Name = resolveFieldNameCollision(field.Name) model.Fields = append(model.Fields, field) } @@ -195,6 +197,30 @@ func hasModelPrefix(name string) bool { return len(name) >= 5 && name[:5] == "Model" } +// resolveFieldNameCollision checks if a field name conflicts with generated method names +// and adds an underscore suffix if there's a collision +func resolveFieldNameCollision(fieldName string) string { + // List of method names that are generated by the template + reservedNames := map[string]bool{ + "TableName": true, + "TableNameOnly": true, + "SchemaName": true, + "GetID": true, + "GetIDStr": true, + "SetID": true, + "UpdateID": true, + "GetIDName": true, + "GetPrefix": true, + } + + // Check if field name conflicts with a reserved method name + if reservedNames[fieldName] { + return fieldName + "_" + } + + return fieldName +} + // sortColumns sorts columns by sequence, then by name func sortColumns(columns map[string]*models.Column) []*models.Column { result := make([]*models.Column, 0, len(columns)) diff --git a/pkg/writers/bun/writer_test.go b/pkg/writers/bun/writer_test.go index a5e86a9..39d9d92 100644 --- a/pkg/writers/bun/writer_test.go +++ b/pkg/writers/bun/writer_test.go @@ -481,6 +481,74 @@ func TestWriter_MultipleHasManyRelationships(t *testing.T) { } } +func TestWriter_FieldNameCollision(t *testing.T) { + // Test scenario: table with columns that would conflict with generated method names + table := models.InitTable("audit_table", "audit") + table.Columns["id_audit_table"] = &models.Column{ + Name: "id_audit_table", + Type: "smallint", + NotNull: true, + IsPrimaryKey: true, + Sequence: 1, + } + table.Columns["table_name"] = &models.Column{ + Name: "table_name", + Type: "varchar", + Length: 100, + NotNull: true, + Sequence: 2, + } + table.Columns["table_schema"] = &models.Column{ + Name: "table_schema", + Type: "varchar", + Length: 100, + NotNull: true, + Sequence: 3, + } + + // Create writer + tmpDir := t.TempDir() + opts := &writers.WriterOptions{ + PackageName: "models", + OutputPath: filepath.Join(tmpDir, "test.go"), + } + + writer := NewWriter(opts) + + err := writer.WriteTable(table) + if err != nil { + t.Fatalf("WriteTable failed: %v", err) + } + + // Read the generated file + content, err := os.ReadFile(opts.OutputPath) + if err != nil { + t.Fatalf("Failed to read generated file: %v", err) + } + + generated := string(content) + + // Verify that TableName field was renamed to TableName_ to avoid collision + if !strings.Contains(generated, "TableName_") { + t.Errorf("Expected field 'TableName_' (with underscore) but not found\nGenerated:\n%s", generated) + } + + // Verify the struct tag still references the correct database column + if !strings.Contains(generated, `bun:"table_name,`) { + t.Errorf("Expected bun tag to reference 'table_name' column\nGenerated:\n%s", generated) + } + + // Verify the TableName() method still exists and doesn't conflict + if !strings.Contains(generated, "func (m ModelAuditTable) TableName() string") { + t.Errorf("TableName() method should still be generated\nGenerated:\n%s", generated) + } + + // Verify NO field named just "TableName" (without underscore) + if strings.Contains(generated, "TableName resolvespec_common") || strings.Contains(generated, "TableName string") { + t.Errorf("Field 'TableName' without underscore should not exist (would conflict with method)\nGenerated:\n%s", generated) + } +} + func TestTypeMapper_SQLTypeToGoType_Bun(t *testing.T) { mapper := NewTypeMapper() diff --git a/pkg/writers/gorm/template_data.go b/pkg/writers/gorm/template_data.go index 9f6251b..a8fe5e2 100644 --- a/pkg/writers/gorm/template_data.go +++ b/pkg/writers/gorm/template_data.go @@ -25,6 +25,7 @@ type ModelData struct { Fields []*FieldData Config *MethodConfig PrimaryKeyField string // Name of the primary key field + PrimaryKeyType string // Go type of the primary key field IDColumnName string // Name of the ID column in database Prefix string // 3-letter prefix } @@ -135,6 +136,7 @@ func NewModelData(table *models.Table, schema string, typeMapper *TypeMapper) *M // Sanitize column name to remove backticks safeName := writers.SanitizeStructTagValue(col.Name) model.PrimaryKeyField = SnakeCaseToPascalCase(safeName) + model.PrimaryKeyType = typeMapper.SQLTypeToGoType(col.Type, col.NotNull) model.IDColumnName = safeName break } @@ -144,6 +146,8 @@ func NewModelData(table *models.Table, schema string, typeMapper *TypeMapper) *M columns := sortColumns(table.Columns) for _, col := range columns { field := columnToField(col, table, typeMapper) + // Check for name collision with generated methods and rename if needed + field.Name = resolveFieldNameCollision(field.Name) model.Fields = append(model.Fields, field) } @@ -190,6 +194,30 @@ func hasModelPrefix(name string) bool { return len(name) >= 5 && name[:5] == "Model" } +// resolveFieldNameCollision checks if a field name conflicts with generated method names +// and adds an underscore suffix if there's a collision +func resolveFieldNameCollision(fieldName string) string { + // List of method names that are generated by the template + reservedNames := map[string]bool{ + "TableName": true, + "TableNameOnly": true, + "SchemaName": true, + "GetID": true, + "GetIDStr": true, + "SetID": true, + "UpdateID": true, + "GetIDName": true, + "GetPrefix": true, + } + + // Check if field name conflicts with a reserved method name + if reservedNames[fieldName] { + return fieldName + "_" + } + + return fieldName +} + // sortColumns sorts columns by sequence, then by name func sortColumns(columns map[string]*models.Column) []*models.Column { result := make([]*models.Column, 0, len(columns)) diff --git a/pkg/writers/gorm/templates.go b/pkg/writers/gorm/templates.go index a7b1f00..5ff0cd5 100644 --- a/pkg/writers/gorm/templates.go +++ b/pkg/writers/gorm/templates.go @@ -62,7 +62,7 @@ func (m {{.Name}}) SetID(newid int64) { {{if and .Config.GenerateUpdateID .PrimaryKeyField}} // UpdateID updates the primary key value func (m *{{.Name}}) UpdateID(newid int64) { - m.{{.PrimaryKeyField}} = int32(newid) + m.{{.PrimaryKeyField}} = {{.PrimaryKeyType}}(newid) } {{end}} {{if and .Config.GenerateGetIDName .IDColumnName}} diff --git a/pkg/writers/gorm/writer_test.go b/pkg/writers/gorm/writer_test.go index ef3f90b..d204a31 100644 --- a/pkg/writers/gorm/writer_test.go +++ b/pkg/writers/gorm/writer_test.go @@ -470,6 +470,134 @@ func TestWriter_MultipleHasManyRelationships(t *testing.T) { } } +func TestWriter_FieldNameCollision(t *testing.T) { + // Test scenario: table with columns that would conflict with generated method names + table := models.InitTable("audit_table", "audit") + table.Columns["id_audit_table"] = &models.Column{ + Name: "id_audit_table", + Type: "smallint", + NotNull: true, + IsPrimaryKey: true, + Sequence: 1, + } + table.Columns["table_name"] = &models.Column{ + Name: "table_name", + Type: "varchar", + Length: 100, + NotNull: true, + Sequence: 2, + } + table.Columns["table_schema"] = &models.Column{ + Name: "table_schema", + Type: "varchar", + Length: 100, + NotNull: true, + Sequence: 3, + } + + // Create writer + tmpDir := t.TempDir() + opts := &writers.WriterOptions{ + PackageName: "models", + OutputPath: filepath.Join(tmpDir, "test.go"), + } + + writer := NewWriter(opts) + + err := writer.WriteTable(table) + if err != nil { + t.Fatalf("WriteTable failed: %v", err) + } + + // Read the generated file + content, err := os.ReadFile(opts.OutputPath) + if err != nil { + t.Fatalf("Failed to read generated file: %v", err) + } + + generated := string(content) + + // Verify that TableName field was renamed to TableName_ to avoid collision + if !strings.Contains(generated, "TableName_") { + t.Errorf("Expected field 'TableName_' (with underscore) but not found\nGenerated:\n%s", generated) + } + + // Verify the struct tag still references the correct database column + if !strings.Contains(generated, `gorm:"column:table_name;`) { + t.Errorf("Expected gorm tag to reference 'table_name' column\nGenerated:\n%s", generated) + } + + // Verify the TableName() method still exists and doesn't conflict + if !strings.Contains(generated, "func (m ModelAuditTable) TableName() string") { + t.Errorf("TableName() method should still be generated\nGenerated:\n%s", generated) + } + + // Verify NO field named just "TableName" (without underscore) + if strings.Contains(generated, "TableName sql_types") || strings.Contains(generated, "TableName string") { + t.Errorf("Field 'TableName' without underscore should not exist (would conflict with method)\nGenerated:\n%s", generated) + } +} + +func TestWriter_UpdateIDTypeSafety(t *testing.T) { + // Test scenario: tables with different primary key types + tests := []struct { + name string + pkType string + expectedPK string + castType string + }{ + {"int32_pk", "int", "int32", "int32(newid)"}, + {"int16_pk", "smallint", "int16", "int16(newid)"}, + {"int64_pk", "bigint", "int64", "int64(newid)"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + table := models.InitTable("test_table", "public") + table.Columns["id"] = &models.Column{ + Name: "id", + Type: tt.pkType, + NotNull: true, + IsPrimaryKey: true, + } + + tmpDir := t.TempDir() + opts := &writers.WriterOptions{ + PackageName: "models", + OutputPath: filepath.Join(tmpDir, "test.go"), + } + + writer := NewWriter(opts) + err := writer.WriteTable(table) + if err != nil { + t.Fatalf("WriteTable failed: %v", err) + } + + content, err := os.ReadFile(opts.OutputPath) + if err != nil { + t.Fatalf("Failed to read generated file: %v", err) + } + + generated := string(content) + + // Verify UpdateID method has correct type cast + if !strings.Contains(generated, tt.castType) { + t.Errorf("Expected UpdateID to cast to %s\nGenerated:\n%s", tt.castType, generated) + } + + // Verify no invalid int32(newid) for non-int32 types + if tt.expectedPK != "int32" && strings.Contains(generated, "int32(newid)") { + t.Errorf("UpdateID should not cast to int32 for %s type\nGenerated:\n%s", tt.pkType, generated) + } + + // Verify UpdateID parameter is int64 (for consistency) + if !strings.Contains(generated, "UpdateID(newid int64)") { + t.Errorf("UpdateID should accept int64 parameter\nGenerated:\n%s", generated) + } + }) + } +} + func TestNameConverter_SnakeCaseToPascalCase(t *testing.T) { tests := []struct { input string