Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| e7a15c8e4f | |||
| c36b5ede2b |
@@ -168,6 +168,13 @@ func (w *Writer) GenerateSchemaStatements(schema *models.Schema) ([]string, erro
|
|||||||
statements = append(statements, stmts...)
|
statements = append(statements, stmts...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Phase 3.5: Add missing columns (for existing tables)
|
||||||
|
addColStmts, err := w.GenerateAddColumnStatements(schema)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to generate add column statements: %w", err)
|
||||||
|
}
|
||||||
|
statements = append(statements, addColStmts...)
|
||||||
|
|
||||||
// Phase 4: Primary keys
|
// Phase 4: Primary keys
|
||||||
for _, table := range schema.Tables {
|
for _, table := range schema.Tables {
|
||||||
// First check for explicit PrimaryKeyConstraint
|
// First check for explicit PrimaryKeyConstraint
|
||||||
@@ -179,27 +186,67 @@ func (w *Writer) GenerateSchemaStatements(schema *models.Schema) ([]string, erro
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var pkColumns []string
|
||||||
|
var pkName string
|
||||||
|
|
||||||
if pkConstraint != nil {
|
if pkConstraint != nil {
|
||||||
stmt := fmt.Sprintf("ALTER TABLE %s.%s ADD CONSTRAINT %s PRIMARY KEY (%s)",
|
pkColumns = pkConstraint.Columns
|
||||||
schema.SQLName(), table.SQLName(), pkConstraint.Name, strings.Join(pkConstraint.Columns, ", "))
|
pkName = pkConstraint.Name
|
||||||
statements = append(statements, stmt)
|
|
||||||
} else {
|
} else {
|
||||||
// No explicit constraint, check for columns with IsPrimaryKey = true
|
// No explicit constraint, check for columns with IsPrimaryKey = true
|
||||||
pkColumns := []string{}
|
pkCols := []string{}
|
||||||
for _, col := range table.Columns {
|
for _, col := range table.Columns {
|
||||||
if col.IsPrimaryKey {
|
if col.IsPrimaryKey {
|
||||||
pkColumns = append(pkColumns, col.SQLName())
|
pkCols = append(pkCols, col.SQLName())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if len(pkColumns) > 0 {
|
if len(pkCols) > 0 {
|
||||||
// Sort for consistent output
|
// Sort for consistent output
|
||||||
sort.Strings(pkColumns)
|
sort.Strings(pkCols)
|
||||||
pkName := fmt.Sprintf("pk_%s_%s", schema.SQLName(), table.SQLName())
|
pkColumns = pkCols
|
||||||
stmt := fmt.Sprintf("ALTER TABLE %s.%s ADD CONSTRAINT %s PRIMARY KEY (%s)",
|
pkName = fmt.Sprintf("pk_%s_%s", schema.SQLName(), table.SQLName())
|
||||||
schema.SQLName(), table.SQLName(), pkName, strings.Join(pkColumns, ", "))
|
|
||||||
statements = append(statements, stmt)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(pkColumns) > 0 {
|
||||||
|
// Auto-generated primary key names to check for and drop
|
||||||
|
autoGenPKNames := []string{
|
||||||
|
fmt.Sprintf("%s_pkey", table.Name),
|
||||||
|
fmt.Sprintf("%s_%s_pkey", schema.Name, table.Name),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wrap in DO block to drop auto-generated PK and add our named PK
|
||||||
|
stmt := fmt.Sprintf("DO $$\nDECLARE\n"+
|
||||||
|
" auto_pk_name text;\n"+
|
||||||
|
"BEGIN\n"+
|
||||||
|
" -- Drop auto-generated primary key if it exists\n"+
|
||||||
|
" SELECT constraint_name INTO auto_pk_name\n"+
|
||||||
|
" FROM information_schema.table_constraints\n"+
|
||||||
|
" WHERE table_schema = '%s'\n"+
|
||||||
|
" AND table_name = '%s'\n"+
|
||||||
|
" AND constraint_type = 'PRIMARY KEY'\n"+
|
||||||
|
" AND constraint_name IN (%s);\n"+
|
||||||
|
"\n"+
|
||||||
|
" IF auto_pk_name IS NOT NULL THEN\n"+
|
||||||
|
" EXECUTE 'ALTER TABLE %s.%s DROP CONSTRAINT ' || quote_ident(auto_pk_name);\n"+
|
||||||
|
" END IF;\n"+
|
||||||
|
"\n"+
|
||||||
|
" -- Add named primary key if it doesn't exist\n"+
|
||||||
|
" IF NOT EXISTS (\n"+
|
||||||
|
" SELECT 1 FROM information_schema.table_constraints\n"+
|
||||||
|
" WHERE table_schema = '%s'\n"+
|
||||||
|
" AND table_name = '%s'\n"+
|
||||||
|
" AND constraint_name = '%s'\n"+
|
||||||
|
" ) THEN\n"+
|
||||||
|
" ALTER TABLE %s.%s ADD CONSTRAINT %s PRIMARY KEY (%s);\n"+
|
||||||
|
" END IF;\n"+
|
||||||
|
"END;\n$$",
|
||||||
|
schema.Name, table.Name, formatStringList(autoGenPKNames),
|
||||||
|
schema.SQLName(), table.SQLName(),
|
||||||
|
schema.Name, table.Name, pkName,
|
||||||
|
schema.SQLName(), table.SQLName(), pkName, strings.Join(pkColumns, ", "))
|
||||||
|
statements = append(statements, stmt)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Phase 5: Indexes
|
// Phase 5: Indexes
|
||||||
@@ -270,7 +317,18 @@ func (w *Writer) GenerateSchemaStatements(schema *models.Schema) ([]string, erro
|
|||||||
onUpdate = "NO ACTION"
|
onUpdate = "NO ACTION"
|
||||||
}
|
}
|
||||||
|
|
||||||
stmt := fmt.Sprintf("ALTER TABLE %s.%s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s.%s(%s) ON DELETE %s ON UPDATE %s",
|
// Wrap in DO block to check for existing constraint
|
||||||
|
stmt := fmt.Sprintf("DO $$\nBEGIN\n"+
|
||||||
|
" IF NOT EXISTS (\n"+
|
||||||
|
" SELECT 1 FROM information_schema.table_constraints\n"+
|
||||||
|
" WHERE table_schema = '%s'\n"+
|
||||||
|
" AND table_name = '%s'\n"+
|
||||||
|
" AND constraint_name = '%s'\n"+
|
||||||
|
" ) THEN\n"+
|
||||||
|
" ALTER TABLE %s.%s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s.%s(%s) ON DELETE %s ON UPDATE %s;\n"+
|
||||||
|
" END IF;\n"+
|
||||||
|
"END;\n$$",
|
||||||
|
schema.Name, table.Name, constraint.Name,
|
||||||
schema.SQLName(), table.SQLName(), constraint.Name,
|
schema.SQLName(), table.SQLName(), constraint.Name,
|
||||||
strings.Join(constraint.Columns, ", "),
|
strings.Join(constraint.Columns, ", "),
|
||||||
strings.ToLower(refSchema), strings.ToLower(constraint.ReferencedTable),
|
strings.ToLower(refSchema), strings.ToLower(constraint.ReferencedTable),
|
||||||
@@ -300,6 +358,68 @@ func (w *Writer) GenerateSchemaStatements(schema *models.Schema) ([]string, erro
|
|||||||
return statements, nil
|
return statements, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GenerateAddColumnStatements generates ALTER TABLE ADD COLUMN statements for existing tables
|
||||||
|
// This is useful for schema evolution when new columns are added to existing tables
|
||||||
|
func (w *Writer) GenerateAddColumnStatements(schema *models.Schema) ([]string, error) {
|
||||||
|
statements := []string{}
|
||||||
|
|
||||||
|
statements = append(statements, fmt.Sprintf("-- Add missing columns for schema: %s", schema.Name))
|
||||||
|
|
||||||
|
for _, table := range schema.Tables {
|
||||||
|
// Sort columns by sequence or name for consistent output
|
||||||
|
columns := make([]*models.Column, 0, len(table.Columns))
|
||||||
|
for _, col := range table.Columns {
|
||||||
|
columns = append(columns, col)
|
||||||
|
}
|
||||||
|
sort.Slice(columns, func(i, j int) bool {
|
||||||
|
if columns[i].Sequence != columns[j].Sequence {
|
||||||
|
return columns[i].Sequence < columns[j].Sequence
|
||||||
|
}
|
||||||
|
return columns[i].Name < columns[j].Name
|
||||||
|
})
|
||||||
|
|
||||||
|
for _, col := range columns {
|
||||||
|
colDef := w.generateColumnDefinition(col)
|
||||||
|
|
||||||
|
// Generate DO block that checks if column exists before adding
|
||||||
|
stmt := fmt.Sprintf("DO $$\nBEGIN\n"+
|
||||||
|
" IF NOT EXISTS (\n"+
|
||||||
|
" SELECT 1 FROM information_schema.columns\n"+
|
||||||
|
" WHERE table_schema = '%s'\n"+
|
||||||
|
" AND table_name = '%s'\n"+
|
||||||
|
" AND column_name = '%s'\n"+
|
||||||
|
" ) THEN\n"+
|
||||||
|
" ALTER TABLE %s.%s ADD COLUMN %s;\n"+
|
||||||
|
" END IF;\n"+
|
||||||
|
"END;\n$$",
|
||||||
|
schema.Name, table.Name, col.Name,
|
||||||
|
schema.SQLName(), table.SQLName(), colDef)
|
||||||
|
statements = append(statements, stmt)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return statements, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateAddColumnsForDatabase generates ALTER TABLE ADD COLUMN statements for the entire database
|
||||||
|
func (w *Writer) GenerateAddColumnsForDatabase(db *models.Database) ([]string, error) {
|
||||||
|
statements := []string{}
|
||||||
|
|
||||||
|
statements = append(statements, "-- Add missing columns to existing tables")
|
||||||
|
statements = append(statements, fmt.Sprintf("-- Database: %s", db.Name))
|
||||||
|
statements = append(statements, "-- Generated by RelSpec")
|
||||||
|
|
||||||
|
for _, schema := range db.Schemas {
|
||||||
|
schemaStatements, err := w.GenerateAddColumnStatements(schema)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to generate add column statements for schema %s: %w", schema.Name, err)
|
||||||
|
}
|
||||||
|
statements = append(statements, schemaStatements...)
|
||||||
|
}
|
||||||
|
|
||||||
|
return statements, nil
|
||||||
|
}
|
||||||
|
|
||||||
// generateCreateTableStatement generates CREATE TABLE statement
|
// generateCreateTableStatement generates CREATE TABLE statement
|
||||||
func (w *Writer) generateCreateTableStatement(schema *models.Schema, table *models.Table) ([]string, error) {
|
func (w *Writer) generateCreateTableStatement(schema *models.Schema, table *models.Table) ([]string, error) {
|
||||||
statements := []string{}
|
statements := []string{}
|
||||||
@@ -322,7 +442,7 @@ func (w *Writer) generateCreateTableStatement(schema *models.Schema, table *mode
|
|||||||
columnDefs = append(columnDefs, " "+def)
|
columnDefs = append(columnDefs, " "+def)
|
||||||
}
|
}
|
||||||
|
|
||||||
stmt := fmt.Sprintf("CREATE TABLE %s.%s (\n%s\n)",
|
stmt := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s.%s (\n%s\n)",
|
||||||
schema.SQLName(), table.SQLName(), strings.Join(columnDefs, ",\n"))
|
schema.SQLName(), table.SQLName(), strings.Join(columnDefs, ",\n"))
|
||||||
statements = append(statements, stmt)
|
statements = append(statements, stmt)
|
||||||
|
|
||||||
@@ -336,14 +456,25 @@ func (w *Writer) generateColumnDefinition(col *models.Column) string {
|
|||||||
// Type with length/precision - convert to valid PostgreSQL type
|
// Type with length/precision - convert to valid PostgreSQL type
|
||||||
baseType := pgsql.ConvertSQLType(col.Type)
|
baseType := pgsql.ConvertSQLType(col.Type)
|
||||||
typeStr := baseType
|
typeStr := baseType
|
||||||
|
|
||||||
|
// Only add size specifiers for types that support them
|
||||||
if col.Length > 0 && col.Precision == 0 {
|
if col.Length > 0 && col.Precision == 0 {
|
||||||
typeStr = fmt.Sprintf("%s(%d)", baseType, col.Length)
|
if supportsLength(baseType) {
|
||||||
} else if col.Precision > 0 {
|
typeStr = fmt.Sprintf("%s(%d)", baseType, col.Length)
|
||||||
if col.Scale > 0 {
|
} else if isTextTypeWithoutLength(baseType) {
|
||||||
typeStr = fmt.Sprintf("%s(%d,%d)", baseType, col.Precision, col.Scale)
|
// Convert text with length to varchar
|
||||||
} else {
|
typeStr = fmt.Sprintf("varchar(%d)", col.Length)
|
||||||
typeStr = fmt.Sprintf("%s(%d)", baseType, col.Precision)
|
|
||||||
}
|
}
|
||||||
|
// For types that don't support length (integer, bigint, etc.), ignore the length
|
||||||
|
} else if col.Precision > 0 {
|
||||||
|
if supportsPrecision(baseType) {
|
||||||
|
if col.Scale > 0 {
|
||||||
|
typeStr = fmt.Sprintf("%s(%d,%d)", baseType, col.Precision, col.Scale)
|
||||||
|
} else {
|
||||||
|
typeStr = fmt.Sprintf("%s(%d)", baseType, col.Precision)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// For types that don't support precision, ignore it
|
||||||
}
|
}
|
||||||
parts = append(parts, typeStr)
|
parts = append(parts, typeStr)
|
||||||
|
|
||||||
@@ -396,6 +527,11 @@ func (w *Writer) WriteSchema(schema *models.Schema) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Phase 3.5: Add missing columns (priority 120)
|
||||||
|
if err := w.writeAddColumns(schema); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
// Phase 4: Create primary keys (priority 160)
|
// Phase 4: Create primary keys (priority 160)
|
||||||
if err := w.writePrimaryKeys(schema); err != nil {
|
if err := w.writePrimaryKeys(schema); err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -437,6 +573,44 @@ func (w *Writer) WriteTable(table *models.Table) error {
|
|||||||
return w.WriteSchema(schema)
|
return w.WriteSchema(schema)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WriteAddColumnStatements writes ALTER TABLE ADD COLUMN statements for a database
|
||||||
|
// This is used for schema evolution/migration when new columns are added
|
||||||
|
func (w *Writer) WriteAddColumnStatements(db *models.Database) error {
|
||||||
|
var writer io.Writer
|
||||||
|
var file *os.File
|
||||||
|
var err error
|
||||||
|
|
||||||
|
// Use existing writer if already set (for testing)
|
||||||
|
if w.writer != nil {
|
||||||
|
writer = w.writer
|
||||||
|
} else if w.options.OutputPath != "" {
|
||||||
|
// Determine output destination
|
||||||
|
file, err = os.Create(w.options.OutputPath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create output file: %w", err)
|
||||||
|
}
|
||||||
|
defer file.Close()
|
||||||
|
writer = file
|
||||||
|
} else {
|
||||||
|
writer = os.Stdout
|
||||||
|
}
|
||||||
|
|
||||||
|
w.writer = writer
|
||||||
|
|
||||||
|
// Generate statements
|
||||||
|
statements, err := w.GenerateAddColumnsForDatabase(db)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write each statement
|
||||||
|
for _, stmt := range statements {
|
||||||
|
fmt.Fprintf(w.writer, "%s;\n\n", stmt)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// writeCreateSchema generates CREATE SCHEMA statement
|
// writeCreateSchema generates CREATE SCHEMA statement
|
||||||
func (w *Writer) writeCreateSchema(schema *models.Schema) error {
|
func (w *Writer) writeCreateSchema(schema *models.Schema) error {
|
||||||
if schema.Name == "public" {
|
if schema.Name == "public" {
|
||||||
@@ -490,15 +664,8 @@ func (w *Writer) writeCreateTables(schema *models.Schema) error {
|
|||||||
columnDefs := make([]string, 0, len(columns))
|
columnDefs := make([]string, 0, len(columns))
|
||||||
|
|
||||||
for _, col := range columns {
|
for _, col := range columns {
|
||||||
colDef := fmt.Sprintf(" %s %s", col.SQLName(), pgsql.ConvertSQLType(col.Type))
|
// Use generateColumnDefinition to properly handle type, length, precision, and defaults
|
||||||
|
colDef := " " + w.generateColumnDefinition(col)
|
||||||
// Add default value if present
|
|
||||||
if col.Default != nil && col.Default != "" {
|
|
||||||
// Strip backticks - DBML uses them for SQL expressions but PostgreSQL doesn't
|
|
||||||
defaultVal := fmt.Sprintf("%v", col.Default)
|
|
||||||
colDef += fmt.Sprintf(" DEFAULT %s", stripBackticks(defaultVal))
|
|
||||||
}
|
|
||||||
|
|
||||||
columnDefs = append(columnDefs, colDef)
|
columnDefs = append(columnDefs, colDef)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -509,6 +676,35 @@ func (w *Writer) writeCreateTables(schema *models.Schema) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// writeAddColumns generates ALTER TABLE ADD COLUMN statements for missing columns
|
||||||
|
func (w *Writer) writeAddColumns(schema *models.Schema) error {
|
||||||
|
fmt.Fprintf(w.writer, "-- Add missing columns for schema: %s\n", schema.Name)
|
||||||
|
|
||||||
|
for _, table := range schema.Tables {
|
||||||
|
// Sort columns by sequence or name for consistent output
|
||||||
|
columns := getSortedColumns(table.Columns)
|
||||||
|
|
||||||
|
for _, col := range columns {
|
||||||
|
colDef := w.generateColumnDefinition(col)
|
||||||
|
|
||||||
|
// Generate DO block that checks if column exists before adding
|
||||||
|
fmt.Fprintf(w.writer, "DO $$\nBEGIN\n")
|
||||||
|
fmt.Fprintf(w.writer, " IF NOT EXISTS (\n")
|
||||||
|
fmt.Fprintf(w.writer, " SELECT 1 FROM information_schema.columns\n")
|
||||||
|
fmt.Fprintf(w.writer, " WHERE table_schema = '%s'\n", schema.Name)
|
||||||
|
fmt.Fprintf(w.writer, " AND table_name = '%s'\n", table.Name)
|
||||||
|
fmt.Fprintf(w.writer, " AND column_name = '%s'\n", col.Name)
|
||||||
|
fmt.Fprintf(w.writer, " ) THEN\n")
|
||||||
|
fmt.Fprintf(w.writer, " ALTER TABLE %s.%s ADD COLUMN %s;\n",
|
||||||
|
schema.SQLName(), table.SQLName(), colDef)
|
||||||
|
fmt.Fprintf(w.writer, " END IF;\n")
|
||||||
|
fmt.Fprintf(w.writer, "END;\n$$;\n\n")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// writePrimaryKeys generates ALTER TABLE statements for primary keys
|
// writePrimaryKeys generates ALTER TABLE statements for primary keys
|
||||||
func (w *Writer) writePrimaryKeys(schema *models.Schema) error {
|
func (w *Writer) writePrimaryKeys(schema *models.Schema) error {
|
||||||
fmt.Fprintf(w.writer, "-- Primary keys for schema: %s\n", schema.Name)
|
fmt.Fprintf(w.writer, "-- Primary keys for schema: %s\n", schema.Name)
|
||||||
@@ -550,7 +746,32 @@ func (w *Writer) writePrimaryKeys(schema *models.Schema) error {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Fprintf(w.writer, "DO $$\nBEGIN\n")
|
// Auto-generated primary key names to check for and drop
|
||||||
|
autoGenPKNames := []string{
|
||||||
|
fmt.Sprintf("%s_pkey", table.Name),
|
||||||
|
fmt.Sprintf("%s_%s_pkey", schema.Name, table.Name),
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintf(w.writer, "DO $$\nDECLARE\n")
|
||||||
|
fmt.Fprintf(w.writer, " auto_pk_name text;\nBEGIN\n")
|
||||||
|
|
||||||
|
// Check for and drop auto-generated primary keys
|
||||||
|
fmt.Fprintf(w.writer, " -- Drop auto-generated primary key if it exists\n")
|
||||||
|
fmt.Fprintf(w.writer, " SELECT constraint_name INTO auto_pk_name\n")
|
||||||
|
fmt.Fprintf(w.writer, " FROM information_schema.table_constraints\n")
|
||||||
|
fmt.Fprintf(w.writer, " WHERE table_schema = '%s'\n", schema.Name)
|
||||||
|
fmt.Fprintf(w.writer, " AND table_name = '%s'\n", table.Name)
|
||||||
|
fmt.Fprintf(w.writer, " AND constraint_type = 'PRIMARY KEY'\n")
|
||||||
|
fmt.Fprintf(w.writer, " AND constraint_name IN (%s);\n", formatStringList(autoGenPKNames))
|
||||||
|
fmt.Fprintf(w.writer, "\n")
|
||||||
|
fmt.Fprintf(w.writer, " IF auto_pk_name IS NOT NULL THEN\n")
|
||||||
|
fmt.Fprintf(w.writer, " EXECUTE 'ALTER TABLE %s.%s DROP CONSTRAINT ' || quote_ident(auto_pk_name);\n",
|
||||||
|
schema.SQLName(), table.SQLName())
|
||||||
|
fmt.Fprintf(w.writer, " END IF;\n")
|
||||||
|
fmt.Fprintf(w.writer, "\n")
|
||||||
|
|
||||||
|
// Add our named primary key if it doesn't exist
|
||||||
|
fmt.Fprintf(w.writer, " -- Add named primary key if it doesn't exist\n")
|
||||||
fmt.Fprintf(w.writer, " IF NOT EXISTS (\n")
|
fmt.Fprintf(w.writer, " IF NOT EXISTS (\n")
|
||||||
fmt.Fprintf(w.writer, " SELECT 1 FROM information_schema.table_constraints\n")
|
fmt.Fprintf(w.writer, " SELECT 1 FROM information_schema.table_constraints\n")
|
||||||
fmt.Fprintf(w.writer, " WHERE table_schema = '%s'\n", schema.Name)
|
fmt.Fprintf(w.writer, " WHERE table_schema = '%s'\n", schema.Name)
|
||||||
@@ -711,13 +932,6 @@ func (w *Writer) writeForeignKeys(schema *models.Schema) error {
|
|||||||
onUpdate = strings.ToUpper(fkConstraint.OnUpdate)
|
onUpdate = strings.ToUpper(fkConstraint.OnUpdate)
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Fprintf(w.writer, "ALTER TABLE %s.%s\n", schema.SQLName(), table.SQLName())
|
|
||||||
fmt.Fprintf(w.writer, " DROP CONSTRAINT IF EXISTS %s;\n", fkName)
|
|
||||||
fmt.Fprintf(w.writer, "\n")
|
|
||||||
fmt.Fprintf(w.writer, "ALTER TABLE %s.%s\n", schema.SQLName(), table.SQLName())
|
|
||||||
fmt.Fprintf(w.writer, " ADD CONSTRAINT %s\n", fkName)
|
|
||||||
fmt.Fprintf(w.writer, " FOREIGN KEY (%s)\n", strings.Join(sourceColumns, ", "))
|
|
||||||
|
|
||||||
// Use constraint's referenced schema/table or relationship's ToSchema/ToTable
|
// Use constraint's referenced schema/table or relationship's ToSchema/ToTable
|
||||||
refSchema := fkConstraint.ReferencedSchema
|
refSchema := fkConstraint.ReferencedSchema
|
||||||
if refSchema == "" {
|
if refSchema == "" {
|
||||||
@@ -728,11 +942,24 @@ func (w *Writer) writeForeignKeys(schema *models.Schema) error {
|
|||||||
refTable = rel.ToTable
|
refTable = rel.ToTable
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Fprintf(w.writer, " REFERENCES %s.%s (%s)\n",
|
// Use DO block to check if constraint exists before adding
|
||||||
|
fmt.Fprintf(w.writer, "DO $$\nBEGIN\n")
|
||||||
|
fmt.Fprintf(w.writer, " IF NOT EXISTS (\n")
|
||||||
|
fmt.Fprintf(w.writer, " SELECT 1 FROM information_schema.table_constraints\n")
|
||||||
|
fmt.Fprintf(w.writer, " WHERE table_schema = '%s'\n", schema.Name)
|
||||||
|
fmt.Fprintf(w.writer, " AND table_name = '%s'\n", table.Name)
|
||||||
|
fmt.Fprintf(w.writer, " AND constraint_name = '%s'\n", fkName)
|
||||||
|
fmt.Fprintf(w.writer, " ) THEN\n")
|
||||||
|
fmt.Fprintf(w.writer, " ALTER TABLE %s.%s\n", schema.SQLName(), table.SQLName())
|
||||||
|
fmt.Fprintf(w.writer, " ADD CONSTRAINT %s\n", fkName)
|
||||||
|
fmt.Fprintf(w.writer, " FOREIGN KEY (%s)\n", strings.Join(sourceColumns, ", "))
|
||||||
|
fmt.Fprintf(w.writer, " REFERENCES %s.%s (%s)\n",
|
||||||
refSchema, refTable, strings.Join(targetColumns, ", "))
|
refSchema, refTable, strings.Join(targetColumns, ", "))
|
||||||
fmt.Fprintf(w.writer, " ON DELETE %s\n", onDelete)
|
fmt.Fprintf(w.writer, " ON DELETE %s\n", onDelete)
|
||||||
fmt.Fprintf(w.writer, " ON UPDATE %s\n", onUpdate)
|
fmt.Fprintf(w.writer, " ON UPDATE %s\n", onUpdate)
|
||||||
fmt.Fprintf(w.writer, " DEFERRABLE;\n\n")
|
fmt.Fprintf(w.writer, " DEFERRABLE;\n")
|
||||||
|
fmt.Fprintf(w.writer, " END IF;\n")
|
||||||
|
fmt.Fprintf(w.writer, "END;\n$$;\n\n")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -844,6 +1071,44 @@ func isTextType(colType string) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// supportsLength checks if a PostgreSQL type supports length specification
|
||||||
|
func supportsLength(colType string) bool {
|
||||||
|
lengthTypes := []string{"varchar", "character varying", "char", "character", "bit", "bit varying", "varbit"}
|
||||||
|
lowerType := strings.ToLower(colType)
|
||||||
|
for _, t := range lengthTypes {
|
||||||
|
if lowerType == t || strings.HasPrefix(lowerType, t+"(") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// supportsPrecision checks if a PostgreSQL type supports precision/scale specification
|
||||||
|
func supportsPrecision(colType string) bool {
|
||||||
|
precisionTypes := []string{"numeric", "decimal", "time", "timestamp", "timestamptz", "timestamp with time zone", "timestamp without time zone", "time with time zone", "time without time zone", "interval"}
|
||||||
|
lowerType := strings.ToLower(colType)
|
||||||
|
for _, t := range precisionTypes {
|
||||||
|
if lowerType == t || strings.HasPrefix(lowerType, t+"(") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// isTextTypeWithoutLength checks if type is text (which should convert to varchar when length is specified)
|
||||||
|
func isTextTypeWithoutLength(colType string) bool {
|
||||||
|
return strings.EqualFold(colType, "text")
|
||||||
|
}
|
||||||
|
|
||||||
|
// formatStringList formats a list of strings as a SQL-safe comma-separated quoted list
|
||||||
|
func formatStringList(items []string) string {
|
||||||
|
quoted := make([]string, len(items))
|
||||||
|
for i, item := range items {
|
||||||
|
quoted[i] = fmt.Sprintf("'%s'", escapeQuote(item))
|
||||||
|
}
|
||||||
|
return strings.Join(quoted, ", ")
|
||||||
|
}
|
||||||
|
|
||||||
// extractOperatorClass extracts operator class from index comment/note
|
// extractOperatorClass extracts operator class from index comment/note
|
||||||
// Looks for common operator classes like gin_trgm_ops, gist_trgm_ops, etc.
|
// Looks for common operator classes like gin_trgm_ops, gist_trgm_ops, etc.
|
||||||
func extractOperatorClass(comment string) string {
|
func extractOperatorClass(comment string) string {
|
||||||
|
|||||||
@@ -305,3 +305,263 @@ func TestTypeConversion(t *testing.T) {
|
|||||||
t.Errorf("Output missing 'smallint' type (converted from 'int16')\nFull output:\n%s", output)
|
t.Errorf("Output missing 'smallint' type (converted from 'int16')\nFull output:\n%s", output)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestPrimaryKeyExistenceCheck(t *testing.T) {
|
||||||
|
db := models.InitDatabase("testdb")
|
||||||
|
schema := models.InitSchema("public")
|
||||||
|
table := models.InitTable("products", "public")
|
||||||
|
|
||||||
|
idCol := models.InitColumn("id", "products", "public")
|
||||||
|
idCol.Type = "integer"
|
||||||
|
idCol.IsPrimaryKey = true
|
||||||
|
table.Columns["id"] = idCol
|
||||||
|
|
||||||
|
nameCol := models.InitColumn("name", "products", "public")
|
||||||
|
nameCol.Type = "text"
|
||||||
|
table.Columns["name"] = nameCol
|
||||||
|
|
||||||
|
schema.Tables = append(schema.Tables, table)
|
||||||
|
db.Schemas = append(db.Schemas, schema)
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
options := &writers.WriterOptions{}
|
||||||
|
writer := NewWriter(options)
|
||||||
|
writer.writer = &buf
|
||||||
|
|
||||||
|
err := writer.WriteDatabase(db)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("WriteDatabase failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
output := buf.String()
|
||||||
|
t.Logf("Generated SQL:\n%s", output)
|
||||||
|
|
||||||
|
// Verify our naming convention is used
|
||||||
|
if !strings.Contains(output, "pk_public_products") {
|
||||||
|
t.Errorf("Output missing expected primary key name 'pk_public_products'\nFull output:\n%s", output)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify it drops auto-generated primary keys
|
||||||
|
if !strings.Contains(output, "products_pkey") || !strings.Contains(output, "DROP CONSTRAINT") {
|
||||||
|
t.Errorf("Output missing logic to drop auto-generated primary key\nFull output:\n%s", output)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify it checks for our specific named constraint before adding it
|
||||||
|
if !strings.Contains(output, "constraint_name = 'pk_public_products'") {
|
||||||
|
t.Errorf("Output missing check for our named primary key constraint\nFull output:\n%s", output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestColumnSizeSpecifiers(t *testing.T) {
|
||||||
|
db := models.InitDatabase("testdb")
|
||||||
|
schema := models.InitSchema("public")
|
||||||
|
table := models.InitTable("test_sizes", "public")
|
||||||
|
|
||||||
|
// Integer with invalid size specifier - should ignore size
|
||||||
|
integerCol := models.InitColumn("int_col", "test_sizes", "public")
|
||||||
|
integerCol.Type = "integer"
|
||||||
|
integerCol.Length = 32
|
||||||
|
table.Columns["int_col"] = integerCol
|
||||||
|
|
||||||
|
// Bigint with invalid size specifier - should ignore size
|
||||||
|
bigintCol := models.InitColumn("bigint_col", "test_sizes", "public")
|
||||||
|
bigintCol.Type = "bigint"
|
||||||
|
bigintCol.Length = 64
|
||||||
|
table.Columns["bigint_col"] = bigintCol
|
||||||
|
|
||||||
|
// Smallint with invalid size specifier - should ignore size
|
||||||
|
smallintCol := models.InitColumn("smallint_col", "test_sizes", "public")
|
||||||
|
smallintCol.Type = "smallint"
|
||||||
|
smallintCol.Length = 16
|
||||||
|
table.Columns["smallint_col"] = smallintCol
|
||||||
|
|
||||||
|
// Text with length - should convert to varchar
|
||||||
|
textCol := models.InitColumn("text_col", "test_sizes", "public")
|
||||||
|
textCol.Type = "text"
|
||||||
|
textCol.Length = 100
|
||||||
|
table.Columns["text_col"] = textCol
|
||||||
|
|
||||||
|
// Varchar with length - should keep varchar with length
|
||||||
|
varcharCol := models.InitColumn("varchar_col", "test_sizes", "public")
|
||||||
|
varcharCol.Type = "varchar"
|
||||||
|
varcharCol.Length = 50
|
||||||
|
table.Columns["varchar_col"] = varcharCol
|
||||||
|
|
||||||
|
// Decimal with precision and scale - should keep them
|
||||||
|
decimalCol := models.InitColumn("decimal_col", "test_sizes", "public")
|
||||||
|
decimalCol.Type = "decimal"
|
||||||
|
decimalCol.Precision = 19
|
||||||
|
decimalCol.Scale = 4
|
||||||
|
table.Columns["decimal_col"] = decimalCol
|
||||||
|
|
||||||
|
schema.Tables = append(schema.Tables, table)
|
||||||
|
db.Schemas = append(db.Schemas, schema)
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
options := &writers.WriterOptions{}
|
||||||
|
writer := NewWriter(options)
|
||||||
|
writer.writer = &buf
|
||||||
|
|
||||||
|
err := writer.WriteDatabase(db)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("WriteDatabase failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
output := buf.String()
|
||||||
|
t.Logf("Generated SQL:\n%s", output)
|
||||||
|
|
||||||
|
// Verify invalid size specifiers are NOT present
|
||||||
|
invalidPatterns := []string{
|
||||||
|
"integer(32)",
|
||||||
|
"bigint(64)",
|
||||||
|
"smallint(16)",
|
||||||
|
"text(100)",
|
||||||
|
}
|
||||||
|
for _, pattern := range invalidPatterns {
|
||||||
|
if strings.Contains(output, pattern) {
|
||||||
|
t.Errorf("Output contains invalid pattern '%s' - PostgreSQL doesn't support this\nFull output:\n%s", pattern, output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify valid patterns ARE present
|
||||||
|
validPatterns := []string{
|
||||||
|
"integer", // without size
|
||||||
|
"bigint", // without size
|
||||||
|
"smallint", // without size
|
||||||
|
"varchar(100)", // text converted to varchar with length
|
||||||
|
"varchar(50)", // varchar with length
|
||||||
|
"decimal(19,4)", // decimal with precision and scale
|
||||||
|
}
|
||||||
|
for _, pattern := range validPatterns {
|
||||||
|
if !strings.Contains(output, pattern) {
|
||||||
|
t.Errorf("Output missing expected pattern '%s'\nFull output:\n%s", pattern, output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerateAddColumnStatements(t *testing.T) {
|
||||||
|
// Create a test database with tables that have new columns
|
||||||
|
db := models.InitDatabase("testdb")
|
||||||
|
schema := models.InitSchema("public")
|
||||||
|
|
||||||
|
// Create a table with columns
|
||||||
|
table := models.InitTable("users", "public")
|
||||||
|
|
||||||
|
// Existing column
|
||||||
|
idCol := models.InitColumn("id", "users", "public")
|
||||||
|
idCol.Type = "integer"
|
||||||
|
idCol.NotNull = true
|
||||||
|
idCol.Sequence = 1
|
||||||
|
table.Columns["id"] = idCol
|
||||||
|
|
||||||
|
// New column to be added
|
||||||
|
emailCol := models.InitColumn("email", "users", "public")
|
||||||
|
emailCol.Type = "varchar"
|
||||||
|
emailCol.Length = 255
|
||||||
|
emailCol.NotNull = true
|
||||||
|
emailCol.Sequence = 2
|
||||||
|
table.Columns["email"] = emailCol
|
||||||
|
|
||||||
|
// New column with default
|
||||||
|
statusCol := models.InitColumn("status", "users", "public")
|
||||||
|
statusCol.Type = "text"
|
||||||
|
statusCol.Default = "active"
|
||||||
|
statusCol.Sequence = 3
|
||||||
|
table.Columns["status"] = statusCol
|
||||||
|
|
||||||
|
schema.Tables = append(schema.Tables, table)
|
||||||
|
db.Schemas = append(db.Schemas, schema)
|
||||||
|
|
||||||
|
// Create writer
|
||||||
|
options := &writers.WriterOptions{}
|
||||||
|
writer := NewWriter(options)
|
||||||
|
|
||||||
|
// Generate ADD COLUMN statements
|
||||||
|
statements, err := writer.GenerateAddColumnsForDatabase(db)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GenerateAddColumnsForDatabase failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Join all statements to verify content
|
||||||
|
output := strings.Join(statements, "\n")
|
||||||
|
t.Logf("Generated ADD COLUMN statements:\n%s", output)
|
||||||
|
|
||||||
|
// Verify expected elements
|
||||||
|
expectedStrings := []string{
|
||||||
|
"ALTER TABLE public.users ADD COLUMN id integer NOT NULL",
|
||||||
|
"ALTER TABLE public.users ADD COLUMN email varchar(255) NOT NULL",
|
||||||
|
"ALTER TABLE public.users ADD COLUMN status text DEFAULT 'active'",
|
||||||
|
"information_schema.columns",
|
||||||
|
"table_schema = 'public'",
|
||||||
|
"table_name = 'users'",
|
||||||
|
"column_name = 'id'",
|
||||||
|
"column_name = 'email'",
|
||||||
|
"column_name = 'status'",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, expected := range expectedStrings {
|
||||||
|
if !strings.Contains(output, expected) {
|
||||||
|
t.Errorf("Output missing expected string: %s\nFull output:\n%s", expected, output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify DO blocks are present for conditional adds
|
||||||
|
doBlockCount := strings.Count(output, "DO $$")
|
||||||
|
if doBlockCount < 3 {
|
||||||
|
t.Errorf("Expected at least 3 DO blocks (one per column), got %d", doBlockCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify IF NOT EXISTS logic
|
||||||
|
ifNotExistsCount := strings.Count(output, "IF NOT EXISTS")
|
||||||
|
if ifNotExistsCount < 3 {
|
||||||
|
t.Errorf("Expected at least 3 IF NOT EXISTS checks (one per column), got %d", ifNotExistsCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWriteAddColumnStatements(t *testing.T) {
|
||||||
|
// Create a test database
|
||||||
|
db := models.InitDatabase("testdb")
|
||||||
|
schema := models.InitSchema("public")
|
||||||
|
|
||||||
|
// Create a table with a new column to be added
|
||||||
|
table := models.InitTable("products", "public")
|
||||||
|
|
||||||
|
idCol := models.InitColumn("id", "products", "public")
|
||||||
|
idCol.Type = "integer"
|
||||||
|
table.Columns["id"] = idCol
|
||||||
|
|
||||||
|
// New column with various properties
|
||||||
|
descCol := models.InitColumn("description", "products", "public")
|
||||||
|
descCol.Type = "text"
|
||||||
|
descCol.NotNull = false
|
||||||
|
table.Columns["description"] = descCol
|
||||||
|
|
||||||
|
schema.Tables = append(schema.Tables, table)
|
||||||
|
db.Schemas = append(db.Schemas, schema)
|
||||||
|
|
||||||
|
// Create writer with output to buffer
|
||||||
|
var buf bytes.Buffer
|
||||||
|
options := &writers.WriterOptions{}
|
||||||
|
writer := NewWriter(options)
|
||||||
|
writer.writer = &buf
|
||||||
|
|
||||||
|
// Write ADD COLUMN statements
|
||||||
|
err := writer.WriteAddColumnStatements(db)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("WriteAddColumnStatements failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
output := buf.String()
|
||||||
|
t.Logf("Generated output:\n%s", output)
|
||||||
|
|
||||||
|
// Verify output contains expected elements
|
||||||
|
if !strings.Contains(output, "ALTER TABLE public.products ADD COLUMN id integer") {
|
||||||
|
t.Errorf("Output missing ADD COLUMN for id\nFull output:\n%s", output)
|
||||||
|
}
|
||||||
|
if !strings.Contains(output, "ALTER TABLE public.products ADD COLUMN description text") {
|
||||||
|
t.Errorf("Output missing ADD COLUMN for description\nFull output:\n%s", output)
|
||||||
|
}
|
||||||
|
if !strings.Contains(output, "DO $$") {
|
||||||
|
t.Errorf("Output missing DO block\nFull output:\n%s", output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user