diff --git a/cmd/relspec/convert.go b/cmd/relspec/convert.go index c75d1ff..4e6d19a 100644 --- a/cmd/relspec/convert.go +++ b/cmd/relspec/convert.go @@ -43,16 +43,17 @@ import ( ) var ( - convertSourceType string - convertSourcePath string - convertSourceConn string - convertFromList []string - convertTargetType string - convertTargetPath string - convertPackageName string - convertSchemaFilter string - convertFlattenSchema bool - convertNullableTypes string + convertSourceType string + convertSourcePath string + convertSourceConn string + convertFromList []string + convertTargetType string + convertTargetPath string + convertPackageName string + convertSchemaFilter string + convertFlattenSchema bool + convertNullableTypes string + convertContinueOnError bool ) var convertCmd = &cobra.Command{ @@ -177,6 +178,7 @@ func init() { convertCmd.Flags().StringVar(&convertSchemaFilter, "schema", "", "Filter to a specific schema by name (required for formats like dctx that only support single schemas)") convertCmd.Flags().BoolVar(&convertFlattenSchema, "flatten-schema", false, "Flatten schema.table names to schema_table (useful for databases like SQLite that do not support schemas)") convertCmd.Flags().StringVar(&convertNullableTypes, "types", "", "Nullable type package for code-gen writers (bun/gorm): 'resolvespec' (default) or 'stdlib' (database/sql)") + convertCmd.Flags().BoolVar(&convertContinueOnError, "continue-on-error", false, "Prepend \\set ON_ERROR_STOP off to generated SQL so psql continues past errors (pgsql output only)") err := convertCmd.MarkFlagRequired("from") if err != nil { @@ -243,7 +245,7 @@ func runConvert(cmd *cobra.Command, args []string) error { fmt.Fprintf(os.Stderr, " Schema: %s\n", convertSchemaFilter) } - if err := writeDatabase(db, convertTargetType, convertTargetPath, convertPackageName, convertSchemaFilter, convertFlattenSchema, convertNullableTypes); err != nil { + if err := writeDatabase(db, convertTargetType, convertTargetPath, convertPackageName, convertSchemaFilter, convertFlattenSchema, convertNullableTypes, convertContinueOnError); err != nil { return fmt.Errorf("failed to write target: %w", err) } @@ -383,10 +385,10 @@ func readDatabaseForConvert(dbType, filePath, connString string) (*models.Databa return db, nil } -func writeDatabase(db *models.Database, dbType, outputPath, packageName, schemaFilter string, flattenSchema bool, nullableTypes string) error { +func writeDatabase(db *models.Database, dbType, outputPath, packageName, schemaFilter string, flattenSchema bool, nullableTypes string, continueOnError bool) error { var writer writers.Writer - writerOpts := newWriterOptions(outputPath, packageName, flattenSchema, nullableTypes) + writerOpts := newWriterOptions(outputPath, packageName, flattenSchema, nullableTypes, continueOnError) switch strings.ToLower(dbType) { case "dbml": diff --git a/cmd/relspec/edit.go b/cmd/relspec/edit.go index fb27325..966290a 100644 --- a/cmd/relspec/edit.go +++ b/cmd/relspec/edit.go @@ -323,31 +323,31 @@ func writeDatabaseForEdit(dbType, filePath, connString string, db *models.Databa switch strings.ToLower(dbType) { case "dbml": - writer = wdbml.NewWriter(newWriterOptions(filePath, "", false, "")) + writer = wdbml.NewWriter(newWriterOptions(filePath, "", false, "", false)) case "dctx": - writer = wdctx.NewWriter(newWriterOptions(filePath, "", false, "")) + writer = wdctx.NewWriter(newWriterOptions(filePath, "", false, "", false)) case "drawdb": - writer = wdrawdb.NewWriter(newWriterOptions(filePath, "", false, "")) + writer = wdrawdb.NewWriter(newWriterOptions(filePath, "", false, "", false)) case "graphql": - writer = wgraphql.NewWriter(newWriterOptions(filePath, "", false, "")) + writer = wgraphql.NewWriter(newWriterOptions(filePath, "", false, "", false)) case "json": - writer = wjson.NewWriter(newWriterOptions(filePath, "", false, "")) + writer = wjson.NewWriter(newWriterOptions(filePath, "", false, "", false)) case "yaml": - writer = wyaml.NewWriter(newWriterOptions(filePath, "", false, "")) + writer = wyaml.NewWriter(newWriterOptions(filePath, "", false, "", false)) case "gorm": - writer = wgorm.NewWriter(newWriterOptions(filePath, "", false, "")) + writer = wgorm.NewWriter(newWriterOptions(filePath, "", false, "", false)) case "bun": - writer = wbun.NewWriter(newWriterOptions(filePath, "", false, "")) + writer = wbun.NewWriter(newWriterOptions(filePath, "", false, "", false)) case "drizzle": - writer = wdrizzle.NewWriter(newWriterOptions(filePath, "", false, "")) + writer = wdrizzle.NewWriter(newWriterOptions(filePath, "", false, "", false)) case "prisma": - writer = wprisma.NewWriter(newWriterOptions(filePath, "", false, "")) + writer = wprisma.NewWriter(newWriterOptions(filePath, "", false, "", false)) case "typeorm": - writer = wtypeorm.NewWriter(newWriterOptions(filePath, "", false, "")) + writer = wtypeorm.NewWriter(newWriterOptions(filePath, "", false, "", false)) case "sqlite", "sqlite3": - writer = wsqlite.NewWriter(newWriterOptions(filePath, "", false, "")) + writer = wsqlite.NewWriter(newWriterOptions(filePath, "", false, "", false)) case "pgsql": - writer = wpgsql.NewWriter(newWriterOptions(filePath, "", false, "")) + writer = wpgsql.NewWriter(newWriterOptions(filePath, "", false, "", false)) default: return fmt.Errorf("%s: unsupported format: %s", label, dbType) } diff --git a/cmd/relspec/merge.go b/cmd/relspec/merge.go index 3677df9..27fae2e 100644 --- a/cmd/relspec/merge.go +++ b/cmd/relspec/merge.go @@ -375,61 +375,61 @@ func writeDatabaseForMerge(dbType, filePath, connString string, db *models.Datab if filePath == "" { return fmt.Errorf("%s: file path is required for DBML format", label) } - writer = wdbml.NewWriter(newWriterOptions(filePath, "", flattenSchema, "")) + writer = wdbml.NewWriter(newWriterOptions(filePath, "", flattenSchema, "", false)) case "dctx": if filePath == "" { return fmt.Errorf("%s: file path is required for DCTX format", label) } - writer = wdctx.NewWriter(newWriterOptions(filePath, "", flattenSchema, "")) + writer = wdctx.NewWriter(newWriterOptions(filePath, "", flattenSchema, "", false)) case "drawdb": if filePath == "" { return fmt.Errorf("%s: file path is required for DrawDB format", label) } - writer = wdrawdb.NewWriter(newWriterOptions(filePath, "", flattenSchema, "")) + writer = wdrawdb.NewWriter(newWriterOptions(filePath, "", flattenSchema, "", false)) case "graphql": if filePath == "" { return fmt.Errorf("%s: file path is required for GraphQL format", label) } - writer = wgraphql.NewWriter(newWriterOptions(filePath, "", flattenSchema, "")) + writer = wgraphql.NewWriter(newWriterOptions(filePath, "", flattenSchema, "", false)) case "json": if filePath == "" { return fmt.Errorf("%s: file path is required for JSON format", label) } - writer = wjson.NewWriter(newWriterOptions(filePath, "", flattenSchema, "")) + writer = wjson.NewWriter(newWriterOptions(filePath, "", flattenSchema, "", false)) case "yaml": if filePath == "" { return fmt.Errorf("%s: file path is required for YAML format", label) } - writer = wyaml.NewWriter(newWriterOptions(filePath, "", flattenSchema, "")) + writer = wyaml.NewWriter(newWriterOptions(filePath, "", flattenSchema, "", false)) case "gorm": if filePath == "" { return fmt.Errorf("%s: file path is required for GORM format", label) } - writer = wgorm.NewWriter(newWriterOptions(filePath, "", flattenSchema, "")) + writer = wgorm.NewWriter(newWriterOptions(filePath, "", flattenSchema, "", false)) case "bun": if filePath == "" { return fmt.Errorf("%s: file path is required for Bun format", label) } - writer = wbun.NewWriter(newWriterOptions(filePath, "", flattenSchema, "")) + writer = wbun.NewWriter(newWriterOptions(filePath, "", flattenSchema, "", false)) case "drizzle": if filePath == "" { return fmt.Errorf("%s: file path is required for Drizzle format", label) } - writer = wdrizzle.NewWriter(newWriterOptions(filePath, "", flattenSchema, "")) + writer = wdrizzle.NewWriter(newWriterOptions(filePath, "", flattenSchema, "", false)) case "prisma": if filePath == "" { return fmt.Errorf("%s: file path is required for Prisma format", label) } - writer = wprisma.NewWriter(newWriterOptions(filePath, "", flattenSchema, "")) + writer = wprisma.NewWriter(newWriterOptions(filePath, "", flattenSchema, "", false)) case "typeorm": if filePath == "" { return fmt.Errorf("%s: file path is required for TypeORM format", label) } - writer = wtypeorm.NewWriter(newWriterOptions(filePath, "", flattenSchema, "")) + writer = wtypeorm.NewWriter(newWriterOptions(filePath, "", flattenSchema, "", false)) case "sqlite", "sqlite3": - writer = wsqlite.NewWriter(newWriterOptions(filePath, "", flattenSchema, "")) + writer = wsqlite.NewWriter(newWriterOptions(filePath, "", flattenSchema, "", false)) case "pgsql": - writerOpts := newWriterOptions(filePath, "", flattenSchema, "") + writerOpts := newWriterOptions(filePath, "", flattenSchema, "", false) if connString != "" { writerOpts.Metadata = map[string]interface{}{ "connection_string": connString, diff --git a/cmd/relspec/prisma_options.go b/cmd/relspec/prisma_options.go index 2675640..e391a91 100644 --- a/cmd/relspec/prisma_options.go +++ b/cmd/relspec/prisma_options.go @@ -13,12 +13,13 @@ func newReaderOptions(filePath, connString string) *readers.ReaderOptions { } } -func newWriterOptions(outputPath, packageName string, flattenSchema bool, nullableTypes string) *writers.WriterOptions { +func newWriterOptions(outputPath, packageName string, flattenSchema bool, nullableTypes string, continueOnError bool) *writers.WriterOptions { return &writers.WriterOptions{ - OutputPath: outputPath, - PackageName: packageName, - FlattenSchema: flattenSchema, - NullableTypes: nullableTypes, - Prisma7: prisma7, + OutputPath: outputPath, + PackageName: packageName, + FlattenSchema: flattenSchema, + NullableTypes: nullableTypes, + Prisma7: prisma7, + ContinueOnError: continueOnError, } } diff --git a/cmd/relspec/split.go b/cmd/relspec/split.go index 876ac73..04da013 100644 --- a/cmd/relspec/split.go +++ b/cmd/relspec/split.go @@ -188,6 +188,7 @@ func runSplit(cmd *cobra.Command, args []string) error { "", // no schema filter for split false, // no flatten-schema for split splitNullableTypes, + false, // no continue-on-error for split ) if err != nil { return fmt.Errorf("failed to write output: %w", err) diff --git a/pkg/writers/pgsql/migration_writer.go b/pkg/writers/pgsql/migration_writer.go index bc72898..4f0dce3 100644 --- a/pkg/writers/pgsql/migration_writer.go +++ b/pkg/writers/pgsql/migration_writer.go @@ -144,7 +144,11 @@ func (w *MigrationWriter) WriteMigration(model *models.Database, current *models // Write header fmt.Fprintf(w.writer, "-- PostgreSQL Migration Script\n") fmt.Fprintf(w.writer, "-- Generated by RelSpec\n") - fmt.Fprintf(w.writer, "-- Source: %s -> %s\n\n", current.Name, model.Name) + fmt.Fprintf(w.writer, "-- Source: %s -> %s\n", current.Name, model.Name) + if w.options.ContinueOnError { + fmt.Fprintf(w.writer, "\\set ON_ERROR_STOP off\n") + } + fmt.Fprintf(w.writer, "\n") // Write scripts for _, script := range scripts { @@ -171,13 +175,15 @@ func (w *MigrationWriter) generateSchemaScripts(model *models.Schema, current *m }) } - // Phase 1: Drop constraints and indexes that changed (Priority 11-50) + // Phase 1: Drop constraints and indexes that changed (Priority 5-50) + var droppedFKs map[string]bool if current != nil { - dropScripts, err := w.generateDropScripts(model, current) + dropScripts, dropped, err := w.generateDropScripts(model, current) if err != nil { return nil, fmt.Errorf("failed to generate drop scripts: %w", err) } scripts = append(scripts, dropScripts...) + droppedFKs = dropped } // Phase 3: Create/Alter tables and columns (Priority 100-145) @@ -195,7 +201,7 @@ func (w *MigrationWriter) generateSchemaScripts(model *models.Schema, current *m scripts = append(scripts, indexScripts...) // Phase 5: Create foreign keys (Priority 195) - fkScripts, err := w.generateForeignKeyScripts(model, current) + fkScripts, err := w.generateForeignKeyScripts(model, current, droppedFKs) if err != nil { return nil, fmt.Errorf("failed to generate foreign key scripts: %w", err) } @@ -211,9 +217,12 @@ func (w *MigrationWriter) generateSchemaScripts(model *models.Schema, current *m return scripts, nil } -// generateDropScripts generates DROP scripts using templates -func (w *MigrationWriter) generateDropScripts(model *models.Schema, current *models.Schema) ([]MigrationScript, error) { +// generateDropScripts generates DROP scripts using templates. +// Returns the scripts and a set of FK constraint keys (schema.table.name) that were +// explicitly dropped because their referenced PK was being dropped, so they can be force-recreated. +func (w *MigrationWriter) generateDropScripts(model *models.Schema, current *models.Schema) ([]MigrationScript, map[string]bool, error) { scripts := make([]MigrationScript, 0) + droppedFKs := make(map[string]bool) // Build map of model tables for quick lookup modelTables := make(map[string]*models.Table) @@ -240,6 +249,44 @@ func (w *MigrationWriter) generateDropScripts(model *models.Schema, current *mod shouldDrop = true } + if shouldDrop && currentConstraint.Type == models.PrimaryKeyConstraint { + // Drop FK constraints that depend on this PK before dropping the PK itself. + for _, otherTable := range current.Tables { + for fkName, fkConstraint := range otherTable.Constraints { + if fkConstraint.Type != models.ForeignKeyConstraint { + continue + } + refTable := fkConstraint.ReferencedTable + refSchema := fkConstraint.ReferencedSchema + if refSchema == "" { + refSchema = current.Name + } + if strings.EqualFold(refTable, currentTable.Name) && strings.EqualFold(refSchema, current.Name) { + fkKey := fmt.Sprintf("%s.%s.%s", current.Name, otherTable.Name, fkName) + if !droppedFKs[fkKey] { + droppedFKs[fkKey] = true + sql, err := w.executor.ExecuteDropConstraint(DropConstraintData{ + SchemaName: current.Name, + TableName: otherTable.Name, + ConstraintName: fkName, + }) + if err != nil { + return nil, nil, err + } + scripts = append(scripts, MigrationScript{ + ObjectName: fkKey, + ObjectType: "drop constraint", + Schema: current.Name, + Priority: 5, + Sequence: len(scripts), + Body: sql, + }) + } + } + } + } + } + if shouldDrop { sql, err := w.executor.ExecuteDropConstraint(DropConstraintData{ SchemaName: current.Name, @@ -247,7 +294,7 @@ func (w *MigrationWriter) generateDropScripts(model *models.Schema, current *mod ConstraintName: constraintName, }) if err != nil { - return nil, err + return nil, nil, err } script := MigrationScript{ @@ -279,7 +326,7 @@ func (w *MigrationWriter) generateDropScripts(model *models.Schema, current *mod IndexName: indexName, }) if err != nil { - return nil, err + return nil, nil, err } script := MigrationScript{ @@ -295,7 +342,7 @@ func (w *MigrationWriter) generateDropScripts(model *models.Schema, current *mod } } - return scripts, nil + return scripts, droppedFKs, nil } // generateTableScripts generates CREATE/ALTER TABLE scripts using templates @@ -631,8 +678,10 @@ func buildIndexColumnExpressions(table *models.Table, index *models.Index, index return columnExprs } -// generateForeignKeyScripts generates ADD CONSTRAINT FOREIGN KEY scripts using templates -func (w *MigrationWriter) generateForeignKeyScripts(model *models.Schema, current *models.Schema) ([]MigrationScript, error) { +// generateForeignKeyScripts generates ADD CONSTRAINT FOREIGN KEY scripts using templates. +// forceRecreate is a set of FK constraint keys (schema.table.name) that must be recreated +// even if unchanged, because their referenced PK was dropped and recreated. +func (w *MigrationWriter) generateForeignKeyScripts(model *models.Schema, current *models.Schema, forceRecreate map[string]bool) ([]MigrationScript, error) { scripts := make([]MigrationScript, 0) // Build map of current tables @@ -653,13 +702,16 @@ func (w *MigrationWriter) generateForeignKeyScripts(model *models.Schema, curren continue } - shouldCreate := true + fkKey := fmt.Sprintf("%s.%s.%s", model.Name, modelTable.Name, constraintName) + shouldCreate := forceRecreate[fkKey] - if currentTable != nil { - if currentConstraint, exists := currentTable.Constraints[constraintName]; exists { - if constraintsEqual(constraint, currentConstraint) { - shouldCreate = false - } + if !shouldCreate { + if currentTable == nil { + shouldCreate = true + } else if currentConstraint, exists := currentTable.Constraints[constraintName]; !exists { + shouldCreate = true + } else if !constraintsEqual(constraint, currentConstraint) { + shouldCreate = true } } diff --git a/pkg/writers/pgsql/templates/create_primary_key_with_autogen_check.tmpl b/pkg/writers/pgsql/templates/create_primary_key_with_autogen_check.tmpl index da70960..4f1564a 100644 --- a/pkg/writers/pgsql/templates/create_primary_key_with_autogen_check.tmpl +++ b/pkg/writers/pgsql/templates/create_primary_key_with_autogen_check.tmpl @@ -31,7 +31,7 @@ BEGIN IF current_pk_name IS NOT NULL AND NOT current_pk_matches AND current_pk_name IN ({{.AutoGenNames}}) THEN - EXECUTE 'ALTER TABLE {{qual_table .SchemaName .TableName}} DROP CONSTRAINT ' || quote_ident(current_pk_name); + EXECUTE 'ALTER TABLE {{qual_table .SchemaName .TableName}} DROP CONSTRAINT ' || quote_ident(current_pk_name) || ' CASCADE'; END IF; -- Add the desired primary key only when no matching primary key already exists. diff --git a/pkg/writers/pgsql/writer.go b/pkg/writers/pgsql/writer.go index 2b4e4d3..50ad275 100644 --- a/pkg/writers/pgsql/writer.go +++ b/pkg/writers/pgsql/writer.go @@ -99,7 +99,11 @@ func (w *Writer) WriteDatabase(db *models.Database) error { // Write header comment fmt.Fprintf(w.writer, "-- PostgreSQL Database Schema\n") fmt.Fprintf(w.writer, "-- Database: %s\n", db.Name) - fmt.Fprintf(w.writer, "-- Generated by RelSpec\n\n") + fmt.Fprintf(w.writer, "-- Generated by RelSpec\n") + if w.options.ContinueOnError { + fmt.Fprintf(w.writer, "\\set ON_ERROR_STOP off\n") + } + fmt.Fprintf(w.writer, "\n") // Process each schema in the database for _, schema := range db.Schemas { diff --git a/pkg/writers/writer.go b/pkg/writers/writer.go index 605ad79..0ada614 100644 --- a/pkg/writers/writer.go +++ b/pkg/writers/writer.go @@ -54,6 +54,10 @@ type WriterOptions struct { // Prisma7 enables Prisma 7-specific output for Prisma writers. Prisma7 bool + // ContinueOnError instructs SQL writers to prepend `\set ON_ERROR_STOP off` + // to their output so that psql continues past errors instead of stopping. + ContinueOnError bool + // Additional options can be added here as needed Metadata map[string]interface{} }