Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| a447b68b22 | |||
| 4303dcf59b | |||
| e828d48798 | |||
| 6e470a9239 |
@@ -1,6 +1,6 @@
|
|||||||
# Maintainer: Hein (Warky Devs) <hein@warky.dev>
|
# Maintainer: Hein (Warky Devs) <hein@warky.dev>
|
||||||
pkgname=relspec
|
pkgname=relspec
|
||||||
pkgver=1.0.51
|
pkgver=1.0.53
|
||||||
pkgrel=1
|
pkgrel=1
|
||||||
pkgdesc="RelSpec is a comprehensive database relations management tool that reads, transforms, and writes database table specifications across multiple formats and ORMs."
|
pkgdesc="RelSpec is a comprehensive database relations management tool that reads, transforms, and writes database table specifications across multiple formats and ORMs."
|
||||||
arch=('x86_64' 'aarch64')
|
arch=('x86_64' 'aarch64')
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
Name: relspec
|
Name: relspec
|
||||||
Version: 1.0.51
|
Version: 1.0.53
|
||||||
Release: 1%{?dist}
|
Release: 1%{?dist}
|
||||||
Summary: RelSpec is a comprehensive database relations management tool that reads, transforms, and writes database table specifications across multiple formats and ORMs.
|
Summary: RelSpec is a comprehensive database relations management tool that reads, transforms, and writes database table specifications across multiple formats and ORMs.
|
||||||
|
|
||||||
|
|||||||
@@ -18,17 +18,20 @@ type TemplateData struct {
|
|||||||
|
|
||||||
// ModelData represents a single model/struct in the template
|
// ModelData represents a single model/struct in the template
|
||||||
type ModelData struct {
|
type ModelData struct {
|
||||||
Name string
|
Name string
|
||||||
TableName string // schema.table format
|
TableName string // schema.table format
|
||||||
SchemaName string
|
SchemaName string
|
||||||
TableNameOnly string // just table name without schema
|
TableNameOnly string // just table name without schema
|
||||||
Comment string
|
Comment string
|
||||||
Fields []*FieldData
|
Fields []*FieldData
|
||||||
Config *MethodConfig
|
Config *MethodConfig
|
||||||
PrimaryKeyField string // Name of the primary key field
|
PrimaryKeyField string // Name of the primary key field
|
||||||
PrimaryKeyIsSQL bool // Whether PK uses SQL type (needs .Int64() call)
|
PrimaryKeyType string // Go type of the primary key field
|
||||||
IDColumnName string // Name of the ID column in database
|
PrimaryKeyIsSQL bool // Whether PK uses SQL type (needs .Int64() call)
|
||||||
Prefix string // 3-letter prefix
|
PrimaryKeyIsStr bool // Whether helper methods should use string IDs
|
||||||
|
PrimaryKeyIDType string // Helper method GetID/SetID/UpdateID type
|
||||||
|
IDColumnName string // Name of the ID column in database
|
||||||
|
Prefix string // 3-letter prefix
|
||||||
}
|
}
|
||||||
|
|
||||||
// FieldData represents a single field in a struct
|
// FieldData represents a single field in a struct
|
||||||
@@ -140,7 +143,13 @@ func NewModelData(table *models.Table, schema string, typeMapper *TypeMapper, fl
|
|||||||
model.IDColumnName = safeName
|
model.IDColumnName = safeName
|
||||||
// Check if PK type is a SQL type (contains resolvespec_common or sql_types)
|
// Check if PK type is a SQL type (contains resolvespec_common or sql_types)
|
||||||
goType := typeMapper.SQLTypeToGoType(col.Type, col.NotNull)
|
goType := typeMapper.SQLTypeToGoType(col.Type, col.NotNull)
|
||||||
|
model.PrimaryKeyType = goType
|
||||||
model.PrimaryKeyIsSQL = strings.Contains(goType, "resolvespec_common") || strings.Contains(goType, "sql_types")
|
model.PrimaryKeyIsSQL = strings.Contains(goType, "resolvespec_common") || strings.Contains(goType, "sql_types")
|
||||||
|
model.PrimaryKeyIsStr = isStringLikePrimaryKeyType(goType)
|
||||||
|
model.PrimaryKeyIDType = "int64"
|
||||||
|
if model.PrimaryKeyIsStr {
|
||||||
|
model.PrimaryKeyIDType = "string"
|
||||||
|
}
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -192,6 +201,15 @@ func formatComment(description, comment string) string {
|
|||||||
return comment
|
return comment
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isStringLikePrimaryKeyType(goType string) bool {
|
||||||
|
switch goType {
|
||||||
|
case "string", "sql.NullString", "resolvespec_common.SqlString", "resolvespec_common.SqlUUID":
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// resolveFieldNameCollision checks if a field name conflicts with generated method names
|
// resolveFieldNameCollision checks if a field name conflicts with generated method names
|
||||||
// and adds an underscore suffix if there's a collision
|
// and adds an underscore suffix if there's a collision
|
||||||
func resolveFieldNameCollision(fieldName string) string {
|
func resolveFieldNameCollision(fieldName string) string {
|
||||||
|
|||||||
@@ -44,33 +44,55 @@ func (m {{.Name}}) SchemaName() string {
|
|||||||
{{end}}
|
{{end}}
|
||||||
{{if and .Config.GenerateGetID .PrimaryKeyField}}
|
{{if and .Config.GenerateGetID .PrimaryKeyField}}
|
||||||
// GetID returns the primary key value
|
// GetID returns the primary key value
|
||||||
func (m {{.Name}}) GetID() int64 {
|
func (m {{.Name}}) GetID() {{.PrimaryKeyIDType}} {
|
||||||
{{if .PrimaryKeyIsSQL -}}
|
{{if .PrimaryKeyIsSQL -}}
|
||||||
|
{{if .PrimaryKeyIsStr -}}
|
||||||
|
return m.{{.PrimaryKeyField}}.String()
|
||||||
|
{{- else -}}
|
||||||
return m.{{.PrimaryKeyField}}.Int64()
|
return m.{{.PrimaryKeyField}}.Int64()
|
||||||
|
{{- end}}
|
||||||
|
{{- else -}}
|
||||||
|
{{if .PrimaryKeyIsStr -}}
|
||||||
|
return m.{{.PrimaryKeyField}}
|
||||||
{{- else -}}
|
{{- else -}}
|
||||||
return int64(m.{{.PrimaryKeyField}})
|
return int64(m.{{.PrimaryKeyField}})
|
||||||
{{- end}}
|
{{- end}}
|
||||||
|
{{- end}}
|
||||||
}
|
}
|
||||||
{{end}}
|
{{end}}
|
||||||
{{if and .Config.GenerateGetIDStr .PrimaryKeyField}}
|
{{if and .Config.GenerateGetIDStr .PrimaryKeyField}}
|
||||||
// GetIDStr returns the primary key as a string
|
// GetIDStr returns the primary key as a string
|
||||||
func (m {{.Name}}) GetIDStr() string {
|
func (m {{.Name}}) GetIDStr() string {
|
||||||
|
{{if .PrimaryKeyIsSQL -}}
|
||||||
|
return m.{{.PrimaryKeyField}}.String()
|
||||||
|
{{- else if .PrimaryKeyIsStr -}}
|
||||||
|
return m.{{.PrimaryKeyField}}
|
||||||
|
{{- else -}}
|
||||||
return fmt.Sprintf("%d", m.{{.PrimaryKeyField}})
|
return fmt.Sprintf("%d", m.{{.PrimaryKeyField}})
|
||||||
|
{{- end}}
|
||||||
}
|
}
|
||||||
{{end}}
|
{{end}}
|
||||||
{{if and .Config.GenerateSetID .PrimaryKeyField}}
|
{{if and .Config.GenerateSetID .PrimaryKeyField}}
|
||||||
// SetID sets the primary key value
|
// SetID sets the primary key value
|
||||||
func (m {{.Name}}) SetID(newid int64) {
|
func (m {{.Name}}) SetID(newid {{.PrimaryKeyIDType}}) {
|
||||||
m.UpdateID(newid)
|
m.UpdateID(newid)
|
||||||
}
|
}
|
||||||
{{end}}
|
{{end}}
|
||||||
{{if and .Config.GenerateUpdateID .PrimaryKeyField}}
|
{{if and .Config.GenerateUpdateID .PrimaryKeyField}}
|
||||||
// UpdateID updates the primary key value
|
// UpdateID updates the primary key value
|
||||||
func (m *{{.Name}}) UpdateID(newid int64) {
|
func (m *{{.Name}}) UpdateID(newid {{.PrimaryKeyIDType}}) {
|
||||||
{{if .PrimaryKeyIsSQL -}}
|
{{if .PrimaryKeyIsSQL -}}
|
||||||
m.{{.PrimaryKeyField}}.FromString(fmt.Sprintf("%d", newid))
|
{{if .PrimaryKeyIsStr -}}
|
||||||
|
m.{{.PrimaryKeyField}}.FromString(newid)
|
||||||
{{- else -}}
|
{{- else -}}
|
||||||
m.{{.PrimaryKeyField}} = int32(newid)
|
m.{{.PrimaryKeyField}}.FromString(fmt.Sprintf("%d", newid))
|
||||||
|
{{- end}}
|
||||||
|
{{- else -}}
|
||||||
|
{{if .PrimaryKeyIsStr -}}
|
||||||
|
m.{{.PrimaryKeyField}} = newid
|
||||||
|
{{- else -}}
|
||||||
|
m.{{.PrimaryKeyField}} = {{.PrimaryKeyType}}(newid)
|
||||||
|
{{- end}}
|
||||||
{{- end}}
|
{{- end}}
|
||||||
}
|
}
|
||||||
{{end}}
|
{{end}}
|
||||||
|
|||||||
@@ -323,7 +323,7 @@ func (tm *TypeMapper) BuildBunTag(column *models.Column, table *models.Table) st
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
parts = append(parts, fmt.Sprintf("type:%s", typeStr))
|
parts = append(parts, fmt.Sprintf("type:%s", typeStr))
|
||||||
if isArray {
|
if isArray && tm.typeStyle == writers.NullableTypeStdlib {
|
||||||
parts = append(parts, "array")
|
parts = append(parts, "array")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -102,8 +102,8 @@ func (w *Writer) writeSingleFile(db *models.Database) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add fmt import if GetIDStr is enabled
|
// Add fmt import when generated helper methods need string formatting.
|
||||||
if w.config.GenerateGetIDStr {
|
if w.needsFmtImport(templateData.Models) {
|
||||||
templateData.AddImport("\"fmt\"")
|
templateData.AddImport("\"fmt\"")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -195,8 +195,8 @@ func (w *Writer) writeMultiFile(db *models.Database) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add fmt import if GetIDStr is enabled
|
// Add fmt import when generated helper methods need string formatting.
|
||||||
if w.config.GenerateGetIDStr {
|
if w.needsFmtImport(templateData.Models) {
|
||||||
templateData.AddImport("\"fmt\"")
|
templateData.AddImport("\"fmt\"")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -301,6 +301,26 @@ func (w *Writer) addRelationshipFields(modelData *ModelData, table *models.Table
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (w *Writer) needsFmtImport(models []*ModelData) bool {
|
||||||
|
if w.config.GenerateGetIDStr {
|
||||||
|
for _, model := range models {
|
||||||
|
if model.PrimaryKeyField != "" && !model.PrimaryKeyIsSQL && !model.PrimaryKeyIsStr {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if w.config.GenerateUpdateID {
|
||||||
|
for _, model := range models {
|
||||||
|
if model.PrimaryKeyField != "" && model.PrimaryKeyIsSQL && !model.PrimaryKeyIsStr {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
// findTable finds a table by schema and name in the database
|
// findTable finds a table by schema and name in the database
|
||||||
func (w *Writer) findTable(schemaName, tableName string, db *models.Database) *models.Table {
|
func (w *Writer) findTable(schemaName, tableName string, db *models.Database) *models.Table {
|
||||||
for _, schema := range db.Schemas {
|
for _, schema := range db.Schemas {
|
||||||
|
|||||||
@@ -590,6 +590,116 @@ func TestTypeMapper_SQLTypeToGoType_Bun(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestWriter_UpdateIDTypeSafety_Bun(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
pkType string
|
||||||
|
expectedPK string
|
||||||
|
expectedLine string
|
||||||
|
forbidInt32 bool
|
||||||
|
}{
|
||||||
|
{"int32_pk", "int", "int32", "m.ID = int32(newid)", false},
|
||||||
|
{"sql_int16_pk", "smallint", "resolvespec_common.SqlInt16", "m.ID.FromString(fmt.Sprintf(\"%d\", newid))", true},
|
||||||
|
{"int64_pk", "bigint", "int64", "m.ID = int64(newid)", true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
table := models.InitTable("test_table", "public")
|
||||||
|
table.Columns["id"] = &models.Column{
|
||||||
|
Name: "id",
|
||||||
|
Type: tt.pkType,
|
||||||
|
NotNull: true,
|
||||||
|
IsPrimaryKey: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
opts := &writers.WriterOptions{
|
||||||
|
PackageName: "models",
|
||||||
|
OutputPath: filepath.Join(tmpDir, "test.go"),
|
||||||
|
}
|
||||||
|
|
||||||
|
writer := NewWriter(opts)
|
||||||
|
err := writer.WriteTable(table)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("WriteTable failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
content, err := os.ReadFile(opts.OutputPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to read generated file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
generated := string(content)
|
||||||
|
|
||||||
|
if !strings.Contains(generated, tt.expectedLine) {
|
||||||
|
t.Errorf("Expected UpdateID to include %s\nGenerated:\n%s", tt.expectedLine, generated)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(generated, "ID "+tt.expectedPK) {
|
||||||
|
t.Errorf("Expected generated primary key field type %s\nGenerated:\n%s", tt.expectedPK, generated)
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.forbidInt32 && strings.Contains(generated, "int32(newid)") {
|
||||||
|
t.Errorf("UpdateID should not cast to int32 for %s type\nGenerated:\n%s", tt.pkType, generated)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(generated, "UpdateID(newid int64)") {
|
||||||
|
t.Errorf("UpdateID should accept int64 parameter\nGenerated:\n%s", generated)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWriter_StringPrimaryKeyHelpers_Bun(t *testing.T) {
|
||||||
|
table := models.InitTable("accounts", "public")
|
||||||
|
table.Columns["id"] = &models.Column{
|
||||||
|
Name: "id",
|
||||||
|
Type: "uuid",
|
||||||
|
NotNull: true,
|
||||||
|
IsPrimaryKey: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
opts := &writers.WriterOptions{
|
||||||
|
PackageName: "models",
|
||||||
|
OutputPath: filepath.Join(tmpDir, "test.go"),
|
||||||
|
}
|
||||||
|
|
||||||
|
writer := NewWriter(opts)
|
||||||
|
err := writer.WriteTable(table)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("WriteTable failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
content, err := os.ReadFile(opts.OutputPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to read generated file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
generated := string(content)
|
||||||
|
|
||||||
|
expectations := []string{
|
||||||
|
"resolvespec_common.SqlUUID",
|
||||||
|
"func (m ModelPublicAccounts) GetID() string",
|
||||||
|
"return m.ID.String()",
|
||||||
|
"func (m ModelPublicAccounts) GetIDStr() string",
|
||||||
|
"func (m ModelPublicAccounts) SetID(newid string)",
|
||||||
|
"func (m *ModelPublicAccounts) UpdateID(newid string)",
|
||||||
|
"m.ID.FromString(newid)",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, expected := range expectations {
|
||||||
|
if !strings.Contains(generated, expected) {
|
||||||
|
t.Errorf("Generated code missing expected content: %q\nGenerated:\n%s", expected, generated)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.Contains(generated, "GetID() int64") || strings.Contains(generated, "UpdateID(newid int64)") {
|
||||||
|
t.Errorf("String primary keys should not use int64 helper signatures\nGenerated:\n%s", generated)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestTypeMapper_BuildBunTag(t *testing.T) {
|
func TestTypeMapper_BuildBunTag(t *testing.T) {
|
||||||
mapper := NewTypeMapper("")
|
mapper := NewTypeMapper("")
|
||||||
|
|
||||||
@@ -696,7 +806,7 @@ func TestTypeMapper_BuildBunTag(t *testing.T) {
|
|||||||
Type: "text[]",
|
Type: "text[]",
|
||||||
NotNull: false,
|
NotNull: false,
|
||||||
},
|
},
|
||||||
want: []string{"tags,", "type:text[],", "array,"},
|
want: []string{"tags,", "type:text[],"},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "integer array type",
|
name: "integer array type",
|
||||||
@@ -705,7 +815,7 @@ func TestTypeMapper_BuildBunTag(t *testing.T) {
|
|||||||
Type: "integer[]",
|
Type: "integer[]",
|
||||||
NotNull: true,
|
NotNull: true,
|
||||||
},
|
},
|
||||||
want: []string{"scores,", "type:integer[],", "array,"},
|
want: []string{"scores,", "type:integer[],"},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -717,6 +827,30 @@ func TestTypeMapper_BuildBunTag(t *testing.T) {
|
|||||||
t.Errorf("BuildBunTag() = %q, missing %q", result, part)
|
t.Errorf("BuildBunTag() = %q, missing %q", result, part)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// resolvespec mode must NOT add "array" — SqlXxxArray uses sql.Scanner
|
||||||
|
if strings.Contains(result, ",array,") || strings.HasSuffix(result, ",array,") {
|
||||||
|
t.Errorf("BuildBunTag() = %q, must not contain 'array' in resolvespec mode", result)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTypeMapper_BuildBunTag_StdlibArrayHasArrayTag(t *testing.T) {
|
||||||
|
mapper := NewTypeMapper(writers.NullableTypeStdlib)
|
||||||
|
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
column *models.Column
|
||||||
|
}{
|
||||||
|
{name: "text array", column: &models.Column{Name: "tags", Type: "text[]"}},
|
||||||
|
{name: "integer array", column: &models.Column{Name: "scores", Type: "integer[]", NotNull: true}},
|
||||||
|
}
|
||||||
|
for _, tt := range cases {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := mapper.BuildBunTag(tt.column, nil)
|
||||||
|
if !strings.Contains(result, "array") {
|
||||||
|
t.Errorf("BuildBunTag() = %q, expected 'array' in stdlib mode", result)
|
||||||
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package gorm
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"sort"
|
"sort"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||||
"git.warky.dev/wdevs/relspecgo/pkg/writers"
|
"git.warky.dev/wdevs/relspecgo/pkg/writers"
|
||||||
@@ -17,17 +18,20 @@ type TemplateData struct {
|
|||||||
|
|
||||||
// ModelData represents a single model/struct in the template
|
// ModelData represents a single model/struct in the template
|
||||||
type ModelData struct {
|
type ModelData struct {
|
||||||
Name string
|
Name string
|
||||||
TableName string // schema.table format
|
TableName string // schema.table format
|
||||||
SchemaName string
|
SchemaName string
|
||||||
TableNameOnly string // just table name without schema
|
TableNameOnly string // just table name without schema
|
||||||
Comment string
|
Comment string
|
||||||
Fields []*FieldData
|
Fields []*FieldData
|
||||||
Config *MethodConfig
|
Config *MethodConfig
|
||||||
PrimaryKeyField string // Name of the primary key field
|
PrimaryKeyField string // Name of the primary key field
|
||||||
PrimaryKeyType string // Go type of the primary key field
|
PrimaryKeyType string // Go type of the primary key field
|
||||||
IDColumnName string // Name of the ID column in database
|
PrimaryKeyIsSQL bool // Whether PK uses a SQL wrapper type
|
||||||
Prefix string // 3-letter prefix
|
PrimaryKeyIsStr bool // Whether helper methods should use string IDs
|
||||||
|
PrimaryKeyIDType string // Helper method GetID/SetID/UpdateID type
|
||||||
|
IDColumnName string // Name of the ID column in database
|
||||||
|
Prefix string // 3-letter prefix
|
||||||
}
|
}
|
||||||
|
|
||||||
// FieldData represents a single field in a struct
|
// FieldData represents a single field in a struct
|
||||||
@@ -136,7 +140,14 @@ func NewModelData(table *models.Table, schema string, typeMapper *TypeMapper, fl
|
|||||||
// Sanitize column name to remove backticks
|
// Sanitize column name to remove backticks
|
||||||
safeName := writers.SanitizeStructTagValue(col.Name)
|
safeName := writers.SanitizeStructTagValue(col.Name)
|
||||||
model.PrimaryKeyField = SnakeCaseToPascalCase(safeName)
|
model.PrimaryKeyField = SnakeCaseToPascalCase(safeName)
|
||||||
model.PrimaryKeyType = typeMapper.SQLTypeToGoType(col.Type, col.NotNull)
|
goType := typeMapper.SQLTypeToGoType(col.Type, col.NotNull)
|
||||||
|
model.PrimaryKeyType = goType
|
||||||
|
model.PrimaryKeyIsSQL = strings.Contains(goType, "sql_types.") || strings.Contains(goType, "sql.")
|
||||||
|
model.PrimaryKeyIsStr = isStringLikePrimaryKeyType(goType)
|
||||||
|
model.PrimaryKeyIDType = "int64"
|
||||||
|
if model.PrimaryKeyIsStr {
|
||||||
|
model.PrimaryKeyIDType = "string"
|
||||||
|
}
|
||||||
model.IDColumnName = safeName
|
model.IDColumnName = safeName
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@@ -189,6 +200,15 @@ func formatComment(description, comment string) string {
|
|||||||
return comment
|
return comment
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isStringLikePrimaryKeyType(goType string) bool {
|
||||||
|
switch goType {
|
||||||
|
case "string", "sql.NullString", "sql_types.SqlString", "sql_types.SqlUUID":
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// resolveFieldNameCollision checks if a field name conflicts with generated method names
|
// resolveFieldNameCollision checks if a field name conflicts with generated method names
|
||||||
// and adds an underscore suffix if there's a collision
|
// and adds an underscore suffix if there's a collision
|
||||||
func resolveFieldNameCollision(fieldName string) string {
|
func resolveFieldNameCollision(fieldName string) string {
|
||||||
|
|||||||
@@ -43,26 +43,56 @@ func (m {{.Name}}) SchemaName() string {
|
|||||||
{{end}}
|
{{end}}
|
||||||
{{if and .Config.GenerateGetID .PrimaryKeyField}}
|
{{if and .Config.GenerateGetID .PrimaryKeyField}}
|
||||||
// GetID returns the primary key value
|
// GetID returns the primary key value
|
||||||
func (m {{.Name}}) GetID() int64 {
|
func (m {{.Name}}) GetID() {{.PrimaryKeyIDType}} {
|
||||||
|
{{if .PrimaryKeyIsSQL -}}
|
||||||
|
{{if .PrimaryKeyIsStr -}}
|
||||||
|
return m.{{.PrimaryKeyField}}.String()
|
||||||
|
{{- else -}}
|
||||||
|
return m.{{.PrimaryKeyField}}.Int64()
|
||||||
|
{{- end}}
|
||||||
|
{{- else -}}
|
||||||
|
{{if .PrimaryKeyIsStr -}}
|
||||||
|
return m.{{.PrimaryKeyField}}
|
||||||
|
{{- else -}}
|
||||||
return int64(m.{{.PrimaryKeyField}})
|
return int64(m.{{.PrimaryKeyField}})
|
||||||
|
{{- end}}
|
||||||
|
{{- end}}
|
||||||
}
|
}
|
||||||
{{end}}
|
{{end}}
|
||||||
{{if and .Config.GenerateGetIDStr .PrimaryKeyField}}
|
{{if and .Config.GenerateGetIDStr .PrimaryKeyField}}
|
||||||
// GetIDStr returns the primary key as a string
|
// GetIDStr returns the primary key as a string
|
||||||
func (m {{.Name}}) GetIDStr() string {
|
func (m {{.Name}}) GetIDStr() string {
|
||||||
|
{{if .PrimaryKeyIsSQL -}}
|
||||||
|
return m.{{.PrimaryKeyField}}.String()
|
||||||
|
{{- else if .PrimaryKeyIsStr -}}
|
||||||
|
return m.{{.PrimaryKeyField}}
|
||||||
|
{{- else -}}
|
||||||
return fmt.Sprintf("%d", m.{{.PrimaryKeyField}})
|
return fmt.Sprintf("%d", m.{{.PrimaryKeyField}})
|
||||||
|
{{- end}}
|
||||||
}
|
}
|
||||||
{{end}}
|
{{end}}
|
||||||
{{if and .Config.GenerateSetID .PrimaryKeyField}}
|
{{if and .Config.GenerateSetID .PrimaryKeyField}}
|
||||||
// SetID sets the primary key value
|
// SetID sets the primary key value
|
||||||
func (m {{.Name}}) SetID(newid int64) {
|
func (m {{.Name}}) SetID(newid {{.PrimaryKeyIDType}}) {
|
||||||
m.UpdateID(newid)
|
m.UpdateID(newid)
|
||||||
}
|
}
|
||||||
{{end}}
|
{{end}}
|
||||||
{{if and .Config.GenerateUpdateID .PrimaryKeyField}}
|
{{if and .Config.GenerateUpdateID .PrimaryKeyField}}
|
||||||
// UpdateID updates the primary key value
|
// UpdateID updates the primary key value
|
||||||
func (m *{{.Name}}) UpdateID(newid int64) {
|
func (m *{{.Name}}) UpdateID(newid {{.PrimaryKeyIDType}}) {
|
||||||
|
{{if .PrimaryKeyIsSQL -}}
|
||||||
|
{{if .PrimaryKeyIsStr -}}
|
||||||
|
m.{{.PrimaryKeyField}}.FromString(newid)
|
||||||
|
{{- else -}}
|
||||||
|
m.{{.PrimaryKeyField}}.FromString(fmt.Sprintf("%d", newid))
|
||||||
|
{{- end}}
|
||||||
|
{{- else -}}
|
||||||
|
{{if .PrimaryKeyIsStr -}}
|
||||||
|
m.{{.PrimaryKeyField}} = newid
|
||||||
|
{{- else -}}
|
||||||
m.{{.PrimaryKeyField}} = {{.PrimaryKeyType}}(newid)
|
m.{{.PrimaryKeyField}} = {{.PrimaryKeyType}}(newid)
|
||||||
|
{{- end}}
|
||||||
|
{{- end}}
|
||||||
}
|
}
|
||||||
{{end}}
|
{{end}}
|
||||||
{{if and .Config.GenerateGetIDName .IDColumnName}}
|
{{if and .Config.GenerateGetIDName .IDColumnName}}
|
||||||
|
|||||||
@@ -99,8 +99,8 @@ func (w *Writer) writeSingleFile(db *models.Database) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add fmt import if GetIDStr is enabled
|
// Add fmt import when generated helper methods need string formatting.
|
||||||
if w.config.GenerateGetIDStr {
|
if w.needsFmtImport(templateData.Models) {
|
||||||
templateData.AddImport("\"fmt\"")
|
templateData.AddImport("\"fmt\"")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -189,8 +189,8 @@ func (w *Writer) writeMultiFile(db *models.Database) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add fmt import if GetIDStr is enabled
|
// Add fmt import when generated helper methods need string formatting.
|
||||||
if w.config.GenerateGetIDStr {
|
if w.needsFmtImport(templateData.Models) {
|
||||||
templateData.AddImport("\"fmt\"")
|
templateData.AddImport("\"fmt\"")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -295,6 +295,26 @@ func (w *Writer) addRelationshipFields(modelData *ModelData, table *models.Table
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (w *Writer) needsFmtImport(models []*ModelData) bool {
|
||||||
|
if w.config.GenerateGetIDStr {
|
||||||
|
for _, model := range models {
|
||||||
|
if model.PrimaryKeyField != "" && !model.PrimaryKeyIsSQL && !model.PrimaryKeyIsStr {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if w.config.GenerateUpdateID {
|
||||||
|
for _, model := range models {
|
||||||
|
if model.PrimaryKeyField != "" && model.PrimaryKeyIsSQL && !model.PrimaryKeyIsStr {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
// findTable finds a table by schema and name in the database
|
// findTable finds a table by schema and name in the database
|
||||||
func (w *Writer) findTable(schemaName, tableName string, db *models.Database) *models.Table {
|
func (w *Writer) findTable(schemaName, tableName string, db *models.Database) *models.Table {
|
||||||
for _, schema := range db.Schemas {
|
for _, schema := range db.Schemas {
|
||||||
|
|||||||
@@ -598,6 +598,55 @@ func TestWriter_UpdateIDTypeSafety(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestWriter_StringPrimaryKeyHelpers_Gorm(t *testing.T) {
|
||||||
|
table := models.InitTable("accounts", "public")
|
||||||
|
table.Columns["id"] = &models.Column{
|
||||||
|
Name: "id",
|
||||||
|
Type: "uuid",
|
||||||
|
NotNull: true,
|
||||||
|
IsPrimaryKey: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
opts := &writers.WriterOptions{
|
||||||
|
PackageName: "models",
|
||||||
|
OutputPath: filepath.Join(tmpDir, "test.go"),
|
||||||
|
}
|
||||||
|
|
||||||
|
writer := NewWriter(opts)
|
||||||
|
err := writer.WriteTable(table)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("WriteTable failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
content, err := os.ReadFile(opts.OutputPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to read generated file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
generated := string(content)
|
||||||
|
|
||||||
|
expectations := []string{
|
||||||
|
"ID string",
|
||||||
|
"func (m ModelPublicAccounts) GetID() string",
|
||||||
|
"return m.ID",
|
||||||
|
"func (m ModelPublicAccounts) GetIDStr() string",
|
||||||
|
"func (m ModelPublicAccounts) SetID(newid string)",
|
||||||
|
"func (m *ModelPublicAccounts) UpdateID(newid string)",
|
||||||
|
"m.ID = newid",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, expected := range expectations {
|
||||||
|
if !strings.Contains(generated, expected) {
|
||||||
|
t.Errorf("Generated code missing expected content: %q\nGenerated:\n%s", expected, generated)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.Contains(generated, "GetID() int64") || strings.Contains(generated, "UpdateID(newid int64)") {
|
||||||
|
t.Errorf("String primary keys should not use int64 helper signatures\nGenerated:\n%s", generated)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestNameConverter_SnakeCaseToPascalCase(t *testing.T) {
|
func TestNameConverter_SnakeCaseToPascalCase(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
input string
|
input string
|
||||||
|
|||||||
Reference in New Issue
Block a user