Compare commits

...

8 Commits

Author SHA1 Message Date
30ef1db010 chore(release): update package version to 1.0.56
All checks were successful
Release / test (push) Successful in -32m19s
Release / release (push) Successful in -31m39s
Release / pkg-deb (push) Successful in -31m57s
Release / pkg-aur (push) Successful in -31m46s
Release / pkg-rpm (push) Successful in -4m7s
2026-05-05 14:51:10 +02:00
2d97a47ee1 feat: Enhance PostgreSQL type handling and migration scripts
- Introduced equivalent base types and variants for PostgreSQL types to normalize type comparisons.
- Added functions for normalizing SQL types and retrieving equivalent type variants.
- Updated migration writer to handle type alterations with checks for existing types.
- Implemented logic to create necessary extensions (e.g., pg_trgm) based on schema requirements.
- Enhanced tests to cover new functionality for type normalization and migration handling.
- Improved handling of GIN indexes to use appropriate operator classes based on column types.
2026-05-05 14:50:34 +02:00
72200ea72e chore(release): update package version to 1.0.55
All checks were successful
Release / test (push) Successful in -32m1s
Release / release (push) Successful in -31m13s
Release / pkg-aur (push) Successful in -32m13s
Release / pkg-deb (push) Successful in -31m12s
Release / pkg-rpm (push) Successful in -29m45s
2026-05-05 11:36:29 +02:00
608893a3d6 feat(index): implement GIN index support for quoted text columns and enhance index column resolution 2026-05-05 11:32:15 +02:00
53ff745d5d chore(release): update package version to 1.0.54
All checks were successful
Release / test (push) Successful in -31m47s
Release / release (push) Successful in -31m9s
Release / pkg-aur (push) Successful in -31m57s
Release / pkg-deb (push) Successful in -31m1s
Release / pkg-rpm (push) Successful in -29m27s
2026-05-05 11:12:49 +02:00
17bc8ed395 feat(migration): enhance primary key handling and add GIN index support in migration writer 2026-05-05 11:12:23 +02:00
a447b68b22 chore(release): update package version to 1.0.53
All checks were successful
Release / test (push) Successful in -31m55s
Release / release (push) Successful in -31m19s
Release / pkg-aur (push) Successful in -32m3s
Release / pkg-deb (push) Successful in -31m21s
Release / pkg-rpm (push) Successful in -28m4s
2026-05-05 10:48:27 +02:00
4303dcf59b Support typed primary key helpers in gorm and bun writers 2026-05-05 10:32:33 +02:00
25 changed files with 1720 additions and 119 deletions

View File

@@ -258,6 +258,11 @@ func runMerge(cmd *cobra.Command, args []string) error {
fmt.Fprintf(os.Stderr, " ✓ Merge complete\n\n") fmt.Fprintf(os.Stderr, " ✓ Merge complete\n\n")
fmt.Fprintf(os.Stderr, "%s\n", merge.GetMergeSummary(result)) fmt.Fprintf(os.Stderr, "%s\n", merge.GetMergeSummary(result))
if strings.EqualFold(mergeOutputType, "pgsql") && len(result.TypeConflicts) > 0 {
return fmt.Errorf("merge detected conflicting existing column types and cannot safely continue with pgsql output\n%s",
merge.GetColumnTypeConflictSummary(result, 10))
}
// Step 4: Write output // Step 4: Write output
fmt.Fprintf(os.Stderr, "\n[4/4] Writing output...\n") fmt.Fprintf(os.Stderr, "\n[4/4] Writing output...\n")
fmt.Fprintf(os.Stderr, " Format: %s\n", mergeOutputType) fmt.Fprintf(os.Stderr, " Format: %s\n", mergeOutputType)

View File

@@ -3,6 +3,7 @@ package main
import ( import (
"os" "os"
"path/filepath" "path/filepath"
"strings"
"testing" "testing"
) )
@@ -160,3 +161,38 @@ func TestRunMerge_FromListMissingSourceType(t *testing.T) {
t.Error("expected error when neither --source-path nor --from-list is provided") t.Error("expected error when neither --source-path nor --from-list is provided")
} }
} }
func TestRunMerge_PgsqlOutputRejectsColumnTypeConflict(t *testing.T) {
saved := saveMergeState()
defer restoreMergeState(saved)
dir := t.TempDir()
targetFile := filepath.Join(dir, "target.json")
sourceFile := filepath.Join(dir, "source.json")
writeTestJSONWithSingleColumnType(t, targetFile, "users", "integer")
writeTestJSONWithSingleColumnType(t, sourceFile, "users", "uuid")
mergeTargetType = "json"
mergeTargetPath = targetFile
mergeTargetConn = ""
mergeSourceType = "json"
mergeSourcePath = sourceFile
mergeSourceConn = ""
mergeFromList = nil
mergeOutputType = "pgsql"
mergeOutputPath = ""
mergeOutputConn = "postgres://relspec:secret@localhost/testdb"
mergeSkipTables = ""
mergeReportPath = ""
err := runMerge(nil, nil)
if err == nil {
t.Fatal("expected pgsql output merge to fail on column type conflict")
}
if !strings.Contains(err.Error(), "column type conflicts detected") {
t.Fatalf("expected conflict summary in error, got: %v", err)
}
if !strings.Contains(err.Error(), "public.users.id") {
t.Fatalf("expected conflicting column path in error, got: %v", err)
}
}

View File

