test(drawdb): add test for converting column types with modifiers
* Implement tests to ensure explicit type modifiers are preserved during conversion. * Validate behavior for varchar, numeric, and custom vector types.
This commit is contained in:
@@ -567,110 +567,182 @@ func (r *Reader) parseDBML(content string) (*models.Database, error) {
|
||||
// parseColumn parses a DBML column definition
|
||||
func (r *Reader) parseColumn(line, tableName, schemaName string) (*models.Column, *models.Constraint) {
|
||||
// Format: column_name type [attributes] // comment
|
||||
parts := strings.Fields(line)
|
||||
if len(parts) < 2 {
|
||||
lineNoComment, inlineComment := splitInlineComment(line)
|
||||
signature, attrs := splitColumnSignatureAndAttrs(lineNoComment)
|
||||
columnName, columnType, ok := parseColumnSignature(signature)
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
columnName := stripQuotes(parts[0])
|
||||
columnType := stripQuotes(parts[1])
|
||||
|
||||
column := models.InitColumn(columnName, tableName, schemaName)
|
||||
column.Type = columnType
|
||||
|
||||
var constraint *models.Constraint
|
||||
|
||||
// Parse attributes in brackets
|
||||
if strings.Contains(line, "[") && strings.Contains(line, "]") {
|
||||
attrStart := strings.Index(line, "[")
|
||||
attrEnd := strings.Index(line, "]")
|
||||
if attrStart < attrEnd {
|
||||
attrs := line[attrStart+1 : attrEnd]
|
||||
attrList := strings.Split(attrs, ",")
|
||||
if attrs != "" {
|
||||
attrList := strings.Split(attrs, ",")
|
||||
|
||||
for _, attr := range attrList {
|
||||
attr = strings.TrimSpace(attr)
|
||||
for _, attr := range attrList {
|
||||
attr = strings.TrimSpace(attr)
|
||||
|
||||
if strings.Contains(attr, "primary key") || attr == "pk" {
|
||||
column.IsPrimaryKey = true
|
||||
column.NotNull = true
|
||||
} else if strings.Contains(attr, "not null") {
|
||||
column.NotNull = true
|
||||
} else if attr == "increment" {
|
||||
column.AutoIncrement = true
|
||||
} else if strings.HasPrefix(attr, "default:") {
|
||||
defaultVal := strings.TrimSpace(strings.TrimPrefix(attr, "default:"))
|
||||
column.Default = strings.Trim(defaultVal, "'\"")
|
||||
} else if attr == "unique" {
|
||||
// Create a unique constraint
|
||||
// Clean table name by removing leading underscores to avoid double underscores
|
||||
cleanTableName := strings.TrimLeft(tableName, "_")
|
||||
uniqueConstraint := models.InitConstraint(
|
||||
fmt.Sprintf("ukey_%s_%s", cleanTableName, columnName),
|
||||
models.UniqueConstraint,
|
||||
)
|
||||
uniqueConstraint.Schema = schemaName
|
||||
uniqueConstraint.Table = tableName
|
||||
uniqueConstraint.Columns = []string{columnName}
|
||||
// Store it to be added later
|
||||
if constraint == nil {
|
||||
constraint = uniqueConstraint
|
||||
}
|
||||
} else if strings.HasPrefix(attr, "note:") {
|
||||
// Parse column note/comment
|
||||
note := strings.TrimSpace(strings.TrimPrefix(attr, "note:"))
|
||||
column.Comment = strings.Trim(note, "'\"")
|
||||
} else if strings.HasPrefix(attr, "ref:") {
|
||||
// Parse inline reference
|
||||
// DBML semantics depend on context:
|
||||
// - On FK column: ref: < target means "this FK references target"
|
||||
// - On PK column: ref: < source means "source references this PK" (reverse notation)
|
||||
refStr := strings.TrimSpace(strings.TrimPrefix(attr, "ref:"))
|
||||
|
||||
// Check relationship direction operator
|
||||
refOp := strings.TrimSpace(refStr)
|
||||
var isReverse bool
|
||||
if strings.HasPrefix(refOp, "<") {
|
||||
// < means "is referenced by" - only makes sense on PK columns
|
||||
isReverse = column.IsPrimaryKey
|
||||
}
|
||||
// > means "references" - always a forward FK, never reverse
|
||||
|
||||
constraint = r.parseRef(refStr)
|
||||
if constraint != nil {
|
||||
if isReverse {
|
||||
// Reverse: parsed ref is SOURCE, current column is TARGET
|
||||
// Constraint should be ON the source table
|
||||
constraint.Schema = constraint.ReferencedSchema
|
||||
constraint.Table = constraint.ReferencedTable
|
||||
constraint.Columns = constraint.ReferencedColumns
|
||||
constraint.ReferencedSchema = schemaName
|
||||
constraint.ReferencedTable = tableName
|
||||
constraint.ReferencedColumns = []string{columnName}
|
||||
} else {
|
||||
// Forward: current column is SOURCE, parsed ref is TARGET
|
||||
// Standard FK: constraint is ON current table
|
||||
constraint.Schema = schemaName
|
||||
constraint.Table = tableName
|
||||
constraint.Columns = []string{columnName}
|
||||
}
|
||||
// Generate constraint name based on table and columns
|
||||
constraint.Name = fmt.Sprintf("fk_%s_%s", constraint.Table, strings.Join(constraint.Columns, "_"))
|
||||
if strings.Contains(attr, "primary key") || attr == "pk" {
|
||||
column.IsPrimaryKey = true
|
||||
column.NotNull = true
|
||||
} else if strings.Contains(attr, "not null") {
|
||||
column.NotNull = true
|
||||
} else if attr == "increment" {
|
||||
column.AutoIncrement = true
|
||||
} else if strings.HasPrefix(attr, "default:") {
|
||||
defaultVal := strings.TrimSpace(strings.TrimPrefix(attr, "default:"))
|
||||
column.Default = strings.Trim(defaultVal, "'\"")
|
||||
} else if attr == "unique" {
|
||||
// Create a unique constraint
|
||||
// Clean table name by removing leading underscores to avoid double underscores
|
||||
cleanTableName := strings.TrimLeft(tableName, "_")
|
||||
uniqueConstraint := models.InitConstraint(
|
||||
fmt.Sprintf("ukey_%s_%s", cleanTableName, columnName),
|
||||
models.UniqueConstraint,
|
||||
)
|
||||
uniqueConstraint.Schema = schemaName
|
||||
uniqueConstraint.Table = tableName
|
||||
uniqueConstraint.Columns = []string{columnName}
|
||||
// Store it to be added later
|
||||
if constraint == nil {
|
||||
constraint = uniqueConstraint
|
||||
}
|
||||
} else if strings.HasPrefix(attr, "note:") {
|
||||
// Parse column note/comment
|
||||
note := strings.TrimSpace(strings.TrimPrefix(attr, "note:"))
|
||||
column.Comment = strings.Trim(note, "'\"")
|
||||
} else if strings.HasPrefix(attr, "ref:") {
|
||||
// Parse inline reference
|
||||
// DBML semantics depend on context:
|
||||
// - On FK column: ref: < target means "this FK references target"
|
||||
// - On PK column: ref: < source means "source references this PK" (reverse notation)
|
||||
refStr := strings.TrimSpace(strings.TrimPrefix(attr, "ref:"))
|
||||
|
||||
// Check relationship direction operator
|
||||
refOp := strings.TrimSpace(refStr)
|
||||
var isReverse bool
|
||||
if strings.HasPrefix(refOp, "<") {
|
||||
// < means "is referenced by" - only makes sense on PK columns
|
||||
isReverse = column.IsPrimaryKey
|
||||
}
|
||||
// > means "references" - always a forward FK, never reverse
|
||||
|
||||
constraint = r.parseRef(refStr)
|
||||
if constraint != nil {
|
||||
if isReverse {
|
||||
// Reverse: parsed ref is SOURCE, current column is TARGET
|
||||
// Constraint should be ON the source table
|
||||
constraint.Schema = constraint.ReferencedSchema
|
||||
constraint.Table = constraint.ReferencedTable
|
||||
constraint.Columns = constraint.ReferencedColumns
|
||||
constraint.ReferencedSchema = schemaName
|
||||
constraint.ReferencedTable = tableName
|
||||
constraint.ReferencedColumns = []string{columnName}
|
||||
} else {
|
||||
// Forward: current column is SOURCE, parsed ref is TARGET
|
||||
// Standard FK: constraint is ON current table
|
||||
constraint.Schema = schemaName
|
||||
constraint.Table = tableName
|
||||
constraint.Columns = []string{columnName}
|
||||
}
|
||||
// Generate constraint name based on table and columns
|
||||
constraint.Name = fmt.Sprintf("fk_%s_%s", constraint.Table, strings.Join(constraint.Columns, "_"))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Parse inline comment
|
||||
if strings.Contains(line, "//") {
|
||||
commentStart := strings.Index(line, "//")
|
||||
column.Comment = strings.TrimSpace(line[commentStart+2:])
|
||||
if inlineComment != "" {
|
||||
column.Comment = inlineComment
|
||||
}
|
||||
|
||||
return column, constraint
|
||||
}
|
||||
|
||||
func splitInlineComment(line string) (string, string) {
|
||||
commentStart := strings.Index(line, "//")
|
||||
if commentStart == -1 {
|
||||
return line, ""
|
||||
}
|
||||
|
||||
return strings.TrimSpace(line[:commentStart]), strings.TrimSpace(line[commentStart+2:])
|
||||
}
|
||||
|
||||
func splitColumnSignatureAndAttrs(line string) (string, string) {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
if trimmed == "" || !strings.HasSuffix(trimmed, "]") {
|
||||
return trimmed, ""
|
||||
}
|
||||
|
||||
bracketDepth := 0
|
||||
for i := len(trimmed) - 1; i >= 0; i-- {
|
||||
switch trimmed[i] {
|
||||
case ']':
|
||||
bracketDepth++
|
||||
case '[':
|
||||
bracketDepth--
|
||||
if bracketDepth == 0 {
|
||||
// DBML attributes are a trailing [ ... ] block preceded by whitespace.
|
||||
// This avoids confusing array types like text[] with attribute blocks.
|
||||
if i > 0 && (trimmed[i-1] == ' ' || trimmed[i-1] == '\t') {
|
||||
return strings.TrimSpace(trimmed[:i]), strings.TrimSpace(trimmed[i+1 : len(trimmed)-1])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return trimmed, ""
|
||||
}
|
||||
|
||||
func parseColumnSignature(signature string) (string, string, bool) {
|
||||
signature = strings.TrimSpace(signature)
|
||||
if signature == "" {
|
||||
return "", "", false
|
||||
}
|
||||
|
||||
var splitAt int
|
||||
if signature[0] == '"' || signature[0] == '\'' {
|
||||
quote := signature[0]
|
||||
splitAt = 1
|
||||
for splitAt < len(signature) {
|
||||
if signature[splitAt] == quote {
|
||||
splitAt++
|
||||
break
|
||||
}
|
||||
splitAt++
|
||||
}
|
||||
} else {
|
||||
for splitAt < len(signature) && signature[splitAt] != ' ' && signature[splitAt] != '\t' {
|
||||
splitAt++
|
||||
}
|
||||
}
|
||||
|
||||
if splitAt <= 0 || splitAt >= len(signature) {
|
||||
return "", "", false
|
||||
}
|
||||
|
||||
columnName := stripQuotes(strings.TrimSpace(signature[:splitAt]))
|
||||
columnType := stripWrappingQuotes(strings.TrimSpace(signature[splitAt:]))
|
||||
if columnName == "" || columnType == "" {
|
||||
return "", "", false
|
||||
}
|
||||
|
||||
return columnName, columnType, true
|
||||
}
|
||||
|
||||
func stripWrappingQuotes(s string) string {
|
||||
s = strings.TrimSpace(s)
|
||||
if len(s) >= 2 && ((s[0] == '"' && s[len(s)-1] == '"') || (s[0] == '\'' && s[len(s)-1] == '\'')) {
|
||||
return s[1 : len(s)-1]
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// parseIndex parses a DBML index definition
|
||||
func (r *Reader) parseIndex(line, tableName, schemaName string) *models.Index {
|
||||
// Format: (columns) [attributes] OR columnname [attributes]
|
||||
|
||||
@@ -839,6 +839,67 @@ func TestConstraintNaming(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseColumn_PostgresTypes(t *testing.T) {
|
||||
reader := &Reader{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
line string
|
||||
wantName string
|
||||
wantType string
|
||||
wantNotNull bool
|
||||
wantComment string
|
||||
}{
|
||||
{
|
||||
name: "array type with attrs",
|
||||
line: "tags text[] [not null]",
|
||||
wantName: "tags",
|
||||
wantType: "text[]",
|
||||
wantNotNull: true,
|
||||
},
|
||||
{
|
||||
name: "vector with dimension",
|
||||
line: "embedding vector(1536)",
|
||||
wantName: "embedding",
|
||||
wantType: "vector(1536)",
|
||||
},
|
||||
{
|
||||
name: "multi word timestamp type",
|
||||
line: "published_at timestamp with time zone",
|
||||
wantName: "published_at",
|
||||
wantType: "timestamp with time zone",
|
||||
},
|
||||
{
|
||||
name: "array type with inline comment",
|
||||
line: "labels varchar(20)[] // column labels",
|
||||
wantName: "labels",
|
||||
wantType: "varchar(20)[]",
|
||||
wantComment: "column labels",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
col, _ := reader.parseColumn(tt.line, "events", "public")
|
||||
if col == nil {
|
||||
t.Fatalf("parseColumn() returned nil column")
|
||||
}
|
||||
if col.Name != tt.wantName {
|
||||
t.Errorf("column name = %q, want %q", col.Name, tt.wantName)
|
||||
}
|
||||
if col.Type != tt.wantType {
|
||||
t.Errorf("column type = %q, want %q", col.Type, tt.wantType)
|
||||
}
|
||||
if col.NotNull != tt.wantNotNull {
|
||||
t.Errorf("column not null = %v, want %v", col.NotNull, tt.wantNotNull)
|
||||
}
|
||||
if col.Comment != tt.wantComment {
|
||||
t.Errorf("column comment = %q, want %q", col.Comment, tt.wantComment)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func getKeys[V any](m map[string]V) []string {
|
||||
keys := make([]string, 0, len(m))
|
||||
for k := range m {
|
||||
|
||||
Reference in New Issue
Block a user