feat(writer): 🎉 Enhance PostgreSQL writer, fixed bugs found using origin
Some checks failed
CI / Test (1.24) (push) Failing after -24m5s
CI / Test (1.25) (push) Successful in -23m53s
CI / Build (push) Failing after -26m29s
CI / Lint (push) Successful in -26m12s
Integration Tests / Integration Tests (push) Successful in -26m20s
Release / Build and Release (push) Successful in -25m7s

This commit is contained in:
2026-01-28 21:59:25 +02:00
parent 6f55505444
commit 91b6046b9b
7 changed files with 1468 additions and 25 deletions

View File

@@ -155,13 +155,30 @@ func (w *Writer) GenerateSchemaStatements(schema *models.Schema) ([]string, erro
indexType = "btree"
}
// Build column expressions with operator class support for GIN indexes
columnExprs := make([]string, 0, len(index.Columns))
for _, colName := range index.Columns {
colExpr := colName
if col, ok := table.Columns[colName]; ok {
// For GIN indexes on text columns, add operator class
if strings.EqualFold(indexType, "gin") && isTextType(col.Type) {
opClass := extractOperatorClass(index.Comment)
if opClass == "" {
opClass = "gin_trgm_ops"
}
colExpr = fmt.Sprintf("%s %s", colName, opClass)
}
}
columnExprs = append(columnExprs, colExpr)
}
whereClause := ""
if index.Where != "" {
whereClause = fmt.Sprintf(" WHERE %s", index.Where)
}
stmt := fmt.Sprintf("CREATE %sINDEX IF NOT EXISTS %s ON %s.%s USING %s (%s)%s",
uniqueStr, index.Name, schema.SQLName(), table.SQLName(), indexType, strings.Join(index.Columns, ", "), whereClause)
uniqueStr, index.Name, schema.SQLName(), table.SQLName(), indexType, strings.Join(columnExprs, ", "), whereClause)
statements = append(statements, stmt)
}
}
@@ -273,12 +290,14 @@ func (w *Writer) generateColumnDefinition(col *models.Column) string {
if col.Default != nil {
switch v := col.Default.(type) {
case string:
if strings.HasPrefix(v, "nextval") || strings.HasPrefix(v, "CURRENT_") || strings.Contains(v, "()") {
parts = append(parts, fmt.Sprintf("DEFAULT %s", v))
} else if v == "true" || v == "false" {
parts = append(parts, fmt.Sprintf("DEFAULT %s", v))
// Strip backticks - DBML uses them for SQL expressions but PostgreSQL doesn't
cleanDefault := stripBackticks(v)
if strings.HasPrefix(cleanDefault, "nextval") || strings.HasPrefix(cleanDefault, "CURRENT_") || strings.Contains(cleanDefault, "()") {
parts = append(parts, fmt.Sprintf("DEFAULT %s", cleanDefault))
} else if cleanDefault == "true" || cleanDefault == "false" {
parts = append(parts, fmt.Sprintf("DEFAULT %s", cleanDefault))
} else {
parts = append(parts, fmt.Sprintf("DEFAULT '%s'", escapeQuote(v)))
parts = append(parts, fmt.Sprintf("DEFAULT '%s'", escapeQuote(cleanDefault)))
}
case bool:
parts = append(parts, fmt.Sprintf("DEFAULT %v", v))
@@ -408,8 +427,10 @@ func (w *Writer) writeCreateTables(schema *models.Schema) error {
colDef := fmt.Sprintf(" %s %s", col.SQLName(), col.Type)
// Add default value if present
if col.Default != "" {
colDef += fmt.Sprintf(" DEFAULT %s", col.Default)
if col.Default != nil && col.Default != "" {
// Strip backticks - DBML uses them for SQL expressions but PostgreSQL doesn't
defaultVal := fmt.Sprintf("%v", col.Default)
colDef += fmt.Sprintf(" DEFAULT %s", stripBackticks(defaultVal))
}
columnDefs = append(columnDefs, colDef)
@@ -503,15 +524,24 @@ func (w *Writer) writeIndexes(schema *models.Schema) error {
indexName = fmt.Sprintf("%s_%s_%s", indexType, schema.SQLName(), table.SQLName())
}
// Build column list
columnNames := make([]string, 0, len(index.Columns))
// Build column list with operator class support for GIN indexes
columnExprs := make([]string, 0, len(index.Columns))
for _, colName := range index.Columns {
if col, ok := table.Columns[colName]; ok {
columnNames = append(columnNames, col.SQLName())
colExpr := col.SQLName()
// For GIN indexes on text columns, add operator class
if strings.EqualFold(index.Type, "gin") && isTextType(col.Type) {
opClass := extractOperatorClass(index.Comment)
if opClass == "" {
opClass = "gin_trgm_ops"
}
colExpr = fmt.Sprintf("%s %s", col.SQLName(), opClass)
}
columnExprs = append(columnExprs, colExpr)
}
}
if len(columnNames) == 0 {
if len(columnExprs) == 0 {
continue
}
@@ -520,10 +550,20 @@ func (w *Writer) writeIndexes(schema *models.Schema) error {
unique = "UNIQUE "
}
indexType := index.Type
if indexType == "" {
indexType = "btree"
}
whereClause := ""
if index.Where != "" {
whereClause = fmt.Sprintf(" WHERE %s", index.Where)
}
fmt.Fprintf(w.writer, "CREATE %sINDEX IF NOT EXISTS %s\n",
unique, indexName)
fmt.Fprintf(w.writer, " ON %s.%s USING btree (%s);\n\n",
schema.SQLName(), table.SQLName(), strings.Join(columnNames, ", "))
fmt.Fprintf(w.writer, " ON %s.%s USING %s (%s)%s;\n\n",
schema.SQLName(), table.SQLName(), indexType, strings.Join(columnExprs, ", "), whereClause)
}
}
@@ -718,11 +758,46 @@ func isIntegerType(colType string) bool {
return false
}
// isTextType checks if a column type is a text type (for GIN index operator class)
func isTextType(colType string) bool {
textTypes := []string{"text", "varchar", "character varying", "char", "character", "string"}
lowerType := strings.ToLower(colType)
for _, t := range textTypes {
if strings.HasPrefix(lowerType, t) {
return true
}
}
return false
}
// extractOperatorClass extracts operator class from index comment/note
// Looks for common operator classes like gin_trgm_ops, gist_trgm_ops, etc.
func extractOperatorClass(comment string) string {
if comment == "" {
return ""
}
lowerComment := strings.ToLower(comment)
// Common GIN/GiST operator classes
opClasses := []string{"gin_trgm_ops", "gist_trgm_ops", "gin_bigm_ops", "jsonb_ops", "jsonb_path_ops", "array_ops"}
for _, op := range opClasses {
if strings.Contains(lowerComment, op) {
return op
}
}
return ""
}
// escapeQuote escapes single quotes in strings for SQL
func escapeQuote(s string) string {
return strings.ReplaceAll(s, "'", "''")
}
// stripBackticks removes backticks from SQL expressions
// DBML uses backticks for SQL expressions like `now()`, but PostgreSQL doesn't use backticks
func stripBackticks(s string) string {
return strings.ReplaceAll(s, "`", "")
}
// extractSequenceName extracts sequence name from nextval() expression
// Example: "nextval('public.users_id_seq'::regclass)" returns "users_id_seq"
func extractSequenceName(defaultExpr string) string {