Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 6f55505444 | |||
| e0e7b64c69 |
@@ -112,13 +112,17 @@ func NewModelData(table *models.Table, schema string, typeMapper *TypeMapper) *M
|
||||
tableName = schema + "." + table.Name
|
||||
}
|
||||
|
||||
// Generate model name: singularize and convert to PascalCase
|
||||
// Generate model name: Model + Schema + Table (all PascalCase)
|
||||
singularTable := Singularize(table.Name)
|
||||
modelName := SnakeCaseToPascalCase(singularTable)
|
||||
tablePart := SnakeCaseToPascalCase(singularTable)
|
||||
|
||||
// Add "Model" prefix if not already present
|
||||
if !hasModelPrefix(modelName) {
|
||||
modelName = "Model" + modelName
|
||||
// Include schema name in model name
|
||||
var modelName string
|
||||
if schema != "" {
|
||||
schemaPart := SnakeCaseToPascalCase(schema)
|
||||
modelName = "Model" + schemaPart + tablePart
|
||||
} else {
|
||||
modelName = "Model" + tablePart
|
||||
}
|
||||
|
||||
model := &ModelData{
|
||||
@@ -149,6 +153,8 @@ func NewModelData(table *models.Table, schema string, typeMapper *TypeMapper) *M
|
||||
columns := sortColumns(table.Columns)
|
||||
for _, col := range columns {
|
||||
field := columnToField(col, table, typeMapper)
|
||||
// Check for name collision with generated methods and rename if needed
|
||||
field.Name = resolveFieldNameCollision(field.Name)
|
||||
model.Fields = append(model.Fields, field)
|
||||
}
|
||||
|
||||
@@ -190,9 +196,28 @@ func formatComment(description, comment string) string {
|
||||
return comment
|
||||
}
|
||||
|
||||
// hasModelPrefix checks if a name already has "Model" prefix
|
||||
func hasModelPrefix(name string) bool {
|
||||
return len(name) >= 5 && name[:5] == "Model"
|
||||
// resolveFieldNameCollision checks if a field name conflicts with generated method names
|
||||
// and adds an underscore suffix if there's a collision
|
||||
func resolveFieldNameCollision(fieldName string) string {
|
||||
// List of method names that are generated by the template
|
||||
reservedNames := map[string]bool{
|
||||
"TableName": true,
|
||||
"TableNameOnly": true,
|
||||
"SchemaName": true,
|
||||
"GetID": true,
|
||||
"GetIDStr": true,
|
||||
"SetID": true,
|
||||
"UpdateID": true,
|
||||
"GetIDName": true,
|
||||
"GetPrefix": true,
|
||||
}
|
||||
|
||||
// Check if field name conflicts with a reserved method name
|
||||
if reservedNames[fieldName] {
|
||||
return fieldName + "_"
|
||||
}
|
||||
|
||||
return fieldName
|
||||
}
|
||||
|
||||
// sortColumns sorts columns by sequence, then by name
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"fmt"
|
||||
"go/format"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
@@ -124,7 +125,16 @@ func (w *Writer) writeSingleFile(db *models.Database) error {
|
||||
}
|
||||
|
||||
// Write output
|
||||
return w.writeOutput(formatted)
|
||||
if err := w.writeOutput(formatted); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Run go fmt on the output file
|
||||
if w.options.OutputPath != "" {
|
||||
w.runGoFmt(w.options.OutputPath)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// writeMultiFile writes each table to a separate file
|
||||
@@ -217,6 +227,9 @@ func (w *Writer) writeMultiFile(db *models.Database) error {
|
||||
if err := os.WriteFile(filepath, []byte(formatted), 0644); err != nil {
|
||||
return fmt.Errorf("failed to write file %s: %w", filename, err)
|
||||
}
|
||||
|
||||
// Run go fmt on the generated file
|
||||
w.runGoFmt(filepath)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -241,7 +254,7 @@ func (w *Writer) addRelationshipFields(modelData *ModelData, table *models.Table
|
||||
}
|
||||
|
||||
// Create relationship field (has-one in Bun, similar to belongs-to in GORM)
|
||||
refModelName := w.getModelName(constraint.ReferencedTable)
|
||||
refModelName := w.getModelName(constraint.ReferencedSchema, constraint.ReferencedTable)
|
||||
fieldName := w.generateHasOneFieldName(constraint)
|
||||
fieldName = w.ensureUniqueFieldName(fieldName, usedFieldNames)
|
||||
relationTag := w.typeMapper.BuildRelationshipTag(constraint, "has-one")
|
||||
@@ -270,8 +283,8 @@ func (w *Writer) addRelationshipFields(modelData *ModelData, table *models.Table
|
||||
// Check if this constraint references our table
|
||||
if constraint.ReferencedTable == table.Name && constraint.ReferencedSchema == schema.Name {
|
||||
// Add has-many relationship
|
||||
otherModelName := w.getModelName(otherTable.Name)
|
||||
fieldName := w.generateHasManyFieldName(constraint, otherTable.Name)
|
||||
otherModelName := w.getModelName(otherSchema.Name, otherTable.Name)
|
||||
fieldName := w.generateHasManyFieldName(constraint, otherSchema.Name, otherTable.Name)
|
||||
fieldName = w.ensureUniqueFieldName(fieldName, usedFieldNames)
|
||||
relationTag := w.typeMapper.BuildRelationshipTag(constraint, "has-many")
|
||||
|
||||
@@ -303,13 +316,18 @@ func (w *Writer) findTable(schemaName, tableName string, db *models.Database) *m
|
||||
return nil
|
||||
}
|
||||
|
||||
// getModelName generates the model name from a table name
|
||||
func (w *Writer) getModelName(tableName string) string {
|
||||
// getModelName generates the model name from schema and table name
|
||||
func (w *Writer) getModelName(schemaName, tableName string) string {
|
||||
singular := Singularize(tableName)
|
||||
modelName := SnakeCaseToPascalCase(singular)
|
||||
tablePart := SnakeCaseToPascalCase(singular)
|
||||
|
||||
if !hasModelPrefix(modelName) {
|
||||
modelName = "Model" + modelName
|
||||
// Include schema name in model name
|
||||
var modelName string
|
||||
if schemaName != "" {
|
||||
schemaPart := SnakeCaseToPascalCase(schemaName)
|
||||
modelName = "Model" + schemaPart + tablePart
|
||||
} else {
|
||||
modelName = "Model" + tablePart
|
||||
}
|
||||
|
||||
return modelName
|
||||
@@ -333,13 +351,13 @@ func (w *Writer) generateHasOneFieldName(constraint *models.Constraint) string {
|
||||
|
||||
// generateHasManyFieldName generates a field name for has-many relationships
|
||||
// Uses the foreign key column name + source table name to avoid duplicates
|
||||
func (w *Writer) generateHasManyFieldName(constraint *models.Constraint, sourceTableName string) string {
|
||||
func (w *Writer) generateHasManyFieldName(constraint *models.Constraint, sourceSchemaName, sourceTableName string) string {
|
||||
// For has-many, we need to include the source table name to avoid duplicates
|
||||
// e.g., multiple tables referencing the same column on this table
|
||||
if len(constraint.Columns) > 0 {
|
||||
columnName := constraint.Columns[0]
|
||||
// Get the model name for the source table (pluralized)
|
||||
sourceModelName := w.getModelName(sourceTableName)
|
||||
sourceModelName := w.getModelName(sourceSchemaName, sourceTableName)
|
||||
// Remove "Model" prefix if present
|
||||
sourceModelName = strings.TrimPrefix(sourceModelName, "Model")
|
||||
|
||||
@@ -350,7 +368,7 @@ func (w *Writer) generateHasManyFieldName(constraint *models.Constraint, sourceT
|
||||
}
|
||||
|
||||
// Fallback to table-based naming
|
||||
sourceModelName := w.getModelName(sourceTableName)
|
||||
sourceModelName := w.getModelName(sourceSchemaName, sourceTableName)
|
||||
sourceModelName = strings.TrimPrefix(sourceModelName, "Model")
|
||||
return "Rel" + Pluralize(sourceModelName)
|
||||
}
|
||||
@@ -399,6 +417,15 @@ func (w *Writer) writeOutput(content string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// runGoFmt runs go fmt on the specified file
|
||||
func (w *Writer) runGoFmt(filepath string) {
|
||||
cmd := exec.Command("gofmt", "-w", filepath)
|
||||
if err := cmd.Run(); err != nil {
|
||||
// Don't fail the whole operation if gofmt fails, just warn
|
||||
fmt.Fprintf(os.Stderr, "Warning: failed to run gofmt on %s: %v\n", filepath, err)
|
||||
}
|
||||
}
|
||||
|
||||
// shouldUseMultiFile determines whether to use multi-file mode based on metadata or output path
|
||||
func (w *Writer) shouldUseMultiFile() bool {
|
||||
// Check if multi_file is explicitly set in metadata
|
||||
|
||||
@@ -66,7 +66,7 @@ func TestWriter_WriteTable(t *testing.T) {
|
||||
// Verify key elements are present
|
||||
expectations := []string{
|
||||
"package models",
|
||||
"type ModelUser struct",
|
||||
"type ModelPublicUser struct",
|
||||
"bun.BaseModel",
|
||||
"table:public.users",
|
||||
"alias:users",
|
||||
@@ -78,9 +78,9 @@ func TestWriter_WriteTable(t *testing.T) {
|
||||
"resolvespec_common.SqlTime",
|
||||
"bun:\"id",
|
||||
"bun:\"email",
|
||||
"func (m ModelUser) TableName() string",
|
||||
"func (m ModelPublicUser) TableName() string",
|
||||
"return \"public.users\"",
|
||||
"func (m ModelUser) GetID() int64",
|
||||
"func (m ModelPublicUser) GetID() int64",
|
||||
}
|
||||
|
||||
for _, expected := range expectations {
|
||||
@@ -191,9 +191,9 @@ func TestWriter_WriteDatabase_MultiFile(t *testing.T) {
|
||||
|
||||
usersStr := string(usersContent)
|
||||
|
||||
// Should have RelUserIDPosts (has-many) field
|
||||
if !strings.Contains(usersStr, "RelUserIDPosts") {
|
||||
t.Errorf("Missing has-many relationship field RelUserIDPosts")
|
||||
// Should have RelUserIDPublicPosts (has-many) field - includes schema prefix
|
||||
if !strings.Contains(usersStr, "RelUserIDPublicPosts") {
|
||||
t.Errorf("Missing has-many relationship field RelUserIDPublicPosts")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -309,8 +309,8 @@ func TestWriter_MultipleReferencesToSameTable(t *testing.T) {
|
||||
|
||||
// Should have two different has-many relationships with unique names
|
||||
hasManyExpectations := []string{
|
||||
"RelRIDFilepointerRequestAPIEvents", // Has many via rid_filepointer_request
|
||||
"RelRIDFilepointerResponseAPIEvents", // Has many via rid_filepointer_response
|
||||
"RelRIDFilepointerRequestOrgAPIEvents", // Has many via rid_filepointer_request
|
||||
"RelRIDFilepointerResponseOrgAPIEvents", // Has many via rid_filepointer_response
|
||||
}
|
||||
|
||||
for _, exp := range hasManyExpectations {
|
||||
@@ -455,10 +455,10 @@ func TestWriter_MultipleHasManyRelationships(t *testing.T) {
|
||||
|
||||
// Verify all has-many relationships have unique names
|
||||
hasManyExpectations := []string{
|
||||
"RelRIDAPIProviderLogins", // Has many via Login
|
||||
"RelRIDAPIProviderFilepointers", // Has many via Filepointer
|
||||
"RelRIDAPIProviderAPIEvents", // Has many via APIEvent
|
||||
"RelRIDOwner", // Has one via rid_owner
|
||||
"RelRIDAPIProviderOrgLogins", // Has many via Login
|
||||
"RelRIDAPIProviderOrgFilepointers", // Has many via Filepointer
|
||||
"RelRIDAPIProviderOrgAPIEvents", // Has many via APIEvent
|
||||
"RelRIDOwner", // Has one via rid_owner
|
||||
}
|
||||
|
||||
for _, exp := range hasManyExpectations {
|
||||
@@ -481,6 +481,74 @@ func TestWriter_MultipleHasManyRelationships(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriter_FieldNameCollision(t *testing.T) {
|
||||
// Test scenario: table with columns that would conflict with generated method names
|
||||
table := models.InitTable("audit_table", "audit")
|
||||
table.Columns["id_audit_table"] = &models.Column{
|
||||
Name: "id_audit_table",
|
||||
Type: "smallint",
|
||||
NotNull: true,
|
||||
IsPrimaryKey: true,
|
||||
Sequence: 1,
|
||||
}
|
||||
table.Columns["table_name"] = &models.Column{
|
||||
Name: "table_name",
|
||||
Type: "varchar",
|
||||
Length: 100,
|
||||
NotNull: true,
|
||||
Sequence: 2,
|
||||
}
|
||||
table.Columns["table_schema"] = &models.Column{
|
||||
Name: "table_schema",
|
||||
Type: "varchar",
|
||||
Length: 100,
|
||||
NotNull: true,
|
||||
Sequence: 3,
|
||||
}
|
||||
|
||||
// Create writer
|
||||
tmpDir := t.TempDir()
|
||||
opts := &writers.WriterOptions{
|
||||
PackageName: "models",
|
||||
OutputPath: filepath.Join(tmpDir, "test.go"),
|
||||
}
|
||||
|
||||
writer := NewWriter(opts)
|
||||
|
||||
err := writer.WriteTable(table)
|
||||
if err != nil {
|
||||
t.Fatalf("WriteTable failed: %v", err)
|
||||
}
|
||||
|
||||
// Read the generated file
|
||||
content, err := os.ReadFile(opts.OutputPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read generated file: %v", err)
|
||||
}
|
||||
|
||||
generated := string(content)
|
||||
|
||||
// Verify that TableName field was renamed to TableName_ to avoid collision
|
||||
if !strings.Contains(generated, "TableName_") {
|
||||
t.Errorf("Expected field 'TableName_' (with underscore) but not found\nGenerated:\n%s", generated)
|
||||
}
|
||||
|
||||
// Verify the struct tag still references the correct database column
|
||||
if !strings.Contains(generated, `bun:"table_name,`) {
|
||||
t.Errorf("Expected bun tag to reference 'table_name' column\nGenerated:\n%s", generated)
|
||||
}
|
||||
|
||||
// Verify the TableName() method still exists and doesn't conflict
|
||||
if !strings.Contains(generated, "func (m ModelAuditAuditTable) TableName() string") {
|
||||
t.Errorf("TableName() method should still be generated\nGenerated:\n%s", generated)
|
||||
}
|
||||
|
||||
// Verify NO field named just "TableName" (without underscore)
|
||||
if strings.Contains(generated, "TableName resolvespec_common") || strings.Contains(generated, "TableName string") {
|
||||
t.Errorf("Field 'TableName' without underscore should not exist (would conflict with method)\nGenerated:\n%s", generated)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTypeMapper_SQLTypeToGoType_Bun(t *testing.T) {
|
||||
mapper := NewTypeMapper()
|
||||
|
||||
|
||||
@@ -25,6 +25,7 @@ type ModelData struct {
|
||||
Fields []*FieldData
|
||||
Config *MethodConfig
|
||||
PrimaryKeyField string // Name of the primary key field
|
||||
PrimaryKeyType string // Go type of the primary key field
|
||||
IDColumnName string // Name of the ID column in database
|
||||
Prefix string // 3-letter prefix
|
||||
}
|
||||
@@ -110,13 +111,17 @@ func NewModelData(table *models.Table, schema string, typeMapper *TypeMapper) *M
|
||||
tableName = schema + "." + table.Name
|
||||
}
|
||||
|
||||
// Generate model name: singularize and convert to PascalCase
|
||||
// Generate model name: Model + Schema + Table (all PascalCase)
|
||||
singularTable := Singularize(table.Name)
|
||||
modelName := SnakeCaseToPascalCase(singularTable)
|
||||
tablePart := SnakeCaseToPascalCase(singularTable)
|
||||
|
||||
// Add "Model" prefix if not already present
|
||||
if !hasModelPrefix(modelName) {
|
||||
modelName = "Model" + modelName
|
||||
// Include schema name in model name
|
||||
var modelName string
|
||||
if schema != "" {
|
||||
schemaPart := SnakeCaseToPascalCase(schema)
|
||||
modelName = "Model" + schemaPart + tablePart
|
||||
} else {
|
||||
modelName = "Model" + tablePart
|
||||
}
|
||||
|
||||
model := &ModelData{
|
||||
@@ -135,6 +140,7 @@ func NewModelData(table *models.Table, schema string, typeMapper *TypeMapper) *M
|
||||
// Sanitize column name to remove backticks
|
||||
safeName := writers.SanitizeStructTagValue(col.Name)
|
||||
model.PrimaryKeyField = SnakeCaseToPascalCase(safeName)
|
||||
model.PrimaryKeyType = typeMapper.SQLTypeToGoType(col.Type, col.NotNull)
|
||||
model.IDColumnName = safeName
|
||||
break
|
||||
}
|
||||
@@ -144,6 +150,8 @@ func NewModelData(table *models.Table, schema string, typeMapper *TypeMapper) *M
|
||||
columns := sortColumns(table.Columns)
|
||||
for _, col := range columns {
|
||||
field := columnToField(col, table, typeMapper)
|
||||
// Check for name collision with generated methods and rename if needed
|
||||
field.Name = resolveFieldNameCollision(field.Name)
|
||||
model.Fields = append(model.Fields, field)
|
||||
}
|
||||
|
||||
@@ -185,9 +193,28 @@ func formatComment(description, comment string) string {
|
||||
return comment
|
||||
}
|
||||
|
||||
// hasModelPrefix checks if a name already has "Model" prefix
|
||||
func hasModelPrefix(name string) bool {
|
||||
return len(name) >= 5 && name[:5] == "Model"
|
||||
// resolveFieldNameCollision checks if a field name conflicts with generated method names
|
||||
// and adds an underscore suffix if there's a collision
|
||||
func resolveFieldNameCollision(fieldName string) string {
|
||||
// List of method names that are generated by the template
|
||||
reservedNames := map[string]bool{
|
||||
"TableName": true,
|
||||
"TableNameOnly": true,
|
||||
"SchemaName": true,
|
||||
"GetID": true,
|
||||
"GetIDStr": true,
|
||||
"SetID": true,
|
||||
"UpdateID": true,
|
||||
"GetIDName": true,
|
||||
"GetPrefix": true,
|
||||
}
|
||||
|
||||
// Check if field name conflicts with a reserved method name
|
||||
if reservedNames[fieldName] {
|
||||
return fieldName + "_"
|
||||
}
|
||||
|
||||
return fieldName
|
||||
}
|
||||
|
||||
// sortColumns sorts columns by sequence, then by name
|
||||
|
||||
@@ -62,7 +62,7 @@ func (m {{.Name}}) SetID(newid int64) {
|
||||
{{if and .Config.GenerateUpdateID .PrimaryKeyField}}
|
||||
// UpdateID updates the primary key value
|
||||
func (m *{{.Name}}) UpdateID(newid int64) {
|
||||
m.{{.PrimaryKeyField}} = int32(newid)
|
||||
m.{{.PrimaryKeyField}} = {{.PrimaryKeyType}}(newid)
|
||||
}
|
||||
{{end}}
|
||||
{{if and .Config.GenerateGetIDName .IDColumnName}}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"fmt"
|
||||
"go/format"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
@@ -121,7 +122,16 @@ func (w *Writer) writeSingleFile(db *models.Database) error {
|
||||
}
|
||||
|
||||
// Write output
|
||||
return w.writeOutput(formatted)
|
||||
if err := w.writeOutput(formatted); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Run go fmt on the output file
|
||||
if w.options.OutputPath != "" {
|
||||
w.runGoFmt(w.options.OutputPath)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// writeMultiFile writes each table to a separate file
|
||||
@@ -211,6 +221,9 @@ func (w *Writer) writeMultiFile(db *models.Database) error {
|
||||
if err := os.WriteFile(filepath, []byte(formatted), 0644); err != nil {
|
||||
return fmt.Errorf("failed to write file %s: %w", filename, err)
|
||||
}
|
||||
|
||||
// Run go fmt on the generated file
|
||||
w.runGoFmt(filepath)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -235,7 +248,7 @@ func (w *Writer) addRelationshipFields(modelData *ModelData, table *models.Table
|
||||
}
|
||||
|
||||
// Create relationship field (belongs-to)
|
||||
refModelName := w.getModelName(constraint.ReferencedTable)
|
||||
refModelName := w.getModelName(constraint.ReferencedSchema, constraint.ReferencedTable)
|
||||
fieldName := w.generateBelongsToFieldName(constraint)
|
||||
fieldName = w.ensureUniqueFieldName(fieldName, usedFieldNames)
|
||||
relationTag := w.typeMapper.BuildRelationshipTag(constraint, false)
|
||||
@@ -264,8 +277,8 @@ func (w *Writer) addRelationshipFields(modelData *ModelData, table *models.Table
|
||||
// Check if this constraint references our table
|
||||
if constraint.ReferencedTable == table.Name && constraint.ReferencedSchema == schema.Name {
|
||||
// Add has-many relationship
|
||||
otherModelName := w.getModelName(otherTable.Name)
|
||||
fieldName := w.generateHasManyFieldName(constraint, otherTable.Name)
|
||||
otherModelName := w.getModelName(otherSchema.Name, otherTable.Name)
|
||||
fieldName := w.generateHasManyFieldName(constraint, otherSchema.Name, otherTable.Name)
|
||||
fieldName = w.ensureUniqueFieldName(fieldName, usedFieldNames)
|
||||
relationTag := w.typeMapper.BuildRelationshipTag(constraint, true)
|
||||
|
||||
@@ -297,13 +310,18 @@ func (w *Writer) findTable(schemaName, tableName string, db *models.Database) *m
|
||||
return nil
|
||||
}
|
||||
|
||||
// getModelName generates the model name from a table name
|
||||
func (w *Writer) getModelName(tableName string) string {
|
||||
// getModelName generates the model name from schema and table name
|
||||
func (w *Writer) getModelName(schemaName, tableName string) string {
|
||||
singular := Singularize(tableName)
|
||||
modelName := SnakeCaseToPascalCase(singular)
|
||||
tablePart := SnakeCaseToPascalCase(singular)
|
||||
|
||||
if !hasModelPrefix(modelName) {
|
||||
modelName = "Model" + modelName
|
||||
// Include schema name in model name
|
||||
var modelName string
|
||||
if schemaName != "" {
|
||||
schemaPart := SnakeCaseToPascalCase(schemaName)
|
||||
modelName = "Model" + schemaPart + tablePart
|
||||
} else {
|
||||
modelName = "Model" + tablePart
|
||||
}
|
||||
|
||||
return modelName
|
||||
@@ -327,13 +345,13 @@ func (w *Writer) generateBelongsToFieldName(constraint *models.Constraint) strin
|
||||
|
||||
// generateHasManyFieldName generates a field name for has-many relationships
|
||||
// Uses the foreign key column name + source table name to avoid duplicates
|
||||
func (w *Writer) generateHasManyFieldName(constraint *models.Constraint, sourceTableName string) string {
|
||||
func (w *Writer) generateHasManyFieldName(constraint *models.Constraint, sourceSchemaName, sourceTableName string) string {
|
||||
// For has-many, we need to include the source table name to avoid duplicates
|
||||
// e.g., multiple tables referencing the same column on this table
|
||||
if len(constraint.Columns) > 0 {
|
||||
columnName := constraint.Columns[0]
|
||||
// Get the model name for the source table (pluralized)
|
||||
sourceModelName := w.getModelName(sourceTableName)
|
||||
sourceModelName := w.getModelName(sourceSchemaName, sourceTableName)
|
||||
// Remove "Model" prefix if present
|
||||
sourceModelName = strings.TrimPrefix(sourceModelName, "Model")
|
||||
|
||||
@@ -344,7 +362,7 @@ func (w *Writer) generateHasManyFieldName(constraint *models.Constraint, sourceT
|
||||
}
|
||||
|
||||
// Fallback to table-based naming
|
||||
sourceModelName := w.getModelName(sourceTableName)
|
||||
sourceModelName := w.getModelName(sourceSchemaName, sourceTableName)
|
||||
sourceModelName = strings.TrimPrefix(sourceModelName, "Model")
|
||||
return "Rel" + Pluralize(sourceModelName)
|
||||
}
|
||||
@@ -393,6 +411,15 @@ func (w *Writer) writeOutput(content string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// runGoFmt runs go fmt on the specified file
|
||||
func (w *Writer) runGoFmt(filepath string) {
|
||||
cmd := exec.Command("gofmt", "-w", filepath)
|
||||
if err := cmd.Run(); err != nil {
|
||||
// Don't fail the whole operation if gofmt fails, just warn
|
||||
fmt.Fprintf(os.Stderr, "Warning: failed to run gofmt on %s: %v\n", filepath, err)
|
||||
}
|
||||
}
|
||||
|
||||
// shouldUseMultiFile determines whether to use multi-file mode based on metadata or output path
|
||||
func (w *Writer) shouldUseMultiFile() bool {
|
||||
// Check if multi_file is explicitly set in metadata
|
||||
|
||||
@@ -66,7 +66,7 @@ func TestWriter_WriteTable(t *testing.T) {
|
||||
// Verify key elements are present
|
||||
expectations := []string{
|
||||
"package models",
|
||||
"type ModelUser struct",
|
||||
"type ModelPublicUser struct",
|
||||
"ID",
|
||||
"int64",
|
||||
"Email",
|
||||
@@ -75,9 +75,9 @@ func TestWriter_WriteTable(t *testing.T) {
|
||||
"time.Time",
|
||||
"gorm:\"column:id",
|
||||
"gorm:\"column:email",
|
||||
"func (m ModelUser) TableName() string",
|
||||
"func (m ModelPublicUser) TableName() string",
|
||||
"return \"public.users\"",
|
||||
"func (m ModelUser) GetID() int64",
|
||||
"func (m ModelPublicUser) GetID() int64",
|
||||
}
|
||||
|
||||
for _, expected := range expectations {
|
||||
@@ -180,9 +180,9 @@ func TestWriter_WriteDatabase_MultiFile(t *testing.T) {
|
||||
|
||||
usersStr := string(usersContent)
|
||||
|
||||
// Should have RelUserIDPosts (has-many) field
|
||||
if !strings.Contains(usersStr, "RelUserIDPosts") {
|
||||
t.Errorf("Missing has-many relationship field RelUserIDPosts")
|
||||
// Should have RelUserIDPublicPosts (has-many) field - includes schema prefix
|
||||
if !strings.Contains(usersStr, "RelUserIDPublicPosts") {
|
||||
t.Errorf("Missing has-many relationship field RelUserIDPublicPosts")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -298,8 +298,8 @@ func TestWriter_MultipleReferencesToSameTable(t *testing.T) {
|
||||
|
||||
// Should have two different has-many relationships with unique names
|
||||
hasManyExpectations := []string{
|
||||
"RelRIDFilepointerRequestAPIEvents", // Has many via rid_filepointer_request
|
||||
"RelRIDFilepointerResponseAPIEvents", // Has many via rid_filepointer_response
|
||||
"RelRIDFilepointerRequestOrgAPIEvents", // Has many via rid_filepointer_request
|
||||
"RelRIDFilepointerResponseOrgAPIEvents", // Has many via rid_filepointer_response
|
||||
}
|
||||
|
||||
for _, exp := range hasManyExpectations {
|
||||
@@ -444,10 +444,10 @@ func TestWriter_MultipleHasManyRelationships(t *testing.T) {
|
||||
|
||||
// Verify all has-many relationships have unique names
|
||||
hasManyExpectations := []string{
|
||||
"RelRIDAPIProviderLogins", // Has many via Login
|
||||
"RelRIDAPIProviderFilepointers", // Has many via Filepointer
|
||||
"RelRIDAPIProviderAPIEvents", // Has many via APIEvent
|
||||
"RelRIDOwner", // Belongs to via rid_owner
|
||||
"RelRIDAPIProviderOrgLogins", // Has many via Login
|
||||
"RelRIDAPIProviderOrgFilepointers", // Has many via Filepointer
|
||||
"RelRIDAPIProviderOrgAPIEvents", // Has many via APIEvent
|
||||
"RelRIDOwner", // Belongs to via rid_owner
|
||||
}
|
||||
|
||||
for _, exp := range hasManyExpectations {
|
||||
@@ -470,6 +470,134 @@ func TestWriter_MultipleHasManyRelationships(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriter_FieldNameCollision(t *testing.T) {
|
||||
// Test scenario: table with columns that would conflict with generated method names
|
||||
table := models.InitTable("audit_table", "audit")
|
||||
table.Columns["id_audit_table"] = &models.Column{
|
||||
Name: "id_audit_table",
|
||||
Type: "smallint",
|
||||
NotNull: true,
|
||||
IsPrimaryKey: true,
|
||||
Sequence: 1,
|
||||
}
|
||||
table.Columns["table_name"] = &models.Column{
|
||||
Name: "table_name",
|
||||
Type: "varchar",
|
||||
Length: 100,
|
||||
NotNull: true,
|
||||
Sequence: 2,
|
||||
}
|
||||
table.Columns["table_schema"] = &models.Column{
|
||||
Name: "table_schema",
|
||||
Type: "varchar",
|
||||
Length: 100,
|
||||
NotNull: true,
|
||||
Sequence: 3,
|
||||
}
|
||||
|
||||
// Create writer
|
||||
tmpDir := t.TempDir()
|
||||
opts := &writers.WriterOptions{
|
||||
PackageName: "models",
|
||||
OutputPath: filepath.Join(tmpDir, "test.go"),
|
||||
}
|
||||
|
||||
writer := NewWriter(opts)
|
||||
|
||||
err := writer.WriteTable(table)
|
||||
if err != nil {
|
||||
t.Fatalf("WriteTable failed: %v", err)
|
||||
}
|
||||
|
||||
// Read the generated file
|
||||
content, err := os.ReadFile(opts.OutputPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read generated file: %v", err)
|
||||
}
|
||||
|
||||
generated := string(content)
|
||||
|
||||
// Verify that TableName field was renamed to TableName_ to avoid collision
|
||||
if !strings.Contains(generated, "TableName_") {
|
||||
t.Errorf("Expected field 'TableName_' (with underscore) but not found\nGenerated:\n%s", generated)
|
||||
}
|
||||
|
||||
// Verify the struct tag still references the correct database column
|
||||
if !strings.Contains(generated, `gorm:"column:table_name;`) {
|
||||
t.Errorf("Expected gorm tag to reference 'table_name' column\nGenerated:\n%s", generated)
|
||||
}
|
||||
|
||||
// Verify the TableName() method still exists and doesn't conflict
|
||||
if !strings.Contains(generated, "func (m ModelAuditAuditTable) TableName() string") {
|
||||
t.Errorf("TableName() method should still be generated\nGenerated:\n%s", generated)
|
||||
}
|
||||
|
||||
// Verify NO field named just "TableName" (without underscore)
|
||||
if strings.Contains(generated, "TableName sql_types") || strings.Contains(generated, "TableName string") {
|
||||
t.Errorf("Field 'TableName' without underscore should not exist (would conflict with method)\nGenerated:\n%s", generated)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriter_UpdateIDTypeSafety(t *testing.T) {
|
||||
// Test scenario: tables with different primary key types
|
||||
tests := []struct {
|
||||
name string
|
||||
pkType string
|
||||
expectedPK string
|
||||
castType string
|
||||
}{
|
||||
{"int32_pk", "int", "int32", "int32(newid)"},
|
||||
{"int16_pk", "smallint", "int16", "int16(newid)"},
|
||||
{"int64_pk", "bigint", "int64", "int64(newid)"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
table := models.InitTable("test_table", "public")
|
||||
table.Columns["id"] = &models.Column{
|
||||
Name: "id",
|
||||
Type: tt.pkType,
|
||||
NotNull: true,
|
||||
IsPrimaryKey: true,
|
||||
}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
opts := &writers.WriterOptions{
|
||||
PackageName: "models",
|
||||
OutputPath: filepath.Join(tmpDir, "test.go"),
|
||||
}
|
||||
|
||||
writer := NewWriter(opts)
|
||||
err := writer.WriteTable(table)
|
||||
if err != nil {
|
||||
t.Fatalf("WriteTable failed: %v", err)
|
||||
}
|
||||
|
||||
content, err := os.ReadFile(opts.OutputPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read generated file: %v", err)
|
||||
}
|
||||
|
||||
generated := string(content)
|
||||
|
||||
// Verify UpdateID method has correct type cast
|
||||
if !strings.Contains(generated, tt.castType) {
|
||||
t.Errorf("Expected UpdateID to cast to %s\nGenerated:\n%s", tt.castType, generated)
|
||||
}
|
||||
|
||||
// Verify no invalid int32(newid) for non-int32 types
|
||||
if tt.expectedPK != "int32" && strings.Contains(generated, "int32(newid)") {
|
||||
t.Errorf("UpdateID should not cast to int32 for %s type\nGenerated:\n%s", tt.pkType, generated)
|
||||
}
|
||||
|
||||
// Verify UpdateID parameter is int64 (for consistency)
|
||||
if !strings.Contains(generated, "UpdateID(newid int64)") {
|
||||
t.Errorf("UpdateID should accept int64 parameter\nGenerated:\n%s", generated)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNameConverter_SnakeCaseToPascalCase(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
|
||||
Reference in New Issue
Block a user