From 2d97a47ee102c41888ea77890f84cb708b7feb02 Mon Sep 17 00:00:00 2001 From: "Hein (Warky)" Date: Tue, 5 May 2026 14:50:34 +0200 Subject: [PATCH] feat: Enhance PostgreSQL type handling and migration scripts - Introduced equivalent base types and variants for PostgreSQL types to normalize type comparisons. - Added functions for normalizing SQL types and retrieving equivalent type variants. - Updated migration writer to handle type alterations with checks for existing types. - Implemented logic to create necessary extensions (e.g., pg_trgm) based on schema requirements. - Enhanced tests to cover new functionality for type normalization and migration handling. - Improved handling of GIN indexes to use appropriate operator classes based on column types. --- cmd/relspec/merge.go | 5 + cmd/relspec/merge_from_list_test.go | 36 +++ cmd/relspec/testhelpers_test.go | 34 +++ pkg/merge/merge.go | 99 ++++++- pkg/merge/merge_test.go | 62 ++++ pkg/pgsql/types_registry.go | 98 +++++++ pkg/pgsql/types_registry_test.go | 48 +++ pkg/writers/pgsql/migration_writer.go | 39 ++- pkg/writers/pgsql/migration_writer_test.go | 245 +++++++++++++++- pkg/writers/pgsql/templates.go | 19 ++ .../pgsql/templates/alter_column_type.tmpl | 2 +- .../alter_column_type_with_check.tmpl | 22 ++ pkg/writers/pgsql/writer.go | 275 ++++++++++++++---- pkg/writers/pgsql/writer_test.go | 123 +++++++- 14 files changed, 1042 insertions(+), 65 deletions(-) create mode 100644 pkg/writers/pgsql/templates/alter_column_type_with_check.tmpl diff --git a/cmd/relspec/merge.go b/cmd/relspec/merge.go index 55281e3..3677df9 100644 --- a/cmd/relspec/merge.go +++ b/cmd/relspec/merge.go @@ -258,6 +258,11 @@ func runMerge(cmd *cobra.Command, args []string) error { fmt.Fprintf(os.Stderr, " ✓ Merge complete\n\n") fmt.Fprintf(os.Stderr, "%s\n", merge.GetMergeSummary(result)) + if strings.EqualFold(mergeOutputType, "pgsql") && len(result.TypeConflicts) > 0 { + return fmt.Errorf("merge detected conflicting existing column types and cannot safely continue with pgsql output\n%s", + merge.GetColumnTypeConflictSummary(result, 10)) + } + // Step 4: Write output fmt.Fprintf(os.Stderr, "\n[4/4] Writing output...\n") fmt.Fprintf(os.Stderr, " Format: %s\n", mergeOutputType) diff --git a/cmd/relspec/merge_from_list_test.go b/cmd/relspec/merge_from_list_test.go index c590ba8..719dca2 100644 --- a/cmd/relspec/merge_from_list_test.go +++ b/cmd/relspec/merge_from_list_test.go @@ -3,6 +3,7 @@ package main import ( "os" "path/filepath" + "strings" "testing" ) @@ -160,3 +161,38 @@ func TestRunMerge_FromListMissingSourceType(t *testing.T) { t.Error("expected error when neither --source-path nor --from-list is provided") } } + +func TestRunMerge_PgsqlOutputRejectsColumnTypeConflict(t *testing.T) { + saved := saveMergeState() + defer restoreMergeState(saved) + + dir := t.TempDir() + targetFile := filepath.Join(dir, "target.json") + sourceFile := filepath.Join(dir, "source.json") + writeTestJSONWithSingleColumnType(t, targetFile, "users", "integer") + writeTestJSONWithSingleColumnType(t, sourceFile, "users", "uuid") + + mergeTargetType = "json" + mergeTargetPath = targetFile + mergeTargetConn = "" + mergeSourceType = "json" + mergeSourcePath = sourceFile + mergeSourceConn = "" + mergeFromList = nil + mergeOutputType = "pgsql" + mergeOutputPath = "" + mergeOutputConn = "postgres://relspec:secret@localhost/testdb" + mergeSkipTables = "" + mergeReportPath = "" + + err := runMerge(nil, nil) + if err == nil { + t.Fatal("expected pgsql output merge to fail on column type conflict") + } + if !strings.Contains(err.Error(), "column type conflicts detected") { + t.Fatalf("expected conflict summary in error, got: %v", err) + } + if !strings.Contains(err.Error(), "public.users.id") { + t.Fatalf("expected conflicting column path in error, got: %v", err) + } +} diff --git a/cmd/relspec/testhelpers_test.go b/cmd/relspec/testhelpers_test.go index fa79bee..4662ed3 100644 --- a/cmd/relspec/testhelpers_test.go +++ b/cmd/relspec/testhelpers_test.go @@ -71,6 +71,40 @@ func writeTestJSON(t *testing.T, path string, tableNames []string) { } } +func writeTestJSONWithSingleColumnType(t *testing.T, path, tableName, columnType string) { + t.Helper() + + db := minimalDatabase{ + Name: "test_db", + Schemas: []minimalSchema{{ + Name: "public", + Tables: []minimalTable{{ + Name: tableName, + Schema: "public", + Columns: map[string]minimalColumn{ + "id": { + Name: "id", + Table: tableName, + Schema: "public", + Type: columnType, + NotNull: true, + IsPrimaryKey: true, + AutoIncrement: true, + }, + }, + }}, + }}, + } + + data, err := json.Marshal(db) + if err != nil { + t.Fatalf("failed to marshal test JSON: %v", err) + } + if err := os.WriteFile(path, data, 0644); err != nil { + t.Fatalf("failed to write test file %s: %v", path, err) + } +} + // convertState captures and restores all convert global vars. type convertState struct { sourceType string diff --git a/pkg/merge/merge.go b/pkg/merge/merge.go index 63e7b08..0ed3186 100644 --- a/pkg/merge/merge.go +++ b/pkg/merge/merge.go @@ -22,6 +22,16 @@ type MergeResult struct { EnumsAdded int ViewsAdded int SequencesAdded int + TypeConflicts []ColumnTypeConflict +} + +// ColumnTypeConflict describes a column that exists in both schemas but with incompatible types. +type ColumnTypeConflict struct { + Schema string + Table string + Column string + TargetType string + SourceType string } // MergeOptions contains options for merge operations @@ -146,11 +156,19 @@ func (r *MergeResult) mergeColumns(table *models.Table, srcTable *models.Table) // Merge columns for colName, srcCol := range srcTable.Columns { - if _, exists := existingColumns[colName]; !exists { + if tgtCol, exists := existingColumns[colName]; !exists { // Column doesn't exist, add it newCol := cloneColumn(srcCol) table.Columns[colName] = newCol r.ColumnsAdded++ + } else if columnTypeConflict(tgtCol, srcCol) { + r.TypeConflicts = append(r.TypeConflicts, ColumnTypeConflict{ + Schema: firstNonEmpty(table.Schema, srcTable.Schema, srcCol.Schema), + Table: firstNonEmpty(table.Name, srcTable.Name, srcCol.Table), + Column: firstNonEmpty(tgtCol.Name, srcCol.Name, colName), + TargetType: describeColumnType(tgtCol), + SourceType: describeColumnType(srcCol), + }) } } } @@ -426,6 +444,52 @@ func cloneColumn(col *models.Column) *models.Column { return newCol } +func columnTypeConflict(target, source *models.Column) bool { + if target == nil || source == nil { + return false + } + + return normalizeType(target.Type) != normalizeType(source.Type) || + target.Length != source.Length || + target.Precision != source.Precision || + target.Scale != source.Scale +} + +func normalizeType(value string) string { + return strings.ToLower(strings.TrimSpace(value)) +} + +func describeColumnType(col *models.Column) string { + if col == nil { + return "" + } + + typeName := strings.TrimSpace(col.Type) + if typeName == "" { + return "" + } + + switch { + case col.Precision > 0 && col.Scale > 0: + return fmt.Sprintf("%s(%d,%d)", typeName, col.Precision, col.Scale) + case col.Precision > 0: + return fmt.Sprintf("%s(%d)", typeName, col.Precision) + case col.Length > 0: + return fmt.Sprintf("%s(%d)", typeName, col.Length) + default: + return typeName + } +} + +func firstNonEmpty(values ...string) string { + for _, value := range values { + if strings.TrimSpace(value) != "" { + return value + } + } + return "" +} + func cloneConstraint(constraint *models.Constraint) *models.Constraint { if constraint == nil { return nil @@ -609,6 +673,7 @@ func GetMergeSummary(result *MergeResult) string { fmt.Sprintf("Enums added: %d", result.EnumsAdded), fmt.Sprintf("Relations added: %d", result.RelationsAdded), fmt.Sprintf("Domains added: %d", result.DomainsAdded), + fmt.Sprintf("Type conflicts: %d", len(result.TypeConflicts)), } totalAdded := result.SchemasAdded + result.TablesAdded + result.ColumnsAdded + @@ -625,3 +690,35 @@ func GetMergeSummary(result *MergeResult) string { return summary } + +// GetColumnTypeConflictSummary returns a short, human-readable conflict summary. +func GetColumnTypeConflictSummary(result *MergeResult, limit int) string { + if result == nil || len(result.TypeConflicts) == 0 { + return "" + } + if limit <= 0 { + limit = len(result.TypeConflicts) + } + + lines := make([]string, 0, min(limit, len(result.TypeConflicts))+1) + lines = append(lines, "column type conflicts detected:") + for i, conflict := range result.TypeConflicts { + if i >= limit { + break + } + lines = append(lines, fmt.Sprintf(" - %s.%s.%s: target=%s source=%s", + conflict.Schema, conflict.Table, conflict.Column, conflict.TargetType, conflict.SourceType)) + } + if len(result.TypeConflicts) > limit { + lines = append(lines, fmt.Sprintf(" ... and %d more", len(result.TypeConflicts)-limit)) + } + + return strings.Join(lines, "\n") +} + +func min(a, b int) int { + if a < b { + return a + } + return b +} diff --git a/pkg/merge/merge_test.go b/pkg/merge/merge_test.go index f93abef..e5f18cb 100644 --- a/pkg/merge/merge_test.go +++ b/pkg/merge/merge_test.go @@ -1,6 +1,7 @@ package merge import ( + "strings" "testing" "git.warky.dev/wdevs/relspecgo/pkg/models" @@ -140,6 +141,61 @@ func TestMergeColumns_NewColumn(t *testing.T) { } } +func TestMergeColumns_TypeConflictIsDetected(t *testing.T) { + target := &models.Database{ + Schemas: []*models.Schema{ + { + Name: "public", + Tables: []*models.Table{ + { + Name: "users", + Schema: "public", + Columns: map[string]*models.Column{ + "email": {Name: "email", Type: "varchar", Length: 255}, + }, + }, + }, + }, + }, + } + source := &models.Database{ + Schemas: []*models.Schema{ + { + Name: "public", + Tables: []*models.Table{ + { + Name: "users", + Schema: "public", + Columns: map[string]*models.Column{ + "email": {Name: "email", Type: "text"}, + }, + }, + }, + }, + }, + } + + result := MergeDatabases(target, source, nil) + + if len(result.TypeConflicts) != 1 { + t.Fatalf("Expected 1 type conflict, got %d", len(result.TypeConflicts)) + } + conflict := result.TypeConflicts[0] + if conflict.Schema != "public" || conflict.Table != "users" || conflict.Column != "email" { + t.Fatalf("Unexpected conflict location: %+v", conflict) + } + if conflict.TargetType != "varchar(255)" { + t.Fatalf("Expected target type varchar(255), got %q", conflict.TargetType) + } + if conflict.SourceType != "text" { + t.Fatalf("Expected source type text, got %q", conflict.SourceType) + } + + if got := target.Schemas[0].Tables[0].Columns["email"].Type; got != "varchar" { + t.Fatalf("Expected target column type to remain unchanged, got %q", got) + } +} + func TestMergeConstraints_NewConstraint(t *testing.T) { target := &models.Database{ Schemas: []*models.Schema{ @@ -509,6 +565,9 @@ func TestGetMergeSummary(t *testing.T) { ConstraintsAdded: 3, IndexesAdded: 2, ViewsAdded: 1, + TypeConflicts: []ColumnTypeConflict{ + {Schema: "public", Table: "users", Column: "email", TargetType: "varchar(255)", SourceType: "text"}, + }, } summary := GetMergeSummary(result) @@ -518,6 +577,9 @@ func TestGetMergeSummary(t *testing.T) { if len(summary) < 50 { t.Errorf("Summary seems too short: %s", summary) } + if !strings.Contains(summary, "Type conflicts: 1") { + t.Errorf("Expected type conflict count in summary, got: %s", summary) + } } func TestGetMergeSummary_Nil(t *testing.T) { diff --git a/pkg/pgsql/types_registry.go b/pkg/pgsql/types_registry.go index d0d507d..b6a14e8 100644 --- a/pkg/pgsql/types_registry.go +++ b/pkg/pgsql/types_registry.go @@ -145,6 +145,24 @@ var postgresTypeAliases = map[string]string{ "bool": "boolean", } +var postgresEquivalentBaseTypes = map[string]string{ + "character varying": "varchar", + "character": "char", + "timestamp without time zone": "timestamp", + "timestamp with time zone": "timestamptz", + "time without time zone": "time", + "time with time zone": "timetz", +} + +var postgresEquivalentBaseTypeVariants = map[string][]string{ + "varchar": {"varchar", "character varying"}, + "char": {"char", "character"}, + "timestamp": {"timestamp", "timestamp without time zone"}, + "timestamptz": {"timestamptz", "timestamp with time zone"}, + "time": {"time", "time without time zone"}, + "timetz": {"timetz", "time with time zone"}, +} + // GetPostgresBaseTypes returns a sorted-ish stable list of registered base type names. func GetPostgresBaseTypes() []string { result := make([]string, 0, len(postgresBaseTypes)) @@ -212,6 +230,86 @@ func CanonicalizeBaseType(baseType string) string { return base } +// EquivalentBaseType resolves broader SQL-equivalent spellings to a common comparable form. +func EquivalentBaseType(baseType string) string { + base := CanonicalizeBaseType(baseType) + if equivalent, ok := postgresEquivalentBaseTypes[base]; ok { + return equivalent + } + return base +} + +// NormalizeEquivalentSQLType returns a normalized SQL type string suitable for equality checks. +// Equivalent spellings such as "character varying(255)" and "varchar(255)" normalize identically. +func NormalizeEquivalentSQLType(sqlType string) string { + t := normalizeTypeToken(sqlType) + if t == "" { + return "" + } + + arrayDepth := 0 + for strings.HasSuffix(t, "[]") { + arrayDepth++ + t = strings.TrimSpace(strings.TrimSuffix(t, "[]")) + } + + modifier := "" + if idx := strings.Index(t, "("); idx >= 0 { + modifier = strings.TrimSpace(t[idx:]) + t = strings.TrimSpace(t[:idx]) + } + + base := EquivalentBaseType(t) + normalized := base + modifier + for i := 0; i < arrayDepth; i++ { + normalized += "[]" + } + return normalized +} + +// EquivalentSQLTypeVariants returns equivalent PostgreSQL spellings for a SQL type. +// Examples: +// - varchar(255) -> ["varchar(255)", "character varying(255)"] +// - timestamptz -> ["timestamptz", "timestamp with time zone"] +func EquivalentSQLTypeVariants(sqlType string) []string { + t := normalizeTypeToken(sqlType) + if t == "" { + return nil + } + + arrayDepth := 0 + for strings.HasSuffix(t, "[]") { + arrayDepth++ + t = strings.TrimSpace(strings.TrimSuffix(t, "[]")) + } + + modifier := "" + if idx := strings.Index(t, "("); idx >= 0 { + modifier = strings.TrimSpace(t[idx:]) + t = strings.TrimSpace(t[:idx]) + } + + base := EquivalentBaseType(t) + bases := postgresEquivalentBaseTypeVariants[base] + if len(bases) == 0 { + bases = []string{base} + } + + seen := make(map[string]bool, len(bases)) + result := make([]string, 0, len(bases)) + for _, variantBase := range bases { + variant := variantBase + modifier + for i := 0; i < arrayDepth; i++ { + variant += "[]" + } + if !seen[variant] { + seen[variant] = true + result = append(result, variant) + } + } + return result +} + // IsKnownPostgresType reports whether a type (including array forms) exists in the registry. func IsKnownPostgresType(sqlType string) bool { base := CanonicalizeBaseType(ExtractBaseTypeLower(sqlType)) diff --git a/pkg/pgsql/types_registry_test.go b/pkg/pgsql/types_registry_test.go index 3a6c104..6653bbf 100644 --- a/pkg/pgsql/types_registry_test.go +++ b/pkg/pgsql/types_registry_test.go @@ -97,3 +97,51 @@ func TestPostgresTypeRegistry_TypeParsingAndCapabilities(t *testing.T) { }) } } + +func TestNormalizeEquivalentSQLType(t *testing.T) { + tests := []struct { + input string + want string + }{ + {input: "character varying(255)", want: "varchar(255)"}, + {input: "varchar(255)", want: "varchar(255)"}, + {input: "timestamp with time zone", want: "timestamptz"}, + {input: "timestamptz", want: "timestamptz"}, + {input: "time without time zone", want: "time"}, + {input: "character varying(255)[]", want: "varchar(255)[]"}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + got := NormalizeEquivalentSQLType(tt.input) + if got != tt.want { + t.Fatalf("NormalizeEquivalentSQLType(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} + +func TestEquivalentSQLTypeVariants(t *testing.T) { + tests := []struct { + input string + want []string + }{ + {input: "character varying(255)", want: []string{"varchar(255)", "character varying(255)"}}, + {input: "timestamptz", want: []string{"timestamptz", "timestamp with time zone"}}, + {input: "text[]", want: []string{"text[]"}}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + got := EquivalentSQLTypeVariants(tt.input) + if len(got) != len(tt.want) { + t.Fatalf("EquivalentSQLTypeVariants(%q) len = %d, want %d (%v)", tt.input, len(got), len(tt.want), got) + } + for i := range tt.want { + if got[i] != tt.want[i] { + t.Fatalf("EquivalentSQLTypeVariants(%q)[%d] = %q, want %q", tt.input, i, got[i], tt.want[i]) + } + } + }) + } +} diff --git a/pkg/writers/pgsql/migration_writer.go b/pkg/writers/pgsql/migration_writer.go index 7ebfab6..bc72898 100644 --- a/pkg/writers/pgsql/migration_writer.go +++ b/pkg/writers/pgsql/migration_writer.go @@ -160,6 +160,17 @@ func (w *MigrationWriter) WriteMigration(model *models.Database, current *models func (w *MigrationWriter) generateSchemaScripts(model *models.Schema, current *models.Schema) ([]MigrationScript, error) { scripts := make([]MigrationScript, 0) + if schemaRequiresPGTrgm(model) { + scripts = append(scripts, MigrationScript{ + ObjectName: "extension.pg_trgm", + ObjectType: "create extension", + Schema: model.Name, + Priority: 80, + Sequence: len(scripts), + Body: "CREATE EXTENSION IF NOT EXISTS pg_trgm;", + }) + } + // Phase 1: Drop constraints and indexes that changed (Priority 11-50) if current != nil { dropScripts, err := w.generateDropScripts(model, current) @@ -361,7 +372,7 @@ func (w *MigrationWriter) generateAlterTableScripts(schema *models.Schema, model SchemaName: schema.Name, TableName: modelTable.Name, ColumnName: modelCol.Name, - ColumnType: pgsql.ConvertSQLType(modelCol.Type), + ColumnType: effectiveColumnSQLType(modelCol), Default: defaultVal, NotNull: modelCol.NotNull, }) @@ -380,12 +391,13 @@ func (w *MigrationWriter) generateAlterTableScripts(schema *models.Schema, model scripts = append(scripts, script) } else if !columnsEqual(modelCol, currentCol) { // Column exists but properties changed - if modelCol.Type != currentCol.Type { + if !columnTypesEqual(modelCol, currentCol) { sql, err := w.executor.ExecuteAlterColumnType(AlterColumnTypeData{ SchemaName: schema.Name, TableName: modelTable.Name, ColumnName: modelCol.Name, - NewType: pgsql.ConvertSQLType(modelCol.Type), + NewType: effectiveAlterColumnSQLType(modelCol), + UsingExpr: buildAlterColumnUsingExpression(modelCol.Name, effectiveAlterColumnSQLType(modelCol)), }) if err != nil { return nil, err @@ -606,12 +618,11 @@ func buildIndexColumnExpressions(table *models.Table, index *models.Index, index if table != nil { if col, ok := resolveIndexColumn(table, colName); ok && col != nil { colExpr = col.SQLName() - if strings.EqualFold(indexType, "gin") && isTextType(col.Type) { - opClass := extractOperatorClass(index.Comment) - if opClass == "" { - opClass = "gin_trgm_ops" + if strings.EqualFold(indexType, "gin") { + opClass := ginOperatorClassForColumn(col, index.Comment) + if opClass != "" { + colExpr = fmt.Sprintf("%s %s", col.SQLName(), opClass) } - colExpr = fmt.Sprintf("%s %s", col.SQLName(), opClass) } } } @@ -875,11 +886,21 @@ func columnsEqual(col1, col2 *models.Column) bool { if col1 == nil || col2 == nil { return false } - return strings.EqualFold(col1.Type, col2.Type) && + return columnTypesEqual(col1, col2) && col1.NotNull == col2.NotNull && fmt.Sprintf("%v", col1.Default) == fmt.Sprintf("%v", col2.Default) } +func columnTypesEqual(col1, col2 *models.Column) bool { + if col1 == nil || col2 == nil { + return false + } + return strings.EqualFold( + pgsql.NormalizeEquivalentSQLType(effectiveColumnSQLType(col1)), + pgsql.NormalizeEquivalentSQLType(effectiveColumnSQLType(col2)), + ) +} + // constraintsEqual checks if two constraints are equal func constraintsEqual(c1, c2 *models.Constraint) bool { if c1 == nil || c2 == nil { diff --git a/pkg/writers/pgsql/migration_writer_test.go b/pkg/writers/pgsql/migration_writer_test.go index f63a6b2..f39ed18 100644 --- a/pkg/writers/pgsql/migration_writer_test.go +++ b/pkg/writers/pgsql/migration_writer_test.go @@ -97,6 +97,160 @@ func TestWriteMigration_ArrayDefault(t *testing.T) { } } +func TestWriteMigration_AltersColumnTypeWhenActualTypeDiffers(t *testing.T) { + current := models.InitDatabase("testdb") + currentSchema := models.InitSchema("public") + currentTable := models.InitTable("learnings", "public") + currentDetails := models.InitColumn("details", "learnings", "public") + currentDetails.Type = "jsonb" + currentTable.Columns["details"] = currentDetails + currentSchema.Tables = append(currentSchema.Tables, currentTable) + current.Schemas = append(current.Schemas, currentSchema) + + model := models.InitDatabase("testdb") + modelSchema := models.InitSchema("public") + modelTable := models.InitTable("learnings", "public") + modelDetails := models.InitColumn("details", "learnings", "public") + modelDetails.Type = "text" + modelTable.Columns["details"] = modelDetails + modelSchema.Tables = append(modelSchema.Tables, modelTable) + model.Schemas = append(model.Schemas, modelSchema) + + var buf bytes.Buffer + writer, err := NewMigrationWriter(&writers.WriterOptions{}) + if err != nil { + t.Fatalf("Failed to create writer: %v", err) + } + writer.writer = &buf + + if err := writer.WriteMigration(model, current); err != nil { + t.Fatalf("WriteMigration failed: %v", err) + } + + output := buf.String() + if !strings.Contains(output, "ALTER TABLE public.learnings") || !strings.Contains(output, "ALTER COLUMN details TYPE text") { + t.Fatalf("expected migration to alter mismatched column type, got:\n%s", output) + } + if !strings.Contains(output, `ALTER COLUMN details TYPE text USING details::text;`) { + t.Fatalf("expected migration type alter to include USING cast, got:\n%s", output) + } +} + +func TestWriteMigration_UsesStorageTypeForSerialAlterStatements(t *testing.T) { + current := models.InitDatabase("testdb") + currentSchema := models.InitSchema("public") + currentTable := models.InitTable("learnings", "public") + currentID := models.InitColumn("id", "learnings", "public") + currentID.Type = "uuid" + currentTable.Columns["id"] = currentID + currentSchema.Tables = append(currentSchema.Tables, currentTable) + current.Schemas = append(current.Schemas, currentSchema) + + model := models.InitDatabase("testdb") + modelSchema := models.InitSchema("public") + modelTable := models.InitTable("learnings", "public") + modelID := models.InitColumn("id", "learnings", "public") + modelID.Type = "bigserial" + modelTable.Columns["id"] = modelID + modelSchema.Tables = append(modelSchema.Tables, modelTable) + model.Schemas = append(model.Schemas, modelSchema) + + var buf bytes.Buffer + writer, err := NewMigrationWriter(&writers.WriterOptions{}) + if err != nil { + t.Fatalf("Failed to create writer: %v", err) + } + writer.writer = &buf + + if err := writer.WriteMigration(model, current); err != nil { + t.Fatalf("WriteMigration failed: %v", err) + } + + output := buf.String() + if !strings.Contains(output, "ALTER COLUMN id TYPE bigint") { + t.Fatalf("expected serial alter to use bigint storage type, got:\n%s", output) + } + if strings.Contains(output, "ALTER COLUMN id TYPE bigserial;") { + t.Fatalf("did not expect invalid bigserial alter statement, got:\n%s", output) + } + if !strings.Contains(output, `ALTER COLUMN id TYPE bigint USING id::bigint;`) { + t.Fatalf("expected serial alter to include USING cast, got:\n%s", output) + } +} + +func TestWriteMigration_ArrayAlterIncludesUsingCast(t *testing.T) { + current := models.InitDatabase("testdb") + currentSchema := models.InitSchema("public") + currentTable := models.InitTable("learnings", "public") + currentTags := models.InitColumn("tags", "learnings", "public") + currentTags.Type = "text" + currentTable.Columns["tags"] = currentTags + currentSchema.Tables = append(currentSchema.Tables, currentTable) + current.Schemas = append(current.Schemas, currentSchema) + + model := models.InitDatabase("testdb") + modelSchema := models.InitSchema("public") + modelTable := models.InitTable("learnings", "public") + modelTags := models.InitColumn("tags", "learnings", "public") + modelTags.Type = "text[]" + modelTable.Columns["tags"] = modelTags + modelSchema.Tables = append(modelSchema.Tables, modelTable) + model.Schemas = append(model.Schemas, modelSchema) + + var buf bytes.Buffer + writer, err := NewMigrationWriter(&writers.WriterOptions{}) + if err != nil { + t.Fatalf("Failed to create writer: %v", err) + } + writer.writer = &buf + + if err := writer.WriteMigration(model, current); err != nil { + t.Fatalf("WriteMigration failed: %v", err) + } + + output := buf.String() + if !strings.Contains(output, `ALTER COLUMN tags TYPE text[] USING tags::text[];`) { + t.Fatalf("expected array alter to include USING cast, got:\n%s", output) + } +} + +func TestWriteMigration_DoesNotAlterEquivalentNormalizedColumnType(t *testing.T) { + current := models.InitDatabase("testdb") + currentSchema := models.InitSchema("public") + currentTable := models.InitTable("users", "public") + currentEmail := models.InitColumn("email", "users", "public") + currentEmail.Type = "character varying" + currentEmail.Length = 255 + currentTable.Columns["email"] = currentEmail + currentSchema.Tables = append(currentSchema.Tables, currentTable) + current.Schemas = append(current.Schemas, currentSchema) + + model := models.InitDatabase("testdb") + modelSchema := models.InitSchema("public") + modelTable := models.InitTable("users", "public") + modelEmail := models.InitColumn("email", "users", "public") + modelEmail.Type = "varchar(255)" + modelTable.Columns["email"] = modelEmail + modelSchema.Tables = append(modelSchema.Tables, modelTable) + model.Schemas = append(model.Schemas, modelSchema) + + var buf bytes.Buffer + writer, err := NewMigrationWriter(&writers.WriterOptions{}) + if err != nil { + t.Fatalf("Failed to create writer: %v", err) + } + writer.writer = &buf + + if err := writer.WriteMigration(model, current); err != nil { + t.Fatalf("WriteMigration failed: %v", err) + } + + output := buf.String() + if strings.Contains(output, "ALTER COLUMN email TYPE") { + t.Fatalf("did not expect alter type for equivalent normalized types, got:\n%s", output) + } +} + func TestWriteMigration_GinIndexOnTextUsesTrigramOperatorClass(t *testing.T) { current := models.InitDatabase("testdb") currentSchema := models.InitSchema("public") @@ -132,6 +286,9 @@ func TestWriteMigration_GinIndexOnTextUsesTrigramOperatorClass(t *testing.T) { } output := buf.String() + if !strings.Contains(output, "CREATE EXTENSION IF NOT EXISTS pg_trgm;") { + t.Fatalf("expected trigram extension for text GIN migration index, got:\n%s", output) + } if !strings.Contains(output, "USING gin (title gin_trgm_ops)") { t.Fatalf("expected GIN text index to include gin_trgm_ops, got:\n%s", output) } @@ -212,14 +369,98 @@ func TestWriteMigration_GinIndexOnTextArrayDoesNotUseTrigramOperatorClass(t *tes } output := buf.String() - if !strings.Contains(output, "USING gin (tags)") { - t.Fatalf("expected GIN array index without explicit trigram opclass, got:\n%s", output) + if !strings.Contains(output, "USING gin (tags array_ops)") { + t.Fatalf("expected GIN array index with array_ops, got:\n%s", output) } if strings.Contains(output, "gin_trgm_ops") { t.Fatalf("did not expect gin_trgm_ops for text[] migration index, got:\n%s", output) } } +func TestWriteMigration_GinIndexOnJSONBUsesJSONBOperatorClass(t *testing.T) { + current := models.InitDatabase("testdb") + currentSchema := models.InitSchema("public") + current.Schemas = append(current.Schemas, currentSchema) + + model := models.InitDatabase("testdb") + modelSchema := models.InitSchema("public") + + table := models.InitTable("learnings", "public") + detailsCol := models.InitColumn("details", "learnings", "public") + detailsCol.Type = "jsonb" + table.Columns["details"] = detailsCol + + index := &models.Index{ + Name: "idx_learnings_details", + Type: "gin", + Columns: []string{"details"}, + } + table.Indexes[index.Name] = index + + modelSchema.Tables = append(modelSchema.Tables, table) + model.Schemas = append(model.Schemas, modelSchema) + + var buf bytes.Buffer + writer, err := NewMigrationWriter(&writers.WriterOptions{}) + if err != nil { + t.Fatalf("Failed to create writer: %v", err) + } + writer.writer = &buf + + if err := writer.WriteMigration(model, current); err != nil { + t.Fatalf("WriteMigration failed: %v", err) + } + + output := buf.String() + if !strings.Contains(output, "USING gin (details jsonb_ops)") { + t.Fatalf("expected GIN jsonb index to include jsonb_ops, got:\n%s", output) + } + if strings.Contains(output, "gin_trgm_ops") { + t.Fatalf("did not expect gin_trgm_ops for jsonb migration index, got:\n%s", output) + } +} + +func TestWriteMigration_GinIndexOnJSONBIgnoresIncompatibleTrigramOperatorClass(t *testing.T) { + current := models.InitDatabase("testdb") + currentSchema := models.InitSchema("public") + current.Schemas = append(current.Schemas, currentSchema) + + model := models.InitDatabase("testdb") + modelSchema := models.InitSchema("public") + + table := models.InitTable("learnings", "public") + detailsCol := models.InitColumn("details", "learnings", "public") + detailsCol.Type = "jsonb" + table.Columns["details"] = detailsCol + + index := &models.Index{ + Name: "idx_learnings_details", + Type: "gin", + Columns: []string{"details"}, + Comment: "gin_trgm_ops", + } + table.Indexes[index.Name] = index + + modelSchema.Tables = append(modelSchema.Tables, table) + model.Schemas = append(model.Schemas, modelSchema) + + var buf bytes.Buffer + writer, err := NewMigrationWriter(&writers.WriterOptions{}) + if err != nil { + t.Fatalf("Failed to create writer: %v", err) + } + writer.writer = &buf + + if err := writer.WriteMigration(model, current); err != nil { + t.Fatalf("WriteMigration failed: %v", err) + } + + output := buf.String() + if !strings.Contains(output, "USING gin (details jsonb_ops)") { + t.Fatalf("expected incompatible trigram hint on jsonb to fall back to jsonb_ops, got:\n%s", output) + } +} + func TestWriteMigration_WithAudit(t *testing.T) { // Current database (empty) current := models.InitDatabase("testdb") diff --git a/pkg/writers/pgsql/templates.go b/pkg/writers/pgsql/templates.go index 0a59423..b299194 100644 --- a/pkg/writers/pgsql/templates.go +++ b/pkg/writers/pgsql/templates.go @@ -95,6 +95,16 @@ type AlterColumnTypeData struct { TableName string ColumnName string NewType string + UsingExpr string +} + +type AlterColumnTypeWithCheckData struct { + SchemaName string + TableName string + ColumnName string + NewType string + EquivalentTypes string + UsingExpr string } // AlterColumnDefaultData contains data for alter column default template @@ -302,6 +312,15 @@ func (te *TemplateExecutor) ExecuteAlterColumnType(data AlterColumnTypeData) (st return buf.String(), nil } +func (te *TemplateExecutor) ExecuteAlterColumnTypeWithCheck(data AlterColumnTypeWithCheckData) (string, error) { + var buf bytes.Buffer + err := te.templates.ExecuteTemplate(&buf, "alter_column_type_with_check.tmpl", data) + if err != nil { + return "", fmt.Errorf("failed to execute alter_column_type_with_check template: %w", err) + } + return buf.String(), nil +} + // ExecuteAlterColumnDefault executes the alter column default template func (te *TemplateExecutor) ExecuteAlterColumnDefault(data AlterColumnDefaultData) (string, error) { var buf bytes.Buffer diff --git a/pkg/writers/pgsql/templates/alter_column_type.tmpl b/pkg/writers/pgsql/templates/alter_column_type.tmpl index 2d5e572..aaccd28 100644 --- a/pkg/writers/pgsql/templates/alter_column_type.tmpl +++ b/pkg/writers/pgsql/templates/alter_column_type.tmpl @@ -1,2 +1,2 @@ ALTER TABLE {{qual_table .SchemaName .TableName}} - ALTER COLUMN {{quote_ident .ColumnName}} TYPE {{.NewType}}; \ No newline at end of file + ALTER COLUMN {{quote_ident .ColumnName}} TYPE {{.NewType}}{{if .UsingExpr}} USING {{.UsingExpr}}{{end}}; diff --git a/pkg/writers/pgsql/templates/alter_column_type_with_check.tmpl b/pkg/writers/pgsql/templates/alter_column_type_with_check.tmpl new file mode 100644 index 0000000..6c2ed8e --- /dev/null +++ b/pkg/writers/pgsql/templates/alter_column_type_with_check.tmpl @@ -0,0 +1,22 @@ +DO $$ +DECLARE + current_type text; +BEGIN + SELECT pg_catalog.format_type(a.atttypid, a.atttypmod) + INTO current_type + FROM pg_attribute a + JOIN pg_class t ON t.oid = a.attrelid + JOIN pg_namespace n ON n.oid = t.relnamespace + WHERE n.nspname = '{{.SchemaName}}' + AND t.relname = '{{.TableName}}' + AND a.attname = '{{.ColumnName}}' + AND a.attnum > 0 + AND NOT a.attisdropped; + + IF current_type IS NOT NULL + AND current_type <> ALL(ARRAY[{{.EquivalentTypes}}]) THEN + ALTER TABLE {{qual_table .SchemaName .TableName}} + ALTER COLUMN {{quote_ident .ColumnName}} TYPE {{.NewType}}{{if .UsingExpr}} USING {{.UsingExpr}}{{end}}; + END IF; +END; +$$; diff --git a/pkg/writers/pgsql/writer.go b/pkg/writers/pgsql/writer.go index a4caac5..2b4e4d3 100644 --- a/pkg/writers/pgsql/writer.go +++ b/pkg/writers/pgsql/writer.go @@ -143,6 +143,10 @@ func (w *Writer) GenerateSchemaStatements(schema *models.Schema) ([]string, erro statements = append(statements, fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %s", schema.SQLName())) } + if schemaRequiresPGTrgm(schema) { + statements = append(statements, `CREATE EXTENSION IF NOT EXISTS pg_trgm`) + } + // Phase 2: Create sequences for _, table := range schema.Tables { pk := table.GetPrimaryKey() @@ -181,6 +185,12 @@ func (w *Writer) GenerateSchemaStatements(schema *models.Schema) ([]string, erro } statements = append(statements, addColStmts...) + alterTypeStmts, err := w.GenerateAlterColumnTypeStatements(schema) + if err != nil { + return nil, fmt.Errorf("failed to generate alter column type statements: %w", err) + } + statements = append(statements, alterTypeStmts...) + // Phase 4: Primary keys for _, table := range schema.Tables { // First check for explicit PrimaryKeyConstraint @@ -262,13 +272,10 @@ func (w *Writer) GenerateSchemaStatements(schema *models.Schema) ([]string, erro for _, colName := range index.Columns { colExpr := colName if col, ok := resolveIndexColumn(table, colName); ok { - // For GIN indexes on text columns, add operator class - if strings.EqualFold(indexType, "gin") && isTextType(col.Type) { - opClass := extractOperatorClass(index.Comment) - if opClass == "" { - opClass = "gin_trgm_ops" + if strings.EqualFold(indexType, "gin") { + if opClass := ginOperatorClassForColumn(col, index.Comment); opClass != "" { + colExpr = fmt.Sprintf("%s %s", colName, opClass) } - colExpr = fmt.Sprintf("%s %s", colName, opClass) } } columnExprs = append(columnExprs, colExpr) @@ -437,6 +444,33 @@ func (w *Writer) GenerateAddColumnStatements(schema *models.Schema) ([]string, e return statements, nil } +func (w *Writer) GenerateAlterColumnTypeStatements(schema *models.Schema) ([]string, error) { + statements := []string{} + + statements = append(statements, fmt.Sprintf("-- Alter column types for schema: %s", schema.Name)) + + for _, table := range schema.Tables { + columns := getSortedColumns(table.Columns) + for _, col := range columns { + targetType := effectiveAlterColumnSQLType(col) + stmt, err := w.executor.ExecuteAlterColumnTypeWithCheck(AlterColumnTypeWithCheckData{ + SchemaName: schema.Name, + TableName: table.Name, + ColumnName: col.Name, + NewType: targetType, + EquivalentTypes: equivalentTypeListSQL(targetType), + UsingExpr: buildAlterColumnUsingExpression(col.Name, targetType), + }) + if err != nil { + return nil, fmt.Errorf("failed to generate alter column type for %s.%s.%s: %w", schema.Name, table.Name, col.Name, err) + } + statements = append(statements, stmt) + } + } + + return statements, nil +} + // GenerateAddColumnsForDatabase generates ALTER TABLE ADD COLUMN statements for the entire database func (w *Writer) GenerateAddColumnsForDatabase(db *models.Database) ([]string, error) { statements := []string{} @@ -489,31 +523,7 @@ func (w *Writer) generateCreateTableStatement(schema *models.Schema, table *mode func (w *Writer) generateColumnDefinition(col *models.Column) string { parts := []string{col.SQLName()} - // Type with length/precision - convert to valid PostgreSQL type - baseType := pgsql.ConvertSQLType(col.Type) - typeStr := baseType - hasExplicitTypeModifier := pgsql.HasExplicitTypeModifier(baseType) - - // Only add size specifiers for types that support them - if !hasExplicitTypeModifier && col.Length > 0 && col.Precision == 0 { - if pgsql.SupportsLength(baseType) { - typeStr = fmt.Sprintf("%s(%d)", baseType, col.Length) - } else if isTextTypeWithoutLength(baseType) { - // Convert text with length to varchar - typeStr = fmt.Sprintf("varchar(%d)", col.Length) - } - // For types that don't support length (integer, bigint, etc.), ignore the length - } else if !hasExplicitTypeModifier && col.Precision > 0 { - if pgsql.SupportsPrecision(baseType) { - if col.Scale > 0 { - typeStr = fmt.Sprintf("%s(%d,%d)", baseType, col.Precision, col.Scale) - } else { - typeStr = fmt.Sprintf("%s(%d)", baseType, col.Precision) - } - } - // For types that don't support precision, ignore it - } - parts = append(parts, typeStr) + parts = append(parts, effectiveColumnSQLType(col)) // NOT NULL if col.NotNull { @@ -535,6 +545,64 @@ func (w *Writer) generateColumnDefinition(col *models.Column) string { return strings.Join(parts, " ") } +func effectiveColumnSQLType(col *models.Column) string { + if col == nil { + return "" + } + + baseType := pgsql.ConvertSQLType(col.Type) + typeStr := baseType + hasExplicitTypeModifier := pgsql.HasExplicitTypeModifier(baseType) + + if !hasExplicitTypeModifier && col.Length > 0 && col.Precision == 0 { + if pgsql.SupportsLength(baseType) { + typeStr = fmt.Sprintf("%s(%d)", baseType, col.Length) + } else if isTextTypeWithoutLength(baseType) { + typeStr = fmt.Sprintf("varchar(%d)", col.Length) + } + } else if !hasExplicitTypeModifier && col.Precision > 0 { + if pgsql.SupportsPrecision(baseType) { + if col.Scale > 0 { + typeStr = fmt.Sprintf("%s(%d,%d)", baseType, col.Precision, col.Scale) + } else { + typeStr = fmt.Sprintf("%s(%d)", baseType, col.Precision) + } + } + } + + return typeStr +} + +func effectiveAlterColumnSQLType(col *models.Column) string { + typeStr := effectiveColumnSQLType(col) + switch strings.ToLower(strings.TrimSpace(typeStr)) { + case "smallserial": + return "smallint" + case "serial": + return "integer" + case "bigserial": + return "bigint" + default: + return typeStr + } +} + +func buildAlterColumnUsingExpression(columnName, targetType string) string { + if strings.TrimSpace(columnName) == "" || strings.TrimSpace(targetType) == "" { + return "" + } + return fmt.Sprintf("%s::%s", quoteIdent(columnName), targetType) +} + +func equivalentTypeListSQL(sqlType string) string { + variants := pgsql.EquivalentSQLTypeVariants(sqlType) + quoted := make([]string, 0, len(variants)) + for _, variant := range variants { + quoted = append(quoted, fmt.Sprintf("'%s'", escapeQuote(variant))) + } + return strings.Join(quoted, ", ") +} + // WriteSchema writes a single schema and all its tables func (w *Writer) WriteSchema(schema *models.Schema) error { if w.writer == nil { @@ -546,6 +614,10 @@ func (w *Writer) WriteSchema(schema *models.Schema) error { return err } + if err := w.writeRequiredExtensions(schema); err != nil { + return err + } + // Phase 2: Create sequences (priority 80) if err := w.writeSequences(schema); err != nil { return err @@ -561,6 +633,10 @@ func (w *Writer) WriteSchema(schema *models.Schema) error { return err } + if err := w.writeAlterColumnTypes(schema); err != nil { + return err + } + // Phase 4: Create primary keys (priority 160) if err := w.writePrimaryKeys(schema); err != nil { return err @@ -661,6 +737,16 @@ func (w *Writer) writeCreateSchema(schema *models.Schema) error { return nil } +func (w *Writer) writeRequiredExtensions(schema *models.Schema) error { + if !schemaRequiresPGTrgm(schema) { + return nil + } + + fmt.Fprintln(w.writer, "CREATE EXTENSION IF NOT EXISTS pg_trgm;") + fmt.Fprintln(w.writer) + return nil +} + // writeSequences generates CREATE SEQUENCE statements for identity columns func (w *Writer) writeSequences(schema *models.Schema) error { fmt.Fprintf(w.writer, "-- Sequences for schema: %s\n", schema.Name) @@ -754,6 +840,21 @@ func (w *Writer) writeAddColumns(schema *models.Schema) error { return nil } +func (w *Writer) writeAlterColumnTypes(schema *models.Schema) error { + fmt.Fprintf(w.writer, "-- Alter column types for schema: %s\n", schema.Name) + + statements, err := w.GenerateAlterColumnTypeStatements(schema) + if err != nil { + return err + } + for _, stmt := range statements[1:] { + fmt.Fprint(w.writer, stmt) + fmt.Fprint(w.writer, "\n") + } + + return nil +} + // writePrimaryKeys generates ALTER TABLE statements for primary keys func (w *Writer) writePrimaryKeys(schema *models.Schema) error { fmt.Fprintf(w.writer, "-- Primary keys for schema: %s\n", schema.Name) @@ -857,13 +958,11 @@ func (w *Writer) writeIndexes(schema *models.Schema) error { for _, colName := range index.Columns { if col, ok := resolveIndexColumn(table, colName); ok { colExpr := col.SQLName() - // For GIN indexes on text columns, add operator class - if strings.EqualFold(index.Type, "gin") && isTextType(col.Type) { - opClass := extractOperatorClass(index.Comment) - if opClass == "" { - opClass = "gin_trgm_ops" + if strings.EqualFold(index.Type, "gin") { + opClass := ginOperatorClassForColumn(col, index.Comment) + if opClass != "" { + colExpr = fmt.Sprintf("%s %s", col.SQLName(), opClass) } - colExpr = fmt.Sprintf("%s %s", col.SQLName(), opClass) } columnExprs = append(columnExprs, colExpr) } @@ -1250,25 +1349,101 @@ func isIntegerType(colType string) bool { } // isTextType checks if a column type is a text type (for GIN index operator class) -func isTextType(colType string) bool { - textTypes := []string{"text", "varchar", "character varying", "char", "character", "string"} - lowerType := strings.ToLower(colType) - if strings.HasSuffix(lowerType, "[]") { - return false - } - for _, t := range textTypes { - if strings.HasPrefix(lowerType, t) { - return true - } - } - return false -} +// func isTextType(colType string) bool { +// textTypes := []string{"text", "varchar", "character varying", "char", "character", "string"} +// lowerType := strings.ToLower(colType) +// if strings.HasSuffix(lowerType, "[]") { +// return false +// } +// for _, t := range textTypes { +// if strings.HasPrefix(lowerType, t) { +// return true +// } +// } +// return false +// } // isTextTypeWithoutLength checks if type is text (which should convert to varchar when length is specified) func isTextTypeWithoutLength(colType string) bool { return strings.EqualFold(colType, "text") } +func ginOperatorClassForColumn(col *models.Column, comment string) string { + if col == nil { + return "" + } + + sqlType := effectiveColumnSQLType(col) + baseType := pgsql.CanonicalizeBaseType(pgsql.ExtractBaseTypeLower(sqlType)) + isArray := pgsql.IsArrayType(sqlType) + requested := extractOperatorClass(comment) + + if requested != "" && ginOperatorClassCompatible(baseType, isArray, requested) { + return requested + } + + if isArray { + return "array_ops" + } + + switch { + case isTextGinBaseType(baseType): + return "gin_trgm_ops" + case baseType == "jsonb": + return "jsonb_ops" + default: + return requested + } +} + +func ginOperatorClassCompatible(baseType string, isArray bool, opClass string) bool { + switch opClass { + case "gin_trgm_ops", "gin_bigm_ops": + return !isArray && isTextGinBaseType(baseType) + case "jsonb_ops", "jsonb_path_ops": + return !isArray && baseType == "jsonb" + case "array_ops": + return isArray + default: + return true + } +} + +func isTextGinBaseType(baseType string) bool { + switch baseType { + case "text", "varchar", "character varying", "char", "character", "string", "citext", "bpchar": + return true + default: + return false + } +} + +func schemaRequiresPGTrgm(schema *models.Schema) bool { + if schema == nil { + return false + } + for _, table := range schema.Tables { + if table == nil { + continue + } + for _, index := range table.Indexes { + if index == nil || !strings.EqualFold(index.Type, "gin") { + continue + } + for _, colName := range index.Columns { + col, ok := resolveIndexColumn(table, colName) + if !ok || col == nil { + continue + } + if ginOperatorClassForColumn(col, index.Comment) == "gin_trgm_ops" { + return true + } + } + } + } + return false +} + func resolveIndexColumn(table *models.Table, colName string) (*models.Column, bool) { if table == nil { return nil, false diff --git a/pkg/writers/pgsql/writer_test.go b/pkg/writers/pgsql/writer_test.go index 47eaa52..eeed3f2 100644 --- a/pkg/writers/pgsql/writer_test.go +++ b/pkg/writers/pgsql/writer_test.go @@ -116,8 +116,8 @@ func TestWriteDatabase_GinIndexOnTextArrayDoesNotUseTrigramOperatorClass(t *test } output := buf.String() - if !strings.Contains(output, `USING gin (tags)`) { - t.Fatalf("expected GIN index on array column without explicit trigram opclass, got:\n%s", output) + if !strings.Contains(output, `USING gin (tags array_ops)`) { + t.Fatalf("expected GIN index on array column with array_ops, got:\n%s", output) } if strings.Contains(output, "gin_trgm_ops") { t.Fatalf("did not expect gin_trgm_ops for text[] column, got:\n%s", output) @@ -153,11 +153,51 @@ func TestWriteDatabase_GinIndexOnQuotedTextColumnUsesTrigramOperatorClass(t *tes } output := buf.String() + if !strings.Contains(output, `CREATE EXTENSION IF NOT EXISTS pg_trgm`) { + t.Fatalf("expected trigram extension for text GIN index, got:\n%s", output) + } if !strings.Contains(output, `USING gin (name gin_trgm_ops)`) { t.Fatalf("expected quoted text GIN index to include gin_trgm_ops, got:\n%s", output) } } +func TestWriteDatabase_GinIndexOnJSONBUsesJSONBOperatorClass(t *testing.T) { + db := models.InitDatabase("testdb") + schema := models.InitSchema("public") + + table := models.InitTable("learnings", "public") + + detailsCol := models.InitColumn("details", "learnings", "public") + detailsCol.Type = "jsonb" + table.Columns["details"] = detailsCol + + index := &models.Index{ + Name: "idx_learnings_details", + Type: "gin", + Columns: []string{"details"}, + } + table.Indexes[index.Name] = index + + schema.Tables = append(schema.Tables, table) + db.Schemas = append(db.Schemas, schema) + + var buf bytes.Buffer + writer := NewWriter(&writers.WriterOptions{}) + writer.writer = &buf + + if err := writer.WriteDatabase(db); err != nil { + t.Fatalf("WriteDatabase failed: %v", err) + } + + output := buf.String() + if !strings.Contains(output, `USING gin (details jsonb_ops)`) { + t.Fatalf("expected GIN jsonb index to include jsonb_ops, got:\n%s", output) + } + if strings.Contains(output, "gin_trgm_ops") { + t.Fatalf("did not expect gin_trgm_ops for jsonb column, got:\n%s", output) + } +} + func TestWriteForeignKeys(t *testing.T) { // Create a test database with two related tables db := models.InitDatabase("testdb") @@ -1018,3 +1058,82 @@ func TestWriteAddColumnStatements(t *testing.T) { t.Errorf("Output missing DO block\nFull output:\n%s", output) } } + +func TestWriteSchema_EmitsGuardedAlterColumnTypeStatements(t *testing.T) { + db := models.InitDatabase("testdb") + schema := models.InitSchema("public") + + table := models.InitTable("agent_skills", "public") + + nameCol := models.InitColumn("name", "agent_skills", "public") + nameCol.Type = "character varying" + nameCol.Length = 255 + table.Columns["name"] = nameCol + + tagsCol := models.InitColumn("tags", "agent_skills", "public") + tagsCol.Type = "text[]" + table.Columns["tags"] = tagsCol + + schema.Tables = append(schema.Tables, table) + db.Schemas = append(db.Schemas, schema) + + var buf bytes.Buffer + writer := NewWriter(&writers.WriterOptions{}) + writer.writer = &buf + + if err := writer.WriteDatabase(db); err != nil { + t.Fatalf("WriteDatabase failed: %v", err) + } + + output := buf.String() + if !strings.Contains(output, "-- Alter column types for schema: public") { + t.Fatalf("expected alter column type section, got:\n%s", output) + } + if !strings.Contains(output, "pg_catalog.format_type") { + t.Fatalf("expected guarded live-type check, got:\n%s", output) + } + if !strings.Contains(output, "ALTER COLUMN name TYPE character varying(255)") { + t.Fatalf("expected guarded alter for character varying(255), got:\n%s", output) + } + if !strings.Contains(output, "ARRAY['varchar(255)', 'character varying(255)']") { + t.Fatalf("expected equivalent type spellings for varchar guard, got:\n%s", output) + } + if !strings.Contains(output, "ALTER COLUMN tags TYPE text[]") { + t.Fatalf("expected guarded alter for array type, got:\n%s", output) + } + if !strings.Contains(output, `ALTER COLUMN tags TYPE text[] USING tags::text[];`) { + t.Fatalf("expected guarded alter for array type to include USING cast, got:\n%s", output) + } +} + +func TestWriteSchema_UsesStorageTypeForSerialAlterStatements(t *testing.T) { + db := models.InitDatabase("testdb") + schema := models.InitSchema("public") + + table := models.InitTable("learnings", "public") + idCol := models.InitColumn("id", "learnings", "public") + idCol.Type = "bigserial" + table.Columns["id"] = idCol + + schema.Tables = append(schema.Tables, table) + db.Schemas = append(db.Schemas, schema) + + var buf bytes.Buffer + writer := NewWriter(&writers.WriterOptions{}) + writer.writer = &buf + + if err := writer.WriteDatabase(db); err != nil { + t.Fatalf("WriteDatabase failed: %v", err) + } + + output := buf.String() + if !strings.Contains(output, "ALTER COLUMN id TYPE bigint") { + t.Fatalf("expected serial alter to use bigint storage type, got:\n%s", output) + } + if strings.Contains(output, "ALTER COLUMN id TYPE bigserial;") { + t.Fatalf("did not expect invalid bigserial alter statement, got:\n%s", output) + } + if !strings.Contains(output, `ALTER COLUMN id TYPE bigint USING id::bigint;`) { + t.Fatalf("expected serial alter to include USING cast, got:\n%s", output) + } +}