feat(reader): 🎉 Add support for multi-file DBML loading
All checks were successful
CI / Test (1.24) (push) Successful in -27m13s
CI / Test (1.25) (push) Successful in -27m5s
CI / Build (push) Successful in -27m16s
CI / Lint (push) Successful in -27m0s
Integration Tests / Integration Tests (push) Successful in -27m14s
Release / Build and Release (push) Successful in -25m52s

* Implement directory reading for DBML files.
* Merge schemas and tables from multiple files.
* Add tests for multi-file loading and merging behavior.
* Enhance file discovery and sorting logic.
This commit is contained in:
2026-01-10 13:17:30 +02:00
parent f6c3f2b460
commit 6388daba56
7 changed files with 626 additions and 12 deletions

View File

@@ -4,7 +4,9 @@ import (
"bufio"
"fmt"
"os"
"path/filepath"
"regexp"
"sort"
"strings"
"git.warky.dev/wdevs/relspecgo/pkg/models"
@@ -24,11 +26,23 @@ func NewReader(options *readers.ReaderOptions) *Reader {
}
// ReadDatabase reads and parses DBML input, returning a Database model
// If FilePath points to a directory, all .dbml files are loaded and merged
func (r *Reader) ReadDatabase() (*models.Database, error) {
if r.options.FilePath == "" {
return nil, fmt.Errorf("file path is required for DBML reader")
}
// Check if path is a directory
info, err := os.Stat(r.options.FilePath)
if err != nil {
return nil, fmt.Errorf("failed to stat path: %w", err)
}
if info.IsDir() {
return r.readDirectoryDBML(r.options.FilePath)
}
// Single file - existing logic
content, err := os.ReadFile(r.options.FilePath)
if err != nil {
return nil, fmt.Errorf("failed to read file: %w", err)
@@ -67,6 +81,53 @@ func (r *Reader) ReadTable() (*models.Table, error) {
return schema.Tables[0], nil
}
// readDirectoryDBML processes all .dbml files in directory
// Returns merged Database model
func (r *Reader) readDirectoryDBML(dirPath string) (*models.Database, error) {
// Discover and sort DBML files
files, err := r.discoverDBMLFiles(dirPath)
if err != nil {
return nil, fmt.Errorf("failed to discover DBML files: %w", err)
}
// If no files found, return empty database
if len(files) == 0 {
db := models.InitDatabase("database")
if r.options.Metadata != nil {
if name, ok := r.options.Metadata["name"].(string); ok {
db.Name = name
}
}
return db, nil
}
// Initialize database (will be merged with files)
var db *models.Database
// Process each file in sorted order
for _, filePath := range files {
content, err := os.ReadFile(filePath)
if err != nil {
return nil, fmt.Errorf("failed to read file %s: %w", filePath, err)
}
fileDB, err := r.parseDBML(string(content))
if err != nil {
return nil, fmt.Errorf("failed to parse file %s: %w", filePath, err)
}
// First file initializes the database
if db == nil {
db = fileDB
} else {
// Subsequent files are merged
mergeDatabase(db, fileDB)
}
}
return db, nil
}
// stripQuotes removes surrounding quotes from an identifier
func stripQuotes(s string) string {
s = strings.TrimSpace(s)
@@ -76,6 +137,235 @@ func stripQuotes(s string) string {
return s
}
// parseFilePrefix extracts numeric prefix from filename
// Examples: "1_schema.dbml" -> (1, true), "tables.dbml" -> (0, false)
func parseFilePrefix(filename string) (int, bool) {
base := filepath.Base(filename)
re := regexp.MustCompile(`^(\d+)[_-]`)
matches := re.FindStringSubmatch(base)
if len(matches) > 1 {
var prefix int
_, err := fmt.Sscanf(matches[1], "%d", &prefix)
if err == nil {
return prefix, true
}
}
return 0, false
}
// hasCommentedRefs scans file content for commented-out Ref statements
// Returns true if file contains lines like: // Ref: table.col > other.col
func hasCommentedRefs(filePath string) (bool, error) {
content, err := os.ReadFile(filePath)
if err != nil {
return false, err
}
scanner := bufio.NewScanner(strings.NewReader(string(content)))
commentedRefRegex := regexp.MustCompile(`^\s*//.*Ref:\s+`)
for scanner.Scan() {
line := scanner.Text()
if commentedRefRegex.MatchString(line) {
return true, nil
}
}
return false, nil
}
// discoverDBMLFiles finds all .dbml files in directory and returns them sorted
func (r *Reader) discoverDBMLFiles(dirPath string) ([]string, error) {
pattern := filepath.Join(dirPath, "*.dbml")
files, err := filepath.Glob(pattern)
if err != nil {
return nil, fmt.Errorf("failed to glob .dbml files: %w", err)
}
return sortDBMLFiles(files), nil
}
// sortDBMLFiles sorts files by:
// 1. Files without commented refs (by numeric prefix, then alphabetically)
// 2. Files with commented refs (by numeric prefix, then alphabetically)
func sortDBMLFiles(files []string) []string {
// Create a slice to hold file info for sorting
type fileInfo struct {
path string
hasCommented bool
prefix int
hasPrefix bool
basename string
}
fileInfos := make([]fileInfo, 0, len(files))
for _, file := range files {
hasCommented, err := hasCommentedRefs(file)
if err != nil {
// If we can't read the file, treat it as not having commented refs
hasCommented = false
}
prefix, hasPrefix := parseFilePrefix(file)
basename := filepath.Base(file)
fileInfos = append(fileInfos, fileInfo{
path: file,
hasCommented: hasCommented,
prefix: prefix,
hasPrefix: hasPrefix,
basename: basename,
})
}
// Sort by: hasCommented (false first), hasPrefix (true first), prefix, basename
sort.Slice(fileInfos, func(i, j int) bool {
// First, sort by commented refs (files without commented refs come first)
if fileInfos[i].hasCommented != fileInfos[j].hasCommented {
return !fileInfos[i].hasCommented
}
// Then by presence of prefix (files with prefix come first)
if fileInfos[i].hasPrefix != fileInfos[j].hasPrefix {
return fileInfos[i].hasPrefix
}
// If both have prefix, sort by prefix value
if fileInfos[i].hasPrefix && fileInfos[j].hasPrefix {
if fileInfos[i].prefix != fileInfos[j].prefix {
return fileInfos[i].prefix < fileInfos[j].prefix
}
}
// Finally, sort alphabetically by basename
return fileInfos[i].basename < fileInfos[j].basename
})
// Extract sorted paths
sortedFiles := make([]string, len(fileInfos))
for i, info := range fileInfos {
sortedFiles[i] = info.path
}
return sortedFiles
}
// mergeTable combines two table definitions
// Merges: Columns (map), Constraints (map), Indexes (map), Relationships (map)
// Uses first non-empty Description
func mergeTable(baseTable, fileTable *models.Table) {
// Merge columns (map naturally merges - later keys overwrite)
for key, col := range fileTable.Columns {
baseTable.Columns[key] = col
}
// Merge constraints
for key, constraint := range fileTable.Constraints {
baseTable.Constraints[key] = constraint
}
// Merge indexes
for key, index := range fileTable.Indexes {
baseTable.Indexes[key] = index
}
// Merge relationships
for key, rel := range fileTable.Relationships {
baseTable.Relationships[key] = rel
}
// Use first non-empty description
if baseTable.Description == "" && fileTable.Description != "" {
baseTable.Description = fileTable.Description
}
// Merge metadata maps
if baseTable.Metadata == nil {
baseTable.Metadata = make(map[string]any)
}
for key, val := range fileTable.Metadata {
baseTable.Metadata[key] = val
}
}
// mergeSchema finds or creates schema and merges tables
func mergeSchema(baseDB *models.Database, fileSchema *models.Schema) {
// Find existing schema by name (normalize names by stripping quotes)
var existingSchema *models.Schema
fileSchemaName := stripQuotes(fileSchema.Name)
for _, schema := range baseDB.Schemas {
if stripQuotes(schema.Name) == fileSchemaName {
existingSchema = schema
break
}
}
// If schema doesn't exist, add it and return
if existingSchema == nil {
baseDB.Schemas = append(baseDB.Schemas, fileSchema)
return
}
// Merge tables from fileSchema into existingSchema
for _, fileTable := range fileSchema.Tables {
// Find existing table by name (normalize names by stripping quotes)
var existingTable *models.Table
fileTableName := stripQuotes(fileTable.Name)
for _, table := range existingSchema.Tables {
if stripQuotes(table.Name) == fileTableName {
existingTable = table
break
}
}
// If table doesn't exist, add it
if existingTable == nil {
existingSchema.Tables = append(existingSchema.Tables, fileTable)
} else {
// Merge table properties - tables are identical, skip
mergeTable(existingTable, fileTable)
}
}
// Merge other schema properties
existingSchema.Views = append(existingSchema.Views, fileSchema.Views...)
existingSchema.Sequences = append(existingSchema.Sequences, fileSchema.Sequences...)
existingSchema.Scripts = append(existingSchema.Scripts, fileSchema.Scripts...)
// Merge permissions
if existingSchema.Permissions == nil {
existingSchema.Permissions = make(map[string]string)
}
for key, val := range fileSchema.Permissions {
existingSchema.Permissions[key] = val
}
// Merge metadata
if existingSchema.Metadata == nil {
existingSchema.Metadata = make(map[string]any)
}
for key, val := range fileSchema.Metadata {
existingSchema.Metadata[key] = val
}
}
// mergeDatabase merges schemas from fileDB into baseDB
func mergeDatabase(baseDB, fileDB *models.Database) {
// Merge each schema from fileDB
for _, fileSchema := range fileDB.Schemas {
mergeSchema(baseDB, fileSchema)
}
// Merge domains
baseDB.Domains = append(baseDB.Domains, fileDB.Domains...)
// Use first non-empty description
if baseDB.Description == "" && fileDB.Description != "" {
baseDB.Description = fileDB.Description
}
}
// parseDBML parses DBML content and returns a Database model
func (r *Reader) parseDBML(content string) (*models.Database, error) {
db := models.InitDatabase("database")
@@ -332,27 +622,31 @@ func (r *Reader) parseIndex(line, tableName, schemaName string) *models.Index {
// Format: (columns) [attributes] OR columnname [attributes]
var columns []string
if strings.Contains(line, "(") && strings.Contains(line, ")") {
// Find the attributes section to avoid parsing parentheses in notes/attributes
attrStart := strings.Index(line, "[")
columnPart := line
if attrStart > 0 {
columnPart = line[:attrStart]
}
if strings.Contains(columnPart, "(") && strings.Contains(columnPart, ")") {
// Multi-column format: (col1, col2) [attributes]
colStart := strings.Index(line, "(")
colEnd := strings.Index(line, ")")
colStart := strings.Index(columnPart, "(")
colEnd := strings.Index(columnPart, ")")
if colStart >= colEnd {
return nil
}
columnsStr := line[colStart+1 : colEnd]
columnsStr := columnPart[colStart+1 : colEnd]
for _, col := range strings.Split(columnsStr, ",") {
columns = append(columns, stripQuotes(strings.TrimSpace(col)))
}
} else if strings.Contains(line, "[") {
} else if attrStart > 0 {
// Single column format: columnname [attributes]
// Extract column name before the bracket
idx := strings.Index(line, "[")
if idx > 0 {
colName := strings.TrimSpace(line[:idx])
if colName != "" {
columns = []string{stripQuotes(colName)}
}
colName := strings.TrimSpace(columnPart)
if colName != "" {
columns = []string{stripQuotes(colName)}
}
}