Compare commits
6 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 53ff745d5d | |||
| 17bc8ed395 | |||
| a447b68b22 | |||
| 4303dcf59b | |||
| e828d48798 | |||
| 6e470a9239 |
@@ -1,6 +1,6 @@
|
||||
# Maintainer: Hein (Warky Devs) <hein@warky.dev>
|
||||
pkgname=relspec
|
||||
pkgver=1.0.51
|
||||
pkgver=1.0.54
|
||||
pkgrel=1
|
||||
pkgdesc="RelSpec is a comprehensive database relations management tool that reads, transforms, and writes database table specifications across multiple formats and ORMs."
|
||||
arch=('x86_64' 'aarch64')
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
Name: relspec
|
||||
Version: 1.0.51
|
||||
Version: 1.0.54
|
||||
Release: 1%{?dist}
|
||||
Summary: RelSpec is a comprehensive database relations management tool that reads, transforms, and writes database table specifications across multiple formats and ORMs.
|
||||
|
||||
|
||||
@@ -18,17 +18,20 @@ type TemplateData struct {
|
||||
|
||||
// ModelData represents a single model/struct in the template
|
||||
type ModelData struct {
|
||||
Name string
|
||||
TableName string // schema.table format
|
||||
SchemaName string
|
||||
TableNameOnly string // just table name without schema
|
||||
Comment string
|
||||
Fields []*FieldData
|
||||
Config *MethodConfig
|
||||
PrimaryKeyField string // Name of the primary key field
|
||||
PrimaryKeyIsSQL bool // Whether PK uses SQL type (needs .Int64() call)
|
||||
IDColumnName string // Name of the ID column in database
|
||||
Prefix string // 3-letter prefix
|
||||
Name string
|
||||
TableName string // schema.table format
|
||||
SchemaName string
|
||||
TableNameOnly string // just table name without schema
|
||||
Comment string
|
||||
Fields []*FieldData
|
||||
Config *MethodConfig
|
||||
PrimaryKeyField string // Name of the primary key field
|
||||
PrimaryKeyType string // Go type of the primary key field
|
||||
PrimaryKeyIsSQL bool // Whether PK uses SQL type (needs .Int64() call)
|
||||
PrimaryKeyIsStr bool // Whether helper methods should use string IDs
|
||||
PrimaryKeyIDType string // Helper method GetID/SetID/UpdateID type
|
||||
IDColumnName string // Name of the ID column in database
|
||||
Prefix string // 3-letter prefix
|
||||
}
|
||||
|
||||
// FieldData represents a single field in a struct
|
||||
@@ -140,7 +143,13 @@ func NewModelData(table *models.Table, schema string, typeMapper *TypeMapper, fl
|
||||
model.IDColumnName = safeName
|
||||
// Check if PK type is a SQL type (contains resolvespec_common or sql_types)
|
||||
goType := typeMapper.SQLTypeToGoType(col.Type, col.NotNull)
|
||||
model.PrimaryKeyType = goType
|
||||
model.PrimaryKeyIsSQL = strings.Contains(goType, "resolvespec_common") || strings.Contains(goType, "sql_types")
|
||||
model.PrimaryKeyIsStr = isStringLikePrimaryKeyType(goType)
|
||||
model.PrimaryKeyIDType = "int64"
|
||||
if model.PrimaryKeyIsStr {
|
||||
model.PrimaryKeyIDType = "string"
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
@@ -192,6 +201,15 @@ func formatComment(description, comment string) string {
|
||||
return comment
|
||||
}
|
||||
|
||||
func isStringLikePrimaryKeyType(goType string) bool {
|
||||
switch goType {
|
||||
case "string", "sql.NullString", "resolvespec_common.SqlString", "resolvespec_common.SqlUUID":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// resolveFieldNameCollision checks if a field name conflicts with generated method names
|
||||
// and adds an underscore suffix if there's a collision
|
||||
func resolveFieldNameCollision(fieldName string) string {
|
||||
|
||||
@@ -44,33 +44,55 @@ func (m {{.Name}}) SchemaName() string {
|
||||
{{end}}
|
||||
{{if and .Config.GenerateGetID .PrimaryKeyField}}
|
||||
// GetID returns the primary key value
|
||||
func (m {{.Name}}) GetID() int64 {
|
||||
func (m {{.Name}}) GetID() {{.PrimaryKeyIDType}} {
|
||||
{{if .PrimaryKeyIsSQL -}}
|
||||
{{if .PrimaryKeyIsStr -}}
|
||||
return m.{{.PrimaryKeyField}}.String()
|
||||
{{- else -}}
|
||||
return m.{{.PrimaryKeyField}}.Int64()
|
||||
{{- end}}
|
||||
{{- else -}}
|
||||
{{if .PrimaryKeyIsStr -}}
|
||||
return m.{{.PrimaryKeyField}}
|
||||
{{- else -}}
|
||||
return int64(m.{{.PrimaryKeyField}})
|
||||
{{- end}}
|
||||
{{- end}}
|
||||
}
|
||||
{{end}}
|
||||
{{if and .Config.GenerateGetIDStr .PrimaryKeyField}}
|
||||
// GetIDStr returns the primary key as a string
|
||||
func (m {{.Name}}) GetIDStr() string {
|
||||
{{if .PrimaryKeyIsSQL -}}
|
||||
return m.{{.PrimaryKeyField}}.String()
|
||||
{{- else if .PrimaryKeyIsStr -}}
|
||||
return m.{{.PrimaryKeyField}}
|
||||
{{- else -}}
|
||||
return fmt.Sprintf("%d", m.{{.PrimaryKeyField}})
|
||||
{{- end}}
|
||||
}
|
||||
{{end}}
|
||||
{{if and .Config.GenerateSetID .PrimaryKeyField}}
|
||||
// SetID sets the primary key value
|
||||
func (m {{.Name}}) SetID(newid int64) {
|
||||
func (m {{.Name}}) SetID(newid {{.PrimaryKeyIDType}}) {
|
||||
m.UpdateID(newid)
|
||||
}
|
||||
{{end}}
|
||||
{{if and .Config.GenerateUpdateID .PrimaryKeyField}}
|
||||
// UpdateID updates the primary key value
|
||||
func (m *{{.Name}}) UpdateID(newid int64) {
|
||||
func (m *{{.Name}}) UpdateID(newid {{.PrimaryKeyIDType}}) {
|
||||
{{if .PrimaryKeyIsSQL -}}
|
||||
m.{{.PrimaryKeyField}}.FromString(fmt.Sprintf("%d", newid))
|
||||
{{if .PrimaryKeyIsStr -}}
|
||||
m.{{.PrimaryKeyField}}.FromString(newid)
|
||||
{{- else -}}
|
||||
m.{{.PrimaryKeyField}} = int32(newid)
|
||||
m.{{.PrimaryKeyField}}.FromString(fmt.Sprintf("%d", newid))
|
||||
{{- end}}
|
||||
{{- else -}}
|
||||
{{if .PrimaryKeyIsStr -}}
|
||||
m.{{.PrimaryKeyField}} = newid
|
||||
{{- else -}}
|
||||
m.{{.PrimaryKeyField}} = {{.PrimaryKeyType}}(newid)
|
||||
{{- end}}
|
||||
{{- end}}
|
||||
}
|
||||
{{end}}
|
||||
|
||||
@@ -323,7 +323,7 @@ func (tm *TypeMapper) BuildBunTag(column *models.Column, table *models.Table) st
|
||||
}
|
||||
}
|
||||
parts = append(parts, fmt.Sprintf("type:%s", typeStr))
|
||||
if isArray {
|
||||
if isArray && tm.typeStyle == writers.NullableTypeStdlib {
|
||||
parts = append(parts, "array")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -102,8 +102,8 @@ func (w *Writer) writeSingleFile(db *models.Database) error {
|
||||
}
|
||||
}
|
||||
|
||||
// Add fmt import if GetIDStr is enabled
|
||||
if w.config.GenerateGetIDStr {
|
||||
// Add fmt import when generated helper methods need string formatting.
|
||||
if w.needsFmtImport(templateData.Models) {
|
||||
templateData.AddImport("\"fmt\"")
|
||||
}
|
||||
|
||||
@@ -195,8 +195,8 @@ func (w *Writer) writeMultiFile(db *models.Database) error {
|
||||
}
|
||||
}
|
||||
|
||||
// Add fmt import if GetIDStr is enabled
|
||||
if w.config.GenerateGetIDStr {
|
||||
// Add fmt import when generated helper methods need string formatting.
|
||||
if w.needsFmtImport(templateData.Models) {
|
||||
templateData.AddImport("\"fmt\"")
|
||||
}
|
||||
|
||||
@@ -301,6 +301,26 @@ func (w *Writer) addRelationshipFields(modelData *ModelData, table *models.Table
|
||||
}
|
||||
}
|
||||
|
||||
func (w *Writer) needsFmtImport(models []*ModelData) bool {
|
||||
if w.config.GenerateGetIDStr {
|
||||
for _, model := range models {
|
||||
if model.PrimaryKeyField != "" && !model.PrimaryKeyIsSQL && !model.PrimaryKeyIsStr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if w.config.GenerateUpdateID {
|
||||
for _, model := range models {
|
||||
if model.PrimaryKeyField != "" && model.PrimaryKeyIsSQL && !model.PrimaryKeyIsStr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// findTable finds a table by schema and name in the database
|
||||
func (w *Writer) findTable(schemaName, tableName string, db *models.Database) *models.Table {
|
||||
for _, schema := range db.Schemas {
|
||||
|
||||
@@ -590,6 +590,116 @@ func TestTypeMapper_SQLTypeToGoType_Bun(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriter_UpdateIDTypeSafety_Bun(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
pkType string
|
||||
expectedPK string
|
||||
expectedLine string
|
||||
forbidInt32 bool
|
||||
}{
|
||||
{"int32_pk", "int", "int32", "m.ID = int32(newid)", false},
|
||||
{"sql_int16_pk", "smallint", "resolvespec_common.SqlInt16", "m.ID.FromString(fmt.Sprintf(\"%d\", newid))", true},
|
||||
{"int64_pk", "bigint", "int64", "m.ID = int64(newid)", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
table := models.InitTable("test_table", "public")
|
||||
table.Columns["id"] = &models.Column{
|
||||
Name: "id",
|
||||
Type: tt.pkType,
|
||||
NotNull: true,
|
||||
IsPrimaryKey: true,
|
||||
}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
opts := &writers.WriterOptions{
|
||||
PackageName: "models",
|
||||
OutputPath: filepath.Join(tmpDir, "test.go"),
|
||||
}
|
||||
|
||||
writer := NewWriter(opts)
|
||||
err := writer.WriteTable(table)
|
||||
if err != nil {
|
||||
t.Fatalf("WriteTable failed: %v", err)
|
||||
}
|
||||
|
||||
content, err := os.ReadFile(opts.OutputPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read generated file: %v", err)
|
||||
}
|
||||
|
||||
generated := string(content)
|
||||
|
||||
if !strings.Contains(generated, tt.expectedLine) {
|
||||
t.Errorf("Expected UpdateID to include %s\nGenerated:\n%s", tt.expectedLine, generated)
|
||||
}
|
||||
|
||||
if !strings.Contains(generated, "ID "+tt.expectedPK) {
|
||||
t.Errorf("Expected generated primary key field type %s\nGenerated:\n%s", tt.expectedPK, generated)
|
||||
}
|
||||
|
||||
if tt.forbidInt32 && strings.Contains(generated, "int32(newid)") {
|
||||
t.Errorf("UpdateID should not cast to int32 for %s type\nGenerated:\n%s", tt.pkType, generated)
|
||||
}
|
||||
|
||||
if !strings.Contains(generated, "UpdateID(newid int64)") {
|
||||
t.Errorf("UpdateID should accept int64 parameter\nGenerated:\n%s", generated)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriter_StringPrimaryKeyHelpers_Bun(t *testing.T) {
|
||||
table := models.InitTable("accounts", "public")
|
||||
table.Columns["id"] = &models.Column{
|
||||
Name: "id",
|
||||
Type: "uuid",
|
||||
NotNull: true,
|
||||
IsPrimaryKey: true,
|
||||
}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
opts := &writers.WriterOptions{
|
||||
PackageName: "models",
|
||||
OutputPath: filepath.Join(tmpDir, "test.go"),
|
||||
}
|
||||
|
||||
writer := NewWriter(opts)
|
||||
err := writer.WriteTable(table)
|
||||
if err != nil {
|
||||
t.Fatalf("WriteTable failed: %v", err)
|
||||
}
|
||||
|
||||
content, err := os.ReadFile(opts.OutputPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read generated file: %v", err)
|
||||
}
|
||||
|
||||
generated := string(content)
|
||||
|
||||
expectations := []string{
|
||||
"resolvespec_common.SqlUUID",
|
||||
"func (m ModelPublicAccounts) GetID() string",
|
||||
"return m.ID.String()",
|
||||
"func (m ModelPublicAccounts) GetIDStr() string",
|
||||
"func (m ModelPublicAccounts) SetID(newid string)",
|
||||
"func (m *ModelPublicAccounts) UpdateID(newid string)",
|
||||
"m.ID.FromString(newid)",
|
||||
}
|
||||
|
||||
for _, expected := range expectations {
|
||||
if !strings.Contains(generated, expected) {
|
||||
t.Errorf("Generated code missing expected content: %q\nGenerated:\n%s", expected, generated)
|
||||
}
|
||||
}
|
||||
|
||||
if strings.Contains(generated, "GetID() int64") || strings.Contains(generated, "UpdateID(newid int64)") {
|
||||
t.Errorf("String primary keys should not use int64 helper signatures\nGenerated:\n%s", generated)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTypeMapper_BuildBunTag(t *testing.T) {
|
||||
mapper := NewTypeMapper("")
|
||||
|
||||
@@ -696,7 +806,7 @@ func TestTypeMapper_BuildBunTag(t *testing.T) {
|
||||
Type: "text[]",
|
||||
NotNull: false,
|
||||
},
|
||||
want: []string{"tags,", "type:text[],", "array,"},
|
||||
want: []string{"tags,", "type:text[],"},
|
||||
},
|
||||
{
|
||||
name: "integer array type",
|
||||
@@ -705,7 +815,7 @@ func TestTypeMapper_BuildBunTag(t *testing.T) {
|
||||
Type: "integer[]",
|
||||
NotNull: true,
|
||||
},
|
||||
want: []string{"scores,", "type:integer[],", "array,"},
|
||||
want: []string{"scores,", "type:integer[],"},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -717,6 +827,30 @@ func TestTypeMapper_BuildBunTag(t *testing.T) {
|
||||
t.Errorf("BuildBunTag() = %q, missing %q", result, part)
|
||||
}
|
||||
}
|
||||
// resolvespec mode must NOT add "array" — SqlXxxArray uses sql.Scanner
|
||||
if strings.Contains(result, ",array,") || strings.HasSuffix(result, ",array,") {
|
||||
t.Errorf("BuildBunTag() = %q, must not contain 'array' in resolvespec mode", result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTypeMapper_BuildBunTag_StdlibArrayHasArrayTag(t *testing.T) {
|
||||
mapper := NewTypeMapper(writers.NullableTypeStdlib)
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
column *models.Column
|
||||
}{
|
||||
{name: "text array", column: &models.Column{Name: "tags", Type: "text[]"}},
|
||||
{name: "integer array", column: &models.Column{Name: "scores", Type: "integer[]", NotNull: true}},
|
||||
}
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := mapper.BuildBunTag(tt.column, nil)
|
||||
if !strings.Contains(result, "array") {
|
||||
t.Errorf("BuildBunTag() = %q, expected 'array' in stdlib mode", result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package gorm
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/writers"
|
||||
@@ -17,17 +18,20 @@ type TemplateData struct {
|
||||
|
||||
// ModelData represents a single model/struct in the template
|
||||
type ModelData struct {
|
||||
Name string
|
||||
TableName string // schema.table format
|
||||
SchemaName string
|
||||
TableNameOnly string // just table name without schema
|
||||
Comment string
|
||||
Fields []*FieldData
|
||||
Config *MethodConfig
|
||||
PrimaryKeyField string // Name of the primary key field
|
||||
PrimaryKeyType string // Go type of the primary key field
|
||||
IDColumnName string // Name of the ID column in database
|
||||
Prefix string // 3-letter prefix
|
||||
Name string
|
||||
TableName string // schema.table format
|
||||
SchemaName string
|
||||
TableNameOnly string // just table name without schema
|
||||
Comment string
|
||||
Fields []*FieldData
|
||||
Config *MethodConfig
|
||||
PrimaryKeyField string // Name of the primary key field
|
||||
PrimaryKeyType string // Go type of the primary key field
|
||||
PrimaryKeyIsSQL bool // Whether PK uses a SQL wrapper type
|
||||
PrimaryKeyIsStr bool // Whether helper methods should use string IDs
|
||||
PrimaryKeyIDType string // Helper method GetID/SetID/UpdateID type
|
||||
IDColumnName string // Name of the ID column in database
|
||||
Prefix string // 3-letter prefix
|
||||
}
|
||||
|
||||
// FieldData represents a single field in a struct
|
||||
@@ -136,7 +140,14 @@ func NewModelData(table *models.Table, schema string, typeMapper *TypeMapper, fl
|
||||
// Sanitize column name to remove backticks
|
||||
safeName := writers.SanitizeStructTagValue(col.Name)
|
||||
model.PrimaryKeyField = SnakeCaseToPascalCase(safeName)
|
||||
model.PrimaryKeyType = typeMapper.SQLTypeToGoType(col.Type, col.NotNull)
|
||||
goType := typeMapper.SQLTypeToGoType(col.Type, col.NotNull)
|
||||
model.PrimaryKeyType = goType
|
||||
model.PrimaryKeyIsSQL = strings.Contains(goType, "sql_types.") || strings.Contains(goType, "sql.")
|
||||
model.PrimaryKeyIsStr = isStringLikePrimaryKeyType(goType)
|
||||
model.PrimaryKeyIDType = "int64"
|
||||
if model.PrimaryKeyIsStr {
|
||||
model.PrimaryKeyIDType = "string"
|
||||
}
|
||||
model.IDColumnName = safeName
|
||||
break
|
||||
}
|
||||
@@ -189,6 +200,15 @@ func formatComment(description, comment string) string {
|
||||
return comment
|
||||
}
|
||||
|
||||
func isStringLikePrimaryKeyType(goType string) bool {
|
||||
switch goType {
|
||||
case "string", "sql.NullString", "sql_types.SqlString", "sql_types.SqlUUID":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// resolveFieldNameCollision checks if a field name conflicts with generated method names
|
||||
// and adds an underscore suffix if there's a collision
|
||||
func resolveFieldNameCollision(fieldName string) string {
|
||||
|
||||
@@ -43,26 +43,56 @@ func (m {{.Name}}) SchemaName() string {
|
||||
{{end}}
|
||||
{{if and .Config.GenerateGetID .PrimaryKeyField}}
|
||||
// GetID returns the primary key value
|
||||
func (m {{.Name}}) GetID() int64 {
|
||||
func (m {{.Name}}) GetID() {{.PrimaryKeyIDType}} {
|
||||
{{if .PrimaryKeyIsSQL -}}
|
||||
{{if .PrimaryKeyIsStr -}}
|
||||
return m.{{.PrimaryKeyField}}.String()
|
||||
{{- else -}}
|
||||
return m.{{.PrimaryKeyField}}.Int64()
|
||||
{{- end}}
|
||||
{{- else -}}
|
||||
{{if .PrimaryKeyIsStr -}}
|
||||
return m.{{.PrimaryKeyField}}
|
||||
{{- else -}}
|
||||
return int64(m.{{.PrimaryKeyField}})
|
||||
{{- end}}
|
||||
{{- end}}
|
||||
}
|
||||
{{end}}
|
||||
{{if and .Config.GenerateGetIDStr .PrimaryKeyField}}
|
||||
// GetIDStr returns the primary key as a string
|
||||
func (m {{.Name}}) GetIDStr() string {
|
||||
{{if .PrimaryKeyIsSQL -}}
|
||||
return m.{{.PrimaryKeyField}}.String()
|
||||
{{- else if .PrimaryKeyIsStr -}}
|
||||
return m.{{.PrimaryKeyField}}
|
||||
{{- else -}}
|
||||
return fmt.Sprintf("%d", m.{{.PrimaryKeyField}})
|
||||
{{- end}}
|
||||
}
|
||||
{{end}}
|
||||
{{if and .Config.GenerateSetID .PrimaryKeyField}}
|
||||
// SetID sets the primary key value
|
||||
func (m {{.Name}}) SetID(newid int64) {
|
||||
func (m {{.Name}}) SetID(newid {{.PrimaryKeyIDType}}) {
|
||||
m.UpdateID(newid)
|
||||
}
|
||||
{{end}}
|
||||
{{if and .Config.GenerateUpdateID .PrimaryKeyField}}
|
||||
// UpdateID updates the primary key value
|
||||
func (m *{{.Name}}) UpdateID(newid int64) {
|
||||
func (m *{{.Name}}) UpdateID(newid {{.PrimaryKeyIDType}}) {
|
||||
{{if .PrimaryKeyIsSQL -}}
|
||||
{{if .PrimaryKeyIsStr -}}
|
||||
m.{{.PrimaryKeyField}}.FromString(newid)
|
||||
{{- else -}}
|
||||
m.{{.PrimaryKeyField}}.FromString(fmt.Sprintf("%d", newid))
|
||||
{{- end}}
|
||||
{{- else -}}
|
||||
{{if .PrimaryKeyIsStr -}}
|
||||
m.{{.PrimaryKeyField}} = newid
|
||||
{{- else -}}
|
||||
m.{{.PrimaryKeyField}} = {{.PrimaryKeyType}}(newid)
|
||||
{{- end}}
|
||||
{{- end}}
|
||||
}
|
||||
{{end}}
|
||||
{{if and .Config.GenerateGetIDName .IDColumnName}}
|
||||
|
||||
@@ -99,8 +99,8 @@ func (w *Writer) writeSingleFile(db *models.Database) error {
|
||||
}
|
||||
}
|
||||
|
||||
// Add fmt import if GetIDStr is enabled
|
||||
if w.config.GenerateGetIDStr {
|
||||
// Add fmt import when generated helper methods need string formatting.
|
||||
if w.needsFmtImport(templateData.Models) {
|
||||
templateData.AddImport("\"fmt\"")
|
||||
}
|
||||
|
||||
@@ -189,8 +189,8 @@ func (w *Writer) writeMultiFile(db *models.Database) error {
|
||||
}
|
||||
}
|
||||
|
||||
// Add fmt import if GetIDStr is enabled
|
||||
if w.config.GenerateGetIDStr {
|
||||
// Add fmt import when generated helper methods need string formatting.
|
||||
if w.needsFmtImport(templateData.Models) {
|
||||
templateData.AddImport("\"fmt\"")
|
||||
}
|
||||
|
||||
@@ -295,6 +295,26 @@ func (w *Writer) addRelationshipFields(modelData *ModelData, table *models.Table
|
||||
}
|
||||
}
|
||||
|
||||
func (w *Writer) needsFmtImport(models []*ModelData) bool {
|
||||
if w.config.GenerateGetIDStr {
|
||||
for _, model := range models {
|
||||
if model.PrimaryKeyField != "" && !model.PrimaryKeyIsSQL && !model.PrimaryKeyIsStr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if w.config.GenerateUpdateID {
|
||||
for _, model := range models {
|
||||
if model.PrimaryKeyField != "" && model.PrimaryKeyIsSQL && !model.PrimaryKeyIsStr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// findTable finds a table by schema and name in the database
|
||||
func (w *Writer) findTable(schemaName, tableName string, db *models.Database) *models.Table {
|
||||
for _, schema := range db.Schemas {
|
||||
|
||||
@@ -598,6 +598,55 @@ func TestWriter_UpdateIDTypeSafety(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriter_StringPrimaryKeyHelpers_Gorm(t *testing.T) {
|
||||
table := models.InitTable("accounts", "public")
|
||||
table.Columns["id"] = &models.Column{
|
||||
Name: "id",
|
||||
Type: "uuid",
|
||||
NotNull: true,
|
||||
IsPrimaryKey: true,
|
||||
}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
opts := &writers.WriterOptions{
|
||||
PackageName: "models",
|
||||
OutputPath: filepath.Join(tmpDir, "test.go"),
|
||||
}
|
||||
|
||||
writer := NewWriter(opts)
|
||||
err := writer.WriteTable(table)
|
||||
if err != nil {
|
||||
t.Fatalf("WriteTable failed: %v", err)
|
||||
}
|
||||
|
||||
content, err := os.ReadFile(opts.OutputPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read generated file: %v", err)
|
||||
}
|
||||
|
||||
generated := string(content)
|
||||
|
||||
expectations := []string{
|
||||
"ID string",
|
||||
"func (m ModelPublicAccounts) GetID() string",
|
||||
"return m.ID",
|
||||
"func (m ModelPublicAccounts) GetIDStr() string",
|
||||
"func (m ModelPublicAccounts) SetID(newid string)",
|
||||
"func (m *ModelPublicAccounts) UpdateID(newid string)",
|
||||
"m.ID = newid",
|
||||
}
|
||||
|
||||
for _, expected := range expectations {
|
||||
if !strings.Contains(generated, expected) {
|
||||
t.Errorf("Generated code missing expected content: %q\nGenerated:\n%s", expected, generated)
|
||||
}
|
||||
}
|
||||
|
||||
if strings.Contains(generated, "GetID() int64") || strings.Contains(generated, "UpdateID(newid int64)") {
|
||||
t.Errorf("String primary keys should not use int64 helper signatures\nGenerated:\n%s", generated)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNameConverter_SnakeCaseToPascalCase(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
|
||||
@@ -31,6 +31,10 @@ type MigrationWriter struct {
|
||||
|
||||
// NewMigrationWriter creates a new templated migration writer
|
||||
func NewMigrationWriter(options *writers.WriterOptions) (*MigrationWriter, error) {
|
||||
if options == nil {
|
||||
options = &writers.WriterOptions{}
|
||||
}
|
||||
|
||||
executor, err := NewTemplateExecutor(options.FlattenSchema)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create template executor: %w", err)
|
||||
@@ -44,6 +48,16 @@ func NewMigrationWriter(options *writers.WriterOptions) (*MigrationWriter, error
|
||||
|
||||
// WriteMigration generates migration scripts using templates
|
||||
func (w *MigrationWriter) WriteMigration(model *models.Database, current *models.Database) error {
|
||||
if model == nil {
|
||||
return fmt.Errorf("model database is required")
|
||||
}
|
||||
if w.options == nil {
|
||||
w.options = &writers.WriterOptions{}
|
||||
}
|
||||
if current == nil {
|
||||
current = models.InitDatabase(model.Name)
|
||||
}
|
||||
|
||||
var writer io.Writer
|
||||
var file *os.File
|
||||
var err error
|
||||
@@ -86,9 +100,16 @@ func (w *MigrationWriter) WriteMigration(model *models.Database, current *models
|
||||
|
||||
// Process each schema in the model
|
||||
for _, modelSchema := range model.Schemas {
|
||||
if modelSchema == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Find corresponding schema in current database
|
||||
var currentSchema *models.Schema
|
||||
for _, cs := range current.Schemas {
|
||||
if cs == nil {
|
||||
continue
|
||||
}
|
||||
if strings.EqualFold(cs.Name, modelSchema.Name) {
|
||||
currentSchema = cs
|
||||
break
|
||||
@@ -545,12 +566,17 @@ func (w *MigrationWriter) generateIndexScripts(model *models.Schema, current *mo
|
||||
indexType = modelIndex.Type
|
||||
}
|
||||
|
||||
columnExprs := buildIndexColumnExpressions(modelTable, modelIndex, indexType)
|
||||
if len(columnExprs) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
sql, err := w.executor.ExecuteCreateIndex(CreateIndexData{
|
||||
SchemaName: model.Name,
|
||||
TableName: modelTable.Name,
|
||||
IndexName: indexName,
|
||||
IndexType: indexType,
|
||||
Columns: strings.Join(modelIndex.Columns, ", "),
|
||||
Columns: strings.Join(columnExprs, ", "),
|
||||
Unique: modelIndex.Unique,
|
||||
})
|
||||
if err != nil {
|
||||
@@ -573,6 +599,27 @@ func (w *MigrationWriter) generateIndexScripts(model *models.Schema, current *mo
|
||||
return scripts, nil
|
||||
}
|
||||
|
||||
func buildIndexColumnExpressions(table *models.Table, index *models.Index, indexType string) []string {
|
||||
columnExprs := make([]string, 0, len(index.Columns))
|
||||
for _, colName := range index.Columns {
|
||||
colExpr := colName
|
||||
if table != nil {
|
||||
if col, ok := table.Columns[colName]; ok && col != nil {
|
||||
colExpr = col.SQLName()
|
||||
if strings.EqualFold(indexType, "gin") && isTextType(col.Type) {
|
||||
opClass := extractOperatorClass(index.Comment)
|
||||
if opClass == "" {
|
||||
opClass = "gin_trgm_ops"
|
||||
}
|
||||
colExpr = fmt.Sprintf("%s %s", col.SQLName(), opClass)
|
||||
}
|
||||
}
|
||||
}
|
||||
columnExprs = append(columnExprs, colExpr)
|
||||
}
|
||||
return columnExprs
|
||||
}
|
||||
|
||||
// generateForeignKeyScripts generates ADD CONSTRAINT FOREIGN KEY scripts using templates
|
||||
func (w *MigrationWriter) generateForeignKeyScripts(model *models.Schema, current *models.Schema) ([]MigrationScript, error) {
|
||||
scripts := make([]MigrationScript, 0)
|
||||
|
||||
@@ -97,6 +97,89 @@ func TestWriteMigration_ArrayDefault(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteMigration_GinIndexOnTextUsesTrigramOperatorClass(t *testing.T) {
|
||||
current := models.InitDatabase("testdb")
|
||||
currentSchema := models.InitSchema("public")
|
||||
current.Schemas = append(current.Schemas, currentSchema)
|
||||
|
||||
model := models.InitDatabase("testdb")
|
||||
modelSchema := models.InitSchema("public")
|
||||
|
||||
table := models.InitTable("articles", "public")
|
||||
titleCol := models.InitColumn("title", "articles", "public")
|
||||
titleCol.Type = "text"
|
||||
table.Columns["title"] = titleCol
|
||||
|
||||
index := &models.Index{
|
||||
Name: "idx_articles_title_gin",
|
||||
Type: "gin",
|
||||
Columns: []string{"title"},
|
||||
}
|
||||
table.Indexes[index.Name] = index
|
||||
|
||||
modelSchema.Tables = append(modelSchema.Tables, table)
|
||||
model.Schemas = append(model.Schemas, modelSchema)
|
||||
|
||||
var buf bytes.Buffer
|
||||
writer, err := NewMigrationWriter(&writers.WriterOptions{})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create writer: %v", err)
|
||||
}
|
||||
writer.writer = &buf
|
||||
|
||||
if err := writer.WriteMigration(model, current); err != nil {
|
||||
t.Fatalf("WriteMigration failed: %v", err)
|
||||
}
|
||||
|
||||
output := buf.String()
|
||||
if !strings.Contains(output, "USING gin (title gin_trgm_ops)") {
|
||||
t.Fatalf("expected GIN text index to include gin_trgm_ops, got:\n%s", output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteMigration_GinIndexOnTextArrayDoesNotUseTrigramOperatorClass(t *testing.T) {
|
||||
current := models.InitDatabase("testdb")
|
||||
currentSchema := models.InitSchema("public")
|
||||
current.Schemas = append(current.Schemas, currentSchema)
|
||||
|
||||
model := models.InitDatabase("testdb")
|
||||
modelSchema := models.InitSchema("public")
|
||||
|
||||
table := models.InitTable("plans", "public")
|
||||
tagsCol := models.InitColumn("tags", "plans", "public")
|
||||
tagsCol.Type = "text[]"
|
||||
table.Columns["tags"] = tagsCol
|
||||
|
||||
index := &models.Index{
|
||||
Name: "idx_plans_tags",
|
||||
Type: "gin",
|
||||
Columns: []string{"tags"},
|
||||
}
|
||||
table.Indexes[index.Name] = index
|
||||
|
||||
modelSchema.Tables = append(modelSchema.Tables, table)
|
||||
model.Schemas = append(model.Schemas, modelSchema)
|
||||
|
||||
var buf bytes.Buffer
|
||||
writer, err := NewMigrationWriter(&writers.WriterOptions{})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create writer: %v", err)
|
||||
}
|
||||
writer.writer = &buf
|
||||
|
||||
if err := writer.WriteMigration(model, current); err != nil {
|
||||
t.Fatalf("WriteMigration failed: %v", err)
|
||||
}
|
||||
|
||||
output := buf.String()
|
||||
if !strings.Contains(output, "USING gin (tags)") {
|
||||
t.Fatalf("expected GIN array index without explicit trigram opclass, got:\n%s", output)
|
||||
}
|
||||
if strings.Contains(output, "gin_trgm_ops") {
|
||||
t.Fatalf("did not expect gin_trgm_ops for text[] migration index, got:\n%s", output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteMigration_WithAudit(t *testing.T) {
|
||||
// Current database (empty)
|
||||
current := models.InitDatabase("testdb")
|
||||
@@ -322,3 +405,46 @@ func TestWriteMigration_NumericConstraintNames(t *testing.T) {
|
||||
t.Error("Migration missing FOREIGN KEY")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewMigrationWriter_NilOptions(t *testing.T) {
|
||||
writer, err := NewMigrationWriter(nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewMigrationWriter(nil) returned error: %v", err)
|
||||
}
|
||||
if writer == nil {
|
||||
t.Fatal("expected writer instance")
|
||||
}
|
||||
if writer.options == nil {
|
||||
t.Fatal("expected default writer options to be initialized")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteMigration_NilCurrentTreatsDatabaseAsEmpty(t *testing.T) {
|
||||
model := models.InitDatabase("testdb")
|
||||
modelSchema := models.InitSchema("public")
|
||||
|
||||
table := models.InitTable("users", "public")
|
||||
idCol := models.InitColumn("id", "users", "public")
|
||||
idCol.Type = "integer"
|
||||
idCol.NotNull = true
|
||||
table.Columns["id"] = idCol
|
||||
|
||||
modelSchema.Tables = append(modelSchema.Tables, table)
|
||||
model.Schemas = append(model.Schemas, modelSchema)
|
||||
|
||||
var buf bytes.Buffer
|
||||
writer, err := NewMigrationWriter(nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create writer: %v", err)
|
||||
}
|
||||
writer.writer = &buf
|
||||
|
||||
if err := writer.WriteMigration(model, nil); err != nil {
|
||||
t.Fatalf("WriteMigration with nil current failed: %v", err)
|
||||
}
|
||||
|
||||
output := buf.String()
|
||||
if !strings.Contains(output, "CREATE TABLE") {
|
||||
t.Fatalf("expected CREATE TABLE in migration output, got:\n%s", output)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -267,6 +267,7 @@ type CreatePrimaryKeyWithAutoGenCheckData struct {
|
||||
ConstraintName string
|
||||
AutoGenNames string // Comma-separated list of names like "'name1', 'name2'"
|
||||
Columns string
|
||||
ColumnNames string // Comma-separated list of quoted column names like "'id', 'tenant_id'"
|
||||
}
|
||||
|
||||
// Execute methods for each template
|
||||
|
||||
@@ -1,26 +1,42 @@
|
||||
DO $$
|
||||
DECLARE
|
||||
auto_pk_name text;
|
||||
current_pk_name text;
|
||||
current_pk_matches boolean := false;
|
||||
BEGIN
|
||||
-- Drop auto-generated primary key if it exists
|
||||
SELECT constraint_name INTO auto_pk_name
|
||||
FROM information_schema.table_constraints
|
||||
WHERE table_schema = '{{.SchemaName}}'
|
||||
AND table_name = '{{.TableName}}'
|
||||
AND constraint_type = 'PRIMARY KEY'
|
||||
AND constraint_name IN ({{.AutoGenNames}});
|
||||
SELECT tc.constraint_name,
|
||||
COALESCE(
|
||||
ARRAY(
|
||||
SELECT a.attname
|
||||
FROM pg_constraint c
|
||||
JOIN pg_class t ON t.oid = c.conrelid
|
||||
JOIN pg_namespace n ON n.oid = t.relnamespace
|
||||
JOIN unnest(c.conkey) WITH ORDINALITY AS cols(attnum, ord)
|
||||
ON TRUE
|
||||
JOIN pg_attribute a
|
||||
ON a.attrelid = t.oid
|
||||
AND a.attnum = cols.attnum
|
||||
WHERE c.contype = 'p'
|
||||
AND n.nspname = '{{.SchemaName}}'
|
||||
AND t.relname = '{{.TableName}}'
|
||||
ORDER BY cols.ord
|
||||
),
|
||||
ARRAY[]::text[]
|
||||
) = ARRAY[{{.ColumnNames}}]
|
||||
INTO current_pk_name, current_pk_matches
|
||||
FROM information_schema.table_constraints tc
|
||||
WHERE tc.table_schema = '{{.SchemaName}}'
|
||||
AND tc.table_name = '{{.TableName}}'
|
||||
AND tc.constraint_type = 'PRIMARY KEY';
|
||||
|
||||
IF auto_pk_name IS NOT NULL THEN
|
||||
EXECUTE 'ALTER TABLE {{qual_table .SchemaName .TableName}} DROP CONSTRAINT ' || quote_ident(auto_pk_name);
|
||||
IF current_pk_name IS NOT NULL
|
||||
AND NOT current_pk_matches
|
||||
AND current_pk_name IN ({{.AutoGenNames}}) THEN
|
||||
EXECUTE 'ALTER TABLE {{qual_table .SchemaName .TableName}} DROP CONSTRAINT ' || quote_ident(current_pk_name);
|
||||
END IF;
|
||||
|
||||
-- Add named primary key if it doesn't exist
|
||||
IF NOT EXISTS (
|
||||
SELECT 1 FROM information_schema.table_constraints
|
||||
WHERE table_schema = '{{.SchemaName}}'
|
||||
AND table_name = '{{.TableName}}'
|
||||
AND constraint_name = '{{.ConstraintName}}'
|
||||
) THEN
|
||||
-- Add the desired primary key only when no matching primary key already exists.
|
||||
IF current_pk_name IS NULL
|
||||
OR (NOT current_pk_matches AND current_pk_name IN ({{.AutoGenNames}})) THEN
|
||||
ALTER TABLE {{qual_table .SchemaName .TableName}} ADD CONSTRAINT {{quote_ident .ConstraintName}} PRIMARY KEY ({{.Columns}});
|
||||
END IF;
|
||||
END;
|
||||
|
||||
@@ -228,6 +228,7 @@ func (w *Writer) GenerateSchemaStatements(schema *models.Schema) ([]string, erro
|
||||
ConstraintName: pkName,
|
||||
AutoGenNames: formatStringList(autoGenPKNames),
|
||||
Columns: strings.Join(pkColumns, ", "),
|
||||
ColumnNames: formatStringList(pkColumns),
|
||||
}
|
||||
|
||||
stmt, err := w.executor.ExecuteCreatePrimaryKeyWithAutoGenCheck(data)
|
||||
@@ -806,6 +807,7 @@ func (w *Writer) writePrimaryKeys(schema *models.Schema) error {
|
||||
ConstraintName: pkName,
|
||||
AutoGenNames: formatStringList(autoGenPKNames),
|
||||
Columns: strings.Join(columnNames, ", "),
|
||||
ColumnNames: formatStringList(columnNames),
|
||||
}
|
||||
|
||||
sql, err := w.executor.ExecuteCreatePrimaryKeyWithAutoGenCheck(data)
|
||||
|
||||
@@ -673,9 +673,14 @@ func TestPrimaryKeyExistenceCheck(t *testing.T) {
|
||||
t.Errorf("Output missing logic to drop auto-generated primary key\nFull output:\n%s", output)
|
||||
}
|
||||
|
||||
// Verify it checks for our specific named constraint before adding it
|
||||
if !strings.Contains(output, "constraint_name = 'pk_public_products'") {
|
||||
t.Errorf("Output missing check for our named primary key constraint\nFull output:\n%s", output)
|
||||
// Verify it compares the current primary key columns before dropping/recreating
|
||||
if !strings.Contains(output, "current_pk_matches") || !strings.Contains(output, "ARRAY['id']") {
|
||||
t.Errorf("Output missing safe primary key comparison logic\nFull output:\n%s", output)
|
||||
}
|
||||
|
||||
// Verify it only adds the desired key when no PK exists or an auto-generated mismatch was dropped
|
||||
if !strings.Contains(output, "current_pk_name IS NULL") || !strings.Contains(output, "current_pk_name IN ('products_pkey', 'public_products_pkey')") {
|
||||
t.Errorf("Output missing guarded primary key creation logic\nFull output:\n%s", output)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -766,6 +771,43 @@ func TestColumnSizeSpecifiers(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteDatabase_PrimaryKeyTemplateDoesNotDropMatchingAutoPrimaryKey(t *testing.T) {
|
||||
db := models.InitDatabase("testdb")
|
||||
schema := models.InitSchema("public")
|
||||
table := models.InitTable("learnings", "public")
|
||||
|
||||
idCol := models.InitColumn("id", "learnings", "public")
|
||||
idCol.Type = "bigint"
|
||||
idCol.IsPrimaryKey = true
|
||||
table.Columns["id"] = idCol
|
||||
|
||||
parentCol := models.InitColumn("duplicate_of_learning_id", "learnings", "public")
|
||||
parentCol.Type = "bigint"
|
||||
table.Columns["duplicate_of_learning_id"] = parentCol
|
||||
|
||||
schema.Tables = append(schema.Tables, table)
|
||||
db.Schemas = append(db.Schemas, schema)
|
||||
|
||||
var buf bytes.Buffer
|
||||
writer := NewWriter(&writers.WriterOptions{})
|
||||
writer.writer = &buf
|
||||
|
||||
if err := writer.WriteDatabase(db); err != nil {
|
||||
t.Fatalf("WriteDatabase failed: %v", err)
|
||||
}
|
||||
|
||||
output := buf.String()
|
||||
if !strings.Contains(output, "current_pk_matches") {
|
||||
t.Fatalf("expected generated SQL to compare current PK columns, got:\n%s", output)
|
||||
}
|
||||
if !strings.Contains(output, "ARRAY['id']") {
|
||||
t.Fatalf("expected generated SQL to compare against desired PK columns, got:\n%s", output)
|
||||
}
|
||||
if !strings.Contains(output, "NOT current_pk_matches") {
|
||||
t.Fatalf("expected generated SQL to avoid dropping matching PKs, got:\n%s", output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateColumnDefinition_PreservesExplicitTypeModifiers(t *testing.T) {
|
||||
writer := NewWriter(&writers.WriterOptions{})
|
||||
|
||||
|
||||
Reference in New Issue
Block a user