feat(writer): 🎉 Implement add column statements for schema evolution
All checks were successful
CI / Test (1.24) (push) Successful in -26m24s
CI / Test (1.25) (push) Successful in -26m14s
CI / Lint (push) Successful in -26m30s
CI / Build (push) Successful in -26m41s
Release / Build and Release (push) Successful in -26m29s
Integration Tests / Integration Tests (push) Successful in -26m13s

* Add functionality to generate ALTER TABLE ADD COLUMN statements for existing tables.
* Introduce tests for generating and writing add column statements.
* Enhance schema evolution capabilities when new columns are added.
This commit is contained in:
2026-01-31 19:12:00 +02:00
parent c36b5ede2b
commit e7a15c8e4f
2 changed files with 269 additions and 1 deletions

View File

@@ -168,6 +168,13 @@ func (w *Writer) GenerateSchemaStatements(schema *models.Schema) ([]string, erro
statements = append(statements, stmts...) statements = append(statements, stmts...)
} }
// Phase 3.5: Add missing columns (for existing tables)
addColStmts, err := w.GenerateAddColumnStatements(schema)
if err != nil {
return nil, fmt.Errorf("failed to generate add column statements: %w", err)
}
statements = append(statements, addColStmts...)
// Phase 4: Primary keys // Phase 4: Primary keys
for _, table := range schema.Tables { for _, table := range schema.Tables {
// First check for explicit PrimaryKeyConstraint // First check for explicit PrimaryKeyConstraint
@@ -351,6 +358,68 @@ func (w *Writer) GenerateSchemaStatements(schema *models.Schema) ([]string, erro
return statements, nil return statements, nil
} }
// GenerateAddColumnStatements generates ALTER TABLE ADD COLUMN statements for existing tables
// This is useful for schema evolution when new columns are added to existing tables
func (w *Writer) GenerateAddColumnStatements(schema *models.Schema) ([]string, error) {
statements := []string{}
statements = append(statements, fmt.Sprintf("-- Add missing columns for schema: %s", schema.Name))
for _, table := range schema.Tables {
// Sort columns by sequence or name for consistent output
columns := make([]*models.Column, 0, len(table.Columns))
for _, col := range table.Columns {
columns = append(columns, col)
}
sort.Slice(columns, func(i, j int) bool {
if columns[i].Sequence != columns[j].Sequence {
return columns[i].Sequence < columns[j].Sequence
}
return columns[i].Name < columns[j].Name
})
for _, col := range columns {
colDef := w.generateColumnDefinition(col)
// Generate DO block that checks if column exists before adding
stmt := fmt.Sprintf("DO $$\nBEGIN\n"+
" IF NOT EXISTS (\n"+
" SELECT 1 FROM information_schema.columns\n"+
" WHERE table_schema = '%s'\n"+
" AND table_name = '%s'\n"+
" AND column_name = '%s'\n"+
" ) THEN\n"+
" ALTER TABLE %s.%s ADD COLUMN %s;\n"+
" END IF;\n"+
"END;\n$$",
schema.Name, table.Name, col.Name,
schema.SQLName(), table.SQLName(), colDef)
statements = append(statements, stmt)
}
}
return statements, nil
}
// GenerateAddColumnsForDatabase generates ALTER TABLE ADD COLUMN statements for the entire database
func (w *Writer) GenerateAddColumnsForDatabase(db *models.Database) ([]string, error) {
statements := []string{}
statements = append(statements, "-- Add missing columns to existing tables")
statements = append(statements, fmt.Sprintf("-- Database: %s", db.Name))
statements = append(statements, "-- Generated by RelSpec")
for _, schema := range db.Schemas {
schemaStatements, err := w.GenerateAddColumnStatements(schema)
if err != nil {
return nil, fmt.Errorf("failed to generate add column statements for schema %s: %w", schema.Name, err)
}
statements = append(statements, schemaStatements...)
}
return statements, nil
}
// generateCreateTableStatement generates CREATE TABLE statement // generateCreateTableStatement generates CREATE TABLE statement
func (w *Writer) generateCreateTableStatement(schema *models.Schema, table *models.Table) ([]string, error) { func (w *Writer) generateCreateTableStatement(schema *models.Schema, table *models.Table) ([]string, error) {
statements := []string{} statements := []string{}
@@ -373,7 +442,7 @@ func (w *Writer) generateCreateTableStatement(schema *models.Schema, table *mode
columnDefs = append(columnDefs, " "+def) columnDefs = append(columnDefs, " "+def)
} }
stmt := fmt.Sprintf("CREATE TABLE %s.%s (\n%s\n)", stmt := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s.%s (\n%s\n)",
schema.SQLName(), table.SQLName(), strings.Join(columnDefs, ",\n")) schema.SQLName(), table.SQLName(), strings.Join(columnDefs, ",\n"))
statements = append(statements, stmt) statements = append(statements, stmt)
@@ -458,6 +527,11 @@ func (w *Writer) WriteSchema(schema *models.Schema) error {
return err return err
} }
// Phase 3.5: Add missing columns (priority 120)
if err := w.writeAddColumns(schema); err != nil {
return err
}
// Phase 4: Create primary keys (priority 160) // Phase 4: Create primary keys (priority 160)
if err := w.writePrimaryKeys(schema); err != nil { if err := w.writePrimaryKeys(schema); err != nil {
return err return err
@@ -499,6 +573,44 @@ func (w *Writer) WriteTable(table *models.Table) error {
return w.WriteSchema(schema) return w.WriteSchema(schema)
} }
// WriteAddColumnStatements writes ALTER TABLE ADD COLUMN statements for a database
// This is used for schema evolution/migration when new columns are added
func (w *Writer) WriteAddColumnStatements(db *models.Database) error {
var writer io.Writer
var file *os.File
var err error
// Use existing writer if already set (for testing)
if w.writer != nil {
writer = w.writer
} else if w.options.OutputPath != "" {
// Determine output destination
file, err = os.Create(w.options.OutputPath)
if err != nil {
return fmt.Errorf("failed to create output file: %w", err)
}
defer file.Close()
writer = file
} else {
writer = os.Stdout
}
w.writer = writer
// Generate statements
statements, err := w.GenerateAddColumnsForDatabase(db)
if err != nil {
return err
}
// Write each statement
for _, stmt := range statements {
fmt.Fprintf(w.writer, "%s;\n\n", stmt)
}
return nil
}
// writeCreateSchema generates CREATE SCHEMA statement // writeCreateSchema generates CREATE SCHEMA statement
func (w *Writer) writeCreateSchema(schema *models.Schema) error { func (w *Writer) writeCreateSchema(schema *models.Schema) error {
if schema.Name == "public" { if schema.Name == "public" {
@@ -564,6 +676,35 @@ func (w *Writer) writeCreateTables(schema *models.Schema) error {
return nil return nil
} }
// writeAddColumns generates ALTER TABLE ADD COLUMN statements for missing columns
func (w *Writer) writeAddColumns(schema *models.Schema) error {
fmt.Fprintf(w.writer, "-- Add missing columns for schema: %s\n", schema.Name)
for _, table := range schema.Tables {
// Sort columns by sequence or name for consistent output
columns := getSortedColumns(table.Columns)
for _, col := range columns {
colDef := w.generateColumnDefinition(col)
// Generate DO block that checks if column exists before adding
fmt.Fprintf(w.writer, "DO $$\nBEGIN\n")
fmt.Fprintf(w.writer, " IF NOT EXISTS (\n")
fmt.Fprintf(w.writer, " SELECT 1 FROM information_schema.columns\n")
fmt.Fprintf(w.writer, " WHERE table_schema = '%s'\n", schema.Name)
fmt.Fprintf(w.writer, " AND table_name = '%s'\n", table.Name)
fmt.Fprintf(w.writer, " AND column_name = '%s'\n", col.Name)
fmt.Fprintf(w.writer, " ) THEN\n")
fmt.Fprintf(w.writer, " ALTER TABLE %s.%s ADD COLUMN %s;\n",
schema.SQLName(), table.SQLName(), colDef)
fmt.Fprintf(w.writer, " END IF;\n")
fmt.Fprintf(w.writer, "END;\n$$;\n\n")
}
}
return nil
}
// writePrimaryKeys generates ALTER TABLE statements for primary keys // writePrimaryKeys generates ALTER TABLE statements for primary keys
func (w *Writer) writePrimaryKeys(schema *models.Schema) error { func (w *Writer) writePrimaryKeys(schema *models.Schema) error {
fmt.Fprintf(w.writer, "-- Primary keys for schema: %s\n", schema.Name) fmt.Fprintf(w.writer, "-- Primary keys for schema: %s\n", schema.Name)

View File

@@ -438,3 +438,130 @@ func TestColumnSizeSpecifiers(t *testing.T) {
} }
} }
} }
func TestGenerateAddColumnStatements(t *testing.T) {
// Create a test database with tables that have new columns
db := models.InitDatabase("testdb")
schema := models.InitSchema("public")
// Create a table with columns
table := models.InitTable("users", "public")
// Existing column
idCol := models.InitColumn("id", "users", "public")
idCol.Type = "integer"
idCol.NotNull = true
idCol.Sequence = 1
table.Columns["id"] = idCol
// New column to be added
emailCol := models.InitColumn("email", "users", "public")
emailCol.Type = "varchar"
emailCol.Length = 255
emailCol.NotNull = true
emailCol.Sequence = 2
table.Columns["email"] = emailCol
// New column with default
statusCol := models.InitColumn("status", "users", "public")
statusCol.Type = "text"
statusCol.Default = "active"
statusCol.Sequence = 3
table.Columns["status"] = statusCol
schema.Tables = append(schema.Tables, table)
db.Schemas = append(db.Schemas, schema)
// Create writer
options := &writers.WriterOptions{}
writer := NewWriter(options)
// Generate ADD COLUMN statements
statements, err := writer.GenerateAddColumnsForDatabase(db)
if err != nil {
t.Fatalf("GenerateAddColumnsForDatabase failed: %v", err)
}
// Join all statements to verify content
output := strings.Join(statements, "\n")
t.Logf("Generated ADD COLUMN statements:\n%s", output)
// Verify expected elements
expectedStrings := []string{
"ALTER TABLE public.users ADD COLUMN id integer NOT NULL",
"ALTER TABLE public.users ADD COLUMN email varchar(255) NOT NULL",
"ALTER TABLE public.users ADD COLUMN status text DEFAULT 'active'",
"information_schema.columns",
"table_schema = 'public'",
"table_name = 'users'",
"column_name = 'id'",
"column_name = 'email'",
"column_name = 'status'",
}
for _, expected := range expectedStrings {
if !strings.Contains(output, expected) {
t.Errorf("Output missing expected string: %s\nFull output:\n%s", expected, output)
}
}
// Verify DO blocks are present for conditional adds
doBlockCount := strings.Count(output, "DO $$")
if doBlockCount < 3 {
t.Errorf("Expected at least 3 DO blocks (one per column), got %d", doBlockCount)
}
// Verify IF NOT EXISTS logic
ifNotExistsCount := strings.Count(output, "IF NOT EXISTS")
if ifNotExistsCount < 3 {
t.Errorf("Expected at least 3 IF NOT EXISTS checks (one per column), got %d", ifNotExistsCount)
}
}
func TestWriteAddColumnStatements(t *testing.T) {
// Create a test database
db := models.InitDatabase("testdb")
schema := models.InitSchema("public")
// Create a table with a new column to be added
table := models.InitTable("products", "public")
idCol := models.InitColumn("id", "products", "public")
idCol.Type = "integer"
table.Columns["id"] = idCol
// New column with various properties
descCol := models.InitColumn("description", "products", "public")
descCol.Type = "text"
descCol.NotNull = false
table.Columns["description"] = descCol
schema.Tables = append(schema.Tables, table)
db.Schemas = append(db.Schemas, schema)
// Create writer with output to buffer
var buf bytes.Buffer
options := &writers.WriterOptions{}
writer := NewWriter(options)
writer.writer = &buf
// Write ADD COLUMN statements
err := writer.WriteAddColumnStatements(db)
if err != nil {
t.Fatalf("WriteAddColumnStatements failed: %v", err)
}
output := buf.String()
t.Logf("Generated output:\n%s", output)
// Verify output contains expected elements
if !strings.Contains(output, "ALTER TABLE public.products ADD COLUMN id integer") {
t.Errorf("Output missing ADD COLUMN for id\nFull output:\n%s", output)
}
if !strings.Contains(output, "ALTER TABLE public.products ADD COLUMN description text") {
t.Errorf("Output missing ADD COLUMN for description\nFull output:\n%s", output)
}
if !strings.Contains(output, "DO $$") {
t.Errorf("Output missing DO block\nFull output:\n%s", output)
}
}