Added Drizzle ORM support
This commit is contained in:
543
pkg/writers/drizzle/writer.go
Normal file
543
pkg/writers/drizzle/writer.go
Normal file
@@ -0,0 +1,543 @@
|
||||
package drizzle
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/writers"
|
||||
)
|
||||
|
||||
// Writer implements the writers.Writer interface for Drizzle ORM
|
||||
type Writer struct {
|
||||
options *writers.WriterOptions
|
||||
typeMapper *TypeMapper
|
||||
templates *Templates
|
||||
}
|
||||
|
||||
// NewWriter creates a new Drizzle writer with the given options
|
||||
func NewWriter(options *writers.WriterOptions) *Writer {
|
||||
w := &Writer{
|
||||
options: options,
|
||||
typeMapper: NewTypeMapper(),
|
||||
}
|
||||
|
||||
// Initialize templates
|
||||
tmpl, err := NewTemplates()
|
||||
if err != nil {
|
||||
// Should not happen with embedded templates
|
||||
panic(fmt.Sprintf("failed to initialize templates: %v", err))
|
||||
}
|
||||
w.templates = tmpl
|
||||
|
||||
return w
|
||||
}
|
||||
|
||||
// WriteDatabase writes a complete database as Drizzle schema
|
||||
func (w *Writer) WriteDatabase(db *models.Database) error {
|
||||
// Check if multi-file mode is enabled
|
||||
multiFile := w.shouldUseMultiFile()
|
||||
|
||||
if multiFile {
|
||||
return w.writeMultiFile(db)
|
||||
}
|
||||
|
||||
return w.writeSingleFile(db)
|
||||
}
|
||||
|
||||
// WriteSchema writes a schema as Drizzle schema
|
||||
func (w *Writer) WriteSchema(schema *models.Schema) error {
|
||||
// Create a temporary database with just this schema
|
||||
db := models.InitDatabase(schema.Name)
|
||||
db.Schemas = []*models.Schema{schema}
|
||||
|
||||
return w.WriteDatabase(db)
|
||||
}
|
||||
|
||||
// WriteTable writes a single table as a Drizzle schema
|
||||
func (w *Writer) WriteTable(table *models.Table) error {
|
||||
// Create a temporary schema and database
|
||||
schema := models.InitSchema(table.Schema)
|
||||
schema.Tables = []*models.Table{table}
|
||||
|
||||
db := models.InitDatabase(schema.Name)
|
||||
db.Schemas = []*models.Schema{schema}
|
||||
|
||||
return w.WriteDatabase(db)
|
||||
}
|
||||
|
||||
// writeSingleFile writes all tables to a single file
|
||||
func (w *Writer) writeSingleFile(db *models.Database) error {
|
||||
templateData := NewTemplateData()
|
||||
|
||||
// Build enum map for quick lookup
|
||||
enumMap := w.buildEnumMap(db)
|
||||
|
||||
// Process all schemas
|
||||
for _, schema := range db.Schemas {
|
||||
// Add enums
|
||||
for _, enum := range schema.Enums {
|
||||
enumData := NewEnumData(enum, w.typeMapper)
|
||||
templateData.AddEnum(enumData)
|
||||
}
|
||||
|
||||
// Add tables
|
||||
for _, table := range schema.Tables {
|
||||
tableData := w.buildTableData(table, schema, db, enumMap)
|
||||
templateData.AddTable(tableData)
|
||||
}
|
||||
}
|
||||
|
||||
// Add imports
|
||||
w.addImports(templateData, db)
|
||||
|
||||
// Finalize imports
|
||||
templateData.FinalizeImports()
|
||||
|
||||
// Generate code
|
||||
code, err := w.templates.GenerateCode(templateData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate code: %w", err)
|
||||
}
|
||||
|
||||
// Write output
|
||||
return w.writeOutput(code)
|
||||
}
|
||||
|
||||
// writeMultiFile writes each table to a separate file
|
||||
func (w *Writer) writeMultiFile(db *models.Database) error {
|
||||
// Ensure output path is a directory
|
||||
if w.options.OutputPath == "" {
|
||||
return fmt.Errorf("output path is required for multi-file mode")
|
||||
}
|
||||
|
||||
// Create output directory if it doesn't exist
|
||||
if err := os.MkdirAll(w.options.OutputPath, 0755); err != nil {
|
||||
return fmt.Errorf("failed to create output directory: %w", err)
|
||||
}
|
||||
|
||||
// Build enum map for quick lookup
|
||||
enumMap := w.buildEnumMap(db)
|
||||
|
||||
// Process all schemas
|
||||
for _, schema := range db.Schemas {
|
||||
// Write enums file if there are any
|
||||
if len(schema.Enums) > 0 {
|
||||
if err := w.writeEnumsFile(schema); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Write each table to a separate file
|
||||
for _, table := range schema.Tables {
|
||||
if err := w.writeTableFile(table, schema, db, enumMap); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// writeEnumsFile writes all enums to a separate file
|
||||
func (w *Writer) writeEnumsFile(schema *models.Schema) error {
|
||||
templateData := NewTemplateData()
|
||||
|
||||
// Add enums
|
||||
for _, enum := range schema.Enums {
|
||||
enumData := NewEnumData(enum, w.typeMapper)
|
||||
templateData.AddEnum(enumData)
|
||||
}
|
||||
|
||||
// Add imports for enums
|
||||
templateData.AddImport("import { pgEnum } from 'drizzle-orm/pg-core';")
|
||||
|
||||
// Generate code
|
||||
code, err := w.templates.GenerateCode(templateData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate enums code: %w", err)
|
||||
}
|
||||
|
||||
// Write to enums.ts file
|
||||
filename := filepath.Join(w.options.OutputPath, "enums.ts")
|
||||
return os.WriteFile(filename, []byte(code), 0644)
|
||||
}
|
||||
|
||||
// writeTableFile writes a single table to its own file
|
||||
func (w *Writer) writeTableFile(table *models.Table, schema *models.Schema, db *models.Database, enumMap map[string]bool) error {
|
||||
templateData := NewTemplateData()
|
||||
|
||||
// Build table data
|
||||
tableData := w.buildTableData(table, schema, db, enumMap)
|
||||
templateData.AddTable(tableData)
|
||||
|
||||
// Add imports
|
||||
w.addImports(templateData, db)
|
||||
|
||||
// If there are enums, add import from enums file
|
||||
if len(schema.Enums) > 0 && w.tableUsesEnum(table, enumMap) {
|
||||
// Import enum definitions from enums.ts
|
||||
enumNames := w.getTableEnumNames(table, schema, enumMap)
|
||||
if len(enumNames) > 0 {
|
||||
importLine := fmt.Sprintf("import { %s } from './enums';", strings.Join(enumNames, ", "))
|
||||
templateData.AddImport(importLine)
|
||||
}
|
||||
}
|
||||
|
||||
// Finalize imports
|
||||
templateData.FinalizeImports()
|
||||
|
||||
// Generate code
|
||||
code, err := w.templates.GenerateCode(templateData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate code for table %s: %w", table.Name, err)
|
||||
}
|
||||
|
||||
// Generate filename: {tableName}.ts
|
||||
filename := filepath.Join(w.options.OutputPath, table.Name+".ts")
|
||||
return os.WriteFile(filename, []byte(code), 0644)
|
||||
}
|
||||
|
||||
// buildTableData builds TableData from a models.Table
|
||||
func (w *Writer) buildTableData(table *models.Table, schema *models.Schema, db *models.Database, enumMap map[string]bool) *TableData {
|
||||
tableData := NewTableData(table, w.typeMapper)
|
||||
|
||||
// Add columns
|
||||
for _, colName := range w.getSortedColumnNames(table) {
|
||||
col := table.Columns[colName]
|
||||
|
||||
// Check if this column uses an enum
|
||||
isEnum := enumMap[col.Type]
|
||||
|
||||
columnData := NewColumnData(col, table, w.typeMapper, isEnum)
|
||||
|
||||
// Set TypeScript type
|
||||
drizzleType := w.typeMapper.SQLTypeToDrizzle(col.Type)
|
||||
enumName := ""
|
||||
if isEnum {
|
||||
// For enums, use the enum type name
|
||||
enumName = col.Type
|
||||
}
|
||||
baseType := w.typeMapper.DrizzleTypeToTypeScript(drizzleType, isEnum, enumName)
|
||||
|
||||
// Add null union if column is nullable
|
||||
if !col.NotNull && !col.IsPrimaryKey {
|
||||
columnData.TypeScriptType = baseType + " | null"
|
||||
} else {
|
||||
columnData.TypeScriptType = baseType
|
||||
}
|
||||
|
||||
// Check if this column is a foreign key
|
||||
if fk := w.getForeignKeyForColumn(col.Name, table); fk != nil {
|
||||
columnData.IsForeignKey = true
|
||||
refTableName := fk.ReferencedTable
|
||||
refChain := w.typeMapper.BuildReferencesChain(fk, refTableName)
|
||||
if refChain != "" {
|
||||
columnData.ReferencesLine = "." + refChain
|
||||
// Append to the drizzle chain
|
||||
columnData.DrizzleChain += columnData.ReferencesLine
|
||||
}
|
||||
}
|
||||
|
||||
tableData.AddColumn(columnData)
|
||||
}
|
||||
|
||||
// Collect all column field names that are used in indexes
|
||||
indexColumnFields := make(map[string]bool)
|
||||
|
||||
// Add indexes (excluding single-column unique indexes, which are handled inline)
|
||||
for _, index := range table.Indexes {
|
||||
// Skip single-column unique indexes (handled by .unique() modifier)
|
||||
if index.Unique && len(index.Columns) == 1 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Track which columns are used in indexes
|
||||
for _, colName := range index.Columns {
|
||||
// Find the field name for this column
|
||||
if col, exists := table.Columns[colName]; exists {
|
||||
fieldName := w.typeMapper.ToCamelCase(col.Name)
|
||||
indexColumnFields[fieldName] = true
|
||||
}
|
||||
}
|
||||
|
||||
indexData := NewIndexData(index, tableData.Name, w.typeMapper)
|
||||
tableData.AddIndex(indexData)
|
||||
}
|
||||
|
||||
// Add multi-column unique constraints as unique indexes
|
||||
for _, constraint := range table.Constraints {
|
||||
if constraint.Type == models.UniqueConstraint && len(constraint.Columns) > 1 {
|
||||
// Create a unique index for this constraint
|
||||
indexData := &IndexData{
|
||||
Name: w.typeMapper.ToCamelCase(constraint.Name) + "Idx",
|
||||
Columns: constraint.Columns,
|
||||
IsUnique: true,
|
||||
}
|
||||
|
||||
// Track which columns are used in indexes
|
||||
for _, colName := range constraint.Columns {
|
||||
if col, exists := table.Columns[colName]; exists {
|
||||
fieldName := w.typeMapper.ToCamelCase(col.Name)
|
||||
indexColumnFields[fieldName] = true
|
||||
}
|
||||
}
|
||||
|
||||
// Build column references as field names (for destructuring)
|
||||
colRefs := make([]string, len(constraint.Columns))
|
||||
for i, colName := range constraint.Columns {
|
||||
if col, exists := table.Columns[colName]; exists {
|
||||
colRefs[i] = w.typeMapper.ToCamelCase(col.Name)
|
||||
} else {
|
||||
colRefs[i] = w.typeMapper.ToCamelCase(colName)
|
||||
}
|
||||
}
|
||||
|
||||
indexData.Definition = "uniqueIndex('" + constraint.Name + "').on(" + joinStrings(colRefs, ", ") + ")"
|
||||
tableData.AddIndex(indexData)
|
||||
}
|
||||
}
|
||||
|
||||
// Convert index column fields map to sorted slice
|
||||
if len(indexColumnFields) > 0 {
|
||||
fields := make([]string, 0, len(indexColumnFields))
|
||||
for field := range indexColumnFields {
|
||||
fields = append(fields, field)
|
||||
}
|
||||
// Sort for consistent output
|
||||
sortStrings(fields)
|
||||
tableData.IndexColumnFields = fields
|
||||
}
|
||||
|
||||
return tableData
|
||||
}
|
||||
|
||||
// sortStrings sorts a slice of strings in place
|
||||
func sortStrings(strs []string) {
|
||||
for i := 0; i < len(strs); i++ {
|
||||
for j := i + 1; j < len(strs); j++ {
|
||||
if strs[i] > strs[j] {
|
||||
strs[i], strs[j] = strs[j], strs[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// addImports adds the necessary imports to the template data
|
||||
func (w *Writer) addImports(templateData *TemplateData, db *models.Database) {
|
||||
// Determine which Drizzle imports we need
|
||||
needsPgTable := len(templateData.Tables) > 0
|
||||
needsPgEnum := len(templateData.Enums) > 0
|
||||
needsIndex := false
|
||||
needsUniqueIndex := false
|
||||
needsSQL := false
|
||||
|
||||
// Check what we need based on tables
|
||||
for _, table := range templateData.Tables {
|
||||
for _, index := range table.Indexes {
|
||||
if index.IsUnique {
|
||||
needsUniqueIndex = true
|
||||
} else {
|
||||
needsIndex = true
|
||||
}
|
||||
}
|
||||
|
||||
// Check if any column uses SQL default values
|
||||
for _, col := range table.Columns {
|
||||
if strings.Contains(col.DrizzleChain, "sql`") {
|
||||
needsSQL = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Build the import statement
|
||||
imports := make([]string, 0)
|
||||
|
||||
if needsPgTable {
|
||||
imports = append(imports, "pgTable")
|
||||
}
|
||||
if needsPgEnum {
|
||||
imports = append(imports, "pgEnum")
|
||||
}
|
||||
|
||||
// Add column types - for now, add common ones
|
||||
// TODO: Could be optimized to only include used types
|
||||
columnTypes := []string{
|
||||
"integer", "bigint", "smallint",
|
||||
"serial", "bigserial", "smallserial",
|
||||
"text", "varchar", "char",
|
||||
"boolean", "numeric", "real", "doublePrecision",
|
||||
"timestamp", "date", "time", "interval",
|
||||
"json", "jsonb", "uuid", "bytea",
|
||||
}
|
||||
imports = append(imports, columnTypes...)
|
||||
|
||||
if needsIndex {
|
||||
imports = append(imports, "index")
|
||||
}
|
||||
if needsUniqueIndex {
|
||||
imports = append(imports, "uniqueIndex")
|
||||
}
|
||||
|
||||
importLine := "import { " + strings.Join(imports, ", ") + " } from 'drizzle-orm/pg-core';"
|
||||
templateData.AddImport(importLine)
|
||||
|
||||
// Add SQL import if needed
|
||||
if needsSQL {
|
||||
templateData.AddImport("import { sql } from 'drizzle-orm';")
|
||||
}
|
||||
}
|
||||
|
||||
// buildEnumMap builds a map of enum type names for quick lookup
|
||||
func (w *Writer) buildEnumMap(db *models.Database) map[string]bool {
|
||||
enumMap := make(map[string]bool)
|
||||
|
||||
for _, schema := range db.Schemas {
|
||||
for _, enum := range schema.Enums {
|
||||
enumMap[enum.Name] = true
|
||||
// Also add lowercase version for case-insensitive lookup
|
||||
enumMap[strings.ToLower(enum.Name)] = true
|
||||
}
|
||||
}
|
||||
|
||||
return enumMap
|
||||
}
|
||||
|
||||
// tableUsesEnum checks if a table uses any enum types
|
||||
func (w *Writer) tableUsesEnum(table *models.Table, enumMap map[string]bool) bool {
|
||||
for _, col := range table.Columns {
|
||||
if enumMap[col.Type] || enumMap[strings.ToLower(col.Type)] {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// getTableEnumNames returns the list of enum variable names used by a table
|
||||
func (w *Writer) getTableEnumNames(table *models.Table, schema *models.Schema, enumMap map[string]bool) []string {
|
||||
enumNames := make([]string, 0)
|
||||
seen := make(map[string]bool)
|
||||
|
||||
for _, col := range table.Columns {
|
||||
if enumMap[col.Type] || enumMap[strings.ToLower(col.Type)] {
|
||||
// Find the enum in schema
|
||||
for _, enum := range schema.Enums {
|
||||
if strings.EqualFold(enum.Name, col.Type) {
|
||||
varName := w.typeMapper.ToCamelCase(enum.Name)
|
||||
if !seen[varName] {
|
||||
enumNames = append(enumNames, varName)
|
||||
seen[varName] = true
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return enumNames
|
||||
}
|
||||
|
||||
// getSortedColumnNames returns column names sorted by sequence or name
|
||||
func (w *Writer) getSortedColumnNames(table *models.Table) []string {
|
||||
// Convert map to slice
|
||||
columns := make([]*models.Column, 0, len(table.Columns))
|
||||
for _, col := range table.Columns {
|
||||
columns = append(columns, col)
|
||||
}
|
||||
|
||||
// Sort by sequence, then by primary key, then by name
|
||||
// (Similar to GORM writer)
|
||||
sortColumns := func(i, j int) bool {
|
||||
// Sort by sequence if both have it
|
||||
if columns[i].Sequence > 0 && columns[j].Sequence > 0 {
|
||||
return columns[i].Sequence < columns[j].Sequence
|
||||
}
|
||||
|
||||
// Put primary keys first
|
||||
if columns[i].IsPrimaryKey != columns[j].IsPrimaryKey {
|
||||
return columns[i].IsPrimaryKey
|
||||
}
|
||||
|
||||
// Otherwise sort alphabetically
|
||||
return columns[i].Name < columns[j].Name
|
||||
}
|
||||
|
||||
// Create a custom sorter
|
||||
for i := 0; i < len(columns); i++ {
|
||||
for j := i + 1; j < len(columns); j++ {
|
||||
if !sortColumns(i, j) {
|
||||
columns[i], columns[j] = columns[j], columns[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Extract names
|
||||
names := make([]string, len(columns))
|
||||
for i, col := range columns {
|
||||
names[i] = col.Name
|
||||
}
|
||||
|
||||
return names
|
||||
}
|
||||
|
||||
// getForeignKeyForColumn returns the foreign key constraint for a column, if any
|
||||
func (w *Writer) getForeignKeyForColumn(columnName string, table *models.Table) *models.Constraint {
|
||||
for _, constraint := range table.Constraints {
|
||||
if constraint.Type == models.ForeignKeyConstraint {
|
||||
for _, col := range constraint.Columns {
|
||||
if col == columnName {
|
||||
return constraint
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// writeOutput writes the content to file or stdout
|
||||
func (w *Writer) writeOutput(content string) error {
|
||||
if w.options.OutputPath != "" {
|
||||
return os.WriteFile(w.options.OutputPath, []byte(content), 0644)
|
||||
}
|
||||
|
||||
// Print to stdout
|
||||
fmt.Print(content)
|
||||
return nil
|
||||
}
|
||||
|
||||
// shouldUseMultiFile determines whether to use multi-file mode based on metadata or output path
|
||||
func (w *Writer) shouldUseMultiFile() bool {
|
||||
// Check if multi_file is explicitly set in metadata
|
||||
if w.options.Metadata != nil {
|
||||
if mf, ok := w.options.Metadata["multi_file"].(bool); ok {
|
||||
return mf
|
||||
}
|
||||
}
|
||||
|
||||
// Auto-detect based on output path
|
||||
if w.options.OutputPath == "" {
|
||||
// No output path means stdout (single file)
|
||||
return false
|
||||
}
|
||||
|
||||
// Check if path ends with .ts (explicit file)
|
||||
if strings.HasSuffix(w.options.OutputPath, ".ts") {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check if path ends with directory separator
|
||||
if strings.HasSuffix(w.options.OutputPath, "/") || strings.HasSuffix(w.options.OutputPath, "\\") {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check if path exists and is a directory
|
||||
info, err := os.Stat(w.options.OutputPath)
|
||||
if err == nil && info.IsDir() {
|
||||
return true
|
||||
}
|
||||
|
||||
// Default to single file for ambiguous cases
|
||||
return false
|
||||
}
|
||||
Reference in New Issue
Block a user