test(drawdb): add test for converting column types with modifiers

* Implement tests to ensure explicit type modifiers are preserved during conversion.
* Validate behavior for varchar, numeric, and custom vector types.
This commit is contained in:
2026-04-26 12:35:54 +02:00
parent 535a91d4be
commit 988798998d
24 changed files with 1052 additions and 264 deletions

View File

@@ -12,6 +12,7 @@ import (
"strings"
"git.warky.dev/wdevs/relspecgo/pkg/models"
"git.warky.dev/wdevs/relspecgo/pkg/pgsql"
"git.warky.dev/wdevs/relspecgo/pkg/readers"
)
@@ -700,16 +701,21 @@ func (r *Reader) extractBunTag(tag string) string {
// parseTypeWithLength parses a type string and extracts length if present
// e.g., "varchar(255)" returns ("varchar", 255)
func (r *Reader) parseTypeWithLength(typeStr string) (baseType string, length int) {
typeStr = strings.TrimSpace(typeStr)
baseType = typeStr
// Check for type with length: varchar(255), char(10), etc.
re := regexp.MustCompile(`^([a-zA-Z\s]+)\((\d+)\)$`)
matches := re.FindStringSubmatch(typeStr)
if len(matches) == 3 {
if _, err := fmt.Sscanf(matches[2], "%d", &length); err == nil {
baseType = strings.TrimSpace(matches[1])
return
rawBaseType := strings.TrimSpace(matches[1])
if pgsql.SupportsLength(rawBaseType) {
if _, err := fmt.Sscanf(matches[2], "%d", &length); err == nil {
return
}
}
}
baseType = typeStr
return
}

View File

@@ -71,8 +71,11 @@ func TestReader_ReadDatabase_Simple(t *testing.T) {
if !emailCol.NotNull {
t.Error("Column 'email' should be NOT NULL (explicit 'notnull' tag)")
}
if emailCol.Type != "varchar" || emailCol.Length != 255 {
t.Errorf("Expected email type 'varchar(255)', got '%s' with length %d", emailCol.Type, emailCol.Length)
if emailCol.Type != "varchar" && emailCol.Type != "varchar(255)" {
t.Errorf("Expected email type 'varchar' or 'varchar(255)', got '%s' with length %d", emailCol.Type, emailCol.Length)
}
if emailCol.Length != 255 {
t.Errorf("Expected email length 255, got %d", emailCol.Length)
}
// Verify name column - primitive string type should be NOT NULL by default in Bun
@@ -356,6 +359,33 @@ func TestReader_ReadDatabase_Complex(t *testing.T) {
}
}
func TestParseTypeWithLength_PreservesExplicitTypeModifiers(t *testing.T) {
reader := &Reader{}
tests := []struct {
input string
wantType string
wantLength int
}{
{"varchar(255)", "varchar(255)", 255},
{"character varying(120)", "character varying(120)", 120},
{"vector(1536)", "vector(1536)", 1536},
{"numeric(10,2)", "numeric(10,2)", 0},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
gotType, gotLength := reader.parseTypeWithLength(tt.input)
if gotType != tt.wantType {
t.Fatalf("parseTypeWithLength(%q) type = %q, want %q", tt.input, gotType, tt.wantType)
}
if gotLength != tt.wantLength {
t.Fatalf("parseTypeWithLength(%q) length = %d, want %d", tt.input, gotLength, tt.wantLength)
}
})
}
}
func TestReader_ReadSchema(t *testing.T) {
opts := &readers.ReaderOptions{
FilePath: filepath.Join("..", "..", "..", "tests", "assets", "bun", "simple.go"),
@@ -485,9 +515,9 @@ func TestReader_NullableTypes(t *testing.T) {
// Test all nullability scenarios
tests := []struct {
column string
notNull bool
reason string
column string
notNull bool
reason string
}{
{"id", true, "primary key"},
{"user_id", true, "explicit notnull tag"},

View File

@@ -567,110 +567,182 @@ func (r *Reader) parseDBML(content string) (*models.Database, error) {
// parseColumn parses a DBML column definition
func (r *Reader) parseColumn(line, tableName, schemaName string) (*models.Column, *models.Constraint) {
// Format: column_name type [attributes] // comment
parts := strings.Fields(line)
if len(parts) < 2 {
lineNoComment, inlineComment := splitInlineComment(line)
signature, attrs := splitColumnSignatureAndAttrs(lineNoComment)
columnName, columnType, ok := parseColumnSignature(signature)
if !ok {
return nil, nil
}
columnName := stripQuotes(parts[0])
columnType := stripQuotes(parts[1])
column := models.InitColumn(columnName, tableName, schemaName)
column.Type = columnType
var constraint *models.Constraint
// Parse attributes in brackets
if strings.Contains(line, "[") && strings.Contains(line, "]") {
attrStart := strings.Index(line, "[")
attrEnd := strings.Index(line, "]")
if attrStart < attrEnd {
attrs := line[attrStart+1 : attrEnd]
attrList := strings.Split(attrs, ",")
if attrs != "" {
attrList := strings.Split(attrs, ",")
for _, attr := range attrList {
attr = strings.TrimSpace(attr)
for _, attr := range attrList {
attr = strings.TrimSpace(attr)
if strings.Contains(attr, "primary key") || attr == "pk" {
column.IsPrimaryKey = true
column.NotNull = true
} else if strings.Contains(attr, "not null") {
column.NotNull = true
} else if attr == "increment" {
column.AutoIncrement = true
} else if strings.HasPrefix(attr, "default:") {
defaultVal := strings.TrimSpace(strings.TrimPrefix(attr, "default:"))
column.Default = strings.Trim(defaultVal, "'\"")
} else if attr == "unique" {
// Create a unique constraint
// Clean table name by removing leading underscores to avoid double underscores
cleanTableName := strings.TrimLeft(tableName, "_")
uniqueConstraint := models.InitConstraint(
fmt.Sprintf("ukey_%s_%s", cleanTableName, columnName),
models.UniqueConstraint,
)
uniqueConstraint.Schema = schemaName
uniqueConstraint.Table = tableName
uniqueConstraint.Columns = []string{columnName}
// Store it to be added later
if constraint == nil {
constraint = uniqueConstraint
}
} else if strings.HasPrefix(attr, "note:") {
// Parse column note/comment
note := strings.TrimSpace(strings.TrimPrefix(attr, "note:"))
column.Comment = strings.Trim(note, "'\"")
} else if strings.HasPrefix(attr, "ref:") {
// Parse inline reference
// DBML semantics depend on context:
// - On FK column: ref: < target means "this FK references target"
// - On PK column: ref: < source means "source references this PK" (reverse notation)
refStr := strings.TrimSpace(strings.TrimPrefix(attr, "ref:"))
// Check relationship direction operator
refOp := strings.TrimSpace(refStr)
var isReverse bool
if strings.HasPrefix(refOp, "<") {
// < means "is referenced by" - only makes sense on PK columns
isReverse = column.IsPrimaryKey
}
// > means "references" - always a forward FK, never reverse
constraint = r.parseRef(refStr)
if constraint != nil {
if isReverse {
// Reverse: parsed ref is SOURCE, current column is TARGET
// Constraint should be ON the source table
constraint.Schema = constraint.ReferencedSchema
constraint.Table = constraint.ReferencedTable
constraint.Columns = constraint.ReferencedColumns
constraint.ReferencedSchema = schemaName
constraint.ReferencedTable = tableName
constraint.ReferencedColumns = []string{columnName}
} else {
// Forward: current column is SOURCE, parsed ref is TARGET
// Standard FK: constraint is ON current table
constraint.Schema = schemaName
constraint.Table = tableName
constraint.Columns = []string{columnName}
}
// Generate constraint name based on table and columns
constraint.Name = fmt.Sprintf("fk_%s_%s", constraint.Table, strings.Join(constraint.Columns, "_"))
if strings.Contains(attr, "primary key") || attr == "pk" {
column.IsPrimaryKey = true
column.NotNull = true
} else if strings.Contains(attr, "not null") {
column.NotNull = true
} else if attr == "increment" {
column.AutoIncrement = true
} else if strings.HasPrefix(attr, "default:") {
defaultVal := strings.TrimSpace(strings.TrimPrefix(attr, "default:"))
column.Default = strings.Trim(defaultVal, "'\"")
} else if attr == "unique" {
// Create a unique constraint
// Clean table name by removing leading underscores to avoid double underscores
cleanTableName := strings.TrimLeft(tableName, "_")
uniqueConstraint := models.InitConstraint(
fmt.Sprintf("ukey_%s_%s", cleanTableName, columnName),
models.UniqueConstraint,
)
uniqueConstraint.Schema = schemaName
uniqueConstraint.Table = tableName
uniqueConstraint.Columns = []string{columnName}
// Store it to be added later
if constraint == nil {
constraint = uniqueConstraint
}
} else if strings.HasPrefix(attr, "note:") {
// Parse column note/comment
note := strings.TrimSpace(strings.TrimPrefix(attr, "note:"))
column.Comment = strings.Trim(note, "'\"")
} else if strings.HasPrefix(attr, "ref:") {
// Parse inline reference
// DBML semantics depend on context:
// - On FK column: ref: < target means "this FK references target"
// - On PK column: ref: < source means "source references this PK" (reverse notation)
refStr := strings.TrimSpace(strings.TrimPrefix(attr, "ref:"))
// Check relationship direction operator
refOp := strings.TrimSpace(refStr)
var isReverse bool
if strings.HasPrefix(refOp, "<") {
// < means "is referenced by" - only makes sense on PK columns
isReverse = column.IsPrimaryKey
}
// > means "references" - always a forward FK, never reverse
constraint = r.parseRef(refStr)
if constraint != nil {
if isReverse {
// Reverse: parsed ref is SOURCE, current column is TARGET
// Constraint should be ON the source table
constraint.Schema = constraint.ReferencedSchema
constraint.Table = constraint.ReferencedTable
constraint.Columns = constraint.ReferencedColumns
constraint.ReferencedSchema = schemaName
constraint.ReferencedTable = tableName
constraint.ReferencedColumns = []string{columnName}
} else {
// Forward: current column is SOURCE, parsed ref is TARGET
// Standard FK: constraint is ON current table
constraint.Schema = schemaName
constraint.Table = tableName
constraint.Columns = []string{columnName}
}
// Generate constraint name based on table and columns
constraint.Name = fmt.Sprintf("fk_%s_%s", constraint.Table, strings.Join(constraint.Columns, "_"))
}
}
}
}
// Parse inline comment
if strings.Contains(line, "//") {
commentStart := strings.Index(line, "//")
column.Comment = strings.TrimSpace(line[commentStart+2:])
if inlineComment != "" {
column.Comment = inlineComment
}
return column, constraint
}
func splitInlineComment(line string) (string, string) {
commentStart := strings.Index(line, "//")
if commentStart == -1 {
return line, ""
}
return strings.TrimSpace(line[:commentStart]), strings.TrimSpace(line[commentStart+2:])
}
func splitColumnSignatureAndAttrs(line string) (string, string) {
trimmed := strings.TrimSpace(line)
if trimmed == "" || !strings.HasSuffix(trimmed, "]") {
return trimmed, ""
}
bracketDepth := 0
for i := len(trimmed) - 1; i >= 0; i-- {
switch trimmed[i] {
case ']':
bracketDepth++
case '[':
bracketDepth--
if bracketDepth == 0 {
// DBML attributes are a trailing [ ... ] block preceded by whitespace.
// This avoids confusing array types like text[] with attribute blocks.
if i > 0 && (trimmed[i-1] == ' ' || trimmed[i-1] == '\t') {
return strings.TrimSpace(trimmed[:i]), strings.TrimSpace(trimmed[i+1 : len(trimmed)-1])
}
}
}
}
return trimmed, ""
}
func parseColumnSignature(signature string) (string, string, bool) {
signature = strings.TrimSpace(signature)
if signature == "" {
return "", "", false
}
var splitAt int
if signature[0] == '"' || signature[0] == '\'' {
quote := signature[0]
splitAt = 1
for splitAt < len(signature) {
if signature[splitAt] == quote {
splitAt++
break
}
splitAt++
}
} else {
for splitAt < len(signature) && signature[splitAt] != ' ' && signature[splitAt] != '\t' {
splitAt++
}
}
if splitAt <= 0 || splitAt >= len(signature) {
return "", "", false
}
columnName := stripQuotes(strings.TrimSpace(signature[:splitAt]))
columnType := stripWrappingQuotes(strings.TrimSpace(signature[splitAt:]))
if columnName == "" || columnType == "" {
return "", "", false
}
return columnName, columnType, true
}
func stripWrappingQuotes(s string) string {
s = strings.TrimSpace(s)
if len(s) >= 2 && ((s[0] == '"' && s[len(s)-1] == '"') || (s[0] == '\'' && s[len(s)-1] == '\'')) {
return s[1 : len(s)-1]
}
return s
}
// parseIndex parses a DBML index definition
func (r *Reader) parseIndex(line, tableName, schemaName string) *models.Index {
// Format: (columns) [attributes] OR columnname [attributes]

View File

@@ -839,6 +839,67 @@ func TestConstraintNaming(t *testing.T) {
}
}
func TestParseColumn_PostgresTypes(t *testing.T) {
reader := &Reader{}
tests := []struct {
name string
line string
wantName string
wantType string
wantNotNull bool
wantComment string
}{
{
name: "array type with attrs",
line: "tags text[] [not null]",
wantName: "tags",
wantType: "text[]",
wantNotNull: true,
},
{
name: "vector with dimension",
line: "embedding vector(1536)",
wantName: "embedding",
wantType: "vector(1536)",
},
{
name: "multi word timestamp type",
line: "published_at timestamp with time zone",
wantName: "published_at",
wantType: "timestamp with time zone",
},
{
name: "array type with inline comment",
line: "labels varchar(20)[] // column labels",
wantName: "labels",
wantType: "varchar(20)[]",
wantComment: "column labels",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
col, _ := reader.parseColumn(tt.line, "events", "public")
if col == nil {
t.Fatalf("parseColumn() returned nil column")
}
if col.Name != tt.wantName {
t.Errorf("column name = %q, want %q", col.Name, tt.wantName)
}
if col.Type != tt.wantType {
t.Errorf("column type = %q, want %q", col.Type, tt.wantType)
}
if col.NotNull != tt.wantNotNull {
t.Errorf("column not null = %v, want %v", col.NotNull, tt.wantNotNull)
}
if col.Comment != tt.wantComment {
t.Errorf("column comment = %q, want %q", col.Comment, tt.wantComment)
}
})
}
}
func getKeys[V any](m map[string]V) []string {
keys := make([]string, 0, len(m))
for k := range m {

View File

@@ -8,6 +8,7 @@ import (
"strings"
"git.warky.dev/wdevs/relspecgo/pkg/models"
"git.warky.dev/wdevs/relspecgo/pkg/pgsql"
"git.warky.dev/wdevs/relspecgo/pkg/readers"
"git.warky.dev/wdevs/relspecgo/pkg/writers/drawdb"
)
@@ -231,30 +232,35 @@ func (r *Reader) convertToColumn(field *drawdb.DrawDBField, tableName, schemaNam
// Parse type and dimensions
typeStr := field.Type
typeStr = strings.TrimSpace(typeStr)
column.Type = typeStr
// Try to extract length/precision from type string like "varchar(255)" or "decimal(10,2)"
if strings.Contains(typeStr, "(") {
parts := strings.Split(typeStr, "(")
column.Type = parts[0]
baseType := strings.TrimSpace(parts[0])
if len(parts) > 1 {
dimensions := strings.TrimSuffix(parts[1], ")")
if strings.Contains(dimensions, ",") {
// Precision and scale (e.g., decimal(10,2))
dims := strings.Split(dimensions, ",")
if precision, err := strconv.Atoi(strings.TrimSpace(dims[0])); err == nil {
column.Precision = precision
}
if len(dims) > 1 {
if scale, err := strconv.Atoi(strings.TrimSpace(dims[1])); err == nil {
column.Scale = scale
// Precision and scale (e.g., decimal(10,2), numeric(10,2))
if pgsql.SupportsPrecision(baseType) {
dims := strings.Split(dimensions, ",")
if precision, err := strconv.Atoi(strings.TrimSpace(dims[0])); err == nil {
column.Precision = precision
}
if len(dims) > 1 {
if scale, err := strconv.Atoi(strings.TrimSpace(dims[1])); err == nil {
column.Scale = scale
}
}
}
} else {
// Just length (e.g., varchar(255))
if length, err := strconv.Atoi(dimensions); err == nil {
column.Length = length
if pgsql.SupportsLength(baseType) {
if length, err := strconv.Atoi(dimensions); err == nil {
column.Length = length
}
}
}
}

View File

@@ -6,6 +6,7 @@ import (
"git.warky.dev/wdevs/relspecgo/pkg/models"
"git.warky.dev/wdevs/relspecgo/pkg/readers"
"git.warky.dev/wdevs/relspecgo/pkg/writers/drawdb"
)
func TestReader_ReadDatabase_Simple(t *testing.T) {
@@ -288,6 +289,61 @@ func TestReader_ReadDatabase_Complex(t *testing.T) {
}
}
func TestConvertToColumn_PreservesExplicitTypeModifiers(t *testing.T) {
reader := &Reader{}
tests := []struct {
name string
fieldType string
wantType string
wantLength int
wantPrecision int
wantScale int
}{
{
name: "varchar with length",
fieldType: "varchar(255)",
wantType: "varchar(255)",
wantLength: 255,
},
{
name: "numeric precision/scale",
fieldType: "numeric(10,2)",
wantType: "numeric(10,2)",
wantPrecision: 10,
wantScale: 2,
},
{
name: "custom vector modifier",
fieldType: "vector(1536)",
wantType: "vector(1536)",
wantLength: 1536,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
field := &drawdb.DrawDBField{
Name: tt.name,
Type: tt.fieldType,
}
col := reader.convertToColumn(field, "events", "public")
if col.Type != tt.wantType {
t.Fatalf("column type = %q, want %q", col.Type, tt.wantType)
}
if col.Length != tt.wantLength {
t.Fatalf("column length = %d, want %d", col.Length, tt.wantLength)
}
if col.Precision != tt.wantPrecision {
t.Fatalf("column precision = %d, want %d", col.Precision, tt.wantPrecision)
}
if col.Scale != tt.wantScale {
t.Fatalf("column scale = %d, want %d", col.Scale, tt.wantScale)
}
})
}
}
func TestReader_ReadSchema(t *testing.T) {
opts := &readers.ReaderOptions{
FilePath: filepath.Join("..", "..", "..", "tests", "assets", "drawdb", "simple.json"),

View File

@@ -12,6 +12,7 @@ import (
"strings"
"git.warky.dev/wdevs/relspecgo/pkg/models"
"git.warky.dev/wdevs/relspecgo/pkg/pgsql"
"git.warky.dev/wdevs/relspecgo/pkg/readers"
)
@@ -784,11 +785,14 @@ func (r *Reader) extractGormTag(tag string) string {
// parseTypeWithLength parses a type string and extracts length if present
// e.g., "varchar(255)" returns ("varchar", 255)
func (r *Reader) parseTypeWithLength(typeStr string) (baseType string, length int) {
typeStr = strings.TrimSpace(typeStr)
baseType = typeStr
// Check for type with length: varchar(255), char(10), etc.
// Also handle precision/scale: numeric(10,2)
if strings.Contains(typeStr, "(") {
idx := strings.Index(typeStr, "(")
baseType = strings.TrimSpace(typeStr[:idx])
rawBaseType := strings.TrimSpace(typeStr[:idx])
// Extract numbers from parentheses
parens := typeStr[idx+1:]
@@ -796,14 +800,15 @@ func (r *Reader) parseTypeWithLength(typeStr string) (baseType string, length in
parens = parens[:endIdx]
}
// For now, just handle single number (length)
if !strings.Contains(parens, ",") {
// Only treat as "length" for text-ish SQL types.
// This avoids converting custom modifiers like vector(1536) into Length.
if pgsql.SupportsLength(rawBaseType) && !strings.Contains(parens, ",") {
if _, err := fmt.Sscanf(parens, "%d", &length); err == nil {
return
}
}
}
baseType = typeStr
return
}

View File

@@ -71,8 +71,11 @@ func TestReader_ReadDatabase_Simple(t *testing.T) {
if !emailCol.NotNull {
t.Error("Column 'email' should be NOT NULL (explicit 'not null' tag)")
}
if emailCol.Type != "varchar" || emailCol.Length != 255 {
t.Errorf("Expected email type 'varchar(255)', got '%s' with length %d", emailCol.Type, emailCol.Length)
if emailCol.Type != "varchar" && emailCol.Type != "varchar(255)" {
t.Errorf("Expected email type 'varchar' or 'varchar(255)', got '%s' with length %d", emailCol.Type, emailCol.Length)
}
if emailCol.Length != 255 {
t.Errorf("Expected email length 255, got %d", emailCol.Length)
}
// Verify name column - primitive string type should be NOT NULL by default
@@ -363,6 +366,33 @@ func TestReader_ReadDatabase_Complex(t *testing.T) {
}
}
func TestParseTypeWithLength_PreservesExplicitTypeModifiers(t *testing.T) {
reader := &Reader{}
tests := []struct {
input string
wantType string
wantLength int
}{
{"varchar(255)", "varchar(255)", 255},
{"character varying(120)", "character varying(120)", 120},
{"vector(1536)", "vector(1536)", 1536},
{"numeric(10,2)", "numeric(10,2)", 0},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
gotType, gotLength := reader.parseTypeWithLength(tt.input)
if gotType != tt.wantType {
t.Fatalf("parseTypeWithLength(%q) type = %q, want %q", tt.input, gotType, tt.wantType)
}
if gotLength != tt.wantLength {
t.Fatalf("parseTypeWithLength(%q) length = %d, want %d", tt.input, gotLength, tt.wantLength)
}
})
}
}
func TestReader_ReadSchema(t *testing.T) {
opts := &readers.ReaderOptions{
FilePath: filepath.Join("..", "..", "..", "tests", "assets", "gorm", "simple.go"),

View File

@@ -206,8 +206,19 @@ func (r *Reader) queryColumns(schemaName string) (map[string]map[string]*models.
c.numeric_precision,
c.numeric_scale,
c.udt_name,
pg_catalog.format_type(a.atttypid, a.atttypmod) as formatted_data_type,
col_description((c.table_schema||'.'||c.table_name)::regclass, c.ordinal_position) as description
FROM information_schema.columns c
JOIN pg_catalog.pg_namespace n
ON n.nspname = c.table_schema
JOIN pg_catalog.pg_class cls
ON cls.relname = c.table_name
AND cls.relnamespace = n.oid
JOIN pg_catalog.pg_attribute a
ON a.attrelid = cls.oid
AND a.attname = c.column_name
AND a.attnum > 0
AND NOT a.attisdropped
WHERE c.table_schema = $1
ORDER BY c.table_schema, c.table_name, c.ordinal_position
`
@@ -221,12 +232,12 @@ func (r *Reader) queryColumns(schemaName string) (map[string]map[string]*models.
columnsMap := make(map[string]map[string]*models.Column)
for rows.Next() {
var schema, tableName, columnName, isNullable, dataType, udtName string
var schema, tableName, columnName, isNullable, dataType, udtName, formattedDataType string
var ordinalPosition int
var columnDefault, description *string
var charMaxLength, numPrecision, numScale *int
if err := rows.Scan(&schema, &tableName, &columnName, &ordinalPosition, &columnDefault, &isNullable, &dataType, &charMaxLength, &numPrecision, &numScale, &udtName, &description); err != nil {
if err := rows.Scan(&schema, &tableName, &columnName, &ordinalPosition, &columnDefault, &isNullable, &dataType, &charMaxLength, &numPrecision, &numScale, &udtName, &formattedDataType, &description); err != nil {
return nil, err
}
@@ -246,7 +257,7 @@ func (r *Reader) queryColumns(schemaName string) (map[string]map[string]*models.
}
// Map data type, preserving serial types when detected
column.Type = r.mapDataType(dataType, udtName, hasNextval)
column.Type = r.mapDataType(dataType, udtName, formattedDataType, hasNextval)
column.NotNull = (isNullable == "NO")
column.Sequence = uint(ordinalPosition)

View File

@@ -259,12 +259,14 @@ func (r *Reader) close() {
}
}
// mapDataType maps PostgreSQL data types to canonical types
func (r *Reader) mapDataType(pgType, udtName string, hasNextval bool) string {
// mapDataType maps PostgreSQL data types while preserving exact type text when available.
func (r *Reader) mapDataType(pgType, udtName, formattedType string, hasNextval bool) string {
normalizedPGType := strings.ToLower(strings.TrimSpace(pgType))
// If the column has a nextval default, it's likely a serial type
// Map to the appropriate serial type instead of the base integer type
if hasNextval {
switch strings.ToLower(pgType) {
switch normalizedPGType {
case "integer", "int", "int4":
return "serial"
case "bigint", "int8":
@@ -274,6 +276,17 @@ func (r *Reader) mapDataType(pgType, udtName string, hasNextval bool) string {
}
}
// Prefer the database-provided formatted type; this preserves arrays/custom
// types/modifiers like text[], vector(1536), numeric(10,2), etc.
if strings.TrimSpace(formattedType) != "" {
return formattedType
}
// information_schema reports arrays generically as "ARRAY" with udt_name like "_text".
if strings.EqualFold(pgType, "ARRAY") && strings.HasPrefix(udtName, "_") && len(udtName) > 1 {
return udtName[1:] + "[]"
}
// Map common PostgreSQL types
typeMap := map[string]string{
"integer": "integer",
@@ -320,7 +333,7 @@ func (r *Reader) mapDataType(pgType, udtName string, hasNextval bool) string {
}
// Try mapped type first
if mapped, exists := typeMap[pgType]; exists {
if mapped, exists := typeMap[normalizedPGType]; exists {
return mapped
}
@@ -329,8 +342,11 @@ func (r *Reader) mapDataType(pgType, udtName string, hasNextval bool) string {
return pgsql.GetSQLType(pgType)
}
// Return UDT name for custom types
// Return UDT name for custom types (including array fallback when needed)
if udtName != "" {
if strings.HasPrefix(udtName, "_") && len(udtName) > 1 {
return udtName[1:] + "[]"
}
return udtName
}

View File

@@ -173,35 +173,39 @@ func TestMapDataType(t *testing.T) {
reader := &Reader{}
tests := []struct {
pgType string
udtName string
expected string
pgType string
udtName string
formattedType string
expected string
}{
{"integer", "int4", "integer"},
{"bigint", "int8", "bigint"},
{"smallint", "int2", "smallint"},
{"character varying", "varchar", "varchar"},
{"text", "text", "text"},
{"boolean", "bool", "boolean"},
{"timestamp without time zone", "timestamp", "timestamp"},
{"timestamp with time zone", "timestamptz", "timestamptz"},
{"json", "json", "json"},
{"jsonb", "jsonb", "jsonb"},
{"uuid", "uuid", "uuid"},
{"numeric", "numeric", "numeric"},
{"real", "float4", "real"},
{"double precision", "float8", "double precision"},
{"date", "date", "date"},
{"time without time zone", "time", "time"},
{"bytea", "bytea", "bytea"},
{"unknown_type", "custom", "custom"}, // Should return UDT name
{"integer", "int4", "", "integer"},
{"bigint", "int8", "", "bigint"},
{"smallint", "int2", "", "smallint"},
{"character varying", "varchar", "", "varchar"},
{"text", "text", "", "text"},
{"boolean", "bool", "", "boolean"},
{"timestamp without time zone", "timestamp", "", "timestamp"},
{"timestamp with time zone", "timestamptz", "", "timestamptz"},
{"json", "json", "", "json"},
{"jsonb", "jsonb", "", "jsonb"},
{"uuid", "uuid", "", "uuid"},
{"numeric", "numeric", "", "numeric"},
{"real", "float4", "", "real"},
{"double precision", "float8", "", "double precision"},
{"date", "date", "", "date"},
{"time without time zone", "time", "", "time"},
{"bytea", "bytea", "", "bytea"},
{"unknown_type", "custom", "", "custom"}, // Should return UDT name
{"ARRAY", "_text", "", "text[]"},
{"USER-DEFINED", "vector", "vector(1536)", "vector(1536)"},
{"character varying", "varchar", "character varying(255)", "character varying(255)"},
}
for _, tt := range tests {
t.Run(tt.pgType, func(t *testing.T) {
result := reader.mapDataType(tt.pgType, tt.udtName, false)
result := reader.mapDataType(tt.pgType, tt.udtName, tt.formattedType, false)
if result != tt.expected {
t.Errorf("mapDataType(%s, %s) = %s, expected %s", tt.pgType, tt.udtName, result, tt.expected)
t.Errorf("mapDataType(%s, %s, %s) = %s, expected %s", tt.pgType, tt.udtName, tt.formattedType, result, tt.expected)
}
})
}
@@ -218,9 +222,9 @@ func TestMapDataType(t *testing.T) {
for _, tt := range serialTests {
t.Run(tt.pgType+"_with_nextval", func(t *testing.T) {
result := reader.mapDataType(tt.pgType, "", true)
result := reader.mapDataType(tt.pgType, "", "", true)
if result != tt.expected {
t.Errorf("mapDataType(%s, '', true) = %s, expected %s", tt.pgType, result, tt.expected)
t.Errorf("mapDataType(%s, '', '', true) = %s, expected %s", tt.pgType, result, tt.expected)
}
})
}
@@ -230,63 +234,63 @@ func TestParseIndexDefinition(t *testing.T) {
reader := &Reader{}
tests := []struct {
name string
indexName string
tableName string
schema string
indexDef string
wantType string
wantUnique bool
name string
indexName string
tableName string
schema string
indexDef string
wantType string
wantUnique bool
wantColumns int
}{
{
name: "simple btree index",
indexName: "idx_users_email",
tableName: "users",
schema: "public",
indexDef: "CREATE INDEX idx_users_email ON public.users USING btree (email)",
wantType: "btree",
wantUnique: false,
name: "simple btree index",
indexName: "idx_users_email",
tableName: "users",
schema: "public",
indexDef: "CREATE INDEX idx_users_email ON public.users USING btree (email)",
wantType: "btree",
wantUnique: false,
wantColumns: 1,
},
{
name: "unique index",
indexName: "idx_users_username",
tableName: "users",
schema: "public",
indexDef: "CREATE UNIQUE INDEX idx_users_username ON public.users USING btree (username)",
wantType: "btree",
wantUnique: true,
name: "unique index",
indexName: "idx_users_username",
tableName: "users",
schema: "public",
indexDef: "CREATE UNIQUE INDEX idx_users_username ON public.users USING btree (username)",
wantType: "btree",
wantUnique: true,
wantColumns: 1,
},
{
name: "composite index",
indexName: "idx_users_name",
tableName: "users",
schema: "public",
indexDef: "CREATE INDEX idx_users_name ON public.users USING btree (first_name, last_name)",
wantType: "btree",
wantUnique: false,
name: "composite index",
indexName: "idx_users_name",
tableName: "users",
schema: "public",
indexDef: "CREATE INDEX idx_users_name ON public.users USING btree (first_name, last_name)",
wantType: "btree",
wantUnique: false,
wantColumns: 2,
},
{
name: "gin index",
indexName: "idx_posts_tags",
tableName: "posts",
schema: "public",
indexDef: "CREATE INDEX idx_posts_tags ON public.posts USING gin (tags)",
wantType: "gin",
wantUnique: false,
name: "gin index",
indexName: "idx_posts_tags",
tableName: "posts",
schema: "public",
indexDef: "CREATE INDEX idx_posts_tags ON public.posts USING gin (tags)",
wantType: "gin",
wantUnique: false,
wantColumns: 1,
},
{
name: "partial index with where clause",
indexName: "idx_users_active",
tableName: "users",
schema: "public",
indexDef: "CREATE INDEX idx_users_active ON public.users USING btree (id) WHERE (active = true)",
wantType: "btree",
wantUnique: false,
name: "partial index with where clause",
indexName: "idx_users_active",
tableName: "users",
schema: "public",
indexDef: "CREATE INDEX idx_users_active ON public.users USING btree (id) WHERE (active = true)",
wantType: "btree",
wantUnique: false,
wantColumns: 1,
},
}

View File

@@ -5,9 +5,11 @@ import (
"fmt"
"os"
"regexp"
"strconv"
"strings"
"git.warky.dev/wdevs/relspecgo/pkg/models"
"git.warky.dev/wdevs/relspecgo/pkg/pgsql"
"git.warky.dev/wdevs/relspecgo/pkg/readers"
)
@@ -549,6 +551,41 @@ func (r *Reader) parseColumnOptions(decorator string, column *models.Column, tab
}
}
// Preserve explicit type modifiers from options where present.
// Example: @Column({ type: 'varchar', length: 255 }) -> varchar(255)
if column.Type != "" && !strings.Contains(column.Type, "(") {
lengthRegex := regexp.MustCompile(`length:\s*(\d+)`)
precisionRegex := regexp.MustCompile(`precision:\s*(\d+)`)
scaleRegex := regexp.MustCompile(`scale:\s*(\d+)`)
baseType := strings.ToLower(strings.TrimSpace(column.Type))
if pgsql.SupportsLength(baseType) {
if matches := lengthRegex.FindStringSubmatch(content); len(matches) == 2 {
if n, err := strconv.Atoi(matches[1]); err == nil && n > 0 {
column.Length = n
column.Type = fmt.Sprintf("%s(%d)", column.Type, n)
}
}
}
if pgsql.SupportsPrecision(baseType) {
if matches := precisionRegex.FindStringSubmatch(content); len(matches) == 2 {
if p, err := strconv.Atoi(matches[1]); err == nil && p > 0 {
column.Precision = p
if sm := scaleRegex.FindStringSubmatch(content); len(sm) == 2 {
if s, err := strconv.Atoi(sm[1]); err == nil && s >= 0 {
column.Scale = s
column.Type = fmt.Sprintf("%s(%d,%d)", column.Type, p, s)
}
} else {
column.Type = fmt.Sprintf("%s(%d)", column.Type, p)
}
}
}
}
}
if strings.Contains(content, "nullable: true") || strings.Contains(content, "nullable:true") {
column.NotNull = false
}

View File

@@ -0,0 +1,60 @@
package typeorm
import (
"testing"
"git.warky.dev/wdevs/relspecgo/pkg/models"
)
func TestParseColumnOptions_PreservesTypeModifiers(t *testing.T) {
reader := &Reader{}
table := models.InitTable("users", "public")
tests := []struct {
name string
decorator string
wantType string
wantLength int
wantPrecision int
wantScale int
}{
{
name: "varchar with length",
decorator: `@Column({ type: 'varchar', length: 255 })`,
wantType: "varchar(255)",
wantLength: 255,
},
{
name: "numeric with precision and scale",
decorator: `@Column({ type: 'numeric', precision: 10, scale: 2 })`,
wantType: "numeric(10,2)",
wantPrecision: 10,
wantScale: 2,
},
{
name: "custom type with explicit modifier is preserved",
decorator: `@Column({ type: 'vector(1536)' })`,
wantType: "vector(1536)",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
col := models.InitColumn("sample", table.Name, table.Schema)
reader.parseColumnOptions(tt.decorator, col, table)
if col.Type != tt.wantType {
t.Fatalf("column type = %q, want %q", col.Type, tt.wantType)
}
if col.Length != tt.wantLength {
t.Fatalf("column length = %d, want %d", col.Length, tt.wantLength)
}
if col.Precision != tt.wantPrecision {
t.Fatalf("column precision = %d, want %d", col.Precision, tt.wantPrecision)
}
if col.Scale != tt.wantScale {
t.Fatalf("column scale = %d, want %d", col.Scale, tt.wantScale)
}
})
}
}