package drizzle import ( "bufio" "fmt" "os" "path/filepath" "regexp" "strings" "git.warky.dev/wdevs/relspecgo/pkg/models" "git.warky.dev/wdevs/relspecgo/pkg/readers" ) // Reader implements the readers.Reader interface for Drizzle schema format type Reader struct { options *readers.ReaderOptions } // NewReader creates a new Drizzle reader with the given options func NewReader(options *readers.ReaderOptions) *Reader { return &Reader{ options: options, } } // ReadDatabase reads and parses Drizzle schema input, returning a Database model func (r *Reader) ReadDatabase() (*models.Database, error) { if r.options.FilePath == "" { return nil, fmt.Errorf("file path is required for Drizzle reader") } // Check if it's a file or directory info, err := os.Stat(r.options.FilePath) if err != nil { return nil, fmt.Errorf("failed to stat path: %w", err) } if info.IsDir() { // Read all .ts files in the directory return r.readDirectory(r.options.FilePath) } // Read single file content, err := os.ReadFile(r.options.FilePath) if err != nil { return nil, fmt.Errorf("failed to read file: %w", err) } return r.parseDrizzle(string(content)) } // ReadSchema reads and parses Drizzle schema input, returning a Schema model func (r *Reader) ReadSchema() (*models.Schema, error) { db, err := r.ReadDatabase() if err != nil { return nil, err } if len(db.Schemas) == 0 { return nil, fmt.Errorf("no schemas found in Drizzle schema") } // Return the first schema return db.Schemas[0], nil } // ReadTable reads and parses Drizzle schema input, returning a Table model func (r *Reader) ReadTable() (*models.Table, error) { schema, err := r.ReadSchema() if err != nil { return nil, err } if len(schema.Tables) == 0 { return nil, fmt.Errorf("no tables found in Drizzle schema") } // Return the first table return schema.Tables[0], nil } // readDirectory reads all .ts files in a directory and parses them func (r *Reader) readDirectory(dirPath string) (*models.Database, error) { db := models.InitDatabase("database") if r.options.Metadata != nil { if name, ok := r.options.Metadata["name"].(string); ok { db.Name = name } } // Default schema for Drizzle schema := models.InitSchema("public") schema.Enums = make([]*models.Enum, 0) // Read all .ts files files, err := filepath.Glob(filepath.Join(dirPath, "*.ts")) if err != nil { return nil, fmt.Errorf("failed to glob directory: %w", err) } // Parse each file for _, file := range files { content, err := os.ReadFile(file) if err != nil { return nil, fmt.Errorf("failed to read file %s: %w", file, err) } // Parse and merge into schema fileDB, err := r.parseDrizzle(string(content)) if err != nil { return nil, fmt.Errorf("failed to parse file %s: %w", file, err) } // Merge schemas if len(fileDB.Schemas) > 0 { fileSchema := fileDB.Schemas[0] schema.Tables = append(schema.Tables, fileSchema.Tables...) schema.Enums = append(schema.Enums, fileSchema.Enums...) } } db.Schemas = append(db.Schemas, schema) return db, nil } // parseDrizzle parses Drizzle schema content and returns a Database model func (r *Reader) parseDrizzle(content string) (*models.Database, error) { db := models.InitDatabase("database") if r.options.Metadata != nil { if name, ok := r.options.Metadata["name"].(string); ok { db.Name = name } } // Default schema for Drizzle (PostgreSQL) schema := models.InitSchema("public") schema.Enums = make([]*models.Enum, 0) db.DatabaseType = models.PostgresqlDatabaseType scanner := bufio.NewScanner(strings.NewReader(content)) // Regex patterns // Match: export const users = pgTable('users', { pgTableRegex := regexp.MustCompile(`export\s+const\s+(\w+)\s*=\s*pgTable\s*\(\s*['"](\w+)['"]`) // Match: export const userRole = pgEnum('UserRole', ['admin', 'user']); pgEnumRegex := regexp.MustCompile(`export\s+const\s+(\w+)\s*=\s*pgEnum\s*\(\s*['"](\w+)['"]`) // State tracking var currentTable *models.Table var currentTableVarName string var inTableBlock bool var blockDepth int var tableLines []string for scanner.Scan() { line := scanner.Text() trimmed := strings.TrimSpace(line) // Skip empty lines and comments if trimmed == "" || strings.HasPrefix(trimmed, "//") { continue } // Check for pgEnum definition if matches := pgEnumRegex.FindStringSubmatch(trimmed); matches != nil { enum := r.parsePgEnum(trimmed, matches) if enum != nil { schema.Enums = append(schema.Enums, enum) } continue } // Check for pgTable definition if matches := pgTableRegex.FindStringSubmatch(trimmed); matches != nil { varName := matches[1] tableName := matches[2] currentTableVarName = varName currentTable = models.InitTable(tableName, "public") inTableBlock = true // Count braces in the first line blockDepth = strings.Count(line, "{") - strings.Count(line, "}") tableLines = []string{line} continue } // If we're in a table block, accumulate lines if inTableBlock { tableLines = append(tableLines, line) // Track brace depth blockDepth += strings.Count(line, "{") blockDepth -= strings.Count(line, "}") // Check if we've closed the table definition if blockDepth < 0 || (blockDepth == 0 && strings.Contains(line, ");")) { // Parse the complete table block if currentTable != nil { r.parseTableBlock(tableLines, currentTable, currentTableVarName) schema.Tables = append(schema.Tables, currentTable) currentTable = nil } inTableBlock = false tableLines = nil } } } db.Schemas = append(db.Schemas, schema) return db, nil } // parsePgEnum parses a pgEnum definition func (r *Reader) parsePgEnum(line string, matches []string) *models.Enum { // matches[1] = variable name // matches[2] = enum name enumName := matches[2] // Extract values from the array // Example: pgEnum('UserRole', ['admin', 'user', 'guest']) valuesRegex := regexp.MustCompile(`\[(.*?)\]`) valuesMatch := valuesRegex.FindStringSubmatch(line) if valuesMatch == nil { return nil } valuesStr := valuesMatch[1] // Split by comma and clean up valueParts := strings.Split(valuesStr, ",") values := make([]string, 0) for _, part := range valueParts { // Remove quotes and whitespace cleaned := strings.TrimSpace(part) cleaned = strings.Trim(cleaned, "'\"") if cleaned != "" { values = append(values, cleaned) } } enum := models.InitEnum(enumName, "public") enum.Values = values return enum } // parseTableBlock parses a complete pgTable definition block func (r *Reader) parseTableBlock(lines []string, table *models.Table, tableVarName string) { // Join all lines into a single string for easier parsing fullText := strings.Join(lines, "\n") // Extract the columns block and index callback separately // The structure is: pgTable('name', { columns }, (table) => [indexes]) // Find the main object block (columns) columnsStart := strings.Index(fullText, "{") if columnsStart == -1 { return } // Find matching closing brace for columns depth := 0 columnsEnd := -1 for i := columnsStart; i < len(fullText); i++ { if fullText[i] == '{' { depth++ } else if fullText[i] == '}' { depth-- if depth == 0 { columnsEnd = i break } } } if columnsEnd == -1 { return } columnsBlock := fullText[columnsStart+1 : columnsEnd] // Parse columns r.parseColumnsBlock(columnsBlock, table, tableVarName) // Check for index callback: , (table) => [ or , ({ col1, col2 }) => [ // Match: }, followed by arrow function with any parameters // Use (?s) flag to make . match newlines indexCallbackRegex := regexp.MustCompile(`(?s)}\s*,\s*\(.*?\)\s*=>\s*\[`) if indexCallbackRegex.MatchString(fullText[columnsEnd:]) { // Find the index array indexStart := strings.Index(fullText[columnsEnd:], "[") if indexStart != -1 { indexStart += columnsEnd indexDepth := 0 indexEnd := -1 for i := indexStart; i < len(fullText); i++ { if fullText[i] == '[' { indexDepth++ } else if fullText[i] == ']' { indexDepth-- if indexDepth == 0 { indexEnd = i break } } } if indexEnd != -1 { indexBlock := fullText[indexStart+1 : indexEnd] r.parseIndexBlock(indexBlock, table, tableVarName) } } } } // parseColumnsBlock parses the columns block of a table func (r *Reader) parseColumnsBlock(block string, table *models.Table, tableVarName string) { // Split by lines and parse each column definition lines := strings.Split(block, "\n") for _, line := range lines { trimmed := strings.TrimSpace(line) if trimmed == "" || strings.HasPrefix(trimmed, "//") { continue } // Match: fieldName: columnType('columnName').modifier().modifier(), // Example: id: integer('id').primaryKey(), columnRegex := regexp.MustCompile(`(\w+):\s*(\w+)\s*\(`) matches := columnRegex.FindStringSubmatch(trimmed) if matches == nil { continue } fieldName := matches[1] columnType := matches[2] // Parse the column definition col := r.parseColumnDefinition(trimmed, fieldName, columnType, table) if col != nil { table.Columns[col.Name] = col } } } // parseColumnDefinition parses a single column definition line func (r *Reader) parseColumnDefinition(line, fieldName, drizzleType string, table *models.Table) *models.Column { // Check for enum column syntax: pgEnum('EnumName')('column_name') enumRegex := regexp.MustCompile(`pgEnum\s*\(['"](\w+)['"]\)\s*\(['"](\w+)['"]\)`) if enumMatch := enumRegex.FindStringSubmatch(line); enumMatch != nil { enumName := enumMatch[1] columnName := enumMatch[2] column := models.InitColumn(columnName, table.Name, table.Schema) column.Type = enumName column.NotNull = false // Parse modifiers r.parseColumnModifiers(line, column, table) return column } // Extract column name from the first argument // Example: integer('id') nameRegex := regexp.MustCompile(`\w+\s*\(['"](\w+)['"]\)`) nameMatch := nameRegex.FindStringSubmatch(line) if nameMatch == nil { return nil } columnName := nameMatch[1] column := models.InitColumn(columnName, table.Name, table.Schema) // Map Drizzle type to SQL type column.Type = r.drizzleTypeToSQL(drizzleType) // Default: columns are nullable unless specified column.NotNull = false // Parse modifiers r.parseColumnModifiers(line, column, table) return column } // drizzleTypeToSQL converts Drizzle column types to SQL types func (r *Reader) drizzleTypeToSQL(drizzleType string) string { typeMap := map[string]string{ // Integer types "integer": "integer", "bigint": "bigint", "smallint": "smallint", // Serial types "serial": "serial", "bigserial": "bigserial", "smallserial": "smallserial", // Numeric types "numeric": "numeric", "real": "real", "doublePrecision": "double precision", // Character types "text": "text", "varchar": "varchar", "char": "char", // Boolean "boolean": "boolean", // Binary "bytea": "bytea", // JSON "json": "json", "jsonb": "jsonb", // Date/Time "time": "time", "timestamp": "timestamp", "date": "date", "interval": "interval", // UUID "uuid": "uuid", // Geometric "point": "point", "line": "line", } if sqlType, ok := typeMap[drizzleType]; ok { return sqlType } // If not found, might be an enum - return as-is return drizzleType } // parseColumnModifiers parses column modifiers like .primaryKey(), .notNull(), etc. func (r *Reader) parseColumnModifiers(line string, column *models.Column, table *models.Table) { // Check for .primaryKey() if strings.Contains(line, ".primaryKey()") { column.IsPrimaryKey = true column.NotNull = true } // Check for .notNull() if strings.Contains(line, ".notNull()") { column.NotNull = true } // Check for .unique() if strings.Contains(line, ".unique()") { uniqueConstraint := models.InitConstraint( fmt.Sprintf("uq_%s", column.Name), models.UniqueConstraint, ) uniqueConstraint.Schema = table.Schema uniqueConstraint.Table = table.Name uniqueConstraint.Columns = []string{column.Name} table.Constraints[uniqueConstraint.Name] = uniqueConstraint } // Check for .default(...) // Need to handle nested backticks and parentheses in SQL expressions defaultIdx := strings.Index(line, ".default(") if defaultIdx != -1 { start := defaultIdx + len(".default(") depth := 1 inBacktick := false i := start for i < len(line) && depth > 0 { ch := line[i] if ch == '`' { inBacktick = !inBacktick } else if !inBacktick { switch ch { case '(': depth++ case ')': depth-- } } i++ } if depth == 0 { defaultValue := strings.TrimSpace(line[start : i-1]) r.parseDefaultValue(defaultValue, column) } } // Check for .generatedAlwaysAsIdentity() if strings.Contains(line, ".generatedAlwaysAsIdentity()") { column.AutoIncrement = true } // Check for .references(() => otherTable.column) referencesRegex := regexp.MustCompile(`\.references\(\(\)\s*=>\s*(\w+)\.(\w+)\)`) if matches := referencesRegex.FindStringSubmatch(line); matches != nil { refTableVar := matches[1] refColumn := matches[2] // Create FK constraint constraintName := fmt.Sprintf("fk_%s_%s", table.Name, column.Name) constraint := models.InitConstraint(constraintName, models.ForeignKeyConstraint) constraint.Schema = table.Schema constraint.Table = table.Name constraint.Columns = []string{column.Name} constraint.ReferencedSchema = table.Schema // Assume same schema constraint.ReferencedTable = r.varNameToTableName(refTableVar) constraint.ReferencedColumns = []string{refColumn} table.Constraints[constraint.Name] = constraint } } // parseDefaultValue parses a default value expression func (r *Reader) parseDefaultValue(defaultExpr string, column *models.Column) { defaultExpr = strings.TrimSpace(defaultExpr) // Handle SQL expressions like sql`now()` sqlRegex := regexp.MustCompile("sql`([^`]+)`") if match := sqlRegex.FindStringSubmatch(defaultExpr); match != nil { column.Default = match[1] return } // Handle boolean values if defaultExpr == "true" { column.Default = true return } if defaultExpr == "false" { column.Default = false return } // Handle string literals if strings.HasPrefix(defaultExpr, "'") && strings.HasSuffix(defaultExpr, "'") { column.Default = defaultExpr[1 : len(defaultExpr)-1] return } if strings.HasPrefix(defaultExpr, "\"") && strings.HasSuffix(defaultExpr, "\"") { column.Default = defaultExpr[1 : len(defaultExpr)-1] return } // Try to parse as number column.Default = defaultExpr } // parseIndexBlock parses the index callback block func (r *Reader) parseIndexBlock(block string, table *models.Table, tableVarName string) { // Split by lines lines := strings.Split(block, "\n") for _, line := range lines { trimmed := strings.TrimSpace(line) if trimmed == "" || strings.HasPrefix(trimmed, "//") { continue } // Match: index('index_name').on(table.col1, table.col2) // or: uniqueIndex('index_name').on(table.col1, table.col2) indexRegex := regexp.MustCompile(`(uniqueIndex|index)\s*\(['"](\w+)['"]\)\s*\.on\s*\((.*?)\)`) matches := indexRegex.FindStringSubmatch(trimmed) if matches == nil { continue } indexType := matches[1] indexName := matches[2] columnsStr := matches[3] // Parse column list columnParts := strings.Split(columnsStr, ",") columns := make([]string, 0) for _, part := range columnParts { // Remove table prefix: table.column -> column cleaned := strings.TrimSpace(part) if strings.Contains(cleaned, ".") { parts := strings.Split(cleaned, ".") cleaned = parts[len(parts)-1] } columns = append(columns, cleaned) } if indexType == "uniqueIndex" { // Create unique constraint constraint := models.InitConstraint(indexName, models.UniqueConstraint) constraint.Schema = table.Schema constraint.Table = table.Name constraint.Columns = columns table.Constraints[constraint.Name] = constraint } else { // Create index index := models.InitIndex(indexName, table.Name, table.Schema) index.Columns = columns index.Unique = false table.Indexes[index.Name] = index } } } // varNameToTableName converts a variable name to a table name // For now, just return as-is (could add inflection later) func (r *Reader) varNameToTableName(varName string) string { // TODO: Could add conversion logic here if needed // For now, assume variable name matches table name return varName }