package drizzle import ( "fmt" "os" "path/filepath" "strings" "git.warky.dev/wdevs/relspecgo/pkg/models" "git.warky.dev/wdevs/relspecgo/pkg/writers" ) // Writer implements the writers.Writer interface for Drizzle ORM type Writer struct { options *writers.WriterOptions typeMapper *TypeMapper templates *Templates } // NewWriter creates a new Drizzle writer with the given options func NewWriter(options *writers.WriterOptions) *Writer { w := &Writer{ options: options, typeMapper: NewTypeMapper(), } // Initialize templates tmpl, err := NewTemplates() if err != nil { // Should not happen with embedded templates panic(fmt.Sprintf("failed to initialize templates: %v", err)) } w.templates = tmpl return w } // WriteDatabase writes a complete database as Drizzle schema func (w *Writer) WriteDatabase(db *models.Database) error { // Check if multi-file mode is enabled multiFile := w.shouldUseMultiFile() if multiFile { return w.writeMultiFile(db) } return w.writeSingleFile(db) } // WriteSchema writes a schema as Drizzle schema func (w *Writer) WriteSchema(schema *models.Schema) error { // Create a temporary database with just this schema db := models.InitDatabase(schema.Name) db.Schemas = []*models.Schema{schema} return w.WriteDatabase(db) } // WriteTable writes a single table as a Drizzle schema func (w *Writer) WriteTable(table *models.Table) error { // Create a temporary schema and database schema := models.InitSchema(table.Schema) schema.Tables = []*models.Table{table} db := models.InitDatabase(schema.Name) db.Schemas = []*models.Schema{schema} return w.WriteDatabase(db) } // writeSingleFile writes all tables to a single file func (w *Writer) writeSingleFile(db *models.Database) error { templateData := NewTemplateData() // Build enum map for quick lookup enumMap := w.buildEnumMap(db) // Process all schemas for _, schema := range db.Schemas { // Add enums for _, enum := range schema.Enums { enumData := NewEnumData(enum, w.typeMapper) templateData.AddEnum(enumData) } // Add tables for _, table := range schema.Tables { tableData := w.buildTableData(table, schema, db, enumMap) templateData.AddTable(tableData) } } // Add imports w.addImports(templateData, db) // Finalize imports templateData.FinalizeImports() // Generate code code, err := w.templates.GenerateCode(templateData) if err != nil { return fmt.Errorf("failed to generate code: %w", err) } // Write output return w.writeOutput(code) } // writeMultiFile writes each table to a separate file func (w *Writer) writeMultiFile(db *models.Database) error { // Ensure output path is a directory if w.options.OutputPath == "" { return fmt.Errorf("output path is required for multi-file mode") } // Create output directory if it doesn't exist if err := os.MkdirAll(w.options.OutputPath, 0755); err != nil { return fmt.Errorf("failed to create output directory: %w", err) } // Build enum map for quick lookup enumMap := w.buildEnumMap(db) // Process all schemas for _, schema := range db.Schemas { // Write enums file if there are any if len(schema.Enums) > 0 { if err := w.writeEnumsFile(schema); err != nil { return err } } // Write each table to a separate file for _, table := range schema.Tables { if err := w.writeTableFile(table, schema, db, enumMap); err != nil { return err } } } return nil } // writeEnumsFile writes all enums to a separate file func (w *Writer) writeEnumsFile(schema *models.Schema) error { templateData := NewTemplateData() // Add enums for _, enum := range schema.Enums { enumData := NewEnumData(enum, w.typeMapper) templateData.AddEnum(enumData) } // Add imports for enums templateData.AddImport("import { pgEnum } from 'drizzle-orm/pg-core';") // Generate code code, err := w.templates.GenerateCode(templateData) if err != nil { return fmt.Errorf("failed to generate enums code: %w", err) } // Write to enums.ts file filename := filepath.Join(w.options.OutputPath, "enums.ts") return os.WriteFile(filename, []byte(code), 0644) } // writeTableFile writes a single table to its own file func (w *Writer) writeTableFile(table *models.Table, schema *models.Schema, db *models.Database, enumMap map[string]bool) error { templateData := NewTemplateData() // Build table data tableData := w.buildTableData(table, schema, db, enumMap) templateData.AddTable(tableData) // Add imports w.addImports(templateData, db) // If there are enums, add import from enums file if len(schema.Enums) > 0 && w.tableUsesEnum(table, enumMap) { // Import enum definitions from enums.ts enumNames := w.getTableEnumNames(table, schema, enumMap) if len(enumNames) > 0 { importLine := fmt.Sprintf("import { %s } from './enums';", strings.Join(enumNames, ", ")) templateData.AddImport(importLine) } } // Finalize imports templateData.FinalizeImports() // Generate code code, err := w.templates.GenerateCode(templateData) if err != nil { return fmt.Errorf("failed to generate code for table %s: %w", table.Name, err) } // Generate filename: {tableName}.ts // Sanitize table name to remove quotes, comments, and invalid characters safeTableName := writers.SanitizeFilename(table.Name) filename := filepath.Join(w.options.OutputPath, safeTableName+".ts") return os.WriteFile(filename, []byte(code), 0644) } // buildTableData builds TableData from a models.Table func (w *Writer) buildTableData(table *models.Table, schema *models.Schema, db *models.Database, enumMap map[string]bool) *TableData { tableData := NewTableData(table, w.typeMapper) // Add columns for _, colName := range w.getSortedColumnNames(table) { col := table.Columns[colName] // Check if this column uses an enum isEnum := enumMap[col.Type] columnData := NewColumnData(col, table, w.typeMapper, isEnum) // Set TypeScript type drizzleType := w.typeMapper.SQLTypeToDrizzle(col.Type) enumName := "" if isEnum { // For enums, use the enum type name enumName = col.Type } baseType := w.typeMapper.DrizzleTypeToTypeScript(drizzleType, isEnum, enumName) // Add null union if column is nullable if !col.NotNull && !col.IsPrimaryKey { columnData.TypeScriptType = baseType + " | null" } else { columnData.TypeScriptType = baseType } // Check if this column is a foreign key if fk := w.getForeignKeyForColumn(col.Name, table); fk != nil { columnData.IsForeignKey = true refTableName := fk.ReferencedTable refChain := w.typeMapper.BuildReferencesChain(fk, refTableName) if refChain != "" { columnData.ReferencesLine = "." + refChain // Append to the drizzle chain columnData.DrizzleChain += columnData.ReferencesLine } } tableData.AddColumn(columnData) } // Collect all column field names that are used in indexes indexColumnFields := make(map[string]bool) // Add indexes (excluding single-column unique indexes, which are handled inline) for _, index := range table.Indexes { // Skip single-column unique indexes (handled by .unique() modifier) if index.Unique && len(index.Columns) == 1 { continue } // Track which columns are used in indexes for _, colName := range index.Columns { // Find the field name for this column if col, exists := table.Columns[colName]; exists { fieldName := w.typeMapper.ToCamelCase(col.Name) indexColumnFields[fieldName] = true } } indexData := NewIndexData(index, tableData.Name, w.typeMapper) tableData.AddIndex(indexData) } // Add multi-column unique constraints as unique indexes for _, constraint := range table.Constraints { if constraint.Type == models.UniqueConstraint && len(constraint.Columns) > 1 { // Create a unique index for this constraint indexData := &IndexData{ Name: w.typeMapper.ToCamelCase(constraint.Name) + "Idx", Columns: constraint.Columns, IsUnique: true, } // Track which columns are used in indexes for _, colName := range constraint.Columns { if col, exists := table.Columns[colName]; exists { fieldName := w.typeMapper.ToCamelCase(col.Name) indexColumnFields[fieldName] = true } } // Build column references as field names (for destructuring) colRefs := make([]string, len(constraint.Columns)) for i, colName := range constraint.Columns { if col, exists := table.Columns[colName]; exists { colRefs[i] = w.typeMapper.ToCamelCase(col.Name) } else { colRefs[i] = w.typeMapper.ToCamelCase(colName) } } indexData.Definition = "uniqueIndex('" + constraint.Name + "').on(" + joinStrings(colRefs, ", ") + ")" tableData.AddIndex(indexData) } } // Convert index column fields map to sorted slice if len(indexColumnFields) > 0 { fields := make([]string, 0, len(indexColumnFields)) for field := range indexColumnFields { fields = append(fields, field) } // Sort for consistent output sortStrings(fields) tableData.IndexColumnFields = fields } return tableData } // sortStrings sorts a slice of strings in place func sortStrings(strs []string) { for i := 0; i < len(strs); i++ { for j := i + 1; j < len(strs); j++ { if strs[i] > strs[j] { strs[i], strs[j] = strs[j], strs[i] } } } } // addImports adds the necessary imports to the template data func (w *Writer) addImports(templateData *TemplateData, db *models.Database) { // Determine which Drizzle imports we need needsPgTable := len(templateData.Tables) > 0 needsPgEnum := len(templateData.Enums) > 0 needsIndex := false needsUniqueIndex := false needsSQL := false // Check what we need based on tables for _, table := range templateData.Tables { for _, index := range table.Indexes { if index.IsUnique { needsUniqueIndex = true } else { needsIndex = true } } // Check if any column uses SQL default values for _, col := range table.Columns { if strings.Contains(col.DrizzleChain, "sql`") { needsSQL = true } } } // Build the import statement imports := make([]string, 0) if needsPgTable { imports = append(imports, "pgTable") } if needsPgEnum { imports = append(imports, "pgEnum") } // Add column types - for now, add common ones // TODO: Could be optimized to only include used types columnTypes := []string{ "integer", "bigint", "smallint", "serial", "bigserial", "smallserial", "text", "varchar", "char", "boolean", "numeric", "real", "doublePrecision", "timestamp", "date", "time", "interval", "json", "jsonb", "uuid", "bytea", } imports = append(imports, columnTypes...) if needsIndex { imports = append(imports, "index") } if needsUniqueIndex { imports = append(imports, "uniqueIndex") } importLine := "import { " + strings.Join(imports, ", ") + " } from 'drizzle-orm/pg-core';" templateData.AddImport(importLine) // Add SQL import if needed if needsSQL { templateData.AddImport("import { sql } from 'drizzle-orm';") } } // buildEnumMap builds a map of enum type names for quick lookup func (w *Writer) buildEnumMap(db *models.Database) map[string]bool { enumMap := make(map[string]bool) for _, schema := range db.Schemas { for _, enum := range schema.Enums { enumMap[enum.Name] = true // Also add lowercase version for case-insensitive lookup enumMap[strings.ToLower(enum.Name)] = true } } return enumMap } // tableUsesEnum checks if a table uses any enum types func (w *Writer) tableUsesEnum(table *models.Table, enumMap map[string]bool) bool { for _, col := range table.Columns { if enumMap[col.Type] || enumMap[strings.ToLower(col.Type)] { return true } } return false } // getTableEnumNames returns the list of enum variable names used by a table func (w *Writer) getTableEnumNames(table *models.Table, schema *models.Schema, enumMap map[string]bool) []string { enumNames := make([]string, 0) seen := make(map[string]bool) for _, col := range table.Columns { if enumMap[col.Type] || enumMap[strings.ToLower(col.Type)] { // Find the enum in schema for _, enum := range schema.Enums { if strings.EqualFold(enum.Name, col.Type) { varName := w.typeMapper.ToCamelCase(enum.Name) if !seen[varName] { enumNames = append(enumNames, varName) seen[varName] = true } break } } } } return enumNames } // getSortedColumnNames returns column names sorted by sequence or name func (w *Writer) getSortedColumnNames(table *models.Table) []string { // Convert map to slice columns := make([]*models.Column, 0, len(table.Columns)) for _, col := range table.Columns { columns = append(columns, col) } // Sort by sequence, then by primary key, then by name // (Similar to GORM writer) sortColumns := func(i, j int) bool { // Sort by sequence if both have it if columns[i].Sequence > 0 && columns[j].Sequence > 0 { return columns[i].Sequence < columns[j].Sequence } // Put primary keys first if columns[i].IsPrimaryKey != columns[j].IsPrimaryKey { return columns[i].IsPrimaryKey } // Otherwise sort alphabetically return columns[i].Name < columns[j].Name } // Create a custom sorter for i := 0; i < len(columns); i++ { for j := i + 1; j < len(columns); j++ { if !sortColumns(i, j) { columns[i], columns[j] = columns[j], columns[i] } } } // Extract names names := make([]string, len(columns)) for i, col := range columns { names[i] = col.Name } return names } // getForeignKeyForColumn returns the foreign key constraint for a column, if any func (w *Writer) getForeignKeyForColumn(columnName string, table *models.Table) *models.Constraint { for _, constraint := range table.Constraints { if constraint.Type == models.ForeignKeyConstraint { for _, col := range constraint.Columns { if col == columnName { return constraint } } } } return nil } // writeOutput writes the content to file or stdout func (w *Writer) writeOutput(content string) error { if w.options.OutputPath != "" { return os.WriteFile(w.options.OutputPath, []byte(content), 0644) } // Print to stdout fmt.Print(content) return nil } // shouldUseMultiFile determines whether to use multi-file mode based on metadata or output path func (w *Writer) shouldUseMultiFile() bool { // Check if multi_file is explicitly set in metadata if w.options.Metadata != nil { if mf, ok := w.options.Metadata["multi_file"].(bool); ok { return mf } } // Auto-detect based on output path if w.options.OutputPath == "" { // No output path means stdout (single file) return false } // Check if path ends with .ts (explicit file) if strings.HasSuffix(w.options.OutputPath, ".ts") { return false } // Check if path ends with directory separator if strings.HasSuffix(w.options.OutputPath, "/") || strings.HasSuffix(w.options.OutputPath, "\\") { return true } // Check if path exists and is a directory info, err := os.Stat(w.options.OutputPath) if err == nil && info.IsDir() { return true } // Default to single file for ambiguous cases return false }