Support typed primary key helpers in gorm and bun writers

This commit is contained in:
2026-05-05 10:32:33 +02:00
parent e828d48798
commit 4303dcf59b
8 changed files with 328 additions and 39 deletions

View File

@@ -2,6 +2,7 @@ package gorm
import (
"sort"
"strings"
"git.warky.dev/wdevs/relspecgo/pkg/models"
"git.warky.dev/wdevs/relspecgo/pkg/writers"
@@ -17,17 +18,20 @@ type TemplateData struct {
// ModelData represents a single model/struct in the template
type ModelData struct {
Name string
TableName string // schema.table format
SchemaName string
TableNameOnly string // just table name without schema
Comment string
Fields []*FieldData
Config *MethodConfig
PrimaryKeyField string // Name of the primary key field
PrimaryKeyType string // Go type of the primary key field
IDColumnName string // Name of the ID column in database
Prefix string // 3-letter prefix
Name string
TableName string // schema.table format
SchemaName string
TableNameOnly string // just table name without schema
Comment string
Fields []*FieldData
Config *MethodConfig
PrimaryKeyField string // Name of the primary key field
PrimaryKeyType string // Go type of the primary key field
PrimaryKeyIsSQL bool // Whether PK uses a SQL wrapper type
PrimaryKeyIsStr bool // Whether helper methods should use string IDs
PrimaryKeyIDType string // Helper method GetID/SetID/UpdateID type
IDColumnName string // Name of the ID column in database
Prefix string // 3-letter prefix
}
// FieldData represents a single field in a struct
@@ -136,7 +140,14 @@ func NewModelData(table *models.Table, schema string, typeMapper *TypeMapper, fl
// Sanitize column name to remove backticks
safeName := writers.SanitizeStructTagValue(col.Name)
model.PrimaryKeyField = SnakeCaseToPascalCase(safeName)
model.PrimaryKeyType = typeMapper.SQLTypeToGoType(col.Type, col.NotNull)
goType := typeMapper.SQLTypeToGoType(col.Type, col.NotNull)
model.PrimaryKeyType = goType
model.PrimaryKeyIsSQL = strings.Contains(goType, "sql_types.") || strings.Contains(goType, "sql.")
model.PrimaryKeyIsStr = isStringLikePrimaryKeyType(goType)
model.PrimaryKeyIDType = "int64"
if model.PrimaryKeyIsStr {
model.PrimaryKeyIDType = "string"
}
model.IDColumnName = safeName
break
}
@@ -189,6 +200,15 @@ func formatComment(description, comment string) string {
return comment
}
func isStringLikePrimaryKeyType(goType string) bool {
switch goType {
case "string", "sql.NullString", "sql_types.SqlString", "sql_types.SqlUUID":
return true
default:
return false
}
}
// resolveFieldNameCollision checks if a field name conflicts with generated method names
// and adds an underscore suffix if there's a collision
func resolveFieldNameCollision(fieldName string) string {

View File

@@ -43,26 +43,56 @@ func (m {{.Name}}) SchemaName() string {
{{end}}
{{if and .Config.GenerateGetID .PrimaryKeyField}}
// GetID returns the primary key value
func (m {{.Name}}) GetID() int64 {
func (m {{.Name}}) GetID() {{.PrimaryKeyIDType}} {
{{if .PrimaryKeyIsSQL -}}
{{if .PrimaryKeyIsStr -}}
return m.{{.PrimaryKeyField}}.String()
{{- else -}}
return m.{{.PrimaryKeyField}}.Int64()
{{- end}}
{{- else -}}
{{if .PrimaryKeyIsStr -}}
return m.{{.PrimaryKeyField}}
{{- else -}}
return int64(m.{{.PrimaryKeyField}})
{{- end}}
{{- end}}
}
{{end}}
{{if and .Config.GenerateGetIDStr .PrimaryKeyField}}
// GetIDStr returns the primary key as a string
func (m {{.Name}}) GetIDStr() string {
{{if .PrimaryKeyIsSQL -}}
return m.{{.PrimaryKeyField}}.String()
{{- else if .PrimaryKeyIsStr -}}
return m.{{.PrimaryKeyField}}
{{- else -}}
return fmt.Sprintf("%d", m.{{.PrimaryKeyField}})
{{- end}}
}
{{end}}
{{if and .Config.GenerateSetID .PrimaryKeyField}}
// SetID sets the primary key value
func (m {{.Name}}) SetID(newid int64) {
func (m {{.Name}}) SetID(newid {{.PrimaryKeyIDType}}) {
m.UpdateID(newid)
}
{{end}}
{{if and .Config.GenerateUpdateID .PrimaryKeyField}}
// UpdateID updates the primary key value
func (m *{{.Name}}) UpdateID(newid int64) {
func (m *{{.Name}}) UpdateID(newid {{.PrimaryKeyIDType}}) {
{{if .PrimaryKeyIsSQL -}}
{{if .PrimaryKeyIsStr -}}
m.{{.PrimaryKeyField}}.FromString(newid)
{{- else -}}
m.{{.PrimaryKeyField}}.FromString(fmt.Sprintf("%d", newid))
{{- end}}
{{- else -}}
{{if .PrimaryKeyIsStr -}}
m.{{.PrimaryKeyField}} = newid
{{- else -}}
m.{{.PrimaryKeyField}} = {{.PrimaryKeyType}}(newid)
{{- end}}
{{- end}}
}
{{end}}
{{if and .Config.GenerateGetIDName .IDColumnName}}

View File

@@ -99,8 +99,8 @@ func (w *Writer) writeSingleFile(db *models.Database) error {
}
}
// Add fmt import if GetIDStr is enabled
if w.config.GenerateGetIDStr {
// Add fmt import when generated helper methods need string formatting.
if w.needsFmtImport(templateData.Models) {
templateData.AddImport("\"fmt\"")
}
@@ -189,8 +189,8 @@ func (w *Writer) writeMultiFile(db *models.Database) error {
}
}
// Add fmt import if GetIDStr is enabled
if w.config.GenerateGetIDStr {
// Add fmt import when generated helper methods need string formatting.
if w.needsFmtImport(templateData.Models) {
templateData.AddImport("\"fmt\"")
}
@@ -295,6 +295,26 @@ func (w *Writer) addRelationshipFields(modelData *ModelData, table *models.Table
}
}
func (w *Writer) needsFmtImport(models []*ModelData) bool {
if w.config.GenerateGetIDStr {
for _, model := range models {
if model.PrimaryKeyField != "" && !model.PrimaryKeyIsSQL && !model.PrimaryKeyIsStr {
return true
}
}
}
if w.config.GenerateUpdateID {
for _, model := range models {
if model.PrimaryKeyField != "" && model.PrimaryKeyIsSQL && !model.PrimaryKeyIsStr {
return true
}
}
}
return false
}
// findTable finds a table by schema and name in the database
func (w *Writer) findTable(schemaName, tableName string, db *models.Database) *models.Table {
for _, schema := range db.Schemas {

View File

@@ -598,6 +598,55 @@ func TestWriter_UpdateIDTypeSafety(t *testing.T) {
}
}
func TestWriter_StringPrimaryKeyHelpers_Gorm(t *testing.T) {
table := models.InitTable("accounts", "public")
table.Columns["id"] = &models.Column{
Name: "id",
Type: "uuid",
NotNull: true,
IsPrimaryKey: true,
}
tmpDir := t.TempDir()
opts := &writers.WriterOptions{
PackageName: "models",
OutputPath: filepath.Join(tmpDir, "test.go"),
}
writer := NewWriter(opts)
err := writer.WriteTable(table)
if err != nil {
t.Fatalf("WriteTable failed: %v", err)
}
content, err := os.ReadFile(opts.OutputPath)
if err != nil {
t.Fatalf("Failed to read generated file: %v", err)
}
generated := string(content)
expectations := []string{
"ID string",
"func (m ModelPublicAccounts) GetID() string",
"return m.ID",
"func (m ModelPublicAccounts) GetIDStr() string",
"func (m ModelPublicAccounts) SetID(newid string)",
"func (m *ModelPublicAccounts) UpdateID(newid string)",
"m.ID = newid",
}
for _, expected := range expectations {
if !strings.Contains(generated, expected) {
t.Errorf("Generated code missing expected content: %q\nGenerated:\n%s", expected, generated)
}
}
if strings.Contains(generated, "GetID() int64") || strings.Contains(generated, "UpdateID(newid int64)") {
t.Errorf("String primary keys should not use int64 helper signatures\nGenerated:\n%s", generated)
}
}
func TestNameConverter_SnakeCaseToPascalCase(t *testing.T) {
tests := []struct {
input string