feat: Enhance PostgreSQL type handling and migration scripts
- Introduced equivalent base types and variants for PostgreSQL types to normalize type comparisons. - Added functions for normalizing SQL types and retrieving equivalent type variants. - Updated migration writer to handle type alterations with checks for existing types. - Implemented logic to create necessary extensions (e.g., pg_trgm) based on schema requirements. - Enhanced tests to cover new functionality for type normalization and migration handling. - Improved handling of GIN indexes to use appropriate operator classes based on column types.
This commit is contained in:
@@ -160,6 +160,17 @@ func (w *MigrationWriter) WriteMigration(model *models.Database, current *models
|
||||
func (w *MigrationWriter) generateSchemaScripts(model *models.Schema, current *models.Schema) ([]MigrationScript, error) {
|
||||
scripts := make([]MigrationScript, 0)
|
||||
|
||||
if schemaRequiresPGTrgm(model) {
|
||||
scripts = append(scripts, MigrationScript{
|
||||
ObjectName: "extension.pg_trgm",
|
||||
ObjectType: "create extension",
|
||||
Schema: model.Name,
|
||||
Priority: 80,
|
||||
Sequence: len(scripts),
|
||||
Body: "CREATE EXTENSION IF NOT EXISTS pg_trgm;",
|
||||
})
|
||||
}
|
||||
|
||||
// Phase 1: Drop constraints and indexes that changed (Priority 11-50)
|
||||
if current != nil {
|
||||
dropScripts, err := w.generateDropScripts(model, current)
|
||||
@@ -361,7 +372,7 @@ func (w *MigrationWriter) generateAlterTableScripts(schema *models.Schema, model
|
||||
SchemaName: schema.Name,
|
||||
TableName: modelTable.Name,
|
||||
ColumnName: modelCol.Name,
|
||||
ColumnType: pgsql.ConvertSQLType(modelCol.Type),
|
||||
ColumnType: effectiveColumnSQLType(modelCol),
|
||||
Default: defaultVal,
|
||||
NotNull: modelCol.NotNull,
|
||||
})
|
||||
@@ -380,12 +391,13 @@ func (w *MigrationWriter) generateAlterTableScripts(schema *models.Schema, model
|
||||
scripts = append(scripts, script)
|
||||
} else if !columnsEqual(modelCol, currentCol) {
|
||||
// Column exists but properties changed
|
||||
if modelCol.Type != currentCol.Type {
|
||||
if !columnTypesEqual(modelCol, currentCol) {
|
||||
sql, err := w.executor.ExecuteAlterColumnType(AlterColumnTypeData{
|
||||
SchemaName: schema.Name,
|
||||
TableName: modelTable.Name,
|
||||
ColumnName: modelCol.Name,
|
||||
NewType: pgsql.ConvertSQLType(modelCol.Type),
|
||||
NewType: effectiveAlterColumnSQLType(modelCol),
|
||||
UsingExpr: buildAlterColumnUsingExpression(modelCol.Name, effectiveAlterColumnSQLType(modelCol)),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -606,12 +618,11 @@ func buildIndexColumnExpressions(table *models.Table, index *models.Index, index
|
||||
if table != nil {
|
||||
if col, ok := resolveIndexColumn(table, colName); ok && col != nil {
|
||||
colExpr = col.SQLName()
|
||||
if strings.EqualFold(indexType, "gin") && isTextType(col.Type) {
|
||||
opClass := extractOperatorClass(index.Comment)
|
||||
if opClass == "" {
|
||||
opClass = "gin_trgm_ops"
|
||||
if strings.EqualFold(indexType, "gin") {
|
||||
opClass := ginOperatorClassForColumn(col, index.Comment)
|
||||
if opClass != "" {
|
||||
colExpr = fmt.Sprintf("%s %s", col.SQLName(), opClass)
|
||||
}
|
||||
colExpr = fmt.Sprintf("%s %s", col.SQLName(), opClass)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -875,11 +886,21 @@ func columnsEqual(col1, col2 *models.Column) bool {
|
||||
if col1 == nil || col2 == nil {
|
||||
return false
|
||||
}
|
||||
return strings.EqualFold(col1.Type, col2.Type) &&
|
||||
return columnTypesEqual(col1, col2) &&
|
||||
col1.NotNull == col2.NotNull &&
|
||||
fmt.Sprintf("%v", col1.Default) == fmt.Sprintf("%v", col2.Default)
|
||||
}
|
||||
|
||||
func columnTypesEqual(col1, col2 *models.Column) bool {
|
||||
if col1 == nil || col2 == nil {
|
||||
return false
|
||||
}
|
||||
return strings.EqualFold(
|
||||
pgsql.NormalizeEquivalentSQLType(effectiveColumnSQLType(col1)),
|
||||
pgsql.NormalizeEquivalentSQLType(effectiveColumnSQLType(col2)),
|
||||
)
|
||||
}
|
||||
|
||||
// constraintsEqual checks if two constraints are equal
|
||||
func constraintsEqual(c1, c2 *models.Constraint) bool {
|
||||
if c1 == nil || c2 == nil {
|
||||
|
||||
@@ -97,6 +97,160 @@ func TestWriteMigration_ArrayDefault(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteMigration_AltersColumnTypeWhenActualTypeDiffers(t *testing.T) {
|
||||
current := models.InitDatabase("testdb")
|
||||
currentSchema := models.InitSchema("public")
|
||||
currentTable := models.InitTable("learnings", "public")
|
||||
currentDetails := models.InitColumn("details", "learnings", "public")
|
||||
currentDetails.Type = "jsonb"
|
||||
currentTable.Columns["details"] = currentDetails
|
||||
currentSchema.Tables = append(currentSchema.Tables, currentTable)
|
||||
current.Schemas = append(current.Schemas, currentSchema)
|
||||
|
||||
model := models.InitDatabase("testdb")
|
||||
modelSchema := models.InitSchema("public")
|
||||
modelTable := models.InitTable("learnings", "public")
|
||||
modelDetails := models.InitColumn("details", "learnings", "public")
|
||||
modelDetails.Type = "text"
|
||||
modelTable.Columns["details"] = modelDetails
|
||||
modelSchema.Tables = append(modelSchema.Tables, modelTable)
|
||||
model.Schemas = append(model.Schemas, modelSchema)
|
||||
|
||||
var buf bytes.Buffer
|
||||
writer, err := NewMigrationWriter(&writers.WriterOptions{})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create writer: %v", err)
|
||||
}
|
||||
writer.writer = &buf
|
||||
|
||||
if err := writer.WriteMigration(model, current); err != nil {
|
||||
t.Fatalf("WriteMigration failed: %v", err)
|
||||
}
|
||||
|
||||
output := buf.String()
|
||||
if !strings.Contains(output, "ALTER TABLE public.learnings") || !strings.Contains(output, "ALTER COLUMN details TYPE text") {
|
||||
t.Fatalf("expected migration to alter mismatched column type, got:\n%s", output)
|
||||
}
|
||||
if !strings.Contains(output, `ALTER COLUMN details TYPE text USING details::text;`) {
|
||||
t.Fatalf("expected migration type alter to include USING cast, got:\n%s", output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteMigration_UsesStorageTypeForSerialAlterStatements(t *testing.T) {
|
||||
current := models.InitDatabase("testdb")
|
||||
currentSchema := models.InitSchema("public")
|
||||
currentTable := models.InitTable("learnings", "public")
|
||||
currentID := models.InitColumn("id", "learnings", "public")
|
||||
currentID.Type = "uuid"
|
||||
currentTable.Columns["id"] = currentID
|
||||
currentSchema.Tables = append(currentSchema.Tables, currentTable)
|
||||
current.Schemas = append(current.Schemas, currentSchema)
|
||||
|
||||
model := models.InitDatabase("testdb")
|
||||
modelSchema := models.InitSchema("public")
|
||||
modelTable := models.InitTable("learnings", "public")
|
||||
modelID := models.InitColumn("id", "learnings", "public")
|
||||
modelID.Type = "bigserial"
|
||||
modelTable.Columns["id"] = modelID
|
||||
modelSchema.Tables = append(modelSchema.Tables, modelTable)
|
||||
model.Schemas = append(model.Schemas, modelSchema)
|
||||
|
||||
var buf bytes.Buffer
|
||||
writer, err := NewMigrationWriter(&writers.WriterOptions{})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create writer: %v", err)
|
||||
}
|
||||
writer.writer = &buf
|
||||
|
||||
if err := writer.WriteMigration(model, current); err != nil {
|
||||
t.Fatalf("WriteMigration failed: %v", err)
|
||||
}
|
||||
|
||||
output := buf.String()
|
||||
if !strings.Contains(output, "ALTER COLUMN id TYPE bigint") {
|
||||
t.Fatalf("expected serial alter to use bigint storage type, got:\n%s", output)
|
||||
}
|
||||
if strings.Contains(output, "ALTER COLUMN id TYPE bigserial;") {
|
||||
t.Fatalf("did not expect invalid bigserial alter statement, got:\n%s", output)
|
||||
}
|
||||
if !strings.Contains(output, `ALTER COLUMN id TYPE bigint USING id::bigint;`) {
|
||||
t.Fatalf("expected serial alter to include USING cast, got:\n%s", output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteMigration_ArrayAlterIncludesUsingCast(t *testing.T) {
|
||||
current := models.InitDatabase("testdb")
|
||||
currentSchema := models.InitSchema("public")
|
||||
currentTable := models.InitTable("learnings", "public")
|
||||
currentTags := models.InitColumn("tags", "learnings", "public")
|
||||
currentTags.Type = "text"
|
||||
currentTable.Columns["tags"] = currentTags
|
||||
currentSchema.Tables = append(currentSchema.Tables, currentTable)
|
||||
current.Schemas = append(current.Schemas, currentSchema)
|
||||
|
||||
model := models.InitDatabase("testdb")
|
||||
modelSchema := models.InitSchema("public")
|
||||
modelTable := models.InitTable("learnings", "public")
|
||||
modelTags := models.InitColumn("tags", "learnings", "public")
|
||||
modelTags.Type = "text[]"
|
||||
modelTable.Columns["tags"] = modelTags
|
||||
modelSchema.Tables = append(modelSchema.Tables, modelTable)
|
||||
model.Schemas = append(model.Schemas, modelSchema)
|
||||
|
||||
var buf bytes.Buffer
|
||||
writer, err := NewMigrationWriter(&writers.WriterOptions{})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create writer: %v", err)
|
||||
}
|
||||
writer.writer = &buf
|
||||
|
||||
if err := writer.WriteMigration(model, current); err != nil {
|
||||
t.Fatalf("WriteMigration failed: %v", err)
|
||||
}
|
||||
|
||||
output := buf.String()
|
||||
if !strings.Contains(output, `ALTER COLUMN tags TYPE text[] USING tags::text[];`) {
|
||||
t.Fatalf("expected array alter to include USING cast, got:\n%s", output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteMigration_DoesNotAlterEquivalentNormalizedColumnType(t *testing.T) {
|
||||
current := models.InitDatabase("testdb")
|
||||
currentSchema := models.InitSchema("public")
|
||||
currentTable := models.InitTable("users", "public")
|
||||
currentEmail := models.InitColumn("email", "users", "public")
|
||||
currentEmail.Type = "character varying"
|
||||
currentEmail.Length = 255
|
||||
currentTable.Columns["email"] = currentEmail
|
||||
currentSchema.Tables = append(currentSchema.Tables, currentTable)
|
||||
current.Schemas = append(current.Schemas, currentSchema)
|
||||
|
||||
model := models.InitDatabase("testdb")
|
||||
modelSchema := models.InitSchema("public")
|
||||
modelTable := models.InitTable("users", "public")
|
||||
modelEmail := models.InitColumn("email", "users", "public")
|
||||
modelEmail.Type = "varchar(255)"
|
||||
modelTable.Columns["email"] = modelEmail
|
||||
modelSchema.Tables = append(modelSchema.Tables, modelTable)
|
||||
model.Schemas = append(model.Schemas, modelSchema)
|
||||
|
||||
var buf bytes.Buffer
|
||||
writer, err := NewMigrationWriter(&writers.WriterOptions{})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create writer: %v", err)
|
||||
}
|
||||
writer.writer = &buf
|
||||
|
||||
if err := writer.WriteMigration(model, current); err != nil {
|
||||
t.Fatalf("WriteMigration failed: %v", err)
|
||||
}
|
||||
|
||||
output := buf.String()
|
||||
if strings.Contains(output, "ALTER COLUMN email TYPE") {
|
||||
t.Fatalf("did not expect alter type for equivalent normalized types, got:\n%s", output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteMigration_GinIndexOnTextUsesTrigramOperatorClass(t *testing.T) {
|
||||
current := models.InitDatabase("testdb")
|
||||
currentSchema := models.InitSchema("public")
|
||||
@@ -132,6 +286,9 @@ func TestWriteMigration_GinIndexOnTextUsesTrigramOperatorClass(t *testing.T) {
|
||||
}
|
||||
|
||||
output := buf.String()
|
||||
if !strings.Contains(output, "CREATE EXTENSION IF NOT EXISTS pg_trgm;") {
|
||||
t.Fatalf("expected trigram extension for text GIN migration index, got:\n%s", output)
|
||||
}
|
||||
if !strings.Contains(output, "USING gin (title gin_trgm_ops)") {
|
||||
t.Fatalf("expected GIN text index to include gin_trgm_ops, got:\n%s", output)
|
||||
}
|
||||
@@ -212,14 +369,98 @@ func TestWriteMigration_GinIndexOnTextArrayDoesNotUseTrigramOperatorClass(t *tes
|
||||
}
|
||||
|
||||
output := buf.String()
|
||||
if !strings.Contains(output, "USING gin (tags)") {
|
||||
t.Fatalf("expected GIN array index without explicit trigram opclass, got:\n%s", output)
|
||||
if !strings.Contains(output, "USING gin (tags array_ops)") {
|
||||
t.Fatalf("expected GIN array index with array_ops, got:\n%s", output)
|
||||
}
|
||||
if strings.Contains(output, "gin_trgm_ops") {
|
||||
t.Fatalf("did not expect gin_trgm_ops for text[] migration index, got:\n%s", output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteMigration_GinIndexOnJSONBUsesJSONBOperatorClass(t *testing.T) {
|
||||
current := models.InitDatabase("testdb")
|
||||
currentSchema := models.InitSchema("public")
|
||||
current.Schemas = append(current.Schemas, currentSchema)
|
||||
|
||||
model := models.InitDatabase("testdb")
|
||||
modelSchema := models.InitSchema("public")
|
||||
|
||||
table := models.InitTable("learnings", "public")
|
||||
detailsCol := models.InitColumn("details", "learnings", "public")
|
||||
detailsCol.Type = "jsonb"
|
||||
table.Columns["details"] = detailsCol
|
||||
|
||||
index := &models.Index{
|
||||
Name: "idx_learnings_details",
|
||||
Type: "gin",
|
||||
Columns: []string{"details"},
|
||||
}
|
||||
table.Indexes[index.Name] = index
|
||||
|
||||
modelSchema.Tables = append(modelSchema.Tables, table)
|
||||
model.Schemas = append(model.Schemas, modelSchema)
|
||||
|
||||
var buf bytes.Buffer
|
||||
writer, err := NewMigrationWriter(&writers.WriterOptions{})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create writer: %v", err)
|
||||
}
|
||||
writer.writer = &buf
|
||||
|
||||
if err := writer.WriteMigration(model, current); err != nil {
|
||||
t.Fatalf("WriteMigration failed: %v", err)
|
||||
}
|
||||
|
||||
output := buf.String()
|
||||
if !strings.Contains(output, "USING gin (details jsonb_ops)") {
|
||||
t.Fatalf("expected GIN jsonb index to include jsonb_ops, got:\n%s", output)
|
||||
}
|
||||
if strings.Contains(output, "gin_trgm_ops") {
|
||||
t.Fatalf("did not expect gin_trgm_ops for jsonb migration index, got:\n%s", output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteMigration_GinIndexOnJSONBIgnoresIncompatibleTrigramOperatorClass(t *testing.T) {
|
||||
current := models.InitDatabase("testdb")
|
||||
currentSchema := models.InitSchema("public")
|
||||
current.Schemas = append(current.Schemas, currentSchema)
|
||||
|
||||
model := models.InitDatabase("testdb")
|
||||
modelSchema := models.InitSchema("public")
|
||||
|
||||
table := models.InitTable("learnings", "public")
|
||||
detailsCol := models.InitColumn("details", "learnings", "public")
|
||||
detailsCol.Type = "jsonb"
|
||||
table.Columns["details"] = detailsCol
|
||||
|
||||
index := &models.Index{
|
||||
Name: "idx_learnings_details",
|
||||
Type: "gin",
|
||||
Columns: []string{"details"},
|
||||
Comment: "gin_trgm_ops",
|
||||
}
|
||||
table.Indexes[index.Name] = index
|
||||
|
||||
modelSchema.Tables = append(modelSchema.Tables, table)
|
||||
model.Schemas = append(model.Schemas, modelSchema)
|
||||
|
||||
var buf bytes.Buffer
|
||||
writer, err := NewMigrationWriter(&writers.WriterOptions{})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create writer: %v", err)
|
||||
}
|
||||
writer.writer = &buf
|
||||
|
||||
if err := writer.WriteMigration(model, current); err != nil {
|
||||
t.Fatalf("WriteMigration failed: %v", err)
|
||||
}
|
||||
|
||||
output := buf.String()
|
||||
if !strings.Contains(output, "USING gin (details jsonb_ops)") {
|
||||
t.Fatalf("expected incompatible trigram hint on jsonb to fall back to jsonb_ops, got:\n%s", output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteMigration_WithAudit(t *testing.T) {
|
||||
// Current database (empty)
|
||||
current := models.InitDatabase("testdb")
|
||||
|
||||
@@ -95,6 +95,16 @@ type AlterColumnTypeData struct {
|
||||
TableName string
|
||||
ColumnName string
|
||||
NewType string
|
||||
UsingExpr string
|
||||
}
|
||||
|
||||
type AlterColumnTypeWithCheckData struct {
|
||||
SchemaName string
|
||||
TableName string
|
||||
ColumnName string
|
||||
NewType string
|
||||
EquivalentTypes string
|
||||
UsingExpr string
|
||||
}
|
||||
|
||||
// AlterColumnDefaultData contains data for alter column default template
|
||||
@@ -302,6 +312,15 @@ func (te *TemplateExecutor) ExecuteAlterColumnType(data AlterColumnTypeData) (st
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
func (te *TemplateExecutor) ExecuteAlterColumnTypeWithCheck(data AlterColumnTypeWithCheckData) (string, error) {
|
||||
var buf bytes.Buffer
|
||||
err := te.templates.ExecuteTemplate(&buf, "alter_column_type_with_check.tmpl", data)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to execute alter_column_type_with_check template: %w", err)
|
||||
}
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
// ExecuteAlterColumnDefault executes the alter column default template
|
||||
func (te *TemplateExecutor) ExecuteAlterColumnDefault(data AlterColumnDefaultData) (string, error) {
|
||||
var buf bytes.Buffer
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
ALTER TABLE {{qual_table .SchemaName .TableName}}
|
||||
ALTER COLUMN {{quote_ident .ColumnName}} TYPE {{.NewType}};
|
||||
ALTER COLUMN {{quote_ident .ColumnName}} TYPE {{.NewType}}{{if .UsingExpr}} USING {{.UsingExpr}}{{end}};
|
||||
|
||||
@@ -0,0 +1,22 @@
|
||||
DO $$
|
||||
DECLARE
|
||||
current_type text;
|
||||
BEGIN
|
||||
SELECT pg_catalog.format_type(a.atttypid, a.atttypmod)
|
||||
INTO current_type
|
||||
FROM pg_attribute a
|
||||
JOIN pg_class t ON t.oid = a.attrelid
|
||||
JOIN pg_namespace n ON n.oid = t.relnamespace
|
||||
WHERE n.nspname = '{{.SchemaName}}'
|
||||
AND t.relname = '{{.TableName}}'
|
||||
AND a.attname = '{{.ColumnName}}'
|
||||
AND a.attnum > 0
|
||||
AND NOT a.attisdropped;
|
||||
|
||||
IF current_type IS NOT NULL
|
||||
AND current_type <> ALL(ARRAY[{{.EquivalentTypes}}]) THEN
|
||||
ALTER TABLE {{qual_table .SchemaName .TableName}}
|
||||
ALTER COLUMN {{quote_ident .ColumnName}} TYPE {{.NewType}}{{if .UsingExpr}} USING {{.UsingExpr}}{{end}};
|
||||
END IF;
|
||||
END;
|
||||
$$;
|
||||
@@ -143,6 +143,10 @@ func (w *Writer) GenerateSchemaStatements(schema *models.Schema) ([]string, erro
|
||||
statements = append(statements, fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %s", schema.SQLName()))
|
||||
}
|
||||
|
||||
if schemaRequiresPGTrgm(schema) {
|
||||
statements = append(statements, `CREATE EXTENSION IF NOT EXISTS pg_trgm`)
|
||||
}
|
||||
|
||||
// Phase 2: Create sequences
|
||||
for _, table := range schema.Tables {
|
||||
pk := table.GetPrimaryKey()
|
||||
@@ -181,6 +185,12 @@ func (w *Writer) GenerateSchemaStatements(schema *models.Schema) ([]string, erro
|
||||
}
|
||||
statements = append(statements, addColStmts...)
|
||||
|
||||
alterTypeStmts, err := w.GenerateAlterColumnTypeStatements(schema)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate alter column type statements: %w", err)
|
||||
}
|
||||
statements = append(statements, alterTypeStmts...)
|
||||
|
||||
// Phase 4: Primary keys
|
||||
for _, table := range schema.Tables {
|
||||
// First check for explicit PrimaryKeyConstraint
|
||||
@@ -262,13 +272,10 @@ func (w *Writer) GenerateSchemaStatements(schema *models.Schema) ([]string, erro
|
||||
for _, colName := range index.Columns {
|
||||
colExpr := colName
|
||||
if col, ok := resolveIndexColumn(table, colName); ok {
|
||||
// For GIN indexes on text columns, add operator class
|
||||
if strings.EqualFold(indexType, "gin") && isTextType(col.Type) {
|
||||
opClass := extractOperatorClass(index.Comment)
|
||||
if opClass == "" {
|
||||
opClass = "gin_trgm_ops"
|
||||
if strings.EqualFold(indexType, "gin") {
|
||||
if opClass := ginOperatorClassForColumn(col, index.Comment); opClass != "" {
|
||||
colExpr = fmt.Sprintf("%s %s", colName, opClass)
|
||||
}
|
||||
colExpr = fmt.Sprintf("%s %s", colName, opClass)
|
||||
}
|
||||
}
|
||||
columnExprs = append(columnExprs, colExpr)
|
||||
@@ -437,6 +444,33 @@ func (w *Writer) GenerateAddColumnStatements(schema *models.Schema) ([]string, e
|
||||
return statements, nil
|
||||
}
|
||||
|
||||
func (w *Writer) GenerateAlterColumnTypeStatements(schema *models.Schema) ([]string, error) {
|
||||
statements := []string{}
|
||||
|
||||
statements = append(statements, fmt.Sprintf("-- Alter column types for schema: %s", schema.Name))
|
||||
|
||||
for _, table := range schema.Tables {
|
||||
columns := getSortedColumns(table.Columns)
|
||||
for _, col := range columns {
|
||||
targetType := effectiveAlterColumnSQLType(col)
|
||||
stmt, err := w.executor.ExecuteAlterColumnTypeWithCheck(AlterColumnTypeWithCheckData{
|
||||
SchemaName: schema.Name,
|
||||
TableName: table.Name,
|
||||
ColumnName: col.Name,
|
||||
NewType: targetType,
|
||||
EquivalentTypes: equivalentTypeListSQL(targetType),
|
||||
UsingExpr: buildAlterColumnUsingExpression(col.Name, targetType),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate alter column type for %s.%s.%s: %w", schema.Name, table.Name, col.Name, err)
|
||||
}
|
||||
statements = append(statements, stmt)
|
||||
}
|
||||
}
|
||||
|
||||
return statements, nil
|
||||
}
|
||||
|
||||
// GenerateAddColumnsForDatabase generates ALTER TABLE ADD COLUMN statements for the entire database
|
||||
func (w *Writer) GenerateAddColumnsForDatabase(db *models.Database) ([]string, error) {
|
||||
statements := []string{}
|
||||
@@ -489,31 +523,7 @@ func (w *Writer) generateCreateTableStatement(schema *models.Schema, table *mode
|
||||
func (w *Writer) generateColumnDefinition(col *models.Column) string {
|
||||
parts := []string{col.SQLName()}
|
||||
|
||||
// Type with length/precision - convert to valid PostgreSQL type
|
||||
baseType := pgsql.ConvertSQLType(col.Type)
|
||||
typeStr := baseType
|
||||
hasExplicitTypeModifier := pgsql.HasExplicitTypeModifier(baseType)
|
||||
|
||||
// Only add size specifiers for types that support them
|
||||
if !hasExplicitTypeModifier && col.Length > 0 && col.Precision == 0 {
|
||||
if pgsql.SupportsLength(baseType) {
|
||||
typeStr = fmt.Sprintf("%s(%d)", baseType, col.Length)
|
||||
} else if isTextTypeWithoutLength(baseType) {
|
||||
// Convert text with length to varchar
|
||||
typeStr = fmt.Sprintf("varchar(%d)", col.Length)
|
||||
}
|
||||
// For types that don't support length (integer, bigint, etc.), ignore the length
|
||||
} else if !hasExplicitTypeModifier && col.Precision > 0 {
|
||||
if pgsql.SupportsPrecision(baseType) {
|
||||
if col.Scale > 0 {
|
||||
typeStr = fmt.Sprintf("%s(%d,%d)", baseType, col.Precision, col.Scale)
|
||||
} else {
|
||||
typeStr = fmt.Sprintf("%s(%d)", baseType, col.Precision)
|
||||
}
|
||||
}
|
||||
// For types that don't support precision, ignore it
|
||||
}
|
||||
parts = append(parts, typeStr)
|
||||
parts = append(parts, effectiveColumnSQLType(col))
|
||||
|
||||
// NOT NULL
|
||||
if col.NotNull {
|
||||
@@ -535,6 +545,64 @@ func (w *Writer) generateColumnDefinition(col *models.Column) string {
|
||||
return strings.Join(parts, " ")
|
||||
}
|
||||
|
||||
func effectiveColumnSQLType(col *models.Column) string {
|
||||
if col == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
baseType := pgsql.ConvertSQLType(col.Type)
|
||||
typeStr := baseType
|
||||
hasExplicitTypeModifier := pgsql.HasExplicitTypeModifier(baseType)
|
||||
|
||||
if !hasExplicitTypeModifier && col.Length > 0 && col.Precision == 0 {
|
||||
if pgsql.SupportsLength(baseType) {
|
||||
typeStr = fmt.Sprintf("%s(%d)", baseType, col.Length)
|
||||
} else if isTextTypeWithoutLength(baseType) {
|
||||
typeStr = fmt.Sprintf("varchar(%d)", col.Length)
|
||||
}
|
||||
} else if !hasExplicitTypeModifier && col.Precision > 0 {
|
||||
if pgsql.SupportsPrecision(baseType) {
|
||||
if col.Scale > 0 {
|
||||
typeStr = fmt.Sprintf("%s(%d,%d)", baseType, col.Precision, col.Scale)
|
||||
} else {
|
||||
typeStr = fmt.Sprintf("%s(%d)", baseType, col.Precision)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return typeStr
|
||||
}
|
||||
|
||||
func effectiveAlterColumnSQLType(col *models.Column) string {
|
||||
typeStr := effectiveColumnSQLType(col)
|
||||
switch strings.ToLower(strings.TrimSpace(typeStr)) {
|
||||
case "smallserial":
|
||||
return "smallint"
|
||||
case "serial":
|
||||
return "integer"
|
||||
case "bigserial":
|
||||
return "bigint"
|
||||
default:
|
||||
return typeStr
|
||||
}
|
||||
}
|
||||
|
||||
func buildAlterColumnUsingExpression(columnName, targetType string) string {
|
||||
if strings.TrimSpace(columnName) == "" || strings.TrimSpace(targetType) == "" {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("%s::%s", quoteIdent(columnName), targetType)
|
||||
}
|
||||
|
||||
func equivalentTypeListSQL(sqlType string) string {
|
||||
variants := pgsql.EquivalentSQLTypeVariants(sqlType)
|
||||
quoted := make([]string, 0, len(variants))
|
||||
for _, variant := range variants {
|
||||
quoted = append(quoted, fmt.Sprintf("'%s'", escapeQuote(variant)))
|
||||
}
|
||||
return strings.Join(quoted, ", ")
|
||||
}
|
||||
|
||||
// WriteSchema writes a single schema and all its tables
|
||||
func (w *Writer) WriteSchema(schema *models.Schema) error {
|
||||
if w.writer == nil {
|
||||
@@ -546,6 +614,10 @@ func (w *Writer) WriteSchema(schema *models.Schema) error {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := w.writeRequiredExtensions(schema); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Phase 2: Create sequences (priority 80)
|
||||
if err := w.writeSequences(schema); err != nil {
|
||||
return err
|
||||
@@ -561,6 +633,10 @@ func (w *Writer) WriteSchema(schema *models.Schema) error {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := w.writeAlterColumnTypes(schema); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Phase 4: Create primary keys (priority 160)
|
||||
if err := w.writePrimaryKeys(schema); err != nil {
|
||||
return err
|
||||
@@ -661,6 +737,16 @@ func (w *Writer) writeCreateSchema(schema *models.Schema) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *Writer) writeRequiredExtensions(schema *models.Schema) error {
|
||||
if !schemaRequiresPGTrgm(schema) {
|
||||
return nil
|
||||
}
|
||||
|
||||
fmt.Fprintln(w.writer, "CREATE EXTENSION IF NOT EXISTS pg_trgm;")
|
||||
fmt.Fprintln(w.writer)
|
||||
return nil
|
||||
}
|
||||
|
||||
// writeSequences generates CREATE SEQUENCE statements for identity columns
|
||||
func (w *Writer) writeSequences(schema *models.Schema) error {
|
||||
fmt.Fprintf(w.writer, "-- Sequences for schema: %s\n", schema.Name)
|
||||
@@ -754,6 +840,21 @@ func (w *Writer) writeAddColumns(schema *models.Schema) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *Writer) writeAlterColumnTypes(schema *models.Schema) error {
|
||||
fmt.Fprintf(w.writer, "-- Alter column types for schema: %s\n", schema.Name)
|
||||
|
||||
statements, err := w.GenerateAlterColumnTypeStatements(schema)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, stmt := range statements[1:] {
|
||||
fmt.Fprint(w.writer, stmt)
|
||||
fmt.Fprint(w.writer, "\n")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// writePrimaryKeys generates ALTER TABLE statements for primary keys
|
||||
func (w *Writer) writePrimaryKeys(schema *models.Schema) error {
|
||||
fmt.Fprintf(w.writer, "-- Primary keys for schema: %s\n", schema.Name)
|
||||
@@ -857,13 +958,11 @@ func (w *Writer) writeIndexes(schema *models.Schema) error {
|
||||
for _, colName := range index.Columns {
|
||||
if col, ok := resolveIndexColumn(table, colName); ok {
|
||||
colExpr := col.SQLName()
|
||||
// For GIN indexes on text columns, add operator class
|
||||
if strings.EqualFold(index.Type, "gin") && isTextType(col.Type) {
|
||||
opClass := extractOperatorClass(index.Comment)
|
||||
if opClass == "" {
|
||||
opClass = "gin_trgm_ops"
|
||||
if strings.EqualFold(index.Type, "gin") {
|
||||
opClass := ginOperatorClassForColumn(col, index.Comment)
|
||||
if opClass != "" {
|
||||
colExpr = fmt.Sprintf("%s %s", col.SQLName(), opClass)
|
||||
}
|
||||
colExpr = fmt.Sprintf("%s %s", col.SQLName(), opClass)
|
||||
}
|
||||
columnExprs = append(columnExprs, colExpr)
|
||||
}
|
||||
@@ -1250,25 +1349,101 @@ func isIntegerType(colType string) bool {
|
||||
}
|
||||
|
||||
// isTextType checks if a column type is a text type (for GIN index operator class)
|
||||
func isTextType(colType string) bool {
|
||||
textTypes := []string{"text", "varchar", "character varying", "char", "character", "string"}
|
||||
lowerType := strings.ToLower(colType)
|
||||
if strings.HasSuffix(lowerType, "[]") {
|
||||
return false
|
||||
}
|
||||
for _, t := range textTypes {
|
||||
if strings.HasPrefix(lowerType, t) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
// func isTextType(colType string) bool {
|
||||
// textTypes := []string{"text", "varchar", "character varying", "char", "character", "string"}
|
||||
// lowerType := strings.ToLower(colType)
|
||||
// if strings.HasSuffix(lowerType, "[]") {
|
||||
// return false
|
||||
// }
|
||||
// for _, t := range textTypes {
|
||||
// if strings.HasPrefix(lowerType, t) {
|
||||
// return true
|
||||
// }
|
||||
// }
|
||||
// return false
|
||||
// }
|
||||
|
||||
// isTextTypeWithoutLength checks if type is text (which should convert to varchar when length is specified)
|
||||
func isTextTypeWithoutLength(colType string) bool {
|
||||
return strings.EqualFold(colType, "text")
|
||||
}
|
||||
|
||||
func ginOperatorClassForColumn(col *models.Column, comment string) string {
|
||||
if col == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
sqlType := effectiveColumnSQLType(col)
|
||||
baseType := pgsql.CanonicalizeBaseType(pgsql.ExtractBaseTypeLower(sqlType))
|
||||
isArray := pgsql.IsArrayType(sqlType)
|
||||
requested := extractOperatorClass(comment)
|
||||
|
||||
if requested != "" && ginOperatorClassCompatible(baseType, isArray, requested) {
|
||||
return requested
|
||||
}
|
||||
|
||||
if isArray {
|
||||
return "array_ops"
|
||||
}
|
||||
|
||||
switch {
|
||||
case isTextGinBaseType(baseType):
|
||||
return "gin_trgm_ops"
|
||||
case baseType == "jsonb":
|
||||
return "jsonb_ops"
|
||||
default:
|
||||
return requested
|
||||
}
|
||||
}
|
||||
|
||||
func ginOperatorClassCompatible(baseType string, isArray bool, opClass string) bool {
|
||||
switch opClass {
|
||||
case "gin_trgm_ops", "gin_bigm_ops":
|
||||
return !isArray && isTextGinBaseType(baseType)
|
||||
case "jsonb_ops", "jsonb_path_ops":
|
||||
return !isArray && baseType == "jsonb"
|
||||
case "array_ops":
|
||||
return isArray
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
func isTextGinBaseType(baseType string) bool {
|
||||
switch baseType {
|
||||
case "text", "varchar", "character varying", "char", "character", "string", "citext", "bpchar":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func schemaRequiresPGTrgm(schema *models.Schema) bool {
|
||||
if schema == nil {
|
||||
return false
|
||||
}
|
||||
for _, table := range schema.Tables {
|
||||
if table == nil {
|
||||
continue
|
||||
}
|
||||
for _, index := range table.Indexes {
|
||||
if index == nil || !strings.EqualFold(index.Type, "gin") {
|
||||
continue
|
||||
}
|
||||
for _, colName := range index.Columns {
|
||||
col, ok := resolveIndexColumn(table, colName)
|
||||
if !ok || col == nil {
|
||||
continue
|
||||
}
|
||||
if ginOperatorClassForColumn(col, index.Comment) == "gin_trgm_ops" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func resolveIndexColumn(table *models.Table, colName string) (*models.Column, bool) {
|
||||
if table == nil {
|
||||
return nil, false
|
||||
|
||||
@@ -116,8 +116,8 @@ func TestWriteDatabase_GinIndexOnTextArrayDoesNotUseTrigramOperatorClass(t *test
|
||||
}
|
||||
|
||||
output := buf.String()
|
||||
if !strings.Contains(output, `USING gin (tags)`) {
|
||||
t.Fatalf("expected GIN index on array column without explicit trigram opclass, got:\n%s", output)
|
||||
if !strings.Contains(output, `USING gin (tags array_ops)`) {
|
||||
t.Fatalf("expected GIN index on array column with array_ops, got:\n%s", output)
|
||||
}
|
||||
if strings.Contains(output, "gin_trgm_ops") {
|
||||
t.Fatalf("did not expect gin_trgm_ops for text[] column, got:\n%s", output)
|
||||
@@ -153,11 +153,51 @@ func TestWriteDatabase_GinIndexOnQuotedTextColumnUsesTrigramOperatorClass(t *tes
|
||||
}
|
||||
|
||||
output := buf.String()
|
||||
if !strings.Contains(output, `CREATE EXTENSION IF NOT EXISTS pg_trgm`) {
|
||||
t.Fatalf("expected trigram extension for text GIN index, got:\n%s", output)
|
||||
}
|
||||
if !strings.Contains(output, `USING gin (name gin_trgm_ops)`) {
|
||||
t.Fatalf("expected quoted text GIN index to include gin_trgm_ops, got:\n%s", output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteDatabase_GinIndexOnJSONBUsesJSONBOperatorClass(t *testing.T) {
|
||||
db := models.InitDatabase("testdb")
|
||||
schema := models.InitSchema("public")
|
||||
|
||||
table := models.InitTable("learnings", "public")
|
||||
|
||||
detailsCol := models.InitColumn("details", "learnings", "public")
|
||||
detailsCol.Type = "jsonb"
|
||||
table.Columns["details"] = detailsCol
|
||||
|
||||
index := &models.Index{
|
||||
Name: "idx_learnings_details",
|
||||
Type: "gin",
|
||||
Columns: []string{"details"},
|
||||
}
|
||||
table.Indexes[index.Name] = index
|
||||
|
||||
schema.Tables = append(schema.Tables, table)
|
||||
db.Schemas = append(db.Schemas, schema)
|
||||
|
||||
var buf bytes.Buffer
|
||||
writer := NewWriter(&writers.WriterOptions{})
|
||||
writer.writer = &buf
|
||||
|
||||
if err := writer.WriteDatabase(db); err != nil {
|
||||
t.Fatalf("WriteDatabase failed: %v", err)
|
||||
}
|
||||
|
||||
output := buf.String()
|
||||
if !strings.Contains(output, `USING gin (details jsonb_ops)`) {
|
||||
t.Fatalf("expected GIN jsonb index to include jsonb_ops, got:\n%s", output)
|
||||
}
|
||||
if strings.Contains(output, "gin_trgm_ops") {
|
||||
t.Fatalf("did not expect gin_trgm_ops for jsonb column, got:\n%s", output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteForeignKeys(t *testing.T) {
|
||||
// Create a test database with two related tables
|
||||
db := models.InitDatabase("testdb")
|
||||
@@ -1018,3 +1058,82 @@ func TestWriteAddColumnStatements(t *testing.T) {
|
||||
t.Errorf("Output missing DO block\nFull output:\n%s", output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteSchema_EmitsGuardedAlterColumnTypeStatements(t *testing.T) {
|
||||
db := models.InitDatabase("testdb")
|
||||
schema := models.InitSchema("public")
|
||||
|
||||
table := models.InitTable("agent_skills", "public")
|
||||
|
||||
nameCol := models.InitColumn("name", "agent_skills", "public")
|
||||
nameCol.Type = "character varying"
|
||||
nameCol.Length = 255
|
||||
table.Columns["name"] = nameCol
|
||||
|
||||
tagsCol := models.InitColumn("tags", "agent_skills", "public")
|
||||
tagsCol.Type = "text[]"
|
||||
table.Columns["tags"] = tagsCol
|
||||
|
||||
schema.Tables = append(schema.Tables, table)
|
||||
db.Schemas = append(db.Schemas, schema)
|
||||
|
||||
var buf bytes.Buffer
|
||||
writer := NewWriter(&writers.WriterOptions{})
|
||||
writer.writer = &buf
|
||||
|
||||
if err := writer.WriteDatabase(db); err != nil {
|
||||
t.Fatalf("WriteDatabase failed: %v", err)
|
||||
}
|
||||
|
||||
output := buf.String()
|
||||
if !strings.Contains(output, "-- Alter column types for schema: public") {
|
||||
t.Fatalf("expected alter column type section, got:\n%s", output)
|
||||
}
|
||||
if !strings.Contains(output, "pg_catalog.format_type") {
|
||||
t.Fatalf("expected guarded live-type check, got:\n%s", output)
|
||||
}
|
||||
if !strings.Contains(output, "ALTER COLUMN name TYPE character varying(255)") {
|
||||
t.Fatalf("expected guarded alter for character varying(255), got:\n%s", output)
|
||||
}
|
||||
if !strings.Contains(output, "ARRAY['varchar(255)', 'character varying(255)']") {
|
||||
t.Fatalf("expected equivalent type spellings for varchar guard, got:\n%s", output)
|
||||
}
|
||||
if !strings.Contains(output, "ALTER COLUMN tags TYPE text[]") {
|
||||
t.Fatalf("expected guarded alter for array type, got:\n%s", output)
|
||||
}
|
||||
if !strings.Contains(output, `ALTER COLUMN tags TYPE text[] USING tags::text[];`) {
|
||||
t.Fatalf("expected guarded alter for array type to include USING cast, got:\n%s", output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteSchema_UsesStorageTypeForSerialAlterStatements(t *testing.T) {
|
||||
db := models.InitDatabase("testdb")
|
||||
schema := models.InitSchema("public")
|
||||
|
||||
table := models.InitTable("learnings", "public")
|
||||
idCol := models.InitColumn("id", "learnings", "public")
|
||||
idCol.Type = "bigserial"
|
||||
table.Columns["id"] = idCol
|
||||
|
||||
schema.Tables = append(schema.Tables, table)
|
||||
db.Schemas = append(db.Schemas, schema)
|
||||
|
||||
var buf bytes.Buffer
|
||||
writer := NewWriter(&writers.WriterOptions{})
|
||||
writer.writer = &buf
|
||||
|
||||
if err := writer.WriteDatabase(db); err != nil {
|
||||
t.Fatalf("WriteDatabase failed: %v", err)
|
||||
}
|
||||
|
||||
output := buf.String()
|
||||
if !strings.Contains(output, "ALTER COLUMN id TYPE bigint") {
|
||||
t.Fatalf("expected serial alter to use bigint storage type, got:\n%s", output)
|
||||
}
|
||||
if strings.Contains(output, "ALTER COLUMN id TYPE bigserial;") {
|
||||
t.Fatalf("did not expect invalid bigserial alter statement, got:\n%s", output)
|
||||
}
|
||||
if !strings.Contains(output, `ALTER COLUMN id TYPE bigint USING id::bigint;`) {
|
||||
t.Fatalf("expected serial alter to include USING cast, got:\n%s", output)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user