diff --git a/pkg/writers/pgsql/migration_writer.go b/pkg/writers/pgsql/migration_writer.go index 5272354..7732f37 100644 --- a/pkg/writers/pgsql/migration_writer.go +++ b/pkg/writers/pgsql/migration_writer.go @@ -31,6 +31,10 @@ type MigrationWriter struct { // NewMigrationWriter creates a new templated migration writer func NewMigrationWriter(options *writers.WriterOptions) (*MigrationWriter, error) { + if options == nil { + options = &writers.WriterOptions{} + } + executor, err := NewTemplateExecutor(options.FlattenSchema) if err != nil { return nil, fmt.Errorf("failed to create template executor: %w", err) @@ -44,6 +48,16 @@ func NewMigrationWriter(options *writers.WriterOptions) (*MigrationWriter, error // WriteMigration generates migration scripts using templates func (w *MigrationWriter) WriteMigration(model *models.Database, current *models.Database) error { + if model == nil { + return fmt.Errorf("model database is required") + } + if w.options == nil { + w.options = &writers.WriterOptions{} + } + if current == nil { + current = models.InitDatabase(model.Name) + } + var writer io.Writer var file *os.File var err error @@ -86,9 +100,16 @@ func (w *MigrationWriter) WriteMigration(model *models.Database, current *models // Process each schema in the model for _, modelSchema := range model.Schemas { + if modelSchema == nil { + continue + } + // Find corresponding schema in current database var currentSchema *models.Schema for _, cs := range current.Schemas { + if cs == nil { + continue + } if strings.EqualFold(cs.Name, modelSchema.Name) { currentSchema = cs break @@ -545,12 +566,17 @@ func (w *MigrationWriter) generateIndexScripts(model *models.Schema, current *mo indexType = modelIndex.Type } + columnExprs := buildIndexColumnExpressions(modelTable, modelIndex, indexType) + if len(columnExprs) == 0 { + continue + } + sql, err := w.executor.ExecuteCreateIndex(CreateIndexData{ SchemaName: model.Name, TableName: modelTable.Name, IndexName: indexName, IndexType: indexType, - Columns: strings.Join(modelIndex.Columns, ", "), + Columns: strings.Join(columnExprs, ", "), Unique: modelIndex.Unique, }) if err != nil { @@ -573,6 +599,27 @@ func (w *MigrationWriter) generateIndexScripts(model *models.Schema, current *mo return scripts, nil } +func buildIndexColumnExpressions(table *models.Table, index *models.Index, indexType string) []string { + columnExprs := make([]string, 0, len(index.Columns)) + for _, colName := range index.Columns { + colExpr := colName + if table != nil { + if col, ok := table.Columns[colName]; ok && col != nil { + colExpr = col.SQLName() + if strings.EqualFold(indexType, "gin") && isTextType(col.Type) { + opClass := extractOperatorClass(index.Comment) + if opClass == "" { + opClass = "gin_trgm_ops" + } + colExpr = fmt.Sprintf("%s %s", col.SQLName(), opClass) + } + } + } + columnExprs = append(columnExprs, colExpr) + } + return columnExprs +} + // generateForeignKeyScripts generates ADD CONSTRAINT FOREIGN KEY scripts using templates func (w *MigrationWriter) generateForeignKeyScripts(model *models.Schema, current *models.Schema) ([]MigrationScript, error) { scripts := make([]MigrationScript, 0) diff --git a/pkg/writers/pgsql/migration_writer_test.go b/pkg/writers/pgsql/migration_writer_test.go index 2d0f493..0aaa154 100644 --- a/pkg/writers/pgsql/migration_writer_test.go +++ b/pkg/writers/pgsql/migration_writer_test.go @@ -97,6 +97,89 @@ func TestWriteMigration_ArrayDefault(t *testing.T) { } } +func TestWriteMigration_GinIndexOnTextUsesTrigramOperatorClass(t *testing.T) { + current := models.InitDatabase("testdb") + currentSchema := models.InitSchema("public") + current.Schemas = append(current.Schemas, currentSchema) + + model := models.InitDatabase("testdb") + modelSchema := models.InitSchema("public") + + table := models.InitTable("articles", "public") + titleCol := models.InitColumn("title", "articles", "public") + titleCol.Type = "text" + table.Columns["title"] = titleCol + + index := &models.Index{ + Name: "idx_articles_title_gin", + Type: "gin", + Columns: []string{"title"}, + } + table.Indexes[index.Name] = index + + modelSchema.Tables = append(modelSchema.Tables, table) + model.Schemas = append(model.Schemas, modelSchema) + + var buf bytes.Buffer + writer, err := NewMigrationWriter(&writers.WriterOptions{}) + if err != nil { + t.Fatalf("Failed to create writer: %v", err) + } + writer.writer = &buf + + if err := writer.WriteMigration(model, current); err != nil { + t.Fatalf("WriteMigration failed: %v", err) + } + + output := buf.String() + if !strings.Contains(output, "USING gin (title gin_trgm_ops)") { + t.Fatalf("expected GIN text index to include gin_trgm_ops, got:\n%s", output) + } +} + +func TestWriteMigration_GinIndexOnTextArrayDoesNotUseTrigramOperatorClass(t *testing.T) { + current := models.InitDatabase("testdb") + currentSchema := models.InitSchema("public") + current.Schemas = append(current.Schemas, currentSchema) + + model := models.InitDatabase("testdb") + modelSchema := models.InitSchema("public") + + table := models.InitTable("plans", "public") + tagsCol := models.InitColumn("tags", "plans", "public") + tagsCol.Type = "text[]" + table.Columns["tags"] = tagsCol + + index := &models.Index{ + Name: "idx_plans_tags", + Type: "gin", + Columns: []string{"tags"}, + } + table.Indexes[index.Name] = index + + modelSchema.Tables = append(modelSchema.Tables, table) + model.Schemas = append(model.Schemas, modelSchema) + + var buf bytes.Buffer + writer, err := NewMigrationWriter(&writers.WriterOptions{}) + if err != nil { + t.Fatalf("Failed to create writer: %v", err) + } + writer.writer = &buf + + if err := writer.WriteMigration(model, current); err != nil { + t.Fatalf("WriteMigration failed: %v", err) + } + + output := buf.String() + if !strings.Contains(output, "USING gin (tags)") { + t.Fatalf("expected GIN array index without explicit trigram opclass, got:\n%s", output) + } + if strings.Contains(output, "gin_trgm_ops") { + t.Fatalf("did not expect gin_trgm_ops for text[] migration index, got:\n%s", output) + } +} + func TestWriteMigration_WithAudit(t *testing.T) { // Current database (empty) current := models.InitDatabase("testdb") @@ -322,3 +405,46 @@ func TestWriteMigration_NumericConstraintNames(t *testing.T) { t.Error("Migration missing FOREIGN KEY") } } + +func TestNewMigrationWriter_NilOptions(t *testing.T) { + writer, err := NewMigrationWriter(nil) + if err != nil { + t.Fatalf("NewMigrationWriter(nil) returned error: %v", err) + } + if writer == nil { + t.Fatal("expected writer instance") + } + if writer.options == nil { + t.Fatal("expected default writer options to be initialized") + } +} + +func TestWriteMigration_NilCurrentTreatsDatabaseAsEmpty(t *testing.T) { + model := models.InitDatabase("testdb") + modelSchema := models.InitSchema("public") + + table := models.InitTable("users", "public") + idCol := models.InitColumn("id", "users", "public") + idCol.Type = "integer" + idCol.NotNull = true + table.Columns["id"] = idCol + + modelSchema.Tables = append(modelSchema.Tables, table) + model.Schemas = append(model.Schemas, modelSchema) + + var buf bytes.Buffer + writer, err := NewMigrationWriter(nil) + if err != nil { + t.Fatalf("Failed to create writer: %v", err) + } + writer.writer = &buf + + if err := writer.WriteMigration(model, nil); err != nil { + t.Fatalf("WriteMigration with nil current failed: %v", err) + } + + output := buf.String() + if !strings.Contains(output, "CREATE TABLE") { + t.Fatalf("expected CREATE TABLE in migration output, got:\n%s", output) + } +} diff --git a/pkg/writers/pgsql/templates.go b/pkg/writers/pgsql/templates.go index e910b82..0a59423 100644 --- a/pkg/writers/pgsql/templates.go +++ b/pkg/writers/pgsql/templates.go @@ -267,6 +267,7 @@ type CreatePrimaryKeyWithAutoGenCheckData struct { ConstraintName string AutoGenNames string // Comma-separated list of names like "'name1', 'name2'" Columns string + ColumnNames string // Comma-separated list of quoted column names like "'id', 'tenant_id'" } // Execute methods for each template diff --git a/pkg/writers/pgsql/templates/create_primary_key_with_autogen_check.tmpl b/pkg/writers/pgsql/templates/create_primary_key_with_autogen_check.tmpl index 90cb5d1..cb12d1e 100644 --- a/pkg/writers/pgsql/templates/create_primary_key_with_autogen_check.tmpl +++ b/pkg/writers/pgsql/templates/create_primary_key_with_autogen_check.tmpl @@ -1,26 +1,42 @@ DO $$ DECLARE - auto_pk_name text; + current_pk_name text; + current_pk_matches boolean := false; BEGIN - -- Drop auto-generated primary key if it exists - SELECT constraint_name INTO auto_pk_name - FROM information_schema.table_constraints - WHERE table_schema = '{{.SchemaName}}' - AND table_name = '{{.TableName}}' - AND constraint_type = 'PRIMARY KEY' - AND constraint_name IN ({{.AutoGenNames}}); + SELECT tc.constraint_name, + COALESCE( + ARRAY( + SELECT a.attname + FROM pg_constraint c + JOIN pg_class t ON t.oid = c.conrelid + JOIN pg_namespace n ON n.oid = t.relnamespace + JOIN unnest(c.conkey) WITH ORDINALITY AS cols(attnum, ord) + ON TRUE + JOIN pg_attribute a + ON a.attrelid = t.oid + AND a.attnum = cols.attnum + WHERE c.contype = 'p' + AND n.nspname = '{{.SchemaName}}' + AND t.relname = '{{.TableName}}' + ORDER BY cols.ord + ), + ARRAY[]::text[] + ) = ARRAY[{{.ColumnNames}}] + INTO current_pk_name, current_pk_matches + FROM information_schema.table_constraints tc + WHERE tc.table_schema = '{{.SchemaName}}' + AND tc.table_name = '{{.TableName}}' + AND tc.constraint_type = 'PRIMARY KEY'; - IF auto_pk_name IS NOT NULL THEN - EXECUTE 'ALTER TABLE {{qual_table .SchemaName .TableName}} DROP CONSTRAINT ' || quote_ident(auto_pk_name); + IF current_pk_name IS NOT NULL + AND NOT current_pk_matches + AND current_pk_name IN ({{.AutoGenNames}}) THEN + EXECUTE 'ALTER TABLE {{qual_table .SchemaName .TableName}} DROP CONSTRAINT ' || quote_ident(current_pk_name); END IF; - -- Add named primary key if it doesn't exist - IF NOT EXISTS ( - SELECT 1 FROM information_schema.table_constraints - WHERE table_schema = '{{.SchemaName}}' - AND table_name = '{{.TableName}}' - AND constraint_name = '{{.ConstraintName}}' - ) THEN + -- Add the desired primary key only when no matching primary key already exists. + IF current_pk_name IS NULL + OR (NOT current_pk_matches AND current_pk_name IN ({{.AutoGenNames}})) THEN ALTER TABLE {{qual_table .SchemaName .TableName}} ADD CONSTRAINT {{quote_ident .ConstraintName}} PRIMARY KEY ({{.Columns}}); END IF; END; diff --git a/pkg/writers/pgsql/writer.go b/pkg/writers/pgsql/writer.go index 613fa36..c7aa92c 100644 --- a/pkg/writers/pgsql/writer.go +++ b/pkg/writers/pgsql/writer.go @@ -228,6 +228,7 @@ func (w *Writer) GenerateSchemaStatements(schema *models.Schema) ([]string, erro ConstraintName: pkName, AutoGenNames: formatStringList(autoGenPKNames), Columns: strings.Join(pkColumns, ", "), + ColumnNames: formatStringList(pkColumns), } stmt, err := w.executor.ExecuteCreatePrimaryKeyWithAutoGenCheck(data) @@ -806,6 +807,7 @@ func (w *Writer) writePrimaryKeys(schema *models.Schema) error { ConstraintName: pkName, AutoGenNames: formatStringList(autoGenPKNames), Columns: strings.Join(columnNames, ", "), + ColumnNames: formatStringList(columnNames), } sql, err := w.executor.ExecuteCreatePrimaryKeyWithAutoGenCheck(data) diff --git a/pkg/writers/pgsql/writer_test.go b/pkg/writers/pgsql/writer_test.go index 356a6d2..60b17a3 100644 --- a/pkg/writers/pgsql/writer_test.go +++ b/pkg/writers/pgsql/writer_test.go @@ -673,9 +673,14 @@ func TestPrimaryKeyExistenceCheck(t *testing.T) { t.Errorf("Output missing logic to drop auto-generated primary key\nFull output:\n%s", output) } - // Verify it checks for our specific named constraint before adding it - if !strings.Contains(output, "constraint_name = 'pk_public_products'") { - t.Errorf("Output missing check for our named primary key constraint\nFull output:\n%s", output) + // Verify it compares the current primary key columns before dropping/recreating + if !strings.Contains(output, "current_pk_matches") || !strings.Contains(output, "ARRAY['id']") { + t.Errorf("Output missing safe primary key comparison logic\nFull output:\n%s", output) + } + + // Verify it only adds the desired key when no PK exists or an auto-generated mismatch was dropped + if !strings.Contains(output, "current_pk_name IS NULL") || !strings.Contains(output, "current_pk_name IN ('products_pkey', 'public_products_pkey')") { + t.Errorf("Output missing guarded primary key creation logic\nFull output:\n%s", output) } } @@ -766,6 +771,43 @@ func TestColumnSizeSpecifiers(t *testing.T) { } } +func TestWriteDatabase_PrimaryKeyTemplateDoesNotDropMatchingAutoPrimaryKey(t *testing.T) { + db := models.InitDatabase("testdb") + schema := models.InitSchema("public") + table := models.InitTable("learnings", "public") + + idCol := models.InitColumn("id", "learnings", "public") + idCol.Type = "bigint" + idCol.IsPrimaryKey = true + table.Columns["id"] = idCol + + parentCol := models.InitColumn("duplicate_of_learning_id", "learnings", "public") + parentCol.Type = "bigint" + table.Columns["duplicate_of_learning_id"] = parentCol + + schema.Tables = append(schema.Tables, table) + db.Schemas = append(db.Schemas, schema) + + var buf bytes.Buffer + writer := NewWriter(&writers.WriterOptions{}) + writer.writer = &buf + + if err := writer.WriteDatabase(db); err != nil { + t.Fatalf("WriteDatabase failed: %v", err) + } + + output := buf.String() + if !strings.Contains(output, "current_pk_matches") { + t.Fatalf("expected generated SQL to compare current PK columns, got:\n%s", output) + } + if !strings.Contains(output, "ARRAY['id']") { + t.Fatalf("expected generated SQL to compare against desired PK columns, got:\n%s", output) + } + if !strings.Contains(output, "NOT current_pk_matches") { + t.Fatalf("expected generated SQL to avoid dropping matching PKs, got:\n%s", output) + } +} + func TestGenerateColumnDefinition_PreservesExplicitTypeModifiers(t *testing.T) { writer := NewWriter(&writers.WriterOptions{})