Files
relspecgo/pkg/readers/drizzle/reader.go
Hein b4ff4334cc
Some checks failed
CI / Lint (push) Successful in -27m53s
CI / Test (1.24) (push) Successful in -27m31s
CI / Build (push) Successful in -28m13s
CI / Test (1.25) (push) Failing after 1m11s
Integration Tests / Integration Tests (push) Failing after -28m15s
feat(models): 🎉 Add GUID field to various models
* Introduced GUID field to Database, Domain, DomainTable, Schema, Table, View, Sequence, Column, Index, Relationship, Constraint, Enum, and Script models.
* Updated initialization functions to assign new GUIDs using uuid package.
* Enhanced DCTX reader and writer to utilize GUIDs from models where available.
2026-01-04 19:53:17 +02:00

618 lines
16 KiB
Go

package drizzle
import (
"bufio"
"fmt"
"os"
"path/filepath"
"regexp"
"strings"
"git.warky.dev/wdevs/relspecgo/pkg/models"
"git.warky.dev/wdevs/relspecgo/pkg/readers"
)
// Reader implements the readers.Reader interface for Drizzle schema format
type Reader struct {
options *readers.ReaderOptions
}
// NewReader creates a new Drizzle reader with the given options
func NewReader(options *readers.ReaderOptions) *Reader {
return &Reader{
options: options,
}
}
// ReadDatabase reads and parses Drizzle schema input, returning a Database model
func (r *Reader) ReadDatabase() (*models.Database, error) {
if r.options.FilePath == "" {
return nil, fmt.Errorf("file path is required for Drizzle reader")
}
// Check if it's a file or directory
info, err := os.Stat(r.options.FilePath)
if err != nil {
return nil, fmt.Errorf("failed to stat path: %w", err)
}
if info.IsDir() {
// Read all .ts files in the directory
return r.readDirectory(r.options.FilePath)
}
// Read single file
content, err := os.ReadFile(r.options.FilePath)
if err != nil {
return nil, fmt.Errorf("failed to read file: %w", err)
}
return r.parseDrizzle(string(content))
}
// ReadSchema reads and parses Drizzle schema input, returning a Schema model
func (r *Reader) ReadSchema() (*models.Schema, error) {
db, err := r.ReadDatabase()
if err != nil {
return nil, err
}
if len(db.Schemas) == 0 {
return nil, fmt.Errorf("no schemas found in Drizzle schema")
}
// Return the first schema
return db.Schemas[0], nil
}
// ReadTable reads and parses Drizzle schema input, returning a Table model
func (r *Reader) ReadTable() (*models.Table, error) {
schema, err := r.ReadSchema()
if err != nil {
return nil, err
}
if len(schema.Tables) == 0 {
return nil, fmt.Errorf("no tables found in Drizzle schema")
}
// Return the first table
return schema.Tables[0], nil
}
// readDirectory reads all .ts files in a directory and parses them
func (r *Reader) readDirectory(dirPath string) (*models.Database, error) {
db := models.InitDatabase("database")
if r.options.Metadata != nil {
if name, ok := r.options.Metadata["name"].(string); ok {
db.Name = name
}
}
// Default schema for Drizzle
schema := models.InitSchema("public")
schema.Enums = make([]*models.Enum, 0)
// Read all .ts files
files, err := filepath.Glob(filepath.Join(dirPath, "*.ts"))
if err != nil {
return nil, fmt.Errorf("failed to glob directory: %w", err)
}
// Parse each file
for _, file := range files {
content, err := os.ReadFile(file)
if err != nil {
return nil, fmt.Errorf("failed to read file %s: %w", file, err)
}
// Parse and merge into schema
fileDB, err := r.parseDrizzle(string(content))
if err != nil {
return nil, fmt.Errorf("failed to parse file %s: %w", file, err)
}
// Merge schemas
if len(fileDB.Schemas) > 0 {
fileSchema := fileDB.Schemas[0]
schema.Tables = append(schema.Tables, fileSchema.Tables...)
schema.Enums = append(schema.Enums, fileSchema.Enums...)
}
}
db.Schemas = append(db.Schemas, schema)
return db, nil
}
// parseDrizzle parses Drizzle schema content and returns a Database model
func (r *Reader) parseDrizzle(content string) (*models.Database, error) {
db := models.InitDatabase("database")
if r.options.Metadata != nil {
if name, ok := r.options.Metadata["name"].(string); ok {
db.Name = name
}
}
// Default schema for Drizzle (PostgreSQL)
schema := models.InitSchema("public")
schema.Enums = make([]*models.Enum, 0)
db.DatabaseType = models.PostgresqlDatabaseType
scanner := bufio.NewScanner(strings.NewReader(content))
// Regex patterns
// Match: export const users = pgTable('users', {
pgTableRegex := regexp.MustCompile(`export\s+const\s+(\w+)\s*=\s*pgTable\s*\(\s*['"](\w+)['"]`)
// Match: export const userRole = pgEnum('UserRole', ['admin', 'user']);
pgEnumRegex := regexp.MustCompile(`export\s+const\s+(\w+)\s*=\s*pgEnum\s*\(\s*['"](\w+)['"]`)
// State tracking
var currentTable *models.Table
var currentTableVarName string
var inTableBlock bool
var blockDepth int
var tableLines []string
for scanner.Scan() {
line := scanner.Text()
trimmed := strings.TrimSpace(line)
// Skip empty lines and comments
if trimmed == "" || strings.HasPrefix(trimmed, "//") {
continue
}
// Check for pgEnum definition
if matches := pgEnumRegex.FindStringSubmatch(trimmed); matches != nil {
enum := r.parsePgEnum(trimmed, matches)
if enum != nil {
schema.Enums = append(schema.Enums, enum)
}
continue
}
// Check for pgTable definition
if matches := pgTableRegex.FindStringSubmatch(trimmed); matches != nil {
varName := matches[1]
tableName := matches[2]
currentTableVarName = varName
currentTable = models.InitTable(tableName, "public")
inTableBlock = true
// Count braces in the first line
blockDepth = strings.Count(line, "{") - strings.Count(line, "}")
tableLines = []string{line}
continue
}
// If we're in a table block, accumulate lines
if inTableBlock {
tableLines = append(tableLines, line)
// Track brace depth
blockDepth += strings.Count(line, "{")
blockDepth -= strings.Count(line, "}")
// Check if we've closed the table definition
if blockDepth < 0 || (blockDepth == 0 && strings.Contains(line, ");")) {
// Parse the complete table block
if currentTable != nil {
r.parseTableBlock(tableLines, currentTable, currentTableVarName)
schema.Tables = append(schema.Tables, currentTable)
currentTable = nil
}
inTableBlock = false
tableLines = nil
}
}
}
db.Schemas = append(db.Schemas, schema)
return db, nil
}
// parsePgEnum parses a pgEnum definition
func (r *Reader) parsePgEnum(line string, matches []string) *models.Enum {
// matches[1] = variable name
// matches[2] = enum name
enumName := matches[2]
// Extract values from the array
// Example: pgEnum('UserRole', ['admin', 'user', 'guest'])
valuesRegex := regexp.MustCompile(`\[(.*?)\]`)
valuesMatch := valuesRegex.FindStringSubmatch(line)
if valuesMatch == nil {
return nil
}
valuesStr := valuesMatch[1]
// Split by comma and clean up
valueParts := strings.Split(valuesStr, ",")
values := make([]string, 0)
for _, part := range valueParts {
// Remove quotes and whitespace
cleaned := strings.TrimSpace(part)
cleaned = strings.Trim(cleaned, "'\"")
if cleaned != "" {
values = append(values, cleaned)
}
}
enum := models.InitEnum(enumName, "public")
enum.Values = values
return enum
}
// parseTableBlock parses a complete pgTable definition block
func (r *Reader) parseTableBlock(lines []string, table *models.Table, tableVarName string) {
// Join all lines into a single string for easier parsing
fullText := strings.Join(lines, "\n")
// Extract the columns block and index callback separately
// The structure is: pgTable('name', { columns }, (table) => [indexes])
// Find the main object block (columns)
columnsStart := strings.Index(fullText, "{")
if columnsStart == -1 {
return
}
// Find matching closing brace for columns
depth := 0
columnsEnd := -1
for i := columnsStart; i < len(fullText); i++ {
if fullText[i] == '{' {
depth++
} else if fullText[i] == '}' {
depth--
if depth == 0 {
columnsEnd = i
break
}
}
}
if columnsEnd == -1 {
return
}
columnsBlock := fullText[columnsStart+1 : columnsEnd]
// Parse columns
r.parseColumnsBlock(columnsBlock, table, tableVarName)
// Check for index callback: , (table) => [ or , ({ col1, col2 }) => [
// Match: }, followed by arrow function with any parameters
// Use (?s) flag to make . match newlines
indexCallbackRegex := regexp.MustCompile(`(?s)}\s*,\s*\(.*?\)\s*=>\s*\[`)
if indexCallbackRegex.MatchString(fullText[columnsEnd:]) {
// Find the index array
indexStart := strings.Index(fullText[columnsEnd:], "[")
if indexStart != -1 {
indexStart += columnsEnd
indexDepth := 0
indexEnd := -1
for i := indexStart; i < len(fullText); i++ {
if fullText[i] == '[' {
indexDepth++
} else if fullText[i] == ']' {
indexDepth--
if indexDepth == 0 {
indexEnd = i
break
}
}
}
if indexEnd != -1 {
indexBlock := fullText[indexStart+1 : indexEnd]
r.parseIndexBlock(indexBlock, table, tableVarName)
}
}
}
}
// parseColumnsBlock parses the columns block of a table
func (r *Reader) parseColumnsBlock(block string, table *models.Table, tableVarName string) {
// Split by lines and parse each column definition
lines := strings.Split(block, "\n")
for _, line := range lines {
trimmed := strings.TrimSpace(line)
if trimmed == "" || strings.HasPrefix(trimmed, "//") {
continue
}
// Match: fieldName: columnType('columnName').modifier().modifier(),
// Example: id: integer('id').primaryKey(),
columnRegex := regexp.MustCompile(`(\w+):\s*(\w+)\s*\(`)
matches := columnRegex.FindStringSubmatch(trimmed)
if matches == nil {
continue
}
fieldName := matches[1]
columnType := matches[2]
// Parse the column definition
col := r.parseColumnDefinition(trimmed, fieldName, columnType, table)
if col != nil {
table.Columns[col.Name] = col
}
}
}
// parseColumnDefinition parses a single column definition line
func (r *Reader) parseColumnDefinition(line, fieldName, drizzleType string, table *models.Table) *models.Column {
// Check for enum column syntax: pgEnum('EnumName')('column_name')
enumRegex := regexp.MustCompile(`pgEnum\s*\(['"](\w+)['"]\)\s*\(['"](\w+)['"]\)`)
if enumMatch := enumRegex.FindStringSubmatch(line); enumMatch != nil {
enumName := enumMatch[1]
columnName := enumMatch[2]
column := models.InitColumn(columnName, table.Name, table.Schema)
column.Type = enumName
column.NotNull = false
// Parse modifiers
r.parseColumnModifiers(line, column, table)
return column
}
// Extract column name from the first argument
// Example: integer('id')
nameRegex := regexp.MustCompile(`\w+\s*\(['"](\w+)['"]\)`)
nameMatch := nameRegex.FindStringSubmatch(line)
if nameMatch == nil {
return nil
}
columnName := nameMatch[1]
column := models.InitColumn(columnName, table.Name, table.Schema)
// Map Drizzle type to SQL type
column.Type = r.drizzleTypeToSQL(drizzleType)
// Default: columns are nullable unless specified
column.NotNull = false
// Parse modifiers
r.parseColumnModifiers(line, column, table)
return column
}
// drizzleTypeToSQL converts Drizzle column types to SQL types
func (r *Reader) drizzleTypeToSQL(drizzleType string) string {
typeMap := map[string]string{
// Integer types
"integer": "integer",
"bigint": "bigint",
"smallint": "smallint",
// Serial types
"serial": "serial",
"bigserial": "bigserial",
"smallserial": "smallserial",
// Numeric types
"numeric": "numeric",
"real": "real",
"doublePrecision": "double precision",
// Character types
"text": "text",
"varchar": "varchar",
"char": "char",
// Boolean
"boolean": "boolean",
// Binary
"bytea": "bytea",
// JSON
"json": "json",
"jsonb": "jsonb",
// Date/Time
"time": "time",
"timestamp": "timestamp",
"date": "date",
"interval": "interval",
// UUID
"uuid": "uuid",
// Geometric
"point": "point",
"line": "line",
}
if sqlType, ok := typeMap[drizzleType]; ok {
return sqlType
}
// If not found, might be an enum - return as-is
return drizzleType
}
// parseColumnModifiers parses column modifiers like .primaryKey(), .notNull(), etc.
func (r *Reader) parseColumnModifiers(line string, column *models.Column, table *models.Table) {
// Check for .primaryKey()
if strings.Contains(line, ".primaryKey()") {
column.IsPrimaryKey = true
column.NotNull = true
}
// Check for .notNull()
if strings.Contains(line, ".notNull()") {
column.NotNull = true
}
// Check for .unique()
if strings.Contains(line, ".unique()") {
uniqueConstraint := models.InitConstraint(
fmt.Sprintf("uq_%s", column.Name),
models.UniqueConstraint,
)
uniqueConstraint.Schema = table.Schema
uniqueConstraint.Table = table.Name
uniqueConstraint.Columns = []string{column.Name}
table.Constraints[uniqueConstraint.Name] = uniqueConstraint
}
// Check for .default(...)
// Need to handle nested backticks and parentheses in SQL expressions
defaultIdx := strings.Index(line, ".default(")
if defaultIdx != -1 {
start := defaultIdx + len(".default(")
depth := 1
inBacktick := false
i := start
for i < len(line) && depth > 0 {
ch := line[i]
if ch == '`' {
inBacktick = !inBacktick
} else if !inBacktick {
switch ch {
case '(':
depth++
case ')':
depth--
}
}
i++
}
if depth == 0 {
defaultValue := strings.TrimSpace(line[start : i-1])
r.parseDefaultValue(defaultValue, column)
}
}
// Check for .generatedAlwaysAsIdentity()
if strings.Contains(line, ".generatedAlwaysAsIdentity()") {
column.AutoIncrement = true
}
// Check for .references(() => otherTable.column)
referencesRegex := regexp.MustCompile(`\.references\(\(\)\s*=>\s*(\w+)\.(\w+)\)`)
if matches := referencesRegex.FindStringSubmatch(line); matches != nil {
refTableVar := matches[1]
refColumn := matches[2]
// Create FK constraint
constraintName := fmt.Sprintf("fk_%s_%s", table.Name, column.Name)
constraint := models.InitConstraint(constraintName, models.ForeignKeyConstraint)
constraint.Schema = table.Schema
constraint.Table = table.Name
constraint.Columns = []string{column.Name}
constraint.ReferencedSchema = table.Schema // Assume same schema
constraint.ReferencedTable = r.varNameToTableName(refTableVar)
constraint.ReferencedColumns = []string{refColumn}
table.Constraints[constraint.Name] = constraint
}
}
// parseDefaultValue parses a default value expression
func (r *Reader) parseDefaultValue(defaultExpr string, column *models.Column) {
defaultExpr = strings.TrimSpace(defaultExpr)
// Handle SQL expressions like sql`now()`
sqlRegex := regexp.MustCompile("sql`([^`]+)`")
if match := sqlRegex.FindStringSubmatch(defaultExpr); match != nil {
column.Default = match[1]
return
}
// Handle boolean values
if defaultExpr == "true" {
column.Default = true
return
}
if defaultExpr == "false" {
column.Default = false
return
}
// Handle string literals
if strings.HasPrefix(defaultExpr, "'") && strings.HasSuffix(defaultExpr, "'") {
column.Default = defaultExpr[1 : len(defaultExpr)-1]
return
}
if strings.HasPrefix(defaultExpr, "\"") && strings.HasSuffix(defaultExpr, "\"") {
column.Default = defaultExpr[1 : len(defaultExpr)-1]
return
}
// Try to parse as number
column.Default = defaultExpr
}
// parseIndexBlock parses the index callback block
func (r *Reader) parseIndexBlock(block string, table *models.Table, tableVarName string) {
// Split by lines
lines := strings.Split(block, "\n")
for _, line := range lines {
trimmed := strings.TrimSpace(line)
if trimmed == "" || strings.HasPrefix(trimmed, "//") {
continue
}
// Match: index('index_name').on(table.col1, table.col2)
// or: uniqueIndex('index_name').on(table.col1, table.col2)
indexRegex := regexp.MustCompile(`(uniqueIndex|index)\s*\(['"](\w+)['"]\)\s*\.on\s*\((.*?)\)`)
matches := indexRegex.FindStringSubmatch(trimmed)
if matches == nil {
continue
}
indexType := matches[1]
indexName := matches[2]
columnsStr := matches[3]
// Parse column list
columnParts := strings.Split(columnsStr, ",")
columns := make([]string, 0)
for _, part := range columnParts {
// Remove table prefix: table.column -> column
cleaned := strings.TrimSpace(part)
if strings.Contains(cleaned, ".") {
parts := strings.Split(cleaned, ".")
cleaned = parts[len(parts)-1]
}
columns = append(columns, cleaned)
}
if indexType == "uniqueIndex" {
// Create unique constraint
constraint := models.InitConstraint(indexName, models.UniqueConstraint)
constraint.Schema = table.Schema
constraint.Table = table.Name
constraint.Columns = columns
table.Constraints[constraint.Name] = constraint
} else {
// Create index
index := models.InitIndex(indexName, table.Name, table.Schema)
index.Columns = columns
index.Unique = false
table.Indexes[index.Name] = index
}
}
}
// varNameToTableName converts a variable name to a table name
// For now, just return as-is (could add inflection later)
func (r *Reader) varNameToTableName(varName string) string {
// TODO: Could add conversion logic here if needed
// For now, assume variable name matches table name
return varName
}