From 6388daba5669d08653524ca1d9d657c712ede2f9 Mon Sep 17 00:00:00 2001 From: Hein Date: Sat, 10 Jan 2026 13:17:30 +0200 Subject: [PATCH] =?UTF-8?q?feat(reader):=20=F0=9F=8E=89=20Add=20support=20?= =?UTF-8?q?for=20multi-file=20DBML=20loading?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Implement directory reading for DBML files. * Merge schemas and tables from multiple files. * Add tests for multi-file loading and merging behavior. * Enhance file discovery and sorting logic. --- pkg/readers/dbml/reader.go | 316 +++++++++++++++++- pkg/readers/dbml/reader_test.go | 284 ++++++++++++++++ pkg/writers/dbml/writer.go | 10 +- tests/assets/dbml/multifile/1_users.dbml | 5 + tests/assets/dbml/multifile/2_posts.dbml | 8 + .../assets/dbml/multifile/3_add_columns.dbml | 5 + tests/assets/dbml/multifile/9_refs.dbml | 10 + 7 files changed, 626 insertions(+), 12 deletions(-) create mode 100644 tests/assets/dbml/multifile/1_users.dbml create mode 100644 tests/assets/dbml/multifile/2_posts.dbml create mode 100644 tests/assets/dbml/multifile/3_add_columns.dbml create mode 100644 tests/assets/dbml/multifile/9_refs.dbml diff --git a/pkg/readers/dbml/reader.go b/pkg/readers/dbml/reader.go index 834a631..24a0302 100644 --- a/pkg/readers/dbml/reader.go +++ b/pkg/readers/dbml/reader.go @@ -4,7 +4,9 @@ import ( "bufio" "fmt" "os" + "path/filepath" "regexp" + "sort" "strings" "git.warky.dev/wdevs/relspecgo/pkg/models" @@ -24,11 +26,23 @@ func NewReader(options *readers.ReaderOptions) *Reader { } // ReadDatabase reads and parses DBML input, returning a Database model +// If FilePath points to a directory, all .dbml files are loaded and merged func (r *Reader) ReadDatabase() (*models.Database, error) { if r.options.FilePath == "" { return nil, fmt.Errorf("file path is required for DBML reader") } + // Check if path is a directory + info, err := os.Stat(r.options.FilePath) + if err != nil { + return nil, fmt.Errorf("failed to stat path: %w", err) + } + + if info.IsDir() { + return r.readDirectoryDBML(r.options.FilePath) + } + + // Single file - existing logic content, err := os.ReadFile(r.options.FilePath) if err != nil { return nil, fmt.Errorf("failed to read file: %w", err) @@ -67,6 +81,53 @@ func (r *Reader) ReadTable() (*models.Table, error) { return schema.Tables[0], nil } +// readDirectoryDBML processes all .dbml files in directory +// Returns merged Database model +func (r *Reader) readDirectoryDBML(dirPath string) (*models.Database, error) { + // Discover and sort DBML files + files, err := r.discoverDBMLFiles(dirPath) + if err != nil { + return nil, fmt.Errorf("failed to discover DBML files: %w", err) + } + + // If no files found, return empty database + if len(files) == 0 { + db := models.InitDatabase("database") + if r.options.Metadata != nil { + if name, ok := r.options.Metadata["name"].(string); ok { + db.Name = name + } + } + return db, nil + } + + // Initialize database (will be merged with files) + var db *models.Database + + // Process each file in sorted order + for _, filePath := range files { + content, err := os.ReadFile(filePath) + if err != nil { + return nil, fmt.Errorf("failed to read file %s: %w", filePath, err) + } + + fileDB, err := r.parseDBML(string(content)) + if err != nil { + return nil, fmt.Errorf("failed to parse file %s: %w", filePath, err) + } + + // First file initializes the database + if db == nil { + db = fileDB + } else { + // Subsequent files are merged + mergeDatabase(db, fileDB) + } + } + + return db, nil +} + // stripQuotes removes surrounding quotes from an identifier func stripQuotes(s string) string { s = strings.TrimSpace(s) @@ -76,6 +137,235 @@ func stripQuotes(s string) string { return s } +// parseFilePrefix extracts numeric prefix from filename +// Examples: "1_schema.dbml" -> (1, true), "tables.dbml" -> (0, false) +func parseFilePrefix(filename string) (int, bool) { + base := filepath.Base(filename) + re := regexp.MustCompile(`^(\d+)[_-]`) + matches := re.FindStringSubmatch(base) + if len(matches) > 1 { + var prefix int + _, err := fmt.Sscanf(matches[1], "%d", &prefix) + if err == nil { + return prefix, true + } + } + return 0, false +} + +// hasCommentedRefs scans file content for commented-out Ref statements +// Returns true if file contains lines like: // Ref: table.col > other.col +func hasCommentedRefs(filePath string) (bool, error) { + content, err := os.ReadFile(filePath) + if err != nil { + return false, err + } + + scanner := bufio.NewScanner(strings.NewReader(string(content))) + commentedRefRegex := regexp.MustCompile(`^\s*//.*Ref:\s+`) + + for scanner.Scan() { + line := scanner.Text() + if commentedRefRegex.MatchString(line) { + return true, nil + } + } + + return false, nil +} + +// discoverDBMLFiles finds all .dbml files in directory and returns them sorted +func (r *Reader) discoverDBMLFiles(dirPath string) ([]string, error) { + pattern := filepath.Join(dirPath, "*.dbml") + files, err := filepath.Glob(pattern) + if err != nil { + return nil, fmt.Errorf("failed to glob .dbml files: %w", err) + } + + return sortDBMLFiles(files), nil +} + +// sortDBMLFiles sorts files by: +// 1. Files without commented refs (by numeric prefix, then alphabetically) +// 2. Files with commented refs (by numeric prefix, then alphabetically) +func sortDBMLFiles(files []string) []string { + // Create a slice to hold file info for sorting + type fileInfo struct { + path string + hasCommented bool + prefix int + hasPrefix bool + basename string + } + + fileInfos := make([]fileInfo, 0, len(files)) + + for _, file := range files { + hasCommented, err := hasCommentedRefs(file) + if err != nil { + // If we can't read the file, treat it as not having commented refs + hasCommented = false + } + + prefix, hasPrefix := parseFilePrefix(file) + basename := filepath.Base(file) + + fileInfos = append(fileInfos, fileInfo{ + path: file, + hasCommented: hasCommented, + prefix: prefix, + hasPrefix: hasPrefix, + basename: basename, + }) + } + + // Sort by: hasCommented (false first), hasPrefix (true first), prefix, basename + sort.Slice(fileInfos, func(i, j int) bool { + // First, sort by commented refs (files without commented refs come first) + if fileInfos[i].hasCommented != fileInfos[j].hasCommented { + return !fileInfos[i].hasCommented + } + + // Then by presence of prefix (files with prefix come first) + if fileInfos[i].hasPrefix != fileInfos[j].hasPrefix { + return fileInfos[i].hasPrefix + } + + // If both have prefix, sort by prefix value + if fileInfos[i].hasPrefix && fileInfos[j].hasPrefix { + if fileInfos[i].prefix != fileInfos[j].prefix { + return fileInfos[i].prefix < fileInfos[j].prefix + } + } + + // Finally, sort alphabetically by basename + return fileInfos[i].basename < fileInfos[j].basename + }) + + // Extract sorted paths + sortedFiles := make([]string, len(fileInfos)) + for i, info := range fileInfos { + sortedFiles[i] = info.path + } + + return sortedFiles +} + +// mergeTable combines two table definitions +// Merges: Columns (map), Constraints (map), Indexes (map), Relationships (map) +// Uses first non-empty Description +func mergeTable(baseTable, fileTable *models.Table) { + // Merge columns (map naturally merges - later keys overwrite) + for key, col := range fileTable.Columns { + baseTable.Columns[key] = col + } + + // Merge constraints + for key, constraint := range fileTable.Constraints { + baseTable.Constraints[key] = constraint + } + + // Merge indexes + for key, index := range fileTable.Indexes { + baseTable.Indexes[key] = index + } + + // Merge relationships + for key, rel := range fileTable.Relationships { + baseTable.Relationships[key] = rel + } + + // Use first non-empty description + if baseTable.Description == "" && fileTable.Description != "" { + baseTable.Description = fileTable.Description + } + + // Merge metadata maps + if baseTable.Metadata == nil { + baseTable.Metadata = make(map[string]any) + } + for key, val := range fileTable.Metadata { + baseTable.Metadata[key] = val + } +} + +// mergeSchema finds or creates schema and merges tables +func mergeSchema(baseDB *models.Database, fileSchema *models.Schema) { + // Find existing schema by name (normalize names by stripping quotes) + var existingSchema *models.Schema + fileSchemaName := stripQuotes(fileSchema.Name) + for _, schema := range baseDB.Schemas { + if stripQuotes(schema.Name) == fileSchemaName { + existingSchema = schema + break + } + } + + // If schema doesn't exist, add it and return + if existingSchema == nil { + baseDB.Schemas = append(baseDB.Schemas, fileSchema) + return + } + + // Merge tables from fileSchema into existingSchema + for _, fileTable := range fileSchema.Tables { + // Find existing table by name (normalize names by stripping quotes) + var existingTable *models.Table + fileTableName := stripQuotes(fileTable.Name) + for _, table := range existingSchema.Tables { + if stripQuotes(table.Name) == fileTableName { + existingTable = table + break + } + } + + // If table doesn't exist, add it + if existingTable == nil { + existingSchema.Tables = append(existingSchema.Tables, fileTable) + } else { + // Merge table properties - tables are identical, skip + mergeTable(existingTable, fileTable) + } + } + + // Merge other schema properties + existingSchema.Views = append(existingSchema.Views, fileSchema.Views...) + existingSchema.Sequences = append(existingSchema.Sequences, fileSchema.Sequences...) + existingSchema.Scripts = append(existingSchema.Scripts, fileSchema.Scripts...) + + // Merge permissions + if existingSchema.Permissions == nil { + existingSchema.Permissions = make(map[string]string) + } + for key, val := range fileSchema.Permissions { + existingSchema.Permissions[key] = val + } + + // Merge metadata + if existingSchema.Metadata == nil { + existingSchema.Metadata = make(map[string]any) + } + for key, val := range fileSchema.Metadata { + existingSchema.Metadata[key] = val + } +} + +// mergeDatabase merges schemas from fileDB into baseDB +func mergeDatabase(baseDB, fileDB *models.Database) { + // Merge each schema from fileDB + for _, fileSchema := range fileDB.Schemas { + mergeSchema(baseDB, fileSchema) + } + + // Merge domains + baseDB.Domains = append(baseDB.Domains, fileDB.Domains...) + + // Use first non-empty description + if baseDB.Description == "" && fileDB.Description != "" { + baseDB.Description = fileDB.Description + } +} + // parseDBML parses DBML content and returns a Database model func (r *Reader) parseDBML(content string) (*models.Database, error) { db := models.InitDatabase("database") @@ -332,27 +622,31 @@ func (r *Reader) parseIndex(line, tableName, schemaName string) *models.Index { // Format: (columns) [attributes] OR columnname [attributes] var columns []string - if strings.Contains(line, "(") && strings.Contains(line, ")") { + // Find the attributes section to avoid parsing parentheses in notes/attributes + attrStart := strings.Index(line, "[") + columnPart := line + if attrStart > 0 { + columnPart = line[:attrStart] + } + + if strings.Contains(columnPart, "(") && strings.Contains(columnPart, ")") { // Multi-column format: (col1, col2) [attributes] - colStart := strings.Index(line, "(") - colEnd := strings.Index(line, ")") + colStart := strings.Index(columnPart, "(") + colEnd := strings.Index(columnPart, ")") if colStart >= colEnd { return nil } - columnsStr := line[colStart+1 : colEnd] + columnsStr := columnPart[colStart+1 : colEnd] for _, col := range strings.Split(columnsStr, ",") { columns = append(columns, stripQuotes(strings.TrimSpace(col))) } - } else if strings.Contains(line, "[") { + } else if attrStart > 0 { // Single column format: columnname [attributes] // Extract column name before the bracket - idx := strings.Index(line, "[") - if idx > 0 { - colName := strings.TrimSpace(line[:idx]) - if colName != "" { - columns = []string{stripQuotes(colName)} - } + colName := strings.TrimSpace(columnPart) + if colName != "" { + columns = []string{stripQuotes(colName)} } } diff --git a/pkg/readers/dbml/reader_test.go b/pkg/readers/dbml/reader_test.go index 4b6a9a2..21edf0f 100644 --- a/pkg/readers/dbml/reader_test.go +++ b/pkg/readers/dbml/reader_test.go @@ -1,6 +1,7 @@ package dbml import ( + "os" "path/filepath" "testing" @@ -517,3 +518,286 @@ func TestGetForeignKeys(t *testing.T) { t.Error("Expected foreign key constraint type") } } + +// Tests for multi-file directory loading + +func TestReadDirectory_MultipleFiles(t *testing.T) { + opts := &readers.ReaderOptions{ + FilePath: filepath.Join("..", "..", "..", "tests", "assets", "dbml", "multifile"), + } + + reader := NewReader(opts) + db, err := reader.ReadDatabase() + if err != nil { + t.Fatalf("ReadDatabase() error = %v", err) + } + + if db == nil { + t.Fatal("ReadDatabase() returned nil database") + } + + // Should have public schema + if len(db.Schemas) == 0 { + t.Fatal("Expected at least one schema") + } + + var publicSchema *models.Schema + for _, schema := range db.Schemas { + if schema.Name == "public" { + publicSchema = schema + break + } + } + + if publicSchema == nil { + t.Fatal("Public schema not found") + } + + // Should have 3 tables: users, posts, comments + if len(publicSchema.Tables) != 3 { + t.Fatalf("Expected 3 tables, got %d", len(publicSchema.Tables)) + } + + // Find tables + var usersTable, postsTable, commentsTable *models.Table + for _, table := range publicSchema.Tables { + switch table.Name { + case "users": + usersTable = table + case "posts": + postsTable = table + case "comments": + commentsTable = table + } + } + + if usersTable == nil { + t.Fatal("Users table not found") + } + if postsTable == nil { + t.Fatal("Posts table not found") + } + if commentsTable == nil { + t.Fatal("Comments table not found") + } + + // Verify users table has merged columns from 1_users.dbml and 3_add_columns.dbml + expectedUserColumns := []string{"id", "email", "name", "created_at"} + if len(usersTable.Columns) != len(expectedUserColumns) { + t.Errorf("Expected %d columns in users table, got %d", len(expectedUserColumns), len(usersTable.Columns)) + } + + for _, colName := range expectedUserColumns { + if _, exists := usersTable.Columns[colName]; !exists { + t.Errorf("Expected column '%s' in users table", colName) + } + } + + // Verify posts table columns + expectedPostColumns := []string{"id", "user_id", "title", "content", "created_at"} + for _, colName := range expectedPostColumns { + if _, exists := postsTable.Columns[colName]; !exists { + t.Errorf("Expected column '%s' in posts table", colName) + } + } +} + +func TestReadDirectory_TableMerging(t *testing.T) { + opts := &readers.ReaderOptions{ + FilePath: filepath.Join("..", "..", "..", "tests", "assets", "dbml", "multifile"), + } + + reader := NewReader(opts) + db, err := reader.ReadDatabase() + if err != nil { + t.Fatalf("ReadDatabase() error = %v", err) + } + + // Find users table + var usersTable *models.Table + for _, schema := range db.Schemas { + for _, table := range schema.Tables { + if table.Name == "users" && schema.Name == "public" { + usersTable = table + break + } + } + } + + if usersTable == nil { + t.Fatal("Users table not found") + } + + // Verify columns from file 1 (id, email) + if _, exists := usersTable.Columns["id"]; !exists { + t.Error("Column 'id' from 1_users.dbml not found") + } + if _, exists := usersTable.Columns["email"]; !exists { + t.Error("Column 'email' from 1_users.dbml not found") + } + + // Verify columns from file 3 (name, created_at) + if _, exists := usersTable.Columns["name"]; !exists { + t.Error("Column 'name' from 3_add_columns.dbml not found") + } + if _, exists := usersTable.Columns["created_at"]; !exists { + t.Error("Column 'created_at' from 3_add_columns.dbml not found") + } + + // Verify column properties from file 1 + emailCol := usersTable.Columns["email"] + if !emailCol.NotNull { + t.Error("Email column should be not null (from 1_users.dbml)") + } + if emailCol.Type != "varchar(255)" { + t.Errorf("Expected email type 'varchar(255)', got '%s'", emailCol.Type) + } +} + +func TestReadDirectory_CommentedRefsLast(t *testing.T) { + // This test verifies that files with commented refs are processed last + // by checking that the file discovery returns them in the correct order + dirPath := filepath.Join("..", "..", "..", "tests", "assets", "dbml", "multifile") + + opts := &readers.ReaderOptions{ + FilePath: dirPath, + } + + reader := NewReader(opts) + files, err := reader.discoverDBMLFiles(dirPath) + if err != nil { + t.Fatalf("discoverDBMLFiles() error = %v", err) + } + + if len(files) < 2 { + t.Skip("Not enough files to test ordering") + } + + // Check that 9_refs.dbml (which has commented refs) comes last + lastFile := filepath.Base(files[len(files)-1]) + if lastFile != "9_refs.dbml" { + t.Errorf("Expected last file to be '9_refs.dbml' (has commented refs), got '%s'", lastFile) + } + + // Check that numbered files without commented refs come first + firstFile := filepath.Base(files[0]) + if firstFile != "1_users.dbml" { + t.Errorf("Expected first file to be '1_users.dbml', got '%s'", firstFile) + } +} + +func TestReadDirectory_EmptyDirectory(t *testing.T) { + // Create a temporary empty directory + tmpDir := filepath.Join("..", "..", "..", "tests", "assets", "dbml", "empty_test_dir") + err := os.MkdirAll(tmpDir, 0755) + if err != nil { + t.Fatalf("Failed to create temp directory: %v", err) + } + defer os.RemoveAll(tmpDir) + + opts := &readers.ReaderOptions{ + FilePath: tmpDir, + } + + reader := NewReader(opts) + db, err := reader.ReadDatabase() + if err != nil { + t.Fatalf("ReadDatabase() should not error on empty directory, got: %v", err) + } + + if db == nil { + t.Fatal("ReadDatabase() returned nil database") + } + + // Empty directory should return empty database + if len(db.Schemas) != 0 { + t.Errorf("Expected 0 schemas for empty directory, got %d", len(db.Schemas)) + } +} + +func TestReadDatabase_BackwardCompat(t *testing.T) { + // Test that single file loading still works + opts := &readers.ReaderOptions{ + FilePath: filepath.Join("..", "..", "..", "tests", "assets", "dbml", "simple.dbml"), + } + + reader := NewReader(opts) + db, err := reader.ReadDatabase() + if err != nil { + t.Fatalf("ReadDatabase() error = %v", err) + } + + if db == nil { + t.Fatal("ReadDatabase() returned nil database") + } + + if len(db.Schemas) == 0 { + t.Fatal("Expected at least one schema") + } + + schema := db.Schemas[0] + if len(schema.Tables) != 1 { + t.Fatalf("Expected 1 table, got %d", len(schema.Tables)) + } + + table := schema.Tables[0] + if table.Name != "users" { + t.Errorf("Expected table name 'users', got '%s'", table.Name) + } +} + +func TestParseFilePrefix(t *testing.T) { + tests := []struct { + filename string + wantPrefix int + wantHas bool + }{ + {"1_schema.dbml", 1, true}, + {"2_tables.dbml", 2, true}, + {"10_relationships.dbml", 10, true}, + {"99_data.dbml", 99, true}, + {"schema.dbml", 0, false}, + {"tables_no_prefix.dbml", 0, false}, + {"/path/to/1_file.dbml", 1, true}, + {"/path/to/file.dbml", 0, false}, + {"1-file.dbml", 1, true}, + {"2-another.dbml", 2, true}, + } + + for _, tt := range tests { + t.Run(tt.filename, func(t *testing.T) { + gotPrefix, gotHas := parseFilePrefix(tt.filename) + if gotPrefix != tt.wantPrefix { + t.Errorf("parseFilePrefix(%s) prefix = %d, want %d", tt.filename, gotPrefix, tt.wantPrefix) + } + if gotHas != tt.wantHas { + t.Errorf("parseFilePrefix(%s) hasPrefix = %v, want %v", tt.filename, gotHas, tt.wantHas) + } + }) + } +} + +func TestHasCommentedRefs(t *testing.T) { + // Test with the actual multifile test fixtures + tests := []struct { + filename string + wantHas bool + }{ + {filepath.Join("..", "..", "..", "tests", "assets", "dbml", "multifile", "1_users.dbml"), false}, + {filepath.Join("..", "..", "..", "tests", "assets", "dbml", "multifile", "2_posts.dbml"), false}, + {filepath.Join("..", "..", "..", "tests", "assets", "dbml", "multifile", "3_add_columns.dbml"), false}, + {filepath.Join("..", "..", "..", "tests", "assets", "dbml", "multifile", "9_refs.dbml"), true}, + } + + for _, tt := range tests { + t.Run(filepath.Base(tt.filename), func(t *testing.T) { + gotHas, err := hasCommentedRefs(tt.filename) + if err != nil { + t.Fatalf("hasCommentedRefs() error = %v", err) + } + if gotHas != tt.wantHas { + t.Errorf("hasCommentedRefs(%s) = %v, want %v", filepath.Base(tt.filename), gotHas, tt.wantHas) + } + }) + } +} diff --git a/pkg/writers/dbml/writer.go b/pkg/writers/dbml/writer.go index af215ec..886defd 100644 --- a/pkg/writers/dbml/writer.go +++ b/pkg/writers/dbml/writer.go @@ -126,7 +126,15 @@ func (w *Writer) tableToDBML(t *models.Table) string { attrs = append(attrs, "increment") } if column.Default != nil { - attrs = append(attrs, fmt.Sprintf("default: `%v`", column.Default)) + // Check if default value contains backticks (DBML expressions like `now()`) + defaultStr := fmt.Sprintf("%v", column.Default) + if strings.HasPrefix(defaultStr, "`") && strings.HasSuffix(defaultStr, "`") { + // Already an expression with backticks, use as-is + attrs = append(attrs, fmt.Sprintf("default: %s", defaultStr)) + } else { + // Regular value, wrap in single quotes + attrs = append(attrs, fmt.Sprintf("default: '%v'", column.Default)) + } } if len(attrs) > 0 { diff --git a/tests/assets/dbml/multifile/1_users.dbml b/tests/assets/dbml/multifile/1_users.dbml new file mode 100644 index 0000000..0161412 --- /dev/null +++ b/tests/assets/dbml/multifile/1_users.dbml @@ -0,0 +1,5 @@ +// First file - users table basic structure +Table public.users { + id bigint [pk, increment] + email varchar(255) [unique, not null] +} diff --git a/tests/assets/dbml/multifile/2_posts.dbml b/tests/assets/dbml/multifile/2_posts.dbml new file mode 100644 index 0000000..02c5622 --- /dev/null +++ b/tests/assets/dbml/multifile/2_posts.dbml @@ -0,0 +1,8 @@ +// Second file - posts table +Table public.posts { + id bigint [pk, increment] + user_id bigint [not null] + title varchar(200) [not null] + content text + created_at timestamp [not null] +} diff --git a/tests/assets/dbml/multifile/3_add_columns.dbml b/tests/assets/dbml/multifile/3_add_columns.dbml new file mode 100644 index 0000000..f34bc77 --- /dev/null +++ b/tests/assets/dbml/multifile/3_add_columns.dbml @@ -0,0 +1,5 @@ +// Third file - adds more columns to users table (tests merging) +Table public.users { + name varchar(100) + created_at timestamp [not null] +} diff --git a/tests/assets/dbml/multifile/9_refs.dbml b/tests/assets/dbml/multifile/9_refs.dbml new file mode 100644 index 0000000..5256be5 --- /dev/null +++ b/tests/assets/dbml/multifile/9_refs.dbml @@ -0,0 +1,10 @@ +// File with commented-out refs - should load last +// Contains relationships that depend on earlier tables + +// Ref: public.posts.user_id > public.users.id [ondelete: CASCADE] + +Table public.comments { + id bigint [pk, increment] + post_id bigint [not null] + content text [not null] +}