package dbml import ( "bufio" "fmt" "os" "path/filepath" "regexp" "sort" "strings" "git.warky.dev/wdevs/relspecgo/pkg/models" "git.warky.dev/wdevs/relspecgo/pkg/readers" ) // Reader implements the readers.Reader interface for DBML format type Reader struct { options *readers.ReaderOptions } // NewReader creates a new DBML reader with the given options func NewReader(options *readers.ReaderOptions) *Reader { return &Reader{ options: options, } } // ReadDatabase reads and parses DBML input, returning a Database model // If FilePath points to a directory, all .dbml files are loaded and merged func (r *Reader) ReadDatabase() (*models.Database, error) { if r.options.FilePath == "" { return nil, fmt.Errorf("file path is required for DBML reader") } // Check if path is a directory info, err := os.Stat(r.options.FilePath) if err != nil { return nil, fmt.Errorf("failed to stat path: %w", err) } if info.IsDir() { return r.readDirectoryDBML(r.options.FilePath) } // Single file - existing logic content, err := os.ReadFile(r.options.FilePath) if err != nil { return nil, fmt.Errorf("failed to read file: %w", err) } return r.parseDBML(string(content)) } // ReadSchema reads and parses DBML input, returning a Schema model func (r *Reader) ReadSchema() (*models.Schema, error) { db, err := r.ReadDatabase() if err != nil { return nil, err } if len(db.Schemas) == 0 { return nil, fmt.Errorf("no schemas found in DBML") } // Return the first schema return db.Schemas[0], nil } // ReadTable reads and parses DBML input, returning a Table model func (r *Reader) ReadTable() (*models.Table, error) { schema, err := r.ReadSchema() if err != nil { return nil, err } if len(schema.Tables) == 0 { return nil, fmt.Errorf("no tables found in DBML") } // Return the first table return schema.Tables[0], nil } // readDirectoryDBML processes all .dbml files in directory // Returns merged Database model func (r *Reader) readDirectoryDBML(dirPath string) (*models.Database, error) { // Discover and sort DBML files files, err := r.discoverDBMLFiles(dirPath) if err != nil { return nil, fmt.Errorf("failed to discover DBML files: %w", err) } // If no files found, return empty database if len(files) == 0 { db := models.InitDatabase("database") if r.options.Metadata != nil { if name, ok := r.options.Metadata["name"].(string); ok { db.Name = name } } return db, nil } // Initialize database (will be merged with files) var db *models.Database // Process each file in sorted order for _, filePath := range files { content, err := os.ReadFile(filePath) if err != nil { return nil, fmt.Errorf("failed to read file %s: %w", filePath, err) } fileDB, err := r.parseDBML(string(content)) if err != nil { return nil, fmt.Errorf("failed to parse file %s: %w", filePath, err) } // First file initializes the database if db == nil { db = fileDB } else { // Subsequent files are merged mergeDatabase(db, fileDB) } } return db, nil } // stripQuotes removes surrounding quotes and comments from an identifier func stripQuotes(s string) string { s = strings.TrimSpace(s) // Remove DBML comments in brackets (e.g., [note: 'description']) // This handles inline comments like: "table_name" [note: 'comment'] commentRegex := regexp.MustCompile(`\s*\[.*?\]\s*`) s = commentRegex.ReplaceAllString(s, "") // Trim again after removing comments s = strings.TrimSpace(s) // Remove surrounding quotes (double or single) if len(s) >= 2 && ((s[0] == '"' && s[len(s)-1] == '"') || (s[0] == '\'' && s[len(s)-1] == '\'')) { return s[1 : len(s)-1] } return s } // parseFilePrefix extracts numeric prefix from filename // Examples: "1_schema.dbml" -> (1, true), "tables.dbml" -> (0, false) func parseFilePrefix(filename string) (int, bool) { base := filepath.Base(filename) re := regexp.MustCompile(`^(\d+)[_-]`) matches := re.FindStringSubmatch(base) if len(matches) > 1 { var prefix int _, err := fmt.Sscanf(matches[1], "%d", &prefix) if err == nil { return prefix, true } } return 0, false } // hasCommentedRefs scans file content for commented-out Ref statements // Returns true if file contains lines like: // Ref: table.col > other.col func hasCommentedRefs(filePath string) (bool, error) { content, err := os.ReadFile(filePath) if err != nil { return false, err } scanner := bufio.NewScanner(strings.NewReader(string(content))) commentedRefRegex := regexp.MustCompile(`^\s*//.*Ref:\s+`) for scanner.Scan() { line := scanner.Text() if commentedRefRegex.MatchString(line) { return true, nil } } return false, nil } // discoverDBMLFiles finds all .dbml files in directory and returns them sorted func (r *Reader) discoverDBMLFiles(dirPath string) ([]string, error) { pattern := filepath.Join(dirPath, "*.dbml") files, err := filepath.Glob(pattern) if err != nil { return nil, fmt.Errorf("failed to glob .dbml files: %w", err) } return sortDBMLFiles(files), nil } // sortDBMLFiles sorts files by: // 1. Files without commented refs (by numeric prefix, then alphabetically) // 2. Files with commented refs (by numeric prefix, then alphabetically) func sortDBMLFiles(files []string) []string { // Create a slice to hold file info for sorting type fileInfo struct { path string hasCommented bool prefix int hasPrefix bool basename string } fileInfos := make([]fileInfo, 0, len(files)) for _, file := range files { hasCommented, err := hasCommentedRefs(file) if err != nil { // If we can't read the file, treat it as not having commented refs hasCommented = false } prefix, hasPrefix := parseFilePrefix(file) basename := filepath.Base(file) fileInfos = append(fileInfos, fileInfo{ path: file, hasCommented: hasCommented, prefix: prefix, hasPrefix: hasPrefix, basename: basename, }) } // Sort by: hasCommented (false first), hasPrefix (true first), prefix, basename sort.Slice(fileInfos, func(i, j int) bool { // First, sort by commented refs (files without commented refs come first) if fileInfos[i].hasCommented != fileInfos[j].hasCommented { return !fileInfos[i].hasCommented } // Then by presence of prefix (files with prefix come first) if fileInfos[i].hasPrefix != fileInfos[j].hasPrefix { return fileInfos[i].hasPrefix } // If both have prefix, sort by prefix value if fileInfos[i].hasPrefix && fileInfos[j].hasPrefix { if fileInfos[i].prefix != fileInfos[j].prefix { return fileInfos[i].prefix < fileInfos[j].prefix } } // Finally, sort alphabetically by basename return fileInfos[i].basename < fileInfos[j].basename }) // Extract sorted paths sortedFiles := make([]string, len(fileInfos)) for i, info := range fileInfos { sortedFiles[i] = info.path } return sortedFiles } // mergeTable combines two table definitions // Merges: Columns (map), Constraints (map), Indexes (map), Relationships (map) // Uses first non-empty Description func mergeTable(baseTable, fileTable *models.Table) { // Merge columns (map naturally merges - later keys overwrite) for key, col := range fileTable.Columns { baseTable.Columns[key] = col } // Merge constraints for key, constraint := range fileTable.Constraints { baseTable.Constraints[key] = constraint } // Merge indexes for key, index := range fileTable.Indexes { baseTable.Indexes[key] = index } // Merge relationships for key, rel := range fileTable.Relationships { baseTable.Relationships[key] = rel } // Use first non-empty description if baseTable.Description == "" && fileTable.Description != "" { baseTable.Description = fileTable.Description } // Merge metadata maps if baseTable.Metadata == nil { baseTable.Metadata = make(map[string]any) } for key, val := range fileTable.Metadata { baseTable.Metadata[key] = val } } // mergeSchema finds or creates schema and merges tables func mergeSchema(baseDB *models.Database, fileSchema *models.Schema) { // Find existing schema by name (normalize names by stripping quotes) var existingSchema *models.Schema fileSchemaName := stripQuotes(fileSchema.Name) for _, schema := range baseDB.Schemas { if stripQuotes(schema.Name) == fileSchemaName { existingSchema = schema break } } // If schema doesn't exist, add it and return if existingSchema == nil { baseDB.Schemas = append(baseDB.Schemas, fileSchema) return } // Merge tables from fileSchema into existingSchema for _, fileTable := range fileSchema.Tables { // Find existing table by name (normalize names by stripping quotes) var existingTable *models.Table fileTableName := stripQuotes(fileTable.Name) for _, table := range existingSchema.Tables { if stripQuotes(table.Name) == fileTableName { existingTable = table break } } // If table doesn't exist, add it if existingTable == nil { existingSchema.Tables = append(existingSchema.Tables, fileTable) } else { // Merge table properties - tables are identical, skip mergeTable(existingTable, fileTable) } } // Merge other schema properties existingSchema.Views = append(existingSchema.Views, fileSchema.Views...) existingSchema.Sequences = append(existingSchema.Sequences, fileSchema.Sequences...) existingSchema.Scripts = append(existingSchema.Scripts, fileSchema.Scripts...) // Merge permissions if existingSchema.Permissions == nil { existingSchema.Permissions = make(map[string]string) } for key, val := range fileSchema.Permissions { existingSchema.Permissions[key] = val } // Merge metadata if existingSchema.Metadata == nil { existingSchema.Metadata = make(map[string]any) } for key, val := range fileSchema.Metadata { existingSchema.Metadata[key] = val } } // mergeDatabase merges schemas from fileDB into baseDB func mergeDatabase(baseDB, fileDB *models.Database) { // Merge each schema from fileDB for _, fileSchema := range fileDB.Schemas { mergeSchema(baseDB, fileSchema) } // Merge domains baseDB.Domains = append(baseDB.Domains, fileDB.Domains...) // Use first non-empty description if baseDB.Description == "" && fileDB.Description != "" { baseDB.Description = fileDB.Description } } // parseDBML parses DBML content and returns a Database model func (r *Reader) parseDBML(content string) (*models.Database, error) { db := models.InitDatabase("database") if r.options.Metadata != nil { if name, ok := r.options.Metadata["name"].(string); ok { db.Name = name } } scanner := bufio.NewScanner(strings.NewReader(content)) schemaMap := make(map[string]*models.Schema) pendingConstraints := []*models.Constraint{} var currentTable *models.Table var currentSchema string var inIndexes bool var inTable bool tableRegex := regexp.MustCompile(`^Table\s+(.+?)\s*{`) refRegex := regexp.MustCompile(`^Ref:\s+(.+)`) for scanner.Scan() { line := strings.TrimSpace(scanner.Text()) // Skip empty lines and comments if line == "" || strings.HasPrefix(line, "//") { continue } // Parse Table definition if matches := tableRegex.FindStringSubmatch(line); matches != nil { tableName := matches[1] parts := strings.Split(tableName, ".") if len(parts) == 2 { currentSchema = stripQuotes(parts[0]) tableName = stripQuotes(parts[1]) } else { currentSchema = "public" tableName = stripQuotes(parts[0]) } // Ensure schema exists if _, exists := schemaMap[currentSchema]; !exists { schemaMap[currentSchema] = models.InitSchema(currentSchema) } currentTable = models.InitTable(tableName, currentSchema) inTable = true inIndexes = false continue } // End of table definition if inTable && line == "}" { if currentTable != nil && currentSchema != "" { schemaMap[currentSchema].Tables = append(schemaMap[currentSchema].Tables, currentTable) currentTable = nil } inTable = false inIndexes = false continue } // Parse indexes section if inTable && (strings.HasPrefix(line, "Indexes {") || strings.HasPrefix(line, "indexes {")) { inIndexes = true continue } // End of indexes section if inIndexes && line == "}" { inIndexes = false continue } // Parse index definition if inIndexes && currentTable != nil { index := r.parseIndex(line, currentTable.Name, currentSchema) if index != nil { currentTable.Indexes[index.Name] = index } continue } // Parse table note if inTable && currentTable != nil && strings.HasPrefix(line, "Note:") { note := strings.TrimPrefix(line, "Note:") note = strings.Trim(note, " '\"") currentTable.Description = note continue } // Parse column definition if inTable && !inIndexes && currentTable != nil { column, constraint := r.parseColumn(line, currentTable.Name, currentSchema) if column != nil { currentTable.Columns[column.Name] = column } if constraint != nil { // Add to pending list - will assign to tables at the end pendingConstraints = append(pendingConstraints, constraint) } continue } // Parse Ref (relationship/foreign key) if matches := refRegex.FindStringSubmatch(line); matches != nil { constraint := r.parseRef(matches[1]) if constraint != nil { // Find the table and add the constraint for _, schema := range schemaMap { for _, table := range schema.Tables { if table.Schema == constraint.Schema && table.Name == constraint.Table { table.Constraints[constraint.Name] = constraint break } } } } continue } } // Assign pending constraints to their respective tables for _, constraint := range pendingConstraints { // Find the table this constraint belongs to if schema, exists := schemaMap[constraint.Schema]; exists { for _, table := range schema.Tables { if table.Name == constraint.Table { table.Constraints[constraint.Name] = constraint break } } } } // Add schemas to database for _, schema := range schemaMap { db.Schemas = append(db.Schemas, schema) } return db, nil } // parseColumn parses a DBML column definition func (r *Reader) parseColumn(line, tableName, schemaName string) (*models.Column, *models.Constraint) { // Format: column_name type [attributes] // comment parts := strings.Fields(line) if len(parts) < 2 { return nil, nil } columnName := stripQuotes(parts[0]) columnType := stripQuotes(parts[1]) column := models.InitColumn(columnName, tableName, schemaName) column.Type = columnType var constraint *models.Constraint // Parse attributes in brackets if strings.Contains(line, "[") && strings.Contains(line, "]") { attrStart := strings.Index(line, "[") attrEnd := strings.Index(line, "]") if attrStart < attrEnd { attrs := line[attrStart+1 : attrEnd] attrList := strings.Split(attrs, ",") for _, attr := range attrList { attr = strings.TrimSpace(attr) if strings.Contains(attr, "primary key") || attr == "pk" { column.IsPrimaryKey = true column.NotNull = true } else if strings.Contains(attr, "not null") { column.NotNull = true } else if attr == "increment" { column.AutoIncrement = true } else if strings.HasPrefix(attr, "default:") { defaultVal := strings.TrimSpace(strings.TrimPrefix(attr, "default:")) column.Default = strings.Trim(defaultVal, "'\"") } else if attr == "unique" { // Create a unique constraint uniqueConstraint := models.InitConstraint( fmt.Sprintf("uq_%s", columnName), models.UniqueConstraint, ) uniqueConstraint.Schema = schemaName uniqueConstraint.Table = tableName uniqueConstraint.Columns = []string{columnName} // Store it to be added later if constraint == nil { constraint = uniqueConstraint } } else if strings.HasPrefix(attr, "note:") { // Parse column note/comment note := strings.TrimSpace(strings.TrimPrefix(attr, "note:")) column.Comment = strings.Trim(note, "'\"") } else if strings.HasPrefix(attr, "ref:") { // Parse inline reference // DBML semantics depend on context: // - On FK column: ref: < target means "this FK references target" // - On PK column: ref: < source means "source references this PK" (reverse notation) refStr := strings.TrimSpace(strings.TrimPrefix(attr, "ref:")) // Check relationship direction operator refOp := strings.TrimSpace(refStr) var isReverse bool if strings.HasPrefix(refOp, "<") { isReverse = column.IsPrimaryKey // < on PK means "is referenced by" (reverse) } else if strings.HasPrefix(refOp, ">") { isReverse = !column.IsPrimaryKey // > on FK means reverse } constraint = r.parseRef(refStr) if constraint != nil { if isReverse { // Reverse: parsed ref is SOURCE, current column is TARGET // Constraint should be ON the source table constraint.Schema = constraint.ReferencedSchema constraint.Table = constraint.ReferencedTable constraint.Columns = constraint.ReferencedColumns constraint.ReferencedSchema = schemaName constraint.ReferencedTable = tableName constraint.ReferencedColumns = []string{columnName} } else { // Forward: current column is SOURCE, parsed ref is TARGET // Standard FK: constraint is ON current table constraint.Schema = schemaName constraint.Table = tableName constraint.Columns = []string{columnName} } // Generate short constraint name based on the column constraint.Name = fmt.Sprintf("fk_%s", constraint.Columns[0]) } } } } } // Parse inline comment if strings.Contains(line, "//") { commentStart := strings.Index(line, "//") column.Comment = strings.TrimSpace(line[commentStart+2:]) } return column, constraint } // parseIndex parses a DBML index definition func (r *Reader) parseIndex(line, tableName, schemaName string) *models.Index { // Format: (columns) [attributes] OR columnname [attributes] var columns []string // Find the attributes section to avoid parsing parentheses in notes/attributes attrStart := strings.Index(line, "[") columnPart := line if attrStart > 0 { columnPart = line[:attrStart] } if strings.Contains(columnPart, "(") && strings.Contains(columnPart, ")") { // Multi-column format: (col1, col2) [attributes] colStart := strings.Index(columnPart, "(") colEnd := strings.Index(columnPart, ")") if colStart >= colEnd { return nil } columnsStr := columnPart[colStart+1 : colEnd] for _, col := range strings.Split(columnsStr, ",") { columns = append(columns, stripQuotes(strings.TrimSpace(col))) } } else if attrStart > 0 { // Single column format: columnname [attributes] // Extract column name before the bracket colName := strings.TrimSpace(columnPart) if colName != "" { columns = []string{stripQuotes(colName)} } } if len(columns) == 0 { return nil } index := models.InitIndex("", tableName, schemaName) index.Table = tableName index.Schema = schemaName index.Columns = columns // Parse attributes if strings.Contains(line, "[") && strings.Contains(line, "]") { attrStart := strings.Index(line, "[") attrEnd := strings.Index(line, "]") if attrStart < attrEnd { attrs := line[attrStart+1 : attrEnd] attrList := strings.Split(attrs, ",") for _, attr := range attrList { attr = strings.TrimSpace(attr) if attr == "unique" { index.Unique = true } else if strings.HasPrefix(attr, "name:") { name := strings.TrimSpace(strings.TrimPrefix(attr, "name:")) index.Name = strings.Trim(name, "'\"") } else if strings.HasPrefix(attr, "type:") { indexType := strings.TrimSpace(strings.TrimPrefix(attr, "type:")) index.Type = strings.Trim(indexType, "'\"") } } } } // Generate name if not provided if index.Name == "" { index.Name = fmt.Sprintf("idx_%s_%s", tableName, strings.Join(columns, "_")) } return index } // parseRef parses a DBML Ref (foreign key relationship) func (r *Reader) parseRef(refStr string) *models.Constraint { // Format: schema.table.(columns) > schema.table.(columns) [actions] // Or inline format: < schema.table.column (for inline column refs) // Split by relationship operator (>, <, -, etc.) var fromPart, toPart string isInlineRef := false for _, op := range []string{">", "<", "-"} { if strings.Contains(refStr, op) { parts := strings.Split(refStr, op) if len(parts) == 2 { fromPart = strings.TrimSpace(parts[0]) toPart = strings.TrimSpace(parts[1]) // Check if this is an inline ref (operator at start) if fromPart == "" { isInlineRef = true } break } } } // For inline refs, only toPart should be populated if isInlineRef { if toPart == "" { return nil } } else if fromPart == "" || toPart == "" { return nil } // Remove actions part if present if idx := strings.Index(toPart, "["); idx >= 0 { toPart = strings.TrimSpace(toPart[:idx]) } // Parse references var fromSchema, fromTable string var fromColumns []string toSchema, toTable, toColumns := r.parseTableRef(toPart) if !isInlineRef { fromSchema, fromTable, fromColumns = r.parseTableRef(fromPart) if fromTable == "" { return nil } } if toTable == "" { return nil } // Generate short constraint name based on the source column constraintName := fmt.Sprintf("fk_%s_%s", fromTable, toTable) if len(fromColumns) > 0 { constraintName = fmt.Sprintf("fk_%s", fromColumns[0]) } constraint := models.InitConstraint( constraintName, models.ForeignKeyConstraint, ) constraint.Schema = fromSchema constraint.Table = fromTable constraint.Columns = fromColumns constraint.ReferencedSchema = toSchema constraint.ReferencedTable = toTable constraint.ReferencedColumns = toColumns // Parse actions if present if strings.Contains(refStr, "[") && strings.Contains(refStr, "]") { actStart := strings.Index(refStr, "[") actEnd := strings.Index(refStr, "]") if actStart < actEnd { actions := refStr[actStart+1 : actEnd] actionList := strings.Split(actions, ",") for _, action := range actionList { action = strings.TrimSpace(action) if strings.HasPrefix(action, "ondelete:") { constraint.OnDelete = strings.TrimSpace(strings.TrimPrefix(action, "ondelete:")) } else if strings.HasPrefix(action, "onupdate:") { constraint.OnUpdate = strings.TrimSpace(strings.TrimPrefix(action, "onupdate:")) } } } } return constraint } // parseTableRef parses a table reference like "schema.table.(column1, column2)" or "schema"."table"."column" func (r *Reader) parseTableRef(ref string) (schema, table string, columns []string) { // Extract columns if present in parentheses format hasParentheses := false if strings.Contains(ref, "(") && strings.Contains(ref, ")") { colStart := strings.Index(ref, "(") colEnd := strings.Index(ref, ")") if colStart < colEnd { columnsStr := ref[colStart+1 : colEnd] for _, col := range strings.Split(columnsStr, ",") { columns = append(columns, stripQuotes(strings.TrimSpace(col))) } hasParentheses = true } ref = ref[:colStart] } // Parse schema, table, and optionally column parts := strings.Split(strings.TrimSpace(ref), ".") if len(parts) == 3 { // Format: "schema"."table"."column" schema = stripQuotes(parts[0]) table = stripQuotes(parts[1]) if !hasParentheses { columns = []string{stripQuotes(parts[2])} } } else if len(parts) == 2 { // Could be "schema"."table" or "table"."column" // If columns are already extracted from parentheses, this is schema.table // If no parentheses, this is table.column if hasParentheses { schema = stripQuotes(parts[0]) table = stripQuotes(parts[1]) } else { schema = "public" table = stripQuotes(parts[0]) columns = []string{stripQuotes(parts[1])} } } else if len(parts) == 1 { // Format: "table" schema = "public" table = stripQuotes(parts[0]) } return }