feat(writer): 🎉 Enhance PostgreSQL writer, fixed bugs found using origin
Some checks failed
CI / Test (1.24) (push) Failing after -24m5s
CI / Test (1.25) (push) Successful in -23m53s
CI / Build (push) Failing after -26m29s
CI / Lint (push) Successful in -26m12s
Integration Tests / Integration Tests (push) Successful in -26m20s
Release / Build and Release (push) Successful in -25m7s
Some checks failed
CI / Test (1.24) (push) Failing after -24m5s
CI / Test (1.25) (push) Successful in -23m53s
CI / Build (push) Failing after -26m29s
CI / Lint (push) Successful in -26m12s
Integration Tests / Integration Tests (push) Successful in -26m20s
Release / Build and Release (push) Successful in -25m7s
This commit is contained in:
@@ -587,10 +587,10 @@ func (r *Reader) parseColumn(line, tableName, schemaName string) (*models.Column
|
||||
refOp := strings.TrimSpace(refStr)
|
||||
var isReverse bool
|
||||
if strings.HasPrefix(refOp, "<") {
|
||||
isReverse = column.IsPrimaryKey // < on PK means "is referenced by" (reverse)
|
||||
} else if strings.HasPrefix(refOp, ">") {
|
||||
isReverse = !column.IsPrimaryKey // > on FK means reverse
|
||||
// < means "is referenced by" - only makes sense on PK columns
|
||||
isReverse = column.IsPrimaryKey
|
||||
}
|
||||
// > means "references" - always a forward FK, never reverse
|
||||
|
||||
constraint = r.parseRef(refStr)
|
||||
if constraint != nil {
|
||||
|
||||
@@ -329,10 +329,10 @@ func (r *Reader) deriveRelationship(table *models.Table, fk *models.Constraint)
|
||||
relationshipName := fmt.Sprintf("%s_to_%s", table.Name, fk.ReferencedTable)
|
||||
|
||||
relationship := models.InitRelationship(relationshipName, models.OneToMany)
|
||||
relationship.FromTable = fk.ReferencedTable
|
||||
relationship.FromSchema = fk.ReferencedSchema
|
||||
relationship.ToTable = table.Name
|
||||
relationship.ToSchema = table.Schema
|
||||
relationship.FromTable = table.Name
|
||||
relationship.FromSchema = table.Schema
|
||||
relationship.ToTable = fk.ReferencedTable
|
||||
relationship.ToSchema = fk.ReferencedSchema
|
||||
relationship.ForeignKey = fk.Name
|
||||
|
||||
// Store constraint actions in properties
|
||||
|
||||
@@ -328,12 +328,12 @@ func TestDeriveRelationship(t *testing.T) {
|
||||
t.Errorf("Expected relationship type %s, got %s", models.OneToMany, rel.Type)
|
||||
}
|
||||
|
||||
if rel.FromTable != "users" {
|
||||
t.Errorf("Expected FromTable 'users', got '%s'", rel.FromTable)
|
||||
if rel.FromTable != "orders" {
|
||||
t.Errorf("Expected FromTable 'orders', got '%s'", rel.FromTable)
|
||||
}
|
||||
|
||||
if rel.ToTable != "orders" {
|
||||
t.Errorf("Expected ToTable 'orders', got '%s'", rel.ToTable)
|
||||
if rel.ToTable != "users" {
|
||||
t.Errorf("Expected ToTable 'users', got '%s'", rel.ToTable)
|
||||
}
|
||||
|
||||
if rel.ForeignKey != "fk_orders_user_id" {
|
||||
|
||||
@@ -155,13 +155,30 @@ func (w *Writer) GenerateSchemaStatements(schema *models.Schema) ([]string, erro
|
||||
indexType = "btree"
|
||||
}
|
||||
|
||||
// Build column expressions with operator class support for GIN indexes
|
||||
columnExprs := make([]string, 0, len(index.Columns))
|
||||
for _, colName := range index.Columns {
|
||||
colExpr := colName
|
||||
if col, ok := table.Columns[colName]; ok {
|
||||
// For GIN indexes on text columns, add operator class
|
||||
if strings.EqualFold(indexType, "gin") && isTextType(col.Type) {
|
||||
opClass := extractOperatorClass(index.Comment)
|
||||
if opClass == "" {
|
||||
opClass = "gin_trgm_ops"
|
||||
}
|
||||
colExpr = fmt.Sprintf("%s %s", colName, opClass)
|
||||
}
|
||||
}
|
||||
columnExprs = append(columnExprs, colExpr)
|
||||
}
|
||||
|
||||
whereClause := ""
|
||||
if index.Where != "" {
|
||||
whereClause = fmt.Sprintf(" WHERE %s", index.Where)
|
||||
}
|
||||
|
||||
stmt := fmt.Sprintf("CREATE %sINDEX IF NOT EXISTS %s ON %s.%s USING %s (%s)%s",
|
||||
uniqueStr, index.Name, schema.SQLName(), table.SQLName(), indexType, strings.Join(index.Columns, ", "), whereClause)
|
||||
uniqueStr, index.Name, schema.SQLName(), table.SQLName(), indexType, strings.Join(columnExprs, ", "), whereClause)
|
||||
statements = append(statements, stmt)
|
||||
}
|
||||
}
|
||||
@@ -273,12 +290,14 @@ func (w *Writer) generateColumnDefinition(col *models.Column) string {
|
||||
if col.Default != nil {
|
||||
switch v := col.Default.(type) {
|
||||
case string:
|
||||
if strings.HasPrefix(v, "nextval") || strings.HasPrefix(v, "CURRENT_") || strings.Contains(v, "()") {
|
||||
parts = append(parts, fmt.Sprintf("DEFAULT %s", v))
|
||||
} else if v == "true" || v == "false" {
|
||||
parts = append(parts, fmt.Sprintf("DEFAULT %s", v))
|
||||
// Strip backticks - DBML uses them for SQL expressions but PostgreSQL doesn't
|
||||
cleanDefault := stripBackticks(v)
|
||||
if strings.HasPrefix(cleanDefault, "nextval") || strings.HasPrefix(cleanDefault, "CURRENT_") || strings.Contains(cleanDefault, "()") {
|
||||
parts = append(parts, fmt.Sprintf("DEFAULT %s", cleanDefault))
|
||||
} else if cleanDefault == "true" || cleanDefault == "false" {
|
||||
parts = append(parts, fmt.Sprintf("DEFAULT %s", cleanDefault))
|
||||
} else {
|
||||
parts = append(parts, fmt.Sprintf("DEFAULT '%s'", escapeQuote(v)))
|
||||
parts = append(parts, fmt.Sprintf("DEFAULT '%s'", escapeQuote(cleanDefault)))
|
||||
}
|
||||
case bool:
|
||||
parts = append(parts, fmt.Sprintf("DEFAULT %v", v))
|
||||
@@ -408,8 +427,10 @@ func (w *Writer) writeCreateTables(schema *models.Schema) error {
|
||||
colDef := fmt.Sprintf(" %s %s", col.SQLName(), col.Type)
|
||||
|
||||
// Add default value if present
|
||||
if col.Default != "" {
|
||||
colDef += fmt.Sprintf(" DEFAULT %s", col.Default)
|
||||
if col.Default != nil && col.Default != "" {
|
||||
// Strip backticks - DBML uses them for SQL expressions but PostgreSQL doesn't
|
||||
defaultVal := fmt.Sprintf("%v", col.Default)
|
||||
colDef += fmt.Sprintf(" DEFAULT %s", stripBackticks(defaultVal))
|
||||
}
|
||||
|
||||
columnDefs = append(columnDefs, colDef)
|
||||
@@ -503,15 +524,24 @@ func (w *Writer) writeIndexes(schema *models.Schema) error {
|
||||
indexName = fmt.Sprintf("%s_%s_%s", indexType, schema.SQLName(), table.SQLName())
|
||||
}
|
||||
|
||||
// Build column list
|
||||
columnNames := make([]string, 0, len(index.Columns))
|
||||
// Build column list with operator class support for GIN indexes
|
||||
columnExprs := make([]string, 0, len(index.Columns))
|
||||
for _, colName := range index.Columns {
|
||||
if col, ok := table.Columns[colName]; ok {
|
||||
columnNames = append(columnNames, col.SQLName())
|
||||
colExpr := col.SQLName()
|
||||
// For GIN indexes on text columns, add operator class
|
||||
if strings.EqualFold(index.Type, "gin") && isTextType(col.Type) {
|
||||
opClass := extractOperatorClass(index.Comment)
|
||||
if opClass == "" {
|
||||
opClass = "gin_trgm_ops"
|
||||
}
|
||||
colExpr = fmt.Sprintf("%s %s", col.SQLName(), opClass)
|
||||
}
|
||||
columnExprs = append(columnExprs, colExpr)
|
||||
}
|
||||
}
|
||||
|
||||
if len(columnNames) == 0 {
|
||||
if len(columnExprs) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -520,10 +550,20 @@ func (w *Writer) writeIndexes(schema *models.Schema) error {
|
||||
unique = "UNIQUE "
|
||||
}
|
||||
|
||||
indexType := index.Type
|
||||
if indexType == "" {
|
||||
indexType = "btree"
|
||||
}
|
||||
|
||||
whereClause := ""
|
||||
if index.Where != "" {
|
||||
whereClause = fmt.Sprintf(" WHERE %s", index.Where)
|
||||
}
|
||||
|
||||
fmt.Fprintf(w.writer, "CREATE %sINDEX IF NOT EXISTS %s\n",
|
||||
unique, indexName)
|
||||
fmt.Fprintf(w.writer, " ON %s.%s USING btree (%s);\n\n",
|
||||
schema.SQLName(), table.SQLName(), strings.Join(columnNames, ", "))
|
||||
fmt.Fprintf(w.writer, " ON %s.%s USING %s (%s)%s;\n\n",
|
||||
schema.SQLName(), table.SQLName(), indexType, strings.Join(columnExprs, ", "), whereClause)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -718,11 +758,46 @@ func isIntegerType(colType string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// isTextType checks if a column type is a text type (for GIN index operator class)
|
||||
func isTextType(colType string) bool {
|
||||
textTypes := []string{"text", "varchar", "character varying", "char", "character", "string"}
|
||||
lowerType := strings.ToLower(colType)
|
||||
for _, t := range textTypes {
|
||||
if strings.HasPrefix(lowerType, t) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// extractOperatorClass extracts operator class from index comment/note
|
||||
// Looks for common operator classes like gin_trgm_ops, gist_trgm_ops, etc.
|
||||
func extractOperatorClass(comment string) string {
|
||||
if comment == "" {
|
||||
return ""
|
||||
}
|
||||
lowerComment := strings.ToLower(comment)
|
||||
// Common GIN/GiST operator classes
|
||||
opClasses := []string{"gin_trgm_ops", "gist_trgm_ops", "gin_bigm_ops", "jsonb_ops", "jsonb_path_ops", "array_ops"}
|
||||
for _, op := range opClasses {
|
||||
if strings.Contains(lowerComment, op) {
|
||||
return op
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// escapeQuote escapes single quotes in strings for SQL
|
||||
func escapeQuote(s string) string {
|
||||
return strings.ReplaceAll(s, "'", "''")
|
||||
}
|
||||
|
||||
// stripBackticks removes backticks from SQL expressions
|
||||
// DBML uses backticks for SQL expressions like `now()`, but PostgreSQL doesn't use backticks
|
||||
func stripBackticks(s string) string {
|
||||
return strings.ReplaceAll(s, "`", "")
|
||||
}
|
||||
|
||||
// extractSequenceName extracts sequence name from nextval() expression
|
||||
// Example: "nextval('public.users_id_seq'::regclass)" returns "users_id_seq"
|
||||
func extractSequenceName(defaultExpr string) string {
|
||||
|
||||
Reference in New Issue
Block a user