Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 6f55505444 | |||
| e0e7b64c69 |
@@ -112,13 +112,17 @@ func NewModelData(table *models.Table, schema string, typeMapper *TypeMapper) *M
|
|||||||
tableName = schema + "." + table.Name
|
tableName = schema + "." + table.Name
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate model name: singularize and convert to PascalCase
|
// Generate model name: Model + Schema + Table (all PascalCase)
|
||||||
singularTable := Singularize(table.Name)
|
singularTable := Singularize(table.Name)
|
||||||
modelName := SnakeCaseToPascalCase(singularTable)
|
tablePart := SnakeCaseToPascalCase(singularTable)
|
||||||
|
|
||||||
// Add "Model" prefix if not already present
|
// Include schema name in model name
|
||||||
if !hasModelPrefix(modelName) {
|
var modelName string
|
||||||
modelName = "Model" + modelName
|
if schema != "" {
|
||||||
|
schemaPart := SnakeCaseToPascalCase(schema)
|
||||||
|
modelName = "Model" + schemaPart + tablePart
|
||||||
|
} else {
|
||||||
|
modelName = "Model" + tablePart
|
||||||
}
|
}
|
||||||
|
|
||||||
model := &ModelData{
|
model := &ModelData{
|
||||||
@@ -149,6 +153,8 @@ func NewModelData(table *models.Table, schema string, typeMapper *TypeMapper) *M
|
|||||||
columns := sortColumns(table.Columns)
|
columns := sortColumns(table.Columns)
|
||||||
for _, col := range columns {
|
for _, col := range columns {
|
||||||
field := columnToField(col, table, typeMapper)
|
field := columnToField(col, table, typeMapper)
|
||||||
|
// Check for name collision with generated methods and rename if needed
|
||||||
|
field.Name = resolveFieldNameCollision(field.Name)
|
||||||
model.Fields = append(model.Fields, field)
|
model.Fields = append(model.Fields, field)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -190,9 +196,28 @@ func formatComment(description, comment string) string {
|
|||||||
return comment
|
return comment
|
||||||
}
|
}
|
||||||
|
|
||||||
// hasModelPrefix checks if a name already has "Model" prefix
|
// resolveFieldNameCollision checks if a field name conflicts with generated method names
|
||||||
func hasModelPrefix(name string) bool {
|
// and adds an underscore suffix if there's a collision
|
||||||
return len(name) >= 5 && name[:5] == "Model"
|
func resolveFieldNameCollision(fieldName string) string {
|
||||||
|
// List of method names that are generated by the template
|
||||||
|
reservedNames := map[string]bool{
|
||||||
|
"TableName": true,
|
||||||
|
"TableNameOnly": true,
|
||||||
|
"SchemaName": true,
|
||||||
|
"GetID": true,
|
||||||
|
"GetIDStr": true,
|
||||||
|
"SetID": true,
|
||||||
|
"UpdateID": true,
|
||||||
|
"GetIDName": true,
|
||||||
|
"GetPrefix": true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if field name conflicts with a reserved method name
|
||||||
|
if reservedNames[fieldName] {
|
||||||
|
return fieldName + "_"
|
||||||
|
}
|
||||||
|
|
||||||
|
return fieldName
|
||||||
}
|
}
|
||||||
|
|
||||||
// sortColumns sorts columns by sequence, then by name
|
// sortColumns sorts columns by sequence, then by name
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"go/format"
|
"go/format"
|
||||||
"os"
|
"os"
|
||||||
|
"os/exec"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@@ -124,7 +125,16 @@ func (w *Writer) writeSingleFile(db *models.Database) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Write output
|
// Write output
|
||||||
return w.writeOutput(formatted)
|
if err := w.writeOutput(formatted); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run go fmt on the output file
|
||||||
|
if w.options.OutputPath != "" {
|
||||||
|
w.runGoFmt(w.options.OutputPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// writeMultiFile writes each table to a separate file
|
// writeMultiFile writes each table to a separate file
|
||||||
@@ -217,6 +227,9 @@ func (w *Writer) writeMultiFile(db *models.Database) error {
|
|||||||
if err := os.WriteFile(filepath, []byte(formatted), 0644); err != nil {
|
if err := os.WriteFile(filepath, []byte(formatted), 0644); err != nil {
|
||||||
return fmt.Errorf("failed to write file %s: %w", filename, err)
|
return fmt.Errorf("failed to write file %s: %w", filename, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Run go fmt on the generated file
|
||||||
|
w.runGoFmt(filepath)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -241,7 +254,7 @@ func (w *Writer) addRelationshipFields(modelData *ModelData, table *models.Table
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create relationship field (has-one in Bun, similar to belongs-to in GORM)
|
// Create relationship field (has-one in Bun, similar to belongs-to in GORM)
|
||||||
refModelName := w.getModelName(constraint.ReferencedTable)
|
refModelName := w.getModelName(constraint.ReferencedSchema, constraint.ReferencedTable)
|
||||||
fieldName := w.generateHasOneFieldName(constraint)
|
fieldName := w.generateHasOneFieldName(constraint)
|
||||||
fieldName = w.ensureUniqueFieldName(fieldName, usedFieldNames)
|
fieldName = w.ensureUniqueFieldName(fieldName, usedFieldNames)
|
||||||
relationTag := w.typeMapper.BuildRelationshipTag(constraint, "has-one")
|
relationTag := w.typeMapper.BuildRelationshipTag(constraint, "has-one")
|
||||||
@@ -270,8 +283,8 @@ func (w *Writer) addRelationshipFields(modelData *ModelData, table *models.Table
|
|||||||
// Check if this constraint references our table
|
// Check if this constraint references our table
|
||||||
if constraint.ReferencedTable == table.Name && constraint.ReferencedSchema == schema.Name {
|
if constraint.ReferencedTable == table.Name && constraint.ReferencedSchema == schema.Name {
|
||||||
// Add has-many relationship
|
// Add has-many relationship
|
||||||
otherModelName := w.getModelName(otherTable.Name)
|
otherModelName := w.getModelName(otherSchema.Name, otherTable.Name)
|
||||||
fieldName := w.generateHasManyFieldName(constraint, otherTable.Name)
|
fieldName := w.generateHasManyFieldName(constraint, otherSchema.Name, otherTable.Name)
|
||||||
fieldName = w.ensureUniqueFieldName(fieldName, usedFieldNames)
|
fieldName = w.ensureUniqueFieldName(fieldName, usedFieldNames)
|
||||||
relationTag := w.typeMapper.BuildRelationshipTag(constraint, "has-many")
|
relationTag := w.typeMapper.BuildRelationshipTag(constraint, "has-many")
|
||||||
|
|
||||||
@@ -303,13 +316,18 @@ func (w *Writer) findTable(schemaName, tableName string, db *models.Database) *m
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// getModelName generates the model name from a table name
|
// getModelName generates the model name from schema and table name
|
||||||
func (w *Writer) getModelName(tableName string) string {
|
func (w *Writer) getModelName(schemaName, tableName string) string {
|
||||||
singular := Singularize(tableName)
|
singular := Singularize(tableName)
|
||||||
modelName := SnakeCaseToPascalCase(singular)
|
tablePart := SnakeCaseToPascalCase(singular)
|
||||||
|
|
||||||
if !hasModelPrefix(modelName) {
|
// Include schema name in model name
|
||||||
modelName = "Model" + modelName
|
var modelName string
|
||||||
|
if schemaName != "" {
|
||||||
|
schemaPart := SnakeCaseToPascalCase(schemaName)
|
||||||
|
modelName = "Model" + schemaPart + tablePart
|
||||||
|
} else {
|
||||||
|
modelName = "Model" + tablePart
|
||||||
}
|
}
|
||||||
|
|
||||||
return modelName
|
return modelName
|
||||||
@@ -333,13 +351,13 @@ func (w *Writer) generateHasOneFieldName(constraint *models.Constraint) string {
|
|||||||
|
|
||||||
// generateHasManyFieldName generates a field name for has-many relationships
|
// generateHasManyFieldName generates a field name for has-many relationships
|
||||||
// Uses the foreign key column name + source table name to avoid duplicates
|
// Uses the foreign key column name + source table name to avoid duplicates
|
||||||
func (w *Writer) generateHasManyFieldName(constraint *models.Constraint, sourceTableName string) string {
|
func (w *Writer) generateHasManyFieldName(constraint *models.Constraint, sourceSchemaName, sourceTableName string) string {
|
||||||
// For has-many, we need to include the source table name to avoid duplicates
|
// For has-many, we need to include the source table name to avoid duplicates
|
||||||
// e.g., multiple tables referencing the same column on this table
|
// e.g., multiple tables referencing the same column on this table
|
||||||
if len(constraint.Columns) > 0 {
|
if len(constraint.Columns) > 0 {
|
||||||
columnName := constraint.Columns[0]
|
columnName := constraint.Columns[0]
|
||||||
// Get the model name for the source table (pluralized)
|
// Get the model name for the source table (pluralized)
|
||||||
sourceModelName := w.getModelName(sourceTableName)
|
sourceModelName := w.getModelName(sourceSchemaName, sourceTableName)
|
||||||
// Remove "Model" prefix if present
|
// Remove "Model" prefix if present
|
||||||
sourceModelName = strings.TrimPrefix(sourceModelName, "Model")
|
sourceModelName = strings.TrimPrefix(sourceModelName, "Model")
|
||||||
|
|
||||||
@@ -350,7 +368,7 @@ func (w *Writer) generateHasManyFieldName(constraint *models.Constraint, sourceT
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Fallback to table-based naming
|
// Fallback to table-based naming
|
||||||
sourceModelName := w.getModelName(sourceTableName)
|
sourceModelName := w.getModelName(sourceSchemaName, sourceTableName)
|
||||||
sourceModelName = strings.TrimPrefix(sourceModelName, "Model")
|
sourceModelName = strings.TrimPrefix(sourceModelName, "Model")
|
||||||
return "Rel" + Pluralize(sourceModelName)
|
return "Rel" + Pluralize(sourceModelName)
|
||||||
}
|
}
|
||||||
@@ -399,6 +417,15 @@ func (w *Writer) writeOutput(content string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// runGoFmt runs go fmt on the specified file
|
||||||
|
func (w *Writer) runGoFmt(filepath string) {
|
||||||
|
cmd := exec.Command("gofmt", "-w", filepath)
|
||||||
|
if err := cmd.Run(); err != nil {
|
||||||
|
// Don't fail the whole operation if gofmt fails, just warn
|
||||||
|
fmt.Fprintf(os.Stderr, "Warning: failed to run gofmt on %s: %v\n", filepath, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// shouldUseMultiFile determines whether to use multi-file mode based on metadata or output path
|
// shouldUseMultiFile determines whether to use multi-file mode based on metadata or output path
|
||||||
func (w *Writer) shouldUseMultiFile() bool {
|
func (w *Writer) shouldUseMultiFile() bool {
|
||||||
// Check if multi_file is explicitly set in metadata
|
// Check if multi_file is explicitly set in metadata
|
||||||
|
|||||||
@@ -66,7 +66,7 @@ func TestWriter_WriteTable(t *testing.T) {
|
|||||||
// Verify key elements are present
|
// Verify key elements are present
|
||||||
expectations := []string{
|
expectations := []string{
|
||||||
"package models",
|
"package models",
|
||||||
"type ModelUser struct",
|
"type ModelPublicUser struct",
|
||||||
"bun.BaseModel",
|
"bun.BaseModel",
|
||||||
"table:public.users",
|
"table:public.users",
|
||||||
"alias:users",
|
"alias:users",
|
||||||
@@ -78,9 +78,9 @@ func TestWriter_WriteTable(t *testing.T) {
|
|||||||
"resolvespec_common.SqlTime",
|
"resolvespec_common.SqlTime",
|
||||||
"bun:\"id",
|
"bun:\"id",
|
||||||
"bun:\"email",
|
"bun:\"email",
|
||||||
"func (m ModelUser) TableName() string",
|
"func (m ModelPublicUser) TableName() string",
|
||||||
"return \"public.users\"",
|
"return \"public.users\"",
|
||||||
"func (m ModelUser) GetID() int64",
|
"func (m ModelPublicUser) GetID() int64",
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, expected := range expectations {
|
for _, expected := range expectations {
|
||||||
@@ -191,9 +191,9 @@ func TestWriter_WriteDatabase_MultiFile(t *testing.T) {
|
|||||||
|
|
||||||
usersStr := string(usersContent)
|
usersStr := string(usersContent)
|
||||||
|
|
||||||
// Should have RelUserIDPosts (has-many) field
|
// Should have RelUserIDPublicPosts (has-many) field - includes schema prefix
|
||||||
if !strings.Contains(usersStr, "RelUserIDPosts") {
|
if !strings.Contains(usersStr, "RelUserIDPublicPosts") {
|
||||||
t.Errorf("Missing has-many relationship field RelUserIDPosts")
|
t.Errorf("Missing has-many relationship field RelUserIDPublicPosts")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -309,8 +309,8 @@ func TestWriter_MultipleReferencesToSameTable(t *testing.T) {
|
|||||||
|
|
||||||
// Should have two different has-many relationships with unique names
|
// Should have two different has-many relationships with unique names
|
||||||
hasManyExpectations := []string{
|
hasManyExpectations := []string{
|
||||||
"RelRIDFilepointerRequestAPIEvents", // Has many via rid_filepointer_request
|
"RelRIDFilepointerRequestOrgAPIEvents", // Has many via rid_filepointer_request
|
||||||
"RelRIDFilepointerResponseAPIEvents", // Has many via rid_filepointer_response
|
"RelRIDFilepointerResponseOrgAPIEvents", // Has many via rid_filepointer_response
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, exp := range hasManyExpectations {
|
for _, exp := range hasManyExpectations {
|
||||||
@@ -455,10 +455,10 @@ func TestWriter_MultipleHasManyRelationships(t *testing.T) {
|
|||||||
|
|
||||||
// Verify all has-many relationships have unique names
|
// Verify all has-many relationships have unique names
|
||||||
hasManyExpectations := []string{
|
hasManyExpectations := []string{
|
||||||
"RelRIDAPIProviderLogins", // Has many via Login
|
"RelRIDAPIProviderOrgLogins", // Has many via Login
|
||||||
"RelRIDAPIProviderFilepointers", // Has many via Filepointer
|
"RelRIDAPIProviderOrgFilepointers", // Has many via Filepointer
|
||||||
"RelRIDAPIProviderAPIEvents", // Has many via APIEvent
|
"RelRIDAPIProviderOrgAPIEvents", // Has many via APIEvent
|
||||||
"RelRIDOwner", // Has one via rid_owner
|
"RelRIDOwner", // Has one via rid_owner
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, exp := range hasManyExpectations {
|
for _, exp := range hasManyExpectations {
|
||||||
@@ -481,6 +481,74 @@ func TestWriter_MultipleHasManyRelationships(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestWriter_FieldNameCollision(t *testing.T) {
|
||||||
|
// Test scenario: table with columns that would conflict with generated method names
|
||||||
|
table := models.InitTable("audit_table", "audit")
|
||||||
|
table.Columns["id_audit_table"] = &models.Column{
|
||||||
|
Name: "id_audit_table",
|
||||||
|
Type: "smallint",
|
||||||
|
NotNull: true,
|
||||||
|
IsPrimaryKey: true,
|
||||||
|
Sequence: 1,
|
||||||
|
}
|
||||||
|
table.Columns["table_name"] = &models.Column{
|
||||||
|
Name: "table_name",
|
||||||
|
Type: "varchar",
|
||||||
|
Length: 100,
|
||||||
|
NotNull: true,
|
||||||
|
Sequence: 2,
|
||||||
|
}
|
||||||
|
table.Columns["table_schema"] = &models.Column{
|
||||||
|
Name: "table_schema",
|
||||||
|
Type: "varchar",
|
||||||
|
Length: 100,
|
||||||
|
NotNull: true,
|
||||||
|
Sequence: 3,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create writer
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
opts := &writers.WriterOptions{
|
||||||
|
PackageName: "models",
|
||||||
|
OutputPath: filepath.Join(tmpDir, "test.go"),
|
||||||
|
}
|
||||||
|
|
||||||
|
writer := NewWriter(opts)
|
||||||
|
|
||||||
|
err := writer.WriteTable(table)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("WriteTable failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read the generated file
|
||||||
|
content, err := os.ReadFile(opts.OutputPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to read generated file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
generated := string(content)
|
||||||
|
|
||||||
|
// Verify that TableName field was renamed to TableName_ to avoid collision
|
||||||
|
if !strings.Contains(generated, "TableName_") {
|
||||||
|
t.Errorf("Expected field 'TableName_' (with underscore) but not found\nGenerated:\n%s", generated)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the struct tag still references the correct database column
|
||||||
|
if !strings.Contains(generated, `bun:"table_name,`) {
|
||||||
|
t.Errorf("Expected bun tag to reference 'table_name' column\nGenerated:\n%s", generated)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the TableName() method still exists and doesn't conflict
|
||||||
|
if !strings.Contains(generated, "func (m ModelAuditAuditTable) TableName() string") {
|
||||||
|
t.Errorf("TableName() method should still be generated\nGenerated:\n%s", generated)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify NO field named just "TableName" (without underscore)
|
||||||
|
if strings.Contains(generated, "TableName resolvespec_common") || strings.Contains(generated, "TableName string") {
|
||||||
|
t.Errorf("Field 'TableName' without underscore should not exist (would conflict with method)\nGenerated:\n%s", generated)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestTypeMapper_SQLTypeToGoType_Bun(t *testing.T) {
|
func TestTypeMapper_SQLTypeToGoType_Bun(t *testing.T) {
|
||||||
mapper := NewTypeMapper()
|
mapper := NewTypeMapper()
|
||||||
|
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ type ModelData struct {
|
|||||||
Fields []*FieldData
|
Fields []*FieldData
|
||||||
Config *MethodConfig
|
Config *MethodConfig
|
||||||
PrimaryKeyField string // Name of the primary key field
|
PrimaryKeyField string // Name of the primary key field
|
||||||
|
PrimaryKeyType string // Go type of the primary key field
|
||||||
IDColumnName string // Name of the ID column in database
|
IDColumnName string // Name of the ID column in database
|
||||||
Prefix string // 3-letter prefix
|
Prefix string // 3-letter prefix
|
||||||
}
|
}
|
||||||
@@ -110,13 +111,17 @@ func NewModelData(table *models.Table, schema string, typeMapper *TypeMapper) *M
|
|||||||
tableName = schema + "." + table.Name
|
tableName = schema + "." + table.Name
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate model name: singularize and convert to PascalCase
|
// Generate model name: Model + Schema + Table (all PascalCase)
|
||||||
singularTable := Singularize(table.Name)
|
singularTable := Singularize(table.Name)
|
||||||
modelName := SnakeCaseToPascalCase(singularTable)
|
tablePart := SnakeCaseToPascalCase(singularTable)
|
||||||
|
|
||||||
// Add "Model" prefix if not already present
|
// Include schema name in model name
|
||||||
if !hasModelPrefix(modelName) {
|
var modelName string
|
||||||
modelName = "Model" + modelName
|
if schema != "" {
|
||||||
|
schemaPart := SnakeCaseToPascalCase(schema)
|
||||||
|
modelName = "Model" + schemaPart + tablePart
|
||||||
|
} else {
|
||||||
|
modelName = "Model" + tablePart
|
||||||
}
|
}
|
||||||
|
|
||||||
model := &ModelData{
|
model := &ModelData{
|
||||||
@@ -135,6 +140,7 @@ func NewModelData(table *models.Table, schema string, typeMapper *TypeMapper) *M
|
|||||||
// Sanitize column name to remove backticks
|
// Sanitize column name to remove backticks
|
||||||
safeName := writers.SanitizeStructTagValue(col.Name)
|
safeName := writers.SanitizeStructTagValue(col.Name)
|
||||||
model.PrimaryKeyField = SnakeCaseToPascalCase(safeName)
|
model.PrimaryKeyField = SnakeCaseToPascalCase(safeName)
|
||||||
|
model.PrimaryKeyType = typeMapper.SQLTypeToGoType(col.Type, col.NotNull)
|
||||||
model.IDColumnName = safeName
|
model.IDColumnName = safeName
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@@ -144,6 +150,8 @@ func NewModelData(table *models.Table, schema string, typeMapper *TypeMapper) *M
|
|||||||
columns := sortColumns(table.Columns)
|
columns := sortColumns(table.Columns)
|
||||||
for _, col := range columns {
|
for _, col := range columns {
|
||||||
field := columnToField(col, table, typeMapper)
|
field := columnToField(col, table, typeMapper)
|
||||||
|
// Check for name collision with generated methods and rename if needed
|
||||||
|
field.Name = resolveFieldNameCollision(field.Name)
|
||||||
model.Fields = append(model.Fields, field)
|
model.Fields = append(model.Fields, field)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -185,9 +193,28 @@ func formatComment(description, comment string) string {
|
|||||||
return comment
|
return comment
|
||||||
}
|
}
|
||||||
|
|
||||||
// hasModelPrefix checks if a name already has "Model" prefix
|
// resolveFieldNameCollision checks if a field name conflicts with generated method names
|
||||||
func hasModelPrefix(name string) bool {
|
// and adds an underscore suffix if there's a collision
|
||||||
return len(name) >= 5 && name[:5] == "Model"
|
func resolveFieldNameCollision(fieldName string) string {
|
||||||
|
// List of method names that are generated by the template
|
||||||
|
reservedNames := map[string]bool{
|
||||||
|
"TableName": true,
|
||||||
|
"TableNameOnly": true,
|
||||||
|
"SchemaName": true,
|
||||||
|
"GetID": true,
|
||||||
|
"GetIDStr": true,
|
||||||
|
"SetID": true,
|
||||||
|
"UpdateID": true,
|
||||||
|
"GetIDName": true,
|
||||||
|
"GetPrefix": true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if field name conflicts with a reserved method name
|
||||||
|
if reservedNames[fieldName] {
|
||||||
|
return fieldName + "_"
|
||||||
|
}
|
||||||
|
|
||||||
|
return fieldName
|
||||||
}
|
}
|
||||||
|
|
||||||
// sortColumns sorts columns by sequence, then by name
|
// sortColumns sorts columns by sequence, then by name
|
||||||
|
|||||||
@@ -62,7 +62,7 @@ func (m {{.Name}}) SetID(newid int64) {
|
|||||||
{{if and .Config.GenerateUpdateID .PrimaryKeyField}}
|
{{if and .Config.GenerateUpdateID .PrimaryKeyField}}
|
||||||
// UpdateID updates the primary key value
|
// UpdateID updates the primary key value
|
||||||
func (m *{{.Name}}) UpdateID(newid int64) {
|
func (m *{{.Name}}) UpdateID(newid int64) {
|
||||||
m.{{.PrimaryKeyField}} = int32(newid)
|
m.{{.PrimaryKeyField}} = {{.PrimaryKeyType}}(newid)
|
||||||
}
|
}
|
||||||
{{end}}
|
{{end}}
|
||||||
{{if and .Config.GenerateGetIDName .IDColumnName}}
|
{{if and .Config.GenerateGetIDName .IDColumnName}}
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"go/format"
|
"go/format"
|
||||||
"os"
|
"os"
|
||||||
|
"os/exec"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@@ -121,7 +122,16 @@ func (w *Writer) writeSingleFile(db *models.Database) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Write output
|
// Write output
|
||||||
return w.writeOutput(formatted)
|
if err := w.writeOutput(formatted); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run go fmt on the output file
|
||||||
|
if w.options.OutputPath != "" {
|
||||||
|
w.runGoFmt(w.options.OutputPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// writeMultiFile writes each table to a separate file
|
// writeMultiFile writes each table to a separate file
|
||||||
@@ -211,6 +221,9 @@ func (w *Writer) writeMultiFile(db *models.Database) error {
|
|||||||
if err := os.WriteFile(filepath, []byte(formatted), 0644); err != nil {
|
if err := os.WriteFile(filepath, []byte(formatted), 0644); err != nil {
|
||||||
return fmt.Errorf("failed to write file %s: %w", filename, err)
|
return fmt.Errorf("failed to write file %s: %w", filename, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Run go fmt on the generated file
|
||||||
|
w.runGoFmt(filepath)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -235,7 +248,7 @@ func (w *Writer) addRelationshipFields(modelData *ModelData, table *models.Table
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create relationship field (belongs-to)
|
// Create relationship field (belongs-to)
|
||||||
refModelName := w.getModelName(constraint.ReferencedTable)
|
refModelName := w.getModelName(constraint.ReferencedSchema, constraint.ReferencedTable)
|
||||||
fieldName := w.generateBelongsToFieldName(constraint)
|
fieldName := w.generateBelongsToFieldName(constraint)
|
||||||
fieldName = w.ensureUniqueFieldName(fieldName, usedFieldNames)
|
fieldName = w.ensureUniqueFieldName(fieldName, usedFieldNames)
|
||||||
relationTag := w.typeMapper.BuildRelationshipTag(constraint, false)
|
relationTag := w.typeMapper.BuildRelationshipTag(constraint, false)
|
||||||
@@ -264,8 +277,8 @@ func (w *Writer) addRelationshipFields(modelData *ModelData, table *models.Table
|
|||||||
// Check if this constraint references our table
|
// Check if this constraint references our table
|
||||||
if constraint.ReferencedTable == table.Name && constraint.ReferencedSchema == schema.Name {
|
if constraint.ReferencedTable == table.Name && constraint.ReferencedSchema == schema.Name {
|
||||||
// Add has-many relationship
|
// Add has-many relationship
|
||||||
otherModelName := w.getModelName(otherTable.Name)
|
otherModelName := w.getModelName(otherSchema.Name, otherTable.Name)
|
||||||
fieldName := w.generateHasManyFieldName(constraint, otherTable.Name)
|
fieldName := w.generateHasManyFieldName(constraint, otherSchema.Name, otherTable.Name)
|
||||||
fieldName = w.ensureUniqueFieldName(fieldName, usedFieldNames)
|
fieldName = w.ensureUniqueFieldName(fieldName, usedFieldNames)
|
||||||
relationTag := w.typeMapper.BuildRelationshipTag(constraint, true)
|
relationTag := w.typeMapper.BuildRelationshipTag(constraint, true)
|
||||||
|
|
||||||
@@ -297,13 +310,18 @@ func (w *Writer) findTable(schemaName, tableName string, db *models.Database) *m
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// getModelName generates the model name from a table name
|
// getModelName generates the model name from schema and table name
|
||||||
func (w *Writer) getModelName(tableName string) string {
|
func (w *Writer) getModelName(schemaName, tableName string) string {
|
||||||
singular := Singularize(tableName)
|
singular := Singularize(tableName)
|
||||||
modelName := SnakeCaseToPascalCase(singular)
|
tablePart := SnakeCaseToPascalCase(singular)
|
||||||
|
|
||||||
if !hasModelPrefix(modelName) {
|
// Include schema name in model name
|
||||||
modelName = "Model" + modelName
|
var modelName string
|
||||||
|
if schemaName != "" {
|
||||||
|
schemaPart := SnakeCaseToPascalCase(schemaName)
|
||||||
|
modelName = "Model" + schemaPart + tablePart
|
||||||
|
} else {
|
||||||
|
modelName = "Model" + tablePart
|
||||||
}
|
}
|
||||||
|
|
||||||
return modelName
|
return modelName
|
||||||
@@ -327,13 +345,13 @@ func (w *Writer) generateBelongsToFieldName(constraint *models.Constraint) strin
|
|||||||
|
|
||||||
// generateHasManyFieldName generates a field name for has-many relationships
|
// generateHasManyFieldName generates a field name for has-many relationships
|
||||||
// Uses the foreign key column name + source table name to avoid duplicates
|
// Uses the foreign key column name + source table name to avoid duplicates
|
||||||
func (w *Writer) generateHasManyFieldName(constraint *models.Constraint, sourceTableName string) string {
|
func (w *Writer) generateHasManyFieldName(constraint *models.Constraint, sourceSchemaName, sourceTableName string) string {
|
||||||
// For has-many, we need to include the source table name to avoid duplicates
|
// For has-many, we need to include the source table name to avoid duplicates
|
||||||
// e.g., multiple tables referencing the same column on this table
|
// e.g., multiple tables referencing the same column on this table
|
||||||
if len(constraint.Columns) > 0 {
|
if len(constraint.Columns) > 0 {
|
||||||
columnName := constraint.Columns[0]
|
columnName := constraint.Columns[0]
|
||||||
// Get the model name for the source table (pluralized)
|
// Get the model name for the source table (pluralized)
|
||||||
sourceModelName := w.getModelName(sourceTableName)
|
sourceModelName := w.getModelName(sourceSchemaName, sourceTableName)
|
||||||
// Remove "Model" prefix if present
|
// Remove "Model" prefix if present
|
||||||
sourceModelName = strings.TrimPrefix(sourceModelName, "Model")
|
sourceModelName = strings.TrimPrefix(sourceModelName, "Model")
|
||||||
|
|
||||||
@@ -344,7 +362,7 @@ func (w *Writer) generateHasManyFieldName(constraint *models.Constraint, sourceT
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Fallback to table-based naming
|
// Fallback to table-based naming
|
||||||
sourceModelName := w.getModelName(sourceTableName)
|
sourceModelName := w.getModelName(sourceSchemaName, sourceTableName)
|
||||||
sourceModelName = strings.TrimPrefix(sourceModelName, "Model")
|
sourceModelName = strings.TrimPrefix(sourceModelName, "Model")
|
||||||
return "Rel" + Pluralize(sourceModelName)
|
return "Rel" + Pluralize(sourceModelName)
|
||||||
}
|
}
|
||||||
@@ -393,6 +411,15 @@ func (w *Writer) writeOutput(content string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// runGoFmt runs go fmt on the specified file
|
||||||
|
func (w *Writer) runGoFmt(filepath string) {
|
||||||
|
cmd := exec.Command("gofmt", "-w", filepath)
|
||||||
|
if err := cmd.Run(); err != nil {
|
||||||
|
// Don't fail the whole operation if gofmt fails, just warn
|
||||||
|
fmt.Fprintf(os.Stderr, "Warning: failed to run gofmt on %s: %v\n", filepath, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// shouldUseMultiFile determines whether to use multi-file mode based on metadata or output path
|
// shouldUseMultiFile determines whether to use multi-file mode based on metadata or output path
|
||||||
func (w *Writer) shouldUseMultiFile() bool {
|
func (w *Writer) shouldUseMultiFile() bool {
|
||||||
// Check if multi_file is explicitly set in metadata
|
// Check if multi_file is explicitly set in metadata
|
||||||
|
|||||||
@@ -66,7 +66,7 @@ func TestWriter_WriteTable(t *testing.T) {
|
|||||||
// Verify key elements are present
|
// Verify key elements are present
|
||||||
expectations := []string{
|
expectations := []string{
|
||||||
"package models",
|
"package models",
|
||||||
"type ModelUser struct",
|
"type ModelPublicUser struct",
|
||||||
"ID",
|
"ID",
|
||||||
"int64",
|
"int64",
|
||||||
"Email",
|
"Email",
|
||||||
@@ -75,9 +75,9 @@ func TestWriter_WriteTable(t *testing.T) {
|
|||||||
"time.Time",
|
"time.Time",
|
||||||
"gorm:\"column:id",
|
"gorm:\"column:id",
|
||||||
"gorm:\"column:email",
|
"gorm:\"column:email",
|
||||||
"func (m ModelUser) TableName() string",
|
"func (m ModelPublicUser) TableName() string",
|
||||||
"return \"public.users\"",
|
"return \"public.users\"",
|
||||||
"func (m ModelUser) GetID() int64",
|
"func (m ModelPublicUser) GetID() int64",
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, expected := range expectations {
|
for _, expected := range expectations {
|
||||||
@@ -180,9 +180,9 @@ func TestWriter_WriteDatabase_MultiFile(t *testing.T) {
|
|||||||
|
|
||||||
usersStr := string(usersContent)
|
usersStr := string(usersContent)
|
||||||
|
|
||||||
// Should have RelUserIDPosts (has-many) field
|
// Should have RelUserIDPublicPosts (has-many) field - includes schema prefix
|
||||||
if !strings.Contains(usersStr, "RelUserIDPosts") {
|
if !strings.Contains(usersStr, "RelUserIDPublicPosts") {
|
||||||
t.Errorf("Missing has-many relationship field RelUserIDPosts")
|
t.Errorf("Missing has-many relationship field RelUserIDPublicPosts")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -298,8 +298,8 @@ func TestWriter_MultipleReferencesToSameTable(t *testing.T) {
|
|||||||
|
|
||||||
// Should have two different has-many relationships with unique names
|
// Should have two different has-many relationships with unique names
|
||||||
hasManyExpectations := []string{
|
hasManyExpectations := []string{
|
||||||
"RelRIDFilepointerRequestAPIEvents", // Has many via rid_filepointer_request
|
"RelRIDFilepointerRequestOrgAPIEvents", // Has many via rid_filepointer_request
|
||||||
"RelRIDFilepointerResponseAPIEvents", // Has many via rid_filepointer_response
|
"RelRIDFilepointerResponseOrgAPIEvents", // Has many via rid_filepointer_response
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, exp := range hasManyExpectations {
|
for _, exp := range hasManyExpectations {
|
||||||
@@ -444,10 +444,10 @@ func TestWriter_MultipleHasManyRelationships(t *testing.T) {
|
|||||||
|
|
||||||
// Verify all has-many relationships have unique names
|
// Verify all has-many relationships have unique names
|
||||||
hasManyExpectations := []string{
|
hasManyExpectations := []string{
|
||||||
"RelRIDAPIProviderLogins", // Has many via Login
|
"RelRIDAPIProviderOrgLogins", // Has many via Login
|
||||||
"RelRIDAPIProviderFilepointers", // Has many via Filepointer
|
"RelRIDAPIProviderOrgFilepointers", // Has many via Filepointer
|
||||||
"RelRIDAPIProviderAPIEvents", // Has many via APIEvent
|
"RelRIDAPIProviderOrgAPIEvents", // Has many via APIEvent
|
||||||
"RelRIDOwner", // Belongs to via rid_owner
|
"RelRIDOwner", // Belongs to via rid_owner
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, exp := range hasManyExpectations {
|
for _, exp := range hasManyExpectations {
|
||||||
@@ -470,6 +470,134 @@ func TestWriter_MultipleHasManyRelationships(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestWriter_FieldNameCollision(t *testing.T) {
|
||||||
|
// Test scenario: table with columns that would conflict with generated method names
|
||||||
|
table := models.InitTable("audit_table", "audit")
|
||||||
|
table.Columns["id_audit_table"] = &models.Column{
|
||||||
|
Name: "id_audit_table",
|
||||||
|
Type: "smallint",
|
||||||
|
NotNull: true,
|
||||||
|
IsPrimaryKey: true,
|
||||||
|
Sequence: 1,
|
||||||
|
}
|
||||||
|
table.Columns["table_name"] = &models.Column{
|
||||||
|
Name: "table_name",
|
||||||
|
Type: "varchar",
|
||||||
|
Length: 100,
|
||||||
|
NotNull: true,
|
||||||
|
Sequence: 2,
|
||||||
|
}
|
||||||
|
table.Columns["table_schema"] = &models.Column{
|
||||||
|
Name: "table_schema",
|
||||||
|
Type: "varchar",
|
||||||
|
Length: 100,
|
||||||
|
NotNull: true,
|
||||||
|
Sequence: 3,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create writer
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
opts := &writers.WriterOptions{
|
||||||
|
PackageName: "models",
|
||||||
|
OutputPath: filepath.Join(tmpDir, "test.go"),
|
||||||
|
}
|
||||||
|
|
||||||
|
writer := NewWriter(opts)
|
||||||
|
|
||||||
|
err := writer.WriteTable(table)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("WriteTable failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read the generated file
|
||||||
|
content, err := os.ReadFile(opts.OutputPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to read generated file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
generated := string(content)
|
||||||
|
|
||||||
|
// Verify that TableName field was renamed to TableName_ to avoid collision
|
||||||
|
if !strings.Contains(generated, "TableName_") {
|
||||||
|
t.Errorf("Expected field 'TableName_' (with underscore) but not found\nGenerated:\n%s", generated)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the struct tag still references the correct database column
|
||||||
|
if !strings.Contains(generated, `gorm:"column:table_name;`) {
|
||||||
|
t.Errorf("Expected gorm tag to reference 'table_name' column\nGenerated:\n%s", generated)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the TableName() method still exists and doesn't conflict
|
||||||
|
if !strings.Contains(generated, "func (m ModelAuditAuditTable) TableName() string") {
|
||||||
|
t.Errorf("TableName() method should still be generated\nGenerated:\n%s", generated)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify NO field named just "TableName" (without underscore)
|
||||||
|
if strings.Contains(generated, "TableName sql_types") || strings.Contains(generated, "TableName string") {
|
||||||
|
t.Errorf("Field 'TableName' without underscore should not exist (would conflict with method)\nGenerated:\n%s", generated)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWriter_UpdateIDTypeSafety(t *testing.T) {
|
||||||
|
// Test scenario: tables with different primary key types
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
pkType string
|
||||||
|
expectedPK string
|
||||||
|
castType string
|
||||||
|
}{
|
||||||
|
{"int32_pk", "int", "int32", "int32(newid)"},
|
||||||
|
{"int16_pk", "smallint", "int16", "int16(newid)"},
|
||||||
|
{"int64_pk", "bigint", "int64", "int64(newid)"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
table := models.InitTable("test_table", "public")
|
||||||
|
table.Columns["id"] = &models.Column{
|
||||||
|
Name: "id",
|
||||||
|
Type: tt.pkType,
|
||||||
|
NotNull: true,
|
||||||
|
IsPrimaryKey: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
opts := &writers.WriterOptions{
|
||||||
|
PackageName: "models",
|
||||||
|
OutputPath: filepath.Join(tmpDir, "test.go"),
|
||||||
|
}
|
||||||
|
|
||||||
|
writer := NewWriter(opts)
|
||||||
|
err := writer.WriteTable(table)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("WriteTable failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
content, err := os.ReadFile(opts.OutputPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to read generated file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
generated := string(content)
|
||||||
|
|
||||||
|
// Verify UpdateID method has correct type cast
|
||||||
|
if !strings.Contains(generated, tt.castType) {
|
||||||
|
t.Errorf("Expected UpdateID to cast to %s\nGenerated:\n%s", tt.castType, generated)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify no invalid int32(newid) for non-int32 types
|
||||||
|
if tt.expectedPK != "int32" && strings.Contains(generated, "int32(newid)") {
|
||||||
|
t.Errorf("UpdateID should not cast to int32 for %s type\nGenerated:\n%s", tt.pkType, generated)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify UpdateID parameter is int64 (for consistency)
|
||||||
|
if !strings.Contains(generated, "UpdateID(newid int64)") {
|
||||||
|
t.Errorf("UpdateID should accept int64 parameter\nGenerated:\n%s", generated)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestNameConverter_SnakeCaseToPascalCase(t *testing.T) {
|
func TestNameConverter_SnakeCaseToPascalCase(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
input string
|
input string
|
||||||
|
|||||||
Reference in New Issue
Block a user