feat: Enhance PostgreSQL type handling and migration scripts

- Introduced equivalent base types and variants for PostgreSQL types to normalize type comparisons.
- Added functions for normalizing SQL types and retrieving equivalent type variants.
- Updated migration writer to handle type alterations with checks for existing types.
- Implemented logic to create necessary extensions (e.g., pg_trgm) based on schema requirements.
- Enhanced tests to cover new functionality for type normalization and migration handling.
- Improved handling of GIN indexes to use appropriate operator classes based on column types.
This commit is contained in:
2026-05-05 14:50:34 +02:00
parent 72200ea72e
commit 2d97a47ee1
14 changed files with 1042 additions and 65 deletions

View File

@@ -143,6 +143,10 @@ func (w *Writer) GenerateSchemaStatements(schema *models.Schema) ([]string, erro
statements = append(statements, fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %s", schema.SQLName()))
}
if schemaRequiresPGTrgm(schema) {
statements = append(statements, `CREATE EXTENSION IF NOT EXISTS pg_trgm`)
}
// Phase 2: Create sequences
for _, table := range schema.Tables {
pk := table.GetPrimaryKey()
@@ -181,6 +185,12 @@ func (w *Writer) GenerateSchemaStatements(schema *models.Schema) ([]string, erro
}
statements = append(statements, addColStmts...)
alterTypeStmts, err := w.GenerateAlterColumnTypeStatements(schema)
if err != nil {
return nil, fmt.Errorf("failed to generate alter column type statements: %w", err)
}
statements = append(statements, alterTypeStmts...)
// Phase 4: Primary keys
for _, table := range schema.Tables {
// First check for explicit PrimaryKeyConstraint
@@ -262,13 +272,10 @@ func (w *Writer) GenerateSchemaStatements(schema *models.Schema) ([]string, erro
for _, colName := range index.Columns {
colExpr := colName
if col, ok := resolveIndexColumn(table, colName); ok {
// For GIN indexes on text columns, add operator class
if strings.EqualFold(indexType, "gin") && isTextType(col.Type) {
opClass := extractOperatorClass(index.Comment)
if opClass == "" {
opClass = "gin_trgm_ops"
if strings.EqualFold(indexType, "gin") {
if opClass := ginOperatorClassForColumn(col, index.Comment); opClass != "" {
colExpr = fmt.Sprintf("%s %s", colName, opClass)
}
colExpr = fmt.Sprintf("%s %s", colName, opClass)
}
}
columnExprs = append(columnExprs, colExpr)
@@ -437,6 +444,33 @@ func (w *Writer) GenerateAddColumnStatements(schema *models.Schema) ([]string, e
return statements, nil
}
func (w *Writer) GenerateAlterColumnTypeStatements(schema *models.Schema) ([]string, error) {
statements := []string{}
statements = append(statements, fmt.Sprintf("-- Alter column types for schema: %s", schema.Name))
for _, table := range schema.Tables {
columns := getSortedColumns(table.Columns)
for _, col := range columns {
targetType := effectiveAlterColumnSQLType(col)
stmt, err := w.executor.ExecuteAlterColumnTypeWithCheck(AlterColumnTypeWithCheckData{
SchemaName: schema.Name,
TableName: table.Name,
ColumnName: col.Name,
NewType: targetType,
EquivalentTypes: equivalentTypeListSQL(targetType),
UsingExpr: buildAlterColumnUsingExpression(col.Name, targetType),
})
if err != nil {
return nil, fmt.Errorf("failed to generate alter column type for %s.%s.%s: %w", schema.Name, table.Name, col.Name, err)
}
statements = append(statements, stmt)
}
}
return statements, nil
}
// GenerateAddColumnsForDatabase generates ALTER TABLE ADD COLUMN statements for the entire database
func (w *Writer) GenerateAddColumnsForDatabase(db *models.Database) ([]string, error) {
statements := []string{}
@@ -489,31 +523,7 @@ func (w *Writer) generateCreateTableStatement(schema *models.Schema, table *mode
func (w *Writer) generateColumnDefinition(col *models.Column) string {
parts := []string{col.SQLName()}
// Type with length/precision - convert to valid PostgreSQL type
baseType := pgsql.ConvertSQLType(col.Type)
typeStr := baseType
hasExplicitTypeModifier := pgsql.HasExplicitTypeModifier(baseType)
// Only add size specifiers for types that support them
if !hasExplicitTypeModifier && col.Length > 0 && col.Precision == 0 {
if pgsql.SupportsLength(baseType) {
typeStr = fmt.Sprintf("%s(%d)", baseType, col.Length)
} else if isTextTypeWithoutLength(baseType) {
// Convert text with length to varchar
typeStr = fmt.Sprintf("varchar(%d)", col.Length)
}
// For types that don't support length (integer, bigint, etc.), ignore the length
} else if !hasExplicitTypeModifier && col.Precision > 0 {
if pgsql.SupportsPrecision(baseType) {
if col.Scale > 0 {
typeStr = fmt.Sprintf("%s(%d,%d)", baseType, col.Precision, col.Scale)
} else {
typeStr = fmt.Sprintf("%s(%d)", baseType, col.Precision)
}
}
// For types that don't support precision, ignore it
}
parts = append(parts, typeStr)
parts = append(parts, effectiveColumnSQLType(col))
// NOT NULL
if col.NotNull {
@@ -535,6 +545,64 @@ func (w *Writer) generateColumnDefinition(col *models.Column) string {
return strings.Join(parts, " ")
}
func effectiveColumnSQLType(col *models.Column) string {
if col == nil {
return ""
}
baseType := pgsql.ConvertSQLType(col.Type)
typeStr := baseType
hasExplicitTypeModifier := pgsql.HasExplicitTypeModifier(baseType)
if !hasExplicitTypeModifier && col.Length > 0 && col.Precision == 0 {
if pgsql.SupportsLength(baseType) {
typeStr = fmt.Sprintf("%s(%d)", baseType, col.Length)
} else if isTextTypeWithoutLength(baseType) {
typeStr = fmt.Sprintf("varchar(%d)", col.Length)
}
} else if !hasExplicitTypeModifier && col.Precision > 0 {
if pgsql.SupportsPrecision(baseType) {
if col.Scale > 0 {
typeStr = fmt.Sprintf("%s(%d,%d)", baseType, col.Precision, col.Scale)
} else {
typeStr = fmt.Sprintf("%s(%d)", baseType, col.Precision)
}
}
}
return typeStr
}
func effectiveAlterColumnSQLType(col *models.Column) string {
typeStr := effectiveColumnSQLType(col)
switch strings.ToLower(strings.TrimSpace(typeStr)) {
case "smallserial":
return "smallint"
case "serial":
return "integer"
case "bigserial":
return "bigint"
default:
return typeStr
}
}
func buildAlterColumnUsingExpression(columnName, targetType string) string {
if strings.TrimSpace(columnName) == "" || strings.TrimSpace(targetType) == "" {
return ""
}
return fmt.Sprintf("%s::%s", quoteIdent(columnName), targetType)
}
func equivalentTypeListSQL(sqlType string) string {
variants := pgsql.EquivalentSQLTypeVariants(sqlType)
quoted := make([]string, 0, len(variants))
for _, variant := range variants {
quoted = append(quoted, fmt.Sprintf("'%s'", escapeQuote(variant)))
}
return strings.Join(quoted, ", ")
}
// WriteSchema writes a single schema and all its tables
func (w *Writer) WriteSchema(schema *models.Schema) error {
if w.writer == nil {
@@ -546,6 +614,10 @@ func (w *Writer) WriteSchema(schema *models.Schema) error {
return err
}
if err := w.writeRequiredExtensions(schema); err != nil {
return err
}
// Phase 2: Create sequences (priority 80)
if err := w.writeSequences(schema); err != nil {
return err
@@ -561,6 +633,10 @@ func (w *Writer) WriteSchema(schema *models.Schema) error {
return err
}
if err := w.writeAlterColumnTypes(schema); err != nil {
return err
}
// Phase 4: Create primary keys (priority 160)
if err := w.writePrimaryKeys(schema); err != nil {
return err
@@ -661,6 +737,16 @@ func (w *Writer) writeCreateSchema(schema *models.Schema) error {
return nil
}
func (w *Writer) writeRequiredExtensions(schema *models.Schema) error {
if !schemaRequiresPGTrgm(schema) {
return nil
}
fmt.Fprintln(w.writer, "CREATE EXTENSION IF NOT EXISTS pg_trgm;")
fmt.Fprintln(w.writer)
return nil
}
// writeSequences generates CREATE SEQUENCE statements for identity columns
func (w *Writer) writeSequences(schema *models.Schema) error {
fmt.Fprintf(w.writer, "-- Sequences for schema: %s\n", schema.Name)
@@ -754,6 +840,21 @@ func (w *Writer) writeAddColumns(schema *models.Schema) error {
return nil
}
func (w *Writer) writeAlterColumnTypes(schema *models.Schema) error {
fmt.Fprintf(w.writer, "-- Alter column types for schema: %s\n", schema.Name)
statements, err := w.GenerateAlterColumnTypeStatements(schema)
if err != nil {
return err
}
for _, stmt := range statements[1:] {
fmt.Fprint(w.writer, stmt)
fmt.Fprint(w.writer, "\n")
}
return nil
}
// writePrimaryKeys generates ALTER TABLE statements for primary keys
func (w *Writer) writePrimaryKeys(schema *models.Schema) error {
fmt.Fprintf(w.writer, "-- Primary keys for schema: %s\n", schema.Name)
@@ -857,13 +958,11 @@ func (w *Writer) writeIndexes(schema *models.Schema) error {
for _, colName := range index.Columns {
if col, ok := resolveIndexColumn(table, colName); ok {
colExpr := col.SQLName()
// For GIN indexes on text columns, add operator class
if strings.EqualFold(index.Type, "gin") && isTextType(col.Type) {
opClass := extractOperatorClass(index.Comment)
if opClass == "" {
opClass = "gin_trgm_ops"
if strings.EqualFold(index.Type, "gin") {
opClass := ginOperatorClassForColumn(col, index.Comment)
if opClass != "" {
colExpr = fmt.Sprintf("%s %s", col.SQLName(), opClass)
}
colExpr = fmt.Sprintf("%s %s", col.SQLName(), opClass)
}
columnExprs = append(columnExprs, colExpr)
}
@@ -1250,25 +1349,101 @@ func isIntegerType(colType string) bool {
}
// isTextType checks if a column type is a text type (for GIN index operator class)
func isTextType(colType string) bool {
textTypes := []string{"text", "varchar", "character varying", "char", "character", "string"}
lowerType := strings.ToLower(colType)
if strings.HasSuffix(lowerType, "[]") {
return false
}
for _, t := range textTypes {
if strings.HasPrefix(lowerType, t) {
return true
}
}
return false
}
// func isTextType(colType string) bool {
// textTypes := []string{"text", "varchar", "character varying", "char", "character", "string"}
// lowerType := strings.ToLower(colType)
// if strings.HasSuffix(lowerType, "[]") {
// return false
// }
// for _, t := range textTypes {
// if strings.HasPrefix(lowerType, t) {
// return true
// }
// }
// return false
// }
// isTextTypeWithoutLength checks if type is text (which should convert to varchar when length is specified)
func isTextTypeWithoutLength(colType string) bool {
return strings.EqualFold(colType, "text")
}
func ginOperatorClassForColumn(col *models.Column, comment string) string {
if col == nil {
return ""
}
sqlType := effectiveColumnSQLType(col)
baseType := pgsql.CanonicalizeBaseType(pgsql.ExtractBaseTypeLower(sqlType))
isArray := pgsql.IsArrayType(sqlType)
requested := extractOperatorClass(comment)
if requested != "" && ginOperatorClassCompatible(baseType, isArray, requested) {
return requested
}
if isArray {
return "array_ops"
}
switch {
case isTextGinBaseType(baseType):
return "gin_trgm_ops"
case baseType == "jsonb":
return "jsonb_ops"
default:
return requested
}
}
func ginOperatorClassCompatible(baseType string, isArray bool, opClass string) bool {
switch opClass {
case "gin_trgm_ops", "gin_bigm_ops":
return !isArray && isTextGinBaseType(baseType)
case "jsonb_ops", "jsonb_path_ops":
return !isArray && baseType == "jsonb"
case "array_ops":
return isArray
default:
return true
}
}
func isTextGinBaseType(baseType string) bool {
switch baseType {
case "text", "varchar", "character varying", "char", "character", "string", "citext", "bpchar":
return true
default:
return false
}
}
func schemaRequiresPGTrgm(schema *models.Schema) bool {
if schema == nil {
return false
}
for _, table := range schema.Tables {
if table == nil {
continue
}
for _, index := range table.Indexes {
if index == nil || !strings.EqualFold(index.Type, "gin") {
continue
}
for _, colName := range index.Columns {
col, ok := resolveIndexColumn(table, colName)
if !ok || col == nil {
continue
}
if ginOperatorClassForColumn(col, index.Comment) == "gin_trgm_ops" {
return true
}
}
}
}
return false
}
func resolveIndexColumn(table *models.Table, colName string) (*models.Column, bool) {
if table == nil {
return nil, false