feat: Enhance PostgreSQL type handling and migration scripts
- Introduced equivalent base types and variants for PostgreSQL types to normalize type comparisons. - Added functions for normalizing SQL types and retrieving equivalent type variants. - Updated migration writer to handle type alterations with checks for existing types. - Implemented logic to create necessary extensions (e.g., pg_trgm) based on schema requirements. - Enhanced tests to cover new functionality for type normalization and migration handling. - Improved handling of GIN indexes to use appropriate operator classes based on column types.
This commit is contained in:
@@ -145,6 +145,24 @@ var postgresTypeAliases = map[string]string{
|
||||
"bool": "boolean",
|
||||
}
|
||||
|
||||
var postgresEquivalentBaseTypes = map[string]string{
|
||||
"character varying": "varchar",
|
||||
"character": "char",
|
||||
"timestamp without time zone": "timestamp",
|
||||
"timestamp with time zone": "timestamptz",
|
||||
"time without time zone": "time",
|
||||
"time with time zone": "timetz",
|
||||
}
|
||||
|
||||
var postgresEquivalentBaseTypeVariants = map[string][]string{
|
||||
"varchar": {"varchar", "character varying"},
|
||||
"char": {"char", "character"},
|
||||
"timestamp": {"timestamp", "timestamp without time zone"},
|
||||
"timestamptz": {"timestamptz", "timestamp with time zone"},
|
||||
"time": {"time", "time without time zone"},
|
||||
"timetz": {"timetz", "time with time zone"},
|
||||
}
|
||||
|
||||
// GetPostgresBaseTypes returns a sorted-ish stable list of registered base type names.
|
||||
func GetPostgresBaseTypes() []string {
|
||||
result := make([]string, 0, len(postgresBaseTypes))
|
||||
@@ -212,6 +230,86 @@ func CanonicalizeBaseType(baseType string) string {
|
||||
return base
|
||||
}
|
||||
|
||||
// EquivalentBaseType resolves broader SQL-equivalent spellings to a common comparable form.
|
||||
func EquivalentBaseType(baseType string) string {
|
||||
base := CanonicalizeBaseType(baseType)
|
||||
if equivalent, ok := postgresEquivalentBaseTypes[base]; ok {
|
||||
return equivalent
|
||||
}
|
||||
return base
|
||||
}
|
||||
|
||||
// NormalizeEquivalentSQLType returns a normalized SQL type string suitable for equality checks.
|
||||
// Equivalent spellings such as "character varying(255)" and "varchar(255)" normalize identically.
|
||||
func NormalizeEquivalentSQLType(sqlType string) string {
|
||||
t := normalizeTypeToken(sqlType)
|
||||
if t == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
arrayDepth := 0
|
||||
for strings.HasSuffix(t, "[]") {
|
||||
arrayDepth++
|
||||
t = strings.TrimSpace(strings.TrimSuffix(t, "[]"))
|
||||
}
|
||||
|
||||
modifier := ""
|
||||
if idx := strings.Index(t, "("); idx >= 0 {
|
||||
modifier = strings.TrimSpace(t[idx:])
|
||||
t = strings.TrimSpace(t[:idx])
|
||||
}
|
||||
|
||||
base := EquivalentBaseType(t)
|
||||
normalized := base + modifier
|
||||
for i := 0; i < arrayDepth; i++ {
|
||||
normalized += "[]"
|
||||
}
|
||||
return normalized
|
||||
}
|
||||
|
||||
// EquivalentSQLTypeVariants returns equivalent PostgreSQL spellings for a SQL type.
|
||||
// Examples:
|
||||
// - varchar(255) -> ["varchar(255)", "character varying(255)"]
|
||||
// - timestamptz -> ["timestamptz", "timestamp with time zone"]
|
||||
func EquivalentSQLTypeVariants(sqlType string) []string {
|
||||
t := normalizeTypeToken(sqlType)
|
||||
if t == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
arrayDepth := 0
|
||||
for strings.HasSuffix(t, "[]") {
|
||||
arrayDepth++
|
||||
t = strings.TrimSpace(strings.TrimSuffix(t, "[]"))
|
||||
}
|
||||
|
||||
modifier := ""
|
||||
if idx := strings.Index(t, "("); idx >= 0 {
|
||||
modifier = strings.TrimSpace(t[idx:])
|
||||
t = strings.TrimSpace(t[:idx])
|
||||
}
|
||||
|
||||
base := EquivalentBaseType(t)
|
||||
bases := postgresEquivalentBaseTypeVariants[base]
|
||||
if len(bases) == 0 {
|
||||
bases = []string{base}
|
||||
}
|
||||
|
||||
seen := make(map[string]bool, len(bases))
|
||||
result := make([]string, 0, len(bases))
|
||||
for _, variantBase := range bases {
|
||||
variant := variantBase + modifier
|
||||
for i := 0; i < arrayDepth; i++ {
|
||||
variant += "[]"
|
||||
}
|
||||
if !seen[variant] {
|
||||
seen[variant] = true
|
||||
result = append(result, variant)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// IsKnownPostgresType reports whether a type (including array forms) exists in the registry.
|
||||
func IsKnownPostgresType(sqlType string) bool {
|
||||
base := CanonicalizeBaseType(ExtractBaseTypeLower(sqlType))
|
||||
|
||||
@@ -97,3 +97,51 @@ func TestPostgresTypeRegistry_TypeParsingAndCapabilities(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeEquivalentSQLType(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{input: "character varying(255)", want: "varchar(255)"},
|
||||
{input: "varchar(255)", want: "varchar(255)"},
|
||||
{input: "timestamp with time zone", want: "timestamptz"},
|
||||
{input: "timestamptz", want: "timestamptz"},
|
||||
{input: "time without time zone", want: "time"},
|
||||
{input: "character varying(255)[]", want: "varchar(255)[]"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
got := NormalizeEquivalentSQLType(tt.input)
|
||||
if got != tt.want {
|
||||
t.Fatalf("NormalizeEquivalentSQLType(%q) = %q, want %q", tt.input, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEquivalentSQLTypeVariants(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
want []string
|
||||
}{
|
||||
{input: "character varying(255)", want: []string{"varchar(255)", "character varying(255)"}},
|
||||
{input: "timestamptz", want: []string{"timestamptz", "timestamp with time zone"}},
|
||||
{input: "text[]", want: []string{"text[]"}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
got := EquivalentSQLTypeVariants(tt.input)
|
||||
if len(got) != len(tt.want) {
|
||||
t.Fatalf("EquivalentSQLTypeVariants(%q) len = %d, want %d (%v)", tt.input, len(got), len(tt.want), got)
|
||||
}
|
||||
for i := range tt.want {
|
||||
if got[i] != tt.want[i] {
|
||||
t.Fatalf("EquivalentSQLTypeVariants(%q)[%d] = %q, want %q", tt.input, i, got[i], tt.want[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user