feat(reader): 🎉 Add support for multi-file DBML loading
All checks were successful
CI / Test (1.24) (push) Successful in -27m13s
CI / Test (1.25) (push) Successful in -27m5s
CI / Build (push) Successful in -27m16s
CI / Lint (push) Successful in -27m0s
Integration Tests / Integration Tests (push) Successful in -27m14s
Release / Build and Release (push) Successful in -25m52s

* Implement directory reading for DBML files.
* Merge schemas and tables from multiple files.
* Add tests for multi-file loading and merging behavior.
* Enhance file discovery and sorting logic.
This commit is contained in:
2026-01-10 13:17:30 +02:00
parent f6c3f2b460
commit 6388daba56
7 changed files with 626 additions and 12 deletions

View File

@@ -4,7 +4,9 @@ import (
"bufio"
"fmt"
"os"
"path/filepath"
"regexp"
"sort"
"strings"
"git.warky.dev/wdevs/relspecgo/pkg/models"
@@ -24,11 +26,23 @@ func NewReader(options *readers.ReaderOptions) *Reader {
}
// ReadDatabase reads and parses DBML input, returning a Database model
// If FilePath points to a directory, all .dbml files are loaded and merged
func (r *Reader) ReadDatabase() (*models.Database, error) {
if r.options.FilePath == "" {
return nil, fmt.Errorf("file path is required for DBML reader")
}
// Check if path is a directory
info, err := os.Stat(r.options.FilePath)
if err != nil {
return nil, fmt.Errorf("failed to stat path: %w", err)
}
if info.IsDir() {
return r.readDirectoryDBML(r.options.FilePath)
}
// Single file - existing logic
content, err := os.ReadFile(r.options.FilePath)
if err != nil {
return nil, fmt.Errorf("failed to read file: %w", err)
@@ -67,6 +81,53 @@ func (r *Reader) ReadTable() (*models.Table, error) {
return schema.Tables[0], nil
}
// readDirectoryDBML processes all .dbml files in directory
// Returns merged Database model
func (r *Reader) readDirectoryDBML(dirPath string) (*models.Database, error) {
// Discover and sort DBML files
files, err := r.discoverDBMLFiles(dirPath)
if err != nil {
return nil, fmt.Errorf("failed to discover DBML files: %w", err)
}
// If no files found, return empty database
if len(files) == 0 {
db := models.InitDatabase("database")
if r.options.Metadata != nil {
if name, ok := r.options.Metadata["name"].(string); ok {
db.Name = name
}
}
return db, nil
}
// Initialize database (will be merged with files)
var db *models.Database
// Process each file in sorted order
for _, filePath := range files {
content, err := os.ReadFile(filePath)
if err != nil {
return nil, fmt.Errorf("failed to read file %s: %w", filePath, err)
}
fileDB, err := r.parseDBML(string(content))
if err != nil {
return nil, fmt.Errorf("failed to parse file %s: %w", filePath, err)
}
// First file initializes the database
if db == nil {
db = fileDB
} else {
// Subsequent files are merged
mergeDatabase(db, fileDB)
}
}
return db, nil
}
// stripQuotes removes surrounding quotes from an identifier
func stripQuotes(s string) string {
s = strings.TrimSpace(s)
@@ -76,6 +137,235 @@ func stripQuotes(s string) string {
return s
}
// parseFilePrefix extracts numeric prefix from filename
// Examples: "1_schema.dbml" -> (1, true), "tables.dbml" -> (0, false)
func parseFilePrefix(filename string) (int, bool) {
base := filepath.Base(filename)
re := regexp.MustCompile(`^(\d+)[_-]`)
matches := re.FindStringSubmatch(base)
if len(matches) > 1 {
var prefix int
_, err := fmt.Sscanf(matches[1], "%d", &prefix)
if err == nil {
return prefix, true
}
}
return 0, false
}
// hasCommentedRefs scans file content for commented-out Ref statements
// Returns true if file contains lines like: // Ref: table.col > other.col
func hasCommentedRefs(filePath string) (bool, error) {
content, err := os.ReadFile(filePath)
if err != nil {
return false, err
}
scanner := bufio.NewScanner(strings.NewReader(string(content)))
commentedRefRegex := regexp.MustCompile(`^\s*//.*Ref:\s+`)
for scanner.Scan() {
line := scanner.Text()
if commentedRefRegex.MatchString(line) {
return true, nil
}
}
return false, nil
}
// discoverDBMLFiles finds all .dbml files in directory and returns them sorted
func (r *Reader) discoverDBMLFiles(dirPath string) ([]string, error) {
pattern := filepath.Join(dirPath, "*.dbml")
files, err := filepath.Glob(pattern)
if err != nil {
return nil, fmt.Errorf("failed to glob .dbml files: %w", err)
}
return sortDBMLFiles(files), nil
}
// sortDBMLFiles sorts files by:
// 1. Files without commented refs (by numeric prefix, then alphabetically)
// 2. Files with commented refs (by numeric prefix, then alphabetically)
func sortDBMLFiles(files []string) []string {
// Create a slice to hold file info for sorting
type fileInfo struct {
path string
hasCommented bool
prefix int
hasPrefix bool
basename string
}
fileInfos := make([]fileInfo, 0, len(files))
for _, file := range files {
hasCommented, err := hasCommentedRefs(file)
if err != nil {
// If we can't read the file, treat it as not having commented refs
hasCommented = false
}
prefix, hasPrefix := parseFilePrefix(file)
basename := filepath.Base(file)
fileInfos = append(fileInfos, fileInfo{
path: file,
hasCommented: hasCommented,
prefix: prefix,
hasPrefix: hasPrefix,
basename: basename,
})
}
// Sort by: hasCommented (false first), hasPrefix (true first), prefix, basename
sort.Slice(fileInfos, func(i, j int) bool {
// First, sort by commented refs (files without commented refs come first)
if fileInfos[i].hasCommented != fileInfos[j].hasCommented {
return !fileInfos[i].hasCommented
}
// Then by presence of prefix (files with prefix come first)
if fileInfos[i].hasPrefix != fileInfos[j].hasPrefix {
return fileInfos[i].hasPrefix
}
// If both have prefix, sort by prefix value
if fileInfos[i].hasPrefix && fileInfos[j].hasPrefix {
if fileInfos[i].prefix != fileInfos[j].prefix {
return fileInfos[i].prefix < fileInfos[j].prefix
}
}
// Finally, sort alphabetically by basename
return fileInfos[i].basename < fileInfos[j].basename
})
// Extract sorted paths
sortedFiles := make([]string, len(fileInfos))
for i, info := range fileInfos {
sortedFiles[i] = info.path
}
return sortedFiles
}
// mergeTable combines two table definitions
// Merges: Columns (map), Constraints (map), Indexes (map), Relationships (map)
// Uses first non-empty Description
func mergeTable(baseTable, fileTable *models.Table) {
// Merge columns (map naturally merges - later keys overwrite)
for key, col := range fileTable.Columns {
baseTable.Columns[key] = col
}
// Merge constraints
for key, constraint := range fileTable.Constraints {
baseTable.Constraints[key] = constraint
}
// Merge indexes
for key, index := range fileTable.Indexes {
baseTable.Indexes[key] = index
}
// Merge relationships
for key, rel := range fileTable.Relationships {
baseTable.Relationships[key] = rel
}
// Use first non-empty description
if baseTable.Description == "" && fileTable.Description != "" {
baseTable.Description = fileTable.Description
}
// Merge metadata maps
if baseTable.Metadata == nil {
baseTable.Metadata = make(map[string]any)
}
for key, val := range fileTable.Metadata {
baseTable.Metadata[key] = val
}
}
// mergeSchema finds or creates schema and merges tables
func mergeSchema(baseDB *models.Database, fileSchema *models.Schema) {
// Find existing schema by name (normalize names by stripping quotes)
var existingSchema *models.Schema
fileSchemaName := stripQuotes(fileSchema.Name)
for _, schema := range baseDB.Schemas {
if stripQuotes(schema.Name) == fileSchemaName {
existingSchema = schema
break
}
}
// If schema doesn't exist, add it and return
if existingSchema == nil {
baseDB.Schemas = append(baseDB.Schemas, fileSchema)
return
}
// Merge tables from fileSchema into existingSchema
for _, fileTable := range fileSchema.Tables {
// Find existing table by name (normalize names by stripping quotes)
var existingTable *models.Table
fileTableName := stripQuotes(fileTable.Name)
for _, table := range existingSchema.Tables {
if stripQuotes(table.Name) == fileTableName {
existingTable = table
break
}
}
// If table doesn't exist, add it
if existingTable == nil {
existingSchema.Tables = append(existingSchema.Tables, fileTable)
} else {
// Merge table properties - tables are identical, skip
mergeTable(existingTable, fileTable)
}
}
// Merge other schema properties
existingSchema.Views = append(existingSchema.Views, fileSchema.Views...)
existingSchema.Sequences = append(existingSchema.Sequences, fileSchema.Sequences...)
existingSchema.Scripts = append(existingSchema.Scripts, fileSchema.Scripts...)
// Merge permissions
if existingSchema.Permissions == nil {
existingSchema.Permissions = make(map[string]string)
}
for key, val := range fileSchema.Permissions {
existingSchema.Permissions[key] = val
}
// Merge metadata
if existingSchema.Metadata == nil {
existingSchema.Metadata = make(map[string]any)
}
for key, val := range fileSchema.Metadata {
existingSchema.Metadata[key] = val
}
}
// mergeDatabase merges schemas from fileDB into baseDB
func mergeDatabase(baseDB, fileDB *models.Database) {
// Merge each schema from fileDB
for _, fileSchema := range fileDB.Schemas {
mergeSchema(baseDB, fileSchema)
}
// Merge domains
baseDB.Domains = append(baseDB.Domains, fileDB.Domains...)
// Use first non-empty description
if baseDB.Description == "" && fileDB.Description != "" {
baseDB.Description = fileDB.Description
}
}
// parseDBML parses DBML content and returns a Database model
func (r *Reader) parseDBML(content string) (*models.Database, error) {
db := models.InitDatabase("database")
@@ -332,27 +622,31 @@ func (r *Reader) parseIndex(line, tableName, schemaName string) *models.Index {
// Format: (columns) [attributes] OR columnname [attributes]
var columns []string
if strings.Contains(line, "(") && strings.Contains(line, ")") {
// Find the attributes section to avoid parsing parentheses in notes/attributes
attrStart := strings.Index(line, "[")
columnPart := line
if attrStart > 0 {
columnPart = line[:attrStart]
}
if strings.Contains(columnPart, "(") && strings.Contains(columnPart, ")") {
// Multi-column format: (col1, col2) [attributes]
colStart := strings.Index(line, "(")
colEnd := strings.Index(line, ")")
colStart := strings.Index(columnPart, "(")
colEnd := strings.Index(columnPart, ")")
if colStart >= colEnd {
return nil
}
columnsStr := line[colStart+1 : colEnd]
columnsStr := columnPart[colStart+1 : colEnd]
for _, col := range strings.Split(columnsStr, ",") {
columns = append(columns, stripQuotes(strings.TrimSpace(col)))
}
} else if strings.Contains(line, "[") {
} else if attrStart > 0 {
// Single column format: columnname [attributes]
// Extract column name before the bracket
idx := strings.Index(line, "[")
if idx > 0 {
colName := strings.TrimSpace(line[:idx])
if colName != "" {
columns = []string{stripQuotes(colName)}
}
colName := strings.TrimSpace(columnPart)
if colName != "" {
columns = []string{stripQuotes(colName)}
}
}

View File

@@ -1,6 +1,7 @@
package dbml
import (
"os"
"path/filepath"
"testing"
@@ -517,3 +518,286 @@ func TestGetForeignKeys(t *testing.T) {
t.Error("Expected foreign key constraint type")
}
}
// Tests for multi-file directory loading
func TestReadDirectory_MultipleFiles(t *testing.T) {
opts := &readers.ReaderOptions{
FilePath: filepath.Join("..", "..", "..", "tests", "assets", "dbml", "multifile"),
}
reader := NewReader(opts)
db, err := reader.ReadDatabase()
if err != nil {
t.Fatalf("ReadDatabase() error = %v", err)
}
if db == nil {
t.Fatal("ReadDatabase() returned nil database")
}
// Should have public schema
if len(db.Schemas) == 0 {
t.Fatal("Expected at least one schema")
}
var publicSchema *models.Schema
for _, schema := range db.Schemas {
if schema.Name == "public" {
publicSchema = schema
break
}
}
if publicSchema == nil {
t.Fatal("Public schema not found")
}
// Should have 3 tables: users, posts, comments
if len(publicSchema.Tables) != 3 {
t.Fatalf("Expected 3 tables, got %d", len(publicSchema.Tables))
}
// Find tables
var usersTable, postsTable, commentsTable *models.Table
for _, table := range publicSchema.Tables {
switch table.Name {
case "users":
usersTable = table
case "posts":
postsTable = table
case "comments":
commentsTable = table
}
}
if usersTable == nil {
t.Fatal("Users table not found")
}
if postsTable == nil {
t.Fatal("Posts table not found")
}
if commentsTable == nil {
t.Fatal("Comments table not found")
}
// Verify users table has merged columns from 1_users.dbml and 3_add_columns.dbml
expectedUserColumns := []string{"id", "email", "name", "created_at"}
if len(usersTable.Columns) != len(expectedUserColumns) {
t.Errorf("Expected %d columns in users table, got %d", len(expectedUserColumns), len(usersTable.Columns))
}
for _, colName := range expectedUserColumns {
if _, exists := usersTable.Columns[colName]; !exists {
t.Errorf("Expected column '%s' in users table", colName)
}
}
// Verify posts table columns
expectedPostColumns := []string{"id", "user_id", "title", "content", "created_at"}
for _, colName := range expectedPostColumns {
if _, exists := postsTable.Columns[colName]; !exists {
t.Errorf("Expected column '%s' in posts table", colName)
}
}
}
func TestReadDirectory_TableMerging(t *testing.T) {
opts := &readers.ReaderOptions{
FilePath: filepath.Join("..", "..", "..", "tests", "assets", "dbml", "multifile"),
}
reader := NewReader(opts)
db, err := reader.ReadDatabase()
if err != nil {
t.Fatalf("ReadDatabase() error = %v", err)
}
// Find users table
var usersTable *models.Table
for _, schema := range db.Schemas {
for _, table := range schema.Tables {
if table.Name == "users" && schema.Name == "public" {
usersTable = table
break
}
}
}
if usersTable == nil {
t.Fatal("Users table not found")
}
// Verify columns from file 1 (id, email)
if _, exists := usersTable.Columns["id"]; !exists {
t.Error("Column 'id' from 1_users.dbml not found")
}
if _, exists := usersTable.Columns["email"]; !exists {
t.Error("Column 'email' from 1_users.dbml not found")
}
// Verify columns from file 3 (name, created_at)
if _, exists := usersTable.Columns["name"]; !exists {
t.Error("Column 'name' from 3_add_columns.dbml not found")
}
if _, exists := usersTable.Columns["created_at"]; !exists {
t.Error("Column 'created_at' from 3_add_columns.dbml not found")
}
// Verify column properties from file 1
emailCol := usersTable.Columns["email"]
if !emailCol.NotNull {
t.Error("Email column should be not null (from 1_users.dbml)")
}
if emailCol.Type != "varchar(255)" {
t.Errorf("Expected email type 'varchar(255)', got '%s'", emailCol.Type)
}
}
func TestReadDirectory_CommentedRefsLast(t *testing.T) {
// This test verifies that files with commented refs are processed last
// by checking that the file discovery returns them in the correct order
dirPath := filepath.Join("..", "..", "..", "tests", "assets", "dbml", "multifile")
opts := &readers.ReaderOptions{
FilePath: dirPath,
}
reader := NewReader(opts)
files, err := reader.discoverDBMLFiles(dirPath)
if err != nil {
t.Fatalf("discoverDBMLFiles() error = %v", err)
}
if len(files) < 2 {
t.Skip("Not enough files to test ordering")
}
// Check that 9_refs.dbml (which has commented refs) comes last
lastFile := filepath.Base(files[len(files)-1])
if lastFile != "9_refs.dbml" {
t.Errorf("Expected last file to be '9_refs.dbml' (has commented refs), got '%s'", lastFile)
}
// Check that numbered files without commented refs come first
firstFile := filepath.Base(files[0])
if firstFile != "1_users.dbml" {
t.Errorf("Expected first file to be '1_users.dbml', got '%s'", firstFile)
}
}
func TestReadDirectory_EmptyDirectory(t *testing.T) {
// Create a temporary empty directory
tmpDir := filepath.Join("..", "..", "..", "tests", "assets", "dbml", "empty_test_dir")
err := os.MkdirAll(tmpDir, 0755)
if err != nil {
t.Fatalf("Failed to create temp directory: %v", err)
}
defer os.RemoveAll(tmpDir)
opts := &readers.ReaderOptions{
FilePath: tmpDir,
}
reader := NewReader(opts)
db, err := reader.ReadDatabase()
if err != nil {
t.Fatalf("ReadDatabase() should not error on empty directory, got: %v", err)
}
if db == nil {
t.Fatal("ReadDatabase() returned nil database")
}
// Empty directory should return empty database
if len(db.Schemas) != 0 {
t.Errorf("Expected 0 schemas for empty directory, got %d", len(db.Schemas))
}
}
func TestReadDatabase_BackwardCompat(t *testing.T) {
// Test that single file loading still works
opts := &readers.ReaderOptions{
FilePath: filepath.Join("..", "..", "..", "tests", "assets", "dbml", "simple.dbml"),
}
reader := NewReader(opts)
db, err := reader.ReadDatabase()
if err != nil {
t.Fatalf("ReadDatabase() error = %v", err)
}
if db == nil {
t.Fatal("ReadDatabase() returned nil database")
}
if len(db.Schemas) == 0 {
t.Fatal("Expected at least one schema")
}
schema := db.Schemas[0]
if len(schema.Tables) != 1 {
t.Fatalf("Expected 1 table, got %d", len(schema.Tables))
}
table := schema.Tables[0]
if table.Name != "users" {
t.Errorf("Expected table name 'users', got '%s'", table.Name)
}
}
func TestParseFilePrefix(t *testing.T) {
tests := []struct {
filename string
wantPrefix int
wantHas bool
}{
{"1_schema.dbml", 1, true},
{"2_tables.dbml", 2, true},
{"10_relationships.dbml", 10, true},
{"99_data.dbml", 99, true},
{"schema.dbml", 0, false},
{"tables_no_prefix.dbml", 0, false},
{"/path/to/1_file.dbml", 1, true},
{"/path/to/file.dbml", 0, false},
{"1-file.dbml", 1, true},
{"2-another.dbml", 2, true},
}
for _, tt := range tests {
t.Run(tt.filename, func(t *testing.T) {
gotPrefix, gotHas := parseFilePrefix(tt.filename)
if gotPrefix != tt.wantPrefix {
t.Errorf("parseFilePrefix(%s) prefix = %d, want %d", tt.filename, gotPrefix, tt.wantPrefix)
}
if gotHas != tt.wantHas {
t.Errorf("parseFilePrefix(%s) hasPrefix = %v, want %v", tt.filename, gotHas, tt.wantHas)
}
})
}
}
func TestHasCommentedRefs(t *testing.T) {
// Test with the actual multifile test fixtures
tests := []struct {
filename string
wantHas bool
}{
{filepath.Join("..", "..", "..", "tests", "assets", "dbml", "multifile", "1_users.dbml"), false},
{filepath.Join("..", "..", "..", "tests", "assets", "dbml", "multifile", "2_posts.dbml"), false},
{filepath.Join("..", "..", "..", "tests", "assets", "dbml", "multifile", "3_add_columns.dbml"), false},
{filepath.Join("..", "..", "..", "tests", "assets", "dbml", "multifile", "9_refs.dbml"), true},
}
for _, tt := range tests {
t.Run(filepath.Base(tt.filename), func(t *testing.T) {
gotHas, err := hasCommentedRefs(tt.filename)
if err != nil {
t.Fatalf("hasCommentedRefs() error = %v", err)
}
if gotHas != tt.wantHas {
t.Errorf("hasCommentedRefs(%s) = %v, want %v", filepath.Base(tt.filename), gotHas, tt.wantHas)
}
})
}
}

View File

@@ -126,7 +126,15 @@ func (w *Writer) tableToDBML(t *models.Table) string {
attrs = append(attrs, "increment")
}
if column.Default != nil {
attrs = append(attrs, fmt.Sprintf("default: `%v`", column.Default))
// Check if default value contains backticks (DBML expressions like `now()`)
defaultStr := fmt.Sprintf("%v", column.Default)
if strings.HasPrefix(defaultStr, "`") && strings.HasSuffix(defaultStr, "`") {
// Already an expression with backticks, use as-is
attrs = append(attrs, fmt.Sprintf("default: %s", defaultStr))
} else {
// Regular value, wrap in single quotes
attrs = append(attrs, fmt.Sprintf("default: '%v'", column.Default))
}
}
if len(attrs) > 0 {

View File

@@ -0,0 +1,5 @@
// First file - users table basic structure
Table public.users {
id bigint [pk, increment]
email varchar(255) [unique, not null]
}

View File

@@ -0,0 +1,8 @@
// Second file - posts table
Table public.posts {
id bigint [pk, increment]
user_id bigint [not null]
title varchar(200) [not null]
content text
created_at timestamp [not null]
}

View File

@@ -0,0 +1,5 @@
// Third file - adds more columns to users table (tests merging)
Table public.users {
name varchar(100)
created_at timestamp [not null]
}

View File

@@ -0,0 +1,10 @@
// File with commented-out refs - should load last
// Contains relationships that depend on earlier tables
// Ref: public.posts.user_id > public.users.id [ondelete: CASCADE]
Table public.comments {
id bigint [pk, increment]
post_id bigint [not null]
content text [not null]
}