Files
relspecgo/pkg/readers/dbml/reader.go
Hein 5e1448dcdb
Some checks are pending
CI / Test (1.23) (push) Waiting to run
CI / Test (1.24) (push) Waiting to run
CI / Test (1.25) (push) Waiting to run
CI / Lint (push) Waiting to run
CI / Build (push) Waiting to run
sql writer
2025-12-17 20:44:02 +02:00

536 lines
15 KiB
Go

package dbml
import (
"bufio"
"fmt"
"os"
"regexp"
"strings"
"git.warky.dev/wdevs/relspecgo/pkg/models"
"git.warky.dev/wdevs/relspecgo/pkg/readers"
)
// Reader implements the readers.Reader interface for DBML format
type Reader struct {
options *readers.ReaderOptions
}
// NewReader creates a new DBML reader with the given options
func NewReader(options *readers.ReaderOptions) *Reader {
return &Reader{
options: options,
}
}
// ReadDatabase reads and parses DBML input, returning a Database model
func (r *Reader) ReadDatabase() (*models.Database, error) {
if r.options.FilePath == "" {
return nil, fmt.Errorf("file path is required for DBML reader")
}
content, err := os.ReadFile(r.options.FilePath)
if err != nil {
return nil, fmt.Errorf("failed to read file: %w", err)
}
return r.parseDBML(string(content))
}
// ReadSchema reads and parses DBML input, returning a Schema model
func (r *Reader) ReadSchema() (*models.Schema, error) {
db, err := r.ReadDatabase()
if err != nil {
return nil, err
}
if len(db.Schemas) == 0 {
return nil, fmt.Errorf("no schemas found in DBML")
}
// Return the first schema
return db.Schemas[0], nil
}
// ReadTable reads and parses DBML input, returning a Table model
func (r *Reader) ReadTable() (*models.Table, error) {
schema, err := r.ReadSchema()
if err != nil {
return nil, err
}
if len(schema.Tables) == 0 {
return nil, fmt.Errorf("no tables found in DBML")
}
// Return the first table
return schema.Tables[0], nil
}
// stripQuotes removes surrounding quotes from an identifier
func stripQuotes(s string) string {
s = strings.TrimSpace(s)
if len(s) >= 2 && ((s[0] == '"' && s[len(s)-1] == '"') || (s[0] == '\'' && s[len(s)-1] == '\'')) {
return s[1 : len(s)-1]
}
return s
}
// parseDBML parses DBML content and returns a Database model
func (r *Reader) parseDBML(content string) (*models.Database, error) {
db := models.InitDatabase("database")
if r.options.Metadata != nil {
if name, ok := r.options.Metadata["name"].(string); ok {
db.Name = name
}
}
scanner := bufio.NewScanner(strings.NewReader(content))
schemaMap := make(map[string]*models.Schema)
pendingConstraints := []*models.Constraint{}
var currentTable *models.Table
var currentSchema string
var inIndexes bool
var inTable bool
tableRegex := regexp.MustCompile(`^Table\s+(.+?)\s*{`)
refRegex := regexp.MustCompile(`^Ref:\s+(.+)`)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
// Skip empty lines and comments
if line == "" || strings.HasPrefix(line, "//") {
continue
}
// Parse Table definition
if matches := tableRegex.FindStringSubmatch(line); matches != nil {
tableName := matches[1]
parts := strings.Split(tableName, ".")
if len(parts) == 2 {
currentSchema = stripQuotes(parts[0])
tableName = stripQuotes(parts[1])
} else {
currentSchema = "public"
tableName = stripQuotes(parts[0])
}
// Ensure schema exists
if _, exists := schemaMap[currentSchema]; !exists {
schemaMap[currentSchema] = models.InitSchema(currentSchema)
}
currentTable = models.InitTable(tableName, currentSchema)
inTable = true
inIndexes = false
continue
}
// End of table definition
if inTable && line == "}" {
if currentTable != nil && currentSchema != "" {
schemaMap[currentSchema].Tables = append(schemaMap[currentSchema].Tables, currentTable)
currentTable = nil
}
inTable = false
inIndexes = false
continue
}
// Parse indexes section
if inTable && (strings.HasPrefix(line, "Indexes {") || strings.HasPrefix(line, "indexes {")) {
inIndexes = true
continue
}
// End of indexes section
if inIndexes && line == "}" {
inIndexes = false
continue
}
// Parse index definition
if inIndexes && currentTable != nil {
index := r.parseIndex(line, currentTable.Name, currentSchema)
if index != nil {
currentTable.Indexes[index.Name] = index
}
continue
}
// Parse table note
if inTable && currentTable != nil && strings.HasPrefix(line, "Note:") {
note := strings.TrimPrefix(line, "Note:")
note = strings.Trim(note, " '\"")
currentTable.Description = note
continue
}
// Parse column definition
if inTable && !inIndexes && currentTable != nil {
column, constraint := r.parseColumn(line, currentTable.Name, currentSchema)
if column != nil {
currentTable.Columns[column.Name] = column
}
if constraint != nil {
// Add to pending list - will assign to tables at the end
pendingConstraints = append(pendingConstraints, constraint)
}
continue
}
// Parse Ref (relationship/foreign key)
if matches := refRegex.FindStringSubmatch(line); matches != nil {
constraint := r.parseRef(matches[1])
if constraint != nil {
// Find the table and add the constraint
for _, schema := range schemaMap {
for _, table := range schema.Tables {
if table.Schema == constraint.Schema && table.Name == constraint.Table {
table.Constraints[constraint.Name] = constraint
break
}
}
}
}
continue
}
}
// Assign pending constraints to their respective tables
for _, constraint := range pendingConstraints {
// Find the table this constraint belongs to
if schema, exists := schemaMap[constraint.Schema]; exists {
for _, table := range schema.Tables {
if table.Name == constraint.Table {
table.Constraints[constraint.Name] = constraint
break
}
}
}
}
// Add schemas to database
for _, schema := range schemaMap {
db.Schemas = append(db.Schemas, schema)
}
return db, nil
}
// parseColumn parses a DBML column definition
func (r *Reader) parseColumn(line, tableName, schemaName string) (*models.Column, *models.Constraint) {
// Format: column_name type [attributes] // comment
parts := strings.Fields(line)
if len(parts) < 2 {
return nil, nil
}
columnName := stripQuotes(parts[0])
columnType := stripQuotes(parts[1])
column := models.InitColumn(columnName, tableName, schemaName)
column.Type = columnType
var constraint *models.Constraint
// Parse attributes in brackets
if strings.Contains(line, "[") && strings.Contains(line, "]") {
attrStart := strings.Index(line, "[")
attrEnd := strings.Index(line, "]")
if attrStart < attrEnd {
attrs := line[attrStart+1 : attrEnd]
attrList := strings.Split(attrs, ",")
for _, attr := range attrList {
attr = strings.TrimSpace(attr)
if strings.Contains(attr, "primary key") || attr == "pk" {
column.IsPrimaryKey = true
column.NotNull = true
} else if strings.Contains(attr, "not null") {
column.NotNull = true
} else if attr == "increment" {
column.AutoIncrement = true
} else if strings.HasPrefix(attr, "default:") {
defaultVal := strings.TrimSpace(strings.TrimPrefix(attr, "default:"))
column.Default = strings.Trim(defaultVal, "'\"")
} else if attr == "unique" {
// Create a unique constraint
uniqueConstraint := models.InitConstraint(
fmt.Sprintf("uq_%s", columnName),
models.UniqueConstraint,
)
uniqueConstraint.Schema = schemaName
uniqueConstraint.Table = tableName
uniqueConstraint.Columns = []string{columnName}
// Store it to be added later
if constraint == nil {
constraint = uniqueConstraint
}
} else if strings.HasPrefix(attr, "ref:") {
// Parse inline reference
// DBML semantics depend on context:
// - On FK column: ref: < target means "this FK references target"
// - On PK column: ref: < source means "source references this PK" (reverse notation)
refStr := strings.TrimSpace(strings.TrimPrefix(attr, "ref:"))
// Check relationship direction operator
refOp := strings.TrimSpace(refStr)
var isReverse bool
if strings.HasPrefix(refOp, "<") {
isReverse = column.IsPrimaryKey // < on PK means "is referenced by" (reverse)
} else if strings.HasPrefix(refOp, ">") {
isReverse = !column.IsPrimaryKey // > on FK means reverse
}
constraint = r.parseRef(refStr)
if constraint != nil {
if isReverse {
// Reverse: parsed ref is SOURCE, current column is TARGET
// Constraint should be ON the source table
constraint.Schema = constraint.ReferencedSchema
constraint.Table = constraint.ReferencedTable
constraint.Columns = constraint.ReferencedColumns
constraint.ReferencedSchema = schemaName
constraint.ReferencedTable = tableName
constraint.ReferencedColumns = []string{columnName}
} else {
// Forward: current column is SOURCE, parsed ref is TARGET
// Standard FK: constraint is ON current table
constraint.Schema = schemaName
constraint.Table = tableName
constraint.Columns = []string{columnName}
}
// Generate short constraint name based on the column
constraint.Name = fmt.Sprintf("fk_%s", constraint.Columns[0])
}
}
}
}
}
// Parse inline comment
if strings.Contains(line, "//") {
commentStart := strings.Index(line, "//")
column.Comment = strings.TrimSpace(line[commentStart+2:])
}
return column, constraint
}
// parseIndex parses a DBML index definition
func (r *Reader) parseIndex(line, tableName, schemaName string) *models.Index {
// Format: (columns) [attributes] OR columnname [attributes]
var columns []string
if strings.Contains(line, "(") && strings.Contains(line, ")") {
// Multi-column format: (col1, col2) [attributes]
colStart := strings.Index(line, "(")
colEnd := strings.Index(line, ")")
if colStart >= colEnd {
return nil
}
columnsStr := line[colStart+1 : colEnd]
for _, col := range strings.Split(columnsStr, ",") {
columns = append(columns, stripQuotes(strings.TrimSpace(col)))
}
} else {
// Single column format: columnname [attributes]
// Extract column name before the bracket
if strings.Contains(line, "[") {
colName := strings.TrimSpace(line[:strings.Index(line, "[")])
if colName != "" {
columns = []string{stripQuotes(colName)}
}
}
}
if len(columns) == 0 {
return nil
}
index := models.InitIndex("")
index.Table = tableName
index.Schema = schemaName
index.Columns = columns
// Parse attributes
if strings.Contains(line, "[") && strings.Contains(line, "]") {
attrStart := strings.Index(line, "[")
attrEnd := strings.Index(line, "]")
if attrStart < attrEnd {
attrs := line[attrStart+1 : attrEnd]
attrList := strings.Split(attrs, ",")
for _, attr := range attrList {
attr = strings.TrimSpace(attr)
if attr == "unique" {
index.Unique = true
} else if strings.HasPrefix(attr, "name:") {
name := strings.TrimSpace(strings.TrimPrefix(attr, "name:"))
index.Name = strings.Trim(name, "'\"")
} else if strings.HasPrefix(attr, "type:") {
indexType := strings.TrimSpace(strings.TrimPrefix(attr, "type:"))
index.Type = strings.Trim(indexType, "'\"")
}
}
}
}
// Generate name if not provided
if index.Name == "" {
index.Name = fmt.Sprintf("idx_%s_%s", tableName, strings.Join(columns, "_"))
}
return index
}
// parseRef parses a DBML Ref (foreign key relationship)
func (r *Reader) parseRef(refStr string) *models.Constraint {
// Format: schema.table.(columns) > schema.table.(columns) [actions]
// Or inline format: < schema.table.column (for inline column refs)
// Split by relationship operator (>, <, -, etc.)
var fromPart, toPart string
isInlineRef := false
for _, op := range []string{">", "<", "-"} {
if strings.Contains(refStr, op) {
parts := strings.Split(refStr, op)
if len(parts) == 2 {
fromPart = strings.TrimSpace(parts[0])
toPart = strings.TrimSpace(parts[1])
// Check if this is an inline ref (operator at start)
if fromPart == "" {
isInlineRef = true
}
break
}
}
}
// For inline refs, only toPart should be populated
if isInlineRef {
if toPart == "" {
return nil
}
} else if fromPart == "" || toPart == "" {
return nil
}
// Remove actions part if present
if idx := strings.Index(toPart, "["); idx >= 0 {
toPart = strings.TrimSpace(toPart[:idx])
}
// Parse references
var fromSchema, fromTable string
var fromColumns []string
toSchema, toTable, toColumns := r.parseTableRef(toPart)
if !isInlineRef {
fromSchema, fromTable, fromColumns = r.parseTableRef(fromPart)
if fromTable == "" {
return nil
}
}
if toTable == "" {
return nil
}
// Generate short constraint name based on the source column
constraintName := fmt.Sprintf("fk_%s_%s", fromTable, toTable)
if len(fromColumns) > 0 {
constraintName = fmt.Sprintf("fk_%s", fromColumns[0])
}
constraint := models.InitConstraint(
constraintName,
models.ForeignKeyConstraint,
)
constraint.Schema = fromSchema
constraint.Table = fromTable
constraint.Columns = fromColumns
constraint.ReferencedSchema = toSchema
constraint.ReferencedTable = toTable
constraint.ReferencedColumns = toColumns
// Parse actions if present
if strings.Contains(refStr, "[") && strings.Contains(refStr, "]") {
actStart := strings.Index(refStr, "[")
actEnd := strings.Index(refStr, "]")
if actStart < actEnd {
actions := refStr[actStart+1 : actEnd]
actionList := strings.Split(actions, ",")
for _, action := range actionList {
action = strings.TrimSpace(action)
if strings.HasPrefix(action, "ondelete:") {
constraint.OnDelete = strings.TrimSpace(strings.TrimPrefix(action, "ondelete:"))
} else if strings.HasPrefix(action, "onupdate:") {
constraint.OnUpdate = strings.TrimSpace(strings.TrimPrefix(action, "onupdate:"))
}
}
}
}
return constraint
}
// parseTableRef parses a table reference like "schema.table.(column1, column2)" or "schema"."table"."column"
func (r *Reader) parseTableRef(ref string) (schema, table string, columns []string) {
// Extract columns if present in parentheses format
hasParentheses := false
if strings.Contains(ref, "(") && strings.Contains(ref, ")") {
colStart := strings.Index(ref, "(")
colEnd := strings.Index(ref, ")")
if colStart < colEnd {
columnsStr := ref[colStart+1 : colEnd]
for _, col := range strings.Split(columnsStr, ",") {
columns = append(columns, stripQuotes(strings.TrimSpace(col)))
}
hasParentheses = true
}
ref = ref[:colStart]
}
// Parse schema, table, and optionally column
parts := strings.Split(strings.TrimSpace(ref), ".")
if len(parts) == 3 {
// Format: "schema"."table"."column"
schema = stripQuotes(parts[0])
table = stripQuotes(parts[1])
if !hasParentheses {
columns = []string{stripQuotes(parts[2])}
}
} else if len(parts) == 2 {
// Could be "schema"."table" or "table"."column"
// If columns are already extracted from parentheses, this is schema.table
// If no parentheses, this is table.column
if hasParentheses {
schema = stripQuotes(parts[0])
table = stripQuotes(parts[1])
} else {
schema = "public"
table = stripQuotes(parts[0])
columns = []string{stripQuotes(parts[1])}
}
} else if len(parts) == 1 {
// Format: "table"
schema = "public"
table = stripQuotes(parts[0])
}
return
}