diff --git a/pkg/writers/pgsql/migration_writer.go b/pkg/writers/pgsql/migration_writer.go index fe1f23e..5272354 100644 --- a/pkg/writers/pgsql/migration_writer.go +++ b/pkg/writers/pgsql/migration_writer.go @@ -329,7 +329,11 @@ func (w *MigrationWriter) generateAlterTableScripts(schema *models.Schema, model // Column doesn't exist, add it defaultVal := "" if modelCol.Default != nil { - defaultVal = fmt.Sprintf("%v", modelCol.Default) + if value, ok := modelCol.Default.(string); ok { + defaultVal = writers.QuoteDefaultValue(value, modelCol.Type) + } else { + defaultVal = fmt.Sprintf("%v", modelCol.Default) + } } sql, err := w.executor.ExecuteAddColumn(AddColumnData{ @@ -382,7 +386,11 @@ func (w *MigrationWriter) generateAlterTableScripts(schema *models.Schema, model setDefault := modelCol.Default != nil defaultVal := "" if setDefault { - defaultVal = fmt.Sprintf("%v", modelCol.Default) + if value, ok := modelCol.Default.(string); ok { + defaultVal = writers.QuoteDefaultValue(value, modelCol.Type) + } else { + defaultVal = fmt.Sprintf("%v", modelCol.Default) + } } sql, err := w.executor.ExecuteAlterColumnDefault(AlterColumnDefaultData{ diff --git a/pkg/writers/pgsql/migration_writer_test.go b/pkg/writers/pgsql/migration_writer_test.go index da976b9..2d0f493 100644 --- a/pkg/writers/pgsql/migration_writer_test.go +++ b/pkg/writers/pgsql/migration_writer_test.go @@ -57,6 +57,46 @@ func TestWriteMigration_NewTable(t *testing.T) { } } +func TestWriteMigration_ArrayDefault(t *testing.T) { + current := models.InitDatabase("testdb") + currentSchema := models.InitSchema("public") + current.Schemas = append(current.Schemas, currentSchema) + + model := models.InitDatabase("testdb") + modelSchema := models.InitSchema("public") + + table := models.InitTable("plans", "public") + + tagsCol := models.InitColumn("tags", "plans", "public") + tagsCol.Type = "text[]" + tagsCol.NotNull = true + tagsCol.Default = "''{}''" + table.Columns["tags"] = tagsCol + + modelSchema.Tables = append(modelSchema.Tables, table) + model.Schemas = append(model.Schemas, modelSchema) + + var buf bytes.Buffer + writer, err := NewMigrationWriter(&writers.WriterOptions{}) + if err != nil { + t.Fatalf("Failed to create writer: %v", err) + } + writer.writer = &buf + + err = writer.WriteMigration(model, current) + if err != nil { + t.Fatalf("WriteMigration failed: %v", err) + } + + output := buf.String() + if !strings.Contains(output, "tags text[] DEFAULT '{}' NOT NULL") { + t.Fatalf("expected normalized array default in migration, got:\n%s", output) + } + if strings.Contains(output, "'''{}'''") { + t.Fatalf("migration still contains triple-quoted array default:\n%s", output) + } +} + func TestWriteMigration_WithAudit(t *testing.T) { // Current database (empty) current := models.InitDatabase("testdb") diff --git a/pkg/writers/pgsql/templates.go b/pkg/writers/pgsql/templates.go index 38cf485..e910b82 100644 --- a/pkg/writers/pgsql/templates.go +++ b/pkg/writers/pgsql/templates.go @@ -8,6 +8,7 @@ import ( "text/template" "git.warky.dev/wdevs/relspecgo/pkg/models" + "git.warky.dev/wdevs/relspecgo/pkg/writers" ) //go:embed templates/*.tmpl @@ -495,7 +496,11 @@ func BuildCreateTableData(schemaName string, table *models.Table) CreateTableDat NotNull: col.NotNull, } if col.Default != nil { - colData.Default = fmt.Sprintf("%v", col.Default) + if value, ok := col.Default.(string); ok { + colData.Default = writers.QuoteDefaultValue(value, col.Type) + } else { + colData.Default = fmt.Sprintf("%v", col.Default) + } } columns = append(columns, colData) } diff --git a/pkg/writers/pgsql/writer.go b/pkg/writers/pgsql/writer.go index 732273d..9a1b241 100644 --- a/pkg/writers/pgsql/writer.go +++ b/pkg/writers/pgsql/writer.go @@ -523,15 +523,7 @@ func (w *Writer) generateColumnDefinition(col *models.Column) string { if col.Default != nil { switch v := col.Default.(type) { case string: - // Strip backticks - DBML uses them for SQL expressions but PostgreSQL doesn't - cleanDefault := stripBackticks(v) - if strings.HasPrefix(cleanDefault, "nextval") || strings.HasPrefix(cleanDefault, "CURRENT_") || strings.Contains(cleanDefault, "()") { - parts = append(parts, fmt.Sprintf("DEFAULT %s", cleanDefault)) - } else if cleanDefault == "true" || cleanDefault == "false" { - parts = append(parts, fmt.Sprintf("DEFAULT %s", cleanDefault)) - } else { - parts = append(parts, fmt.Sprintf("DEFAULT '%s'", escapeQuote(cleanDefault))) - } + parts = append(parts, fmt.Sprintf("DEFAULT %s", writers.QuoteDefaultValue(stripBackticks(v), col.Type))) case bool: parts = append(parts, fmt.Sprintf("DEFAULT %v", v)) default: diff --git a/pkg/writers/writer.go b/pkg/writers/writer.go index 86f92d9..0637ebf 100644 --- a/pkg/writers/writer.go +++ b/pkg/writers/writer.go @@ -110,8 +110,12 @@ func SanitizeFilename(name string) string { // Examples (bigint): "0" → "0" // Examples (timestamp): "now()" → "now()" (function call – never quoted) func QuoteDefaultValue(value, sqlType string) string { + value = strings.TrimSpace(value) + // Function calls are never quoted regardless of column type. - if strings.Contains(value, "(") || strings.Contains(value, ")") { + if strings.Contains(value, "(") || strings.Contains(value, ")") || + strings.Contains(value, "::") || + strings.HasPrefix(strings.ToUpper(value), "ARRAY[") { return value } @@ -121,6 +125,16 @@ func QuoteDefaultValue(value, sqlType string) string { baseType = baseType[:idx] } + if isArraySQLType(baseType) { + if arrayLiteral, ok := normalizeArrayDefaultLiteral(value); ok { + return quoteSQLLiteral(arrayLiteral) + } + } + + if isQuotedSQLLiteral(value) { + return value + } + // Types whose default values must NOT be quoted. unquotedTypes := map[string]bool{ // Integer types @@ -154,7 +168,32 @@ func QuoteDefaultValue(value, sqlType string) string { // Everything else (text, varchar, char, uuid, date, time, timestamp, json, …) // is treated as a quoted literal. - return "'" + value + "'" + return quoteSQLLiteral(value) +} + +func isArraySQLType(sqlType string) bool { + return strings.HasSuffix(sqlType, "[]") +} + +func normalizeArrayDefaultLiteral(value string) (string, bool) { + switch { + case strings.HasPrefix(value, "''{") && strings.HasSuffix(value, "}''"): + return value[2 : len(value)-2], true + case strings.HasPrefix(value, "'{") && strings.HasSuffix(value, "}'"): + return value[1 : len(value)-1], true + case strings.HasPrefix(value, "{") && strings.HasSuffix(value, "}"): + return value, true + default: + return "", false + } +} + +func isQuotedSQLLiteral(value string) bool { + return len(value) >= 2 && strings.HasPrefix(value, "'") && strings.HasSuffix(value, "'") +} + +func quoteSQLLiteral(value string) string { + return "'" + strings.ReplaceAll(value, "'", "''") + "'" } // SanitizeStructTagValue sanitizes a value to be safely used inside Go struct tags. diff --git a/pkg/writers/writer_test.go b/pkg/writers/writer_test.go new file mode 100644 index 0000000..63a2830 --- /dev/null +++ b/pkg/writers/writer_test.go @@ -0,0 +1,54 @@ +package writers + +import "testing" + +func TestQuoteDefaultValue(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + value string + sqlType string + want string + }{ + { + name: "text default is quoted", + value: "active", + sqlType: "text", + want: "'active'", + }, + { + name: "array default from bare literal is quoted once", + value: "{}", + sqlType: "text[]", + want: "'{}'", + }, + { + name: "array default from quoted literal is preserved", + value: "'{}'", + sqlType: "text[]", + want: "'{}'", + }, + { + name: "array default from double quoted literal is normalized", + value: "''{}''", + sqlType: "text[]", + want: "'{}'", + }, + { + name: "function default is left alone", + value: "now()", + sqlType: "timestamptz", + want: "now()", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := QuoteDefaultValue(tt.value, tt.sqlType) + if got != tt.want { + t.Fatalf("QuoteDefaultValue(%q, %q) = %q, want %q", tt.value, tt.sqlType, got, tt.want) + } + }) + } +}