@@ -71,6 +71,40 @@ func writeTestJSON(t *testing.T, path string, tableNames []string) {
} }
} }
func writeTestJSONWithSingleColumnType(t *testing.T, path, tableName, columnType string) {
t.Helper()
db := minimalDatabase{
Name: "test_db",
Schemas: []minimalSchema{{
Name: "public",
Tables: []minimalTable{{
Name: tableName,
Schema: "public",
Columns: map[string]minimalColumn{
"id": {
Name: "id",
Table: tableName,
Schema: "public",
Type: columnType,
NotNull: true,
IsPrimaryKey: true,
AutoIncrement: true,
},
},
}},
}},
}
data, err := json.Marshal(db)
if err != nil {
t.Fatalf("failed to marshal test JSON: %v", err)
}
if err := os.WriteFile(path, data, 0644); err != nil {
t.Fatalf("failed to write test file %s: %v", path, err)
}
}
// convertState captures and restores all convert global vars. // convertState captures and restores all convert global vars.
type convertState struct { type convertState struct {
sourceType string sourceType string

View File

@@ -1,6 +1,6 @@
# Maintainer: Hein (Warky Devs) <hein@warky.dev> # Maintainer: Hein (Warky Devs) <hein@warky.dev>
pkgname=relspec pkgname=relspec
pkgver=1.0.52 pkgver=1.0.56
pkgrel=1 pkgrel=1
pkgdesc="RelSpec is a comprehensive database relations management tool that reads, transforms, and writes database table specifications across multiple formats and ORMs." pkgdesc="RelSpec is a comprehensive database relations management tool that reads, transforms, and writes database table specifications across multiple formats and ORMs."
arch=('x86_64' 'aarch64') arch=('x86_64' 'aarch64')

View File

@@ -1,5 +1,5 @@
Name: relspec Name: relspec
Version: 1.0.52 Version: 1.0.56
Release: 1%{?dist} Release: 1%{?dist}
Summary: RelSpec is a comprehensive database relations management tool that reads, transforms, and writes database table specifications across multiple formats and ORMs. Summary: RelSpec is a comprehensive database relations management tool that reads, transforms, and writes database table specifications across multiple formats and ORMs.

View File

@@ -22,6 +22,16 @@ type MergeResult struct {
EnumsAdded int EnumsAdded int
ViewsAdded int ViewsAdded int
SequencesAdded int SequencesAdded int
TypeConflicts []ColumnTypeConflict
}
// ColumnTypeConflict describes a column that exists in both schemas but with incompatible types.
type ColumnTypeConflict struct {
Schema string
Table string
Column string
TargetType string
SourceType string
} }
// MergeOptions contains options for merge operations // MergeOptions contains options for merge operations
@@ -146,11 +156,19 @@ func (r *MergeResult) mergeColumns(table *models.Table, srcTable *models.Table)
// Merge columns // Merge columns
for colName, srcCol := range srcTable.Columns { for colName, srcCol := range srcTable.Columns {
if _, exists := existingColumns[colName]; !exists { if tgtCol, exists := existingColumns[colName]; !exists {
// Column doesn't exist, add it // Column doesn't exist, add it
newCol := cloneColumn(srcCol) newCol := cloneColumn(srcCol)
table.Columns[colName] = newCol table.Columns[colName] = newCol
r.ColumnsAdded++ r.ColumnsAdded++
} else if columnTypeConflict(tgtCol, srcCol) {
r.TypeConflicts = append(r.TypeConflicts, ColumnTypeConflict{
Schema: firstNonEmpty(table.Schema, srcTable.Schema, srcCol.Schema),
Table: firstNonEmpty(table.Name, srcTable.Name, srcCol.Table),
Column: firstNonEmpty(tgtCol.Name, srcCol.Name, colName),
TargetType: describeColumnType(tgtCol),
SourceType: describeColumnType(srcCol),
})
} }
} }
} }
@@ -426,6 +444,52 @@ func cloneColumn(col *models.Column) *models.Column {
return newCol return newCol
} }
func columnTypeConflict(target, source *models.Column) bool {
if target == nil || source == nil {
return false
}
return normalizeType(target.Type) != normalizeType(source.Type) ||
target.Length != source.Length ||
target.Precision != source.Precision ||
target.Scale != source.Scale
}
func normalizeType(value string) string {
return strings.ToLower(strings.TrimSpace(value))
}
func describeColumnType(col *models.Column) string {
if col == nil {
return ""
}
typeName := strings.TrimSpace(col.Type)
if typeName == "" {
return ""
}
switch {
case col.Precision > 0 && col.Scale > 0:
return fmt.Sprintf("%s(%d,%d)", typeName, col.Precision, col.Scale)
case col.Precision > 0:
return fmt.Sprintf("%s(%d)", typeName, col.Precision)
case col.Length > 0:
return fmt.Sprintf("%s(%d)", typeName, col.Length)
default:
return typeName
}
}
func firstNonEmpty(values ...string) string {
for _, value := range values {
if strings.TrimSpace(value) != "" {
return value
}
}
return ""
}
func cloneConstraint(constraint *models.Constraint) *models.Constraint { func cloneConstraint(constraint *models.Constraint) *models.Constraint {
if constraint == nil { if constraint == nil {
return nil return nil
@@ -609,6 +673,7 @@ func GetMergeSummary(result *MergeResult) string {
fmt.Sprintf("Enums added: %d", result.EnumsAdded), fmt.Sprintf("Enums added: %d", result.EnumsAdded),
fmt.Sprintf("Relations added: %d", result.RelationsAdded), fmt.Sprintf("Relations added: %d", result.RelationsAdded),
fmt.Sprintf("Domains added: %d", result.DomainsAdded), fmt.Sprintf("Domains added: %d", result.DomainsAdded),
fmt.Sprintf("Type conflicts: %d", len(result.TypeConflicts)),
} }
totalAdded := result.SchemasAdded + result.TablesAdded + result.ColumnsAdded + totalAdded := result.SchemasAdded + result.TablesAdded + result.ColumnsAdded +
@@ -625,3 +690,35 @@ func GetMergeSummary(result *MergeResult) string {
return summary return summary
} }
// GetColumnTypeConflictSummary returns a short, human-readable conflict summary.
func GetColumnTypeConflictSummary(result *MergeResult, limit int) string {
if result == nil || len(result.TypeConflicts) == 0 {
return ""
}
if limit <= 0 {
limit = len(result.TypeConflicts)
}
lines := make([]string, 0, min(limit, len(result.TypeConflicts))+1)
lines = append(lines, "column type conflicts detected:")
for i, conflict := range result.TypeConflicts {
if i >= limit {
break
}
lines = append(lines, fmt.Sprintf(" - %s.%s.%s: target=%s source=%s",
conflict.Schema, conflict.Table, conflict.Column, conflict.TargetType, conflict.SourceType))
}
if len(result.TypeConflicts) > limit {
lines = append(lines, fmt.Sprintf(" ... and %d more", len(result.TypeConflicts)-limit))
}
return strings.Join(lines, "\n")
}
func min(a, b int) int {
if a < b {
return a
}
return b
}

View File

@@ -1,6 +1,7 @@
package merge package merge
import ( import (
"strings"
"testing" "testing"
"git.warky.dev/wdevs/relspecgo/pkg/models" "git.warky.dev/wdevs/relspecgo/pkg/models"
@@ -140,6 +141,61 @@ func TestMergeColumns_NewColumn(t *testing.T) {
} }
} }
func TestMergeColumns_TypeConflictIsDetected(t *testing.T) {
target := &models.Database{
Schemas: []*models.Schema{
{
Name: "public",
Tables: []*models.Table{
{
Name: "users",
Schema: "public",
Columns: map[string]*models.Column{
"email": {Name: "email", Type: "varchar", Length: 255},
},
},
},
},
},
}
source := &models.Database{
Schemas: []*models.Schema{
{
Name: "public",
Tables: []*models.Table{
{
Name: "users",
Schema: "public",
Columns: map[string]*models.Column{
"email": {Name: "email", Type: "text"},
},
},
},
},
},
}
result := MergeDatabases(target, source, nil)
if len(result.TypeConflicts) != 1 {
t.Fatalf("Expected 1 type conflict, got %d", len(result.TypeConflicts))
}
conflict := result.TypeConflicts[0]
if conflict.Schema != "public" || conflict.Table != "users" || conflict.Column != "email" {
t.Fatalf("Unexpected conflict location: %+v", conflict)
}
if conflict.TargetType != "varchar(255)" {
t.Fatalf("Expected target type varchar(255), got %q", conflict.TargetType)
}
if conflict.SourceType != "text" {
t.Fatalf("Expected source type text, got %q", conflict.SourceType)
}
if got := target.Schemas[0].Tables[0].Columns["email"].Type; got != "varchar" {
t.Fatalf("Expected target column type to remain unchanged, got %q", got)
}
}
func TestMergeConstraints_NewConstraint(t *testing.T) { func TestMergeConstraints_NewConstraint(t *testing.T) {
target := &models.Database{ target := &models.Database{
Schemas: []*models.Schema{ Schemas: []*models.Schema{
@@ -509,6 +565,9 @@ func TestGetMergeSummary(t *testing.T) {
ConstraintsAdded: 3, ConstraintsAdded: 3,
IndexesAdded: 2, IndexesAdded: 2,
ViewsAdded: 1, ViewsAdded: 1,
TypeConflicts: []ColumnTypeConflict{
{Schema: "public", Table: "users", Column: "email", TargetType: "varchar(255)", SourceType: "text"},
},
} }
summary := GetMergeSummary(result) summary := GetMergeSummary(result)
@@ -518,6 +577,9 @@ func TestGetMergeSummary(t *testing.T) {
if len(summary) < 50 { if len(summary) < 50 {
t.Errorf("Summary seems too short: %s", summary) t.Errorf("Summary seems too short: %s", summary)
} }
if !strings.Contains(summary, "Type conflicts: 1") {
t.Errorf("Expected type conflict count in summary, got: %s", summary)
}
} }
func TestGetMergeSummary_Nil(t *testing.T) { func TestGetMergeSummary_Nil(t *testing.T) {

View File

@@ -145,6 +145,24 @@ var postgresTypeAliases = map[string]string{
"bool": "boolean", "bool": "boolean",
} }
var postgresEquivalentBaseTypes = map[string]string{
"character varying": "varchar",
"character": "char",
"timestamp without time zone": "timestamp",
"timestamp with time zone": "timestamptz",
"time without time zone": "time",
"time with time zone": "timetz",
}
var postgresEquivalentBaseTypeVariants = map[string][]string{
"varchar": {"varchar", "character varying"},
"char": {"char", "character"},
"timestamp": {"timestamp", "timestamp without time zone"},
"timestamptz": {"timestamptz", "timestamp with time zone"},
"time": {"time", "time without time zone"},
"timetz": {"timetz", "time with time zone"},
}
// GetPostgresBaseTypes returns a sorted-ish stable list of registered base type names. // GetPostgresBaseTypes returns a sorted-ish stable list of registered base type names.
func GetPostgresBaseTypes() []string { func GetPostgresBaseTypes() []string {
result := make([]string, 0, len(postgresBaseTypes)) result := make([]string, 0, len(postgresBaseTypes))
@@ -212,6 +230,86 @@ func CanonicalizeBaseType(baseType string) string {
return base return base
} }
// EquivalentBaseType resolves broader SQL-equivalent spellings to a common comparable form.
func EquivalentBaseType(baseType string) string {
base := CanonicalizeBaseType(baseType)
if equivalent, ok := postgresEquivalentBaseTypes[base]; ok {
return equivalent
}
return base
}
// NormalizeEquivalentSQLType returns a normalized SQL type string suitable for equality checks.
// Equivalent spellings such as "character varying(255)" and "varchar(255)" normalize identically.
func NormalizeEquivalentSQLType(sqlType string) string {
t := normalizeTypeToken(sqlType)
if t == "" {
return ""
}
arrayDepth := 0
for strings.HasSuffix(t, "[]") {
arrayDepth++
t = strings.TrimSpace(strings.TrimSuffix(t, "[]"))
}
modifier := ""
if idx := strings.Index(t, "("); idx >= 0 {
modifier = strings.TrimSpace(t[idx:])
t = strings.TrimSpace(t[:idx])
}
base := EquivalentBaseType(t)
normalized := base + modifier
for i := 0; i < arrayDepth; i++ {
normalized += "[]"
}
return normalized
}
// EquivalentSQLTypeVariants returns equivalent PostgreSQL spellings for a SQL type.
// Examples:
// - varchar(255) -> ["varchar(255)", "character varying(255)"]
// - timestamptz -> ["timestamptz", "timestamp with time zone"]
func EquivalentSQLTypeVariants(sqlType string) []string {
t := normalizeTypeToken(sqlType)
if t == "" {
return nil
}
arrayDepth := 0
for strings.HasSuffix(t, "[]") {
arrayDepth++
t = strings.TrimSpace(strings.TrimSuffix(t, "[]"))
}
modifier := ""
if idx := strings.Index(t, "("); idx >= 0 {
modifier = strings.TrimSpace(t[idx:])
t = strings.TrimSpace(t[:idx])
}
base := EquivalentBaseType(t)
bases := postgresEquivalentBaseTypeVariants[base]
if len(bases) == 0 {
bases = []string{base}
}
seen := make(map[string]bool, len(bases))
result := make([]string, 0, len(bases))
for _, variantBase := range bases {
variant := variantBase + modifier
for i := 0; i < arrayDepth; i++ {
variant += "[]"
}
if !seen[variant] {
seen[variant] = true
result = append(result, variant)
}
}
return result
}
// IsKnownPostgresType reports whether a type (including array forms) exists in the registry. // IsKnownPostgresType reports whether a type (including array forms) exists in the registry.
func IsKnownPostgresType(sqlType string) bool { func IsKnownPostgresType(sqlType string) bool {
base := CanonicalizeBaseType(ExtractBaseTypeLower(sqlType)) base := CanonicalizeBaseType(ExtractBaseTypeLower(sqlType))

View File

@@ -97,3 +97,51 @@ func TestPostgresTypeRegistry_TypeParsingAndCapabilities(t *testing.T) {
}) })
} }
} }
func TestNormalizeEquivalentSQLType(t *testing.T) {
tests := []struct {
input string
want string
}{
{input: "character varying(255)", want: "varchar(255)"},
{input: "varchar(255)", want: "varchar(255)"},
{input: "timestamp with time zone", want: "timestamptz"},
{input: "timestamptz", want: "timestamptz"},
{input: "time without time zone", want: "time"},
{input: "character varying(255)[]", want: "varchar(255)[]"},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
got := NormalizeEquivalentSQLType(tt.input)
if got != tt.want {
t.Fatalf("NormalizeEquivalentSQLType(%q) = %q, want %q", tt.input, got, tt.want)
}
})
}
}
func TestEquivalentSQLTypeVariants(t *testing.T) {
tests := []struct {
input string
want []string
}{
{input: "character varying(255)", want: []string{"varchar(255)", "character varying(255)"}},
{input: "timestamptz", want: []string{"timestamptz", "timestamp with time zone"}},
{input: "text[]", want: []string{"text[]"}},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
got := EquivalentSQLTypeVariants(tt.input)
if len(got) != len(tt.want) {
t.Fatalf("EquivalentSQLTypeVariants(%q) len = %d, want %d (%v)", tt.input, len(got), len(tt.want), got)
}
for i := range tt.want {
if got[i] != tt.want[i] {
t.Fatalf("EquivalentSQLTypeVariants(%q)[%d] = %q, want %q", tt.input, i, got[i], tt.want[i])
}
}
})
}
}

View File

