diff --git a/pkg/writers/pgsql/writer.go b/pkg/writers/pgsql/writer.go index 277b4e9..0ef7a8c 100644 --- a/pkg/writers/pgsql/writer.go +++ b/pkg/writers/pgsql/writer.go @@ -168,6 +168,13 @@ func (w *Writer) GenerateSchemaStatements(schema *models.Schema) ([]string, erro statements = append(statements, stmts...) } + // Phase 3.5: Add missing columns (for existing tables) + addColStmts, err := w.GenerateAddColumnStatements(schema) + if err != nil { + return nil, fmt.Errorf("failed to generate add column statements: %w", err) + } + statements = append(statements, addColStmts...) + // Phase 4: Primary keys for _, table := range schema.Tables { // First check for explicit PrimaryKeyConstraint @@ -351,6 +358,68 @@ func (w *Writer) GenerateSchemaStatements(schema *models.Schema) ([]string, erro return statements, nil } +// GenerateAddColumnStatements generates ALTER TABLE ADD COLUMN statements for existing tables +// This is useful for schema evolution when new columns are added to existing tables +func (w *Writer) GenerateAddColumnStatements(schema *models.Schema) ([]string, error) { + statements := []string{} + + statements = append(statements, fmt.Sprintf("-- Add missing columns for schema: %s", schema.Name)) + + for _, table := range schema.Tables { + // Sort columns by sequence or name for consistent output + columns := make([]*models.Column, 0, len(table.Columns)) + for _, col := range table.Columns { + columns = append(columns, col) + } + sort.Slice(columns, func(i, j int) bool { + if columns[i].Sequence != columns[j].Sequence { + return columns[i].Sequence < columns[j].Sequence + } + return columns[i].Name < columns[j].Name + }) + + for _, col := range columns { + colDef := w.generateColumnDefinition(col) + + // Generate DO block that checks if column exists before adding + stmt := fmt.Sprintf("DO $$\nBEGIN\n"+ + " IF NOT EXISTS (\n"+ + " SELECT 1 FROM information_schema.columns\n"+ + " WHERE table_schema = '%s'\n"+ + " AND table_name = '%s'\n"+ + " AND column_name = '%s'\n"+ + " ) THEN\n"+ + " ALTER TABLE %s.%s ADD COLUMN %s;\n"+ + " END IF;\n"+ + "END;\n$$", + schema.Name, table.Name, col.Name, + schema.SQLName(), table.SQLName(), colDef) + statements = append(statements, stmt) + } + } + + return statements, nil +} + +// GenerateAddColumnsForDatabase generates ALTER TABLE ADD COLUMN statements for the entire database +func (w *Writer) GenerateAddColumnsForDatabase(db *models.Database) ([]string, error) { + statements := []string{} + + statements = append(statements, "-- Add missing columns to existing tables") + statements = append(statements, fmt.Sprintf("-- Database: %s", db.Name)) + statements = append(statements, "-- Generated by RelSpec") + + for _, schema := range db.Schemas { + schemaStatements, err := w.GenerateAddColumnStatements(schema) + if err != nil { + return nil, fmt.Errorf("failed to generate add column statements for schema %s: %w", schema.Name, err) + } + statements = append(statements, schemaStatements...) + } + + return statements, nil +} + // generateCreateTableStatement generates CREATE TABLE statement func (w *Writer) generateCreateTableStatement(schema *models.Schema, table *models.Table) ([]string, error) { statements := []string{} @@ -373,7 +442,7 @@ func (w *Writer) generateCreateTableStatement(schema *models.Schema, table *mode columnDefs = append(columnDefs, " "+def) } - stmt := fmt.Sprintf("CREATE TABLE %s.%s (\n%s\n)", + stmt := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s.%s (\n%s\n)", schema.SQLName(), table.SQLName(), strings.Join(columnDefs, ",\n")) statements = append(statements, stmt) @@ -458,6 +527,11 @@ func (w *Writer) WriteSchema(schema *models.Schema) error { return err } + // Phase 3.5: Add missing columns (priority 120) + if err := w.writeAddColumns(schema); err != nil { + return err + } + // Phase 4: Create primary keys (priority 160) if err := w.writePrimaryKeys(schema); err != nil { return err @@ -499,6 +573,44 @@ func (w *Writer) WriteTable(table *models.Table) error { return w.WriteSchema(schema) } +// WriteAddColumnStatements writes ALTER TABLE ADD COLUMN statements for a database +// This is used for schema evolution/migration when new columns are added +func (w *Writer) WriteAddColumnStatements(db *models.Database) error { + var writer io.Writer + var file *os.File + var err error + + // Use existing writer if already set (for testing) + if w.writer != nil { + writer = w.writer + } else if w.options.OutputPath != "" { + // Determine output destination + file, err = os.Create(w.options.OutputPath) + if err != nil { + return fmt.Errorf("failed to create output file: %w", err) + } + defer file.Close() + writer = file + } else { + writer = os.Stdout + } + + w.writer = writer + + // Generate statements + statements, err := w.GenerateAddColumnsForDatabase(db) + if err != nil { + return err + } + + // Write each statement + for _, stmt := range statements { + fmt.Fprintf(w.writer, "%s;\n\n", stmt) + } + + return nil +} + // writeCreateSchema generates CREATE SCHEMA statement func (w *Writer) writeCreateSchema(schema *models.Schema) error { if schema.Name == "public" { @@ -564,6 +676,35 @@ func (w *Writer) writeCreateTables(schema *models.Schema) error { return nil } +// writeAddColumns generates ALTER TABLE ADD COLUMN statements for missing columns +func (w *Writer) writeAddColumns(schema *models.Schema) error { + fmt.Fprintf(w.writer, "-- Add missing columns for schema: %s\n", schema.Name) + + for _, table := range schema.Tables { + // Sort columns by sequence or name for consistent output + columns := getSortedColumns(table.Columns) + + for _, col := range columns { + colDef := w.generateColumnDefinition(col) + + // Generate DO block that checks if column exists before adding + fmt.Fprintf(w.writer, "DO $$\nBEGIN\n") + fmt.Fprintf(w.writer, " IF NOT EXISTS (\n") + fmt.Fprintf(w.writer, " SELECT 1 FROM information_schema.columns\n") + fmt.Fprintf(w.writer, " WHERE table_schema = '%s'\n", schema.Name) + fmt.Fprintf(w.writer, " AND table_name = '%s'\n", table.Name) + fmt.Fprintf(w.writer, " AND column_name = '%s'\n", col.Name) + fmt.Fprintf(w.writer, " ) THEN\n") + fmt.Fprintf(w.writer, " ALTER TABLE %s.%s ADD COLUMN %s;\n", + schema.SQLName(), table.SQLName(), colDef) + fmt.Fprintf(w.writer, " END IF;\n") + fmt.Fprintf(w.writer, "END;\n$$;\n\n") + } + } + + return nil +} + // writePrimaryKeys generates ALTER TABLE statements for primary keys func (w *Writer) writePrimaryKeys(schema *models.Schema) error { fmt.Fprintf(w.writer, "-- Primary keys for schema: %s\n", schema.Name) diff --git a/pkg/writers/pgsql/writer_test.go b/pkg/writers/pgsql/writer_test.go index 4344881..5a857e0 100644 --- a/pkg/writers/pgsql/writer_test.go +++ b/pkg/writers/pgsql/writer_test.go @@ -438,3 +438,130 @@ func TestColumnSizeSpecifiers(t *testing.T) { } } } + +func TestGenerateAddColumnStatements(t *testing.T) { + // Create a test database with tables that have new columns + db := models.InitDatabase("testdb") + schema := models.InitSchema("public") + + // Create a table with columns + table := models.InitTable("users", "public") + + // Existing column + idCol := models.InitColumn("id", "users", "public") + idCol.Type = "integer" + idCol.NotNull = true + idCol.Sequence = 1 + table.Columns["id"] = idCol + + // New column to be added + emailCol := models.InitColumn("email", "users", "public") + emailCol.Type = "varchar" + emailCol.Length = 255 + emailCol.NotNull = true + emailCol.Sequence = 2 + table.Columns["email"] = emailCol + + // New column with default + statusCol := models.InitColumn("status", "users", "public") + statusCol.Type = "text" + statusCol.Default = "active" + statusCol.Sequence = 3 + table.Columns["status"] = statusCol + + schema.Tables = append(schema.Tables, table) + db.Schemas = append(db.Schemas, schema) + + // Create writer + options := &writers.WriterOptions{} + writer := NewWriter(options) + + // Generate ADD COLUMN statements + statements, err := writer.GenerateAddColumnsForDatabase(db) + if err != nil { + t.Fatalf("GenerateAddColumnsForDatabase failed: %v", err) + } + + // Join all statements to verify content + output := strings.Join(statements, "\n") + t.Logf("Generated ADD COLUMN statements:\n%s", output) + + // Verify expected elements + expectedStrings := []string{ + "ALTER TABLE public.users ADD COLUMN id integer NOT NULL", + "ALTER TABLE public.users ADD COLUMN email varchar(255) NOT NULL", + "ALTER TABLE public.users ADD COLUMN status text DEFAULT 'active'", + "information_schema.columns", + "table_schema = 'public'", + "table_name = 'users'", + "column_name = 'id'", + "column_name = 'email'", + "column_name = 'status'", + } + + for _, expected := range expectedStrings { + if !strings.Contains(output, expected) { + t.Errorf("Output missing expected string: %s\nFull output:\n%s", expected, output) + } + } + + // Verify DO blocks are present for conditional adds + doBlockCount := strings.Count(output, "DO $$") + if doBlockCount < 3 { + t.Errorf("Expected at least 3 DO blocks (one per column), got %d", doBlockCount) + } + + // Verify IF NOT EXISTS logic + ifNotExistsCount := strings.Count(output, "IF NOT EXISTS") + if ifNotExistsCount < 3 { + t.Errorf("Expected at least 3 IF NOT EXISTS checks (one per column), got %d", ifNotExistsCount) + } +} + +func TestWriteAddColumnStatements(t *testing.T) { + // Create a test database + db := models.InitDatabase("testdb") + schema := models.InitSchema("public") + + // Create a table with a new column to be added + table := models.InitTable("products", "public") + + idCol := models.InitColumn("id", "products", "public") + idCol.Type = "integer" + table.Columns["id"] = idCol + + // New column with various properties + descCol := models.InitColumn("description", "products", "public") + descCol.Type = "text" + descCol.NotNull = false + table.Columns["description"] = descCol + + schema.Tables = append(schema.Tables, table) + db.Schemas = append(db.Schemas, schema) + + // Create writer with output to buffer + var buf bytes.Buffer + options := &writers.WriterOptions{} + writer := NewWriter(options) + writer.writer = &buf + + // Write ADD COLUMN statements + err := writer.WriteAddColumnStatements(db) + if err != nil { + t.Fatalf("WriteAddColumnStatements failed: %v", err) + } + + output := buf.String() + t.Logf("Generated output:\n%s", output) + + // Verify output contains expected elements + if !strings.Contains(output, "ALTER TABLE public.products ADD COLUMN id integer") { + t.Errorf("Output missing ADD COLUMN for id\nFull output:\n%s", output) + } + if !strings.Contains(output, "ALTER TABLE public.products ADD COLUMN description text") { + t.Errorf("Output missing ADD COLUMN for description\nFull output:\n%s", output) + } + if !strings.Contains(output, "DO $$") { + t.Errorf("Output missing DO block\nFull output:\n%s", output) + } +}