feat(writer): 🎉 Enhance primary key handling and add tests
All checks were successful
CI / Test (1.24) (push) Successful in -26m18s
CI / Test (1.25) (push) Successful in -26m11s
CI / Build (push) Successful in -26m43s
CI / Lint (push) Successful in -26m34s
Release / Build and Release (push) Successful in -26m31s
Integration Tests / Integration Tests (push) Successful in -26m20s

* Implement checks for existing primary keys before adding new ones.
* Drop auto-generated primary keys if they exist.
* Add tests for primary key existence and column size specifiers.
* Improve type conversion handling for PostgreSQL compatibility.
This commit is contained in:
2026-01-31 18:59:32 +02:00
parent 51ab29f8e3
commit c36b5ede2b
2 changed files with 296 additions and 39 deletions

View File

@@ -179,27 +179,67 @@ func (w *Writer) GenerateSchemaStatements(schema *models.Schema) ([]string, erro
} }
} }
var pkColumns []string
var pkName string
if pkConstraint != nil { if pkConstraint != nil {
stmt := fmt.Sprintf("ALTER TABLE %s.%s ADD CONSTRAINT %s PRIMARY KEY (%s)", pkColumns = pkConstraint.Columns
schema.SQLName(), table.SQLName(), pkConstraint.Name, strings.Join(pkConstraint.Columns, ", ")) pkName = pkConstraint.Name
statements = append(statements, stmt)
} else { } else {
// No explicit constraint, check for columns with IsPrimaryKey = true // No explicit constraint, check for columns with IsPrimaryKey = true
pkColumns := []string{} pkCols := []string{}
for _, col := range table.Columns { for _, col := range table.Columns {
if col.IsPrimaryKey { if col.IsPrimaryKey {
pkColumns = append(pkColumns, col.SQLName()) pkCols = append(pkCols, col.SQLName())
} }
} }
if len(pkColumns) > 0 { if len(pkCols) > 0 {
// Sort for consistent output // Sort for consistent output
sort.Strings(pkColumns) sort.Strings(pkCols)
pkName := fmt.Sprintf("pk_%s_%s", schema.SQLName(), table.SQLName()) pkColumns = pkCols
stmt := fmt.Sprintf("ALTER TABLE %s.%s ADD CONSTRAINT %s PRIMARY KEY (%s)", pkName = fmt.Sprintf("pk_%s_%s", schema.SQLName(), table.SQLName())
schema.SQLName(), table.SQLName(), pkName, strings.Join(pkColumns, ", "))
statements = append(statements, stmt)
} }
} }
if len(pkColumns) > 0 {
// Auto-generated primary key names to check for and drop
autoGenPKNames := []string{
fmt.Sprintf("%s_pkey", table.Name),
fmt.Sprintf("%s_%s_pkey", schema.Name, table.Name),
}
// Wrap in DO block to drop auto-generated PK and add our named PK
stmt := fmt.Sprintf("DO $$\nDECLARE\n"+
" auto_pk_name text;\n"+
"BEGIN\n"+
" -- Drop auto-generated primary key if it exists\n"+
" SELECT constraint_name INTO auto_pk_name\n"+
" FROM information_schema.table_constraints\n"+
" WHERE table_schema = '%s'\n"+
" AND table_name = '%s'\n"+
" AND constraint_type = 'PRIMARY KEY'\n"+
" AND constraint_name IN (%s);\n"+
"\n"+
" IF auto_pk_name IS NOT NULL THEN\n"+
" EXECUTE 'ALTER TABLE %s.%s DROP CONSTRAINT ' || quote_ident(auto_pk_name);\n"+
" END IF;\n"+
"\n"+
" -- Add named primary key if it doesn't exist\n"+
" IF NOT EXISTS (\n"+
" SELECT 1 FROM information_schema.table_constraints\n"+
" WHERE table_schema = '%s'\n"+
" AND table_name = '%s'\n"+
" AND constraint_name = '%s'\n"+
" ) THEN\n"+
" ALTER TABLE %s.%s ADD CONSTRAINT %s PRIMARY KEY (%s);\n"+
" END IF;\n"+
"END;\n$$",
schema.Name, table.Name, formatStringList(autoGenPKNames),
schema.SQLName(), table.SQLName(),
schema.Name, table.Name, pkName,
schema.SQLName(), table.SQLName(), pkName, strings.Join(pkColumns, ", "))
statements = append(statements, stmt)
}
} }
// Phase 5: Indexes // Phase 5: Indexes
@@ -270,7 +310,18 @@ func (w *Writer) GenerateSchemaStatements(schema *models.Schema) ([]string, erro
onUpdate = "NO ACTION" onUpdate = "NO ACTION"
} }
stmt := fmt.Sprintf("ALTER TABLE %s.%s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s.%s(%s) ON DELETE %s ON UPDATE %s", // Wrap in DO block to check for existing constraint
stmt := fmt.Sprintf("DO $$\nBEGIN\n"+
" IF NOT EXISTS (\n"+
" SELECT 1 FROM information_schema.table_constraints\n"+
" WHERE table_schema = '%s'\n"+
" AND table_name = '%s'\n"+
" AND constraint_name = '%s'\n"+
" ) THEN\n"+
" ALTER TABLE %s.%s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s.%s(%s) ON DELETE %s ON UPDATE %s;\n"+
" END IF;\n"+
"END;\n$$",
schema.Name, table.Name, constraint.Name,
schema.SQLName(), table.SQLName(), constraint.Name, schema.SQLName(), table.SQLName(), constraint.Name,
strings.Join(constraint.Columns, ", "), strings.Join(constraint.Columns, ", "),
strings.ToLower(refSchema), strings.ToLower(constraint.ReferencedTable), strings.ToLower(refSchema), strings.ToLower(constraint.ReferencedTable),
@@ -336,14 +387,25 @@ func (w *Writer) generateColumnDefinition(col *models.Column) string {
// Type with length/precision - convert to valid PostgreSQL type // Type with length/precision - convert to valid PostgreSQL type
baseType := pgsql.ConvertSQLType(col.Type) baseType := pgsql.ConvertSQLType(col.Type)
typeStr := baseType typeStr := baseType
// Only add size specifiers for types that support them
if col.Length > 0 && col.Precision == 0 { if col.Length > 0 && col.Precision == 0 {
typeStr = fmt.Sprintf("%s(%d)", baseType, col.Length) if supportsLength(baseType) {
} else if col.Precision > 0 { typeStr = fmt.Sprintf("%s(%d)", baseType, col.Length)
if col.Scale > 0 { } else if isTextTypeWithoutLength(baseType) {
typeStr = fmt.Sprintf("%s(%d,%d)", baseType, col.Precision, col.Scale) // Convert text with length to varchar
} else { typeStr = fmt.Sprintf("varchar(%d)", col.Length)
typeStr = fmt.Sprintf("%s(%d)", baseType, col.Precision)
} }
// For types that don't support length (integer, bigint, etc.), ignore the length
} else if col.Precision > 0 {
if supportsPrecision(baseType) {
if col.Scale > 0 {
typeStr = fmt.Sprintf("%s(%d,%d)", baseType, col.Precision, col.Scale)
} else {
typeStr = fmt.Sprintf("%s(%d)", baseType, col.Precision)
}
}
// For types that don't support precision, ignore it
} }
parts = append(parts, typeStr) parts = append(parts, typeStr)
@@ -490,15 +552,8 @@ func (w *Writer) writeCreateTables(schema *models.Schema) error {
columnDefs := make([]string, 0, len(columns)) columnDefs := make([]string, 0, len(columns))
for _, col := range columns { for _, col := range columns {
colDef := fmt.Sprintf(" %s %s", col.SQLName(), pgsql.ConvertSQLType(col.Type)) // Use generateColumnDefinition to properly handle type, length, precision, and defaults
colDef := " " + w.generateColumnDefinition(col)
// Add default value if present
if col.Default != nil && col.Default != "" {
// Strip backticks - DBML uses them for SQL expressions but PostgreSQL doesn't
defaultVal := fmt.Sprintf("%v", col.Default)
colDef += fmt.Sprintf(" DEFAULT %s", stripBackticks(defaultVal))
}
columnDefs = append(columnDefs, colDef) columnDefs = append(columnDefs, colDef)
} }
@@ -550,7 +605,32 @@ func (w *Writer) writePrimaryKeys(schema *models.Schema) error {
continue continue
} }
fmt.Fprintf(w.writer, "DO $$\nBEGIN\n") // Auto-generated primary key names to check for and drop
autoGenPKNames := []string{
fmt.Sprintf("%s_pkey", table.Name),
fmt.Sprintf("%s_%s_pkey", schema.Name, table.Name),
}
fmt.Fprintf(w.writer, "DO $$\nDECLARE\n")
fmt.Fprintf(w.writer, " auto_pk_name text;\nBEGIN\n")
// Check for and drop auto-generated primary keys
fmt.Fprintf(w.writer, " -- Drop auto-generated primary key if it exists\n")
fmt.Fprintf(w.writer, " SELECT constraint_name INTO auto_pk_name\n")
fmt.Fprintf(w.writer, " FROM information_schema.table_constraints\n")
fmt.Fprintf(w.writer, " WHERE table_schema = '%s'\n", schema.Name)
fmt.Fprintf(w.writer, " AND table_name = '%s'\n", table.Name)
fmt.Fprintf(w.writer, " AND constraint_type = 'PRIMARY KEY'\n")
fmt.Fprintf(w.writer, " AND constraint_name IN (%s);\n", formatStringList(autoGenPKNames))
fmt.Fprintf(w.writer, "\n")
fmt.Fprintf(w.writer, " IF auto_pk_name IS NOT NULL THEN\n")
fmt.Fprintf(w.writer, " EXECUTE 'ALTER TABLE %s.%s DROP CONSTRAINT ' || quote_ident(auto_pk_name);\n",
schema.SQLName(), table.SQLName())
fmt.Fprintf(w.writer, " END IF;\n")
fmt.Fprintf(w.writer, "\n")
// Add our named primary key if it doesn't exist
fmt.Fprintf(w.writer, " -- Add named primary key if it doesn't exist\n")
fmt.Fprintf(w.writer, " IF NOT EXISTS (\n") fmt.Fprintf(w.writer, " IF NOT EXISTS (\n")
fmt.Fprintf(w.writer, " SELECT 1 FROM information_schema.table_constraints\n") fmt.Fprintf(w.writer, " SELECT 1 FROM information_schema.table_constraints\n")
fmt.Fprintf(w.writer, " WHERE table_schema = '%s'\n", schema.Name) fmt.Fprintf(w.writer, " WHERE table_schema = '%s'\n", schema.Name)
@@ -711,13 +791,6 @@ func (w *Writer) writeForeignKeys(schema *models.Schema) error {
onUpdate = strings.ToUpper(fkConstraint.OnUpdate) onUpdate = strings.ToUpper(fkConstraint.OnUpdate)
} }
fmt.Fprintf(w.writer, "ALTER TABLE %s.%s\n", schema.SQLName(), table.SQLName())
fmt.Fprintf(w.writer, " DROP CONSTRAINT IF EXISTS %s;\n", fkName)
fmt.Fprintf(w.writer, "\n")
fmt.Fprintf(w.writer, "ALTER TABLE %s.%s\n", schema.SQLName(), table.SQLName())
fmt.Fprintf(w.writer, " ADD CONSTRAINT %s\n", fkName)
fmt.Fprintf(w.writer, " FOREIGN KEY (%s)\n", strings.Join(sourceColumns, ", "))
// Use constraint's referenced schema/table or relationship's ToSchema/ToTable // Use constraint's referenced schema/table or relationship's ToSchema/ToTable
refSchema := fkConstraint.ReferencedSchema refSchema := fkConstraint.ReferencedSchema
if refSchema == "" { if refSchema == "" {
@@ -728,11 +801,24 @@ func (w *Writer) writeForeignKeys(schema *models.Schema) error {
refTable = rel.ToTable refTable = rel.ToTable
} }
fmt.Fprintf(w.writer, " REFERENCES %s.%s (%s)\n", // Use DO block to check if constraint exists before adding
fmt.Fprintf(w.writer, "DO $$\nBEGIN\n")
fmt.Fprintf(w.writer, " IF NOT EXISTS (\n")
fmt.Fprintf(w.writer, " SELECT 1 FROM information_schema.table_constraints\n")
fmt.Fprintf(w.writer, " WHERE table_schema = '%s'\n", schema.Name)
fmt.Fprintf(w.writer, " AND table_name = '%s'\n", table.Name)
fmt.Fprintf(w.writer, " AND constraint_name = '%s'\n", fkName)
fmt.Fprintf(w.writer, " ) THEN\n")
fmt.Fprintf(w.writer, " ALTER TABLE %s.%s\n", schema.SQLName(), table.SQLName())
fmt.Fprintf(w.writer, " ADD CONSTRAINT %s\n", fkName)
fmt.Fprintf(w.writer, " FOREIGN KEY (%s)\n", strings.Join(sourceColumns, ", "))
fmt.Fprintf(w.writer, " REFERENCES %s.%s (%s)\n",
refSchema, refTable, strings.Join(targetColumns, ", ")) refSchema, refTable, strings.Join(targetColumns, ", "))
fmt.Fprintf(w.writer, " ON DELETE %s\n", onDelete) fmt.Fprintf(w.writer, " ON DELETE %s\n", onDelete)
fmt.Fprintf(w.writer, " ON UPDATE %s\n", onUpdate) fmt.Fprintf(w.writer, " ON UPDATE %s\n", onUpdate)
fmt.Fprintf(w.writer, " DEFERRABLE;\n\n") fmt.Fprintf(w.writer, " DEFERRABLE;\n")
fmt.Fprintf(w.writer, " END IF;\n")
fmt.Fprintf(w.writer, "END;\n$$;\n\n")
} }
} }
@@ -844,6 +930,44 @@ func isTextType(colType string) bool {
return false return false
} }
// supportsLength checks if a PostgreSQL type supports length specification
func supportsLength(colType string) bool {
lengthTypes := []string{"varchar", "character varying", "char", "character", "bit", "bit varying", "varbit"}
lowerType := strings.ToLower(colType)
for _, t := range lengthTypes {
if lowerType == t || strings.HasPrefix(lowerType, t+"(") {
return true
}
}
return false
}
// supportsPrecision checks if a PostgreSQL type supports precision/scale specification
func supportsPrecision(colType string) bool {
precisionTypes := []string{"numeric", "decimal", "time", "timestamp", "timestamptz", "timestamp with time zone", "timestamp without time zone", "time with time zone", "time without time zone", "interval"}
lowerType := strings.ToLower(colType)
for _, t := range precisionTypes {
if lowerType == t || strings.HasPrefix(lowerType, t+"(") {
return true
}
}
return false
}
// isTextTypeWithoutLength checks if type is text (which should convert to varchar when length is specified)
func isTextTypeWithoutLength(colType string) bool {
return strings.EqualFold(colType, "text")
}
// formatStringList formats a list of strings as a SQL-safe comma-separated quoted list
func formatStringList(items []string) string {
quoted := make([]string, len(items))
for i, item := range items {
quoted[i] = fmt.Sprintf("'%s'", escapeQuote(item))
}
return strings.Join(quoted, ", ")
}
// extractOperatorClass extracts operator class from index comment/note // extractOperatorClass extracts operator class from index comment/note
// Looks for common operator classes like gin_trgm_ops, gist_trgm_ops, etc. // Looks for common operator classes like gin_trgm_ops, gist_trgm_ops, etc.
func extractOperatorClass(comment string) string { func extractOperatorClass(comment string) string {

View File

@@ -305,3 +305,136 @@ func TestTypeConversion(t *testing.T) {
t.Errorf("Output missing 'smallint' type (converted from 'int16')\nFull output:\n%s", output) t.Errorf("Output missing 'smallint' type (converted from 'int16')\nFull output:\n%s", output)
} }
} }
func TestPrimaryKeyExistenceCheck(t *testing.T) {
db := models.InitDatabase("testdb")
schema := models.InitSchema("public")
table := models.InitTable("products", "public")
idCol := models.InitColumn("id", "products", "public")
idCol.Type = "integer"
idCol.IsPrimaryKey = true
table.Columns["id"] = idCol
nameCol := models.InitColumn("name", "products", "public")
nameCol.Type = "text"
table.Columns["name"] = nameCol
schema.Tables = append(schema.Tables, table)
db.Schemas = append(db.Schemas, schema)
var buf bytes.Buffer
options := &writers.WriterOptions{}
writer := NewWriter(options)
writer.writer = &buf
err := writer.WriteDatabase(db)
if err != nil {
t.Fatalf("WriteDatabase failed: %v", err)
}
output := buf.String()
t.Logf("Generated SQL:\n%s", output)
// Verify our naming convention is used
if !strings.Contains(output, "pk_public_products") {
t.Errorf("Output missing expected primary key name 'pk_public_products'\nFull output:\n%s", output)
}
// Verify it drops auto-generated primary keys
if !strings.Contains(output, "products_pkey") || !strings.Contains(output, "DROP CONSTRAINT") {
t.Errorf("Output missing logic to drop auto-generated primary key\nFull output:\n%s", output)
}
// Verify it checks for our specific named constraint before adding it
if !strings.Contains(output, "constraint_name = 'pk_public_products'") {
t.Errorf("Output missing check for our named primary key constraint\nFull output:\n%s", output)
}
}
func TestColumnSizeSpecifiers(t *testing.T) {
db := models.InitDatabase("testdb")
schema := models.InitSchema("public")
table := models.InitTable("test_sizes", "public")
// Integer with invalid size specifier - should ignore size
integerCol := models.InitColumn("int_col", "test_sizes", "public")
integerCol.Type = "integer"
integerCol.Length = 32
table.Columns["int_col"] = integerCol
// Bigint with invalid size specifier - should ignore size
bigintCol := models.InitColumn("bigint_col", "test_sizes", "public")
bigintCol.Type = "bigint"
bigintCol.Length = 64
table.Columns["bigint_col"] = bigintCol
// Smallint with invalid size specifier - should ignore size
smallintCol := models.InitColumn("smallint_col", "test_sizes", "public")
smallintCol.Type = "smallint"
smallintCol.Length = 16
table.Columns["smallint_col"] = smallintCol
// Text with length - should convert to varchar
textCol := models.InitColumn("text_col", "test_sizes", "public")
textCol.Type = "text"
textCol.Length = 100
table.Columns["text_col"] = textCol
// Varchar with length - should keep varchar with length
varcharCol := models.InitColumn("varchar_col", "test_sizes", "public")
varcharCol.Type = "varchar"
varcharCol.Length = 50
table.Columns["varchar_col"] = varcharCol
// Decimal with precision and scale - should keep them
decimalCol := models.InitColumn("decimal_col", "test_sizes", "public")
decimalCol.Type = "decimal"
decimalCol.Precision = 19
decimalCol.Scale = 4
table.Columns["decimal_col"] = decimalCol
schema.Tables = append(schema.Tables, table)
db.Schemas = append(db.Schemas, schema)
var buf bytes.Buffer
options := &writers.WriterOptions{}
writer := NewWriter(options)
writer.writer = &buf
err := writer.WriteDatabase(db)
if err != nil {
t.Fatalf("WriteDatabase failed: %v", err)
}
output := buf.String()
t.Logf("Generated SQL:\n%s", output)
// Verify invalid size specifiers are NOT present
invalidPatterns := []string{
"integer(32)",
"bigint(64)",
"smallint(16)",
"text(100)",
}
for _, pattern := range invalidPatterns {
if strings.Contains(output, pattern) {
t.Errorf("Output contains invalid pattern '%s' - PostgreSQL doesn't support this\nFull output:\n%s", pattern, output)
}
}
// Verify valid patterns ARE present
validPatterns := []string{
"integer", // without size
"bigint", // without size
"smallint", // without size
"varchar(100)", // text converted to varchar with length
"varchar(50)", // varchar with length
"decimal(19,4)", // decimal with precision and scale
}
for _, pattern := range validPatterns {
if !strings.Contains(output, pattern) {
t.Errorf("Output missing expected pattern '%s'\nFull output:\n%s", pattern, output)
}
}
}