fix(pgsql): handle default values for array types in migrations
* update default value quoting logic for PostgreSQL * add tests for array default value handling
This commit is contained in:
@@ -329,7 +329,11 @@ func (w *MigrationWriter) generateAlterTableScripts(schema *models.Schema, model
|
||||
// Column doesn't exist, add it
|
||||
defaultVal := ""
|
||||
if modelCol.Default != nil {
|
||||
defaultVal = fmt.Sprintf("%v", modelCol.Default)
|
||||
if value, ok := modelCol.Default.(string); ok {
|
||||
defaultVal = writers.QuoteDefaultValue(value, modelCol.Type)
|
||||
} else {
|
||||
defaultVal = fmt.Sprintf("%v", modelCol.Default)
|
||||
}
|
||||
}
|
||||
|
||||
sql, err := w.executor.ExecuteAddColumn(AddColumnData{
|
||||
@@ -382,7 +386,11 @@ func (w *MigrationWriter) generateAlterTableScripts(schema *models.Schema, model
|
||||
setDefault := modelCol.Default != nil
|
||||
defaultVal := ""
|
||||
if setDefault {
|
||||
defaultVal = fmt.Sprintf("%v", modelCol.Default)
|
||||
if value, ok := modelCol.Default.(string); ok {
|
||||
defaultVal = writers.QuoteDefaultValue(value, modelCol.Type)
|
||||
} else {
|
||||
defaultVal = fmt.Sprintf("%v", modelCol.Default)
|
||||
}
|
||||
}
|
||||
|
||||
sql, err := w.executor.ExecuteAlterColumnDefault(AlterColumnDefaultData{
|
||||
|
||||
@@ -57,6 +57,46 @@ func TestWriteMigration_NewTable(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteMigration_ArrayDefault(t *testing.T) {
|
||||
current := models.InitDatabase("testdb")
|
||||
currentSchema := models.InitSchema("public")
|
||||
current.Schemas = append(current.Schemas, currentSchema)
|
||||
|
||||
model := models.InitDatabase("testdb")
|
||||
modelSchema := models.InitSchema("public")
|
||||
|
||||
table := models.InitTable("plans", "public")
|
||||
|
||||
tagsCol := models.InitColumn("tags", "plans", "public")
|
||||
tagsCol.Type = "text[]"
|
||||
tagsCol.NotNull = true
|
||||
tagsCol.Default = "''{}''"
|
||||
table.Columns["tags"] = tagsCol
|
||||
|
||||
modelSchema.Tables = append(modelSchema.Tables, table)
|
||||
model.Schemas = append(model.Schemas, modelSchema)
|
||||
|
||||
var buf bytes.Buffer
|
||||
writer, err := NewMigrationWriter(&writers.WriterOptions{})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create writer: %v", err)
|
||||
}
|
||||
writer.writer = &buf
|
||||
|
||||
err = writer.WriteMigration(model, current)
|
||||
if err != nil {
|
||||
t.Fatalf("WriteMigration failed: %v", err)
|
||||
}
|
||||
|
||||
output := buf.String()
|
||||
if !strings.Contains(output, "tags text[] DEFAULT '{}' NOT NULL") {
|
||||
t.Fatalf("expected normalized array default in migration, got:\n%s", output)
|
||||
}
|
||||
if strings.Contains(output, "'''{}'''") {
|
||||
t.Fatalf("migration still contains triple-quoted array default:\n%s", output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteMigration_WithAudit(t *testing.T) {
|
||||
// Current database (empty)
|
||||
current := models.InitDatabase("testdb")
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"text/template"
|
||||
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/writers"
|
||||
)
|
||||
|
||||
//go:embed templates/*.tmpl
|
||||
@@ -495,7 +496,11 @@ func BuildCreateTableData(schemaName string, table *models.Table) CreateTableDat
|
||||
NotNull: col.NotNull,
|
||||
}
|
||||
if col.Default != nil {
|
||||
colData.Default = fmt.Sprintf("%v", col.Default)
|
||||
if value, ok := col.Default.(string); ok {
|
||||
colData.Default = writers.QuoteDefaultValue(value, col.Type)
|
||||
} else {
|
||||
colData.Default = fmt.Sprintf("%v", col.Default)
|
||||
}
|
||||
}
|
||||
columns = append(columns, colData)
|
||||
}
|
||||
|
||||
@@ -523,15 +523,7 @@ func (w *Writer) generateColumnDefinition(col *models.Column) string {
|
||||
if col.Default != nil {
|
||||
switch v := col.Default.(type) {
|
||||
case string:
|
||||
// Strip backticks - DBML uses them for SQL expressions but PostgreSQL doesn't
|
||||
cleanDefault := stripBackticks(v)
|
||||
if strings.HasPrefix(cleanDefault, "nextval") || strings.HasPrefix(cleanDefault, "CURRENT_") || strings.Contains(cleanDefault, "()") {
|
||||
parts = append(parts, fmt.Sprintf("DEFAULT %s", cleanDefault))
|
||||
} else if cleanDefault == "true" || cleanDefault == "false" {
|
||||
parts = append(parts, fmt.Sprintf("DEFAULT %s", cleanDefault))
|
||||
} else {
|
||||
parts = append(parts, fmt.Sprintf("DEFAULT '%s'", escapeQuote(cleanDefault)))
|
||||
}
|
||||
parts = append(parts, fmt.Sprintf("DEFAULT %s", writers.QuoteDefaultValue(stripBackticks(v), col.Type)))
|
||||
case bool:
|
||||
parts = append(parts, fmt.Sprintf("DEFAULT %v", v))
|
||||
default:
|
||||
|
||||
Reference in New Issue
Block a user