package dbml import ( "bufio" "fmt" "os" "regexp" "strings" "git.warky.dev/wdevs/relspecgo/pkg/models" "git.warky.dev/wdevs/relspecgo/pkg/readers" ) // Reader implements the readers.Reader interface for DBML format type Reader struct { options *readers.ReaderOptions } // NewReader creates a new DBML reader with the given options func NewReader(options *readers.ReaderOptions) *Reader { return &Reader{ options: options, } } // ReadDatabase reads and parses DBML input, returning a Database model func (r *Reader) ReadDatabase() (*models.Database, error) { if r.options.FilePath == "" { return nil, fmt.Errorf("file path is required for DBML reader") } content, err := os.ReadFile(r.options.FilePath) if err != nil { return nil, fmt.Errorf("failed to read file: %w", err) } return r.parseDBML(string(content)) } // ReadSchema reads and parses DBML input, returning a Schema model func (r *Reader) ReadSchema() (*models.Schema, error) { db, err := r.ReadDatabase() if err != nil { return nil, err } if len(db.Schemas) == 0 { return nil, fmt.Errorf("no schemas found in DBML") } // Return the first schema return db.Schemas[0], nil } // ReadTable reads and parses DBML input, returning a Table model func (r *Reader) ReadTable() (*models.Table, error) { schema, err := r.ReadSchema() if err != nil { return nil, err } if len(schema.Tables) == 0 { return nil, fmt.Errorf("no tables found in DBML") } // Return the first table return schema.Tables[0], nil } // stripQuotes removes surrounding quotes from an identifier func stripQuotes(s string) string { s = strings.TrimSpace(s) if len(s) >= 2 && ((s[0] == '"' && s[len(s)-1] == '"') || (s[0] == '\'' && s[len(s)-1] == '\'')) { return s[1 : len(s)-1] } return s } // parseDBML parses DBML content and returns a Database model func (r *Reader) parseDBML(content string) (*models.Database, error) { db := models.InitDatabase("database") if r.options.Metadata != nil { if name, ok := r.options.Metadata["name"].(string); ok { db.Name = name } } scanner := bufio.NewScanner(strings.NewReader(content)) schemaMap := make(map[string]*models.Schema) pendingConstraints := []*models.Constraint{} var currentTable *models.Table var currentSchema string var inIndexes bool var inTable bool tableRegex := regexp.MustCompile(`^Table\s+(.+?)\s*{`) refRegex := regexp.MustCompile(`^Ref:\s+(.+)`) for scanner.Scan() { line := strings.TrimSpace(scanner.Text()) // Skip empty lines and comments if line == "" || strings.HasPrefix(line, "//") { continue } // Parse Table definition if matches := tableRegex.FindStringSubmatch(line); matches != nil { tableName := matches[1] parts := strings.Split(tableName, ".") if len(parts) == 2 { currentSchema = stripQuotes(parts[0]) tableName = stripQuotes(parts[1]) } else { currentSchema = "public" tableName = stripQuotes(parts[0]) } // Ensure schema exists if _, exists := schemaMap[currentSchema]; !exists { schemaMap[currentSchema] = models.InitSchema(currentSchema) } currentTable = models.InitTable(tableName, currentSchema) inTable = true inIndexes = false continue } // End of table definition if inTable && line == "}" { if currentTable != nil && currentSchema != "" { schemaMap[currentSchema].Tables = append(schemaMap[currentSchema].Tables, currentTable) currentTable = nil } inTable = false inIndexes = false continue } // Parse indexes section if inTable && (strings.HasPrefix(line, "Indexes {") || strings.HasPrefix(line, "indexes {")) { inIndexes = true continue } // End of indexes section if inIndexes && line == "}" { inIndexes = false continue } // Parse index definition if inIndexes && currentTable != nil { index := r.parseIndex(line, currentTable.Name, currentSchema) if index != nil { currentTable.Indexes[index.Name] = index } continue } // Parse table note if inTable && currentTable != nil && strings.HasPrefix(line, "Note:") { note := strings.TrimPrefix(line, "Note:") note = strings.Trim(note, " '\"") currentTable.Description = note continue } // Parse column definition if inTable && !inIndexes && currentTable != nil { column, constraint := r.parseColumn(line, currentTable.Name, currentSchema) if column != nil { currentTable.Columns[column.Name] = column } if constraint != nil { // Add to pending list - will assign to tables at the end pendingConstraints = append(pendingConstraints, constraint) } continue } // Parse Ref (relationship/foreign key) if matches := refRegex.FindStringSubmatch(line); matches != nil { constraint := r.parseRef(matches[1]) if constraint != nil { // Find the table and add the constraint for _, schema := range schemaMap { for _, table := range schema.Tables { if table.Schema == constraint.Schema && table.Name == constraint.Table { table.Constraints[constraint.Name] = constraint break } } } } continue } } // Assign pending constraints to their respective tables for _, constraint := range pendingConstraints { // Find the table this constraint belongs to if schema, exists := schemaMap[constraint.Schema]; exists { for _, table := range schema.Tables { if table.Name == constraint.Table { table.Constraints[constraint.Name] = constraint break } } } } // Add schemas to database for _, schema := range schemaMap { db.Schemas = append(db.Schemas, schema) } return db, nil } // parseColumn parses a DBML column definition func (r *Reader) parseColumn(line, tableName, schemaName string) (*models.Column, *models.Constraint) { // Format: column_name type [attributes] // comment parts := strings.Fields(line) if len(parts) < 2 { return nil, nil } columnName := stripQuotes(parts[0]) columnType := stripQuotes(parts[1]) column := models.InitColumn(columnName, tableName, schemaName) column.Type = columnType var constraint *models.Constraint // Parse attributes in brackets if strings.Contains(line, "[") && strings.Contains(line, "]") { attrStart := strings.Index(line, "[") attrEnd := strings.Index(line, "]") if attrStart < attrEnd { attrs := line[attrStart+1 : attrEnd] attrList := strings.Split(attrs, ",") for _, attr := range attrList { attr = strings.TrimSpace(attr) if strings.Contains(attr, "primary key") || attr == "pk" { column.IsPrimaryKey = true column.NotNull = true } else if strings.Contains(attr, "not null") { column.NotNull = true } else if attr == "increment" { column.AutoIncrement = true } else if strings.HasPrefix(attr, "default:") { defaultVal := strings.TrimSpace(strings.TrimPrefix(attr, "default:")) column.Default = strings.Trim(defaultVal, "'\"") } else if attr == "unique" { // Create a unique constraint uniqueConstraint := models.InitConstraint( fmt.Sprintf("uq_%s", columnName), models.UniqueConstraint, ) uniqueConstraint.Schema = schemaName uniqueConstraint.Table = tableName uniqueConstraint.Columns = []string{columnName} // Store it to be added later if constraint == nil { constraint = uniqueConstraint } } else if strings.HasPrefix(attr, "note:") { // Parse column note/comment note := strings.TrimSpace(strings.TrimPrefix(attr, "note:")) column.Comment = strings.Trim(note, "'\"") } else if strings.HasPrefix(attr, "ref:") { // Parse inline reference // DBML semantics depend on context: // - On FK column: ref: < target means "this FK references target" // - On PK column: ref: < source means "source references this PK" (reverse notation) refStr := strings.TrimSpace(strings.TrimPrefix(attr, "ref:")) // Check relationship direction operator refOp := strings.TrimSpace(refStr) var isReverse bool if strings.HasPrefix(refOp, "<") { isReverse = column.IsPrimaryKey // < on PK means "is referenced by" (reverse) } else if strings.HasPrefix(refOp, ">") { isReverse = !column.IsPrimaryKey // > on FK means reverse } constraint = r.parseRef(refStr) if constraint != nil { if isReverse { // Reverse: parsed ref is SOURCE, current column is TARGET // Constraint should be ON the source table constraint.Schema = constraint.ReferencedSchema constraint.Table = constraint.ReferencedTable constraint.Columns = constraint.ReferencedColumns constraint.ReferencedSchema = schemaName constraint.ReferencedTable = tableName constraint.ReferencedColumns = []string{columnName} } else { // Forward: current column is SOURCE, parsed ref is TARGET // Standard FK: constraint is ON current table constraint.Schema = schemaName constraint.Table = tableName constraint.Columns = []string{columnName} } // Generate short constraint name based on the column constraint.Name = fmt.Sprintf("fk_%s", constraint.Columns[0]) } } } } } // Parse inline comment if strings.Contains(line, "//") { commentStart := strings.Index(line, "//") column.Comment = strings.TrimSpace(line[commentStart+2:]) } return column, constraint } // parseIndex parses a DBML index definition func (r *Reader) parseIndex(line, tableName, schemaName string) *models.Index { // Format: (columns) [attributes] OR columnname [attributes] var columns []string if strings.Contains(line, "(") && strings.Contains(line, ")") { // Multi-column format: (col1, col2) [attributes] colStart := strings.Index(line, "(") colEnd := strings.Index(line, ")") if colStart >= colEnd { return nil } columnsStr := line[colStart+1 : colEnd] for _, col := range strings.Split(columnsStr, ",") { columns = append(columns, stripQuotes(strings.TrimSpace(col))) } } else if strings.Contains(line, "[") { // Single column format: columnname [attributes] // Extract column name before the bracket idx := strings.Index(line, "[") if idx > 0 { colName := strings.TrimSpace(line[:idx]) if colName != "" { columns = []string{stripQuotes(colName)} } } } if len(columns) == 0 { return nil } index := models.InitIndex("", tableName, schemaName) index.Table = tableName index.Schema = schemaName index.Columns = columns // Parse attributes if strings.Contains(line, "[") && strings.Contains(line, "]") { attrStart := strings.Index(line, "[") attrEnd := strings.Index(line, "]") if attrStart < attrEnd { attrs := line[attrStart+1 : attrEnd] attrList := strings.Split(attrs, ",") for _, attr := range attrList { attr = strings.TrimSpace(attr) if attr == "unique" { index.Unique = true } else if strings.HasPrefix(attr, "name:") { name := strings.TrimSpace(strings.TrimPrefix(attr, "name:")) index.Name = strings.Trim(name, "'\"") } else if strings.HasPrefix(attr, "type:") { indexType := strings.TrimSpace(strings.TrimPrefix(attr, "type:")) index.Type = strings.Trim(indexType, "'\"") } } } } // Generate name if not provided if index.Name == "" { index.Name = fmt.Sprintf("idx_%s_%s", tableName, strings.Join(columns, "_")) } return index } // parseRef parses a DBML Ref (foreign key relationship) func (r *Reader) parseRef(refStr string) *models.Constraint { // Format: schema.table.(columns) > schema.table.(columns) [actions] // Or inline format: < schema.table.column (for inline column refs) // Split by relationship operator (>, <, -, etc.) var fromPart, toPart string isInlineRef := false for _, op := range []string{">", "<", "-"} { if strings.Contains(refStr, op) { parts := strings.Split(refStr, op) if len(parts) == 2 { fromPart = strings.TrimSpace(parts[0]) toPart = strings.TrimSpace(parts[1]) // Check if this is an inline ref (operator at start) if fromPart == "" { isInlineRef = true } break } } } // For inline refs, only toPart should be populated if isInlineRef { if toPart == "" { return nil } } else if fromPart == "" || toPart == "" { return nil } // Remove actions part if present if idx := strings.Index(toPart, "["); idx >= 0 { toPart = strings.TrimSpace(toPart[:idx]) } // Parse references var fromSchema, fromTable string var fromColumns []string toSchema, toTable, toColumns := r.parseTableRef(toPart) if !isInlineRef { fromSchema, fromTable, fromColumns = r.parseTableRef(fromPart) if fromTable == "" { return nil } } if toTable == "" { return nil } // Generate short constraint name based on the source column constraintName := fmt.Sprintf("fk_%s_%s", fromTable, toTable) if len(fromColumns) > 0 { constraintName = fmt.Sprintf("fk_%s", fromColumns[0]) } constraint := models.InitConstraint( constraintName, models.ForeignKeyConstraint, ) constraint.Schema = fromSchema constraint.Table = fromTable constraint.Columns = fromColumns constraint.ReferencedSchema = toSchema constraint.ReferencedTable = toTable constraint.ReferencedColumns = toColumns // Parse actions if present if strings.Contains(refStr, "[") && strings.Contains(refStr, "]") { actStart := strings.Index(refStr, "[") actEnd := strings.Index(refStr, "]") if actStart < actEnd { actions := refStr[actStart+1 : actEnd] actionList := strings.Split(actions, ",") for _, action := range actionList { action = strings.TrimSpace(action) if strings.HasPrefix(action, "ondelete:") { constraint.OnDelete = strings.TrimSpace(strings.TrimPrefix(action, "ondelete:")) } else if strings.HasPrefix(action, "onupdate:") { constraint.OnUpdate = strings.TrimSpace(strings.TrimPrefix(action, "onupdate:")) } } } } return constraint } // parseTableRef parses a table reference like "schema.table.(column1, column2)" or "schema"."table"."column" func (r *Reader) parseTableRef(ref string) (schema, table string, columns []string) { // Extract columns if present in parentheses format hasParentheses := false if strings.Contains(ref, "(") && strings.Contains(ref, ")") { colStart := strings.Index(ref, "(") colEnd := strings.Index(ref, ")") if colStart < colEnd { columnsStr := ref[colStart+1 : colEnd] for _, col := range strings.Split(columnsStr, ",") { columns = append(columns, stripQuotes(strings.TrimSpace(col))) } hasParentheses = true } ref = ref[:colStart] } // Parse schema, table, and optionally column parts := strings.Split(strings.TrimSpace(ref), ".") if len(parts) == 3 { // Format: "schema"."table"."column" schema = stripQuotes(parts[0]) table = stripQuotes(parts[1]) if !hasParentheses { columns = []string{stripQuotes(parts[2])} } } else if len(parts) == 2 { // Could be "schema"."table" or "table"."column" // If columns are already extracted from parentheses, this is schema.table // If no parentheses, this is table.column if hasParentheses { schema = stripQuotes(parts[0]) table = stripQuotes(parts[1]) } else { schema = "public" table = stripQuotes(parts[0]) columns = []string{stripQuotes(parts[1])} } } else if len(parts) == 1 { // Format: "table" schema = "public" table = stripQuotes(parts[0]) } return }