diff --git a/AI_USE.md b/AI_USE.md new file mode 100644 index 0000000..397857c --- /dev/null +++ b/AI_USE.md @@ -0,0 +1,35 @@ +# AI Usage Declaration + +This Go project utilizes AI tools for the following purposes: + +- Generating and improving documentation +- Writing and enhancing tests +- Refactoring and optimizing existing code + +AI is **not** used for core design or architecture decisions. +All design decisions are deferred to human discussion. +AI is employed only for enhancements to human-written code. + +We are aware of significant AI hallucinations; all AI-generated content is to be reviewed and verified by humans. + + + .-""""""-. + .' '. + / O O \ + : ` : + | | + : .------. : + \ ' ' / + '. .' + '-......-' + MEGAMIND AI + [============] + + ___________ + /___________\ + /_____________\ + | ASSIMILATE | + | RESISTANCE | + | IS FUTILE | + \_____________/ + \___________/ diff --git a/README.md b/README.md index 2741403..9627a50 100644 --- a/README.md +++ b/README.md @@ -30,6 +30,10 @@ RelSpec provides bidirectional conversion and comparison between various databas - **JSON** - Standard JSON schema output - **YAML** - Human-readable YAML format + +## Use of AI +[Rules and use of AI](./AI_USE.md) + ## Installation ```bash diff --git a/TODO.md b/TODO.md index 9eff2d8..7e35689 100644 --- a/TODO.md +++ b/TODO.md @@ -9,9 +9,10 @@ - [ ] MSSQL driver - [✔️] Foreign key detection - [✔️] Index extraction - - [ ] .sql file generation with sequence and priority + - [*] .sql file generation with sequence and priority - [✔️] .dbml: Database Markup Language (DBML) for textual schema representation. - [✔️] Prisma schema support (PSL format) .prisma +- [✔️] Drizzle ORM support .ts (TypeScript / JavaScript) (Mr. Edd wanted to move from Prisma to Drizzle. If you are bugs, you are welcome to do pull requests or issues) - [☠️] Entity Framework (.NET) model .edmx (Fuck no, EDMX files were bloated, verbose XML nightmares—hard to merge, error-prone, and a pain in teams. Microsoft wisely ditched them in EF Core for code-first. Classic overkill from old MS era.) - [✔️] TypeORM support - [] .hbm.xml / schema.xml: Hibernate/Propel mappings (Java/PHP) (💲 Someone can do this, not me) diff --git a/cmd/relspec/convert.go b/cmd/relspec/convert.go index a04c27d..7a6b54b 100644 --- a/cmd/relspec/convert.go +++ b/cmd/relspec/convert.go @@ -14,6 +14,7 @@ import ( "git.warky.dev/wdevs/relspecgo/pkg/readers/dbml" "git.warky.dev/wdevs/relspecgo/pkg/readers/dctx" "git.warky.dev/wdevs/relspecgo/pkg/readers/drawdb" + "git.warky.dev/wdevs/relspecgo/pkg/readers/drizzle" "git.warky.dev/wdevs/relspecgo/pkg/readers/gorm" "git.warky.dev/wdevs/relspecgo/pkg/readers/json" "git.warky.dev/wdevs/relspecgo/pkg/readers/pgsql" @@ -25,6 +26,7 @@ import ( wdbml "git.warky.dev/wdevs/relspecgo/pkg/writers/dbml" wdctx "git.warky.dev/wdevs/relspecgo/pkg/writers/dctx" wdrawdb "git.warky.dev/wdevs/relspecgo/pkg/writers/drawdb" + wdrizzle "git.warky.dev/wdevs/relspecgo/pkg/writers/drizzle" wgorm "git.warky.dev/wdevs/relspecgo/pkg/writers/gorm" wjson "git.warky.dev/wdevs/relspecgo/pkg/writers/json" wpgsql "git.warky.dev/wdevs/relspecgo/pkg/writers/pgsql" @@ -60,6 +62,7 @@ Input formats: - yaml: YAML database schema - gorm: GORM model files (Go, file or directory) - bun: Bun model files (Go, file or directory) + - drizzle: Drizzle ORM schema files (TypeScript, file or directory) - prisma: Prisma schema files (.prisma) - typeorm: TypeORM entity files (TypeScript) - pgsql: PostgreSQL database (live connection) @@ -72,6 +75,7 @@ Output formats: - yaml: YAML database schema - gorm: GORM model files (Go) - bun: Bun model files (Go) + - drizzle: Drizzle ORM schema files (TypeScript) - prisma: Prisma schema files (.prisma) - typeorm: TypeORM entity files (TypeScript) - pgsql: PostgreSQL SQL schema @@ -132,11 +136,11 @@ Examples: } func init() { - convertCmd.Flags().StringVar(&convertSourceType, "from", "", "Source format (dbml, dctx, drawdb, json, yaml, gorm, bun, prisma, typeorm, pgsql)") + convertCmd.Flags().StringVar(&convertSourceType, "from", "", "Source format (dbml, dctx, drawdb, json, yaml, gorm, bun, drizzle, prisma, typeorm, pgsql)") convertCmd.Flags().StringVar(&convertSourcePath, "from-path", "", "Source file path (for file-based formats)") convertCmd.Flags().StringVar(&convertSourceConn, "from-conn", "", "Source connection string (for database formats)") - convertCmd.Flags().StringVar(&convertTargetType, "to", "", "Target format (dbml, dctx, drawdb, json, yaml, gorm, bun, prisma, typeorm, pgsql)") + convertCmd.Flags().StringVar(&convertTargetType, "to", "", "Target format (dbml, dctx, drawdb, json, yaml, gorm, bun, drizzle, prisma, typeorm, pgsql)") convertCmd.Flags().StringVar(&convertTargetPath, "to-path", "", "Target output path (file or directory)") convertCmd.Flags().StringVar(&convertPackageName, "package", "", "Package name (for code generation formats like gorm/bun)") convertCmd.Flags().StringVar(&convertSchemaFilter, "schema", "", "Filter to a specific schema by name (required for formats like dctx that only support single schemas)") @@ -257,6 +261,12 @@ func readDatabaseForConvert(dbType, filePath, connString string) (*models.Databa } reader = bun.NewReader(&readers.ReaderOptions{FilePath: filePath}) + case "drizzle": + if filePath == "" { + return nil, fmt.Errorf("file path is required for Drizzle format") + } + reader = drizzle.NewReader(&readers.ReaderOptions{FilePath: filePath}) + case "prisma": if filePath == "" { return nil, fmt.Errorf("file path is required for Prisma format") @@ -317,6 +327,9 @@ func writeDatabase(db *models.Database, dbType, outputPath, packageName, schemaF } writer = wbun.NewWriter(writerOpts) + case "drizzle": + writer = wdrizzle.NewWriter(writerOpts) + case "pgsql", "postgres", "postgresql", "sql": writer = wpgsql.NewWriter(writerOpts) diff --git a/pkg/readers/drizzle/reader.go b/pkg/readers/drizzle/reader.go new file mode 100644 index 0000000..cf1e838 --- /dev/null +++ b/pkg/readers/drizzle/reader.go @@ -0,0 +1,627 @@ +package drizzle + +import ( + "bufio" + "fmt" + "os" + "path/filepath" + "regexp" + "strings" + + "git.warky.dev/wdevs/relspecgo/pkg/models" + "git.warky.dev/wdevs/relspecgo/pkg/readers" +) + +// min returns the minimum of two integers +func min(a, b int) int { + if a < b { + return a + } + return b +} + +// Reader implements the readers.Reader interface for Drizzle schema format +type Reader struct { + options *readers.ReaderOptions +} + +// NewReader creates a new Drizzle reader with the given options +func NewReader(options *readers.ReaderOptions) *Reader { + return &Reader{ + options: options, + } +} + +// ReadDatabase reads and parses Drizzle schema input, returning a Database model +func (r *Reader) ReadDatabase() (*models.Database, error) { + if r.options.FilePath == "" { + return nil, fmt.Errorf("file path is required for Drizzle reader") + } + + // Check if it's a file or directory + info, err := os.Stat(r.options.FilePath) + if err != nil { + return nil, fmt.Errorf("failed to stat path: %w", err) + } + + if info.IsDir() { + // Read all .ts files in the directory + return r.readDirectory(r.options.FilePath) + } + + // Read single file + content, err := os.ReadFile(r.options.FilePath) + if err != nil { + return nil, fmt.Errorf("failed to read file: %w", err) + } + + return r.parseDrizzle(string(content)) +} + +// ReadSchema reads and parses Drizzle schema input, returning a Schema model +func (r *Reader) ReadSchema() (*models.Schema, error) { + db, err := r.ReadDatabase() + if err != nil { + return nil, err + } + + if len(db.Schemas) == 0 { + return nil, fmt.Errorf("no schemas found in Drizzle schema") + } + + // Return the first schema + return db.Schemas[0], nil +} + +// ReadTable reads and parses Drizzle schema input, returning a Table model +func (r *Reader) ReadTable() (*models.Table, error) { + schema, err := r.ReadSchema() + if err != nil { + return nil, err + } + + if len(schema.Tables) == 0 { + return nil, fmt.Errorf("no tables found in Drizzle schema") + } + + // Return the first table + return schema.Tables[0], nil +} + +// readDirectory reads all .ts files in a directory and parses them +func (r *Reader) readDirectory(dirPath string) (*models.Database, error) { + db := models.InitDatabase("database") + + if r.options.Metadata != nil { + if name, ok := r.options.Metadata["name"].(string); ok { + db.Name = name + } + } + + // Default schema for Drizzle + schema := models.InitSchema("public") + schema.Enums = make([]*models.Enum, 0) + + // Read all .ts files + files, err := filepath.Glob(filepath.Join(dirPath, "*.ts")) + if err != nil { + return nil, fmt.Errorf("failed to glob directory: %w", err) + } + + // Parse each file + for _, file := range files { + content, err := os.ReadFile(file) + if err != nil { + return nil, fmt.Errorf("failed to read file %s: %w", file, err) + } + + // Parse and merge into schema + fileDB, err := r.parseDrizzle(string(content)) + if err != nil { + return nil, fmt.Errorf("failed to parse file %s: %w", file, err) + } + + // Merge schemas + if len(fileDB.Schemas) > 0 { + fileSchema := fileDB.Schemas[0] + schema.Tables = append(schema.Tables, fileSchema.Tables...) + schema.Enums = append(schema.Enums, fileSchema.Enums...) + } + } + + db.Schemas = append(db.Schemas, schema) + return db, nil +} + +// parseDrizzle parses Drizzle schema content and returns a Database model +func (r *Reader) parseDrizzle(content string) (*models.Database, error) { + db := models.InitDatabase("database") + + if r.options.Metadata != nil { + if name, ok := r.options.Metadata["name"].(string); ok { + db.Name = name + } + } + + // Default schema for Drizzle (PostgreSQL) + schema := models.InitSchema("public") + schema.Enums = make([]*models.Enum, 0) + db.DatabaseType = models.PostgresqlDatabaseType + + scanner := bufio.NewScanner(strings.NewReader(content)) + + // Regex patterns + // Match: export const users = pgTable('users', { + pgTableRegex := regexp.MustCompile(`export\s+const\s+(\w+)\s*=\s*pgTable\s*\(\s*['"](\w+)['"]`) + // Match: export const userRole = pgEnum('UserRole', ['admin', 'user']); + pgEnumRegex := regexp.MustCompile(`export\s+const\s+(\w+)\s*=\s*pgEnum\s*\(\s*['"](\w+)['"]`) + + // State tracking + var currentTable *models.Table + var currentTableVarName string + var inTableBlock bool + var blockDepth int + var tableLines []string + + for scanner.Scan() { + line := scanner.Text() + trimmed := strings.TrimSpace(line) + + // Skip empty lines and comments + if trimmed == "" || strings.HasPrefix(trimmed, "//") { + continue + } + + // Check for pgEnum definition + if matches := pgEnumRegex.FindStringSubmatch(trimmed); matches != nil { + enum := r.parsePgEnum(trimmed, matches) + if enum != nil { + schema.Enums = append(schema.Enums, enum) + } + continue + } + + // Check for pgTable definition + if matches := pgTableRegex.FindStringSubmatch(trimmed); matches != nil { + varName := matches[1] + tableName := matches[2] + + currentTableVarName = varName + currentTable = models.InitTable(tableName, "public") + inTableBlock = true + // Count braces in the first line + blockDepth = strings.Count(line, "{") - strings.Count(line, "}") + tableLines = []string{line} + continue + } + + // If we're in a table block, accumulate lines + if inTableBlock { + tableLines = append(tableLines, line) + + // Track brace depth + blockDepth += strings.Count(line, "{") + blockDepth -= strings.Count(line, "}") + + // Check if we've closed the table definition + if blockDepth < 0 || (blockDepth == 0 && strings.Contains(line, ");")) { + // Parse the complete table block + if currentTable != nil { + r.parseTableBlock(tableLines, currentTable, currentTableVarName) + schema.Tables = append(schema.Tables, currentTable) + currentTable = nil + } + inTableBlock = false + tableLines = nil + } + } + } + + db.Schemas = append(db.Schemas, schema) + return db, nil +} + +// parsePgEnum parses a pgEnum definition +func (r *Reader) parsePgEnum(line string, matches []string) *models.Enum { + // matches[1] = variable name + // matches[2] = enum name + + enumName := matches[2] + + // Extract values from the array + // Example: pgEnum('UserRole', ['admin', 'user', 'guest']) + valuesRegex := regexp.MustCompile(`\[(.*?)\]`) + valuesMatch := valuesRegex.FindStringSubmatch(line) + if valuesMatch == nil { + return nil + } + + valuesStr := valuesMatch[1] + // Split by comma and clean up + valueParts := strings.Split(valuesStr, ",") + values := make([]string, 0) + for _, part := range valueParts { + // Remove quotes and whitespace + cleaned := strings.TrimSpace(part) + cleaned = strings.Trim(cleaned, "'\"") + if cleaned != "" { + values = append(values, cleaned) + } + } + + return &models.Enum{ + Name: enumName, + Values: values, + Schema: "public", + } +} + +// parseTableBlock parses a complete pgTable definition block +func (r *Reader) parseTableBlock(lines []string, table *models.Table, tableVarName string) { + // Join all lines into a single string for easier parsing + fullText := strings.Join(lines, "\n") + + // Extract the columns block and index callback separately + // The structure is: pgTable('name', { columns }, (table) => [indexes]) + + // Find the main object block (columns) + columnsStart := strings.Index(fullText, "{") + if columnsStart == -1 { + return + } + + // Find matching closing brace for columns + depth := 0 + columnsEnd := -1 + for i := columnsStart; i < len(fullText); i++ { + if fullText[i] == '{' { + depth++ + } else if fullText[i] == '}' { + depth-- + if depth == 0 { + columnsEnd = i + break + } + } + } + + if columnsEnd == -1 { + return + } + + columnsBlock := fullText[columnsStart+1 : columnsEnd] + + // Parse columns + r.parseColumnsBlock(columnsBlock, table, tableVarName) + + // Check for index callback: , (table) => [ or , ({ col1, col2 }) => [ + // Match: }, followed by arrow function with any parameters + // Use (?s) flag to make . match newlines + indexCallbackRegex := regexp.MustCompile(`(?s)}\s*,\s*\(.*?\)\s*=>\s*\[`) + if indexCallbackRegex.MatchString(fullText[columnsEnd:]) { + // Find the index array + indexStart := strings.Index(fullText[columnsEnd:], "[") + if indexStart != -1 { + indexStart += columnsEnd + indexDepth := 0 + indexEnd := -1 + for i := indexStart; i < len(fullText); i++ { + if fullText[i] == '[' { + indexDepth++ + } else if fullText[i] == ']' { + indexDepth-- + if indexDepth == 0 { + indexEnd = i + break + } + } + } + + if indexEnd != -1 { + indexBlock := fullText[indexStart+1 : indexEnd] + r.parseIndexBlock(indexBlock, table, tableVarName) + } + } + } +} + +// parseColumnsBlock parses the columns block of a table +func (r *Reader) parseColumnsBlock(block string, table *models.Table, tableVarName string) { + // Split by lines and parse each column definition + lines := strings.Split(block, "\n") + + for _, line := range lines { + trimmed := strings.TrimSpace(line) + if trimmed == "" || strings.HasPrefix(trimmed, "//") { + continue + } + + // Match: fieldName: columnType('columnName').modifier().modifier(), + // Example: id: integer('id').primaryKey(), + columnRegex := regexp.MustCompile(`(\w+):\s*(\w+)\s*\(`) + matches := columnRegex.FindStringSubmatch(trimmed) + if matches == nil { + continue + } + + fieldName := matches[1] + columnType := matches[2] + + // Parse the column definition + col := r.parseColumnDefinition(trimmed, fieldName, columnType, table) + if col != nil { + table.Columns[col.Name] = col + } + } +} + +// parseColumnDefinition parses a single column definition line +func (r *Reader) parseColumnDefinition(line, fieldName, drizzleType string, table *models.Table) *models.Column { + // Check for enum column syntax: pgEnum('EnumName')('column_name') + enumRegex := regexp.MustCompile(`pgEnum\s*\(['"](\w+)['"]\)\s*\(['"](\w+)['"]\)`) + if enumMatch := enumRegex.FindStringSubmatch(line); enumMatch != nil { + enumName := enumMatch[1] + columnName := enumMatch[2] + + column := models.InitColumn(columnName, table.Name, table.Schema) + column.Type = enumName + column.NotNull = false + + // Parse modifiers + r.parseColumnModifiers(line, column, table) + return column + } + + // Extract column name from the first argument + // Example: integer('id') + nameRegex := regexp.MustCompile(`\w+\s*\(['"](\w+)['"]\)`) + nameMatch := nameRegex.FindStringSubmatch(line) + if nameMatch == nil { + return nil + } + + columnName := nameMatch[1] + column := models.InitColumn(columnName, table.Name, table.Schema) + + // Map Drizzle type to SQL type + column.Type = r.drizzleTypeToSQL(drizzleType) + + // Default: columns are nullable unless specified + column.NotNull = false + + // Parse modifiers + r.parseColumnModifiers(line, column, table) + + return column +} + +// drizzleTypeToSQL converts Drizzle column types to SQL types +func (r *Reader) drizzleTypeToSQL(drizzleType string) string { + typeMap := map[string]string{ + // Integer types + "integer": "integer", + "bigint": "bigint", + "smallint": "smallint", + + // Serial types + "serial": "serial", + "bigserial": "bigserial", + "smallserial": "smallserial", + + // Numeric types + "numeric": "numeric", + "real": "real", + "doublePrecision": "double precision", + + // Character types + "text": "text", + "varchar": "varchar", + "char": "char", + + // Boolean + "boolean": "boolean", + + // Binary + "bytea": "bytea", + + // JSON + "json": "json", + "jsonb": "jsonb", + + // Date/Time + "time": "time", + "timestamp": "timestamp", + "date": "date", + "interval": "interval", + + // UUID + "uuid": "uuid", + + // Geometric + "point": "point", + "line": "line", + } + + if sqlType, ok := typeMap[drizzleType]; ok { + return sqlType + } + + // If not found, might be an enum - return as-is + return drizzleType +} + +// parseColumnModifiers parses column modifiers like .primaryKey(), .notNull(), etc. +func (r *Reader) parseColumnModifiers(line string, column *models.Column, table *models.Table) { + // Check for .primaryKey() + if strings.Contains(line, ".primaryKey()") { + column.IsPrimaryKey = true + column.NotNull = true + } + + // Check for .notNull() + if strings.Contains(line, ".notNull()") { + column.NotNull = true + } + + // Check for .unique() + if strings.Contains(line, ".unique()") { + uniqueConstraint := models.InitConstraint( + fmt.Sprintf("uq_%s", column.Name), + models.UniqueConstraint, + ) + uniqueConstraint.Schema = table.Schema + uniqueConstraint.Table = table.Name + uniqueConstraint.Columns = []string{column.Name} + table.Constraints[uniqueConstraint.Name] = uniqueConstraint + } + + // Check for .default(...) + // Need to handle nested backticks and parentheses in SQL expressions + defaultIdx := strings.Index(line, ".default(") + if defaultIdx != -1 { + start := defaultIdx + len(".default(") + depth := 1 + inBacktick := false + i := start + + for i < len(line) && depth > 0 { + ch := line[i] + if ch == '`' { + inBacktick = !inBacktick + } else if !inBacktick { + switch ch { + case '(': + depth++ + case ')': + depth-- + } + } + i++ + } + + if depth == 0 { + defaultValue := strings.TrimSpace(line[start : i-1]) + r.parseDefaultValue(defaultValue, column) + } + } + + // Check for .generatedAlwaysAsIdentity() + if strings.Contains(line, ".generatedAlwaysAsIdentity()") { + column.AutoIncrement = true + } + + // Check for .references(() => otherTable.column) + referencesRegex := regexp.MustCompile(`\.references\(\(\)\s*=>\s*(\w+)\.(\w+)\)`) + if matches := referencesRegex.FindStringSubmatch(line); matches != nil { + refTableVar := matches[1] + refColumn := matches[2] + + // Create FK constraint + constraintName := fmt.Sprintf("fk_%s_%s", table.Name, column.Name) + constraint := models.InitConstraint(constraintName, models.ForeignKeyConstraint) + constraint.Schema = table.Schema + constraint.Table = table.Name + constraint.Columns = []string{column.Name} + constraint.ReferencedSchema = table.Schema // Assume same schema + constraint.ReferencedTable = r.varNameToTableName(refTableVar) + constraint.ReferencedColumns = []string{refColumn} + + table.Constraints[constraint.Name] = constraint + } +} + +// parseDefaultValue parses a default value expression +func (r *Reader) parseDefaultValue(defaultExpr string, column *models.Column) { + defaultExpr = strings.TrimSpace(defaultExpr) + + // Handle SQL expressions like sql`now()` + sqlRegex := regexp.MustCompile("sql`([^`]+)`") + if match := sqlRegex.FindStringSubmatch(defaultExpr); match != nil { + column.Default = match[1] + return + } + + // Handle boolean values + if defaultExpr == "true" { + column.Default = true + return + } + if defaultExpr == "false" { + column.Default = false + return + } + + // Handle string literals + if strings.HasPrefix(defaultExpr, "'") && strings.HasSuffix(defaultExpr, "'") { + column.Default = defaultExpr[1 : len(defaultExpr)-1] + return + } + if strings.HasPrefix(defaultExpr, "\"") && strings.HasSuffix(defaultExpr, "\"") { + column.Default = defaultExpr[1 : len(defaultExpr)-1] + return + } + + // Try to parse as number + column.Default = defaultExpr +} + +// parseIndexBlock parses the index callback block +func (r *Reader) parseIndexBlock(block string, table *models.Table, tableVarName string) { + // Split by lines + lines := strings.Split(block, "\n") + + for _, line := range lines { + trimmed := strings.TrimSpace(line) + if trimmed == "" || strings.HasPrefix(trimmed, "//") { + continue + } + + // Match: index('index_name').on(table.col1, table.col2) + // or: uniqueIndex('index_name').on(table.col1, table.col2) + indexRegex := regexp.MustCompile(`(uniqueIndex|index)\s*\(['"](\w+)['"]\)\s*\.on\s*\((.*?)\)`) + matches := indexRegex.FindStringSubmatch(trimmed) + if matches == nil { + continue + } + + indexType := matches[1] + indexName := matches[2] + columnsStr := matches[3] + + // Parse column list + columnParts := strings.Split(columnsStr, ",") + columns := make([]string, 0) + for _, part := range columnParts { + // Remove table prefix: table.column -> column + cleaned := strings.TrimSpace(part) + if strings.Contains(cleaned, ".") { + parts := strings.Split(cleaned, ".") + cleaned = parts[len(parts)-1] + } + columns = append(columns, cleaned) + } + + if indexType == "uniqueIndex" { + // Create unique constraint + constraint := models.InitConstraint(indexName, models.UniqueConstraint) + constraint.Schema = table.Schema + constraint.Table = table.Name + constraint.Columns = columns + table.Constraints[constraint.Name] = constraint + } else { + // Create index + index := models.InitIndex(indexName, table.Name, table.Schema) + index.Columns = columns + index.Unique = false + table.Indexes[index.Name] = index + } + } +} + +// varNameToTableName converts a variable name to a table name +// For now, just return as-is (could add inflection later) +func (r *Reader) varNameToTableName(varName string) string { + // TODO: Could add conversion logic here if needed + // For now, assume variable name matches table name + return varName +} diff --git a/pkg/writers/drizzle/template_data.go b/pkg/writers/drizzle/template_data.go new file mode 100644 index 0000000..0a4584a --- /dev/null +++ b/pkg/writers/drizzle/template_data.go @@ -0,0 +1,221 @@ +package drizzle + +import ( + "sort" + + "git.warky.dev/wdevs/relspecgo/pkg/models" +) + +// TemplateData represents the data passed to the template for code generation +type TemplateData struct { + Imports []string + Enums []*EnumData + Tables []*TableData +} + +// EnumData represents an enum in the schema +type EnumData struct { + Name string // Enum name (PascalCase) + VarName string // Variable name for the enum (camelCase) + Values []string // Enum values + ValuesStr string // Comma-separated quoted values for pgEnum() + TypeUnion string // TypeScript union type (e.g., "'admin' | 'user' | 'guest'") + SchemaName string // Schema name +} + +// TableData represents a table in the template +type TableData struct { + Name string // Table variable name (camelCase, e.g., users) + TableName string // Actual database table name (e.g., users) + TypeName string // TypeScript type name (PascalCase, e.g., Users) + Columns []*ColumnData // Column definitions + Indexes []*IndexData // Index definitions + Comment string // Table comment + SchemaName string // Schema name + NeedsSQLTag bool // Whether we need to import 'sql' from drizzle-orm + IndexColumnFields []string // Column field names used in indexes (for destructuring) +} + +// ColumnData represents a column in a table +type ColumnData struct { + Name string // Column name in database + FieldName string // Field name in TypeScript (camelCase) + DrizzleChain string // Complete Drizzle column chain (e.g., "integer('id').primaryKey()") + TypeScriptType string // TypeScript type for interface (e.g., "string", "number | null") + IsForeignKey bool // Whether this is a foreign key + ReferencesLine string // The .references() line if FK + Comment string // Column comment +} + +// IndexData represents an index definition +type IndexData struct { + Name string // Index name + Columns []string // Column names + IsUnique bool // Whether it's a unique index + Definition string // Complete index definition line +} + +// NewTemplateData creates a new TemplateData +func NewTemplateData() *TemplateData { + return &TemplateData{ + Imports: make([]string, 0), + Enums: make([]*EnumData, 0), + Tables: make([]*TableData, 0), + } +} + +// AddImport adds an import to the template data (deduplicates automatically) +func (td *TemplateData) AddImport(importLine string) { + // Check if already exists + for _, imp := range td.Imports { + if imp == importLine { + return + } + } + td.Imports = append(td.Imports, importLine) +} + +// AddEnum adds an enum to the template data +func (td *TemplateData) AddEnum(enum *EnumData) { + td.Enums = append(td.Enums, enum) +} + +// AddTable adds a table to the template data +func (td *TemplateData) AddTable(table *TableData) { + td.Tables = append(td.Tables, table) +} + +// FinalizeImports sorts imports +func (td *TemplateData) FinalizeImports() { + sort.Strings(td.Imports) +} + +// NewEnumData creates EnumData from a models.Enum +func NewEnumData(enum *models.Enum, tm *TypeMapper) *EnumData { + // Keep enum name as-is (it should already be PascalCase from the source) + enumName := enum.Name + // Variable name is camelCase version + varName := tm.ToCamelCase(enum.Name) + + // Format values as comma-separated quoted strings for pgEnum() + quotedValues := make([]string, len(enum.Values)) + for i, v := range enum.Values { + quotedValues[i] = "'" + v + "'" + } + valuesStr := "" + for i, qv := range quotedValues { + if i > 0 { + valuesStr += ", " + } + valuesStr += qv + } + + // Build TypeScript union type (e.g., "'admin' | 'user' | 'guest'") + typeUnion := "" + for i, qv := range quotedValues { + if i > 0 { + typeUnion += " | " + } + typeUnion += qv + } + + return &EnumData{ + Name: enumName, + VarName: varName, + Values: enum.Values, + ValuesStr: valuesStr, + TypeUnion: typeUnion, + SchemaName: enum.Schema, + } +} + +// NewTableData creates TableData from a models.Table +func NewTableData(table *models.Table, tm *TypeMapper) *TableData { + tableName := tm.ToCamelCase(table.Name) + typeName := tm.ToPascalCase(table.Name) + + return &TableData{ + Name: tableName, + TableName: table.Name, + TypeName: typeName, + Columns: make([]*ColumnData, 0), + Indexes: make([]*IndexData, 0), + Comment: formatComment(table.Description, table.Comment), + SchemaName: table.Schema, + } +} + +// AddColumn adds a column to the table data +func (td *TableData) AddColumn(col *ColumnData) { + td.Columns = append(td.Columns, col) +} + +// AddIndex adds an index to the table data +func (td *TableData) AddIndex(idx *IndexData) { + td.Indexes = append(td.Indexes, idx) +} + +// NewColumnData creates ColumnData from a models.Column +func NewColumnData(col *models.Column, table *models.Table, tm *TypeMapper, isEnum bool) *ColumnData { + fieldName := tm.ToCamelCase(col.Name) + drizzleChain := tm.BuildColumnChain(col, table, isEnum) + + return &ColumnData{ + Name: col.Name, + FieldName: fieldName, + DrizzleChain: drizzleChain, + Comment: formatComment(col.Description, col.Comment), + } +} + +// NewIndexData creates IndexData from a models.Index +func NewIndexData(index *models.Index, tableVar string, tm *TypeMapper) *IndexData { + indexName := tm.ToCamelCase(index.Name) + "Idx" + + // Build column references as field names (will be used with destructuring) + colRefs := make([]string, len(index.Columns)) + for i, colName := range index.Columns { + // Use just the field name for destructured parameters + colRefs[i] = tm.ToCamelCase(colName) + } + + // Build the complete definition + // Example: index('email_idx').on(email) + // or: uniqueIndex('unique_email_idx').on(email) + definition := "" + if index.Unique { + definition = "uniqueIndex('" + index.Name + "').on(" + joinStrings(colRefs, ", ") + ")" + } else { + definition = "index('" + index.Name + "').on(" + joinStrings(colRefs, ", ") + ")" + } + + return &IndexData{ + Name: indexName, + Columns: index.Columns, + IsUnique: index.Unique, + Definition: definition, + } +} + +// formatComment combines description and comment into a single comment string +func formatComment(description, comment string) string { + if description != "" && comment != "" { + return description + " - " + comment + } + if description != "" { + return description + } + return comment +} + +// joinStrings joins a slice of strings with a separator +func joinStrings(strs []string, sep string) string { + result := "" + for i, s := range strs { + if i > 0 { + result += sep + } + result += s + } + return result +} diff --git a/pkg/writers/drizzle/templates.go b/pkg/writers/drizzle/templates.go new file mode 100644 index 0000000..6cc0c3a --- /dev/null +++ b/pkg/writers/drizzle/templates.go @@ -0,0 +1,64 @@ +package drizzle + +import ( + "bytes" + "text/template" +) + +// schemaTemplate defines the template for generating Drizzle schemas +const schemaTemplate = `// Code generated by relspecgo. DO NOT EDIT. +{{range .Imports}}{{.}} +{{end}} +{{if .Enums}} +// Enums +{{range .Enums}}export const {{.VarName}} = pgEnum('{{.Name}}', [{{.ValuesStr}}]); +export type {{.Name}} = {{.TypeUnion}}; +{{end}} +{{end}} +{{range .Tables}}// Table: {{.TableName}}{{if .Comment}} - {{.Comment}}{{end}} +export interface {{.TypeName}} { +{{- range $i, $col := .Columns}} + {{$col.FieldName}}: {{$col.TypeScriptType}};{{if $col.Comment}} // {{$col.Comment}}{{end}} +{{- end}} +} + +export const {{.Name}} = pgTable('{{.TableName}}', { +{{- range $i, $col := .Columns}} + {{$col.FieldName}}: {{$col.DrizzleChain}},{{if $col.Comment}} // {{$col.Comment}}{{end}} +{{- end}} +}{{if .Indexes}}{{if .IndexColumnFields}}, ({ {{range $i, $field := .IndexColumnFields}}{{if $i}}, {{end}}{{$field}}{{end}} }) => [{{else}}, (table) => [{{end}} +{{- range $i, $idx := .Indexes}} + {{$idx.Definition}}, +{{- end}} +]{{end}}); + +export type New{{.TypeName}} = typeof {{.Name}}.$inferInsert; +{{end}}` + +// Templates holds the parsed templates +type Templates struct { + schemaTmpl *template.Template +} + +// NewTemplates creates and parses the templates +func NewTemplates() (*Templates, error) { + schemaTmpl, err := template.New("schema").Parse(schemaTemplate) + if err != nil { + return nil, err + } + + return &Templates{ + schemaTmpl: schemaTmpl, + }, nil +} + +// GenerateCode executes the template with the given data +func (t *Templates) GenerateCode(data *TemplateData) (string, error) { + var buf bytes.Buffer + err := t.schemaTmpl.Execute(&buf, data) + if err != nil { + return "", err + } + + return buf.String(), nil +} diff --git a/pkg/writers/drizzle/type_mapper.go b/pkg/writers/drizzle/type_mapper.go new file mode 100644 index 0000000..97998bd --- /dev/null +++ b/pkg/writers/drizzle/type_mapper.go @@ -0,0 +1,318 @@ +package drizzle + +import ( + "fmt" + "strings" + + "git.warky.dev/wdevs/relspecgo/pkg/models" +) + +// TypeMapper handles SQL to Drizzle type conversions +type TypeMapper struct{} + +// NewTypeMapper creates a new TypeMapper instance +func NewTypeMapper() *TypeMapper { + return &TypeMapper{} +} + +// SQLTypeToDrizzle converts SQL types to Drizzle column type functions +// Returns the Drizzle column constructor (e.g., "integer", "varchar", "text") +func (tm *TypeMapper) SQLTypeToDrizzle(sqlType string) string { + sqlTypeLower := strings.ToLower(sqlType) + + // PostgreSQL type mapping to Drizzle + typeMap := map[string]string{ + // Integer types + "integer": "integer", + "int": "integer", + "int4": "integer", + "smallint": "smallint", + "int2": "smallint", + "bigint": "bigint", + "int8": "bigint", + + // Serial types + "serial": "serial", + "serial4": "serial", + "smallserial": "smallserial", + "serial2": "smallserial", + "bigserial": "bigserial", + "serial8": "bigserial", + + // Numeric types + "numeric": "numeric", + "decimal": "numeric", + "real": "real", + "float4": "real", + "double precision": "doublePrecision", + "float": "doublePrecision", + "float8": "doublePrecision", + + // Character types + "text": "text", + "varchar": "varchar", + "character varying": "varchar", + "char": "char", + "character": "char", + + // Boolean + "boolean": "boolean", + "bool": "boolean", + + // Binary + "bytea": "bytea", + + // JSON types + "json": "json", + "jsonb": "jsonb", + + // Date/Time types + "time": "time", + "timetz": "time", + "timestamp": "timestamp", + "timestamptz": "timestamp", + "date": "date", + "interval": "interval", + + // UUID + "uuid": "uuid", + + // Geometric types + "point": "point", + "line": "line", + } + + // Check for exact match first + if drizzleType, ok := typeMap[sqlTypeLower]; ok { + return drizzleType + } + + // Check for partial matches (e.g., "varchar(255)" -> "varchar") + for sqlPattern, drizzleType := range typeMap { + if strings.HasPrefix(sqlTypeLower, sqlPattern) { + return drizzleType + } + } + + // Default to text for unknown types + return "text" +} + +// BuildColumnChain builds the complete column definition chain for Drizzle +// Example: integer('id').primaryKey().notNull() +func (tm *TypeMapper) BuildColumnChain(col *models.Column, table *models.Table, isEnum bool) string { + var parts []string + + // Determine Drizzle column type + var drizzleType string + if isEnum { + // For enum types, use the type name directly + drizzleType = fmt.Sprintf("pgEnum('%s')", col.Type) + } else { + drizzleType = tm.SQLTypeToDrizzle(col.Type) + } + + // Start with column type and name + // Note: column name is passed as first argument to the column constructor + base := fmt.Sprintf("%s('%s')", drizzleType, col.Name) + parts = append(parts, base) + + // Add column modifiers in order + modifiers := tm.buildColumnModifiers(col, table) + if len(modifiers) > 0 { + parts = append(parts, modifiers...) + } + + return strings.Join(parts, ".") +} + +// buildColumnModifiers builds an array of method calls for column modifiers +func (tm *TypeMapper) buildColumnModifiers(col *models.Column, table *models.Table) []string { + var modifiers []string + + // Primary key + if col.IsPrimaryKey { + modifiers = append(modifiers, "primaryKey()") + } + + // Not null constraint + if col.NotNull && !col.IsPrimaryKey { + modifiers = append(modifiers, "notNull()") + } + + // Unique constraint (check if there's a single-column unique constraint) + if tm.hasUniqueConstraint(col.Name, table) { + modifiers = append(modifiers, "unique()") + } + + // Default value + if col.AutoIncrement { + // For auto-increment, use generatedAlwaysAsIdentity() + modifiers = append(modifiers, "generatedAlwaysAsIdentity()") + } else if col.Default != nil { + defaultValue := tm.formatDefaultValue(col.Default) + if defaultValue != "" { + modifiers = append(modifiers, fmt.Sprintf("default(%s)", defaultValue)) + } + } + + return modifiers +} + +// formatDefaultValue formats a default value for Drizzle +func (tm *TypeMapper) formatDefaultValue(defaultValue any) string { + switch v := defaultValue.(type) { + case string: + if v == "now()" || v == "CURRENT_TIMESTAMP" { + return "sql`now()`" + } else if v == "gen_random_uuid()" || strings.Contains(strings.ToLower(v), "uuid") { + return "sql`gen_random_uuid()`" + } else { + // Try to parse as number first + // Check if it's a numeric string that should be a number + if isNumericString(v) { + return v + } + // String literal + return fmt.Sprintf("'%s'", strings.ReplaceAll(v, "'", "\\'")) + } + case bool: + if v { + return "true" + } + return "false" + case int, int64, int32, int16, int8: + return fmt.Sprintf("%v", v) + case float32, float64: + return fmt.Sprintf("%v", v) + default: + return fmt.Sprintf("%v", v) + } +} + +// isNumericString checks if a string represents a number +func isNumericString(s string) bool { + if s == "" { + return false + } + // Simple check for numeric strings + for i, c := range s { + if i == 0 && c == '-' { + continue // Allow negative sign at start + } + if c < '0' || c > '9' { + if c != '.' { + return false + } + } + } + return true +} + +// hasUniqueConstraint checks if a column has a unique constraint +func (tm *TypeMapper) hasUniqueConstraint(colName string, table *models.Table) bool { + for _, constraint := range table.Constraints { + if constraint.Type == models.UniqueConstraint && + len(constraint.Columns) == 1 && + constraint.Columns[0] == colName { + return true + } + } + return false +} + +// BuildReferencesChain builds the .references() chain for foreign key columns +func (tm *TypeMapper) BuildReferencesChain(fk *models.Constraint, referencedTable string) string { + // Example: .references(() => users.id) + if len(fk.ReferencedColumns) > 0 { + // Use the referenced table variable name (camelCase) + refTableVar := tm.ToCamelCase(referencedTable) + refColumn := fk.ReferencedColumns[0] + return fmt.Sprintf("references(() => %s.%s)", refTableVar, refColumn) + } + return "" +} + +// ToCamelCase converts snake_case or PascalCase to camelCase +func (tm *TypeMapper) ToCamelCase(s string) string { + if s == "" { + return s + } + + // Check if it's snake_case + if strings.Contains(s, "_") { + parts := strings.Split(s, "_") + if len(parts) == 0 { + return s + } + + // First part stays lowercase + result := strings.ToLower(parts[0]) + + // Capitalize first letter of remaining parts + for i := 1; i < len(parts); i++ { + if len(parts[i]) > 0 { + result += strings.ToUpper(parts[i][:1]) + strings.ToLower(parts[i][1:]) + } + } + + return result + } + + // Otherwise, assume it's PascalCase - just lowercase the first letter + return strings.ToLower(s[:1]) + s[1:] +} + +// ToPascalCase converts snake_case to PascalCase +func (tm *TypeMapper) ToPascalCase(s string) string { + parts := strings.Split(s, "_") + var result string + + for _, part := range parts { + if len(part) > 0 { + result += strings.ToUpper(part[:1]) + strings.ToLower(part[1:]) + } + } + + return result +} + +// DrizzleTypeToTypeScript converts Drizzle column types to TypeScript types +func (tm *TypeMapper) DrizzleTypeToTypeScript(drizzleType string, isEnum bool, enumName string) string { + if isEnum { + return enumName + } + + typeMap := map[string]string{ + "integer": "number", + "bigint": "number", + "smallint": "number", + "serial": "number", + "bigserial": "number", + "smallserial": "number", + "numeric": "number", + "real": "number", + "doublePrecision": "number", + "text": "string", + "varchar": "string", + "char": "string", + "boolean": "boolean", + "bytea": "Buffer", + "json": "any", + "jsonb": "any", + "timestamp": "Date", + "date": "Date", + "time": "Date", + "interval": "string", + "uuid": "string", + "point": "{ x: number; y: number }", + "line": "{ a: number; b: number; c: number }", + } + + if tsType, ok := typeMap[drizzleType]; ok { + return tsType + } + + // Default to any for unknown types + return "any" +} diff --git a/pkg/writers/drizzle/writer.go b/pkg/writers/drizzle/writer.go new file mode 100644 index 0000000..03d95a8 --- /dev/null +++ b/pkg/writers/drizzle/writer.go @@ -0,0 +1,543 @@ +package drizzle + +import ( + "fmt" + "os" + "path/filepath" + "strings" + + "git.warky.dev/wdevs/relspecgo/pkg/models" + "git.warky.dev/wdevs/relspecgo/pkg/writers" +) + +// Writer implements the writers.Writer interface for Drizzle ORM +type Writer struct { + options *writers.WriterOptions + typeMapper *TypeMapper + templates *Templates +} + +// NewWriter creates a new Drizzle writer with the given options +func NewWriter(options *writers.WriterOptions) *Writer { + w := &Writer{ + options: options, + typeMapper: NewTypeMapper(), + } + + // Initialize templates + tmpl, err := NewTemplates() + if err != nil { + // Should not happen with embedded templates + panic(fmt.Sprintf("failed to initialize templates: %v", err)) + } + w.templates = tmpl + + return w +} + +// WriteDatabase writes a complete database as Drizzle schema +func (w *Writer) WriteDatabase(db *models.Database) error { + // Check if multi-file mode is enabled + multiFile := w.shouldUseMultiFile() + + if multiFile { + return w.writeMultiFile(db) + } + + return w.writeSingleFile(db) +} + +// WriteSchema writes a schema as Drizzle schema +func (w *Writer) WriteSchema(schema *models.Schema) error { + // Create a temporary database with just this schema + db := models.InitDatabase(schema.Name) + db.Schemas = []*models.Schema{schema} + + return w.WriteDatabase(db) +} + +// WriteTable writes a single table as a Drizzle schema +func (w *Writer) WriteTable(table *models.Table) error { + // Create a temporary schema and database + schema := models.InitSchema(table.Schema) + schema.Tables = []*models.Table{table} + + db := models.InitDatabase(schema.Name) + db.Schemas = []*models.Schema{schema} + + return w.WriteDatabase(db) +} + +// writeSingleFile writes all tables to a single file +func (w *Writer) writeSingleFile(db *models.Database) error { + templateData := NewTemplateData() + + // Build enum map for quick lookup + enumMap := w.buildEnumMap(db) + + // Process all schemas + for _, schema := range db.Schemas { + // Add enums + for _, enum := range schema.Enums { + enumData := NewEnumData(enum, w.typeMapper) + templateData.AddEnum(enumData) + } + + // Add tables + for _, table := range schema.Tables { + tableData := w.buildTableData(table, schema, db, enumMap) + templateData.AddTable(tableData) + } + } + + // Add imports + w.addImports(templateData, db) + + // Finalize imports + templateData.FinalizeImports() + + // Generate code + code, err := w.templates.GenerateCode(templateData) + if err != nil { + return fmt.Errorf("failed to generate code: %w", err) + } + + // Write output + return w.writeOutput(code) +} + +// writeMultiFile writes each table to a separate file +func (w *Writer) writeMultiFile(db *models.Database) error { + // Ensure output path is a directory + if w.options.OutputPath == "" { + return fmt.Errorf("output path is required for multi-file mode") + } + + // Create output directory if it doesn't exist + if err := os.MkdirAll(w.options.OutputPath, 0755); err != nil { + return fmt.Errorf("failed to create output directory: %w", err) + } + + // Build enum map for quick lookup + enumMap := w.buildEnumMap(db) + + // Process all schemas + for _, schema := range db.Schemas { + // Write enums file if there are any + if len(schema.Enums) > 0 { + if err := w.writeEnumsFile(schema); err != nil { + return err + } + } + + // Write each table to a separate file + for _, table := range schema.Tables { + if err := w.writeTableFile(table, schema, db, enumMap); err != nil { + return err + } + } + } + + return nil +} + +// writeEnumsFile writes all enums to a separate file +func (w *Writer) writeEnumsFile(schema *models.Schema) error { + templateData := NewTemplateData() + + // Add enums + for _, enum := range schema.Enums { + enumData := NewEnumData(enum, w.typeMapper) + templateData.AddEnum(enumData) + } + + // Add imports for enums + templateData.AddImport("import { pgEnum } from 'drizzle-orm/pg-core';") + + // Generate code + code, err := w.templates.GenerateCode(templateData) + if err != nil { + return fmt.Errorf("failed to generate enums code: %w", err) + } + + // Write to enums.ts file + filename := filepath.Join(w.options.OutputPath, "enums.ts") + return os.WriteFile(filename, []byte(code), 0644) +} + +// writeTableFile writes a single table to its own file +func (w *Writer) writeTableFile(table *models.Table, schema *models.Schema, db *models.Database, enumMap map[string]bool) error { + templateData := NewTemplateData() + + // Build table data + tableData := w.buildTableData(table, schema, db, enumMap) + templateData.AddTable(tableData) + + // Add imports + w.addImports(templateData, db) + + // If there are enums, add import from enums file + if len(schema.Enums) > 0 && w.tableUsesEnum(table, enumMap) { + // Import enum definitions from enums.ts + enumNames := w.getTableEnumNames(table, schema, enumMap) + if len(enumNames) > 0 { + importLine := fmt.Sprintf("import { %s } from './enums';", strings.Join(enumNames, ", ")) + templateData.AddImport(importLine) + } + } + + // Finalize imports + templateData.FinalizeImports() + + // Generate code + code, err := w.templates.GenerateCode(templateData) + if err != nil { + return fmt.Errorf("failed to generate code for table %s: %w", table.Name, err) + } + + // Generate filename: {tableName}.ts + filename := filepath.Join(w.options.OutputPath, table.Name+".ts") + return os.WriteFile(filename, []byte(code), 0644) +} + +// buildTableData builds TableData from a models.Table +func (w *Writer) buildTableData(table *models.Table, schema *models.Schema, db *models.Database, enumMap map[string]bool) *TableData { + tableData := NewTableData(table, w.typeMapper) + + // Add columns + for _, colName := range w.getSortedColumnNames(table) { + col := table.Columns[colName] + + // Check if this column uses an enum + isEnum := enumMap[col.Type] + + columnData := NewColumnData(col, table, w.typeMapper, isEnum) + + // Set TypeScript type + drizzleType := w.typeMapper.SQLTypeToDrizzle(col.Type) + enumName := "" + if isEnum { + // For enums, use the enum type name + enumName = col.Type + } + baseType := w.typeMapper.DrizzleTypeToTypeScript(drizzleType, isEnum, enumName) + + // Add null union if column is nullable + if !col.NotNull && !col.IsPrimaryKey { + columnData.TypeScriptType = baseType + " | null" + } else { + columnData.TypeScriptType = baseType + } + + // Check if this column is a foreign key + if fk := w.getForeignKeyForColumn(col.Name, table); fk != nil { + columnData.IsForeignKey = true + refTableName := fk.ReferencedTable + refChain := w.typeMapper.BuildReferencesChain(fk, refTableName) + if refChain != "" { + columnData.ReferencesLine = "." + refChain + // Append to the drizzle chain + columnData.DrizzleChain += columnData.ReferencesLine + } + } + + tableData.AddColumn(columnData) + } + + // Collect all column field names that are used in indexes + indexColumnFields := make(map[string]bool) + + // Add indexes (excluding single-column unique indexes, which are handled inline) + for _, index := range table.Indexes { + // Skip single-column unique indexes (handled by .unique() modifier) + if index.Unique && len(index.Columns) == 1 { + continue + } + + // Track which columns are used in indexes + for _, colName := range index.Columns { + // Find the field name for this column + if col, exists := table.Columns[colName]; exists { + fieldName := w.typeMapper.ToCamelCase(col.Name) + indexColumnFields[fieldName] = true + } + } + + indexData := NewIndexData(index, tableData.Name, w.typeMapper) + tableData.AddIndex(indexData) + } + + // Add multi-column unique constraints as unique indexes + for _, constraint := range table.Constraints { + if constraint.Type == models.UniqueConstraint && len(constraint.Columns) > 1 { + // Create a unique index for this constraint + indexData := &IndexData{ + Name: w.typeMapper.ToCamelCase(constraint.Name) + "Idx", + Columns: constraint.Columns, + IsUnique: true, + } + + // Track which columns are used in indexes + for _, colName := range constraint.Columns { + if col, exists := table.Columns[colName]; exists { + fieldName := w.typeMapper.ToCamelCase(col.Name) + indexColumnFields[fieldName] = true + } + } + + // Build column references as field names (for destructuring) + colRefs := make([]string, len(constraint.Columns)) + for i, colName := range constraint.Columns { + if col, exists := table.Columns[colName]; exists { + colRefs[i] = w.typeMapper.ToCamelCase(col.Name) + } else { + colRefs[i] = w.typeMapper.ToCamelCase(colName) + } + } + + indexData.Definition = "uniqueIndex('" + constraint.Name + "').on(" + joinStrings(colRefs, ", ") + ")" + tableData.AddIndex(indexData) + } + } + + // Convert index column fields map to sorted slice + if len(indexColumnFields) > 0 { + fields := make([]string, 0, len(indexColumnFields)) + for field := range indexColumnFields { + fields = append(fields, field) + } + // Sort for consistent output + sortStrings(fields) + tableData.IndexColumnFields = fields + } + + return tableData +} + +// sortStrings sorts a slice of strings in place +func sortStrings(strs []string) { + for i := 0; i < len(strs); i++ { + for j := i + 1; j < len(strs); j++ { + if strs[i] > strs[j] { + strs[i], strs[j] = strs[j], strs[i] + } + } + } +} + +// addImports adds the necessary imports to the template data +func (w *Writer) addImports(templateData *TemplateData, db *models.Database) { + // Determine which Drizzle imports we need + needsPgTable := len(templateData.Tables) > 0 + needsPgEnum := len(templateData.Enums) > 0 + needsIndex := false + needsUniqueIndex := false + needsSQL := false + + // Check what we need based on tables + for _, table := range templateData.Tables { + for _, index := range table.Indexes { + if index.IsUnique { + needsUniqueIndex = true + } else { + needsIndex = true + } + } + + // Check if any column uses SQL default values + for _, col := range table.Columns { + if strings.Contains(col.DrizzleChain, "sql`") { + needsSQL = true + } + } + } + + // Build the import statement + imports := make([]string, 0) + + if needsPgTable { + imports = append(imports, "pgTable") + } + if needsPgEnum { + imports = append(imports, "pgEnum") + } + + // Add column types - for now, add common ones + // TODO: Could be optimized to only include used types + columnTypes := []string{ + "integer", "bigint", "smallint", + "serial", "bigserial", "smallserial", + "text", "varchar", "char", + "boolean", "numeric", "real", "doublePrecision", + "timestamp", "date", "time", "interval", + "json", "jsonb", "uuid", "bytea", + } + imports = append(imports, columnTypes...) + + if needsIndex { + imports = append(imports, "index") + } + if needsUniqueIndex { + imports = append(imports, "uniqueIndex") + } + + importLine := "import { " + strings.Join(imports, ", ") + " } from 'drizzle-orm/pg-core';" + templateData.AddImport(importLine) + + // Add SQL import if needed + if needsSQL { + templateData.AddImport("import { sql } from 'drizzle-orm';") + } +} + +// buildEnumMap builds a map of enum type names for quick lookup +func (w *Writer) buildEnumMap(db *models.Database) map[string]bool { + enumMap := make(map[string]bool) + + for _, schema := range db.Schemas { + for _, enum := range schema.Enums { + enumMap[enum.Name] = true + // Also add lowercase version for case-insensitive lookup + enumMap[strings.ToLower(enum.Name)] = true + } + } + + return enumMap +} + +// tableUsesEnum checks if a table uses any enum types +func (w *Writer) tableUsesEnum(table *models.Table, enumMap map[string]bool) bool { + for _, col := range table.Columns { + if enumMap[col.Type] || enumMap[strings.ToLower(col.Type)] { + return true + } + } + return false +} + +// getTableEnumNames returns the list of enum variable names used by a table +func (w *Writer) getTableEnumNames(table *models.Table, schema *models.Schema, enumMap map[string]bool) []string { + enumNames := make([]string, 0) + seen := make(map[string]bool) + + for _, col := range table.Columns { + if enumMap[col.Type] || enumMap[strings.ToLower(col.Type)] { + // Find the enum in schema + for _, enum := range schema.Enums { + if strings.EqualFold(enum.Name, col.Type) { + varName := w.typeMapper.ToCamelCase(enum.Name) + if !seen[varName] { + enumNames = append(enumNames, varName) + seen[varName] = true + } + break + } + } + } + } + + return enumNames +} + +// getSortedColumnNames returns column names sorted by sequence or name +func (w *Writer) getSortedColumnNames(table *models.Table) []string { + // Convert map to slice + columns := make([]*models.Column, 0, len(table.Columns)) + for _, col := range table.Columns { + columns = append(columns, col) + } + + // Sort by sequence, then by primary key, then by name + // (Similar to GORM writer) + sortColumns := func(i, j int) bool { + // Sort by sequence if both have it + if columns[i].Sequence > 0 && columns[j].Sequence > 0 { + return columns[i].Sequence < columns[j].Sequence + } + + // Put primary keys first + if columns[i].IsPrimaryKey != columns[j].IsPrimaryKey { + return columns[i].IsPrimaryKey + } + + // Otherwise sort alphabetically + return columns[i].Name < columns[j].Name + } + + // Create a custom sorter + for i := 0; i < len(columns); i++ { + for j := i + 1; j < len(columns); j++ { + if !sortColumns(i, j) { + columns[i], columns[j] = columns[j], columns[i] + } + } + } + + // Extract names + names := make([]string, len(columns)) + for i, col := range columns { + names[i] = col.Name + } + + return names +} + +// getForeignKeyForColumn returns the foreign key constraint for a column, if any +func (w *Writer) getForeignKeyForColumn(columnName string, table *models.Table) *models.Constraint { + for _, constraint := range table.Constraints { + if constraint.Type == models.ForeignKeyConstraint { + for _, col := range constraint.Columns { + if col == columnName { + return constraint + } + } + } + } + return nil +} + +// writeOutput writes the content to file or stdout +func (w *Writer) writeOutput(content string) error { + if w.options.OutputPath != "" { + return os.WriteFile(w.options.OutputPath, []byte(content), 0644) + } + + // Print to stdout + fmt.Print(content) + return nil +} + +// shouldUseMultiFile determines whether to use multi-file mode based on metadata or output path +func (w *Writer) shouldUseMultiFile() bool { + // Check if multi_file is explicitly set in metadata + if w.options.Metadata != nil { + if mf, ok := w.options.Metadata["multi_file"].(bool); ok { + return mf + } + } + + // Auto-detect based on output path + if w.options.OutputPath == "" { + // No output path means stdout (single file) + return false + } + + // Check if path ends with .ts (explicit file) + if strings.HasSuffix(w.options.OutputPath, ".ts") { + return false + } + + // Check if path ends with directory separator + if strings.HasSuffix(w.options.OutputPath, "/") || strings.HasSuffix(w.options.OutputPath, "\\") { + return true + } + + // Check if path exists and is a directory + info, err := os.Stat(w.options.OutputPath) + if err == nil && info.IsDir() { + return true + } + + // Default to single file for ambiguous cases + return false +} diff --git a/tests/assets/drizzle/schema-updated.ts b/tests/assets/drizzle/schema-updated.ts new file mode 100644 index 0000000..6f56da2 --- /dev/null +++ b/tests/assets/drizzle/schema-updated.ts @@ -0,0 +1,156 @@ +// Code generated by relspecgo. DO NOT EDIT. +import { pgTable, pgEnum, integer, bigint, smallint, serial, bigserial, smallserial, text, varchar, char, boolean, numeric, real, doublePrecision, timestamp, date, time, interval, json, jsonb, uuid, bytea } from 'drizzle-orm/pg-core'; +import { sql } from 'drizzle-orm'; + + +// Enums +export const userRole = pgEnum('UserRole', ['admin', 'user', 'moderator', 'guest']); +export const orderStatus = pgEnum('OrderStatus', ['pending', 'processing', 'shipped', 'delivered', 'cancelled']); + + +// Table: users +export const users = pgTable('users', { + id: serial('id').primaryKey(), + createdAt: timestamp('created_at').notNull().default(sql`now()`), + email: varchar('email').notNull().unique(), + isActive: boolean('is_active').notNull().default(true), + lastLoginAt: timestamp('last_login_at'), + passwordHash: varchar('password_hash').notNull(), + profile: jsonb('profile'), + role: pgEnum('UserRole')('role').notNull(), + updatedAt: timestamp('updated_at').notNull().default(sql`now()`), + username: varchar('username').notNull().unique(), +}); + +// Types for users +export type Users = typeof users.$inferSelect; +export type NewUsers = typeof users.$inferInsert; +// Table: profiles +export const profiles = pgTable('profiles', { + id: serial('id').primaryKey(), + avatarUrl: varchar('avatar_url'), + bio: text('bio'), + createdAt: timestamp('created_at').notNull().default(sql`now()`), + dateOfBirth: date('date_of_birth'), + firstName: varchar('first_name'), + lastName: varchar('last_name'), + phoneNumber: varchar('phone_number'), + updatedAt: timestamp('updated_at').notNull().default(sql`now()`), + userId: integer('user_id').notNull().unique().references(() => users.id), +}); + +// Types for profiles +export type Profiles = typeof profiles.$inferSelect; +export type NewProfiles = typeof profiles.$inferInsert; +// Table: posts +export const posts = pgTable('posts', { + id: serial('id').primaryKey(), + authorId: integer('author_id').notNull().references(() => users.id), + content: text('content').notNull(), + createdAt: timestamp('created_at').notNull().default(sql`now()`), + excerpt: text('excerpt'), + featuredImage: varchar('featured_image'), + isPublished: boolean('is_published').notNull().default(false), + publishedAt: timestamp('published_at'), + slug: varchar('slug').notNull().unique(), + title: varchar('title').notNull(), + updatedAt: timestamp('updated_at').notNull().default(sql`now()`), + viewCount: integer('view_count').notNull().default(0), +}); + +// Types for posts +export type Posts = typeof posts.$inferSelect; +export type NewPosts = typeof posts.$inferInsert; +// Table: comments +export const comments = pgTable('comments', { + id: serial('id').primaryKey(), + authorId: integer('author_id').notNull().references(() => users.id), + content: text('content').notNull(), + createdAt: timestamp('created_at').notNull().default(sql`now()`), + isApproved: boolean('is_approved').notNull().default(false), + parentId: integer('parent_id').references(() => comments.id), + postId: integer('post_id').notNull().references(() => posts.id), + updatedAt: timestamp('updated_at').notNull().default(sql`now()`), +}); + +// Types for comments +export type Comments = typeof comments.$inferSelect; +export type NewComments = typeof comments.$inferInsert; +// Table: categories +export const categories = pgTable('categories', { + id: serial('id').primaryKey(), + createdAt: timestamp('created_at').notNull().default(sql`now()`), + description: text('description'), + name: varchar('name').notNull().unique(), + parentId: integer('parent_id').references(() => categories.id), + slug: varchar('slug').notNull().unique(), + updatedAt: timestamp('updated_at').notNull().default(sql`now()`), +}); + +// Types for categories +export type Categories = typeof categories.$inferSelect; +export type NewCategories = typeof categories.$inferInsert; +// Table: post_categories +export const postCategories = pgTable('post_categories', { + categoryId: integer('category_id').notNull().references(() => categories.id), + createdAt: timestamp('created_at').notNull().default(sql`now()`), + postId: integer('post_id').notNull().references(() => posts.id), +}); + +// Types for post_categories +export type PostCategories = typeof postCategories.$inferSelect; +export type NewPostCategories = typeof postCategories.$inferInsert; +// Table: tags +export const tags = pgTable('tags', { + id: serial('id').primaryKey(), + createdAt: timestamp('created_at').notNull().default(sql`now()`), + name: varchar('name').notNull().unique(), + slug: varchar('slug').notNull().unique(), +}); + +// Types for tags +export type Tags = typeof tags.$inferSelect; +export type NewTags = typeof tags.$inferInsert; +// Table: post_tags +export const postTags = pgTable('post_tags', { + createdAt: timestamp('created_at').notNull().default(sql`now()`), + postId: integer('post_id').notNull().references(() => posts.id), + tagId: integer('tag_id').notNull().references(() => tags.id), +}); + +// Types for post_tags +export type PostTags = typeof postTags.$inferSelect; +export type NewPostTags = typeof postTags.$inferInsert; +// Table: orders +export const orders = pgTable('orders', { + id: serial('id').primaryKey(), + billingAddress: jsonb('billing_address').notNull(), + completedAt: timestamp('completed_at'), + createdAt: timestamp('created_at').notNull().default(sql`now()`), + currency: varchar('currency').notNull().default('USD'), + notes: text('notes'), + orderNumber: varchar('order_number').notNull().unique(), + shippingAddress: jsonb('shipping_address').notNull(), + status: pgEnum('OrderStatus')('status').notNull().default('pending'), + totalAmount: numeric('total_amount').notNull(), + updatedAt: timestamp('updated_at').notNull().default(sql`now()`), + userId: integer('user_id').notNull().references(() => users.id), +}); + +// Types for orders +export type Orders = typeof orders.$inferSelect; +export type NewOrders = typeof orders.$inferInsert; +// Table: sessions +export const sessions = pgTable('sessions', { + id: uuid('id').primaryKey().default(sql`gen_random_uuid()`), + createdAt: timestamp('created_at').notNull().default(sql`now()`), + expiresAt: timestamp('expires_at').notNull(), + ipAddress: varchar('ip_address'), + token: varchar('token').notNull().unique(), + userAgent: text('user_agent'), + userId: integer('user_id').notNull().references(() => users.id), +}); + +// Types for sessions +export type Sessions = typeof sessions.$inferSelect; +export type NewSessions = typeof sessions.$inferInsert; diff --git a/tests/assets/drizzle/schema.ts b/tests/assets/drizzle/schema.ts new file mode 100644 index 0000000..04a8c5e --- /dev/null +++ b/tests/assets/drizzle/schema.ts @@ -0,0 +1,90 @@ +// Code generated by relspecgo. DO NOT EDIT. +import { pgTable, pgEnum, integer, bigint, smallint, serial, bigserial, smallserial, text, varchar, char, boolean, numeric, real, doublePrecision, timestamp, date, time, interval, json, jsonb, uuid, bytea } from 'drizzle-orm/pg-core'; +import { sql } from 'drizzle-orm'; + + +// Enums +export const role = pgEnum('Role', ['USER', 'ADMIN']); +export type Role = 'USER' | 'ADMIN'; + + +// Table: User +export interface User { + id: number; + email: string; + name: string | null; + profile: string | null; + role: Role; +} + +export const user = pgTable('User', { + id: integer('id').primaryKey().generatedAlwaysAsIdentity(), + email: text('email').notNull().unique(), + name: text('name'), + profile: text('profile'), + role: pgEnum('Role')('role').notNull().default('USER'), +}); + +export type NewUser = typeof user.$inferInsert; +// Table: Profile +export interface Profile { + id: number; + bio: string; + user: string; + userId: number; +} + +export const profile = pgTable('Profile', { + id: integer('id').primaryKey().generatedAlwaysAsIdentity(), + bio: text('bio').notNull(), + user: text('user').notNull(), + userId: integer('userId').notNull().unique().references(() => user.id), +}); + +export type NewProfile = typeof profile.$inferInsert; +// Table: Post +export interface Post { + id: number; + author: string; + authorId: number; + createdAt: Date; + published: boolean; + title: string; + updatedAt: Date; // @updatedAt +} + +export const post = pgTable('Post', { + id: integer('id').primaryKey().generatedAlwaysAsIdentity(), + author: text('author').notNull(), + authorId: integer('authorId').notNull().references(() => user.id), + createdAt: timestamp('createdAt').notNull().default(sql`now()`), + published: boolean('published').notNull().default(false), + title: text('title').notNull(), + updatedAt: timestamp('updatedAt').notNull(), // @updatedAt +}); + +export type NewPost = typeof post.$inferInsert; +// Table: Category +export interface Category { + id: number; + name: string; +} + +export const category = pgTable('Category', { + id: integer('id').primaryKey().generatedAlwaysAsIdentity(), + name: text('name').notNull(), +}); + +export type NewCategory = typeof category.$inferInsert; +// Table: _CategoryToPost +export interface Categorytopost { + categoryId: number; + postId: number; +} + +export const Categorytopost = pgTable('_CategoryToPost', { + categoryId: integer('CategoryId').primaryKey().references(() => category.id), + postId: integer('PostId').primaryKey().references(() => post.id), +}); + +export type NewCategorytopost = typeof Categorytopost.$inferInsert;