@@ -18,17 +18,20 @@ type TemplateData struct {
// ModelData represents a single model/struct in the template // ModelData represents a single model/struct in the template
type ModelData struct { type ModelData struct {
Name string Name string
TableName string // schema.table format TableName string // schema.table format
SchemaName string SchemaName string
TableNameOnly string // just table name without schema TableNameOnly string // just table name without schema
Comment string Comment string
Fields []*FieldData Fields []*FieldData
Config *MethodConfig Config *MethodConfig
PrimaryKeyField string // Name of the primary key field PrimaryKeyField string // Name of the primary key field
PrimaryKeyIsSQL bool // Whether PK uses SQL type (needs .Int64() call) PrimaryKeyType string // Go type of the primary key field
IDColumnName string // Name of the ID column in database PrimaryKeyIsSQL bool // Whether PK uses SQL type (needs .Int64() call)
Prefix string // 3-letter prefix PrimaryKeyIsStr bool // Whether helper methods should use string IDs
PrimaryKeyIDType string // Helper method GetID/SetID/UpdateID type
IDColumnName string // Name of the ID column in database
Prefix string // 3-letter prefix
} }
// FieldData represents a single field in a struct // FieldData represents a single field in a struct
@@ -140,7 +143,13 @@ func NewModelData(table *models.Table, schema string, typeMapper *TypeMapper, fl
model.IDColumnName = safeName model.IDColumnName = safeName
// Check if PK type is a SQL type (contains resolvespec_common or sql_types) // Check if PK type is a SQL type (contains resolvespec_common or sql_types)
goType := typeMapper.SQLTypeToGoType(col.Type, col.NotNull) goType := typeMapper.SQLTypeToGoType(col.Type, col.NotNull)
model.PrimaryKeyType = goType
model.PrimaryKeyIsSQL = strings.Contains(goType, "resolvespec_common") || strings.Contains(goType, "sql_types") model.PrimaryKeyIsSQL = strings.Contains(goType, "resolvespec_common") || strings.Contains(goType, "sql_types")
model.PrimaryKeyIsStr = isStringLikePrimaryKeyType(goType)
model.PrimaryKeyIDType = "int64"
if model.PrimaryKeyIsStr {
model.PrimaryKeyIDType = "string"
}
break break
} }
} }
@@ -192,6 +201,15 @@ func formatComment(description, comment string) string {
return comment return comment
} }
func isStringLikePrimaryKeyType(goType string) bool {
switch goType {
case "string", "sql.NullString", "resolvespec_common.SqlString", "resolvespec_common.SqlUUID":
return true
default:
return false
}
}
// resolveFieldNameCollision checks if a field name conflicts with generated method names // resolveFieldNameCollision checks if a field name conflicts with generated method names
// and adds an underscore suffix if there's a collision // and adds an underscore suffix if there's a collision
func resolveFieldNameCollision(fieldName string) string { func resolveFieldNameCollision(fieldName string) string {

View File

@@ -44,33 +44,55 @@ func (m {{.Name}}) SchemaName() string {
{{end}} {{end}}
{{if and .Config.GenerateGetID .PrimaryKeyField}} {{if and .Config.GenerateGetID .PrimaryKeyField}}
// GetID returns the primary key value // GetID returns the primary key value
func (m {{.Name}}) GetID() int64 { func (m {{.Name}}) GetID() {{.PrimaryKeyIDType}} {
{{if .PrimaryKeyIsSQL -}} {{if .PrimaryKeyIsSQL -}}
{{if .PrimaryKeyIsStr -}}
return m.{{.PrimaryKeyField}}.String()
{{- else -}}
return m.{{.PrimaryKeyField}}.Int64() return m.{{.PrimaryKeyField}}.Int64()
{{- end}}
{{- else -}}
{{if .PrimaryKeyIsStr -}}
return m.{{.PrimaryKeyField}}
{{- else -}} {{- else -}}
return int64(m.{{.PrimaryKeyField}}) return int64(m.{{.PrimaryKeyField}})
{{- end}} {{- end}}
{{- end}}
} }
{{end}} {{end}}
{{if and .Config.GenerateGetIDStr .PrimaryKeyField}} {{if and .Config.GenerateGetIDStr .PrimaryKeyField}}
// GetIDStr returns the primary key as a string // GetIDStr returns the primary key as a string
func (m {{.Name}}) GetIDStr() string { func (m {{.Name}}) GetIDStr() string {
{{if .PrimaryKeyIsSQL -}}
return m.{{.PrimaryKeyField}}.String()
{{- else if .PrimaryKeyIsStr -}}
return m.{{.PrimaryKeyField}}
{{- else -}}
return fmt.Sprintf("%d", m.{{.PrimaryKeyField}}) return fmt.Sprintf("%d", m.{{.PrimaryKeyField}})
{{- end}}
} }
{{end}} {{end}}
{{if and .Config.GenerateSetID .PrimaryKeyField}} {{if and .Config.GenerateSetID .PrimaryKeyField}}
// SetID sets the primary key value // SetID sets the primary key value
func (m {{.Name}}) SetID(newid int64) { func (m {{.Name}}) SetID(newid {{.PrimaryKeyIDType}}) {
m.UpdateID(newid) m.UpdateID(newid)
} }
{{end}} {{end}}
{{if and .Config.GenerateUpdateID .PrimaryKeyField}} {{if and .Config.GenerateUpdateID .PrimaryKeyField}}
// UpdateID updates the primary key value // UpdateID updates the primary key value
func (m *{{.Name}}) UpdateID(newid int64) { func (m *{{.Name}}) UpdateID(newid {{.PrimaryKeyIDType}}) {
{{if .PrimaryKeyIsSQL -}} {{if .PrimaryKeyIsSQL -}}
m.{{.PrimaryKeyField}}.FromString(fmt.Sprintf("%d", newid)) {{if .PrimaryKeyIsStr -}}
m.{{.PrimaryKeyField}}.FromString(newid)
{{- else -}} {{- else -}}
m.{{.PrimaryKeyField}} = int32(newid) m.{{.PrimaryKeyField}}.FromString(fmt.Sprintf("%d", newid))
{{- end}}
{{- else -}}
{{if .PrimaryKeyIsStr -}}
m.{{.PrimaryKeyField}} = newid
{{- else -}}
m.{{.PrimaryKeyField}} = {{.PrimaryKeyType}}(newid)
{{- end}}
{{- end}} {{- end}}
} }
{{end}} {{end}}

View File

@@ -102,8 +102,8 @@ func (w *Writer) writeSingleFile(db *models.Database) error {
} }
} }
// Add fmt import if GetIDStr is enabled // Add fmt import when generated helper methods need string formatting.
if w.config.GenerateGetIDStr { if w.needsFmtImport(templateData.Models) {
templateData.AddImport("\"fmt\"") templateData.AddImport("\"fmt\"")
} }
@@ -195,8 +195,8 @@ func (w *Writer) writeMultiFile(db *models.Database) error {
} }
} }
// Add fmt import if GetIDStr is enabled // Add fmt import when generated helper methods need string formatting.
if w.config.GenerateGetIDStr { if w.needsFmtImport(templateData.Models) {
templateData.AddImport("\"fmt\"") templateData.AddImport("\"fmt\"")
} }
@@ -301,6 +301,26 @@ func (w *Writer) addRelationshipFields(modelData *ModelData, table *models.Table
} }
} }
func (w *Writer) needsFmtImport(models []*ModelData) bool {
if w.config.GenerateGetIDStr {
for _, model := range models {
if model.PrimaryKeyField != "" && !model.PrimaryKeyIsSQL && !model.PrimaryKeyIsStr {
return true
}
}
}
if w.config.GenerateUpdateID {
for _, model := range models {
if model.PrimaryKeyField != "" && model.PrimaryKeyIsSQL && !model.PrimaryKeyIsStr {
return true
}
}
}
return false
}
// findTable finds a table by schema and name in the database // findTable finds a table by schema and name in the database
func (w *Writer) findTable(schemaName, tableName string, db *models.Database) *models.Table { func (w *Writer) findTable(schemaName, tableName string, db *models.Database) *models.Table {
for _, schema := range db.Schemas { for _, schema := range db.Schemas {

View File

@@ -590,6 +590,116 @@ func TestTypeMapper_SQLTypeToGoType_Bun(t *testing.T) {
} }
} }
func TestWriter_UpdateIDTypeSafety_Bun(t *testing.T) {
tests := []struct {
name string
pkType string
expectedPK string
expectedLine string
forbidInt32 bool
}{
{"int32_pk", "int", "int32", "m.ID = int32(newid)", false},
{"sql_int16_pk", "smallint", "resolvespec_common.SqlInt16", "m.ID.FromString(fmt.Sprintf(\"%d\", newid))", true},
{"int64_pk", "bigint", "int64", "m.ID = int64(newid)", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
table := models.InitTable("test_table", "public")
table.Columns["id"] = &models.Column{
Name: "id",
Type: tt.pkType,
NotNull: true,
IsPrimaryKey: true,
}
tmpDir := t.TempDir()
opts := &writers.WriterOptions{
PackageName: "models",
OutputPath: filepath.Join(tmpDir, "test.go"),
}
writer := NewWriter(opts)
err := writer.WriteTable(table)
if err != nil {
t.Fatalf("WriteTable failed: %v", err)
}
content, err := os.ReadFile(opts.OutputPath)
if err != nil {
t.Fatalf("Failed to read generated file: %v", err)
}
generated := string(content)
if !strings.Contains(generated, tt.expectedLine) {
t.Errorf("Expected UpdateID to include %s\nGenerated:\n%s", tt.expectedLine, generated)
}
if !strings.Contains(generated, "ID "+tt.expectedPK) {
t.Errorf("Expected generated primary key field type %s\nGenerated:\n%s", tt.expectedPK, generated)
}
if tt.forbidInt32 && strings.Contains(generated, "int32(newid)") {
t.Errorf("UpdateID should not cast to int32 for %s type\nGenerated:\n%s", tt.pkType, generated)
}
if !strings.Contains(generated, "UpdateID(newid int64)") {
t.Errorf("UpdateID should accept int64 parameter\nGenerated:\n%s", generated)
}
})
}
}
func TestWriter_StringPrimaryKeyHelpers_Bun(t *testing.T) {
table := models.InitTable("accounts", "public")
table.Columns["id"] = &models.Column{
Name: "id",
Type: "uuid",
NotNull: true,
IsPrimaryKey: true,
}
tmpDir := t.TempDir()
opts := &writers.WriterOptions{
PackageName: "models",
OutputPath: filepath.Join(tmpDir, "test.go"),
}
writer := NewWriter(opts)
err := writer.WriteTable(table)
if err != nil {
t.Fatalf("WriteTable failed: %v", err)
}
content, err := os.ReadFile(opts.OutputPath)
if err != nil {
t.Fatalf("Failed to read generated file: %v", err)
}
generated := string(content)
expectations := []string{
"resolvespec_common.SqlUUID",
"func (m ModelPublicAccounts) GetID() string",
"return m.ID.String()",
"func (m ModelPublicAccounts) GetIDStr() string",
"func (m ModelPublicAccounts) SetID(newid string)",
"func (m *ModelPublicAccounts) UpdateID(newid string)",
"m.ID.FromString(newid)",
}
for _, expected := range expectations {
if !strings.Contains(generated, expected) {
t.Errorf("Generated code missing expected content: %q\nGenerated:\n%s", expected, generated)
}
}
if strings.Contains(generated, "GetID() int64") || strings.Contains(generated, "UpdateID(newid int64)") {
t.Errorf("String primary keys should not use int64 helper signatures\nGenerated:\n%s", generated)
}
}
func TestTypeMapper_BuildBunTag(t *testing.T) { func TestTypeMapper_BuildBunTag(t *testing.T) {
mapper := NewTypeMapper("") mapper := NewTypeMapper("")

View File

@@ -2,6 +2,7 @@ package gorm
import ( import (
"sort" "sort"
"strings"
"git.warky.dev/wdevs/relspecgo/pkg/models" "git.warky.dev/wdevs/relspecgo/pkg/models"
"git.warky.dev/wdevs/relspecgo/pkg/writers" "git.warky.dev/wdevs/relspecgo/pkg/writers"
@@ -17,17 +18,20 @@ type TemplateData struct {
// ModelData represents a single model/struct in the template // ModelData represents a single model/struct in the template
type ModelData struct { type ModelData struct {
Name string Name string
TableName string // schema.table format TableName string // schema.table format
SchemaName string SchemaName string
TableNameOnly string // just table name without schema TableNameOnly string // just table name without schema
Comment string Comment string
Fields []*FieldData Fields []*FieldData
Config *MethodConfig Config *MethodConfig
PrimaryKeyField string // Name of the primary key field PrimaryKeyField string // Name of the primary key field
PrimaryKeyType string // Go type of the primary key field PrimaryKeyType string // Go type of the primary key field
IDColumnName string // Name of the ID column in database PrimaryKeyIsSQL bool // Whether PK uses a SQL wrapper type
Prefix string // 3-letter prefix PrimaryKeyIsStr bool // Whether helper methods should use string IDs
PrimaryKeyIDType string // Helper method GetID/SetID/UpdateID type
IDColumnName string // Name of the ID column in database
Prefix string // 3-letter prefix
} }
// FieldData represents a single field in a struct // FieldData represents a single field in a struct
@@ -136,7 +140,14 @@ func NewModelData(table *models.Table, schema string, typeMapper *TypeMapper, fl
// Sanitize column name to remove backticks // Sanitize column name to remove backticks
safeName := writers.SanitizeStructTagValue(col.Name) safeName := writers.SanitizeStructTagValue(col.Name)
model.PrimaryKeyField = SnakeCaseToPascalCase(safeName) model.PrimaryKeyField = SnakeCaseToPascalCase(safeName)
model.PrimaryKeyType = typeMapper.SQLTypeToGoType(col.Type, col.NotNull) goType := typeMapper.SQLTypeToGoType(col.Type, col.NotNull)
model.PrimaryKeyType = goType
model.PrimaryKeyIsSQL = strings.Contains(goType, "sql_types.") || strings.Contains(goType, "sql.")
model.PrimaryKeyIsStr = isStringLikePrimaryKeyType(goType)
model.PrimaryKeyIDType = "int64"
if model.PrimaryKeyIsStr {
model.PrimaryKeyIDType = "string"
}
model.IDColumnName = safeName model.IDColumnName = safeName
break break
} }
@@ -189,6 +200,15 @@ func formatComment(description, comment string) string {
return comment return comment
} }
func isStringLikePrimaryKeyType(goType string) bool {
switch goType {
case "string", "sql.NullString", "sql_types.SqlString", "sql_types.SqlUUID":
return true
default:
return false
}
}
// resolveFieldNameCollision checks if a field name conflicts with generated method names // resolveFieldNameCollision checks if a field name conflicts with generated method names
// and adds an underscore suffix if there's a collision // and adds an underscore suffix if there's a collision
func resolveFieldNameCollision(fieldName string) string { func resolveFieldNameCollision(fieldName string) string {

View File

@@ -43,26 +43,56 @@ func (m {{.Name}}) SchemaName() string {
{{end}} {{end}}
{{if and .Config.GenerateGetID .PrimaryKeyField}} {{if and .Config.GenerateGetID .PrimaryKeyField}}
// GetID returns the primary key value // GetID returns the primary key value
func (m {{.Name}}) GetID() int64 { func (m {{.Name}}) GetID() {{.PrimaryKeyIDType}} {
{{if .PrimaryKeyIsSQL -}}
{{if .PrimaryKeyIsStr -}}
return m.{{.PrimaryKeyField}}.String()
{{- else -}}
return m.{{.PrimaryKeyField}}.Int64()
{{- end}}
{{- else -}}
{{if .PrimaryKeyIsStr -}}
return m.{{.PrimaryKeyField}}
{{- else -}}
return int64(m.{{.PrimaryKeyField}}) return int64(m.{{.PrimaryKeyField}})
{{- end}}
{{- end}}
} }
{{end}} {{end}}
{{if and .Config.GenerateGetIDStr .PrimaryKeyField}} {{if and .Config.GenerateGetIDStr .PrimaryKeyField}}
// GetIDStr returns the primary key as a string // GetIDStr returns the primary key as a string
func (m {{.Name}}) GetIDStr() string { func (m {{.Name}}) GetIDStr() string {
{{if .PrimaryKeyIsSQL -}}
return m.{{.PrimaryKeyField}}.String()
{{- else if .PrimaryKeyIsStr -}}
return m.{{.PrimaryKeyField}}
{{- else -}}
return fmt.Sprintf("%d", m.{{.PrimaryKeyField}}) return fmt.Sprintf("%d", m.{{.PrimaryKeyField}})
{{- end}}
} }
{{end}} {{end}}
{{if and .Config.GenerateSetID .PrimaryKeyField}} {{if and .Config.GenerateSetID .PrimaryKeyField}}
// SetID sets the primary key value // SetID sets the primary key value
func (m {{.Name}}) SetID(newid int64) { func (m {{.Name}}) SetID(newid {{.PrimaryKeyIDType}}) {
m.UpdateID(newid) m.UpdateID(newid)
} }
{{end}} {{end}}
{{if and .Config.GenerateUpdateID .PrimaryKeyField}} {{if and .Config.GenerateUpdateID .PrimaryKeyField}}
// UpdateID updates the primary key value // UpdateID updates the primary key value
func (m *{{.Name}}) UpdateID(newid int64) { func (m *{{.Name}}) UpdateID(newid {{.PrimaryKeyIDType}}) {
{{if .PrimaryKeyIsSQL -}}
{{if .PrimaryKeyIsStr -}}
m.{{.PrimaryKeyField}}.FromString(newid)
{{- else -}}
m.{{.PrimaryKeyField}}.FromString(fmt.Sprintf("%d", newid))
{{- end}}
{{- else -}}
{{if .PrimaryKeyIsStr -}}
m.{{.PrimaryKeyField}} = newid
{{- else -}}
m.{{.PrimaryKeyField}} = {{.PrimaryKeyType}}(newid) m.{{.PrimaryKeyField}} = {{.PrimaryKeyType}}(newid)
{{- end}}
{{- end}}
} }
{{end}} {{end}}
{{if and .Config.GenerateGetIDName .IDColumnName}} {{if and .Config.GenerateGetIDName .IDColumnName}}

View File

@@ -99,8 +99,8 @@ func (w *Writer) writeSingleFile(db *models.Database) error {
} }
} }
// Add fmt import if GetIDStr is enabled // Add fmt import when generated helper methods need string formatting.
if w.config.GenerateGetIDStr { if w.needsFmtImport(templateData.Models) {
templateData.AddImport("\"fmt\"") templateData.AddImport("\"fmt\"")
} }
@@ -189,8 +189,8 @@ func (w *Writer) writeMultiFile(db *models.Database) error {
} }
} }
// Add fmt import if GetIDStr is enabled // Add fmt import when generated helper methods need string formatting.
if w.config.GenerateGetIDStr { if w.needsFmtImport(templateData.Models) {
templateData.AddImport("\"fmt\"") templateData.AddImport("\"fmt\"")
} }
@@ -295,6 +295,26 @@ func (w *Writer) addRelationshipFields(modelData *ModelData, table *models.Table
} }
} }
func (w *Writer) needsFmtImport(models []*ModelData) bool {
if w.config.GenerateGetIDStr {
for _, model := range models {
if model.PrimaryKeyField != "" && !model.PrimaryKeyIsSQL && !model.PrimaryKeyIsStr {
return true
}
}
}
if w.config.GenerateUpdateID {
for _, model := range models {
if model.PrimaryKeyField != "" && model.PrimaryKeyIsSQL && !model.PrimaryKeyIsStr {
return true
}
}
}
return false
}
// findTable finds a table by schema and name in the database // findTable finds a table by schema and name in the database
func (w *Writer) findTable(schemaName, tableName string, db *models.Database) *models.Table { func (w *Writer) findTable(schemaName, tableName string, db *models.Database) *models.Table {
for _, schema := range db.Schemas { for _, schema := range db.Schemas {

View File

@@ -598,6 +598,55 @@ func TestWriter_UpdateIDTypeSafety(t *testing.T) {
} }
} }
func TestWriter_StringPrimaryKeyHelpers_Gorm(t *testing.T) {
table := models.InitTable("accounts", "public")
table.Columns["id"] = &models.Column{
Name: "id",
Type: "uuid",
NotNull: true,
IsPrimaryKey: true,
}
tmpDir := t.TempDir()
opts := &writers.WriterOptions{
PackageName: "models",
OutputPath: filepath.Join(tmpDir, "test.go"),
}
writer := NewWriter(opts)
err := writer.WriteTable(table)
if err != nil {
t.Fatalf("WriteTable failed: %v", err)
}
content, err := os.ReadFile(opts.OutputPath)
if err != nil {
t.Fatalf("Failed to read generated file: %v", err)
}
generated := string(content)
expectations := []string{
"ID string",
"func (m ModelPublicAccounts) GetID() string",
"return m.ID",
"func (m ModelPublicAccounts) GetIDStr() string",
"func (m ModelPublicAccounts) SetID(newid string)",
"func (m *ModelPublicAccounts) UpdateID(newid string)",
"m.ID = newid",
}
for _, expected := range expectations {
if !strings.Contains(generated, expected) {
t.Errorf("Generated code missing expected content: %q\nGenerated:\n%s", expected, generated)
}
}
if strings.Contains(generated, "GetID() int64") || strings.Contains(generated, "UpdateID(newid int64)") {
t.Errorf("String primary keys should not use int64 helper signatures\nGenerated:\n%s", generated)
}
}
func TestNameConverter_SnakeCaseToPascalCase(t *testing.T) { func TestNameConverter_SnakeCaseToPascalCase(t *testing.T) {
tests := []struct { tests := []struct {
input string input string

View File

@@ -31,6 +31,10 @@ type MigrationWriter struct {
// NewMigrationWriter creates a new templated migration writer // NewMigrationWriter creates a new templated migration writer
func NewMigrationWriter(options *writers.WriterOptions) (*MigrationWriter, error) { func NewMigrationWriter(options *writers.WriterOptions) (*MigrationWriter, error) {
if options == nil {
options = &writers.WriterOptions{}
}
executor, err := NewTemplateExecutor(options.FlattenSchema) executor, err := NewTemplateExecutor(options.FlattenSchema)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create template executor: %w", err) return nil, fmt.Errorf("failed to create template executor: %w", err)
@@ -44,6 +48,16 @@ func NewMigrationWriter(options *writers.WriterOptions) (*MigrationWriter, error
// WriteMigration generates migration scripts using templates // WriteMigration generates migration scripts using templates
func (w *MigrationWriter) WriteMigration(model *models.Database, current *models.Database) error { func (w *MigrationWriter) WriteMigration(model *models.Database, current *models.Database) error {
if model == nil {
return fmt.Errorf("model database is required")
}
if w.options == nil {
w.options = &writers.WriterOptions{}
}
if current == nil {
current = models.InitDatabase(model.Name)
}
var writer io.Writer var writer io.Writer
var file *os.File var file *os.File
var err error var err error
@@ -86,9 +100,16 @@ func (w *MigrationWriter) WriteMigration(model *models.Database, current *models
// Process each schema in the model // Process each schema in the model
for _, modelSchema := range model.Schemas { for _, modelSchema := range model.Schemas {
if modelSchema == nil {
continue
}
// Find corresponding schema in current database // Find corresponding schema in current database
var currentSchema *models.Schema var currentSchema *models.Schema
for _, cs := range current.Schemas { for _, cs := range current.Schemas {
if cs == nil {
continue
}
if strings.EqualFold(cs.Name, modelSchema.Name) { if strings.EqualFold(cs.Name, modelSchema.Name) {
currentSchema = cs currentSchema = cs
break break
@@ -139,6 +160,17 @@ func (w *MigrationWriter) WriteMigration(model *models.Database, current *models
func (w *MigrationWriter) generateSchemaScripts(model *models.Schema, current *models.Schema) ([]MigrationScript, error) { func (w *MigrationWriter) generateSchemaScripts(model *models.Schema, current *models.Schema) ([]MigrationScript, error) {
scripts := make([]MigrationScript, 0) scripts := make([]MigrationScript, 0)
if schemaRequiresPGTrgm(model) {
scripts = append(scripts, MigrationScript{
ObjectName: "extension.pg_trgm",
ObjectType: "create extension",
Schema: model.Name,
Priority: 80,
Sequence: len(scripts),
Body: "CREATE EXTENSION IF NOT EXISTS pg_trgm;",
})
}
// Phase 1: Drop constraints and indexes that changed (Priority 11-50) // Phase 1: Drop constraints and indexes that changed (Priority 11-50)
if current != nil { if current != nil {
dropScripts, err := w.generateDropScripts(model, current) dropScripts, err := w.generateDropScripts(model, current)
@@ -340,7 +372,7 @@ func (w *MigrationWriter) generateAlterTableScripts(schema *models.Schema, model
SchemaName: schema.Name, SchemaName: schema.Name,
TableName: modelTable.Name, TableName: modelTable.Name,
ColumnName: modelCol.Name, ColumnName: modelCol.Name,
ColumnType: pgsql.ConvertSQLType(modelCol.Type), ColumnType: effectiveColumnSQLType(modelCol),
Default: defaultVal, Default: defaultVal,
NotNull: modelCol.NotNull, NotNull: modelCol.NotNull,
}) })
@@ -359,12 +391,13 @@ func (w *MigrationWriter) generateAlterTableScripts(schema *models.Schema, model
scripts = append(scripts, script) scripts = append(scripts, script)
} else if !columnsEqual(modelCol, currentCol) { } else if !columnsEqual(modelCol, currentCol) {
// Column exists but properties changed // Column exists but properties changed
if modelCol.Type != currentCol.Type { if !columnTypesEqual(modelCol, currentCol) {
sql, err := w.executor.ExecuteAlterColumnType(AlterColumnTypeData{ sql, err := w.executor.ExecuteAlterColumnType(AlterColumnTypeData{
SchemaName: schema.Name, SchemaName: schema.Name,
TableName: modelTable.Name, TableName: modelTable.Name,
ColumnName: modelCol.Name, ColumnName: modelCol.Name,
NewType: pgsql.ConvertSQLType(modelCol.Type), NewType: effectiveAlterColumnSQLType(modelCol),
UsingExpr: buildAlterColumnUsingExpression(modelCol.Name, effectiveAlterColumnSQLType(modelCol)),
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@@ -545,12 +578,17 @@ func (w *MigrationWriter) generateIndexScripts(model *models.Schema, current *mo
indexType = modelIndex.Type indexType = modelIndex.Type
} }
columnExprs := buildIndexColumnExpressions(modelTable, modelIndex, indexType)
if len(columnExprs) == 0 {
continue
}
sql, err := w.executor.ExecuteCreateIndex(CreateIndexData{ sql, err := w.executor.ExecuteCreateIndex(CreateIndexData{
SchemaName: model.Name, SchemaName: model.Name,
TableName: modelTable.Name, TableName: modelTable.Name,
IndexName: indexName, IndexName: indexName,
IndexType: indexType, IndexType: indexType,
Columns: strings.Join(modelIndex.Columns, ", "), Columns: strings.Join(columnExprs, ", "),
Unique: modelIndex.Unique, Unique: modelIndex.Unique,
}) })
if err != nil { if err != nil {
@@ -573,6 +611,26 @@ func (w *MigrationWriter) generateIndexScripts(model *models.Schema, current *mo
return scripts, nil return scripts, nil
} }
func buildIndexColumnExpressions(table *models.Table, index *models.Index, indexType string) []string {
columnExprs := make([]string, 0, len(index.Columns))
for _, colName := range index.Columns {
colExpr := colName
if table != nil {
if col, ok := resolveIndexColumn(table, colName); ok && col != nil {
colExpr = col.SQLName()
if strings.EqualFold(indexType, "gin") {
opClass := ginOperatorClassForColumn(col, index.Comment)
if opClass != "" {
colExpr = fmt.Sprintf("%s %s", col.SQLName(), opClass)
}
}
}
}
columnExprs = append(columnExprs, colExpr)
}
return columnExprs
}
// generateForeignKeyScripts generates ADD CONSTRAINT FOREIGN KEY scripts using templates // generateForeignKeyScripts generates ADD CONSTRAINT FOREIGN KEY scripts using templates
func (w *MigrationWriter) generateForeignKeyScripts(model *models.Schema, current *models.Schema) ([]MigrationScript, error) { func (w *MigrationWriter) generateForeignKeyScripts(model *models.Schema, current *models.Schema) ([]MigrationScript, error) {
scripts := make([]MigrationScript, 0) scripts := make([]MigrationScript, 0)
@@ -828,11 +886,21 @@ func columnsEqual(col1, col2 *models.Column) bool {
if col1 == nil || col2 == nil { if col1 == nil || col2 == nil {
return false return false
} }
return strings.EqualFold(col1.Type, col2.Type) && return columnTypesEqual(col1, col2) &&
col1.NotNull == col2.NotNull && col1.NotNull == col2.NotNull &&
fmt.Sprintf("%v", col1.Default) == fmt.Sprintf("%v", col2.Default) fmt.Sprintf("%v", col1.Default) == fmt.Sprintf("%v", col2.Default)
} }
func columnTypesEqual(col1, col2 *models.Column) bool {
if col1 == nil || col2 == nil {
return false
}
return strings.EqualFold(
pgsql.NormalizeEquivalentSQLType(effectiveColumnSQLType(col1)),
pgsql.NormalizeEquivalentSQLType(effectiveColumnSQLType(col2)),
)
}
// constraintsEqual checks if two constraints are equal // constraintsEqual checks if two constraints are equal
func constraintsEqual(c1, c2 *models.Constraint) bool { func constraintsEqual(c1, c2 *models.Constraint) bool {
if c1 == nil || c2 == nil { if c1 == nil || c2 == nil {

View File

@@ -97,6 +97,370 @@ func TestWriteMigration_ArrayDefault(t *testing.T) {
} }
} }
func TestWriteMigration_AltersColumnTypeWhenActualTypeDiffers(t *testing.T) {
current := models.InitDatabase("testdb")
currentSchema := models.InitSchema("public")
currentTable := models.InitTable("learnings", "public")
currentDetails := models.InitColumn("details", "learnings", "public")
currentDetails.Type = "jsonb"
currentTable.Columns["details"] = currentDetails
currentSchema.Tables = append(currentSchema.Tables, currentTable)
current.Schemas = append(current.Schemas, currentSchema)
model := models.InitDatabase("testdb")
modelSchema := models.InitSchema("public")
modelTable := models.InitTable("learnings", "public")
modelDetails := models.InitColumn("details", "learnings", "public")
modelDetails.Type = "text"
modelTable.Columns["details"] = modelDetails
modelSchema.Tables = append(modelSchema.Tables, modelTable)
model.Schemas = append(model.Schemas, modelSchema)
var buf bytes.Buffer
writer, err := NewMigrationWriter(&writers.WriterOptions{})
if err != nil {
t.Fatalf("Failed to create writer: %v", err)
}
writer.writer = &buf
if err := writer.WriteMigration(model, current); err != nil {
t.Fatalf("WriteMigration failed: %v", err)
}
output := buf.String()
if !strings.Contains(output, "ALTER TABLE public.learnings") || !strings.Contains(output, "ALTER COLUMN details TYPE text") {
t.Fatalf("expected migration to alter mismatched column type, got:\n%s", output)
}
if !strings.Contains(output, `ALTER COLUMN details TYPE text USING details::text;`) {
t.Fatalf("expected migration type alter to include USING cast, got:\n%s", output)
}
}
func TestWriteMigration_UsesStorageTypeForSerialAlterStatements(t *testing.T) {
current := models.InitDatabase("testdb")
currentSchema := models.InitSchema("public")
currentTable := models.InitTable("learnings", "public")
currentID := models.InitColumn("id", "learnings", "public")
currentID.Type = "uuid"
currentTable.Columns["id"] = currentID
currentSchema.Tables = append(currentSchema.Tables, currentTable)
current.Schemas = append(current.Schemas, currentSchema)
model := models.InitDatabase("testdb")
modelSchema := models.InitSchema("public")
modelTable := models.InitTable("learnings", "public")
modelID := models.InitColumn("id", "learnings", "public")
modelID.Type = "bigserial"
modelTable.Columns["id"] = modelID
modelSchema.Tables = append(modelSchema.Tables, modelTable)
model.Schemas = append(model.Schemas, modelSchema)
var buf bytes.Buffer
writer, err := NewMigrationWriter(&writers.WriterOptions{})
if err != nil {
t.Fatalf("Failed to create writer: %v", err)
}
writer.writer = &buf
if err := writer.WriteMigration(model, current); err != nil {
t.Fatalf("WriteMigration failed: %v", err)
}
output := buf.String()
if !strings.Contains(output, "ALTER COLUMN id TYPE bigint") {
t.Fatalf("expected serial alter to use bigint storage type, got:\n%s", output)
}
if strings.Contains(output, "ALTER COLUMN id TYPE bigserial;") {
t.Fatalf("did not expect invalid bigserial alter statement, got:\n%s", output)
}
if !strings.Contains(output, `ALTER COLUMN id TYPE bigint USING id::bigint;`) {
t.Fatalf("expected serial alter to include USING cast, got:\n%s", output)
}
}
func TestWriteMigration_ArrayAlterIncludesUsingCast(t *testing.T) {
current := models.InitDatabase("testdb")
currentSchema := models.InitSchema("public")
currentTable := models.InitTable("learnings", "public")
currentTags := models.InitColumn("tags", "learnings", "public")
currentTags.Type = "text"
currentTable.Columns["tags"] = currentTags
currentSchema.Tables = append(currentSchema.Tables, currentTable)
current.Schemas = append(current.Schemas, currentSchema)
model := models.InitDatabase("testdb")
modelSchema := models.InitSchema("public")
modelTable := models.InitTable("learnings", "public")
modelTags := models.InitColumn("tags", "learnings", "public")
modelTags.Type = "text[]"
modelTable.Columns["tags"] = modelTags
modelSchema.Tables = append(modelSchema.Tables, modelTable)
model.Schemas = append(model.Schemas, modelSchema)
var buf bytes.Buffer
writer, err := NewMigrationWriter(&writers.WriterOptions{})
if err != nil {
t.Fatalf("Failed to create writer: %v", err)
}
writer.writer = &buf
if err := writer.WriteMigration(model, current); err != nil {
t.Fatalf("WriteMigration failed: %v", err)
}
output := buf.String()
if !strings.Contains(output, `ALTER COLUMN tags TYPE text[] USING tags::text[];`) {
t.Fatalf("expected array alter to include USING cast, got:\n%s", output)
}
}
func TestWriteMigration_DoesNotAlterEquivalentNormalizedColumnType(t *testing.T) {
current := models.InitDatabase("testdb")
currentSchema := models.InitSchema("public")
currentTable := models.InitTable("users", "public")
currentEmail := models.InitColumn("email", "users", "public")
currentEmail.Type = "character varying"
currentEmail.Length = 255
currentTable.Columns["email"] = currentEmail
currentSchema.Tables = append(currentSchema.Tables, currentTable)
current.Schemas = append(current.Schemas, currentSchema)
model := models.InitDatabase("testdb")
modelSchema := models.InitSchema("public")
modelTable := models.InitTable("users", "public")
modelEmail := models.InitColumn("email", "users", "public")
modelEmail.Type = "varchar(255)"
modelTable.Columns["email"] = modelEmail
modelSchema.Tables = append(modelSchema.Tables, modelTable)
model.Schemas = append(model.Schemas, modelSchema)
var buf bytes.Buffer
writer, err := NewMigrationWriter(&writers.WriterOptions{})
if err != nil {
t.Fatalf("Failed to create writer: %v", err)
}
writer.writer = &buf
if err := writer.WriteMigration(model, current); err != nil {
t.Fatalf("WriteMigration failed: %v", err)
}
output := buf.String()
if strings.Contains(output, "ALTER COLUMN email TYPE") {
t.Fatalf("did not expect alter type for equivalent normalized types, got:\n%s", output)
}
}
func TestWriteMigration_GinIndexOnTextUsesTrigramOperatorClass(t *testing.T) {
current := models.InitDatabase("testdb")
currentSchema := models.InitSchema("public")
current.Schemas = append(current.Schemas, currentSchema)
model := models.InitDatabase("testdb")
modelSchema := models.InitSchema("public")
table := models.InitTable("articles", "public")
titleCol := models.InitColumn("title", "articles", "public")
titleCol.Type = "text"
table.Columns["title"] = titleCol
index := &models.Index{
Name: "idx_articles_title_gin",
Type: "gin",
Columns: []string{"title"},
}
table.Indexes[index.Name] = index
modelSchema.Tables = append(modelSchema.Tables, table)
model.Schemas = append(model.Schemas, modelSchema)
var buf bytes.Buffer
writer, err := NewMigrationWriter(&writers.WriterOptions{})
if err != nil {
t.Fatalf("Failed to create writer: %v", err)
}
writer.writer = &buf
if err := writer.WriteMigration(model, current); err != nil {
t.Fatalf("WriteMigration failed: %v", err)
}
output := buf.String()
if !strings.Contains(output, "CREATE EXTENSION IF NOT EXISTS pg_trgm;") {
t.Fatalf("expected trigram extension for text GIN migration index, got:\n%s", output)
}
if !strings.Contains(output, "USING gin (title gin_trgm_ops)") {
t.Fatalf("expected GIN text index to include gin_trgm_ops, got:\n%s", output)
}
}
func TestWriteMigration_GinIndexOnQuotedTextColumnUsesTrigramOperatorClass(t *testing.T) {
current := models.InitDatabase("testdb")
currentSchema := models.InitSchema("public")
current.Schemas = append(current.Schemas, currentSchema)
model := models.InitDatabase("testdb")
modelSchema := models.InitSchema("public")
table := models.InitTable("agent_personas", "public")
nameCol := models.InitColumn("name", "agent_personas", "public")
nameCol.Type = "text"
table.Columns["name"] = nameCol
index := &models.Index{
Name: "idx_agent_personas_name_gin",
Type: "gin",
Columns: []string{`"name"`},
}
table.Indexes[index.Name] = index
modelSchema.Tables = append(modelSchema.Tables, table)
model.Schemas = append(model.Schemas, modelSchema)
var buf bytes.Buffer
writer, err := NewMigrationWriter(&writers.WriterOptions{})
if err != nil {
t.Fatalf("Failed to create writer: %v", err)
}
writer.writer = &buf
if err := writer.WriteMigration(model, current); err != nil {
t.Fatalf("WriteMigration failed: %v", err)
}
output := buf.String()
if !strings.Contains(output, "USING gin (name gin_trgm_ops)") {
t.Fatalf("expected quoted text column GIN index to include gin_trgm_ops, got:\n%s", output)
}
}
func TestWriteMigration_GinIndexOnTextArrayDoesNotUseTrigramOperatorClass(t *testing.T) {
current := models.InitDatabase("testdb")
currentSchema := models.InitSchema("public")
current.Schemas = append(current.Schemas, currentSchema)
model := models.InitDatabase("testdb")
modelSchema := models.InitSchema("public")
table := models.InitTable("plans", "public")
tagsCol := models.InitColumn("tags", "plans", "public")
tagsCol.Type = "text[]"
table.Columns["tags"] = tagsCol
index := &models.Index{
Name: "idx_plans_tags",
Type: "gin",
Columns: []string{"tags"},
}
table.Indexes[index.Name] = index
modelSchema.Tables = append(modelSchema.Tables, table)
model.Schemas = append(model.Schemas, modelSchema)
var buf bytes.Buffer
writer, err := NewMigrationWriter(&writers.WriterOptions{})
if err != nil {
t.Fatalf("Failed to create writer: %v", err)
}
writer.writer = &buf
if err := writer.WriteMigration(model, current); err != nil {
t.Fatalf("WriteMigration failed: %v", err)
}
output := buf.String()
if !strings.Contains(output, "USING gin (tags array_ops)") {
t.Fatalf("expected GIN array index with array_ops, got:\n%s", output)
}
if strings.Contains(output, "gin_trgm_ops") {
t.Fatalf("did not expect gin_trgm_ops for text[] migration index, got:\n%s", output)
}
}
func TestWriteMigration_GinIndexOnJSONBUsesJSONBOperatorClass(t *testing.T) {
current := models.InitDatabase("testdb")
currentSchema := models.InitSchema("public")
current.Schemas = append(current.Schemas, currentSchema)
model := models.InitDatabase("testdb")
modelSchema := models.InitSchema("public")
table := models.InitTable("learnings", "public")
detailsCol := models.InitColumn("details", "learnings", "public")
detailsCol.Type = "jsonb"
table.Columns["details"] = detailsCol
index := &models.Index{
Name: "idx_learnings_details",
Type: "gin",
Columns: []string{"details"},
}
table.Indexes[index.Name] = index
modelSchema.Tables = append(modelSchema.Tables, table)
model.Schemas = append(model.Schemas, modelSchema)
var buf bytes.Buffer
writer, err := NewMigrationWriter(&writers.WriterOptions{})
if err != nil {
t.Fatalf("Failed to create writer: %v", err)
}
writer.writer = &buf
if err := writer.WriteMigration(model, current); err != nil {
t.Fatalf("WriteMigration failed: %v", err)
}
output := buf.String()
if !strings.Contains(output, "USING gin (details jsonb_ops)") {
t.Fatalf("expected GIN jsonb index to include jsonb_ops, got:\n%s", output)
}
if strings.Contains(output, "gin_trgm_ops") {
t.Fatalf("did not expect gin_trgm_ops for jsonb migration index, got:\n%s", output)
}
}
func TestWriteMigration_GinIndexOnJSONBIgnoresIncompatibleTrigramOperatorClass(t *testing.T) {
current := models.InitDatabase("testdb")
currentSchema := models.InitSchema("public")
current.Schemas = append(current.Schemas, currentSchema)
model := models.InitDatabase("testdb")
modelSchema := models.InitSchema("public")
table := models.InitTable("learnings", "public")
detailsCol := models.InitColumn("details", "learnings", "public")
detailsCol.Type = "jsonb"
table.Columns["details"] = detailsCol
index := &models.Index{
Name: "idx_learnings_details",
Type: "gin",
Columns: []string{"details"},
Comment: "gin_trgm_ops",
}
table.Indexes[index.Name] = index
modelSchema.Tables = append(modelSchema.Tables, table)
model.Schemas = append(model.Schemas, modelSchema)
var buf bytes.Buffer
writer, err := NewMigrationWriter(&writers.WriterOptions{})
if err != nil {
t.Fatalf("Failed to create writer: %v", err)
}
writer.writer = &buf
if err := writer.WriteMigration(model, current); err != nil {
t.Fatalf("WriteMigration failed: %v", err)
}
output := buf.String()
if !strings.Contains(output, "USING gin (details jsonb_ops)") {
t.Fatalf("expected incompatible trigram hint on jsonb to fall back to jsonb_ops, got:\n%s", output)
}
}
func TestWriteMigration_WithAudit(t *testing.T) { func TestWriteMigration_WithAudit(t *testing.T) {
// Current database (empty) // Current database (empty)
current := models.InitDatabase("testdb") current := models.InitDatabase("testdb")
@@ -322,3 +686,46 @@ func TestWriteMigration_NumericConstraintNames(t *testing.T) {
t.Error("Migration missing FOREIGN KEY") t.Error("Migration missing FOREIGN KEY")
} }
} }
func TestNewMigrationWriter_NilOptions(t *testing.T) {
writer, err := NewMigrationWriter(nil)
if err != nil {
t.Fatalf("NewMigrationWriter(nil) returned error: %v", err)
}
if writer == nil {
t.Fatal("expected writer instance")
}
if writer.options == nil {
t.Fatal("expected default writer options to be initialized")
}
}
func TestWriteMigration_NilCurrentTreatsDatabaseAsEmpty(t *testing.T) {
model := models.InitDatabase("testdb")
modelSchema := models.InitSchema("public")
table := models.InitTable("users", "public")
idCol := models.InitColumn("id", "users", "public")
idCol.Type = "integer"
idCol.NotNull = true
table.Columns["id"] = idCol
modelSchema.Tables = append(modelSchema.Tables, table)
model.Schemas = append(model.Schemas, modelSchema)
var buf bytes.Buffer
writer, err := NewMigrationWriter(nil)
if err != nil {
t.Fatalf("Failed to create writer: %v", err)
}
writer.writer = &buf
if err := writer.WriteMigration(model, nil); err != nil {
t.Fatalf("WriteMigration with nil current failed: %v", err)
}
output := buf.String()
if !strings.Contains(output, "CREATE TABLE") {
t.Fatalf("expected CREATE TABLE in migration output, got:\n%s", output)
}
}

View File

@@ -95,6 +95,16 @@ type AlterColumnTypeData struct {
TableName string TableName string
ColumnName string ColumnName string
NewType string NewType string
UsingExpr string
}
type AlterColumnTypeWithCheckData struct {
SchemaName string
TableName string
ColumnName string
NewType string
EquivalentTypes string
UsingExpr string
} }
// AlterColumnDefaultData contains data for alter column default template // AlterColumnDefaultData contains data for alter column default template
@@ -267,6 +277,7 @@ type CreatePrimaryKeyWithAutoGenCheckData struct {
ConstraintName string ConstraintName string
AutoGenNames string // Comma-separated list of names like "'name1', 'name2'" AutoGenNames string // Comma-separated list of names like "'name1', 'name2'"
Columns string Columns string
ColumnNames string // Comma-separated list of quoted column names like "'id', 'tenant_id'"
} }
// Execute methods for each template // Execute methods for each template
@@ -301,6 +312,15 @@ func (te *TemplateExecutor) ExecuteAlterColumnType(data AlterColumnTypeData) (st
return buf.String(), nil return buf.String(), nil
} }
func (te *TemplateExecutor) ExecuteAlterColumnTypeWithCheck(data AlterColumnTypeWithCheckData) (string, error) {
var buf bytes.Buffer
err := te.templates.ExecuteTemplate(&buf, "alter_column_type_with_check.tmpl", data)
if err != nil {
return "", fmt.Errorf("failed to execute alter_column_type_with_check template: %w", err)
}
return buf.String(), nil
}
// ExecuteAlterColumnDefault executes the alter column default template // ExecuteAlterColumnDefault executes the alter column default template
func (te *TemplateExecutor) ExecuteAlterColumnDefault(data AlterColumnDefaultData) (string, error) { func (te *TemplateExecutor) ExecuteAlterColumnDefault(data AlterColumnDefaultData) (string, error) {
var buf bytes.Buffer var buf bytes.Buffer

View File

@@ -1,2 +1,2 @@
ALTER TABLE {{qual_table .SchemaName .TableName}} ALTER TABLE {{qual_table .SchemaName .TableName}}
ALTER COLUMN {{quote_ident .ColumnName}} TYPE {{.NewType}}; ALTER COLUMN {{quote_ident .ColumnName}} TYPE {{.NewType}}{{if .UsingExpr}} USING {{.UsingExpr}}{{end}};

View File

@@ -0,0 +1,22 @@
DO $$
DECLARE
current_type text;
BEGIN
SELECT pg_catalog.format_type(a.atttypid, a.atttypmod)
INTO current_type
FROM pg_attribute a
JOIN pg_class t ON t.oid = a.attrelid
JOIN pg_namespace n ON n.oid = t.relnamespace
WHERE n.nspname = '{{.SchemaName}}'
AND t.relname = '{{.TableName}}'
AND a.attname = '{{.ColumnName}}'
AND a.attnum > 0
AND NOT a.attisdropped;
IF current_type IS NOT NULL
AND current_type <> ALL(ARRAY[{{.EquivalentTypes}}]) THEN
ALTER TABLE {{qual_table .SchemaName .TableName}}
ALTER COLUMN {{quote_ident .ColumnName}} TYPE {{.NewType}}{{if .UsingExpr}} USING {{.UsingExpr}}{{end}};
END IF;
END;
$$;

View File

@@ -1,26 +1,42 @@
DO $$ DO $$
DECLARE DECLARE
auto_pk_name text; current_pk_name text;
current_pk_matches boolean := false;
BEGIN BEGIN
-- Drop auto-generated primary key if it exists SELECT tc.constraint_name,
SELECT constraint_name INTO auto_pk_name COALESCE(
FROM information_schema.table_constraints ARRAY(
WHERE table_schema = '{{.SchemaName}}' SELECT a.attname::text
AND table_name = '{{.TableName}}' FROM pg_constraint c
AND constraint_type = 'PRIMARY KEY' JOIN pg_class t ON t.oid = c.conrelid
AND constraint_name IN ({{.AutoGenNames}}); JOIN pg_namespace n ON n.oid = t.relnamespace
JOIN unnest(c.conkey) WITH ORDINALITY AS cols(attnum, ord)
ON TRUE
JOIN pg_attribute a
ON a.attrelid = t.oid
AND a.attnum = cols.attnum
WHERE c.contype = 'p'
AND n.nspname = '{{.SchemaName}}'
AND t.relname = '{{.TableName}}'
ORDER BY cols.ord
),
ARRAY[]::text[]
) = ARRAY[{{.ColumnNames}}]
INTO current_pk_name, current_pk_matches
FROM information_schema.table_constraints tc
WHERE tc.table_schema = '{{.SchemaName}}'
AND tc.table_name = '{{.TableName}}'
AND tc.constraint_type = 'PRIMARY KEY';
IF auto_pk_name IS NOT NULL THEN IF current_pk_name IS NOT NULL
EXECUTE 'ALTER TABLE {{qual_table .SchemaName .TableName}} DROP CONSTRAINT ' || quote_ident(auto_pk_name); AND NOT current_pk_matches
AND current_pk_name IN ({{.AutoGenNames}}) THEN
EXECUTE 'ALTER TABLE {{qual_table .SchemaName .TableName}} DROP CONSTRAINT ' || quote_ident(current_pk_name);
END IF; END IF;
-- Add named primary key if it doesn't exist -- Add the desired primary key only when no matching primary key already exists.
IF NOT EXISTS ( IF current_pk_name IS NULL
SELECT 1 FROM information_schema.table_constraints OR (NOT current_pk_matches AND current_pk_name IN ({{.AutoGenNames}})) THEN
WHERE table_schema = '{{.SchemaName}}'
AND table_name = '{{.TableName}}'
AND constraint_name = '{{.ConstraintName}}'
) THEN
ALTER TABLE {{qual_table .SchemaName .TableName}} ADD CONSTRAINT {{quote_ident .ConstraintName}} PRIMARY KEY ({{.Columns}}); ALTER TABLE {{qual_table .SchemaName .TableName}} ADD CONSTRAINT {{quote_ident .ConstraintName}} PRIMARY KEY ({{.Columns}});
END IF; END IF;
END; END;

View File

@@ -143,6 +143,10 @@ func (w *Writer) GenerateSchemaStatements(schema *models.Schema) ([]string, erro
statements = append(statements, fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %s", schema.SQLName())) statements = append(statements, fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %s", schema.SQLName()))
} }
if schemaRequiresPGTrgm(schema) {
statements = append(statements, `CREATE EXTENSION IF NOT EXISTS pg_trgm`)
}
// Phase 2: Create sequences // Phase 2: Create sequences
for _, table := range schema.Tables { for _, table := range schema.Tables {
pk := table.GetPrimaryKey() pk := table.GetPrimaryKey()
@@ -181,6 +185,12 @@ func (w *Writer) GenerateSchemaStatements(schema *models.Schema) ([]string, erro
} }
statements = append(statements, addColStmts...) statements = append(statements, addColStmts...)
alterTypeStmts, err := w.GenerateAlterColumnTypeStatements(schema)
if err != nil {
return nil, fmt.Errorf("failed to generate alter column type statements: %w", err)
}
statements = append(statements, alterTypeStmts...)
// Phase 4: Primary keys // Phase 4: Primary keys
for _, table := range schema.Tables { for _, table := range schema.Tables {
// First check for explicit PrimaryKeyConstraint // First check for explicit PrimaryKeyConstraint
@@ -228,6 +238,7 @@ func (w *Writer) GenerateSchemaStatements(schema *models.Schema) ([]string, erro
ConstraintName: pkName, ConstraintName: pkName,
AutoGenNames: formatStringList(autoGenPKNames), AutoGenNames: formatStringList(autoGenPKNames),
Columns: strings.Join(pkColumns, ", "), Columns: strings.Join(pkColumns, ", "),
ColumnNames: formatStringList(pkColumns),
} }
stmt, err := w.executor.ExecuteCreatePrimaryKeyWithAutoGenCheck(data) stmt, err := w.executor.ExecuteCreatePrimaryKeyWithAutoGenCheck(data)
@@ -260,14 +271,11 @@ func (w *Writer) GenerateSchemaStatements(schema *models.Schema) ([]string, erro
columnExprs := make([]string, 0, len(index.Columns)) columnExprs := make([]string, 0, len(index.Columns))
for _, colName := range index.Columns { for _, colName := range index.Columns {
colExpr := colName colExpr := colName
if col, ok := table.Columns[colName]; ok { if col, ok := resolveIndexColumn(table, colName); ok {
// For GIN indexes on text columns, add operator class if strings.EqualFold(indexType, "gin") {
if strings.EqualFold(indexType, "gin") && isTextType(col.Type) { if opClass := ginOperatorClassForColumn(col, index.Comment); opClass != "" {
opClass := extractOperatorClass(index.Comment) colExpr = fmt.Sprintf("%s %s", colName, opClass)
if opClass == "" {
opClass = "gin_trgm_ops"
} }
colExpr = fmt.Sprintf("%s %s", colName, opClass)
} }
} }
columnExprs = append(columnExprs, colExpr) columnExprs = append(columnExprs, colExpr)
@@ -436,6 +444,33 @@ func (w *Writer) GenerateAddColumnStatements(schema *models.Schema) ([]string, e
return statements, nil return statements, nil
} }
func (w *Writer) GenerateAlterColumnTypeStatements(schema *models.Schema) ([]string, error) {
statements := []string{}
statements = append(statements, fmt.Sprintf("-- Alter column types for schema: %s", schema.Name))
for _, table := range schema.Tables {
columns := getSortedColumns(table.Columns)
for _, col := range columns {
targetType := effectiveAlterColumnSQLType(col)
stmt, err := w.executor.ExecuteAlterColumnTypeWithCheck(AlterColumnTypeWithCheckData{
SchemaName: schema.Name,
TableName: table.Name,
ColumnName: col.Name,
NewType: targetType,
EquivalentTypes: equivalentTypeListSQL(targetType),
UsingExpr: buildAlterColumnUsingExpression(col.Name, targetType),
})
if err != nil {
return nil, fmt.Errorf("failed to generate alter column type for %s.%s.%s: %w", schema.Name, table.Name, col.Name, err)
}
statements = append(statements, stmt)
}
}
return statements, nil
}
// GenerateAddColumnsForDatabase generates ALTER TABLE ADD COLUMN statements for the entire database // GenerateAddColumnsForDatabase generates ALTER TABLE ADD COLUMN statements for the entire database
func (w *Writer) GenerateAddColumnsForDatabase(db *models.Database) ([]string, error) { func (w *Writer) GenerateAddColumnsForDatabase(db *models.Database) ([]string, error) {
statements := []string{} statements := []string{}
@@ -488,31 +523,7 @@ func (w *Writer) generateCreateTableStatement(schema *models.Schema, table *mode
func (w *Writer) generateColumnDefinition(col *models.Column) string { func (w *Writer) generateColumnDefinition(col *models.Column) string {
parts := []string{col.SQLName()} parts := []string{col.SQLName()}
// Type with length/precision - convert to valid PostgreSQL type parts = append(parts, effectiveColumnSQLType(col))
baseType := pgsql.ConvertSQLType(col.Type)
typeStr := baseType
hasExplicitTypeModifier := pgsql.HasExplicitTypeModifier(baseType)
// Only add size specifiers for types that support them
if !hasExplicitTypeModifier && col.Length > 0 && col.Precision == 0 {
if pgsql.SupportsLength(baseType) {
typeStr = fmt.Sprintf("%s(%d)", baseType, col.Length)
} else if isTextTypeWithoutLength(baseType) {
// Convert text with length to varchar
typeStr = fmt.Sprintf("varchar(%d)", col.Length)
}
// For types that don't support length (integer, bigint, etc.), ignore the length
} else if !hasExplicitTypeModifier && col.Precision > 0 {
if pgsql.SupportsPrecision(baseType) {
if col.Scale > 0 {
typeStr = fmt.Sprintf("%s(%d,%d)", baseType, col.Precision, col.Scale)
} else {
typeStr = fmt.Sprintf("%s(%d)", baseType, col.Precision)
}
}
// For types that don't support precision, ignore it
}
parts = append(parts, typeStr)
// NOT NULL // NOT NULL
if col.NotNull { if col.NotNull {
@@ -534,6 +545,64 @@ func (w *Writer) generateColumnDefinition(col *models.Column) string {
return strings.Join(parts, " ") return strings.Join(parts, " ")
} }
func effectiveColumnSQLType(col *models.Column) string {
if col == nil {
return ""
}
baseType := pgsql.ConvertSQLType(col.Type)
typeStr := baseType
hasExplicitTypeModifier := pgsql.HasExplicitTypeModifier(baseType)
if !hasExplicitTypeModifier && col.Length > 0 && col.Precision == 0 {
if pgsql.SupportsLength(baseType) {
typeStr = fmt.Sprintf("%s(%d)", baseType, col.Length)
} else if isTextTypeWithoutLength(baseType) {
typeStr = fmt.Sprintf("varchar(%d)", col.Length)
}
} else if !hasExplicitTypeModifier && col.Precision > 0 {
if pgsql.SupportsPrecision(baseType) {
if col.Scale > 0 {
typeStr = fmt.Sprintf("%s(%d,%d)", baseType, col.Precision, col.Scale)
} else {
typeStr = fmt.Sprintf("%s(%d)", baseType, col.Precision)
}
}
}
return typeStr
}
func effectiveAlterColumnSQLType(col *models.Column) string {
typeStr := effectiveColumnSQLType(col)
switch strings.ToLower(strings.TrimSpace(typeStr)) {
case "smallserial":
return "smallint"
case "serial":
return "integer"
case "bigserial":
return "bigint"
default:
return typeStr
}
}
func buildAlterColumnUsingExpression(columnName, targetType string) string {
if strings.TrimSpace(columnName) == "" || strings.TrimSpace(targetType) == "" {
return ""
}
return fmt.Sprintf("%s::%s", quoteIdent(columnName), targetType)
}
func equivalentTypeListSQL(sqlType string) string {
variants := pgsql.EquivalentSQLTypeVariants(sqlType)
quoted := make([]string, 0, len(variants))
for _, variant := range variants {
quoted = append(quoted, fmt.Sprintf("'%s'", escapeQuote(variant)))
}
return strings.Join(quoted, ", ")
}
// WriteSchema writes a single schema and all its tables // WriteSchema writes a single schema and all its tables
func (w *Writer) WriteSchema(schema *models.Schema) error { func (w *Writer) WriteSchema(schema *models.Schema) error {
if w.writer == nil { if w.writer == nil {
@@ -545,6 +614,10 @@ func (w *Writer) WriteSchema(schema *models.Schema) error {
return err return err
} }
if err := w.writeRequiredExtensions(schema); err != nil {
return err
}
// Phase 2: Create sequences (priority 80) // Phase 2: Create sequences (priority 80)
if err := w.writeSequences(schema); err != nil { if err := w.writeSequences(schema); err != nil {
return err return err
@@ -560,6 +633,10 @@ func (w *Writer) WriteSchema(schema *models.Schema) error {
return err return err
} }
if err := w.writeAlterColumnTypes(schema); err != nil {
return err
}
// Phase 4: Create primary keys (priority 160) // Phase 4: Create primary keys (priority 160)
if err := w.writePrimaryKeys(schema); err != nil { if err := w.writePrimaryKeys(schema); err != nil {
return err return err
@@ -660,6 +737,16 @@ func (w *Writer) writeCreateSchema(schema *models.Schema) error {
return nil return nil
} }
func (w *Writer) writeRequiredExtensions(schema *models.Schema) error {
if !schemaRequiresPGTrgm(schema) {
return nil
}
fmt.Fprintln(w.writer, "CREATE EXTENSION IF NOT EXISTS pg_trgm;")
fmt.Fprintln(w.writer)
return nil
}
// writeSequences generates CREATE SEQUENCE statements for identity columns // writeSequences generates CREATE SEQUENCE statements for identity columns
func (w *Writer) writeSequences(schema *models.Schema) error { func (w *Writer) writeSequences(schema *models.Schema) error {
fmt.Fprintf(w.writer, "-- Sequences for schema: %s\n", schema.Name) fmt.Fprintf(w.writer, "-- Sequences for schema: %s\n", schema.Name)
@@ -753,6 +840,21 @@ func (w *Writer) writeAddColumns(schema *models.Schema) error {
return nil return nil
} }
func (w *Writer) writeAlterColumnTypes(schema *models.Schema) error {
fmt.Fprintf(w.writer, "-- Alter column types for schema: %s\n", schema.Name)
statements, err := w.GenerateAlterColumnTypeStatements(schema)
if err != nil {
return err
}
for _, stmt := range statements[1:] {
fmt.Fprint(w.writer, stmt)
fmt.Fprint(w.writer, "\n")
}
return nil
}
// writePrimaryKeys generates ALTER TABLE statements for primary keys // writePrimaryKeys generates ALTER TABLE statements for primary keys
func (w *Writer) writePrimaryKeys(schema *models.Schema) error { func (w *Writer) writePrimaryKeys(schema *models.Schema) error {
fmt.Fprintf(w.writer, "-- Primary keys for schema: %s\n", schema.Name) fmt.Fprintf(w.writer, "-- Primary keys for schema: %s\n", schema.Name)
@@ -806,6 +908,7 @@ func (w *Writer) writePrimaryKeys(schema *models.Schema) error {
ConstraintName: pkName, ConstraintName: pkName,
AutoGenNames: formatStringList(autoGenPKNames), AutoGenNames: formatStringList(autoGenPKNames),
Columns: strings.Join(columnNames, ", "), Columns: strings.Join(columnNames, ", "),
ColumnNames: formatStringList(columnNames),
} }
sql, err := w.executor.ExecuteCreatePrimaryKeyWithAutoGenCheck(data) sql, err := w.executor.ExecuteCreatePrimaryKeyWithAutoGenCheck(data)
@@ -853,15 +956,13 @@ func (w *Writer) writeIndexes(schema *models.Schema) error {
// Build column list with operator class support for GIN indexes // Build column list with operator class support for GIN indexes
columnExprs := make([]string, 0, len(index.Columns)) columnExprs := make([]string, 0, len(index.Columns))
for _, colName := range index.Columns { for _, colName := range index.Columns {
if col, ok := table.Columns[colName]; ok { if col, ok := resolveIndexColumn(table, colName); ok {
colExpr := col.SQLName() colExpr := col.SQLName()
// For GIN indexes on text columns, add operator class if strings.EqualFold(index.Type, "gin") {
if strings.EqualFold(index.Type, "gin") && isTextType(col.Type) { opClass := ginOperatorClassForColumn(col, index.Comment)
opClass := extractOperatorClass(index.Comment) if opClass != "" {
if opClass == "" { colExpr = fmt.Sprintf("%s %s", col.SQLName(), opClass)
opClass = "gin_trgm_ops"
} }
colExpr = fmt.Sprintf("%s %s", col.SQLName(), opClass)
} }
columnExprs = append(columnExprs, colExpr) columnExprs = append(columnExprs, colExpr)
} }
@@ -1248,23 +1349,126 @@ func isIntegerType(colType string) bool {
} }
// isTextType checks if a column type is a text type (for GIN index operator class) // isTextType checks if a column type is a text type (for GIN index operator class)
func isTextType(colType string) bool { // func isTextType(colType string) bool {
textTypes := []string{"text", "varchar", "character varying", "char", "character", "string"} // textTypes := []string{"text", "varchar", "character varying", "char", "character", "string"}
lowerType := strings.ToLower(colType) // lowerType := strings.ToLower(colType)
if strings.HasSuffix(lowerType, "[]") { // if strings.HasSuffix(lowerType, "[]") {
// return false
// }
// for _, t := range textTypes {
// if strings.HasPrefix(lowerType, t) {
// return true
// }
// }
// return false
// }
// isTextTypeWithoutLength checks if type is text (which should convert to varchar when length is specified)
func isTextTypeWithoutLength(colType string) bool {
return strings.EqualFold(colType, "text")
}
func ginOperatorClassForColumn(col *models.Column, comment string) string {
if col == nil {
return ""
}
sqlType := effectiveColumnSQLType(col)
baseType := pgsql.CanonicalizeBaseType(pgsql.ExtractBaseTypeLower(sqlType))
isArray := pgsql.IsArrayType(sqlType)
requested := extractOperatorClass(comment)
if requested != "" && ginOperatorClassCompatible(baseType, isArray, requested) {
return requested
}
if isArray {
return "array_ops"
}
switch {
case isTextGinBaseType(baseType):
return "gin_trgm_ops"
case baseType == "jsonb":
return "jsonb_ops"
default:
return requested
}
}
func ginOperatorClassCompatible(baseType string, isArray bool, opClass string) bool {
switch opClass {
case "gin_trgm_ops", "gin_bigm_ops":
return !isArray && isTextGinBaseType(baseType)
case "jsonb_ops", "jsonb_path_ops":
return !isArray && baseType == "jsonb"
case "array_ops":
return isArray
default:
return true
}
}
func isTextGinBaseType(baseType string) bool {
switch baseType {
case "text", "varchar", "character varying", "char", "character", "string", "citext", "bpchar":
return true
default:
return false return false
} }
for _, t := range textTypes { }
if strings.HasPrefix(lowerType, t) {
return true func schemaRequiresPGTrgm(schema *models.Schema) bool {
if schema == nil {
return false
}
for _, table := range schema.Tables {
if table == nil {
continue
}
for _, index := range table.Indexes {
if index == nil || !strings.EqualFold(index.Type, "gin") {
continue
}
for _, colName := range index.Columns {
col, ok := resolveIndexColumn(table, colName)
if !ok || col == nil {
continue
}
if ginOperatorClassForColumn(col, index.Comment) == "gin_trgm_ops" {
return true
}
}
} }
} }
return false return false
} }
// isTextTypeWithoutLength checks if type is text (which should convert to varchar when length is specified) func resolveIndexColumn(table *models.Table, colName string) (*models.Column, bool) {
func isTextTypeWithoutLength(colType string) bool { if table == nil {
return strings.EqualFold(colType, "text") return nil, false
}
if col, ok := table.Columns[colName]; ok && col != nil {
return col, true
}
normalized := strings.ToLower(strings.Trim(colName, `"`))
for key, col := range table.Columns {
if col == nil {
continue
}
if strings.ToLower(strings.Trim(key, `"`)) == normalized {
return col, true
}
if strings.ToLower(strings.Trim(col.Name, `"`)) == normalized {
return col, true
}
if strings.ToLower(strings.Trim(col.SQLName(), `"`)) == normalized {
return col, true
}
}
return nil, false
} }
// formatStringList formats a list of strings as a SQL-safe comma-separated quoted list // formatStringList formats a list of strings as a SQL-safe comma-separated quoted list

View File

@@ -116,14 +116,88 @@ func TestWriteDatabase_GinIndexOnTextArrayDoesNotUseTrigramOperatorClass(t *test
} }
output := buf.String() output := buf.String()
if !strings.Contains(output, `USING gin (tags)`) { if !strings.Contains(output, `USING gin (tags array_ops)`) {
t.Fatalf("expected GIN index on array column without explicit trigram opclass, got:\n%s", output) t.Fatalf("expected GIN index on array column with array_ops, got:\n%s", output)
} }
if strings.Contains(output, "gin_trgm_ops") { if strings.Contains(output, "gin_trgm_ops") {
t.Fatalf("did not expect gin_trgm_ops for text[] column, got:\n%s", output) t.Fatalf("did not expect gin_trgm_ops for text[] column, got:\n%s", output)
} }
} }
func TestWriteDatabase_GinIndexOnQuotedTextColumnUsesTrigramOperatorClass(t *testing.T) {
db := models.InitDatabase("testdb")
schema := models.InitSchema("public")
table := models.InitTable("agent_personas", "public")
nameCol := models.InitColumn("name", "agent_personas", "public")
nameCol.Type = "text"
table.Columns["name"] = nameCol
index := &models.Index{
Name: "idx_agent_personas_name_gin",
Type: "gin",
Columns: []string{`"name"`},
}
table.Indexes[index.Name] = index
schema.Tables = append(schema.Tables, table)
db.Schemas = append(db.Schemas, schema)
var buf bytes.Buffer
writer := NewWriter(&writers.WriterOptions{})
writer.writer = &buf
if err := writer.WriteDatabase(db); err != nil {
t.Fatalf("WriteDatabase failed: %v", err)
}
output := buf.String()
if !strings.Contains(output, `CREATE EXTENSION IF NOT EXISTS pg_trgm`) {
t.Fatalf("expected trigram extension for text GIN index, got:\n%s", output)
}
if !strings.Contains(output, `USING gin (name gin_trgm_ops)`) {
t.Fatalf("expected quoted text GIN index to include gin_trgm_ops, got:\n%s", output)
}
}
func TestWriteDatabase_GinIndexOnJSONBUsesJSONBOperatorClass(t *testing.T) {
db := models.InitDatabase("testdb")
schema := models.InitSchema("public")
table := models.InitTable("learnings", "public")
detailsCol := models.InitColumn("details", "learnings", "public")
detailsCol.Type = "jsonb"
table.Columns["details"] = detailsCol
index := &models.Index{
Name: "idx_learnings_details",
Type: "gin",
Columns: []string{"details"},
}
table.Indexes[index.Name] = index
schema.Tables = append(schema.Tables, table)
db.Schemas = append(db.Schemas, schema)
var buf bytes.Buffer
writer := NewWriter(&writers.WriterOptions{})
writer.writer = &buf
if err := writer.WriteDatabase(db); err != nil {
t.Fatalf("WriteDatabase failed: %v", err)
}
output := buf.String()
if !strings.Contains(output, `USING gin (details jsonb_ops)`) {
t.Fatalf("expected GIN jsonb index to include jsonb_ops, got:\n%s", output)
}
if strings.Contains(output, "gin_trgm_ops") {
t.Fatalf("did not expect gin_trgm_ops for jsonb column, got:\n%s", output)
}
}
func TestWriteForeignKeys(t *testing.T) { func TestWriteForeignKeys(t *testing.T) {
// Create a test database with two related tables // Create a test database with two related tables
db := models.InitDatabase("testdb") db := models.InitDatabase("testdb")
@@ -673,9 +747,14 @@ func TestPrimaryKeyExistenceCheck(t *testing.T) {
t.Errorf("Output missing logic to drop auto-generated primary key\nFull output:\n%s", output) t.Errorf("Output missing logic to drop auto-generated primary key\nFull output:\n%s", output)
} }
// Verify it checks for our specific named constraint before adding it // Verify it compares the current primary key columns before dropping/recreating
if !strings.Contains(output, "constraint_name = 'pk_public_products'") { if !strings.Contains(output, "current_pk_matches") || !strings.Contains(output, "ARRAY['id']") {
t.Errorf("Output missing check for our named primary key constraint\nFull output:\n%s", output) t.Errorf("Output missing safe primary key comparison logic\nFull output:\n%s", output)
}
// Verify it only adds the desired key when no PK exists or an auto-generated mismatch was dropped
if !strings.Contains(output, "current_pk_name IS NULL") || !strings.Contains(output, "current_pk_name IN ('products_pkey', 'public_products_pkey')") {
t.Errorf("Output missing guarded primary key creation logic\nFull output:\n%s", output)
} }
} }
@@ -766,6 +845,43 @@ func TestColumnSizeSpecifiers(t *testing.T) {
} }
} }
func TestWriteDatabase_PrimaryKeyTemplateDoesNotDropMatchingAutoPrimaryKey(t *testing.T) {
db := models.InitDatabase("testdb")
schema := models.InitSchema("public")
table := models.InitTable("learnings", "public")
idCol := models.InitColumn("id", "learnings", "public")
idCol.Type = "bigint"
idCol.IsPrimaryKey = true
table.Columns["id"] = idCol
parentCol := models.InitColumn("duplicate_of_learning_id", "learnings", "public")
parentCol.Type = "bigint"
table.Columns["duplicate_of_learning_id"] = parentCol
schema.Tables = append(schema.Tables, table)
db.Schemas = append(db.Schemas, schema)
var buf bytes.Buffer
writer := NewWriter(&writers.WriterOptions{})
writer.writer = &buf
if err := writer.WriteDatabase(db); err != nil {
t.Fatalf("WriteDatabase failed: %v", err)
}
output := buf.String()
if !strings.Contains(output, "current_pk_matches") {
t.Fatalf("expected generated SQL to compare current PK columns, got:\n%s", output)
}
if !strings.Contains(output, "ARRAY['id']") {
t.Fatalf("expected generated SQL to compare against desired PK columns, got:\n%s", output)
}
if !strings.Contains(output, "NOT current_pk_matches") {
t.Fatalf("expected generated SQL to avoid dropping matching PKs, got:\n%s", output)
}
}
func TestGenerateColumnDefinition_PreservesExplicitTypeModifiers(t *testing.T) { func TestGenerateColumnDefinition_PreservesExplicitTypeModifiers(t *testing.T) {
writer := NewWriter(&writers.WriterOptions{}) writer := NewWriter(&writers.WriterOptions{})
@@ -942,3 +1058,82 @@ func TestWriteAddColumnStatements(t *testing.T) {
t.Errorf("Output missing DO block\nFull output:\n%s", output) t.Errorf("Output missing DO block\nFull output:\n%s", output)
} }
} }
func TestWriteSchema_EmitsGuardedAlterColumnTypeStatements(t *testing.T) {
db := models.InitDatabase("testdb")
schema := models.InitSchema("public")
table := models.InitTable("agent_skills", "public")
nameCol := models.InitColumn("name", "agent_skills", "public")
nameCol.Type = "character varying"
nameCol.Length = 255
table.Columns["name"] = nameCol
tagsCol := models.InitColumn("tags", "agent_skills", "public")
tagsCol.Type = "text[]"
table.Columns["tags"] = tagsCol
schema.Tables = append(schema.Tables, table)
db.Schemas = append(db.Schemas, schema)
var buf bytes.Buffer
writer := NewWriter(&writers.WriterOptions{})
writer.writer = &buf
if err := writer.WriteDatabase(db); err != nil {
t.Fatalf("WriteDatabase failed: %v", err)
}
output := buf.String()
if !strings.Contains(output, "-- Alter column types for schema: public") {
t.Fatalf("expected alter column type section, got:\n%s", output)
}
if !strings.Contains(output, "pg_catalog.format_type") {
t.Fatalf("expected guarded live-type check, got:\n%s", output)
}
if !strings.Contains(output, "ALTER COLUMN name TYPE character varying(255)") {
t.Fatalf("expected guarded alter for character varying(255), got:\n%s", output)
}
if !strings.Contains(output, "ARRAY['varchar(255)', 'character varying(255)']") {
t.Fatalf("expected equivalent type spellings for varchar guard, got:\n%s", output)
}
if !strings.Contains(output, "ALTER COLUMN tags TYPE text[]") {
t.Fatalf("expected guarded alter for array type, got:\n%s", output)
}
if !strings.Contains(output, `ALTER COLUMN tags TYPE text[] USING tags::text[];`) {
t.Fatalf("expected guarded alter for array type to include USING cast, got:\n%s", output)
}
}
func TestWriteSchema_UsesStorageTypeForSerialAlterStatements(t *testing.T) {
db := models.InitDatabase("testdb")
schema := models.InitSchema("public")
table := models.InitTable("learnings", "public")
idCol := models.InitColumn("id", "learnings", "public")
idCol.Type = "bigserial"
table.Columns["id"] = idCol
schema.Tables = append(schema.Tables, table)
db.Schemas = append(db.Schemas, schema)
var buf bytes.Buffer
writer := NewWriter(&writers.WriterOptions{})
writer.writer = &buf
if err := writer.WriteDatabase(db); err != nil {
t.Fatalf("WriteDatabase failed: %v", err)
}
output := buf.String()
if !strings.Contains(output, "ALTER COLUMN id TYPE bigint") {
t.Fatalf("expected serial alter to use bigint storage type, got:\n%s", output)
}
if strings.Contains(output, "ALTER COLUMN id TYPE bigserial;") {
t.Fatalf("did not expect invalid bigserial alter statement, got:\n%s", output)
}
if !strings.Contains(output, `ALTER COLUMN id TYPE bigint USING id::bigint;`) {
t.Fatalf("expected serial alter to include USING cast, got:\n%s", output)
}
}