From 6f55505444ef10b195670ec042a7190843ea932a Mon Sep 17 00:00:00 2001 From: Hein Date: Sat, 10 Jan 2026 18:28:41 +0200 Subject: [PATCH] =?UTF-8?q?feat(writer):=20=F0=9F=8E=89=20Enhance=20model?= =?UTF-8?q?=20name=20generation=20and=20formatting?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Update model name generation to include schema name. * Add gofmt execution after writing output files. * Refactor relationship field naming to include schema. * Update tests to reflect changes in model names and relationships. --- pkg/writers/bun/template_data.go | 19 ++++++------ pkg/writers/bun/writer.go | 51 +++++++++++++++++++++++-------- pkg/writers/bun/writer_test.go | 26 ++++++++-------- pkg/writers/gorm/template_data.go | 19 ++++++------ pkg/writers/gorm/writer.go | 51 +++++++++++++++++++++++-------- pkg/writers/gorm/writer_test.go | 26 ++++++++-------- 6 files changed, 122 insertions(+), 70 deletions(-) diff --git a/pkg/writers/bun/template_data.go b/pkg/writers/bun/template_data.go index ab3435f..5d0a1be 100644 --- a/pkg/writers/bun/template_data.go +++ b/pkg/writers/bun/template_data.go @@ -112,13 +112,17 @@ func NewModelData(table *models.Table, schema string, typeMapper *TypeMapper) *M tableName = schema + "." + table.Name } - // Generate model name: singularize and convert to PascalCase + // Generate model name: Model + Schema + Table (all PascalCase) singularTable := Singularize(table.Name) - modelName := SnakeCaseToPascalCase(singularTable) + tablePart := SnakeCaseToPascalCase(singularTable) - // Add "Model" prefix if not already present - if !hasModelPrefix(modelName) { - modelName = "Model" + modelName + // Include schema name in model name + var modelName string + if schema != "" { + schemaPart := SnakeCaseToPascalCase(schema) + modelName = "Model" + schemaPart + tablePart + } else { + modelName = "Model" + tablePart } model := &ModelData{ @@ -192,11 +196,6 @@ func formatComment(description, comment string) string { return comment } -// hasModelPrefix checks if a name already has "Model" prefix -func hasModelPrefix(name string) bool { - return len(name) >= 5 && name[:5] == "Model" -} - // resolveFieldNameCollision checks if a field name conflicts with generated method names // and adds an underscore suffix if there's a collision func resolveFieldNameCollision(fieldName string) string { diff --git a/pkg/writers/bun/writer.go b/pkg/writers/bun/writer.go index a7ba21f..647bf16 100644 --- a/pkg/writers/bun/writer.go +++ b/pkg/writers/bun/writer.go @@ -4,6 +4,7 @@ import ( "fmt" "go/format" "os" + "os/exec" "path/filepath" "strings" @@ -124,7 +125,16 @@ func (w *Writer) writeSingleFile(db *models.Database) error { } // Write output - return w.writeOutput(formatted) + if err := w.writeOutput(formatted); err != nil { + return err + } + + // Run go fmt on the output file + if w.options.OutputPath != "" { + w.runGoFmt(w.options.OutputPath) + } + + return nil } // writeMultiFile writes each table to a separate file @@ -217,6 +227,9 @@ func (w *Writer) writeMultiFile(db *models.Database) error { if err := os.WriteFile(filepath, []byte(formatted), 0644); err != nil { return fmt.Errorf("failed to write file %s: %w", filename, err) } + + // Run go fmt on the generated file + w.runGoFmt(filepath) } } @@ -241,7 +254,7 @@ func (w *Writer) addRelationshipFields(modelData *ModelData, table *models.Table } // Create relationship field (has-one in Bun, similar to belongs-to in GORM) - refModelName := w.getModelName(constraint.ReferencedTable) + refModelName := w.getModelName(constraint.ReferencedSchema, constraint.ReferencedTable) fieldName := w.generateHasOneFieldName(constraint) fieldName = w.ensureUniqueFieldName(fieldName, usedFieldNames) relationTag := w.typeMapper.BuildRelationshipTag(constraint, "has-one") @@ -270,8 +283,8 @@ func (w *Writer) addRelationshipFields(modelData *ModelData, table *models.Table // Check if this constraint references our table if constraint.ReferencedTable == table.Name && constraint.ReferencedSchema == schema.Name { // Add has-many relationship - otherModelName := w.getModelName(otherTable.Name) - fieldName := w.generateHasManyFieldName(constraint, otherTable.Name) + otherModelName := w.getModelName(otherSchema.Name, otherTable.Name) + fieldName := w.generateHasManyFieldName(constraint, otherSchema.Name, otherTable.Name) fieldName = w.ensureUniqueFieldName(fieldName, usedFieldNames) relationTag := w.typeMapper.BuildRelationshipTag(constraint, "has-many") @@ -303,13 +316,18 @@ func (w *Writer) findTable(schemaName, tableName string, db *models.Database) *m return nil } -// getModelName generates the model name from a table name -func (w *Writer) getModelName(tableName string) string { +// getModelName generates the model name from schema and table name +func (w *Writer) getModelName(schemaName, tableName string) string { singular := Singularize(tableName) - modelName := SnakeCaseToPascalCase(singular) + tablePart := SnakeCaseToPascalCase(singular) - if !hasModelPrefix(modelName) { - modelName = "Model" + modelName + // Include schema name in model name + var modelName string + if schemaName != "" { + schemaPart := SnakeCaseToPascalCase(schemaName) + modelName = "Model" + schemaPart + tablePart + } else { + modelName = "Model" + tablePart } return modelName @@ -333,13 +351,13 @@ func (w *Writer) generateHasOneFieldName(constraint *models.Constraint) string { // generateHasManyFieldName generates a field name for has-many relationships // Uses the foreign key column name + source table name to avoid duplicates -func (w *Writer) generateHasManyFieldName(constraint *models.Constraint, sourceTableName string) string { +func (w *Writer) generateHasManyFieldName(constraint *models.Constraint, sourceSchemaName, sourceTableName string) string { // For has-many, we need to include the source table name to avoid duplicates // e.g., multiple tables referencing the same column on this table if len(constraint.Columns) > 0 { columnName := constraint.Columns[0] // Get the model name for the source table (pluralized) - sourceModelName := w.getModelName(sourceTableName) + sourceModelName := w.getModelName(sourceSchemaName, sourceTableName) // Remove "Model" prefix if present sourceModelName = strings.TrimPrefix(sourceModelName, "Model") @@ -350,7 +368,7 @@ func (w *Writer) generateHasManyFieldName(constraint *models.Constraint, sourceT } // Fallback to table-based naming - sourceModelName := w.getModelName(sourceTableName) + sourceModelName := w.getModelName(sourceSchemaName, sourceTableName) sourceModelName = strings.TrimPrefix(sourceModelName, "Model") return "Rel" + Pluralize(sourceModelName) } @@ -399,6 +417,15 @@ func (w *Writer) writeOutput(content string) error { return nil } +// runGoFmt runs go fmt on the specified file +func (w *Writer) runGoFmt(filepath string) { + cmd := exec.Command("gofmt", "-w", filepath) + if err := cmd.Run(); err != nil { + // Don't fail the whole operation if gofmt fails, just warn + fmt.Fprintf(os.Stderr, "Warning: failed to run gofmt on %s: %v\n", filepath, err) + } +} + // shouldUseMultiFile determines whether to use multi-file mode based on metadata or output path func (w *Writer) shouldUseMultiFile() bool { // Check if multi_file is explicitly set in metadata diff --git a/pkg/writers/bun/writer_test.go b/pkg/writers/bun/writer_test.go index 39d9d92..9ecd74e 100644 --- a/pkg/writers/bun/writer_test.go +++ b/pkg/writers/bun/writer_test.go @@ -66,7 +66,7 @@ func TestWriter_WriteTable(t *testing.T) { // Verify key elements are present expectations := []string{ "package models", - "type ModelUser struct", + "type ModelPublicUser struct", "bun.BaseModel", "table:public.users", "alias:users", @@ -78,9 +78,9 @@ func TestWriter_WriteTable(t *testing.T) { "resolvespec_common.SqlTime", "bun:\"id", "bun:\"email", - "func (m ModelUser) TableName() string", + "func (m ModelPublicUser) TableName() string", "return \"public.users\"", - "func (m ModelUser) GetID() int64", + "func (m ModelPublicUser) GetID() int64", } for _, expected := range expectations { @@ -191,9 +191,9 @@ func TestWriter_WriteDatabase_MultiFile(t *testing.T) { usersStr := string(usersContent) - // Should have RelUserIDPosts (has-many) field - if !strings.Contains(usersStr, "RelUserIDPosts") { - t.Errorf("Missing has-many relationship field RelUserIDPosts") + // Should have RelUserIDPublicPosts (has-many) field - includes schema prefix + if !strings.Contains(usersStr, "RelUserIDPublicPosts") { + t.Errorf("Missing has-many relationship field RelUserIDPublicPosts") } } @@ -309,8 +309,8 @@ func TestWriter_MultipleReferencesToSameTable(t *testing.T) { // Should have two different has-many relationships with unique names hasManyExpectations := []string{ - "RelRIDFilepointerRequestAPIEvents", // Has many via rid_filepointer_request - "RelRIDFilepointerResponseAPIEvents", // Has many via rid_filepointer_response + "RelRIDFilepointerRequestOrgAPIEvents", // Has many via rid_filepointer_request + "RelRIDFilepointerResponseOrgAPIEvents", // Has many via rid_filepointer_response } for _, exp := range hasManyExpectations { @@ -455,10 +455,10 @@ func TestWriter_MultipleHasManyRelationships(t *testing.T) { // Verify all has-many relationships have unique names hasManyExpectations := []string{ - "RelRIDAPIProviderLogins", // Has many via Login - "RelRIDAPIProviderFilepointers", // Has many via Filepointer - "RelRIDAPIProviderAPIEvents", // Has many via APIEvent - "RelRIDOwner", // Has one via rid_owner + "RelRIDAPIProviderOrgLogins", // Has many via Login + "RelRIDAPIProviderOrgFilepointers", // Has many via Filepointer + "RelRIDAPIProviderOrgAPIEvents", // Has many via APIEvent + "RelRIDOwner", // Has one via rid_owner } for _, exp := range hasManyExpectations { @@ -539,7 +539,7 @@ func TestWriter_FieldNameCollision(t *testing.T) { } // Verify the TableName() method still exists and doesn't conflict - if !strings.Contains(generated, "func (m ModelAuditTable) TableName() string") { + if !strings.Contains(generated, "func (m ModelAuditAuditTable) TableName() string") { t.Errorf("TableName() method should still be generated\nGenerated:\n%s", generated) } diff --git a/pkg/writers/gorm/template_data.go b/pkg/writers/gorm/template_data.go index a8fe5e2..84fc3a0 100644 --- a/pkg/writers/gorm/template_data.go +++ b/pkg/writers/gorm/template_data.go @@ -111,13 +111,17 @@ func NewModelData(table *models.Table, schema string, typeMapper *TypeMapper) *M tableName = schema + "." + table.Name } - // Generate model name: singularize and convert to PascalCase + // Generate model name: Model + Schema + Table (all PascalCase) singularTable := Singularize(table.Name) - modelName := SnakeCaseToPascalCase(singularTable) + tablePart := SnakeCaseToPascalCase(singularTable) - // Add "Model" prefix if not already present - if !hasModelPrefix(modelName) { - modelName = "Model" + modelName + // Include schema name in model name + var modelName string + if schema != "" { + schemaPart := SnakeCaseToPascalCase(schema) + modelName = "Model" + schemaPart + tablePart + } else { + modelName = "Model" + tablePart } model := &ModelData{ @@ -189,11 +193,6 @@ func formatComment(description, comment string) string { return comment } -// hasModelPrefix checks if a name already has "Model" prefix -func hasModelPrefix(name string) bool { - return len(name) >= 5 && name[:5] == "Model" -} - // resolveFieldNameCollision checks if a field name conflicts with generated method names // and adds an underscore suffix if there's a collision func resolveFieldNameCollision(fieldName string) string { diff --git a/pkg/writers/gorm/writer.go b/pkg/writers/gorm/writer.go index 2a03655..db7c14a 100644 --- a/pkg/writers/gorm/writer.go +++ b/pkg/writers/gorm/writer.go @@ -4,6 +4,7 @@ import ( "fmt" "go/format" "os" + "os/exec" "path/filepath" "strings" @@ -121,7 +122,16 @@ func (w *Writer) writeSingleFile(db *models.Database) error { } // Write output - return w.writeOutput(formatted) + if err := w.writeOutput(formatted); err != nil { + return err + } + + // Run go fmt on the output file + if w.options.OutputPath != "" { + w.runGoFmt(w.options.OutputPath) + } + + return nil } // writeMultiFile writes each table to a separate file @@ -211,6 +221,9 @@ func (w *Writer) writeMultiFile(db *models.Database) error { if err := os.WriteFile(filepath, []byte(formatted), 0644); err != nil { return fmt.Errorf("failed to write file %s: %w", filename, err) } + + // Run go fmt on the generated file + w.runGoFmt(filepath) } } @@ -235,7 +248,7 @@ func (w *Writer) addRelationshipFields(modelData *ModelData, table *models.Table } // Create relationship field (belongs-to) - refModelName := w.getModelName(constraint.ReferencedTable) + refModelName := w.getModelName(constraint.ReferencedSchema, constraint.ReferencedTable) fieldName := w.generateBelongsToFieldName(constraint) fieldName = w.ensureUniqueFieldName(fieldName, usedFieldNames) relationTag := w.typeMapper.BuildRelationshipTag(constraint, false) @@ -264,8 +277,8 @@ func (w *Writer) addRelationshipFields(modelData *ModelData, table *models.Table // Check if this constraint references our table if constraint.ReferencedTable == table.Name && constraint.ReferencedSchema == schema.Name { // Add has-many relationship - otherModelName := w.getModelName(otherTable.Name) - fieldName := w.generateHasManyFieldName(constraint, otherTable.Name) + otherModelName := w.getModelName(otherSchema.Name, otherTable.Name) + fieldName := w.generateHasManyFieldName(constraint, otherSchema.Name, otherTable.Name) fieldName = w.ensureUniqueFieldName(fieldName, usedFieldNames) relationTag := w.typeMapper.BuildRelationshipTag(constraint, true) @@ -297,13 +310,18 @@ func (w *Writer) findTable(schemaName, tableName string, db *models.Database) *m return nil } -// getModelName generates the model name from a table name -func (w *Writer) getModelName(tableName string) string { +// getModelName generates the model name from schema and table name +func (w *Writer) getModelName(schemaName, tableName string) string { singular := Singularize(tableName) - modelName := SnakeCaseToPascalCase(singular) + tablePart := SnakeCaseToPascalCase(singular) - if !hasModelPrefix(modelName) { - modelName = "Model" + modelName + // Include schema name in model name + var modelName string + if schemaName != "" { + schemaPart := SnakeCaseToPascalCase(schemaName) + modelName = "Model" + schemaPart + tablePart + } else { + modelName = "Model" + tablePart } return modelName @@ -327,13 +345,13 @@ func (w *Writer) generateBelongsToFieldName(constraint *models.Constraint) strin // generateHasManyFieldName generates a field name for has-many relationships // Uses the foreign key column name + source table name to avoid duplicates -func (w *Writer) generateHasManyFieldName(constraint *models.Constraint, sourceTableName string) string { +func (w *Writer) generateHasManyFieldName(constraint *models.Constraint, sourceSchemaName, sourceTableName string) string { // For has-many, we need to include the source table name to avoid duplicates // e.g., multiple tables referencing the same column on this table if len(constraint.Columns) > 0 { columnName := constraint.Columns[0] // Get the model name for the source table (pluralized) - sourceModelName := w.getModelName(sourceTableName) + sourceModelName := w.getModelName(sourceSchemaName, sourceTableName) // Remove "Model" prefix if present sourceModelName = strings.TrimPrefix(sourceModelName, "Model") @@ -344,7 +362,7 @@ func (w *Writer) generateHasManyFieldName(constraint *models.Constraint, sourceT } // Fallback to table-based naming - sourceModelName := w.getModelName(sourceTableName) + sourceModelName := w.getModelName(sourceSchemaName, sourceTableName) sourceModelName = strings.TrimPrefix(sourceModelName, "Model") return "Rel" + Pluralize(sourceModelName) } @@ -393,6 +411,15 @@ func (w *Writer) writeOutput(content string) error { return nil } +// runGoFmt runs go fmt on the specified file +func (w *Writer) runGoFmt(filepath string) { + cmd := exec.Command("gofmt", "-w", filepath) + if err := cmd.Run(); err != nil { + // Don't fail the whole operation if gofmt fails, just warn + fmt.Fprintf(os.Stderr, "Warning: failed to run gofmt on %s: %v\n", filepath, err) + } +} + // shouldUseMultiFile determines whether to use multi-file mode based on metadata or output path func (w *Writer) shouldUseMultiFile() bool { // Check if multi_file is explicitly set in metadata diff --git a/pkg/writers/gorm/writer_test.go b/pkg/writers/gorm/writer_test.go index d204a31..e6e9137 100644 --- a/pkg/writers/gorm/writer_test.go +++ b/pkg/writers/gorm/writer_test.go @@ -66,7 +66,7 @@ func TestWriter_WriteTable(t *testing.T) { // Verify key elements are present expectations := []string{ "package models", - "type ModelUser struct", + "type ModelPublicUser struct", "ID", "int64", "Email", @@ -75,9 +75,9 @@ func TestWriter_WriteTable(t *testing.T) { "time.Time", "gorm:\"column:id", "gorm:\"column:email", - "func (m ModelUser) TableName() string", + "func (m ModelPublicUser) TableName() string", "return \"public.users\"", - "func (m ModelUser) GetID() int64", + "func (m ModelPublicUser) GetID() int64", } for _, expected := range expectations { @@ -180,9 +180,9 @@ func TestWriter_WriteDatabase_MultiFile(t *testing.T) { usersStr := string(usersContent) - // Should have RelUserIDPosts (has-many) field - if !strings.Contains(usersStr, "RelUserIDPosts") { - t.Errorf("Missing has-many relationship field RelUserIDPosts") + // Should have RelUserIDPublicPosts (has-many) field - includes schema prefix + if !strings.Contains(usersStr, "RelUserIDPublicPosts") { + t.Errorf("Missing has-many relationship field RelUserIDPublicPosts") } } @@ -298,8 +298,8 @@ func TestWriter_MultipleReferencesToSameTable(t *testing.T) { // Should have two different has-many relationships with unique names hasManyExpectations := []string{ - "RelRIDFilepointerRequestAPIEvents", // Has many via rid_filepointer_request - "RelRIDFilepointerResponseAPIEvents", // Has many via rid_filepointer_response + "RelRIDFilepointerRequestOrgAPIEvents", // Has many via rid_filepointer_request + "RelRIDFilepointerResponseOrgAPIEvents", // Has many via rid_filepointer_response } for _, exp := range hasManyExpectations { @@ -444,10 +444,10 @@ func TestWriter_MultipleHasManyRelationships(t *testing.T) { // Verify all has-many relationships have unique names hasManyExpectations := []string{ - "RelRIDAPIProviderLogins", // Has many via Login - "RelRIDAPIProviderFilepointers", // Has many via Filepointer - "RelRIDAPIProviderAPIEvents", // Has many via APIEvent - "RelRIDOwner", // Belongs to via rid_owner + "RelRIDAPIProviderOrgLogins", // Has many via Login + "RelRIDAPIProviderOrgFilepointers", // Has many via Filepointer + "RelRIDAPIProviderOrgAPIEvents", // Has many via APIEvent + "RelRIDOwner", // Belongs to via rid_owner } for _, exp := range hasManyExpectations { @@ -528,7 +528,7 @@ func TestWriter_FieldNameCollision(t *testing.T) { } // Verify the TableName() method still exists and doesn't conflict - if !strings.Contains(generated, "func (m ModelAuditTable) TableName() string") { + if !strings.Contains(generated, "func (m ModelAuditAuditTable) TableName() string") { t.Errorf("TableName() method should still be generated\nGenerated:\n%s", generated) }