544 lines
15 KiB
Go
544 lines
15 KiB
Go
package drizzle
|
|
|
|
import (
|
|
"fmt"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
|
|
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
|
"git.warky.dev/wdevs/relspecgo/pkg/writers"
|
|
)
|
|
|
|
// Writer implements the writers.Writer interface for Drizzle ORM
|
|
type Writer struct {
|
|
options *writers.WriterOptions
|
|
typeMapper *TypeMapper
|
|
templates *Templates
|
|
}
|
|
|
|
// NewWriter creates a new Drizzle writer with the given options
|
|
func NewWriter(options *writers.WriterOptions) *Writer {
|
|
w := &Writer{
|
|
options: options,
|
|
typeMapper: NewTypeMapper(),
|
|
}
|
|
|
|
// Initialize templates
|
|
tmpl, err := NewTemplates()
|
|
if err != nil {
|
|
// Should not happen with embedded templates
|
|
panic(fmt.Sprintf("failed to initialize templates: %v", err))
|
|
}
|
|
w.templates = tmpl
|
|
|
|
return w
|
|
}
|
|
|
|
// WriteDatabase writes a complete database as Drizzle schema
|
|
func (w *Writer) WriteDatabase(db *models.Database) error {
|
|
// Check if multi-file mode is enabled
|
|
multiFile := w.shouldUseMultiFile()
|
|
|
|
if multiFile {
|
|
return w.writeMultiFile(db)
|
|
}
|
|
|
|
return w.writeSingleFile(db)
|
|
}
|
|
|
|
// WriteSchema writes a schema as Drizzle schema
|
|
func (w *Writer) WriteSchema(schema *models.Schema) error {
|
|
// Create a temporary database with just this schema
|
|
db := models.InitDatabase(schema.Name)
|
|
db.Schemas = []*models.Schema{schema}
|
|
|
|
return w.WriteDatabase(db)
|
|
}
|
|
|
|
// WriteTable writes a single table as a Drizzle schema
|
|
func (w *Writer) WriteTable(table *models.Table) error {
|
|
// Create a temporary schema and database
|
|
schema := models.InitSchema(table.Schema)
|
|
schema.Tables = []*models.Table{table}
|
|
|
|
db := models.InitDatabase(schema.Name)
|
|
db.Schemas = []*models.Schema{schema}
|
|
|
|
return w.WriteDatabase(db)
|
|
}
|
|
|
|
// writeSingleFile writes all tables to a single file
|
|
func (w *Writer) writeSingleFile(db *models.Database) error {
|
|
templateData := NewTemplateData()
|
|
|
|
// Build enum map for quick lookup
|
|
enumMap := w.buildEnumMap(db)
|
|
|
|
// Process all schemas
|
|
for _, schema := range db.Schemas {
|
|
// Add enums
|
|
for _, enum := range schema.Enums {
|
|
enumData := NewEnumData(enum, w.typeMapper)
|
|
templateData.AddEnum(enumData)
|
|
}
|
|
|
|
// Add tables
|
|
for _, table := range schema.Tables {
|
|
tableData := w.buildTableData(table, schema, db, enumMap)
|
|
templateData.AddTable(tableData)
|
|
}
|
|
}
|
|
|
|
// Add imports
|
|
w.addImports(templateData, db)
|
|
|
|
// Finalize imports
|
|
templateData.FinalizeImports()
|
|
|
|
// Generate code
|
|
code, err := w.templates.GenerateCode(templateData)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to generate code: %w", err)
|
|
}
|
|
|
|
// Write output
|
|
return w.writeOutput(code)
|
|
}
|
|
|
|
// writeMultiFile writes each table to a separate file
|
|
func (w *Writer) writeMultiFile(db *models.Database) error {
|
|
// Ensure output path is a directory
|
|
if w.options.OutputPath == "" {
|
|
return fmt.Errorf("output path is required for multi-file mode")
|
|
}
|
|
|
|
// Create output directory if it doesn't exist
|
|
if err := os.MkdirAll(w.options.OutputPath, 0755); err != nil {
|
|
return fmt.Errorf("failed to create output directory: %w", err)
|
|
}
|
|
|
|
// Build enum map for quick lookup
|
|
enumMap := w.buildEnumMap(db)
|
|
|
|
// Process all schemas
|
|
for _, schema := range db.Schemas {
|
|
// Write enums file if there are any
|
|
if len(schema.Enums) > 0 {
|
|
if err := w.writeEnumsFile(schema); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
// Write each table to a separate file
|
|
for _, table := range schema.Tables {
|
|
if err := w.writeTableFile(table, schema, db, enumMap); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// writeEnumsFile writes all enums to a separate file
|
|
func (w *Writer) writeEnumsFile(schema *models.Schema) error {
|
|
templateData := NewTemplateData()
|
|
|
|
// Add enums
|
|
for _, enum := range schema.Enums {
|
|
enumData := NewEnumData(enum, w.typeMapper)
|
|
templateData.AddEnum(enumData)
|
|
}
|
|
|
|
// Add imports for enums
|
|
templateData.AddImport("import { pgEnum } from 'drizzle-orm/pg-core';")
|
|
|
|
// Generate code
|
|
code, err := w.templates.GenerateCode(templateData)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to generate enums code: %w", err)
|
|
}
|
|
|
|
// Write to enums.ts file
|
|
filename := filepath.Join(w.options.OutputPath, "enums.ts")
|
|
return os.WriteFile(filename, []byte(code), 0644)
|
|
}
|
|
|
|
// writeTableFile writes a single table to its own file
|
|
func (w *Writer) writeTableFile(table *models.Table, schema *models.Schema, db *models.Database, enumMap map[string]bool) error {
|
|
templateData := NewTemplateData()
|
|
|
|
// Build table data
|
|
tableData := w.buildTableData(table, schema, db, enumMap)
|
|
templateData.AddTable(tableData)
|
|
|
|
// Add imports
|
|
w.addImports(templateData, db)
|
|
|
|
// If there are enums, add import from enums file
|
|
if len(schema.Enums) > 0 && w.tableUsesEnum(table, enumMap) {
|
|
// Import enum definitions from enums.ts
|
|
enumNames := w.getTableEnumNames(table, schema, enumMap)
|
|
if len(enumNames) > 0 {
|
|
importLine := fmt.Sprintf("import { %s } from './enums';", strings.Join(enumNames, ", "))
|
|
templateData.AddImport(importLine)
|
|
}
|
|
}
|
|
|
|
// Finalize imports
|
|
templateData.FinalizeImports()
|
|
|
|
// Generate code
|
|
code, err := w.templates.GenerateCode(templateData)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to generate code for table %s: %w", table.Name, err)
|
|
}
|
|
|
|
// Generate filename: {tableName}.ts
|
|
filename := filepath.Join(w.options.OutputPath, table.Name+".ts")
|
|
return os.WriteFile(filename, []byte(code), 0644)
|
|
}
|
|
|
|
// buildTableData builds TableData from a models.Table
|
|
func (w *Writer) buildTableData(table *models.Table, schema *models.Schema, db *models.Database, enumMap map[string]bool) *TableData {
|
|
tableData := NewTableData(table, w.typeMapper)
|
|
|
|
// Add columns
|
|
for _, colName := range w.getSortedColumnNames(table) {
|
|
col := table.Columns[colName]
|
|
|
|
// Check if this column uses an enum
|
|
isEnum := enumMap[col.Type]
|
|
|
|
columnData := NewColumnData(col, table, w.typeMapper, isEnum)
|
|
|
|
// Set TypeScript type
|
|
drizzleType := w.typeMapper.SQLTypeToDrizzle(col.Type)
|
|
enumName := ""
|
|
if isEnum {
|
|
// For enums, use the enum type name
|
|
enumName = col.Type
|
|
}
|
|
baseType := w.typeMapper.DrizzleTypeToTypeScript(drizzleType, isEnum, enumName)
|
|
|
|
// Add null union if column is nullable
|
|
if !col.NotNull && !col.IsPrimaryKey {
|
|
columnData.TypeScriptType = baseType + " | null"
|
|
} else {
|
|
columnData.TypeScriptType = baseType
|
|
}
|
|
|
|
// Check if this column is a foreign key
|
|
if fk := w.getForeignKeyForColumn(col.Name, table); fk != nil {
|
|
columnData.IsForeignKey = true
|
|
refTableName := fk.ReferencedTable
|
|
refChain := w.typeMapper.BuildReferencesChain(fk, refTableName)
|
|
if refChain != "" {
|
|
columnData.ReferencesLine = "." + refChain
|
|
// Append to the drizzle chain
|
|
columnData.DrizzleChain += columnData.ReferencesLine
|
|
}
|
|
}
|
|
|
|
tableData.AddColumn(columnData)
|
|
}
|
|
|
|
// Collect all column field names that are used in indexes
|
|
indexColumnFields := make(map[string]bool)
|
|
|
|
// Add indexes (excluding single-column unique indexes, which are handled inline)
|
|
for _, index := range table.Indexes {
|
|
// Skip single-column unique indexes (handled by .unique() modifier)
|
|
if index.Unique && len(index.Columns) == 1 {
|
|
continue
|
|
}
|
|
|
|
// Track which columns are used in indexes
|
|
for _, colName := range index.Columns {
|
|
// Find the field name for this column
|
|
if col, exists := table.Columns[colName]; exists {
|
|
fieldName := w.typeMapper.ToCamelCase(col.Name)
|
|
indexColumnFields[fieldName] = true
|
|
}
|
|
}
|
|
|
|
indexData := NewIndexData(index, tableData.Name, w.typeMapper)
|
|
tableData.AddIndex(indexData)
|
|
}
|
|
|
|
// Add multi-column unique constraints as unique indexes
|
|
for _, constraint := range table.Constraints {
|
|
if constraint.Type == models.UniqueConstraint && len(constraint.Columns) > 1 {
|
|
// Create a unique index for this constraint
|
|
indexData := &IndexData{
|
|
Name: w.typeMapper.ToCamelCase(constraint.Name) + "Idx",
|
|
Columns: constraint.Columns,
|
|
IsUnique: true,
|
|
}
|
|
|
|
// Track which columns are used in indexes
|
|
for _, colName := range constraint.Columns {
|
|
if col, exists := table.Columns[colName]; exists {
|
|
fieldName := w.typeMapper.ToCamelCase(col.Name)
|
|
indexColumnFields[fieldName] = true
|
|
}
|
|
}
|
|
|
|
// Build column references as field names (for destructuring)
|
|
colRefs := make([]string, len(constraint.Columns))
|
|
for i, colName := range constraint.Columns {
|
|
if col, exists := table.Columns[colName]; exists {
|
|
colRefs[i] = w.typeMapper.ToCamelCase(col.Name)
|
|
} else {
|
|
colRefs[i] = w.typeMapper.ToCamelCase(colName)
|
|
}
|
|
}
|
|
|
|
indexData.Definition = "uniqueIndex('" + constraint.Name + "').on(" + joinStrings(colRefs, ", ") + ")"
|
|
tableData.AddIndex(indexData)
|
|
}
|
|
}
|
|
|
|
// Convert index column fields map to sorted slice
|
|
if len(indexColumnFields) > 0 {
|
|
fields := make([]string, 0, len(indexColumnFields))
|
|
for field := range indexColumnFields {
|
|
fields = append(fields, field)
|
|
}
|
|
// Sort for consistent output
|
|
sortStrings(fields)
|
|
tableData.IndexColumnFields = fields
|
|
}
|
|
|
|
return tableData
|
|
}
|
|
|
|
// sortStrings sorts a slice of strings in place
|
|
func sortStrings(strs []string) {
|
|
for i := 0; i < len(strs); i++ {
|
|
for j := i + 1; j < len(strs); j++ {
|
|
if strs[i] > strs[j] {
|
|
strs[i], strs[j] = strs[j], strs[i]
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// addImports adds the necessary imports to the template data
|
|
func (w *Writer) addImports(templateData *TemplateData, db *models.Database) {
|
|
// Determine which Drizzle imports we need
|
|
needsPgTable := len(templateData.Tables) > 0
|
|
needsPgEnum := len(templateData.Enums) > 0
|
|
needsIndex := false
|
|
needsUniqueIndex := false
|
|
needsSQL := false
|
|
|
|
// Check what we need based on tables
|
|
for _, table := range templateData.Tables {
|
|
for _, index := range table.Indexes {
|
|
if index.IsUnique {
|
|
needsUniqueIndex = true
|
|
} else {
|
|
needsIndex = true
|
|
}
|
|
}
|
|
|
|
// Check if any column uses SQL default values
|
|
for _, col := range table.Columns {
|
|
if strings.Contains(col.DrizzleChain, "sql`") {
|
|
needsSQL = true
|
|
}
|
|
}
|
|
}
|
|
|
|
// Build the import statement
|
|
imports := make([]string, 0)
|
|
|
|
if needsPgTable {
|
|
imports = append(imports, "pgTable")
|
|
}
|
|
if needsPgEnum {
|
|
imports = append(imports, "pgEnum")
|
|
}
|
|
|
|
// Add column types - for now, add common ones
|
|
// TODO: Could be optimized to only include used types
|
|
columnTypes := []string{
|
|
"integer", "bigint", "smallint",
|
|
"serial", "bigserial", "smallserial",
|
|
"text", "varchar", "char",
|
|
"boolean", "numeric", "real", "doublePrecision",
|
|
"timestamp", "date", "time", "interval",
|
|
"json", "jsonb", "uuid", "bytea",
|
|
}
|
|
imports = append(imports, columnTypes...)
|
|
|
|
if needsIndex {
|
|
imports = append(imports, "index")
|
|
}
|
|
if needsUniqueIndex {
|
|
imports = append(imports, "uniqueIndex")
|
|
}
|
|
|
|
importLine := "import { " + strings.Join(imports, ", ") + " } from 'drizzle-orm/pg-core';"
|
|
templateData.AddImport(importLine)
|
|
|
|
// Add SQL import if needed
|
|
if needsSQL {
|
|
templateData.AddImport("import { sql } from 'drizzle-orm';")
|
|
}
|
|
}
|
|
|
|
// buildEnumMap builds a map of enum type names for quick lookup
|
|
func (w *Writer) buildEnumMap(db *models.Database) map[string]bool {
|
|
enumMap := make(map[string]bool)
|
|
|
|
for _, schema := range db.Schemas {
|
|
for _, enum := range schema.Enums {
|
|
enumMap[enum.Name] = true
|
|
// Also add lowercase version for case-insensitive lookup
|
|
enumMap[strings.ToLower(enum.Name)] = true
|
|
}
|
|
}
|
|
|
|
return enumMap
|
|
}
|
|
|
|
// tableUsesEnum checks if a table uses any enum types
|
|
func (w *Writer) tableUsesEnum(table *models.Table, enumMap map[string]bool) bool {
|
|
for _, col := range table.Columns {
|
|
if enumMap[col.Type] || enumMap[strings.ToLower(col.Type)] {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// getTableEnumNames returns the list of enum variable names used by a table
|
|
func (w *Writer) getTableEnumNames(table *models.Table, schema *models.Schema, enumMap map[string]bool) []string {
|
|
enumNames := make([]string, 0)
|
|
seen := make(map[string]bool)
|
|
|
|
for _, col := range table.Columns {
|
|
if enumMap[col.Type] || enumMap[strings.ToLower(col.Type)] {
|
|
// Find the enum in schema
|
|
for _, enum := range schema.Enums {
|
|
if strings.EqualFold(enum.Name, col.Type) {
|
|
varName := w.typeMapper.ToCamelCase(enum.Name)
|
|
if !seen[varName] {
|
|
enumNames = append(enumNames, varName)
|
|
seen[varName] = true
|
|
}
|
|
break
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return enumNames
|
|
}
|
|
|
|
// getSortedColumnNames returns column names sorted by sequence or name
|
|
func (w *Writer) getSortedColumnNames(table *models.Table) []string {
|
|
// Convert map to slice
|
|
columns := make([]*models.Column, 0, len(table.Columns))
|
|
for _, col := range table.Columns {
|
|
columns = append(columns, col)
|
|
}
|
|
|
|
// Sort by sequence, then by primary key, then by name
|
|
// (Similar to GORM writer)
|
|
sortColumns := func(i, j int) bool {
|
|
// Sort by sequence if both have it
|
|
if columns[i].Sequence > 0 && columns[j].Sequence > 0 {
|
|
return columns[i].Sequence < columns[j].Sequence
|
|
}
|
|
|
|
// Put primary keys first
|
|
if columns[i].IsPrimaryKey != columns[j].IsPrimaryKey {
|
|
return columns[i].IsPrimaryKey
|
|
}
|
|
|
|
// Otherwise sort alphabetically
|
|
return columns[i].Name < columns[j].Name
|
|
}
|
|
|
|
// Create a custom sorter
|
|
for i := 0; i < len(columns); i++ {
|
|
for j := i + 1; j < len(columns); j++ {
|
|
if !sortColumns(i, j) {
|
|
columns[i], columns[j] = columns[j], columns[i]
|
|
}
|
|
}
|
|
}
|
|
|
|
// Extract names
|
|
names := make([]string, len(columns))
|
|
for i, col := range columns {
|
|
names[i] = col.Name
|
|
}
|
|
|
|
return names
|
|
}
|
|
|
|
// getForeignKeyForColumn returns the foreign key constraint for a column, if any
|
|
func (w *Writer) getForeignKeyForColumn(columnName string, table *models.Table) *models.Constraint {
|
|
for _, constraint := range table.Constraints {
|
|
if constraint.Type == models.ForeignKeyConstraint {
|
|
for _, col := range constraint.Columns {
|
|
if col == columnName {
|
|
return constraint
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// writeOutput writes the content to file or stdout
|
|
func (w *Writer) writeOutput(content string) error {
|
|
if w.options.OutputPath != "" {
|
|
return os.WriteFile(w.options.OutputPath, []byte(content), 0644)
|
|
}
|
|
|
|
// Print to stdout
|
|
fmt.Print(content)
|
|
return nil
|
|
}
|
|
|
|
// shouldUseMultiFile determines whether to use multi-file mode based on metadata or output path
|
|
func (w *Writer) shouldUseMultiFile() bool {
|
|
// Check if multi_file is explicitly set in metadata
|
|
if w.options.Metadata != nil {
|
|
if mf, ok := w.options.Metadata["multi_file"].(bool); ok {
|
|
return mf
|
|
}
|
|
}
|
|
|
|
// Auto-detect based on output path
|
|
if w.options.OutputPath == "" {
|
|
// No output path means stdout (single file)
|
|
return false
|
|
}
|
|
|
|
// Check if path ends with .ts (explicit file)
|
|
if strings.HasSuffix(w.options.OutputPath, ".ts") {
|
|
return false
|
|
}
|
|
|
|
// Check if path ends with directory separator
|
|
if strings.HasSuffix(w.options.OutputPath, "/") || strings.HasSuffix(w.options.OutputPath, "\\") {
|
|
return true
|
|
}
|
|
|
|
// Check if path exists and is a directory
|
|
info, err := os.Stat(w.options.OutputPath)
|
|
if err == nil && info.IsDir() {
|
|
return true
|
|
}
|
|
|
|
// Default to single file for ambiguous cases
|
|
return false
|
|
}
|