test(drawdb): add test for converting column types with modifiers
* Implement tests to ensure explicit type modifiers are preserved during conversion. * Validate behavior for varchar, numeric, and custom vector types.
This commit is contained in:
@@ -5,6 +5,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/pgsql"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/writers"
|
||||
)
|
||||
|
||||
@@ -39,14 +40,7 @@ func (tm *TypeMapper) SQLTypeToGoType(sqlType string, notNull bool) string {
|
||||
|
||||
// extractBaseType extracts the base type from a SQL type string
|
||||
func (tm *TypeMapper) extractBaseType(sqlType string) string {
|
||||
sqlType = strings.ToLower(strings.TrimSpace(sqlType))
|
||||
|
||||
// Remove everything after '('
|
||||
if idx := strings.Index(sqlType, "("); idx > 0 {
|
||||
sqlType = sqlType[:idx]
|
||||
}
|
||||
|
||||
return sqlType
|
||||
return pgsql.CanonicalizeBaseType(pgsql.ExtractBaseTypeLower(sqlType))
|
||||
}
|
||||
|
||||
// isSimpleType checks if a type should use base Go type when NOT NULL
|
||||
@@ -184,9 +178,10 @@ func (tm *TypeMapper) BuildBunTag(column *models.Column, table *models.Table) st
|
||||
if column.Type != "" {
|
||||
// Sanitize type to remove backticks
|
||||
typeStr := writers.SanitizeStructTagValue(column.Type)
|
||||
if column.Length > 0 {
|
||||
hasExplicitTypeModifier := pgsql.HasExplicitTypeModifier(typeStr)
|
||||
if !hasExplicitTypeModifier && column.Length > 0 {
|
||||
typeStr = fmt.Sprintf("%s(%d)", typeStr, column.Length)
|
||||
} else if column.Precision > 0 {
|
||||
} else if !hasExplicitTypeModifier && column.Precision > 0 {
|
||||
if column.Scale > 0 {
|
||||
typeStr = fmt.Sprintf("%s(%d,%d)", typeStr, column.Precision, column.Scale)
|
||||
} else {
|
||||
|
||||
@@ -698,3 +698,23 @@ func TestTypeMapper_BuildBunTag(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTypeMapper_BuildBunTag_PreservesExplicitTypeModifiers(t *testing.T) {
|
||||
mapper := NewTypeMapper()
|
||||
|
||||
col := &models.Column{
|
||||
Name: "embedding",
|
||||
Type: "vector(1536)",
|
||||
Length: 1536,
|
||||
Precision: 0,
|
||||
Scale: 0,
|
||||
}
|
||||
|
||||
tag := mapper.BuildBunTag(col, nil)
|
||||
if !strings.Contains(tag, "type:vector(1536),") {
|
||||
t.Fatalf("expected explicit modifier to be preserved, got %q", tag)
|
||||
}
|
||||
if strings.Contains(tag, ")(") {
|
||||
t.Fatalf("type modifier appears duplicated in %q", tag)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/pgsql"
|
||||
)
|
||||
|
||||
// TypeMapper handles SQL to Drizzle type conversions
|
||||
@@ -18,7 +19,7 @@ func NewTypeMapper() *TypeMapper {
|
||||
// SQLTypeToDrizzle converts SQL types to Drizzle column type functions
|
||||
// Returns the Drizzle column constructor (e.g., "integer", "varchar", "text")
|
||||
func (tm *TypeMapper) SQLTypeToDrizzle(sqlType string) string {
|
||||
sqlTypeLower := strings.ToLower(sqlType)
|
||||
sqlTypeLower := pgsql.CanonicalizeBaseType(pgsql.ExtractBaseTypeLower(sqlType))
|
||||
|
||||
// PostgreSQL type mapping to Drizzle
|
||||
typeMap := map[string]string{
|
||||
@@ -87,13 +88,6 @@ func (tm *TypeMapper) SQLTypeToDrizzle(sqlType string) string {
|
||||
return drizzleType
|
||||
}
|
||||
|
||||
// Check for partial matches (e.g., "varchar(255)" -> "varchar")
|
||||
for sqlPattern, drizzleType := range typeMap {
|
||||
if strings.HasPrefix(sqlTypeLower, sqlPattern) {
|
||||
return drizzleType
|
||||
}
|
||||
}
|
||||
|
||||
// Default to text for unknown types
|
||||
return "text"
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/pgsql"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/writers"
|
||||
)
|
||||
|
||||
@@ -39,14 +40,7 @@ func (tm *TypeMapper) SQLTypeToGoType(sqlType string, notNull bool) string {
|
||||
// extractBaseType extracts the base type from a SQL type string
|
||||
// Examples: varchar(100) → varchar, numeric(10,2) → numeric
|
||||
func (tm *TypeMapper) extractBaseType(sqlType string) string {
|
||||
sqlType = strings.ToLower(strings.TrimSpace(sqlType))
|
||||
|
||||
// Remove everything after '('
|
||||
if idx := strings.Index(sqlType, "("); idx > 0 {
|
||||
sqlType = sqlType[:idx]
|
||||
}
|
||||
|
||||
return sqlType
|
||||
return pgsql.CanonicalizeBaseType(pgsql.ExtractBaseTypeLower(sqlType))
|
||||
}
|
||||
|
||||
// baseGoType returns the base Go type for a SQL type (not null)
|
||||
@@ -209,9 +203,10 @@ func (tm *TypeMapper) BuildGormTag(column *models.Column, table *models.Table) s
|
||||
// Include length, precision, scale if present
|
||||
// Sanitize type to remove backticks
|
||||
typeStr := writers.SanitizeStructTagValue(column.Type)
|
||||
if column.Length > 0 {
|
||||
hasExplicitTypeModifier := pgsql.HasExplicitTypeModifier(typeStr)
|
||||
if !hasExplicitTypeModifier && column.Length > 0 {
|
||||
typeStr = fmt.Sprintf("%s(%d)", typeStr, column.Length)
|
||||
} else if column.Precision > 0 {
|
||||
} else if !hasExplicitTypeModifier && column.Precision > 0 {
|
||||
if column.Scale > 0 {
|
||||
typeStr = fmt.Sprintf("%s(%d,%d)", typeStr, column.Precision, column.Scale)
|
||||
} else {
|
||||
|
||||
@@ -14,12 +14,12 @@ func TestWriter_WriteTable(t *testing.T) {
|
||||
// Create a simple table
|
||||
table := models.InitTable("users", "public")
|
||||
table.Columns["id"] = &models.Column{
|
||||
Name: "id",
|
||||
Type: "bigint",
|
||||
NotNull: true,
|
||||
IsPrimaryKey: true,
|
||||
Name: "id",
|
||||
Type: "bigint",
|
||||
NotNull: true,
|
||||
IsPrimaryKey: true,
|
||||
AutoIncrement: true,
|
||||
Sequence: 1,
|
||||
Sequence: 1,
|
||||
}
|
||||
table.Columns["email"] = &models.Column{
|
||||
Name: "email",
|
||||
@@ -444,10 +444,10 @@ func TestWriter_MultipleHasManyRelationships(t *testing.T) {
|
||||
|
||||
// Verify all has-many relationships have unique names
|
||||
hasManyExpectations := []string{
|
||||
"RelRIDAPIProviderOrgLogins", // Has many via Login
|
||||
"RelRIDAPIProviderOrgLogins", // Has many via Login
|
||||
"RelRIDAPIProviderOrgFilepointers", // Has many via Filepointer
|
||||
"RelRIDAPIProviderOrgAPIEvents", // Has many via APIEvent
|
||||
"RelRIDOwner", // Belongs to via rid_owner
|
||||
"RelRIDAPIProviderOrgAPIEvents", // Has many via APIEvent
|
||||
"RelRIDOwner", // Belongs to via rid_owner
|
||||
}
|
||||
|
||||
for _, exp := range hasManyExpectations {
|
||||
@@ -669,3 +669,23 @@ func TestTypeMapper_SQLTypeToGoType(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTypeMapper_BuildGormTag_PreservesExplicitTypeModifiers(t *testing.T) {
|
||||
mapper := NewTypeMapper()
|
||||
|
||||
col := &models.Column{
|
||||
Name: "embedding",
|
||||
Type: "vector(1536)",
|
||||
Length: 1536,
|
||||
Precision: 0,
|
||||
Scale: 0,
|
||||
}
|
||||
|
||||
tag := mapper.BuildGormTag(col, nil)
|
||||
if !strings.Contains(tag, "type:vector(1536)") {
|
||||
t.Fatalf("expected explicit modifier to be preserved, got %q", tag)
|
||||
}
|
||||
if strings.Contains(tag, ")(") {
|
||||
t.Fatalf("type modifier appears duplicated in %q", tag)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/pgsql"
|
||||
)
|
||||
|
||||
func (w *Writer) sqlTypeToGraphQL(sqlType string, column *models.Column, table *models.Table, schema *models.Schema) string {
|
||||
@@ -33,12 +34,11 @@ func (w *Writer) sqlTypeToGraphQL(sqlType string, column *models.Column, table *
|
||||
}
|
||||
|
||||
// Standard type mappings
|
||||
baseType := strings.Split(sqlType, "(")[0] // Remove length/precision
|
||||
baseType = strings.TrimSpace(baseType)
|
||||
baseType := pgsql.CanonicalizeBaseType(pgsql.ExtractBaseTypeLower(sqlType))
|
||||
|
||||
// Handle array types
|
||||
if strings.HasSuffix(baseType, "[]") {
|
||||
elemType := strings.TrimSuffix(baseType, "[]")
|
||||
if pgsql.IsArrayType(sqlType) {
|
||||
elemType := pgsql.CanonicalizeBaseType(pgsql.ExtractBaseTypeLower(pgsql.ElementType(sqlType)))
|
||||
gqlType := w.mapBaseTypeToGraphQL(elemType)
|
||||
return "[" + gqlType + "]"
|
||||
}
|
||||
@@ -108,8 +108,7 @@ func (w *Writer) sqlTypeToCustomScalar(sqlType string) string {
|
||||
"date": "Date",
|
||||
}
|
||||
|
||||
baseType := strings.Split(sqlType, "(")[0]
|
||||
baseType = strings.TrimSpace(baseType)
|
||||
baseType := pgsql.CanonicalizeBaseType(pgsql.ExtractBaseTypeLower(sqlType))
|
||||
|
||||
if scalar, ok := scalarMap[baseType]; ok {
|
||||
return scalar
|
||||
@@ -132,8 +131,7 @@ func (w *Writer) isIntegerType(sqlType string) bool {
|
||||
"smallserial": true,
|
||||
}
|
||||
|
||||
baseType := strings.Split(sqlType, "(")[0]
|
||||
baseType = strings.TrimSpace(baseType)
|
||||
baseType := pgsql.CanonicalizeBaseType(pgsql.ExtractBaseTypeLower(sqlType))
|
||||
|
||||
return intTypes[baseType]
|
||||
}
|
||||
|
||||
@@ -493,18 +493,19 @@ func (w *Writer) generateColumnDefinition(col *models.Column) string {
|
||||
// Type with length/precision - convert to valid PostgreSQL type
|
||||
baseType := pgsql.ConvertSQLType(col.Type)
|
||||
typeStr := baseType
|
||||
hasExplicitTypeModifier := pgsql.HasExplicitTypeModifier(baseType)
|
||||
|
||||
// Only add size specifiers for types that support them
|
||||
if col.Length > 0 && col.Precision == 0 {
|
||||
if supportsLength(baseType) {
|
||||
if !hasExplicitTypeModifier && col.Length > 0 && col.Precision == 0 {
|
||||
if pgsql.SupportsLength(baseType) {
|
||||
typeStr = fmt.Sprintf("%s(%d)", baseType, col.Length)
|
||||
} else if isTextTypeWithoutLength(baseType) {
|
||||
// Convert text with length to varchar
|
||||
typeStr = fmt.Sprintf("varchar(%d)", col.Length)
|
||||
}
|
||||
// For types that don't support length (integer, bigint, etc.), ignore the length
|
||||
} else if col.Precision > 0 {
|
||||
if supportsPrecision(baseType) {
|
||||
} else if !hasExplicitTypeModifier && col.Precision > 0 {
|
||||
if pgsql.SupportsPrecision(baseType) {
|
||||
if col.Scale > 0 {
|
||||
typeStr = fmt.Sprintf("%s(%d,%d)", baseType, col.Precision, col.Scale)
|
||||
} else {
|
||||
@@ -1268,30 +1269,6 @@ func isTextType(colType string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// supportsLength checks if a PostgreSQL type supports length specification
|
||||
func supportsLength(colType string) bool {
|
||||
lengthTypes := []string{"varchar", "character varying", "char", "character", "bit", "bit varying", "varbit"}
|
||||
lowerType := strings.ToLower(colType)
|
||||
for _, t := range lengthTypes {
|
||||
if lowerType == t || strings.HasPrefix(lowerType, t+"(") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// supportsPrecision checks if a PostgreSQL type supports precision/scale specification
|
||||
func supportsPrecision(colType string) bool {
|
||||
precisionTypes := []string{"numeric", "decimal", "time", "timestamp", "timestamptz", "timestamp with time zone", "timestamp without time zone", "time with time zone", "time without time zone", "interval"}
|
||||
lowerType := strings.ToLower(colType)
|
||||
for _, t := range precisionTypes {
|
||||
if lowerType == t || strings.HasPrefix(lowerType, t+"(") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// isTextTypeWithoutLength checks if type is text (which should convert to varchar when length is specified)
|
||||
func isTextTypeWithoutLength(colType string) bool {
|
||||
return strings.EqualFold(colType, "text")
|
||||
|
||||
@@ -426,11 +426,11 @@ func TestWriteAllConstraintTypes(t *testing.T) {
|
||||
|
||||
// Verify all constraint types are present
|
||||
expectedConstraints := map[string]string{
|
||||
"Primary Key": "PRIMARY KEY",
|
||||
"Unique": "ADD CONSTRAINT uq_order_number UNIQUE (order_number)",
|
||||
"Check (total)": "ADD CONSTRAINT ck_total_positive CHECK (total > 0)",
|
||||
"Check (status)": "ADD CONSTRAINT ck_status_valid CHECK (status IN ('pending', 'completed', 'cancelled'))",
|
||||
"Foreign Key": "FOREIGN KEY",
|
||||
"Primary Key": "PRIMARY KEY",
|
||||
"Unique": "ADD CONSTRAINT uq_order_number UNIQUE (order_number)",
|
||||
"Check (total)": "ADD CONSTRAINT ck_total_positive CHECK (total > 0)",
|
||||
"Check (status)": "ADD CONSTRAINT ck_status_valid CHECK (status IN ('pending', 'completed', 'cancelled'))",
|
||||
"Foreign Key": "FOREIGN KEY",
|
||||
}
|
||||
|
||||
for name, expected := range expectedConstraints {
|
||||
@@ -715,11 +715,11 @@ func TestColumnSizeSpecifiers(t *testing.T) {
|
||||
|
||||
// Verify valid patterns ARE present
|
||||
validPatterns := []string{
|
||||
"integer", // without size
|
||||
"bigint", // without size
|
||||
"smallint", // without size
|
||||
"varchar(100)", // text converted to varchar with length
|
||||
"varchar(50)", // varchar with length
|
||||
"integer", // without size
|
||||
"bigint", // without size
|
||||
"smallint", // without size
|
||||
"varchar(100)", // text converted to varchar with length
|
||||
"varchar(50)", // varchar with length
|
||||
"decimal(19,4)", // decimal with precision and scale
|
||||
}
|
||||
for _, pattern := range validPatterns {
|
||||
@@ -729,6 +729,56 @@ func TestColumnSizeSpecifiers(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateColumnDefinition_PreservesExplicitTypeModifiers(t *testing.T) {
|
||||
writer := NewWriter(&writers.WriterOptions{})
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
colType string
|
||||
length int
|
||||
precision int
|
||||
scale int
|
||||
wantType string
|
||||
}{
|
||||
{
|
||||
name: "character varying already includes length",
|
||||
colType: "character varying(50)",
|
||||
length: 50,
|
||||
wantType: "character varying(50)",
|
||||
},
|
||||
{
|
||||
name: "numeric already includes precision",
|
||||
colType: "numeric(10,2)",
|
||||
precision: 10,
|
||||
scale: 2,
|
||||
wantType: "numeric(10,2)",
|
||||
},
|
||||
{
|
||||
name: "custom vector modifier preserved",
|
||||
colType: "vector(1536)",
|
||||
wantType: "vector(1536)",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
col := models.InitColumn("sample", "events", "public")
|
||||
col.Type = tc.colType
|
||||
col.Length = tc.length
|
||||
col.Precision = tc.precision
|
||||
col.Scale = tc.scale
|
||||
|
||||
def := writer.generateColumnDefinition(col)
|
||||
if !strings.Contains(def, " "+tc.wantType+" ") && !strings.HasSuffix(def, " "+tc.wantType) {
|
||||
t.Fatalf("generated definition %q does not contain expected type %q", def, tc.wantType)
|
||||
}
|
||||
if strings.Contains(def, ")(") {
|
||||
t.Fatalf("generated definition %q appears to duplicate modifiers", def)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateAddColumnStatements(t *testing.T) {
|
||||
// Create a test database with tables that have new columns
|
||||
db := models.InitDatabase("testdb")
|
||||
|
||||
Reference in New Issue
Block a user