diff --git a/pkg/writers/bun/template_data.go b/pkg/writers/bun/template_data.go index 2f017c7..1b630f0 100644 --- a/pkg/writers/bun/template_data.go +++ b/pkg/writers/bun/template_data.go @@ -18,17 +18,20 @@ type TemplateData struct { // ModelData represents a single model/struct in the template type ModelData struct { - Name string - TableName string // schema.table format - SchemaName string - TableNameOnly string // just table name without schema - Comment string - Fields []*FieldData - Config *MethodConfig - PrimaryKeyField string // Name of the primary key field - PrimaryKeyIsSQL bool // Whether PK uses SQL type (needs .Int64() call) - IDColumnName string // Name of the ID column in database - Prefix string // 3-letter prefix + Name string + TableName string // schema.table format + SchemaName string + TableNameOnly string // just table name without schema + Comment string + Fields []*FieldData + Config *MethodConfig + PrimaryKeyField string // Name of the primary key field + PrimaryKeyType string // Go type of the primary key field + PrimaryKeyIsSQL bool // Whether PK uses SQL type (needs .Int64() call) + PrimaryKeyIsStr bool // Whether helper methods should use string IDs + PrimaryKeyIDType string // Helper method GetID/SetID/UpdateID type + IDColumnName string // Name of the ID column in database + Prefix string // 3-letter prefix } // FieldData represents a single field in a struct @@ -140,7 +143,13 @@ func NewModelData(table *models.Table, schema string, typeMapper *TypeMapper, fl model.IDColumnName = safeName // Check if PK type is a SQL type (contains resolvespec_common or sql_types) goType := typeMapper.SQLTypeToGoType(col.Type, col.NotNull) + model.PrimaryKeyType = goType model.PrimaryKeyIsSQL = strings.Contains(goType, "resolvespec_common") || strings.Contains(goType, "sql_types") + model.PrimaryKeyIsStr = isStringLikePrimaryKeyType(goType) + model.PrimaryKeyIDType = "int64" + if model.PrimaryKeyIsStr { + model.PrimaryKeyIDType = "string" + } break } } @@ -192,6 +201,15 @@ func formatComment(description, comment string) string { return comment } +func isStringLikePrimaryKeyType(goType string) bool { + switch goType { + case "string", "sql.NullString", "resolvespec_common.SqlString", "resolvespec_common.SqlUUID": + return true + default: + return false + } +} + // resolveFieldNameCollision checks if a field name conflicts with generated method names // and adds an underscore suffix if there's a collision func resolveFieldNameCollision(fieldName string) string { diff --git a/pkg/writers/bun/templates.go b/pkg/writers/bun/templates.go index 9b9331b..5c0fa19 100644 --- a/pkg/writers/bun/templates.go +++ b/pkg/writers/bun/templates.go @@ -44,33 +44,55 @@ func (m {{.Name}}) SchemaName() string { {{end}} {{if and .Config.GenerateGetID .PrimaryKeyField}} // GetID returns the primary key value -func (m {{.Name}}) GetID() int64 { +func (m {{.Name}}) GetID() {{.PrimaryKeyIDType}} { {{if .PrimaryKeyIsSQL -}} + {{if .PrimaryKeyIsStr -}} + return m.{{.PrimaryKeyField}}.String() + {{- else -}} return m.{{.PrimaryKeyField}}.Int64() + {{- end}} + {{- else -}} + {{if .PrimaryKeyIsStr -}} + return m.{{.PrimaryKeyField}} {{- else -}} return int64(m.{{.PrimaryKeyField}}) {{- end}} + {{- end}} } {{end}} {{if and .Config.GenerateGetIDStr .PrimaryKeyField}} // GetIDStr returns the primary key as a string func (m {{.Name}}) GetIDStr() string { + {{if .PrimaryKeyIsSQL -}} + return m.{{.PrimaryKeyField}}.String() + {{- else if .PrimaryKeyIsStr -}} + return m.{{.PrimaryKeyField}} + {{- else -}} return fmt.Sprintf("%d", m.{{.PrimaryKeyField}}) + {{- end}} } {{end}} {{if and .Config.GenerateSetID .PrimaryKeyField}} // SetID sets the primary key value -func (m {{.Name}}) SetID(newid int64) { +func (m {{.Name}}) SetID(newid {{.PrimaryKeyIDType}}) { m.UpdateID(newid) } {{end}} {{if and .Config.GenerateUpdateID .PrimaryKeyField}} // UpdateID updates the primary key value -func (m *{{.Name}}) UpdateID(newid int64) { +func (m *{{.Name}}) UpdateID(newid {{.PrimaryKeyIDType}}) { {{if .PrimaryKeyIsSQL -}} - m.{{.PrimaryKeyField}}.FromString(fmt.Sprintf("%d", newid)) + {{if .PrimaryKeyIsStr -}} + m.{{.PrimaryKeyField}}.FromString(newid) {{- else -}} - m.{{.PrimaryKeyField}} = int32(newid) + m.{{.PrimaryKeyField}}.FromString(fmt.Sprintf("%d", newid)) + {{- end}} + {{- else -}} + {{if .PrimaryKeyIsStr -}} + m.{{.PrimaryKeyField}} = newid + {{- else -}} + m.{{.PrimaryKeyField}} = {{.PrimaryKeyType}}(newid) + {{- end}} {{- end}} } {{end}} diff --git a/pkg/writers/bun/writer.go b/pkg/writers/bun/writer.go index 3e8afd7..37561ac 100644 --- a/pkg/writers/bun/writer.go +++ b/pkg/writers/bun/writer.go @@ -102,8 +102,8 @@ func (w *Writer) writeSingleFile(db *models.Database) error { } } - // Add fmt import if GetIDStr is enabled - if w.config.GenerateGetIDStr { + // Add fmt import when generated helper methods need string formatting. + if w.needsFmtImport(templateData.Models) { templateData.AddImport("\"fmt\"") } @@ -195,8 +195,8 @@ func (w *Writer) writeMultiFile(db *models.Database) error { } } - // Add fmt import if GetIDStr is enabled - if w.config.GenerateGetIDStr { + // Add fmt import when generated helper methods need string formatting. + if w.needsFmtImport(templateData.Models) { templateData.AddImport("\"fmt\"") } @@ -301,6 +301,26 @@ func (w *Writer) addRelationshipFields(modelData *ModelData, table *models.Table } } +func (w *Writer) needsFmtImport(models []*ModelData) bool { + if w.config.GenerateGetIDStr { + for _, model := range models { + if model.PrimaryKeyField != "" && !model.PrimaryKeyIsSQL && !model.PrimaryKeyIsStr { + return true + } + } + } + + if w.config.GenerateUpdateID { + for _, model := range models { + if model.PrimaryKeyField != "" && model.PrimaryKeyIsSQL && !model.PrimaryKeyIsStr { + return true + } + } + } + + return false +} + // findTable finds a table by schema and name in the database func (w *Writer) findTable(schemaName, tableName string, db *models.Database) *models.Table { for _, schema := range db.Schemas { diff --git a/pkg/writers/bun/writer_test.go b/pkg/writers/bun/writer_test.go index c53a25e..e987ac1 100644 --- a/pkg/writers/bun/writer_test.go +++ b/pkg/writers/bun/writer_test.go @@ -590,6 +590,116 @@ func TestTypeMapper_SQLTypeToGoType_Bun(t *testing.T) { } } +func TestWriter_UpdateIDTypeSafety_Bun(t *testing.T) { + tests := []struct { + name string + pkType string + expectedPK string + expectedLine string + forbidInt32 bool + }{ + {"int32_pk", "int", "int32", "m.ID = int32(newid)", false}, + {"sql_int16_pk", "smallint", "resolvespec_common.SqlInt16", "m.ID.FromString(fmt.Sprintf(\"%d\", newid))", true}, + {"int64_pk", "bigint", "int64", "m.ID = int64(newid)", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + table := models.InitTable("test_table", "public") + table.Columns["id"] = &models.Column{ + Name: "id", + Type: tt.pkType, + NotNull: true, + IsPrimaryKey: true, + } + + tmpDir := t.TempDir() + opts := &writers.WriterOptions{ + PackageName: "models", + OutputPath: filepath.Join(tmpDir, "test.go"), + } + + writer := NewWriter(opts) + err := writer.WriteTable(table) + if err != nil { + t.Fatalf("WriteTable failed: %v", err) + } + + content, err := os.ReadFile(opts.OutputPath) + if err != nil { + t.Fatalf("Failed to read generated file: %v", err) + } + + generated := string(content) + + if !strings.Contains(generated, tt.expectedLine) { + t.Errorf("Expected UpdateID to include %s\nGenerated:\n%s", tt.expectedLine, generated) + } + + if !strings.Contains(generated, "ID "+tt.expectedPK) { + t.Errorf("Expected generated primary key field type %s\nGenerated:\n%s", tt.expectedPK, generated) + } + + if tt.forbidInt32 && strings.Contains(generated, "int32(newid)") { + t.Errorf("UpdateID should not cast to int32 for %s type\nGenerated:\n%s", tt.pkType, generated) + } + + if !strings.Contains(generated, "UpdateID(newid int64)") { + t.Errorf("UpdateID should accept int64 parameter\nGenerated:\n%s", generated) + } + }) + } +} + +func TestWriter_StringPrimaryKeyHelpers_Bun(t *testing.T) { + table := models.InitTable("accounts", "public") + table.Columns["id"] = &models.Column{ + Name: "id", + Type: "uuid", + NotNull: true, + IsPrimaryKey: true, + } + + tmpDir := t.TempDir() + opts := &writers.WriterOptions{ + PackageName: "models", + OutputPath: filepath.Join(tmpDir, "test.go"), + } + + writer := NewWriter(opts) + err := writer.WriteTable(table) + if err != nil { + t.Fatalf("WriteTable failed: %v", err) + } + + content, err := os.ReadFile(opts.OutputPath) + if err != nil { + t.Fatalf("Failed to read generated file: %v", err) + } + + generated := string(content) + + expectations := []string{ + "resolvespec_common.SqlUUID", + "func (m ModelPublicAccounts) GetID() string", + "return m.ID.String()", + "func (m ModelPublicAccounts) GetIDStr() string", + "func (m ModelPublicAccounts) SetID(newid string)", + "func (m *ModelPublicAccounts) UpdateID(newid string)", + "m.ID.FromString(newid)", + } + + for _, expected := range expectations { + if !strings.Contains(generated, expected) { + t.Errorf("Generated code missing expected content: %q\nGenerated:\n%s", expected, generated) + } + } + + if strings.Contains(generated, "GetID() int64") || strings.Contains(generated, "UpdateID(newid int64)") { + t.Errorf("String primary keys should not use int64 helper signatures\nGenerated:\n%s", generated) + } +} + func TestTypeMapper_BuildBunTag(t *testing.T) { mapper := NewTypeMapper("") diff --git a/pkg/writers/gorm/template_data.go b/pkg/writers/gorm/template_data.go index ae4f1c4..2b54350 100644 --- a/pkg/writers/gorm/template_data.go +++ b/pkg/writers/gorm/template_data.go @@ -2,6 +2,7 @@ package gorm import ( "sort" + "strings" "git.warky.dev/wdevs/relspecgo/pkg/models" "git.warky.dev/wdevs/relspecgo/pkg/writers" @@ -17,17 +18,20 @@ type TemplateData struct { // ModelData represents a single model/struct in the template type ModelData struct { - Name string - TableName string // schema.table format - SchemaName string - TableNameOnly string // just table name without schema - Comment string - Fields []*FieldData - Config *MethodConfig - PrimaryKeyField string // Name of the primary key field - PrimaryKeyType string // Go type of the primary key field - IDColumnName string // Name of the ID column in database - Prefix string // 3-letter prefix + Name string + TableName string // schema.table format + SchemaName string + TableNameOnly string // just table name without schema + Comment string + Fields []*FieldData + Config *MethodConfig + PrimaryKeyField string // Name of the primary key field + PrimaryKeyType string // Go type of the primary key field + PrimaryKeyIsSQL bool // Whether PK uses a SQL wrapper type + PrimaryKeyIsStr bool // Whether helper methods should use string IDs + PrimaryKeyIDType string // Helper method GetID/SetID/UpdateID type + IDColumnName string // Name of the ID column in database + Prefix string // 3-letter prefix } // FieldData represents a single field in a struct @@ -136,7 +140,14 @@ func NewModelData(table *models.Table, schema string, typeMapper *TypeMapper, fl // Sanitize column name to remove backticks safeName := writers.SanitizeStructTagValue(col.Name) model.PrimaryKeyField = SnakeCaseToPascalCase(safeName) - model.PrimaryKeyType = typeMapper.SQLTypeToGoType(col.Type, col.NotNull) + goType := typeMapper.SQLTypeToGoType(col.Type, col.NotNull) + model.PrimaryKeyType = goType + model.PrimaryKeyIsSQL = strings.Contains(goType, "sql_types.") || strings.Contains(goType, "sql.") + model.PrimaryKeyIsStr = isStringLikePrimaryKeyType(goType) + model.PrimaryKeyIDType = "int64" + if model.PrimaryKeyIsStr { + model.PrimaryKeyIDType = "string" + } model.IDColumnName = safeName break } @@ -189,6 +200,15 @@ func formatComment(description, comment string) string { return comment } +func isStringLikePrimaryKeyType(goType string) bool { + switch goType { + case "string", "sql.NullString", "sql_types.SqlString", "sql_types.SqlUUID": + return true + default: + return false + } +} + // resolveFieldNameCollision checks if a field name conflicts with generated method names // and adds an underscore suffix if there's a collision func resolveFieldNameCollision(fieldName string) string { diff --git a/pkg/writers/gorm/templates.go b/pkg/writers/gorm/templates.go index 5ff0cd5..b337d40 100644 --- a/pkg/writers/gorm/templates.go +++ b/pkg/writers/gorm/templates.go @@ -43,26 +43,56 @@ func (m {{.Name}}) SchemaName() string { {{end}} {{if and .Config.GenerateGetID .PrimaryKeyField}} // GetID returns the primary key value -func (m {{.Name}}) GetID() int64 { +func (m {{.Name}}) GetID() {{.PrimaryKeyIDType}} { + {{if .PrimaryKeyIsSQL -}} + {{if .PrimaryKeyIsStr -}} + return m.{{.PrimaryKeyField}}.String() + {{- else -}} + return m.{{.PrimaryKeyField}}.Int64() + {{- end}} + {{- else -}} + {{if .PrimaryKeyIsStr -}} + return m.{{.PrimaryKeyField}} + {{- else -}} return int64(m.{{.PrimaryKeyField}}) + {{- end}} + {{- end}} } {{end}} {{if and .Config.GenerateGetIDStr .PrimaryKeyField}} // GetIDStr returns the primary key as a string func (m {{.Name}}) GetIDStr() string { + {{if .PrimaryKeyIsSQL -}} + return m.{{.PrimaryKeyField}}.String() + {{- else if .PrimaryKeyIsStr -}} + return m.{{.PrimaryKeyField}} + {{- else -}} return fmt.Sprintf("%d", m.{{.PrimaryKeyField}}) + {{- end}} } {{end}} {{if and .Config.GenerateSetID .PrimaryKeyField}} // SetID sets the primary key value -func (m {{.Name}}) SetID(newid int64) { +func (m {{.Name}}) SetID(newid {{.PrimaryKeyIDType}}) { m.UpdateID(newid) } {{end}} {{if and .Config.GenerateUpdateID .PrimaryKeyField}} // UpdateID updates the primary key value -func (m *{{.Name}}) UpdateID(newid int64) { +func (m *{{.Name}}) UpdateID(newid {{.PrimaryKeyIDType}}) { + {{if .PrimaryKeyIsSQL -}} + {{if .PrimaryKeyIsStr -}} + m.{{.PrimaryKeyField}}.FromString(newid) + {{- else -}} + m.{{.PrimaryKeyField}}.FromString(fmt.Sprintf("%d", newid)) + {{- end}} + {{- else -}} + {{if .PrimaryKeyIsStr -}} + m.{{.PrimaryKeyField}} = newid + {{- else -}} m.{{.PrimaryKeyField}} = {{.PrimaryKeyType}}(newid) + {{- end}} + {{- end}} } {{end}} {{if and .Config.GenerateGetIDName .IDColumnName}} diff --git a/pkg/writers/gorm/writer.go b/pkg/writers/gorm/writer.go index f4a80fe..6e9121e 100644 --- a/pkg/writers/gorm/writer.go +++ b/pkg/writers/gorm/writer.go @@ -99,8 +99,8 @@ func (w *Writer) writeSingleFile(db *models.Database) error { } } - // Add fmt import if GetIDStr is enabled - if w.config.GenerateGetIDStr { + // Add fmt import when generated helper methods need string formatting. + if w.needsFmtImport(templateData.Models) { templateData.AddImport("\"fmt\"") } @@ -189,8 +189,8 @@ func (w *Writer) writeMultiFile(db *models.Database) error { } } - // Add fmt import if GetIDStr is enabled - if w.config.GenerateGetIDStr { + // Add fmt import when generated helper methods need string formatting. + if w.needsFmtImport(templateData.Models) { templateData.AddImport("\"fmt\"") } @@ -295,6 +295,26 @@ func (w *Writer) addRelationshipFields(modelData *ModelData, table *models.Table } } +func (w *Writer) needsFmtImport(models []*ModelData) bool { + if w.config.GenerateGetIDStr { + for _, model := range models { + if model.PrimaryKeyField != "" && !model.PrimaryKeyIsSQL && !model.PrimaryKeyIsStr { + return true + } + } + } + + if w.config.GenerateUpdateID { + for _, model := range models { + if model.PrimaryKeyField != "" && model.PrimaryKeyIsSQL && !model.PrimaryKeyIsStr { + return true + } + } + } + + return false +} + // findTable finds a table by schema and name in the database func (w *Writer) findTable(schemaName, tableName string, db *models.Database) *models.Table { for _, schema := range db.Schemas { diff --git a/pkg/writers/gorm/writer_test.go b/pkg/writers/gorm/writer_test.go index f743d04..49851bf 100644 --- a/pkg/writers/gorm/writer_test.go +++ b/pkg/writers/gorm/writer_test.go @@ -598,6 +598,55 @@ func TestWriter_UpdateIDTypeSafety(t *testing.T) { } } +func TestWriter_StringPrimaryKeyHelpers_Gorm(t *testing.T) { + table := models.InitTable("accounts", "public") + table.Columns["id"] = &models.Column{ + Name: "id", + Type: "uuid", + NotNull: true, + IsPrimaryKey: true, + } + + tmpDir := t.TempDir() + opts := &writers.WriterOptions{ + PackageName: "models", + OutputPath: filepath.Join(tmpDir, "test.go"), + } + + writer := NewWriter(opts) + err := writer.WriteTable(table) + if err != nil { + t.Fatalf("WriteTable failed: %v", err) + } + + content, err := os.ReadFile(opts.OutputPath) + if err != nil { + t.Fatalf("Failed to read generated file: %v", err) + } + + generated := string(content) + + expectations := []string{ + "ID string", + "func (m ModelPublicAccounts) GetID() string", + "return m.ID", + "func (m ModelPublicAccounts) GetIDStr() string", + "func (m ModelPublicAccounts) SetID(newid string)", + "func (m *ModelPublicAccounts) UpdateID(newid string)", + "m.ID = newid", + } + + for _, expected := range expectations { + if !strings.Contains(generated, expected) { + t.Errorf("Generated code missing expected content: %q\nGenerated:\n%s", expected, generated) + } + } + + if strings.Contains(generated, "GetID() int64") || strings.Contains(generated, "UpdateID(newid int64)") { + t.Errorf("String primary keys should not use int64 helper signatures\nGenerated:\n%s", generated) + } +} + func TestNameConverter_SnakeCaseToPascalCase(t *testing.T) { tests := []struct { input string