feat(migration): enhance primary key handling and add GIN index support in migration writer

This commit is contained in:
2026-05-05 11:12:23 +02:00
parent a447b68b22
commit 17bc8ed395
6 changed files with 255 additions and 21 deletions

View File

@@ -31,6 +31,10 @@ type MigrationWriter struct {
// NewMigrationWriter creates a new templated migration writer // NewMigrationWriter creates a new templated migration writer
func NewMigrationWriter(options *writers.WriterOptions) (*MigrationWriter, error) { func NewMigrationWriter(options *writers.WriterOptions) (*MigrationWriter, error) {
if options == nil {
options = &writers.WriterOptions{}
}
executor, err := NewTemplateExecutor(options.FlattenSchema) executor, err := NewTemplateExecutor(options.FlattenSchema)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create template executor: %w", err) return nil, fmt.Errorf("failed to create template executor: %w", err)
@@ -44,6 +48,16 @@ func NewMigrationWriter(options *writers.WriterOptions) (*MigrationWriter, error
// WriteMigration generates migration scripts using templates // WriteMigration generates migration scripts using templates
func (w *MigrationWriter) WriteMigration(model *models.Database, current *models.Database) error { func (w *MigrationWriter) WriteMigration(model *models.Database, current *models.Database) error {
if model == nil {
return fmt.Errorf("model database is required")
}
if w.options == nil {
w.options = &writers.WriterOptions{}
}
if current == nil {
current = models.InitDatabase(model.Name)
}
var writer io.Writer var writer io.Writer
var file *os.File var file *os.File
var err error var err error
@@ -86,9 +100,16 @@ func (w *MigrationWriter) WriteMigration(model *models.Database, current *models
// Process each schema in the model // Process each schema in the model
for _, modelSchema := range model.Schemas { for _, modelSchema := range model.Schemas {
if modelSchema == nil {
continue
}
// Find corresponding schema in current database // Find corresponding schema in current database
var currentSchema *models.Schema var currentSchema *models.Schema
for _, cs := range current.Schemas { for _, cs := range current.Schemas {
if cs == nil {
continue
}
if strings.EqualFold(cs.Name, modelSchema.Name) { if strings.EqualFold(cs.Name, modelSchema.Name) {
currentSchema = cs currentSchema = cs
break break
@@ -545,12 +566,17 @@ func (w *MigrationWriter) generateIndexScripts(model *models.Schema, current *mo
indexType = modelIndex.Type indexType = modelIndex.Type
} }
columnExprs := buildIndexColumnExpressions(modelTable, modelIndex, indexType)
if len(columnExprs) == 0 {
continue
}
sql, err := w.executor.ExecuteCreateIndex(CreateIndexData{ sql, err := w.executor.ExecuteCreateIndex(CreateIndexData{
SchemaName: model.Name, SchemaName: model.Name,
TableName: modelTable.Name, TableName: modelTable.Name,
IndexName: indexName, IndexName: indexName,
IndexType: indexType, IndexType: indexType,
Columns: strings.Join(modelIndex.Columns, ", "), Columns: strings.Join(columnExprs, ", "),
Unique: modelIndex.Unique, Unique: modelIndex.Unique,
}) })
if err != nil { if err != nil {
@@ -573,6 +599,27 @@ func (w *MigrationWriter) generateIndexScripts(model *models.Schema, current *mo
return scripts, nil return scripts, nil
} }
func buildIndexColumnExpressions(table *models.Table, index *models.Index, indexType string) []string {
columnExprs := make([]string, 0, len(index.Columns))
for _, colName := range index.Columns {
colExpr := colName
if table != nil {
if col, ok := table.Columns[colName]; ok && col != nil {
colExpr = col.SQLName()
if strings.EqualFold(indexType, "gin") && isTextType(col.Type) {
opClass := extractOperatorClass(index.Comment)
if opClass == "" {
opClass = "gin_trgm_ops"
}
colExpr = fmt.Sprintf("%s %s", col.SQLName(), opClass)
}
}
}
columnExprs = append(columnExprs, colExpr)
}
return columnExprs
}
// generateForeignKeyScripts generates ADD CONSTRAINT FOREIGN KEY scripts using templates // generateForeignKeyScripts generates ADD CONSTRAINT FOREIGN KEY scripts using templates
func (w *MigrationWriter) generateForeignKeyScripts(model *models.Schema, current *models.Schema) ([]MigrationScript, error) { func (w *MigrationWriter) generateForeignKeyScripts(model *models.Schema, current *models.Schema) ([]MigrationScript, error) {
scripts := make([]MigrationScript, 0) scripts := make([]MigrationScript, 0)

View File

@@ -97,6 +97,89 @@ func TestWriteMigration_ArrayDefault(t *testing.T) {
} }
} }
func TestWriteMigration_GinIndexOnTextUsesTrigramOperatorClass(t *testing.T) {
current := models.InitDatabase("testdb")
currentSchema := models.InitSchema("public")
current.Schemas = append(current.Schemas, currentSchema)
model := models.InitDatabase("testdb")
modelSchema := models.InitSchema("public")
table := models.InitTable("articles", "public")
titleCol := models.InitColumn("title", "articles", "public")
titleCol.Type = "text"
table.Columns["title"] = titleCol
index := &models.Index{
Name: "idx_articles_title_gin",
Type: "gin",
Columns: []string{"title"},
}
table.Indexes[index.Name] = index
modelSchema.Tables = append(modelSchema.Tables, table)
model.Schemas = append(model.Schemas, modelSchema)
var buf bytes.Buffer
writer, err := NewMigrationWriter(&writers.WriterOptions{})
if err != nil {
t.Fatalf("Failed to create writer: %v", err)
}
writer.writer = &buf
if err := writer.WriteMigration(model, current); err != nil {
t.Fatalf("WriteMigration failed: %v", err)
}
output := buf.String()
if !strings.Contains(output, "USING gin (title gin_trgm_ops)") {
t.Fatalf("expected GIN text index to include gin_trgm_ops, got:\n%s", output)
}
}
func TestWriteMigration_GinIndexOnTextArrayDoesNotUseTrigramOperatorClass(t *testing.T) {
current := models.InitDatabase("testdb")
currentSchema := models.InitSchema("public")
current.Schemas = append(current.Schemas, currentSchema)
model := models.InitDatabase("testdb")
modelSchema := models.InitSchema("public")
table := models.InitTable("plans", "public")
tagsCol := models.InitColumn("tags", "plans", "public")
tagsCol.Type = "text[]"
table.Columns["tags"] = tagsCol
index := &models.Index{
Name: "idx_plans_tags",
Type: "gin",
Columns: []string{"tags"},
}
table.Indexes[index.Name] = index
modelSchema.Tables = append(modelSchema.Tables, table)
model.Schemas = append(model.Schemas, modelSchema)
var buf bytes.Buffer
writer, err := NewMigrationWriter(&writers.WriterOptions{})
if err != nil {
t.Fatalf("Failed to create writer: %v", err)
}
writer.writer = &buf
if err := writer.WriteMigration(model, current); err != nil {
t.Fatalf("WriteMigration failed: %v", err)
}
output := buf.String()
if !strings.Contains(output, "USING gin (tags)") {
t.Fatalf("expected GIN array index without explicit trigram opclass, got:\n%s", output)
}
if strings.Contains(output, "gin_trgm_ops") {
t.Fatalf("did not expect gin_trgm_ops for text[] migration index, got:\n%s", output)
}
}
func TestWriteMigration_WithAudit(t *testing.T) { func TestWriteMigration_WithAudit(t *testing.T) {
// Current database (empty) // Current database (empty)
current := models.InitDatabase("testdb") current := models.InitDatabase("testdb")
@@ -322,3 +405,46 @@ func TestWriteMigration_NumericConstraintNames(t *testing.T) {
t.Error("Migration missing FOREIGN KEY") t.Error("Migration missing FOREIGN KEY")
} }
} }
func TestNewMigrationWriter_NilOptions(t *testing.T) {
writer, err := NewMigrationWriter(nil)
if err != nil {
t.Fatalf("NewMigrationWriter(nil) returned error: %v", err)
}
if writer == nil {
t.Fatal("expected writer instance")
}
if writer.options == nil {
t.Fatal("expected default writer options to be initialized")
}
}
func TestWriteMigration_NilCurrentTreatsDatabaseAsEmpty(t *testing.T) {
model := models.InitDatabase("testdb")
modelSchema := models.InitSchema("public")
table := models.InitTable("users", "public")
idCol := models.InitColumn("id", "users", "public")
idCol.Type = "integer"
idCol.NotNull = true
table.Columns["id"] = idCol
modelSchema.Tables = append(modelSchema.Tables, table)
model.Schemas = append(model.Schemas, modelSchema)
var buf bytes.Buffer
writer, err := NewMigrationWriter(nil)
if err != nil {
t.Fatalf("Failed to create writer: %v", err)
}
writer.writer = &buf
if err := writer.WriteMigration(model, nil); err != nil {
t.Fatalf("WriteMigration with nil current failed: %v", err)
}
output := buf.String()
if !strings.Contains(output, "CREATE TABLE") {
t.Fatalf("expected CREATE TABLE in migration output, got:\n%s", output)
}
}

View File

@@ -267,6 +267,7 @@ type CreatePrimaryKeyWithAutoGenCheckData struct {
ConstraintName string ConstraintName string
AutoGenNames string // Comma-separated list of names like "'name1', 'name2'" AutoGenNames string // Comma-separated list of names like "'name1', 'name2'"
Columns string Columns string
ColumnNames string // Comma-separated list of quoted column names like "'id', 'tenant_id'"
} }
// Execute methods for each template // Execute methods for each template

View File

@@ -1,26 +1,42 @@
DO $$ DO $$
DECLARE DECLARE
auto_pk_name text; current_pk_name text;
current_pk_matches boolean := false;
BEGIN BEGIN
-- Drop auto-generated primary key if it exists SELECT tc.constraint_name,
SELECT constraint_name INTO auto_pk_name COALESCE(
FROM information_schema.table_constraints ARRAY(
WHERE table_schema = '{{.SchemaName}}' SELECT a.attname
AND table_name = '{{.TableName}}' FROM pg_constraint c
AND constraint_type = 'PRIMARY KEY' JOIN pg_class t ON t.oid = c.conrelid
AND constraint_name IN ({{.AutoGenNames}}); JOIN pg_namespace n ON n.oid = t.relnamespace
JOIN unnest(c.conkey) WITH ORDINALITY AS cols(attnum, ord)
ON TRUE
JOIN pg_attribute a
ON a.attrelid = t.oid
AND a.attnum = cols.attnum
WHERE c.contype = 'p'
AND n.nspname = '{{.SchemaName}}'
AND t.relname = '{{.TableName}}'
ORDER BY cols.ord
),
ARRAY[]::text[]
) = ARRAY[{{.ColumnNames}}]
INTO current_pk_name, current_pk_matches
FROM information_schema.table_constraints tc
WHERE tc.table_schema = '{{.SchemaName}}'
AND tc.table_name = '{{.TableName}}'
AND tc.constraint_type = 'PRIMARY KEY';
IF auto_pk_name IS NOT NULL THEN IF current_pk_name IS NOT NULL
EXECUTE 'ALTER TABLE {{qual_table .SchemaName .TableName}} DROP CONSTRAINT ' || quote_ident(auto_pk_name); AND NOT current_pk_matches
AND current_pk_name IN ({{.AutoGenNames}}) THEN
EXECUTE 'ALTER TABLE {{qual_table .SchemaName .TableName}} DROP CONSTRAINT ' || quote_ident(current_pk_name);
END IF; END IF;
-- Add named primary key if it doesn't exist -- Add the desired primary key only when no matching primary key already exists.
IF NOT EXISTS ( IF current_pk_name IS NULL
SELECT 1 FROM information_schema.table_constraints OR (NOT current_pk_matches AND current_pk_name IN ({{.AutoGenNames}})) THEN
WHERE table_schema = '{{.SchemaName}}'
AND table_name = '{{.TableName}}'
AND constraint_name = '{{.ConstraintName}}'
) THEN
ALTER TABLE {{qual_table .SchemaName .TableName}} ADD CONSTRAINT {{quote_ident .ConstraintName}} PRIMARY KEY ({{.Columns}}); ALTER TABLE {{qual_table .SchemaName .TableName}} ADD CONSTRAINT {{quote_ident .ConstraintName}} PRIMARY KEY ({{.Columns}});
END IF; END IF;
END; END;

View File

@@ -228,6 +228,7 @@ func (w *Writer) GenerateSchemaStatements(schema *models.Schema) ([]string, erro
ConstraintName: pkName, ConstraintName: pkName,
AutoGenNames: formatStringList(autoGenPKNames), AutoGenNames: formatStringList(autoGenPKNames),
Columns: strings.Join(pkColumns, ", "), Columns: strings.Join(pkColumns, ", "),
ColumnNames: formatStringList(pkColumns),
} }
stmt, err := w.executor.ExecuteCreatePrimaryKeyWithAutoGenCheck(data) stmt, err := w.executor.ExecuteCreatePrimaryKeyWithAutoGenCheck(data)
@@ -806,6 +807,7 @@ func (w *Writer) writePrimaryKeys(schema *models.Schema) error {
ConstraintName: pkName, ConstraintName: pkName,
AutoGenNames: formatStringList(autoGenPKNames), AutoGenNames: formatStringList(autoGenPKNames),
Columns: strings.Join(columnNames, ", "), Columns: strings.Join(columnNames, ", "),
ColumnNames: formatStringList(columnNames),
} }
sql, err := w.executor.ExecuteCreatePrimaryKeyWithAutoGenCheck(data) sql, err := w.executor.ExecuteCreatePrimaryKeyWithAutoGenCheck(data)

View File

@@ -673,9 +673,14 @@ func TestPrimaryKeyExistenceCheck(t *testing.T) {
t.Errorf("Output missing logic to drop auto-generated primary key\nFull output:\n%s", output) t.Errorf("Output missing logic to drop auto-generated primary key\nFull output:\n%s", output)
} }
// Verify it checks for our specific named constraint before adding it // Verify it compares the current primary key columns before dropping/recreating
if !strings.Contains(output, "constraint_name = 'pk_public_products'") { if !strings.Contains(output, "current_pk_matches") || !strings.Contains(output, "ARRAY['id']") {
t.Errorf("Output missing check for our named primary key constraint\nFull output:\n%s", output) t.Errorf("Output missing safe primary key comparison logic\nFull output:\n%s", output)
}
// Verify it only adds the desired key when no PK exists or an auto-generated mismatch was dropped
if !strings.Contains(output, "current_pk_name IS NULL") || !strings.Contains(output, "current_pk_name IN ('products_pkey', 'public_products_pkey')") {
t.Errorf("Output missing guarded primary key creation logic\nFull output:\n%s", output)
} }
} }
@@ -766,6 +771,43 @@ func TestColumnSizeSpecifiers(t *testing.T) {
} }
} }
func TestWriteDatabase_PrimaryKeyTemplateDoesNotDropMatchingAutoPrimaryKey(t *testing.T) {
db := models.InitDatabase("testdb")
schema := models.InitSchema("public")
table := models.InitTable("learnings", "public")
idCol := models.InitColumn("id", "learnings", "public")
idCol.Type = "bigint"
idCol.IsPrimaryKey = true
table.Columns["id"] = idCol
parentCol := models.InitColumn("duplicate_of_learning_id", "learnings", "public")
parentCol.Type = "bigint"
table.Columns["duplicate_of_learning_id"] = parentCol
schema.Tables = append(schema.Tables, table)
db.Schemas = append(db.Schemas, schema)
var buf bytes.Buffer
writer := NewWriter(&writers.WriterOptions{})
writer.writer = &buf
if err := writer.WriteDatabase(db); err != nil {
t.Fatalf("WriteDatabase failed: %v", err)
}
output := buf.String()
if !strings.Contains(output, "current_pk_matches") {
t.Fatalf("expected generated SQL to compare current PK columns, got:\n%s", output)
}
if !strings.Contains(output, "ARRAY['id']") {
t.Fatalf("expected generated SQL to compare against desired PK columns, got:\n%s", output)
}
if !strings.Contains(output, "NOT current_pk_matches") {
t.Fatalf("expected generated SQL to avoid dropping matching PKs, got:\n%s", output)
}
}
func TestGenerateColumnDefinition_PreservesExplicitTypeModifiers(t *testing.T) { func TestGenerateColumnDefinition_PreservesExplicitTypeModifiers(t *testing.T) {
writer := NewWriter(&writers.WriterOptions{}) writer := NewWriter(&writers.WriterOptions{})