From 40bc0be1cbeb6494356dd3695f6f03330813ef11 Mon Sep 17 00:00:00 2001 From: Hein Date: Wed, 17 Dec 2025 11:05:42 +0200 Subject: [PATCH] SQL Writer --- pkg/writers/pgsql/MIGRATION.md | 258 ++++++++ pkg/writers/pgsql/migration_writer.go | 668 +++++++++++++++++++++ pkg/writers/pgsql/migration_writer_test.go | 287 +++++++++ pkg/writers/pgsql/writer.go | 496 +++++++++++++++ pkg/writers/pgsql/writer_test.go | 243 ++++++++ 5 files changed, 1952 insertions(+) create mode 100644 pkg/writers/pgsql/MIGRATION.md create mode 100644 pkg/writers/pgsql/migration_writer.go create mode 100644 pkg/writers/pgsql/migration_writer_test.go create mode 100644 pkg/writers/pgsql/writer.go create mode 100644 pkg/writers/pgsql/writer_test.go diff --git a/pkg/writers/pgsql/MIGRATION.md b/pkg/writers/pgsql/MIGRATION.md new file mode 100644 index 0000000..ce9d179 --- /dev/null +++ b/pkg/writers/pgsql/MIGRATION.md @@ -0,0 +1,258 @@ +# PostgreSQL Migration Writer + +## Overview + +The PostgreSQL Migration Writer implements database schema inspection and differential migration generation, following the same approach as the `pgsql_meta_upgrade` migration system. It compares a desired model (target schema) against the current database state and generates the necessary SQL migration scripts. + +## Migration Phases + +The migration writer follows a phased approach with specific priorities to ensure proper execution order: + +### Phase 1: Drops (Priority 11-50) +- Drop changed constraints (Priority 11) +- Drop changed indexes (Priority 20) +- Drop changed foreign keys (Priority 50) + +### Phase 2: Renames (Priority 60-90) +- Rename tables (Priority 60) +- Rename columns (Priority 90) +- *Note: Currently requires manual handling or metadata for rename detection* + +### Phase 3: Tables & Columns (Priority 100-145) +- Create new tables (Priority 100) +- Add new columns (Priority 120) +- Alter column types (Priority 120) +- Alter column defaults (Priority 145) + +### Phase 4: Indexes (Priority 160-180) +- Create primary keys (Priority 160) +- Create indexes (Priority 180) + +### Phase 5: Foreign Keys (Priority 195) +- Create foreign key constraints + +### Phase 6: Comments (Priority 200+) +- Add table and column comments + +## Usage + +### 1. Inspect Current Database + +```go +import ( + "git.warky.dev/wdevs/relspecgo/pkg/readers" + "git.warky.dev/wdevs/relspecgo/pkg/readers/pgsql" +) + +// Create reader with connection string +options := &readers.ReaderOptions{ + ConnectionString: "host=localhost port=5432 dbname=mydb user=postgres password=secret", +} + +reader := pgsql.NewReader(options) + +// Read current database state +currentDB, err := reader.ReadDatabase() +if err != nil { + log.Fatal(err) +} +``` + +### 2. Define Desired Model + +```go +import "git.warky.dev/wdevs/relspecgo/pkg/models" + +// Create desired model (could be loaded from DBML, JSON, etc.) +modelDB := models.InitDatabase("mydb") +schema := models.InitSchema("public") + +// Define table +table := models.InitTable("users", "public") +table.Description = "User accounts" + +// Add columns +idCol := models.InitColumn("id", "users", "public") +idCol.Type = "integer" +idCol.NotNull = true +idCol.IsPrimaryKey = true +table.Columns["id"] = idCol + +nameCol := models.InitColumn("name", "users", "public") +nameCol.Type = "text" +nameCol.NotNull = true +table.Columns["name"] = nameCol + +emailCol := models.InitColumn("email", "users", "public") +emailCol.Type = "text" +table.Columns["email"] = emailCol + +// Add primary key constraint +pkConstraint := &models.Constraint{ + Name: "pk_users", + Type: models.PrimaryKeyConstraint, + Columns: []string{"id"}, +} +table.Constraints["pk_users"] = pkConstraint + +// Add unique index +emailIndex := &models.Index{ + Name: "uk_users_email", + Unique: true, + Columns: []string{"email"}, +} +table.Indexes["uk_users_email"] = emailIndex + +schema.Tables = append(schema.Tables, table) +modelDB.Schemas = append(modelDB.Schemas, schema) +``` + +### 3. Generate Migration + +```go +import ( + "git.warky.dev/wdevs/relspecgo/pkg/writers" + "git.warky.dev/wdevs/relspecgo/pkg/writers/pgsql" +) + +// Create migration writer +writerOptions := &writers.WriterOptions{ + OutputPath: "migration_001.sql", +} + +migrationWriter := pgsql.NewMigrationWriter(writerOptions) + +// Generate migration comparing model vs current +err = migrationWriter.WriteMigration(modelDB, currentDB) +if err != nil { + log.Fatal(err) +} +``` + +## Example Migration Output + +```sql +-- PostgreSQL Migration Script +-- Generated by RelSpec +-- Source: mydb -> mydb + +-- Priority: 11 | Type: drop constraint | Object: public.users.old_constraint +ALTER TABLE public.users DROP CONSTRAINT IF EXISTS old_constraint; + +-- Priority: 100 | Type: create table | Object: public.orders +CREATE TABLE IF NOT EXISTS public.orders ( + id integer NOT NULL, + user_id integer, + total numeric(10,2) DEFAULT 0.00, + created_at timestamp DEFAULT CURRENT_TIMESTAMP +); + +-- Priority: 120 | Type: create column | Object: public.users.phone +ALTER TABLE public.users + ADD COLUMN IF NOT EXISTS phone text; + +-- Priority: 120 | Type: alter column type | Object: public.users.age +ALTER TABLE public.users + ALTER COLUMN age TYPE integer; + +-- Priority: 160 | Type: create primary key | Object: public.orders.pk_orders +DO $$ +BEGIN + IF NOT EXISTS ( + SELECT 1 FROM information_schema.table_constraints + WHERE table_schema = 'public' + AND table_name = 'orders' + AND constraint_name = 'pk_orders' + ) THEN + ALTER TABLE public.orders + ADD CONSTRAINT pk_orders PRIMARY KEY (id); + END IF; +END; +$$; + +-- Priority: 180 | Type: create index | Object: public.users.idx_users_email +CREATE INDEX IF NOT EXISTS idx_users_email + ON public.users USING btree (email); + +-- Priority: 195 | Type: create foreign key | Object: public.orders.fk_orders_users +ALTER TABLE public.orders + DROP CONSTRAINT IF EXISTS fk_orders_users; + +ALTER TABLE public.orders + ADD CONSTRAINT fk_orders_users + FOREIGN KEY (user_id) + REFERENCES public.users (id) + ON DELETE CASCADE + ON UPDATE CASCADE + DEFERRABLE; + +-- Priority: 200 | Type: comment on table | Object: public.users +COMMENT ON TABLE public.users IS 'User accounts'; + +-- Priority: 200 | Type: comment on column | Object: public.users.email +COMMENT ON COLUMN public.users.email IS 'User email address'; +``` + +## Migration Script Structure + +Each migration script includes: + +- **ObjectName**: Fully qualified name of the object being modified +- **ObjectType**: Type of operation (create table, alter column, etc.) +- **Schema**: Schema name +- **Priority**: Execution order priority (lower runs first) +- **Sequence**: Sub-ordering within same priority +- **Body**: The actual SQL statement + +## Comparison Logic + +The migration writer compares objects using: + +### Tables +- Existence check by name (case-insensitive) +- New tables generate CREATE TABLE statements + +### Columns +- Existence check within tables +- Type changes generate ALTER COLUMN TYPE +- Default value changes generate SET/DROP DEFAULT +- New columns generate ADD COLUMN + +### Constraints +- Compared by type, columns, and referenced objects +- Changed constraints are dropped and recreated + +### Indexes +- Compared by uniqueness and column list +- Changed indexes are dropped and recreated + +### Foreign Keys +- Compared by columns, referenced table/columns, and actions +- Changed foreign keys are dropped and recreated + +## Best Practices + +1. **Always Review Generated Migrations**: Manually review SQL before execution +2. **Test on Non-Production First**: Apply migrations to development/staging environments first +3. **Backup Before Migration**: Create database backup before running migrations +4. **Use Transactions**: Wrap migrations in transactions when possible +5. **Handle Renames Carefully**: Column/table renames may appear as DROP + CREATE without metadata +6. **Consider Data Migration**: Generated SQL handles structure only; data migration may be needed + +## Limitations + +1. **Rename Detection**: Automatic rename detection not implemented; requires GUID or metadata matching +2. **Data Type Conversions**: Some type changes may require custom USING clauses +3. **Complex Constraints**: CHECK constraints with complex expressions may need manual handling +4. **Sequence Values**: Current sequence values not automatically synced +5. **Permissions**: Schema and object permissions not included in migrations + +## Integration with Migration System + +This implementation follows the same logic as the SQL migration system in `examples/pgsql_meta_upgrade`: + +- `migration_inspect.sql` → Reader (pkg/readers/pgsql) +- `migration_build.sql` → MigrationWriter (pkg/writers/pgsql) +- `migration_run.sql` → External execution (psql, application code) + +The phases, priorities, and script generation logic match the original migration system to ensure compatibility and consistency. diff --git a/pkg/writers/pgsql/migration_writer.go b/pkg/writers/pgsql/migration_writer.go new file mode 100644 index 0000000..7c54e49 --- /dev/null +++ b/pkg/writers/pgsql/migration_writer.go @@ -0,0 +1,668 @@ +package pgsql + +import ( + "fmt" + "io" + "os" + "sort" + "strings" + + "git.warky.dev/wdevs/relspecgo/pkg/models" + "git.warky.dev/wdevs/relspecgo/pkg/writers" +) + +// MigrationWriter generates differential migration SQL scripts +type MigrationWriter struct { + options *writers.WriterOptions + writer io.Writer +} + +// MigrationScript represents a single migration script with priority and sequence +type MigrationScript struct { + ObjectName string + ObjectType string + Schema string + Priority int + Sequence int + Body string +} + +// NewMigrationWriter creates a new migration writer +func NewMigrationWriter(options *writers.WriterOptions) *MigrationWriter { + return &MigrationWriter{ + options: options, + } +} + +// WriteMigration generates migration scripts by comparing model (desired) vs current (actual) database +func (w *MigrationWriter) WriteMigration(model *models.Database, current *models.Database) error { + var writer io.Writer + var file *os.File + var err error + + // Use existing writer if already set (for testing) + if w.writer != nil { + writer = w.writer + } else if w.options.OutputPath != "" { + file, err = os.Create(w.options.OutputPath) + if err != nil { + return fmt.Errorf("failed to create output file: %w", err) + } + defer file.Close() + writer = file + } else { + writer = os.Stdout + } + + w.writer = writer + + // Generate all migration scripts + scripts := make([]MigrationScript, 0) + + // Process each schema in the model + for _, modelSchema := range model.Schemas { + // Find corresponding schema in current database + var currentSchema *models.Schema + for _, cs := range current.Schemas { + if strings.EqualFold(cs.Name, modelSchema.Name) { + currentSchema = cs + break + } + } + + // Generate schema-level scripts + schemaScripts := w.generateSchemaScripts(modelSchema, currentSchema) + scripts = append(scripts, schemaScripts...) + } + + // Sort scripts by priority and sequence + sort.Slice(scripts, func(i, j int) bool { + if scripts[i].Priority != scripts[j].Priority { + return scripts[i].Priority < scripts[j].Priority + } + return scripts[i].Sequence < scripts[j].Sequence + }) + + // Write header + fmt.Fprintf(w.writer, "-- PostgreSQL Migration Script\n") + fmt.Fprintf(w.writer, "-- Generated by RelSpec\n") + fmt.Fprintf(w.writer, "-- Source: %s -> %s\n\n", current.Name, model.Name) + + // Write scripts + for _, script := range scripts { + fmt.Fprintf(w.writer, "-- Priority: %d | Type: %s | Object: %s\n", + script.Priority, script.ObjectType, script.ObjectName) + fmt.Fprintf(w.writer, "%s\n\n", script.Body) + } + + return nil +} + +// generateSchemaScripts generates migration scripts for a schema +func (w *MigrationWriter) generateSchemaScripts(model *models.Schema, current *models.Schema) []MigrationScript { + scripts := make([]MigrationScript, 0) + + // Phase 1: Drop constraints and indexes that changed (Priority 11-50) + if current != nil { + scripts = append(scripts, w.generateDropScripts(model, current)...) + } + + // Phase 2: Rename tables and columns (Priority 60-90) + if current != nil { + scripts = append(scripts, w.generateRenameScripts(model, current)...) + } + + // Phase 3: Create/Alter tables and columns (Priority 100-145) + scripts = append(scripts, w.generateTableScripts(model, current)...) + + // Phase 4: Create indexes (Priority 160-180) + scripts = append(scripts, w.generateIndexScripts(model, current)...) + + // Phase 5: Create foreign keys (Priority 195) + scripts = append(scripts, w.generateForeignKeyScripts(model, current)...) + + // Phase 6: Add comments (Priority 200+) + scripts = append(scripts, w.generateCommentScripts(model, current)...) + + return scripts +} + +// generateDropScripts generates DROP scripts for removed/changed objects +func (w *MigrationWriter) generateDropScripts(model *models.Schema, current *models.Schema) []MigrationScript { + scripts := make([]MigrationScript, 0) + + // Build map of model tables for quick lookup + modelTables := make(map[string]*models.Table) + for _, table := range model.Tables { + modelTables[strings.ToLower(table.Name)] = table + } + + // Find constraints to drop + for _, currentTable := range current.Tables { + modelTable, existsInModel := modelTables[strings.ToLower(currentTable.Name)] + + if !existsInModel { + // Table will be dropped, skip individual constraint drops + continue + } + + // Check each constraint in current database + for constraintName, currentConstraint := range currentTable.Constraints { + // Check if constraint exists in model + modelConstraint, existsInModel := modelTable.Constraints[constraintName] + + shouldDrop := false + + if !existsInModel { + shouldDrop = true + } else if !constraintsEqual(modelConstraint, currentConstraint) { + // Constraint changed, drop and recreate + shouldDrop = true + } + + if shouldDrop { + script := MigrationScript{ + ObjectName: fmt.Sprintf("%s.%s.%s", current.Name, currentTable.Name, constraintName), + ObjectType: "drop constraint", + Schema: current.Name, + Priority: 11, + Sequence: len(scripts), + Body: fmt.Sprintf( + "ALTER TABLE %s.%s DROP CONSTRAINT IF EXISTS %s;", + current.Name, currentTable.Name, constraintName, + ), + } + scripts = append(scripts, script) + } + } + + // Check indexes + for indexName, currentIndex := range currentTable.Indexes { + modelIndex, existsInModel := modelTable.Indexes[indexName] + + shouldDrop := false + + if !existsInModel { + shouldDrop = true + } else if !indexesEqual(modelIndex, currentIndex) { + shouldDrop = true + } + + if shouldDrop { + script := MigrationScript{ + ObjectName: fmt.Sprintf("%s.%s.%s", current.Name, currentTable.Name, indexName), + ObjectType: "drop index", + Schema: current.Name, + Priority: 20, + Sequence: len(scripts), + Body: fmt.Sprintf( + "DROP INDEX IF EXISTS %s.%s CASCADE;", + current.Name, indexName, + ), + } + scripts = append(scripts, script) + } + } + } + + return scripts +} + +// generateRenameScripts generates RENAME scripts for renamed objects +func (w *MigrationWriter) generateRenameScripts(model *models.Schema, current *models.Schema) []MigrationScript { + scripts := make([]MigrationScript, 0) + + // For now, we don't attempt to detect renames automatically + // This would require GUID matching or other heuristics + // Users would need to handle renames manually or through metadata + + // Suppress unused parameter warnings + _ = model + _ = current + + return scripts +} + +// generateTableScripts generates CREATE/ALTER TABLE scripts +func (w *MigrationWriter) generateTableScripts(model *models.Schema, current *models.Schema) []MigrationScript { + scripts := make([]MigrationScript, 0) + + // Build map of current tables + currentTables := make(map[string]*models.Table) + if current != nil { + for _, table := range current.Tables { + currentTables[strings.ToLower(table.Name)] = table + } + } + + // Process each model table + for _, modelTable := range model.Tables { + currentTable, exists := currentTables[strings.ToLower(modelTable.Name)] + + if !exists { + // Table doesn't exist, create it + script := w.generateCreateTableScript(model, modelTable) + scripts = append(scripts, script) + } else { + // Table exists, check for column changes + alterScripts := w.generateAlterTableScripts(model, modelTable, currentTable) + scripts = append(scripts, alterScripts...) + } + } + + return scripts +} + +// generateCreateTableScript generates a CREATE TABLE script +func (w *MigrationWriter) generateCreateTableScript(schema *models.Schema, table *models.Table) MigrationScript { + var body strings.Builder + + body.WriteString(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s.%s (\n", schema.Name, table.Name)) + + // Get sorted columns + columns := getSortedColumns(table.Columns) + columnDefs := make([]string, 0, len(columns)) + + for _, col := range columns { + colDef := fmt.Sprintf(" %s %s", col.Name, col.Type) + + // Add default value if present + if col.Default != nil { + colDef += fmt.Sprintf(" DEFAULT %v", col.Default) + } + + // Add NOT NULL if needed + if col.NotNull { + colDef += " NOT NULL" + } + + columnDefs = append(columnDefs, colDef) + } + + body.WriteString(strings.Join(columnDefs, ",\n")) + body.WriteString("\n);") + + return MigrationScript{ + ObjectName: fmt.Sprintf("%s.%s", schema.Name, table.Name), + ObjectType: "create table", + Schema: schema.Name, + Priority: 100, + Sequence: 0, + Body: body.String(), + } +} + +// generateAlterTableScripts generates ALTER TABLE scripts for column changes +func (w *MigrationWriter) generateAlterTableScripts(schema *models.Schema, modelTable *models.Table, currentTable *models.Table) []MigrationScript { + scripts := make([]MigrationScript, 0) + + // Build map of current columns + currentColumns := make(map[string]*models.Column) + for name, col := range currentTable.Columns { + currentColumns[strings.ToLower(name)] = col + } + + // Check each model column + for _, modelCol := range modelTable.Columns { + currentCol, exists := currentColumns[strings.ToLower(modelCol.Name)] + + if !exists { + // Column doesn't exist, add it + script := MigrationScript{ + ObjectName: fmt.Sprintf("%s.%s.%s", schema.Name, modelTable.Name, modelCol.Name), + ObjectType: "create column", + Schema: schema.Name, + Priority: 120, + Sequence: len(scripts), + Body: fmt.Sprintf( + "ALTER TABLE %s.%s\n ADD COLUMN IF NOT EXISTS %s %s%s%s;", + schema.Name, modelTable.Name, modelCol.Name, modelCol.Type, + func() string { + if modelCol.Default != nil { + return fmt.Sprintf(" DEFAULT %v", modelCol.Default) + } + return "" + }(), + func() string { + if modelCol.NotNull { + return " NOT NULL" + } + return "" + }(), + ), + } + scripts = append(scripts, script) + } else if !columnsEqual(modelCol, currentCol) { + // Column exists but type or properties changed + if modelCol.Type != currentCol.Type { + script := MigrationScript{ + ObjectName: fmt.Sprintf("%s.%s.%s", schema.Name, modelTable.Name, modelCol.Name), + ObjectType: "alter column type", + Schema: schema.Name, + Priority: 120, + Sequence: len(scripts), + Body: fmt.Sprintf( + "ALTER TABLE %s.%s\n ALTER COLUMN %s TYPE %s;", + schema.Name, modelTable.Name, modelCol.Name, modelCol.Type, + ), + } + scripts = append(scripts, script) + } + + // Check default value changes + if fmt.Sprintf("%v", modelCol.Default) != fmt.Sprintf("%v", currentCol.Default) { + if modelCol.Default != nil { + script := MigrationScript{ + ObjectName: fmt.Sprintf("%s.%s.%s", schema.Name, modelTable.Name, modelCol.Name), + ObjectType: "alter column default", + Schema: schema.Name, + Priority: 145, + Sequence: len(scripts), + Body: fmt.Sprintf( + "ALTER TABLE %s.%s\n ALTER COLUMN %s SET DEFAULT %v;", + schema.Name, modelTable.Name, modelCol.Name, modelCol.Default, + ), + } + scripts = append(scripts, script) + } else { + script := MigrationScript{ + ObjectName: fmt.Sprintf("%s.%s.%s", schema.Name, modelTable.Name, modelCol.Name), + ObjectType: "alter column default", + Schema: schema.Name, + Priority: 145, + Sequence: len(scripts), + Body: fmt.Sprintf( + "ALTER TABLE %s.%s\n ALTER COLUMN %s DROP DEFAULT;", + schema.Name, modelTable.Name, modelCol.Name, + ), + } + scripts = append(scripts, script) + } + } + } + } + + return scripts +} + +// generateIndexScripts generates CREATE INDEX scripts +func (w *MigrationWriter) generateIndexScripts(model *models.Schema, current *models.Schema) []MigrationScript { + scripts := make([]MigrationScript, 0) + + // Build map of current tables + currentTables := make(map[string]*models.Table) + if current != nil { + for _, table := range current.Tables { + currentTables[strings.ToLower(table.Name)] = table + } + } + + // Process each model table + for _, modelTable := range model.Tables { + currentTable := currentTables[strings.ToLower(modelTable.Name)] + + // Process each index in model + for indexName, modelIndex := range modelTable.Indexes { + shouldCreate := true + + // Check if index exists in current + if currentTable != nil { + if currentIndex, exists := currentTable.Indexes[indexName]; exists { + if indexesEqual(modelIndex, currentIndex) { + shouldCreate = false + } + } + } + + if shouldCreate { + unique := "" + if modelIndex.Unique { + unique = "UNIQUE " + } + + indexType := "btree" + if modelIndex.Type != "" { + indexType = modelIndex.Type + } + + script := MigrationScript{ + ObjectName: fmt.Sprintf("%s.%s.%s", model.Name, modelTable.Name, indexName), + ObjectType: "create index", + Schema: model.Name, + Priority: 180, + Sequence: len(scripts), + Body: fmt.Sprintf( + "CREATE %sINDEX IF NOT EXISTS %s\n ON %s.%s USING %s (%s);", + unique, indexName, model.Name, modelTable.Name, indexType, + strings.Join(modelIndex.Columns, ", "), + ), + } + scripts = append(scripts, script) + } + } + + // Add primary key constraint if it exists + for constraintName, constraint := range modelTable.Constraints { + if constraint.Type == models.PrimaryKeyConstraint { + shouldCreate := true + + if currentTable != nil { + if currentConstraint, exists := currentTable.Constraints[constraintName]; exists { + if constraintsEqual(constraint, currentConstraint) { + shouldCreate = false + } + } + } + + if shouldCreate { + script := MigrationScript{ + ObjectName: fmt.Sprintf("%s.%s.%s", model.Name, modelTable.Name, constraintName), + ObjectType: "create primary key", + Schema: model.Name, + Priority: 160, + Sequence: len(scripts), + Body: fmt.Sprintf( + "DO $$\nBEGIN\n IF NOT EXISTS (\n"+ + " SELECT 1 FROM information_schema.table_constraints\n"+ + " WHERE table_schema = '%s'\n"+ + " AND table_name = '%s'\n"+ + " AND constraint_name = '%s'\n"+ + " ) THEN\n"+ + " ALTER TABLE %s.%s\n"+ + " ADD CONSTRAINT %s PRIMARY KEY (%s);\n"+ + " END IF;\n"+ + "END;\n$$;", + model.Name, modelTable.Name, constraintName, + model.Name, modelTable.Name, constraintName, + strings.Join(constraint.Columns, ", "), + ), + } + scripts = append(scripts, script) + } + } + } + } + + return scripts +} + +// generateForeignKeyScripts generates ADD CONSTRAINT FOREIGN KEY scripts +func (w *MigrationWriter) generateForeignKeyScripts(model *models.Schema, current *models.Schema) []MigrationScript { + scripts := make([]MigrationScript, 0) + + // Build map of current tables + currentTables := make(map[string]*models.Table) + if current != nil { + for _, table := range current.Tables { + currentTables[strings.ToLower(table.Name)] = table + } + } + + // Process each model table + for _, modelTable := range model.Tables { + currentTable := currentTables[strings.ToLower(modelTable.Name)] + + // Process each constraint + for constraintName, constraint := range modelTable.Constraints { + if constraint.Type != models.ForeignKeyConstraint { + continue + } + + shouldCreate := true + + // Check if constraint exists in current + if currentTable != nil { + if currentConstraint, exists := currentTable.Constraints[constraintName]; exists { + if constraintsEqual(constraint, currentConstraint) { + shouldCreate = false + } + } + } + + if shouldCreate { + onDelete := "NO ACTION" + if constraint.OnDelete != "" { + onDelete = strings.ToUpper(constraint.OnDelete) + } + + onUpdate := "NO ACTION" + if constraint.OnUpdate != "" { + onUpdate = strings.ToUpper(constraint.OnUpdate) + } + + script := MigrationScript{ + ObjectName: fmt.Sprintf("%s.%s.%s", model.Name, modelTable.Name, constraintName), + ObjectType: "create foreign key", + Schema: model.Name, + Priority: 195, + Sequence: len(scripts), + Body: fmt.Sprintf( + "ALTER TABLE %s.%s\n"+ + " DROP CONSTRAINT IF EXISTS %s;\n\n"+ + "ALTER TABLE %s.%s\n"+ + " ADD CONSTRAINT %s\n"+ + " FOREIGN KEY (%s)\n"+ + " REFERENCES %s.%s (%s)\n"+ + " ON DELETE %s\n"+ + " ON UPDATE %s\n"+ + " DEFERRABLE;", + model.Name, modelTable.Name, constraintName, + model.Name, modelTable.Name, constraintName, + strings.Join(constraint.Columns, ", "), + constraint.ReferencedSchema, constraint.ReferencedTable, + strings.Join(constraint.ReferencedColumns, ", "), + onDelete, onUpdate, + ), + } + scripts = append(scripts, script) + } + } + } + + return scripts +} + +// generateCommentScripts generates COMMENT ON scripts +func (w *MigrationWriter) generateCommentScripts(model *models.Schema, current *models.Schema) []MigrationScript { + scripts := make([]MigrationScript, 0) + + // Suppress unused parameter warning (current not used yet, could be used for diffing) + _ = current + + // Process each model table + for _, modelTable := range model.Tables { + // Table comment + if modelTable.Description != "" { + script := MigrationScript{ + ObjectName: fmt.Sprintf("%s.%s", model.Name, modelTable.Name), + ObjectType: "comment on table", + Schema: model.Name, + Priority: 200, + Sequence: len(scripts), + Body: fmt.Sprintf( + "COMMENT ON TABLE %s.%s IS '%s';", + model.Name, modelTable.Name, escapeQuote(modelTable.Description), + ), + } + scripts = append(scripts, script) + } + + // Column comments + for _, col := range modelTable.Columns { + if col.Description != "" { + script := MigrationScript{ + ObjectName: fmt.Sprintf("%s.%s.%s", model.Name, modelTable.Name, col.Name), + ObjectType: "comment on column", + Schema: model.Name, + Priority: 200, + Sequence: len(scripts), + Body: fmt.Sprintf( + "COMMENT ON COLUMN %s.%s.%s IS '%s';", + model.Name, modelTable.Name, col.Name, escapeQuote(col.Description), + ), + } + scripts = append(scripts, script) + } + } + } + + return scripts +} + +// Comparison helper functions + +func constraintsEqual(a, b *models.Constraint) bool { + if a.Type != b.Type { + return false + } + if len(a.Columns) != len(b.Columns) { + return false + } + for i := range a.Columns { + if !strings.EqualFold(a.Columns[i], b.Columns[i]) { + return false + } + } + if a.Type == models.ForeignKeyConstraint { + if a.ReferencedTable != b.ReferencedTable || a.ReferencedSchema != b.ReferencedSchema { + return false + } + if len(a.ReferencedColumns) != len(b.ReferencedColumns) { + return false + } + for i := range a.ReferencedColumns { + if !strings.EqualFold(a.ReferencedColumns[i], b.ReferencedColumns[i]) { + return false + } + } + } + return true +} + +func indexesEqual(a, b *models.Index) bool { + if a.Unique != b.Unique { + return false + } + if len(a.Columns) != len(b.Columns) { + return false + } + for i := range a.Columns { + if !strings.EqualFold(a.Columns[i], b.Columns[i]) { + return false + } + } + return true +} + +func columnsEqual(a, b *models.Column) bool { + if a.Type != b.Type { + return false + } + if a.NotNull != b.NotNull { + return false + } + if fmt.Sprintf("%v", a.Default) != fmt.Sprintf("%v", b.Default) { + return false + } + return true +} diff --git a/pkg/writers/pgsql/migration_writer_test.go b/pkg/writers/pgsql/migration_writer_test.go new file mode 100644 index 0000000..83eb156 --- /dev/null +++ b/pkg/writers/pgsql/migration_writer_test.go @@ -0,0 +1,287 @@ +package pgsql + +import ( + "bytes" + "strings" + "testing" + + "git.warky.dev/wdevs/relspecgo/pkg/models" + "git.warky.dev/wdevs/relspecgo/pkg/writers" +) + +func TestWriteMigration_NewTable(t *testing.T) { + // Current database (empty) + current := models.InitDatabase("testdb") + currentSchema := models.InitSchema("public") + current.Schemas = append(current.Schemas, currentSchema) + + // Model database (with new table) + model := models.InitDatabase("testdb") + modelSchema := models.InitSchema("public") + + table := models.InitTable("users", "public") + idCol := models.InitColumn("id", "users", "public") + idCol.Type = "integer" + idCol.NotNull = true + table.Columns["id"] = idCol + + nameCol := models.InitColumn("name", "users", "public") + nameCol.Type = "text" + table.Columns["name"] = nameCol + + modelSchema.Tables = append(modelSchema.Tables, table) + model.Schemas = append(model.Schemas, modelSchema) + + // Generate migration + var buf bytes.Buffer + writer := NewMigrationWriter(&writers.WriterOptions{}) + writer.writer = &buf + + err := writer.WriteMigration(model, current) + if err != nil { + t.Fatalf("WriteMigration failed: %v", err) + } + + output := buf.String() + t.Logf("Generated migration:\n%s", output) + + // Verify CREATE TABLE is present + if !strings.Contains(output, "CREATE TABLE") { + t.Error("Migration missing CREATE TABLE statement") + } + if !strings.Contains(output, "users") { + t.Error("Migration missing table name 'users'") + } +} + +func TestWriteMigration_AddColumn(t *testing.T) { + // Current database (with table but missing column) + current := models.InitDatabase("testdb") + currentSchema := models.InitSchema("public") + currentTable := models.InitTable("users", "public") + + idCol := models.InitColumn("id", "users", "public") + idCol.Type = "integer" + currentTable.Columns["id"] = idCol + + currentSchema.Tables = append(currentSchema.Tables, currentTable) + current.Schemas = append(current.Schemas, currentSchema) + + // Model database (with additional column) + model := models.InitDatabase("testdb") + modelSchema := models.InitSchema("public") + modelTable := models.InitTable("users", "public") + + idCol2 := models.InitColumn("id", "users", "public") + idCol2.Type = "integer" + modelTable.Columns["id"] = idCol2 + + emailCol := models.InitColumn("email", "users", "public") + emailCol.Type = "text" + modelTable.Columns["email"] = emailCol + + modelSchema.Tables = append(modelSchema.Tables, modelTable) + model.Schemas = append(model.Schemas, modelSchema) + + // Generate migration + var buf bytes.Buffer + writer := NewMigrationWriter(&writers.WriterOptions{}) + writer.writer = &buf + + err := writer.WriteMigration(model, current) + if err != nil { + t.Fatalf("WriteMigration failed: %v", err) + } + + output := buf.String() + t.Logf("Generated migration:\n%s", output) + + // Verify ADD COLUMN is present + if !strings.Contains(output, "ADD COLUMN") { + t.Error("Migration missing ADD COLUMN statement") + } + if !strings.Contains(output, "email") { + t.Error("Migration missing column name 'email'") + } +} + +func TestWriteMigration_ChangeColumnType(t *testing.T) { + // Current database (with integer column) + current := models.InitDatabase("testdb") + currentSchema := models.InitSchema("public") + currentTable := models.InitTable("users", "public") + + idCol := models.InitColumn("id", "users", "public") + idCol.Type = "integer" + currentTable.Columns["id"] = idCol + + currentSchema.Tables = append(currentSchema.Tables, currentTable) + current.Schemas = append(current.Schemas, currentSchema) + + // Model database (changed to bigint) + model := models.InitDatabase("testdb") + modelSchema := models.InitSchema("public") + modelTable := models.InitTable("users", "public") + + idCol2 := models.InitColumn("id", "users", "public") + idCol2.Type = "bigint" + modelTable.Columns["id"] = idCol2 + + modelSchema.Tables = append(modelSchema.Tables, modelTable) + model.Schemas = append(model.Schemas, modelSchema) + + // Generate migration + var buf bytes.Buffer + writer := NewMigrationWriter(&writers.WriterOptions{}) + writer.writer = &buf + + err := writer.WriteMigration(model, current) + if err != nil { + t.Fatalf("WriteMigration failed: %v", err) + } + + output := buf.String() + t.Logf("Generated migration:\n%s", output) + + // Verify ALTER COLUMN TYPE is present + if !strings.Contains(output, "ALTER COLUMN") { + t.Error("Migration missing ALTER COLUMN statement") + } + if !strings.Contains(output, "TYPE bigint") { + t.Error("Migration missing TYPE bigint") + } +} + +func TestWriteMigration_AddForeignKey(t *testing.T) { + // Current database (two tables, no relationship) + current := models.InitDatabase("testdb") + currentSchema := models.InitSchema("public") + + usersTable := models.InitTable("users", "public") + idCol := models.InitColumn("id", "users", "public") + idCol.Type = "integer" + usersTable.Columns["id"] = idCol + + postsTable := models.InitTable("posts", "public") + postIdCol := models.InitColumn("id", "posts", "public") + postIdCol.Type = "integer" + postsTable.Columns["id"] = postIdCol + + userIdCol := models.InitColumn("user_id", "posts", "public") + userIdCol.Type = "integer" + postsTable.Columns["user_id"] = userIdCol + + currentSchema.Tables = append(currentSchema.Tables, usersTable, postsTable) + current.Schemas = append(current.Schemas, currentSchema) + + // Model database (with foreign key) + model := models.InitDatabase("testdb") + modelSchema := models.InitSchema("public") + + modelUsersTable := models.InitTable("users", "public") + modelIdCol := models.InitColumn("id", "users", "public") + modelIdCol.Type = "integer" + modelUsersTable.Columns["id"] = modelIdCol + + modelPostsTable := models.InitTable("posts", "public") + modelPostIdCol := models.InitColumn("id", "posts", "public") + modelPostIdCol.Type = "integer" + modelPostsTable.Columns["id"] = modelPostIdCol + + modelUserIdCol := models.InitColumn("user_id", "posts", "public") + modelUserIdCol.Type = "integer" + modelPostsTable.Columns["user_id"] = modelUserIdCol + + // Add foreign key constraint + fkConstraint := &models.Constraint{ + Name: "fk_posts_users", + Type: models.ForeignKeyConstraint, + Columns: []string{"user_id"}, + ReferencedTable: "users", + ReferencedSchema: "public", + ReferencedColumns: []string{"id"}, + OnDelete: "CASCADE", + OnUpdate: "CASCADE", + } + modelPostsTable.Constraints["fk_posts_users"] = fkConstraint + + modelSchema.Tables = append(modelSchema.Tables, modelUsersTable, modelPostsTable) + model.Schemas = append(model.Schemas, modelSchema) + + // Generate migration + var buf bytes.Buffer + writer := NewMigrationWriter(&writers.WriterOptions{}) + writer.writer = &buf + + err := writer.WriteMigration(model, current) + if err != nil { + t.Fatalf("WriteMigration failed: %v", err) + } + + output := buf.String() + t.Logf("Generated migration:\n%s", output) + + // Verify FOREIGN KEY is present + if !strings.Contains(output, "FOREIGN KEY") { + t.Error("Migration missing FOREIGN KEY statement") + } + if !strings.Contains(output, "ON DELETE CASCADE") { + t.Error("Migration missing ON DELETE CASCADE") + } +} + +func TestWriteMigration_AddIndex(t *testing.T) { + // Current database (table without index) + current := models.InitDatabase("testdb") + currentSchema := models.InitSchema("public") + currentTable := models.InitTable("users", "public") + + emailCol := models.InitColumn("email", "users", "public") + emailCol.Type = "text" + currentTable.Columns["email"] = emailCol + + currentSchema.Tables = append(currentSchema.Tables, currentTable) + current.Schemas = append(current.Schemas, currentSchema) + + // Model database (with unique index) + model := models.InitDatabase("testdb") + modelSchema := models.InitSchema("public") + modelTable := models.InitTable("users", "public") + + modelEmailCol := models.InitColumn("email", "users", "public") + modelEmailCol.Type = "text" + modelTable.Columns["email"] = modelEmailCol + + // Add unique index + index := &models.Index{ + Name: "uk_users_email", + Unique: true, + Columns: []string{"email"}, + Type: "btree", + } + modelTable.Indexes["uk_users_email"] = index + + modelSchema.Tables = append(modelSchema.Tables, modelTable) + model.Schemas = append(model.Schemas, modelSchema) + + // Generate migration + var buf bytes.Buffer + writer := NewMigrationWriter(&writers.WriterOptions{}) + writer.writer = &buf + + err := writer.WriteMigration(model, current) + if err != nil { + t.Fatalf("WriteMigration failed: %v", err) + } + + output := buf.String() + t.Logf("Generated migration:\n%s", output) + + // Verify CREATE UNIQUE INDEX is present + if !strings.Contains(output, "CREATE UNIQUE INDEX") { + t.Error("Migration missing CREATE UNIQUE INDEX statement") + } + if !strings.Contains(output, "uk_users_email") { + t.Error("Migration missing index name") + } +} diff --git a/pkg/writers/pgsql/writer.go b/pkg/writers/pgsql/writer.go new file mode 100644 index 0000000..cf54f4b --- /dev/null +++ b/pkg/writers/pgsql/writer.go @@ -0,0 +1,496 @@ +package pgsql + +import ( + "fmt" + "io" + "os" + "sort" + "strings" + + "git.warky.dev/wdevs/relspecgo/pkg/models" + "git.warky.dev/wdevs/relspecgo/pkg/writers" +) + +// Writer implements the Writer interface for PostgreSQL SQL output +type Writer struct { + options *writers.WriterOptions + writer io.Writer +} + +// NewWriter creates a new PostgreSQL SQL writer +func NewWriter(options *writers.WriterOptions) *Writer { + return &Writer{ + options: options, + } +} + +// WriteDatabase writes the entire database schema as SQL +func (w *Writer) WriteDatabase(db *models.Database) error { + var writer io.Writer + var file *os.File + var err error + + // Use existing writer if already set (for testing) + if w.writer != nil { + writer = w.writer + } else if w.options.OutputPath != "" { + // Determine output destination + file, err = os.Create(w.options.OutputPath) + if err != nil { + return fmt.Errorf("failed to create output file: %w", err) + } + defer file.Close() + writer = file + } else { + writer = os.Stdout + } + + w.writer = writer + + // Write header comment + fmt.Fprintf(w.writer, "-- PostgreSQL Database Schema\n") + fmt.Fprintf(w.writer, "-- Database: %s\n", db.Name) + fmt.Fprintf(w.writer, "-- Generated by RelSpec\n\n") + + // Process each schema in the database + for _, schema := range db.Schemas { + if err := w.WriteSchema(schema); err != nil { + return fmt.Errorf("failed to write schema %s: %w", schema.Name, err) + } + } + + return nil +} + +// WriteSchema writes a single schema and all its tables +func (w *Writer) WriteSchema(schema *models.Schema) error { + if w.writer == nil { + w.writer = os.Stdout + } + + // Phase 1: Create schema (priority 1) + if err := w.writeCreateSchema(schema); err != nil { + return err + } + + // Phase 2: Create sequences (priority 80) + if err := w.writeSequences(schema); err != nil { + return err + } + + // Phase 3: Create tables with columns (priority 100) + if err := w.writeCreateTables(schema); err != nil { + return err + } + + // Phase 4: Create primary keys (priority 160) + if err := w.writePrimaryKeys(schema); err != nil { + return err + } + + // Phase 5: Create indexes (priority 180) + if err := w.writeIndexes(schema); err != nil { + return err + } + + // Phase 6: Create foreign key constraints (priority 195) + if err := w.writeForeignKeys(schema); err != nil { + return err + } + + // Phase 7: Set sequence values (priority 200) + if err := w.writeSetSequenceValues(schema); err != nil { + return err + } + + // Phase 8: Add comments (priority 200+) + if err := w.writeComments(schema); err != nil { + return err + } + + return nil +} + +// WriteTable writes a single table with all its elements +func (w *Writer) WriteTable(table *models.Table) error { + if w.writer == nil { + w.writer = os.Stdout + } + + // Create a temporary schema with just this table + schema := models.InitSchema(table.Schema) + schema.Tables = append(schema.Tables, table) + + return w.WriteSchema(schema) +} + +// writeCreateSchema generates CREATE SCHEMA statement +func (w *Writer) writeCreateSchema(schema *models.Schema) error { + if schema.Name == "public" { + // public schema exists by default + return nil + } + + fmt.Fprintf(w.writer, "-- Schema: %s\n", schema.Name) + fmt.Fprintf(w.writer, "CREATE SCHEMA IF NOT EXISTS %s;\n\n", schema.SQLName()) + return nil +} + +// writeSequences generates CREATE SEQUENCE statements for identity columns +func (w *Writer) writeSequences(schema *models.Schema) error { + fmt.Fprintf(w.writer, "-- Sequences for schema: %s\n", schema.Name) + + for _, table := range schema.Tables { + pk := table.GetPrimaryKey() + if pk == nil { + continue + } + + // Only create sequences for integer-type PKs with identity + if !isIntegerType(pk.Type) { + continue + } + + seqName := fmt.Sprintf("identity_%s_%s", table.SQLName(), pk.SQLName()) + fmt.Fprintf(w.writer, "CREATE SEQUENCE IF NOT EXISTS %s.%s\n", + schema.SQLName(), seqName) + fmt.Fprintf(w.writer, " INCREMENT 1\n") + fmt.Fprintf(w.writer, " MINVALUE 1\n") + fmt.Fprintf(w.writer, " MAXVALUE 9223372036854775807\n") + fmt.Fprintf(w.writer, " START 1\n") + fmt.Fprintf(w.writer, " CACHE 1;\n\n") + } + + return nil +} + +// writeCreateTables generates CREATE TABLE statements +func (w *Writer) writeCreateTables(schema *models.Schema) error { + fmt.Fprintf(w.writer, "-- Tables for schema: %s\n", schema.Name) + + for _, table := range schema.Tables { + fmt.Fprintf(w.writer, "CREATE TABLE IF NOT EXISTS %s.%s (\n", + schema.SQLName(), table.SQLName()) + + // Write columns + columns := getSortedColumns(table.Columns) + columnDefs := make([]string, 0, len(columns)) + + for _, col := range columns { + colDef := fmt.Sprintf(" %s %s", col.SQLName(), col.Type) + + // Add default value if present + if col.Default != "" { + colDef += fmt.Sprintf(" DEFAULT %s", col.Default) + } + + columnDefs = append(columnDefs, colDef) + } + + fmt.Fprintf(w.writer, "%s\n", strings.Join(columnDefs, ",\n")) + fmt.Fprintf(w.writer, ");\n\n") + } + + return nil +} + +// writePrimaryKeys generates ALTER TABLE statements for primary keys +func (w *Writer) writePrimaryKeys(schema *models.Schema) error { + fmt.Fprintf(w.writer, "-- Primary keys for schema: %s\n", schema.Name) + + for _, table := range schema.Tables { + // Find primary key constraint + var pkConstraint *models.Constraint + for name, constraint := range table.Constraints { + if constraint.Type == models.PrimaryKeyConstraint { + pkConstraint = constraint + _ = name // Use the name variable + break + } + } + + if pkConstraint == nil { + // No explicit PK constraint, skip + continue + } + + pkName := fmt.Sprintf("pk_%s_%s", schema.SQLName(), table.SQLName()) + + // Build column list + columnNames := make([]string, 0, len(pkConstraint.Columns)) + for _, colName := range pkConstraint.Columns { + if col, ok := table.Columns[colName]; ok { + columnNames = append(columnNames, col.SQLName()) + } + } + + if len(columnNames) == 0 { + continue + } + + fmt.Fprintf(w.writer, "DO $$\nBEGIN\n") + fmt.Fprintf(w.writer, " IF NOT EXISTS (\n") + fmt.Fprintf(w.writer, " SELECT 1 FROM information_schema.table_constraints\n") + fmt.Fprintf(w.writer, " WHERE table_schema = '%s'\n", schema.Name) + fmt.Fprintf(w.writer, " AND table_name = '%s'\n", table.Name) + fmt.Fprintf(w.writer, " AND constraint_name = '%s'\n", pkName) + fmt.Fprintf(w.writer, " ) THEN\n") + fmt.Fprintf(w.writer, " ALTER TABLE %s.%s\n", schema.SQLName(), table.SQLName()) + fmt.Fprintf(w.writer, " ADD CONSTRAINT %s PRIMARY KEY (%s);\n", + pkName, strings.Join(columnNames, ", ")) + fmt.Fprintf(w.writer, " END IF;\n") + fmt.Fprintf(w.writer, "END;\n$$;\n\n") + } + + return nil +} + +// writeIndexes generates CREATE INDEX statements +func (w *Writer) writeIndexes(schema *models.Schema) error { + fmt.Fprintf(w.writer, "-- Indexes for schema: %s\n", schema.Name) + + for _, table := range schema.Tables { + // Sort indexes by name for consistent output + indexNames := make([]string, 0, len(table.Indexes)) + for name := range table.Indexes { + indexNames = append(indexNames, name) + } + sort.Strings(indexNames) + + for _, name := range indexNames { + index := table.Indexes[name] + + // Skip if it's a primary key index (based on name convention or columns) + // Primary keys are handled separately + if strings.HasPrefix(strings.ToLower(index.Name), "pk_") { + continue + } + + indexName := index.Name + if indexName == "" { + indexType := "idx" + if index.Unique { + indexType = "uk" + } + indexName = fmt.Sprintf("%s_%s_%s", indexType, schema.SQLName(), table.SQLName()) + } + + // Build column list + columnNames := make([]string, 0, len(index.Columns)) + for _, colName := range index.Columns { + if col, ok := table.Columns[colName]; ok { + columnNames = append(columnNames, col.SQLName()) + } + } + + if len(columnNames) == 0 { + continue + } + + unique := "" + if index.Unique { + unique = "UNIQUE " + } + + fmt.Fprintf(w.writer, "CREATE %sINDEX IF NOT EXISTS %s\n", + unique, indexName) + fmt.Fprintf(w.writer, " ON %s.%s USING btree (%s);\n\n", + schema.SQLName(), table.SQLName(), strings.Join(columnNames, ", ")) + } + } + + return nil +} + +// writeForeignKeys generates ALTER TABLE statements for foreign keys +func (w *Writer) writeForeignKeys(schema *models.Schema) error { + fmt.Fprintf(w.writer, "-- Foreign keys for schema: %s\n", schema.Name) + + for _, table := range schema.Tables { + // Sort relationships by name for consistent output + relNames := make([]string, 0, len(table.Relationships)) + for name := range table.Relationships { + relNames = append(relNames, name) + } + sort.Strings(relNames) + + for _, name := range relNames { + rel := table.Relationships[name] + + // For relationships, we need to look up the foreign key constraint + // that defines the actual column mappings + fkName := rel.ForeignKey + if fkName == "" { + fkName = name + } + if fkName == "" { + fkName = fmt.Sprintf("fk_%s_%s", table.SQLName(), rel.ToTable) + } + + // Find the foreign key constraint that matches this relationship + var fkConstraint *models.Constraint + for _, constraint := range table.Constraints { + if constraint.Type == models.ForeignKeyConstraint && + (constraint.Name == fkName || constraint.ReferencedTable == rel.ToTable) { + fkConstraint = constraint + break + } + } + + // If no constraint found, skip this relationship + if fkConstraint == nil { + continue + } + + // Build column lists from the constraint + sourceColumns := make([]string, 0, len(fkConstraint.Columns)) + for _, colName := range fkConstraint.Columns { + if col, ok := table.Columns[colName]; ok { + sourceColumns = append(sourceColumns, col.SQLName()) + } + } + + targetColumns := make([]string, 0, len(fkConstraint.ReferencedColumns)) + for _, colName := range fkConstraint.ReferencedColumns { + targetColumns = append(targetColumns, strings.ToLower(colName)) + } + + if len(sourceColumns) == 0 || len(targetColumns) == 0 { + continue + } + + onDelete := "NO ACTION" + if fkConstraint.OnDelete != "" { + onDelete = strings.ToUpper(fkConstraint.OnDelete) + } + + onUpdate := "NO ACTION" + if fkConstraint.OnUpdate != "" { + onUpdate = strings.ToUpper(fkConstraint.OnUpdate) + } + + fmt.Fprintf(w.writer, "ALTER TABLE %s.%s\n", schema.SQLName(), table.SQLName()) + fmt.Fprintf(w.writer, " DROP CONSTRAINT IF EXISTS %s;\n", fkName) + fmt.Fprintf(w.writer, "\n") + fmt.Fprintf(w.writer, "ALTER TABLE %s.%s\n", schema.SQLName(), table.SQLName()) + fmt.Fprintf(w.writer, " ADD CONSTRAINT %s\n", fkName) + fmt.Fprintf(w.writer, " FOREIGN KEY (%s)\n", strings.Join(sourceColumns, ", ")) + + // Use constraint's referenced schema/table or relationship's ToSchema/ToTable + refSchema := fkConstraint.ReferencedSchema + if refSchema == "" { + refSchema = rel.ToSchema + } + refTable := fkConstraint.ReferencedTable + if refTable == "" { + refTable = rel.ToTable + } + + fmt.Fprintf(w.writer, " REFERENCES %s.%s (%s)\n", + refSchema, refTable, strings.Join(targetColumns, ", ")) + fmt.Fprintf(w.writer, " ON DELETE %s\n", onDelete) + fmt.Fprintf(w.writer, " ON UPDATE %s\n", onUpdate) + fmt.Fprintf(w.writer, " DEFERRABLE;\n\n") + } + } + + return nil +} + +// writeSetSequenceValues generates statements to set sequence current values +func (w *Writer) writeSetSequenceValues(schema *models.Schema) error { + fmt.Fprintf(w.writer, "-- Set sequence values for schema: %s\n", schema.Name) + + for _, table := range schema.Tables { + pk := table.GetPrimaryKey() + if pk == nil || !isIntegerType(pk.Type) { + continue + } + + seqName := fmt.Sprintf("identity_%s_%s", table.SQLName(), pk.SQLName()) + + fmt.Fprintf(w.writer, "DO $$\n") + fmt.Fprintf(w.writer, "DECLARE\n") + fmt.Fprintf(w.writer, " m_cnt bigint;\n") + fmt.Fprintf(w.writer, "BEGIN\n") + fmt.Fprintf(w.writer, " IF EXISTS (\n") + fmt.Fprintf(w.writer, " SELECT 1 FROM pg_class c\n") + fmt.Fprintf(w.writer, " INNER JOIN pg_namespace n ON n.oid = c.relnamespace\n") + fmt.Fprintf(w.writer, " WHERE c.relname = '%s'\n", seqName) + fmt.Fprintf(w.writer, " AND n.nspname = '%s'\n", schema.Name) + fmt.Fprintf(w.writer, " AND c.relkind = 'S'\n") + fmt.Fprintf(w.writer, " ) THEN\n") + fmt.Fprintf(w.writer, " SELECT COALESCE(MAX(%s), 0) + 1\n", pk.SQLName()) + fmt.Fprintf(w.writer, " FROM %s.%s\n", schema.SQLName(), table.SQLName()) + fmt.Fprintf(w.writer, " INTO m_cnt;\n") + fmt.Fprintf(w.writer, " \n") + fmt.Fprintf(w.writer, " PERFORM setval('%s.%s'::regclass, m_cnt);\n", + schema.SQLName(), seqName) + fmt.Fprintf(w.writer, " END IF;\n") + fmt.Fprintf(w.writer, "END;\n") + fmt.Fprintf(w.writer, "$$;\n\n") + } + + return nil +} + +// writeComments generates COMMENT statements for tables and columns +func (w *Writer) writeComments(schema *models.Schema) error { + fmt.Fprintf(w.writer, "-- Comments for schema: %s\n", schema.Name) + + for _, table := range schema.Tables { + // Table comment + if table.Description != "" { + fmt.Fprintf(w.writer, "COMMENT ON TABLE %s.%s IS '%s';\n", + schema.SQLName(), table.SQLName(), + escapeQuote(table.Description)) + } + + // Column comments + for _, col := range getSortedColumns(table.Columns) { + if col.Description != "" { + fmt.Fprintf(w.writer, "COMMENT ON COLUMN %s.%s.%s IS '%s';\n", + schema.SQLName(), table.SQLName(), col.SQLName(), + escapeQuote(col.Description)) + } + } + + fmt.Fprintf(w.writer, "\n") + } + + return nil +} + +// Helper functions + +// getSortedColumns returns columns sorted by name +func getSortedColumns(columns map[string]*models.Column) []*models.Column { + names := make([]string, 0, len(columns)) + for name := range columns { + names = append(names, name) + } + sort.Strings(names) + + sorted := make([]*models.Column, 0, len(columns)) + for _, name := range names { + sorted = append(sorted, columns[name]) + } + return sorted +} + +// isIntegerType checks if a column type is an integer type +func isIntegerType(colType string) bool { + intTypes := []string{"integer", "int", "bigint", "smallint", "serial", "bigserial"} + lowerType := strings.ToLower(colType) + for _, t := range intTypes { + if strings.HasPrefix(lowerType, t) { + return true + } + } + return false +} + +// escapeQuote escapes single quotes in strings for SQL +func escapeQuote(s string) string { + return strings.ReplaceAll(s, "'", "''") +} diff --git a/pkg/writers/pgsql/writer_test.go b/pkg/writers/pgsql/writer_test.go new file mode 100644 index 0000000..097a93d --- /dev/null +++ b/pkg/writers/pgsql/writer_test.go @@ -0,0 +1,243 @@ +package pgsql + +import ( + "bytes" + "strings" + "testing" + + "git.warky.dev/wdevs/relspecgo/pkg/models" + "git.warky.dev/wdevs/relspecgo/pkg/writers" +) + +func TestWriteDatabase(t *testing.T) { + // Create a test database + db := models.InitDatabase("testdb") + schema := models.InitSchema("public") + + // Create a test table + table := models.InitTable("users", "public") + table.Description = "User accounts table" + + // Add columns + idCol := models.InitColumn("id", "users", "public") + idCol.Type = "integer" + idCol.Description = "Primary key" + idCol.Default = "nextval('public.identity_users_id'::regclass)" + table.Columns["id"] = idCol + + nameCol := models.InitColumn("name", "users", "public") + nameCol.Type = "text" + nameCol.Description = "User name" + table.Columns["name"] = nameCol + + emailCol := models.InitColumn("email", "users", "public") + emailCol.Type = "text" + emailCol.Description = "Email address" + table.Columns["email"] = emailCol + + // Add primary key constraint + pkConstraint := &models.Constraint{ + Name: "pk_users", + Type: models.PrimaryKeyConstraint, + Columns: []string{"id"}, + } + table.Constraints["pk_users"] = pkConstraint + + // Add unique index + uniqueEmailIndex := &models.Index{ + Name: "uk_users_email", + Unique: true, + Columns: []string{"email"}, + } + table.Indexes["uk_users_email"] = uniqueEmailIndex + + schema.Tables = append(schema.Tables, table) + db.Schemas = append(db.Schemas, schema) + + // Create writer with output to buffer + var buf bytes.Buffer + options := &writers.WriterOptions{} + writer := NewWriter(options) + writer.writer = &buf + + // Write the database + err := writer.WriteDatabase(db) + if err != nil { + t.Fatalf("WriteDatabase failed: %v", err) + } + + output := buf.String() + + // Print output for debugging + t.Logf("Generated SQL:\n%s", output) + + // Verify output contains expected elements + expectedStrings := []string{ + "CREATE TABLE", + "PRIMARY KEY", + "UNIQUE INDEX", + "COMMENT ON TABLE", + "COMMENT ON COLUMN", + } + + for _, expected := range expectedStrings { + if !strings.Contains(output, expected) { + t.Errorf("Output missing expected string: %s\nFull output:\n%s", expected, output) + } + } +} + +func TestWriteForeignKeys(t *testing.T) { + // Create a test database with two related tables + db := models.InitDatabase("testdb") + schema := models.InitSchema("public") + + // Create parent table (users) + usersTable := models.InitTable("users", "public") + idCol := models.InitColumn("id", "users", "public") + idCol.Type = "integer" + usersTable.Columns["id"] = idCol + + // Create child table (posts) + postsTable := models.InitTable("posts", "public") + postIdCol := models.InitColumn("id", "posts", "public") + postIdCol.Type = "integer" + postsTable.Columns["id"] = postIdCol + + userIdCol := models.InitColumn("user_id", "posts", "public") + userIdCol.Type = "integer" + postsTable.Columns["user_id"] = userIdCol + + // Add foreign key constraint + fkConstraint := &models.Constraint{ + Name: "fk_posts_users", + Type: models.ForeignKeyConstraint, + Columns: []string{"user_id"}, + ReferencedTable: "users", + ReferencedSchema: "public", + ReferencedColumns: []string{"id"}, + OnDelete: "CASCADE", + OnUpdate: "CASCADE", + } + postsTable.Constraints["fk_posts_users"] = fkConstraint + + // Add relationship + relationship := &models.Relationship{ + Name: "fk_posts_users", + FromTable: "posts", + FromSchema: "public", + ToTable: "users", + ToSchema: "public", + ForeignKey: "fk_posts_users", + } + postsTable.Relationships["fk_posts_users"] = relationship + + schema.Tables = append(schema.Tables, usersTable, postsTable) + db.Schemas = append(db.Schemas, schema) + + // Create writer with output to buffer + var buf bytes.Buffer + options := &writers.WriterOptions{} + writer := NewWriter(options) + writer.writer = &buf + + // Write the database + err := writer.WriteDatabase(db) + if err != nil { + t.Fatalf("WriteDatabase failed: %v", err) + } + + output := buf.String() + + // Print output for debugging + t.Logf("Generated SQL:\n%s", output) + + // Verify foreign key is present + if !strings.Contains(output, "FOREIGN KEY") { + t.Errorf("Output missing FOREIGN KEY statement\nFull output:\n%s", output) + } + if !strings.Contains(output, "ON DELETE CASCADE") { + t.Errorf("Output missing ON DELETE CASCADE\nFull output:\n%s", output) + } + if !strings.Contains(output, "ON UPDATE CASCADE") { + t.Errorf("Output missing ON UPDATE CASCADE\nFull output:\n%s", output) + } +} + +func TestWriteTable(t *testing.T) { + // Create a single table + table := models.InitTable("products", "public") + + idCol := models.InitColumn("id", "products", "public") + idCol.Type = "integer" + table.Columns["id"] = idCol + + nameCol := models.InitColumn("name", "products", "public") + nameCol.Type = "text" + table.Columns["name"] = nameCol + + // Create writer with output to buffer + var buf bytes.Buffer + options := &writers.WriterOptions{} + writer := NewWriter(options) + writer.writer = &buf + + // Write the table + err := writer.WriteTable(table) + if err != nil { + t.Fatalf("WriteTable failed: %v", err) + } + + output := buf.String() + + // Verify output contains table creation + if !strings.Contains(output, "CREATE TABLE") { + t.Error("Output missing CREATE TABLE statement") + } + if !strings.Contains(output, "products") { + t.Error("Output missing table name 'products'") + } +} + +func TestEscapeQuote(t *testing.T) { + tests := []struct { + input string + expected string + }{ + {"simple text", "simple text"}, + {"text with 'quote'", "text with ''quote''"}, + {"multiple 'quotes' here", "multiple ''quotes'' here"}, + {"", ""}, + } + + for _, tt := range tests { + result := escapeQuote(tt.input) + if result != tt.expected { + t.Errorf("escapeQuote(%q) = %q, want %q", tt.input, result, tt.expected) + } + } +} + +func TestIsIntegerType(t *testing.T) { + tests := []struct { + input string + expected bool + }{ + {"integer", true}, + {"INTEGER", true}, + {"bigint", true}, + {"smallint", true}, + {"serial", true}, + {"bigserial", true}, + {"text", false}, + {"varchar", false}, + {"uuid", false}, + } + + for _, tt := range tests { + result := isIntegerType(tt.input) + if result != tt.expected { + t.Errorf("isIntegerType(%q) = %v, want %v", tt.input, result, tt.expected) + } + } +}