fix(pgsql): handle default values for array types in migrations
* update default value quoting logic for PostgreSQL * add tests for array default value handling
This commit is contained in:
@@ -329,8 +329,12 @@ func (w *MigrationWriter) generateAlterTableScripts(schema *models.Schema, model
|
|||||||
// Column doesn't exist, add it
|
// Column doesn't exist, add it
|
||||||
defaultVal := ""
|
defaultVal := ""
|
||||||
if modelCol.Default != nil {
|
if modelCol.Default != nil {
|
||||||
|
if value, ok := modelCol.Default.(string); ok {
|
||||||
|
defaultVal = writers.QuoteDefaultValue(value, modelCol.Type)
|
||||||
|
} else {
|
||||||
defaultVal = fmt.Sprintf("%v", modelCol.Default)
|
defaultVal = fmt.Sprintf("%v", modelCol.Default)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
sql, err := w.executor.ExecuteAddColumn(AddColumnData{
|
sql, err := w.executor.ExecuteAddColumn(AddColumnData{
|
||||||
SchemaName: schema.Name,
|
SchemaName: schema.Name,
|
||||||
@@ -382,8 +386,12 @@ func (w *MigrationWriter) generateAlterTableScripts(schema *models.Schema, model
|
|||||||
setDefault := modelCol.Default != nil
|
setDefault := modelCol.Default != nil
|
||||||
defaultVal := ""
|
defaultVal := ""
|
||||||
if setDefault {
|
if setDefault {
|
||||||
|
if value, ok := modelCol.Default.(string); ok {
|
||||||
|
defaultVal = writers.QuoteDefaultValue(value, modelCol.Type)
|
||||||
|
} else {
|
||||||
defaultVal = fmt.Sprintf("%v", modelCol.Default)
|
defaultVal = fmt.Sprintf("%v", modelCol.Default)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
sql, err := w.executor.ExecuteAlterColumnDefault(AlterColumnDefaultData{
|
sql, err := w.executor.ExecuteAlterColumnDefault(AlterColumnDefaultData{
|
||||||
SchemaName: schema.Name,
|
SchemaName: schema.Name,
|
||||||
|
|||||||
@@ -57,6 +57,46 @@ func TestWriteMigration_NewTable(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestWriteMigration_ArrayDefault(t *testing.T) {
|
||||||
|
current := models.InitDatabase("testdb")
|
||||||
|
currentSchema := models.InitSchema("public")
|
||||||
|
current.Schemas = append(current.Schemas, currentSchema)
|
||||||
|
|
||||||
|
model := models.InitDatabase("testdb")
|
||||||
|
modelSchema := models.InitSchema("public")
|
||||||
|
|
||||||
|
table := models.InitTable("plans", "public")
|
||||||
|
|
||||||
|
tagsCol := models.InitColumn("tags", "plans", "public")
|
||||||
|
tagsCol.Type = "text[]"
|
||||||
|
tagsCol.NotNull = true
|
||||||
|
tagsCol.Default = "''{}''"
|
||||||
|
table.Columns["tags"] = tagsCol
|
||||||
|
|
||||||
|
modelSchema.Tables = append(modelSchema.Tables, table)
|
||||||
|
model.Schemas = append(model.Schemas, modelSchema)
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
writer, err := NewMigrationWriter(&writers.WriterOptions{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create writer: %v", err)
|
||||||
|
}
|
||||||
|
writer.writer = &buf
|
||||||
|
|
||||||
|
err = writer.WriteMigration(model, current)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("WriteMigration failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
output := buf.String()
|
||||||
|
if !strings.Contains(output, "tags text[] DEFAULT '{}' NOT NULL") {
|
||||||
|
t.Fatalf("expected normalized array default in migration, got:\n%s", output)
|
||||||
|
}
|
||||||
|
if strings.Contains(output, "'''{}'''") {
|
||||||
|
t.Fatalf("migration still contains triple-quoted array default:\n%s", output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestWriteMigration_WithAudit(t *testing.T) {
|
func TestWriteMigration_WithAudit(t *testing.T) {
|
||||||
// Current database (empty)
|
// Current database (empty)
|
||||||
current := models.InitDatabase("testdb")
|
current := models.InitDatabase("testdb")
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"text/template"
|
"text/template"
|
||||||
|
|
||||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||||
|
"git.warky.dev/wdevs/relspecgo/pkg/writers"
|
||||||
)
|
)
|
||||||
|
|
||||||
//go:embed templates/*.tmpl
|
//go:embed templates/*.tmpl
|
||||||
@@ -495,8 +496,12 @@ func BuildCreateTableData(schemaName string, table *models.Table) CreateTableDat
|
|||||||
NotNull: col.NotNull,
|
NotNull: col.NotNull,
|
||||||
}
|
}
|
||||||
if col.Default != nil {
|
if col.Default != nil {
|
||||||
|
if value, ok := col.Default.(string); ok {
|
||||||
|
colData.Default = writers.QuoteDefaultValue(value, col.Type)
|
||||||
|
} else {
|
||||||
colData.Default = fmt.Sprintf("%v", col.Default)
|
colData.Default = fmt.Sprintf("%v", col.Default)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
columns = append(columns, colData)
|
columns = append(columns, colData)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -523,15 +523,7 @@ func (w *Writer) generateColumnDefinition(col *models.Column) string {
|
|||||||
if col.Default != nil {
|
if col.Default != nil {
|
||||||
switch v := col.Default.(type) {
|
switch v := col.Default.(type) {
|
||||||
case string:
|
case string:
|
||||||
// Strip backticks - DBML uses them for SQL expressions but PostgreSQL doesn't
|
parts = append(parts, fmt.Sprintf("DEFAULT %s", writers.QuoteDefaultValue(stripBackticks(v), col.Type)))
|
||||||
cleanDefault := stripBackticks(v)
|
|
||||||
if strings.HasPrefix(cleanDefault, "nextval") || strings.HasPrefix(cleanDefault, "CURRENT_") || strings.Contains(cleanDefault, "()") {
|
|
||||||
parts = append(parts, fmt.Sprintf("DEFAULT %s", cleanDefault))
|
|
||||||
} else if cleanDefault == "true" || cleanDefault == "false" {
|
|
||||||
parts = append(parts, fmt.Sprintf("DEFAULT %s", cleanDefault))
|
|
||||||
} else {
|
|
||||||
parts = append(parts, fmt.Sprintf("DEFAULT '%s'", escapeQuote(cleanDefault)))
|
|
||||||
}
|
|
||||||
case bool:
|
case bool:
|
||||||
parts = append(parts, fmt.Sprintf("DEFAULT %v", v))
|
parts = append(parts, fmt.Sprintf("DEFAULT %v", v))
|
||||||
default:
|
default:
|
||||||
|
|||||||
@@ -110,8 +110,12 @@ func SanitizeFilename(name string) string {
|
|||||||
// Examples (bigint): "0" → "0"
|
// Examples (bigint): "0" → "0"
|
||||||
// Examples (timestamp): "now()" → "now()" (function call – never quoted)
|
// Examples (timestamp): "now()" → "now()" (function call – never quoted)
|
||||||
func QuoteDefaultValue(value, sqlType string) string {
|
func QuoteDefaultValue(value, sqlType string) string {
|
||||||
|
value = strings.TrimSpace(value)
|
||||||
|
|
||||||
// Function calls are never quoted regardless of column type.
|
// Function calls are never quoted regardless of column type.
|
||||||
if strings.Contains(value, "(") || strings.Contains(value, ")") {
|
if strings.Contains(value, "(") || strings.Contains(value, ")") ||
|
||||||
|
strings.Contains(value, "::") ||
|
||||||
|
strings.HasPrefix(strings.ToUpper(value), "ARRAY[") {
|
||||||
return value
|
return value
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -121,6 +125,16 @@ func QuoteDefaultValue(value, sqlType string) string {
|
|||||||
baseType = baseType[:idx]
|
baseType = baseType[:idx]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if isArraySQLType(baseType) {
|
||||||
|
if arrayLiteral, ok := normalizeArrayDefaultLiteral(value); ok {
|
||||||
|
return quoteSQLLiteral(arrayLiteral)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if isQuotedSQLLiteral(value) {
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
|
||||||
// Types whose default values must NOT be quoted.
|
// Types whose default values must NOT be quoted.
|
||||||
unquotedTypes := map[string]bool{
|
unquotedTypes := map[string]bool{
|
||||||
// Integer types
|
// Integer types
|
||||||
@@ -154,7 +168,32 @@ func QuoteDefaultValue(value, sqlType string) string {
|
|||||||
|
|
||||||
// Everything else (text, varchar, char, uuid, date, time, timestamp, json, …)
|
// Everything else (text, varchar, char, uuid, date, time, timestamp, json, …)
|
||||||
// is treated as a quoted literal.
|
// is treated as a quoted literal.
|
||||||
return "'" + value + "'"
|
return quoteSQLLiteral(value)
|
||||||
|
}
|
||||||
|
|
||||||
|
func isArraySQLType(sqlType string) bool {
|
||||||
|
return strings.HasSuffix(sqlType, "[]")
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeArrayDefaultLiteral(value string) (string, bool) {
|
||||||
|
switch {
|
||||||
|
case strings.HasPrefix(value, "''{") && strings.HasSuffix(value, "}''"):
|
||||||
|
return value[2 : len(value)-2], true
|
||||||
|
case strings.HasPrefix(value, "'{") && strings.HasSuffix(value, "}'"):
|
||||||
|
return value[1 : len(value)-1], true
|
||||||
|
case strings.HasPrefix(value, "{") && strings.HasSuffix(value, "}"):
|
||||||
|
return value, true
|
||||||
|
default:
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func isQuotedSQLLiteral(value string) bool {
|
||||||
|
return len(value) >= 2 && strings.HasPrefix(value, "'") && strings.HasSuffix(value, "'")
|
||||||
|
}
|
||||||
|
|
||||||
|
func quoteSQLLiteral(value string) string {
|
||||||
|
return "'" + strings.ReplaceAll(value, "'", "''") + "'"
|
||||||
}
|
}
|
||||||
|
|
||||||
// SanitizeStructTagValue sanitizes a value to be safely used inside Go struct tags.
|
// SanitizeStructTagValue sanitizes a value to be safely used inside Go struct tags.
|
||||||
|
|||||||
54
pkg/writers/writer_test.go
Normal file
54
pkg/writers/writer_test.go
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
package writers
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestQuoteDefaultValue(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
value string
|
||||||
|
sqlType string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "text default is quoted",
|
||||||
|
value: "active",
|
||||||
|
sqlType: "text",
|
||||||
|
want: "'active'",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "array default from bare literal is quoted once",
|
||||||
|
value: "{}",
|
||||||
|
sqlType: "text[]",
|
||||||
|
want: "'{}'",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "array default from quoted literal is preserved",
|
||||||
|
value: "'{}'",
|
||||||
|
sqlType: "text[]",
|
||||||
|
want: "'{}'",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "array default from double quoted literal is normalized",
|
||||||
|
value: "''{}''",
|
||||||
|
sqlType: "text[]",
|
||||||
|
want: "'{}'",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "function default is left alone",
|
||||||
|
value: "now()",
|
||||||
|
sqlType: "timestamptz",
|
||||||
|
want: "now()",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := QuoteDefaultValue(tt.value, tt.sqlType)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Fatalf("QuoteDefaultValue(%q, %q) = %q, want %q", tt.value, tt.sqlType, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user