test(drawdb): add test for converting column types with modifiers
* Implement tests to ensure explicit type modifiers are preserved during conversion. * Validate behavior for varchar, numeric, and custom vector types.
This commit is contained in:
@@ -12,6 +12,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/pgsql"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/readers"
|
||||
)
|
||||
|
||||
@@ -700,16 +701,21 @@ func (r *Reader) extractBunTag(tag string) string {
|
||||
// parseTypeWithLength parses a type string and extracts length if present
|
||||
// e.g., "varchar(255)" returns ("varchar", 255)
|
||||
func (r *Reader) parseTypeWithLength(typeStr string) (baseType string, length int) {
|
||||
typeStr = strings.TrimSpace(typeStr)
|
||||
baseType = typeStr
|
||||
|
||||
// Check for type with length: varchar(255), char(10), etc.
|
||||
re := regexp.MustCompile(`^([a-zA-Z\s]+)\((\d+)\)$`)
|
||||
matches := re.FindStringSubmatch(typeStr)
|
||||
if len(matches) == 3 {
|
||||
if _, err := fmt.Sscanf(matches[2], "%d", &length); err == nil {
|
||||
baseType = strings.TrimSpace(matches[1])
|
||||
return
|
||||
rawBaseType := strings.TrimSpace(matches[1])
|
||||
if pgsql.SupportsLength(rawBaseType) {
|
||||
if _, err := fmt.Sscanf(matches[2], "%d", &length); err == nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
baseType = typeStr
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -71,8 +71,11 @@ func TestReader_ReadDatabase_Simple(t *testing.T) {
|
||||
if !emailCol.NotNull {
|
||||
t.Error("Column 'email' should be NOT NULL (explicit 'notnull' tag)")
|
||||
}
|
||||
if emailCol.Type != "varchar" || emailCol.Length != 255 {
|
||||
t.Errorf("Expected email type 'varchar(255)', got '%s' with length %d", emailCol.Type, emailCol.Length)
|
||||
if emailCol.Type != "varchar" && emailCol.Type != "varchar(255)" {
|
||||
t.Errorf("Expected email type 'varchar' or 'varchar(255)', got '%s' with length %d", emailCol.Type, emailCol.Length)
|
||||
}
|
||||
if emailCol.Length != 255 {
|
||||
t.Errorf("Expected email length 255, got %d", emailCol.Length)
|
||||
}
|
||||
|
||||
// Verify name column - primitive string type should be NOT NULL by default in Bun
|
||||
@@ -356,6 +359,33 @@ func TestReader_ReadDatabase_Complex(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseTypeWithLength_PreservesExplicitTypeModifiers(t *testing.T) {
|
||||
reader := &Reader{}
|
||||
|
||||
tests := []struct {
|
||||
input string
|
||||
wantType string
|
||||
wantLength int
|
||||
}{
|
||||
{"varchar(255)", "varchar(255)", 255},
|
||||
{"character varying(120)", "character varying(120)", 120},
|
||||
{"vector(1536)", "vector(1536)", 1536},
|
||||
{"numeric(10,2)", "numeric(10,2)", 0},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
gotType, gotLength := reader.parseTypeWithLength(tt.input)
|
||||
if gotType != tt.wantType {
|
||||
t.Fatalf("parseTypeWithLength(%q) type = %q, want %q", tt.input, gotType, tt.wantType)
|
||||
}
|
||||
if gotLength != tt.wantLength {
|
||||
t.Fatalf("parseTypeWithLength(%q) length = %d, want %d", tt.input, gotLength, tt.wantLength)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestReader_ReadSchema(t *testing.T) {
|
||||
opts := &readers.ReaderOptions{
|
||||
FilePath: filepath.Join("..", "..", "..", "tests", "assets", "bun", "simple.go"),
|
||||
@@ -485,9 +515,9 @@ func TestReader_NullableTypes(t *testing.T) {
|
||||
|
||||
// Test all nullability scenarios
|
||||
tests := []struct {
|
||||
column string
|
||||
notNull bool
|
||||
reason string
|
||||
column string
|
||||
notNull bool
|
||||
reason string
|
||||
}{
|
||||
{"id", true, "primary key"},
|
||||
{"user_id", true, "explicit notnull tag"},
|
||||
|
||||
@@ -567,110 +567,182 @@ func (r *Reader) parseDBML(content string) (*models.Database, error) {
|
||||
// parseColumn parses a DBML column definition
|
||||
func (r *Reader) parseColumn(line, tableName, schemaName string) (*models.Column, *models.Constraint) {
|
||||
// Format: column_name type [attributes] // comment
|
||||
parts := strings.Fields(line)
|
||||
if len(parts) < 2 {
|
||||
lineNoComment, inlineComment := splitInlineComment(line)
|
||||
signature, attrs := splitColumnSignatureAndAttrs(lineNoComment)
|
||||
columnName, columnType, ok := parseColumnSignature(signature)
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
columnName := stripQuotes(parts[0])
|
||||
columnType := stripQuotes(parts[1])
|
||||
|
||||
column := models.InitColumn(columnName, tableName, schemaName)
|
||||
column.Type = columnType
|
||||
|
||||
var constraint *models.Constraint
|
||||
|
||||
// Parse attributes in brackets
|
||||
if strings.Contains(line, "[") && strings.Contains(line, "]") {
|
||||
attrStart := strings.Index(line, "[")
|
||||
attrEnd := strings.Index(line, "]")
|
||||
if attrStart < attrEnd {
|
||||
attrs := line[attrStart+1 : attrEnd]
|
||||
attrList := strings.Split(attrs, ",")
|
||||
if attrs != "" {
|
||||
attrList := strings.Split(attrs, ",")
|
||||
|
||||
for _, attr := range attrList {
|
||||
attr = strings.TrimSpace(attr)
|
||||
for _, attr := range attrList {
|
||||
attr = strings.TrimSpace(attr)
|
||||
|
||||
if strings.Contains(attr, "primary key") || attr == "pk" {
|
||||
column.IsPrimaryKey = true
|
||||
column.NotNull = true
|
||||
} else if strings.Contains(attr, "not null") {
|
||||
column.NotNull = true
|
||||
} else if attr == "increment" {
|
||||
column.AutoIncrement = true
|
||||
} else if strings.HasPrefix(attr, "default:") {
|
||||
defaultVal := strings.TrimSpace(strings.TrimPrefix(attr, "default:"))
|
||||
column.Default = strings.Trim(defaultVal, "'\"")
|
||||
} else if attr == "unique" {
|
||||
// Create a unique constraint
|
||||
// Clean table name by removing leading underscores to avoid double underscores
|
||||
cleanTableName := strings.TrimLeft(tableName, "_")
|
||||
uniqueConstraint := models.InitConstraint(
|
||||
fmt.Sprintf("ukey_%s_%s", cleanTableName, columnName),
|
||||
models.UniqueConstraint,
|
||||
)
|
||||
uniqueConstraint.Schema = schemaName
|
||||
uniqueConstraint.Table = tableName
|
||||
uniqueConstraint.Columns = []string{columnName}
|
||||
// Store it to be added later
|
||||
if constraint == nil {
|
||||
constraint = uniqueConstraint
|
||||
}
|
||||
} else if strings.HasPrefix(attr, "note:") {
|
||||
// Parse column note/comment
|
||||
note := strings.TrimSpace(strings.TrimPrefix(attr, "note:"))
|
||||
column.Comment = strings.Trim(note, "'\"")
|
||||
} else if strings.HasPrefix(attr, "ref:") {
|
||||
// Parse inline reference
|
||||
// DBML semantics depend on context:
|
||||
// - On FK column: ref: < target means "this FK references target"
|
||||
// - On PK column: ref: < source means "source references this PK" (reverse notation)
|
||||
refStr := strings.TrimSpace(strings.TrimPrefix(attr, "ref:"))
|
||||
|
||||
// Check relationship direction operator
|
||||
refOp := strings.TrimSpace(refStr)
|
||||
var isReverse bool
|
||||
if strings.HasPrefix(refOp, "<") {
|
||||
// < means "is referenced by" - only makes sense on PK columns
|
||||
isReverse = column.IsPrimaryKey
|
||||
}
|
||||
// > means "references" - always a forward FK, never reverse
|
||||
|
||||
constraint = r.parseRef(refStr)
|
||||
if constraint != nil {
|
||||
if isReverse {
|
||||
// Reverse: parsed ref is SOURCE, current column is TARGET
|
||||
// Constraint should be ON the source table
|
||||
constraint.Schema = constraint.ReferencedSchema
|
||||
constraint.Table = constraint.ReferencedTable
|
||||
constraint.Columns = constraint.ReferencedColumns
|
||||
constraint.ReferencedSchema = schemaName
|
||||
constraint.ReferencedTable = tableName
|
||||
constraint.ReferencedColumns = []string{columnName}
|
||||
} else {
|
||||
// Forward: current column is SOURCE, parsed ref is TARGET
|
||||
// Standard FK: constraint is ON current table
|
||||
constraint.Schema = schemaName
|
||||
constraint.Table = tableName
|
||||
constraint.Columns = []string{columnName}
|
||||
}
|
||||
// Generate constraint name based on table and columns
|
||||
constraint.Name = fmt.Sprintf("fk_%s_%s", constraint.Table, strings.Join(constraint.Columns, "_"))
|
||||
if strings.Contains(attr, "primary key") || attr == "pk" {
|
||||
column.IsPrimaryKey = true
|
||||
column.NotNull = true
|
||||
} else if strings.Contains(attr, "not null") {
|
||||
column.NotNull = true
|
||||
} else if attr == "increment" {
|
||||
column.AutoIncrement = true
|
||||
} else if strings.HasPrefix(attr, "default:") {
|
||||
defaultVal := strings.TrimSpace(strings.TrimPrefix(attr, "default:"))
|
||||
column.Default = strings.Trim(defaultVal, "'\"")
|
||||
} else if attr == "unique" {
|
||||
// Create a unique constraint
|
||||
// Clean table name by removing leading underscores to avoid double underscores
|
||||
cleanTableName := strings.TrimLeft(tableName, "_")
|
||||
uniqueConstraint := models.InitConstraint(
|
||||
fmt.Sprintf("ukey_%s_%s", cleanTableName, columnName),
|
||||
models.UniqueConstraint,
|
||||
)
|
||||
uniqueConstraint.Schema = schemaName
|
||||
uniqueConstraint.Table = tableName
|
||||
uniqueConstraint.Columns = []string{columnName}
|
||||
// Store it to be added later
|
||||
if constraint == nil {
|
||||
constraint = uniqueConstraint
|
||||
}
|
||||
} else if strings.HasPrefix(attr, "note:") {
|
||||
// Parse column note/comment
|
||||
note := strings.TrimSpace(strings.TrimPrefix(attr, "note:"))
|
||||
column.Comment = strings.Trim(note, "'\"")
|
||||
} else if strings.HasPrefix(attr, "ref:") {
|
||||
// Parse inline reference
|
||||
// DBML semantics depend on context:
|
||||
// - On FK column: ref: < target means "this FK references target"
|
||||
// - On PK column: ref: < source means "source references this PK" (reverse notation)
|
||||
refStr := strings.TrimSpace(strings.TrimPrefix(attr, "ref:"))
|
||||
|
||||
// Check relationship direction operator
|
||||
refOp := strings.TrimSpace(refStr)
|
||||
var isReverse bool
|
||||
if strings.HasPrefix(refOp, "<") {
|
||||
// < means "is referenced by" - only makes sense on PK columns
|
||||
isReverse = column.IsPrimaryKey
|
||||
}
|
||||
// > means "references" - always a forward FK, never reverse
|
||||
|
||||
constraint = r.parseRef(refStr)
|
||||
if constraint != nil {
|
||||
if isReverse {
|
||||
// Reverse: parsed ref is SOURCE, current column is TARGET
|
||||
// Constraint should be ON the source table
|
||||
constraint.Schema = constraint.ReferencedSchema
|
||||
constraint.Table = constraint.ReferencedTable
|
||||
constraint.Columns = constraint.ReferencedColumns
|
||||
constraint.ReferencedSchema = schemaName
|
||||
constraint.ReferencedTable = tableName
|
||||
constraint.ReferencedColumns = []string{columnName}
|
||||
} else {
|
||||
// Forward: current column is SOURCE, parsed ref is TARGET
|
||||
// Standard FK: constraint is ON current table
|
||||
constraint.Schema = schemaName
|
||||
constraint.Table = tableName
|
||||
constraint.Columns = []string{columnName}
|
||||
}
|
||||
// Generate constraint name based on table and columns
|
||||
constraint.Name = fmt.Sprintf("fk_%s_%s", constraint.Table, strings.Join(constraint.Columns, "_"))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Parse inline comment
|
||||
if strings.Contains(line, "//") {
|
||||
commentStart := strings.Index(line, "//")
|
||||
column.Comment = strings.TrimSpace(line[commentStart+2:])
|
||||
if inlineComment != "" {
|
||||
column.Comment = inlineComment
|
||||
}
|
||||
|
||||
return column, constraint
|
||||
}
|
||||
|
||||
func splitInlineComment(line string) (string, string) {
|
||||
commentStart := strings.Index(line, "//")
|
||||
if commentStart == -1 {
|
||||
return line, ""
|
||||
}
|
||||
|
||||
return strings.TrimSpace(line[:commentStart]), strings.TrimSpace(line[commentStart+2:])
|
||||
}
|
||||
|
||||
func splitColumnSignatureAndAttrs(line string) (string, string) {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
if trimmed == "" || !strings.HasSuffix(trimmed, "]") {
|
||||
return trimmed, ""
|
||||
}
|
||||
|
||||
bracketDepth := 0
|
||||
for i := len(trimmed) - 1; i >= 0; i-- {
|
||||
switch trimmed[i] {
|
||||
case ']':
|
||||
bracketDepth++
|
||||
case '[':
|
||||
bracketDepth--
|
||||
if bracketDepth == 0 {
|
||||
// DBML attributes are a trailing [ ... ] block preceded by whitespace.
|
||||
// This avoids confusing array types like text[] with attribute blocks.
|
||||
if i > 0 && (trimmed[i-1] == ' ' || trimmed[i-1] == '\t') {
|
||||
return strings.TrimSpace(trimmed[:i]), strings.TrimSpace(trimmed[i+1 : len(trimmed)-1])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return trimmed, ""
|
||||
}
|
||||
|
||||
func parseColumnSignature(signature string) (string, string, bool) {
|
||||
signature = strings.TrimSpace(signature)
|
||||
if signature == "" {
|
||||
return "", "", false
|
||||
}
|
||||
|
||||
var splitAt int
|
||||
if signature[0] == '"' || signature[0] == '\'' {
|
||||
quote := signature[0]
|
||||
splitAt = 1
|
||||
for splitAt < len(signature) {
|
||||
if signature[splitAt] == quote {
|
||||
splitAt++
|
||||
break
|
||||
}
|
||||
splitAt++
|
||||
}
|
||||
} else {
|
||||
for splitAt < len(signature) && signature[splitAt] != ' ' && signature[splitAt] != '\t' {
|
||||
splitAt++
|
||||
}
|
||||
}
|
||||
|
||||
if splitAt <= 0 || splitAt >= len(signature) {
|
||||
return "", "", false
|
||||
}
|
||||
|
||||
columnName := stripQuotes(strings.TrimSpace(signature[:splitAt]))
|
||||
columnType := stripWrappingQuotes(strings.TrimSpace(signature[splitAt:]))
|
||||
if columnName == "" || columnType == "" {
|
||||
return "", "", false
|
||||
}
|
||||
|
||||
return columnName, columnType, true
|
||||
}
|
||||
|
||||
func stripWrappingQuotes(s string) string {
|
||||
s = strings.TrimSpace(s)
|
||||
if len(s) >= 2 && ((s[0] == '"' && s[len(s)-1] == '"') || (s[0] == '\'' && s[len(s)-1] == '\'')) {
|
||||
return s[1 : len(s)-1]
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// parseIndex parses a DBML index definition
|
||||
func (r *Reader) parseIndex(line, tableName, schemaName string) *models.Index {
|
||||
// Format: (columns) [attributes] OR columnname [attributes]
|
||||
|
||||
@@ -839,6 +839,67 @@ func TestConstraintNaming(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseColumn_PostgresTypes(t *testing.T) {
|
||||
reader := &Reader{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
line string
|
||||
wantName string
|
||||
wantType string
|
||||
wantNotNull bool
|
||||
wantComment string
|
||||
}{
|
||||
{
|
||||
name: "array type with attrs",
|
||||
line: "tags text[] [not null]",
|
||||
wantName: "tags",
|
||||
wantType: "text[]",
|
||||
wantNotNull: true,
|
||||
},
|
||||
{
|
||||
name: "vector with dimension",
|
||||
line: "embedding vector(1536)",
|
||||
wantName: "embedding",
|
||||
wantType: "vector(1536)",
|
||||
},
|
||||
{
|
||||
name: "multi word timestamp type",
|
||||
line: "published_at timestamp with time zone",
|
||||
wantName: "published_at",
|
||||
wantType: "timestamp with time zone",
|
||||
},
|
||||
{
|
||||
name: "array type with inline comment",
|
||||
line: "labels varchar(20)[] // column labels",
|
||||
wantName: "labels",
|
||||
wantType: "varchar(20)[]",
|
||||
wantComment: "column labels",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
col, _ := reader.parseColumn(tt.line, "events", "public")
|
||||
if col == nil {
|
||||
t.Fatalf("parseColumn() returned nil column")
|
||||
}
|
||||
if col.Name != tt.wantName {
|
||||
t.Errorf("column name = %q, want %q", col.Name, tt.wantName)
|
||||
}
|
||||
if col.Type != tt.wantType {
|
||||
t.Errorf("column type = %q, want %q", col.Type, tt.wantType)
|
||||
}
|
||||
if col.NotNull != tt.wantNotNull {
|
||||
t.Errorf("column not null = %v, want %v", col.NotNull, tt.wantNotNull)
|
||||
}
|
||||
if col.Comment != tt.wantComment {
|
||||
t.Errorf("column comment = %q, want %q", col.Comment, tt.wantComment)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func getKeys[V any](m map[string]V) []string {
|
||||
keys := make([]string, 0, len(m))
|
||||
for k := range m {
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/pgsql"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/readers"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/writers/drawdb"
|
||||
)
|
||||
@@ -231,30 +232,35 @@ func (r *Reader) convertToColumn(field *drawdb.DrawDBField, tableName, schemaNam
|
||||
|
||||
// Parse type and dimensions
|
||||
typeStr := field.Type
|
||||
typeStr = strings.TrimSpace(typeStr)
|
||||
column.Type = typeStr
|
||||
|
||||
// Try to extract length/precision from type string like "varchar(255)" or "decimal(10,2)"
|
||||
if strings.Contains(typeStr, "(") {
|
||||
parts := strings.Split(typeStr, "(")
|
||||
column.Type = parts[0]
|
||||
baseType := strings.TrimSpace(parts[0])
|
||||
|
||||
if len(parts) > 1 {
|
||||
dimensions := strings.TrimSuffix(parts[1], ")")
|
||||
if strings.Contains(dimensions, ",") {
|
||||
// Precision and scale (e.g., decimal(10,2))
|
||||
dims := strings.Split(dimensions, ",")
|
||||
if precision, err := strconv.Atoi(strings.TrimSpace(dims[0])); err == nil {
|
||||
column.Precision = precision
|
||||
}
|
||||
if len(dims) > 1 {
|
||||
if scale, err := strconv.Atoi(strings.TrimSpace(dims[1])); err == nil {
|
||||
column.Scale = scale
|
||||
// Precision and scale (e.g., decimal(10,2), numeric(10,2))
|
||||
if pgsql.SupportsPrecision(baseType) {
|
||||
dims := strings.Split(dimensions, ",")
|
||||
if precision, err := strconv.Atoi(strings.TrimSpace(dims[0])); err == nil {
|
||||
column.Precision = precision
|
||||
}
|
||||
if len(dims) > 1 {
|
||||
if scale, err := strconv.Atoi(strings.TrimSpace(dims[1])); err == nil {
|
||||
column.Scale = scale
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Just length (e.g., varchar(255))
|
||||
if length, err := strconv.Atoi(dimensions); err == nil {
|
||||
column.Length = length
|
||||
if pgsql.SupportsLength(baseType) {
|
||||
if length, err := strconv.Atoi(dimensions); err == nil {
|
||||
column.Length = length
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/readers"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/writers/drawdb"
|
||||
)
|
||||
|
||||
func TestReader_ReadDatabase_Simple(t *testing.T) {
|
||||
@@ -288,6 +289,61 @@ func TestReader_ReadDatabase_Complex(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertToColumn_PreservesExplicitTypeModifiers(t *testing.T) {
|
||||
reader := &Reader{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
fieldType string
|
||||
wantType string
|
||||
wantLength int
|
||||
wantPrecision int
|
||||
wantScale int
|
||||
}{
|
||||
{
|
||||
name: "varchar with length",
|
||||
fieldType: "varchar(255)",
|
||||
wantType: "varchar(255)",
|
||||
wantLength: 255,
|
||||
},
|
||||
{
|
||||
name: "numeric precision/scale",
|
||||
fieldType: "numeric(10,2)",
|
||||
wantType: "numeric(10,2)",
|
||||
wantPrecision: 10,
|
||||
wantScale: 2,
|
||||
},
|
||||
{
|
||||
name: "custom vector modifier",
|
||||
fieldType: "vector(1536)",
|
||||
wantType: "vector(1536)",
|
||||
wantLength: 1536,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
field := &drawdb.DrawDBField{
|
||||
Name: tt.name,
|
||||
Type: tt.fieldType,
|
||||
}
|
||||
col := reader.convertToColumn(field, "events", "public")
|
||||
if col.Type != tt.wantType {
|
||||
t.Fatalf("column type = %q, want %q", col.Type, tt.wantType)
|
||||
}
|
||||
if col.Length != tt.wantLength {
|
||||
t.Fatalf("column length = %d, want %d", col.Length, tt.wantLength)
|
||||
}
|
||||
if col.Precision != tt.wantPrecision {
|
||||
t.Fatalf("column precision = %d, want %d", col.Precision, tt.wantPrecision)
|
||||
}
|
||||
if col.Scale != tt.wantScale {
|
||||
t.Fatalf("column scale = %d, want %d", col.Scale, tt.wantScale)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestReader_ReadSchema(t *testing.T) {
|
||||
opts := &readers.ReaderOptions{
|
||||
FilePath: filepath.Join("..", "..", "..", "tests", "assets", "drawdb", "simple.json"),
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/pgsql"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/readers"
|
||||
)
|
||||
|
||||
@@ -784,11 +785,14 @@ func (r *Reader) extractGormTag(tag string) string {
|
||||
// parseTypeWithLength parses a type string and extracts length if present
|
||||
// e.g., "varchar(255)" returns ("varchar", 255)
|
||||
func (r *Reader) parseTypeWithLength(typeStr string) (baseType string, length int) {
|
||||
typeStr = strings.TrimSpace(typeStr)
|
||||
baseType = typeStr
|
||||
|
||||
// Check for type with length: varchar(255), char(10), etc.
|
||||
// Also handle precision/scale: numeric(10,2)
|
||||
if strings.Contains(typeStr, "(") {
|
||||
idx := strings.Index(typeStr, "(")
|
||||
baseType = strings.TrimSpace(typeStr[:idx])
|
||||
rawBaseType := strings.TrimSpace(typeStr[:idx])
|
||||
|
||||
// Extract numbers from parentheses
|
||||
parens := typeStr[idx+1:]
|
||||
@@ -796,14 +800,15 @@ func (r *Reader) parseTypeWithLength(typeStr string) (baseType string, length in
|
||||
parens = parens[:endIdx]
|
||||
}
|
||||
|
||||
// For now, just handle single number (length)
|
||||
if !strings.Contains(parens, ",") {
|
||||
// Only treat as "length" for text-ish SQL types.
|
||||
// This avoids converting custom modifiers like vector(1536) into Length.
|
||||
if pgsql.SupportsLength(rawBaseType) && !strings.Contains(parens, ",") {
|
||||
if _, err := fmt.Sscanf(parens, "%d", &length); err == nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
baseType = typeStr
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -71,8 +71,11 @@ func TestReader_ReadDatabase_Simple(t *testing.T) {
|
||||
if !emailCol.NotNull {
|
||||
t.Error("Column 'email' should be NOT NULL (explicit 'not null' tag)")
|
||||
}
|
||||
if emailCol.Type != "varchar" || emailCol.Length != 255 {
|
||||
t.Errorf("Expected email type 'varchar(255)', got '%s' with length %d", emailCol.Type, emailCol.Length)
|
||||
if emailCol.Type != "varchar" && emailCol.Type != "varchar(255)" {
|
||||
t.Errorf("Expected email type 'varchar' or 'varchar(255)', got '%s' with length %d", emailCol.Type, emailCol.Length)
|
||||
}
|
||||
if emailCol.Length != 255 {
|
||||
t.Errorf("Expected email length 255, got %d", emailCol.Length)
|
||||
}
|
||||
|
||||
// Verify name column - primitive string type should be NOT NULL by default
|
||||
@@ -363,6 +366,33 @@ func TestReader_ReadDatabase_Complex(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseTypeWithLength_PreservesExplicitTypeModifiers(t *testing.T) {
|
||||
reader := &Reader{}
|
||||
|
||||
tests := []struct {
|
||||
input string
|
||||
wantType string
|
||||
wantLength int
|
||||
}{
|
||||
{"varchar(255)", "varchar(255)", 255},
|
||||
{"character varying(120)", "character varying(120)", 120},
|
||||
{"vector(1536)", "vector(1536)", 1536},
|
||||
{"numeric(10,2)", "numeric(10,2)", 0},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
gotType, gotLength := reader.parseTypeWithLength(tt.input)
|
||||
if gotType != tt.wantType {
|
||||
t.Fatalf("parseTypeWithLength(%q) type = %q, want %q", tt.input, gotType, tt.wantType)
|
||||
}
|
||||
if gotLength != tt.wantLength {
|
||||
t.Fatalf("parseTypeWithLength(%q) length = %d, want %d", tt.input, gotLength, tt.wantLength)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestReader_ReadSchema(t *testing.T) {
|
||||
opts := &readers.ReaderOptions{
|
||||
FilePath: filepath.Join("..", "..", "..", "tests", "assets", "gorm", "simple.go"),
|
||||
|
||||
@@ -206,8 +206,19 @@ func (r *Reader) queryColumns(schemaName string) (map[string]map[string]*models.
|
||||
c.numeric_precision,
|
||||
c.numeric_scale,
|
||||
c.udt_name,
|
||||
pg_catalog.format_type(a.atttypid, a.atttypmod) as formatted_data_type,
|
||||
col_description((c.table_schema||'.'||c.table_name)::regclass, c.ordinal_position) as description
|
||||
FROM information_schema.columns c
|
||||
JOIN pg_catalog.pg_namespace n
|
||||
ON n.nspname = c.table_schema
|
||||
JOIN pg_catalog.pg_class cls
|
||||
ON cls.relname = c.table_name
|
||||
AND cls.relnamespace = n.oid
|
||||
JOIN pg_catalog.pg_attribute a
|
||||
ON a.attrelid = cls.oid
|
||||
AND a.attname = c.column_name
|
||||
AND a.attnum > 0
|
||||
AND NOT a.attisdropped
|
||||
WHERE c.table_schema = $1
|
||||
ORDER BY c.table_schema, c.table_name, c.ordinal_position
|
||||
`
|
||||
@@ -221,12 +232,12 @@ func (r *Reader) queryColumns(schemaName string) (map[string]map[string]*models.
|
||||
columnsMap := make(map[string]map[string]*models.Column)
|
||||
|
||||
for rows.Next() {
|
||||
var schema, tableName, columnName, isNullable, dataType, udtName string
|
||||
var schema, tableName, columnName, isNullable, dataType, udtName, formattedDataType string
|
||||
var ordinalPosition int
|
||||
var columnDefault, description *string
|
||||
var charMaxLength, numPrecision, numScale *int
|
||||
|
||||
if err := rows.Scan(&schema, &tableName, &columnName, &ordinalPosition, &columnDefault, &isNullable, &dataType, &charMaxLength, &numPrecision, &numScale, &udtName, &description); err != nil {
|
||||
if err := rows.Scan(&schema, &tableName, &columnName, &ordinalPosition, &columnDefault, &isNullable, &dataType, &charMaxLength, &numPrecision, &numScale, &udtName, &formattedDataType, &description); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -246,7 +257,7 @@ func (r *Reader) queryColumns(schemaName string) (map[string]map[string]*models.
|
||||
}
|
||||
|
||||
// Map data type, preserving serial types when detected
|
||||
column.Type = r.mapDataType(dataType, udtName, hasNextval)
|
||||
column.Type = r.mapDataType(dataType, udtName, formattedDataType, hasNextval)
|
||||
column.NotNull = (isNullable == "NO")
|
||||
column.Sequence = uint(ordinalPosition)
|
||||
|
||||
|
||||
@@ -259,12 +259,14 @@ func (r *Reader) close() {
|
||||
}
|
||||
}
|
||||
|
||||
// mapDataType maps PostgreSQL data types to canonical types
|
||||
func (r *Reader) mapDataType(pgType, udtName string, hasNextval bool) string {
|
||||
// mapDataType maps PostgreSQL data types while preserving exact type text when available.
|
||||
func (r *Reader) mapDataType(pgType, udtName, formattedType string, hasNextval bool) string {
|
||||
normalizedPGType := strings.ToLower(strings.TrimSpace(pgType))
|
||||
|
||||
// If the column has a nextval default, it's likely a serial type
|
||||
// Map to the appropriate serial type instead of the base integer type
|
||||
if hasNextval {
|
||||
switch strings.ToLower(pgType) {
|
||||
switch normalizedPGType {
|
||||
case "integer", "int", "int4":
|
||||
return "serial"
|
||||
case "bigint", "int8":
|
||||
@@ -274,6 +276,17 @@ func (r *Reader) mapDataType(pgType, udtName string, hasNextval bool) string {
|
||||
}
|
||||
}
|
||||
|
||||
// Prefer the database-provided formatted type; this preserves arrays/custom
|
||||
// types/modifiers like text[], vector(1536), numeric(10,2), etc.
|
||||
if strings.TrimSpace(formattedType) != "" {
|
||||
return formattedType
|
||||
}
|
||||
|
||||
// information_schema reports arrays generically as "ARRAY" with udt_name like "_text".
|
||||
if strings.EqualFold(pgType, "ARRAY") && strings.HasPrefix(udtName, "_") && len(udtName) > 1 {
|
||||
return udtName[1:] + "[]"
|
||||
}
|
||||
|
||||
// Map common PostgreSQL types
|
||||
typeMap := map[string]string{
|
||||
"integer": "integer",
|
||||
@@ -320,7 +333,7 @@ func (r *Reader) mapDataType(pgType, udtName string, hasNextval bool) string {
|
||||
}
|
||||
|
||||
// Try mapped type first
|
||||
if mapped, exists := typeMap[pgType]; exists {
|
||||
if mapped, exists := typeMap[normalizedPGType]; exists {
|
||||
return mapped
|
||||
}
|
||||
|
||||
@@ -329,8 +342,11 @@ func (r *Reader) mapDataType(pgType, udtName string, hasNextval bool) string {
|
||||
return pgsql.GetSQLType(pgType)
|
||||
}
|
||||
|
||||
// Return UDT name for custom types
|
||||
// Return UDT name for custom types (including array fallback when needed)
|
||||
if udtName != "" {
|
||||
if strings.HasPrefix(udtName, "_") && len(udtName) > 1 {
|
||||
return udtName[1:] + "[]"
|
||||
}
|
||||
return udtName
|
||||
}
|
||||
|
||||
|
||||
@@ -173,35 +173,39 @@ func TestMapDataType(t *testing.T) {
|
||||
reader := &Reader{}
|
||||
|
||||
tests := []struct {
|
||||
pgType string
|
||||
udtName string
|
||||
expected string
|
||||
pgType string
|
||||
udtName string
|
||||
formattedType string
|
||||
expected string
|
||||
}{
|
||||
{"integer", "int4", "integer"},
|
||||
{"bigint", "int8", "bigint"},
|
||||
{"smallint", "int2", "smallint"},
|
||||
{"character varying", "varchar", "varchar"},
|
||||
{"text", "text", "text"},
|
||||
{"boolean", "bool", "boolean"},
|
||||
{"timestamp without time zone", "timestamp", "timestamp"},
|
||||
{"timestamp with time zone", "timestamptz", "timestamptz"},
|
||||
{"json", "json", "json"},
|
||||
{"jsonb", "jsonb", "jsonb"},
|
||||
{"uuid", "uuid", "uuid"},
|
||||
{"numeric", "numeric", "numeric"},
|
||||
{"real", "float4", "real"},
|
||||
{"double precision", "float8", "double precision"},
|
||||
{"date", "date", "date"},
|
||||
{"time without time zone", "time", "time"},
|
||||
{"bytea", "bytea", "bytea"},
|
||||
{"unknown_type", "custom", "custom"}, // Should return UDT name
|
||||
{"integer", "int4", "", "integer"},
|
||||
{"bigint", "int8", "", "bigint"},
|
||||
{"smallint", "int2", "", "smallint"},
|
||||
{"character varying", "varchar", "", "varchar"},
|
||||
{"text", "text", "", "text"},
|
||||
{"boolean", "bool", "", "boolean"},
|
||||
{"timestamp without time zone", "timestamp", "", "timestamp"},
|
||||
{"timestamp with time zone", "timestamptz", "", "timestamptz"},
|
||||
{"json", "json", "", "json"},
|
||||
{"jsonb", "jsonb", "", "jsonb"},
|
||||
{"uuid", "uuid", "", "uuid"},
|
||||
{"numeric", "numeric", "", "numeric"},
|
||||
{"real", "float4", "", "real"},
|
||||
{"double precision", "float8", "", "double precision"},
|
||||
{"date", "date", "", "date"},
|
||||
{"time without time zone", "time", "", "time"},
|
||||
{"bytea", "bytea", "", "bytea"},
|
||||
{"unknown_type", "custom", "", "custom"}, // Should return UDT name
|
||||
{"ARRAY", "_text", "", "text[]"},
|
||||
{"USER-DEFINED", "vector", "vector(1536)", "vector(1536)"},
|
||||
{"character varying", "varchar", "character varying(255)", "character varying(255)"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.pgType, func(t *testing.T) {
|
||||
result := reader.mapDataType(tt.pgType, tt.udtName, false)
|
||||
result := reader.mapDataType(tt.pgType, tt.udtName, tt.formattedType, false)
|
||||
if result != tt.expected {
|
||||
t.Errorf("mapDataType(%s, %s) = %s, expected %s", tt.pgType, tt.udtName, result, tt.expected)
|
||||
t.Errorf("mapDataType(%s, %s, %s) = %s, expected %s", tt.pgType, tt.udtName, tt.formattedType, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -218,9 +222,9 @@ func TestMapDataType(t *testing.T) {
|
||||
|
||||
for _, tt := range serialTests {
|
||||
t.Run(tt.pgType+"_with_nextval", func(t *testing.T) {
|
||||
result := reader.mapDataType(tt.pgType, "", true)
|
||||
result := reader.mapDataType(tt.pgType, "", "", true)
|
||||
if result != tt.expected {
|
||||
t.Errorf("mapDataType(%s, '', true) = %s, expected %s", tt.pgType, result, tt.expected)
|
||||
t.Errorf("mapDataType(%s, '', '', true) = %s, expected %s", tt.pgType, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -230,63 +234,63 @@ func TestParseIndexDefinition(t *testing.T) {
|
||||
reader := &Reader{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
indexName string
|
||||
tableName string
|
||||
schema string
|
||||
indexDef string
|
||||
wantType string
|
||||
wantUnique bool
|
||||
name string
|
||||
indexName string
|
||||
tableName string
|
||||
schema string
|
||||
indexDef string
|
||||
wantType string
|
||||
wantUnique bool
|
||||
wantColumns int
|
||||
}{
|
||||
{
|
||||
name: "simple btree index",
|
||||
indexName: "idx_users_email",
|
||||
tableName: "users",
|
||||
schema: "public",
|
||||
indexDef: "CREATE INDEX idx_users_email ON public.users USING btree (email)",
|
||||
wantType: "btree",
|
||||
wantUnique: false,
|
||||
name: "simple btree index",
|
||||
indexName: "idx_users_email",
|
||||
tableName: "users",
|
||||
schema: "public",
|
||||
indexDef: "CREATE INDEX idx_users_email ON public.users USING btree (email)",
|
||||
wantType: "btree",
|
||||
wantUnique: false,
|
||||
wantColumns: 1,
|
||||
},
|
||||
{
|
||||
name: "unique index",
|
||||
indexName: "idx_users_username",
|
||||
tableName: "users",
|
||||
schema: "public",
|
||||
indexDef: "CREATE UNIQUE INDEX idx_users_username ON public.users USING btree (username)",
|
||||
wantType: "btree",
|
||||
wantUnique: true,
|
||||
name: "unique index",
|
||||
indexName: "idx_users_username",
|
||||
tableName: "users",
|
||||
schema: "public",
|
||||
indexDef: "CREATE UNIQUE INDEX idx_users_username ON public.users USING btree (username)",
|
||||
wantType: "btree",
|
||||
wantUnique: true,
|
||||
wantColumns: 1,
|
||||
},
|
||||
{
|
||||
name: "composite index",
|
||||
indexName: "idx_users_name",
|
||||
tableName: "users",
|
||||
schema: "public",
|
||||
indexDef: "CREATE INDEX idx_users_name ON public.users USING btree (first_name, last_name)",
|
||||
wantType: "btree",
|
||||
wantUnique: false,
|
||||
name: "composite index",
|
||||
indexName: "idx_users_name",
|
||||
tableName: "users",
|
||||
schema: "public",
|
||||
indexDef: "CREATE INDEX idx_users_name ON public.users USING btree (first_name, last_name)",
|
||||
wantType: "btree",
|
||||
wantUnique: false,
|
||||
wantColumns: 2,
|
||||
},
|
||||
{
|
||||
name: "gin index",
|
||||
indexName: "idx_posts_tags",
|
||||
tableName: "posts",
|
||||
schema: "public",
|
||||
indexDef: "CREATE INDEX idx_posts_tags ON public.posts USING gin (tags)",
|
||||
wantType: "gin",
|
||||
wantUnique: false,
|
||||
name: "gin index",
|
||||
indexName: "idx_posts_tags",
|
||||
tableName: "posts",
|
||||
schema: "public",
|
||||
indexDef: "CREATE INDEX idx_posts_tags ON public.posts USING gin (tags)",
|
||||
wantType: "gin",
|
||||
wantUnique: false,
|
||||
wantColumns: 1,
|
||||
},
|
||||
{
|
||||
name: "partial index with where clause",
|
||||
indexName: "idx_users_active",
|
||||
tableName: "users",
|
||||
schema: "public",
|
||||
indexDef: "CREATE INDEX idx_users_active ON public.users USING btree (id) WHERE (active = true)",
|
||||
wantType: "btree",
|
||||
wantUnique: false,
|
||||
name: "partial index with where clause",
|
||||
indexName: "idx_users_active",
|
||||
tableName: "users",
|
||||
schema: "public",
|
||||
indexDef: "CREATE INDEX idx_users_active ON public.users USING btree (id) WHERE (active = true)",
|
||||
wantType: "btree",
|
||||
wantUnique: false,
|
||||
wantColumns: 1,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -5,9 +5,11 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/pgsql"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/readers"
|
||||
)
|
||||
|
||||
@@ -549,6 +551,41 @@ func (r *Reader) parseColumnOptions(decorator string, column *models.Column, tab
|
||||
}
|
||||
}
|
||||
|
||||
// Preserve explicit type modifiers from options where present.
|
||||
// Example: @Column({ type: 'varchar', length: 255 }) -> varchar(255)
|
||||
if column.Type != "" && !strings.Contains(column.Type, "(") {
|
||||
lengthRegex := regexp.MustCompile(`length:\s*(\d+)`)
|
||||
precisionRegex := regexp.MustCompile(`precision:\s*(\d+)`)
|
||||
scaleRegex := regexp.MustCompile(`scale:\s*(\d+)`)
|
||||
|
||||
baseType := strings.ToLower(strings.TrimSpace(column.Type))
|
||||
|
||||
if pgsql.SupportsLength(baseType) {
|
||||
if matches := lengthRegex.FindStringSubmatch(content); len(matches) == 2 {
|
||||
if n, err := strconv.Atoi(matches[1]); err == nil && n > 0 {
|
||||
column.Length = n
|
||||
column.Type = fmt.Sprintf("%s(%d)", column.Type, n)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if pgsql.SupportsPrecision(baseType) {
|
||||
if matches := precisionRegex.FindStringSubmatch(content); len(matches) == 2 {
|
||||
if p, err := strconv.Atoi(matches[1]); err == nil && p > 0 {
|
||||
column.Precision = p
|
||||
if sm := scaleRegex.FindStringSubmatch(content); len(sm) == 2 {
|
||||
if s, err := strconv.Atoi(sm[1]); err == nil && s >= 0 {
|
||||
column.Scale = s
|
||||
column.Type = fmt.Sprintf("%s(%d,%d)", column.Type, p, s)
|
||||
}
|
||||
} else {
|
||||
column.Type = fmt.Sprintf("%s(%d)", column.Type, p)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if strings.Contains(content, "nullable: true") || strings.Contains(content, "nullable:true") {
|
||||
column.NotNull = false
|
||||
}
|
||||
|
||||
60
pkg/readers/typeorm/reader_test.go
Normal file
60
pkg/readers/typeorm/reader_test.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package typeorm
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||
)
|
||||
|
||||
func TestParseColumnOptions_PreservesTypeModifiers(t *testing.T) {
|
||||
reader := &Reader{}
|
||||
table := models.InitTable("users", "public")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
decorator string
|
||||
wantType string
|
||||
wantLength int
|
||||
wantPrecision int
|
||||
wantScale int
|
||||
}{
|
||||
{
|
||||
name: "varchar with length",
|
||||
decorator: `@Column({ type: 'varchar', length: 255 })`,
|
||||
wantType: "varchar(255)",
|
||||
wantLength: 255,
|
||||
},
|
||||
{
|
||||
name: "numeric with precision and scale",
|
||||
decorator: `@Column({ type: 'numeric', precision: 10, scale: 2 })`,
|
||||
wantType: "numeric(10,2)",
|
||||
wantPrecision: 10,
|
||||
wantScale: 2,
|
||||
},
|
||||
{
|
||||
name: "custom type with explicit modifier is preserved",
|
||||
decorator: `@Column({ type: 'vector(1536)' })`,
|
||||
wantType: "vector(1536)",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
col := models.InitColumn("sample", table.Name, table.Schema)
|
||||
reader.parseColumnOptions(tt.decorator, col, table)
|
||||
|
||||
if col.Type != tt.wantType {
|
||||
t.Fatalf("column type = %q, want %q", col.Type, tt.wantType)
|
||||
}
|
||||
if col.Length != tt.wantLength {
|
||||
t.Fatalf("column length = %d, want %d", col.Length, tt.wantLength)
|
||||
}
|
||||
if col.Precision != tt.wantPrecision {
|
||||
t.Fatalf("column precision = %d, want %d", col.Precision, tt.wantPrecision)
|
||||
}
|
||||
if col.Scale != tt.wantScale {
|
||||
t.Fatalf("column scale = %d, want %d", col.Scale, tt.wantScale)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user