Files
relspecgo/pkg/writers/drizzle/writer.go
Hein f258f8baeb
All checks were successful
CI / Test (1.24) (push) Successful in -27m23s
CI / Test (1.25) (push) Successful in -27m16s
CI / Build (push) Successful in -27m40s
CI / Lint (push) Successful in -27m29s
Release / Build and Release (push) Successful in -27m21s
Integration Tests / Integration Tests (push) Successful in -27m17s
feat(writer): 🎉 Add filename sanitization for DBML identifiers
* Implement SanitizeFilename function to clean identifiers
* Remove quotes, comments, and invalid characters from filenames
* Update filename generation in writers to use sanitized names
2026-01-10 13:32:33 +02:00

546 lines
15 KiB
Go

package drizzle
import (
"fmt"
"os"
"path/filepath"
"strings"
"git.warky.dev/wdevs/relspecgo/pkg/models"
"git.warky.dev/wdevs/relspecgo/pkg/writers"
)
// Writer implements the writers.Writer interface for Drizzle ORM
type Writer struct {
options *writers.WriterOptions
typeMapper *TypeMapper
templates *Templates
}
// NewWriter creates a new Drizzle writer with the given options
func NewWriter(options *writers.WriterOptions) *Writer {
w := &Writer{
options: options,
typeMapper: NewTypeMapper(),
}
// Initialize templates
tmpl, err := NewTemplates()
if err != nil {
// Should not happen with embedded templates
panic(fmt.Sprintf("failed to initialize templates: %v", err))
}
w.templates = tmpl
return w
}
// WriteDatabase writes a complete database as Drizzle schema
func (w *Writer) WriteDatabase(db *models.Database) error {
// Check if multi-file mode is enabled
multiFile := w.shouldUseMultiFile()
if multiFile {
return w.writeMultiFile(db)
}
return w.writeSingleFile(db)
}
// WriteSchema writes a schema as Drizzle schema
func (w *Writer) WriteSchema(schema *models.Schema) error {
// Create a temporary database with just this schema
db := models.InitDatabase(schema.Name)
db.Schemas = []*models.Schema{schema}
return w.WriteDatabase(db)
}
// WriteTable writes a single table as a Drizzle schema
func (w *Writer) WriteTable(table *models.Table) error {
// Create a temporary schema and database
schema := models.InitSchema(table.Schema)
schema.Tables = []*models.Table{table}
db := models.InitDatabase(schema.Name)
db.Schemas = []*models.Schema{schema}
return w.WriteDatabase(db)
}
// writeSingleFile writes all tables to a single file
func (w *Writer) writeSingleFile(db *models.Database) error {
templateData := NewTemplateData()
// Build enum map for quick lookup
enumMap := w.buildEnumMap(db)
// Process all schemas
for _, schema := range db.Schemas {
// Add enums
for _, enum := range schema.Enums {
enumData := NewEnumData(enum, w.typeMapper)
templateData.AddEnum(enumData)
}
// Add tables
for _, table := range schema.Tables {
tableData := w.buildTableData(table, schema, db, enumMap)
templateData.AddTable(tableData)
}
}
// Add imports
w.addImports(templateData, db)
// Finalize imports
templateData.FinalizeImports()
// Generate code
code, err := w.templates.GenerateCode(templateData)
if err != nil {
return fmt.Errorf("failed to generate code: %w", err)
}
// Write output
return w.writeOutput(code)
}
// writeMultiFile writes each table to a separate file
func (w *Writer) writeMultiFile(db *models.Database) error {
// Ensure output path is a directory
if w.options.OutputPath == "" {
return fmt.Errorf("output path is required for multi-file mode")
}
// Create output directory if it doesn't exist
if err := os.MkdirAll(w.options.OutputPath, 0755); err != nil {
return fmt.Errorf("failed to create output directory: %w", err)
}
// Build enum map for quick lookup
enumMap := w.buildEnumMap(db)
// Process all schemas
for _, schema := range db.Schemas {
// Write enums file if there are any
if len(schema.Enums) > 0 {
if err := w.writeEnumsFile(schema); err != nil {
return err
}
}
// Write each table to a separate file
for _, table := range schema.Tables {
if err := w.writeTableFile(table, schema, db, enumMap); err != nil {
return err
}
}
}
return nil
}
// writeEnumsFile writes all enums to a separate file
func (w *Writer) writeEnumsFile(schema *models.Schema) error {
templateData := NewTemplateData()
// Add enums
for _, enum := range schema.Enums {
enumData := NewEnumData(enum, w.typeMapper)
templateData.AddEnum(enumData)
}
// Add imports for enums
templateData.AddImport("import { pgEnum } from 'drizzle-orm/pg-core';")
// Generate code
code, err := w.templates.GenerateCode(templateData)
if err != nil {
return fmt.Errorf("failed to generate enums code: %w", err)
}
// Write to enums.ts file
filename := filepath.Join(w.options.OutputPath, "enums.ts")
return os.WriteFile(filename, []byte(code), 0644)
}
// writeTableFile writes a single table to its own file
func (w *Writer) writeTableFile(table *models.Table, schema *models.Schema, db *models.Database, enumMap map[string]bool) error {
templateData := NewTemplateData()
// Build table data
tableData := w.buildTableData(table, schema, db, enumMap)
templateData.AddTable(tableData)
// Add imports
w.addImports(templateData, db)
// If there are enums, add import from enums file
if len(schema.Enums) > 0 && w.tableUsesEnum(table, enumMap) {
// Import enum definitions from enums.ts
enumNames := w.getTableEnumNames(table, schema, enumMap)
if len(enumNames) > 0 {
importLine := fmt.Sprintf("import { %s } from './enums';", strings.Join(enumNames, ", "))
templateData.AddImport(importLine)
}
}
// Finalize imports
templateData.FinalizeImports()
// Generate code
code, err := w.templates.GenerateCode(templateData)
if err != nil {
return fmt.Errorf("failed to generate code for table %s: %w", table.Name, err)
}
// Generate filename: {tableName}.ts
// Sanitize table name to remove quotes, comments, and invalid characters
safeTableName := writers.SanitizeFilename(table.Name)
filename := filepath.Join(w.options.OutputPath, safeTableName+".ts")
return os.WriteFile(filename, []byte(code), 0644)
}
// buildTableData builds TableData from a models.Table
func (w *Writer) buildTableData(table *models.Table, schema *models.Schema, db *models.Database, enumMap map[string]bool) *TableData {
tableData := NewTableData(table, w.typeMapper)
// Add columns
for _, colName := range w.getSortedColumnNames(table) {
col := table.Columns[colName]
// Check if this column uses an enum
isEnum := enumMap[col.Type]
columnData := NewColumnData(col, table, w.typeMapper, isEnum)
// Set TypeScript type
drizzleType := w.typeMapper.SQLTypeToDrizzle(col.Type)
enumName := ""
if isEnum {
// For enums, use the enum type name
enumName = col.Type
}
baseType := w.typeMapper.DrizzleTypeToTypeScript(drizzleType, isEnum, enumName)
// Add null union if column is nullable
if !col.NotNull && !col.IsPrimaryKey {
columnData.TypeScriptType = baseType + " | null"
} else {
columnData.TypeScriptType = baseType
}
// Check if this column is a foreign key
if fk := w.getForeignKeyForColumn(col.Name, table); fk != nil {
columnData.IsForeignKey = true
refTableName := fk.ReferencedTable
refChain := w.typeMapper.BuildReferencesChain(fk, refTableName)
if refChain != "" {
columnData.ReferencesLine = "." + refChain
// Append to the drizzle chain
columnData.DrizzleChain += columnData.ReferencesLine
}
}
tableData.AddColumn(columnData)
}
// Collect all column field names that are used in indexes
indexColumnFields := make(map[string]bool)
// Add indexes (excluding single-column unique indexes, which are handled inline)
for _, index := range table.Indexes {
// Skip single-column unique indexes (handled by .unique() modifier)
if index.Unique && len(index.Columns) == 1 {
continue
}
// Track which columns are used in indexes
for _, colName := range index.Columns {
// Find the field name for this column
if col, exists := table.Columns[colName]; exists {
fieldName := w.typeMapper.ToCamelCase(col.Name)
indexColumnFields[fieldName] = true
}
}
indexData := NewIndexData(index, tableData.Name, w.typeMapper)
tableData.AddIndex(indexData)
}
// Add multi-column unique constraints as unique indexes
for _, constraint := range table.Constraints {
if constraint.Type == models.UniqueConstraint && len(constraint.Columns) > 1 {
// Create a unique index for this constraint
indexData := &IndexData{
Name: w.typeMapper.ToCamelCase(constraint.Name) + "Idx",
Columns: constraint.Columns,
IsUnique: true,
}
// Track which columns are used in indexes
for _, colName := range constraint.Columns {
if col, exists := table.Columns[colName]; exists {
fieldName := w.typeMapper.ToCamelCase(col.Name)
indexColumnFields[fieldName] = true
}
}
// Build column references as field names (for destructuring)
colRefs := make([]string, len(constraint.Columns))
for i, colName := range constraint.Columns {
if col, exists := table.Columns[colName]; exists {
colRefs[i] = w.typeMapper.ToCamelCase(col.Name)
} else {
colRefs[i] = w.typeMapper.ToCamelCase(colName)
}
}
indexData.Definition = "uniqueIndex('" + constraint.Name + "').on(" + joinStrings(colRefs, ", ") + ")"
tableData.AddIndex(indexData)
}
}
// Convert index column fields map to sorted slice
if len(indexColumnFields) > 0 {
fields := make([]string, 0, len(indexColumnFields))
for field := range indexColumnFields {
fields = append(fields, field)
}
// Sort for consistent output
sortStrings(fields)
tableData.IndexColumnFields = fields
}
return tableData
}
// sortStrings sorts a slice of strings in place
func sortStrings(strs []string) {
for i := 0; i < len(strs); i++ {
for j := i + 1; j < len(strs); j++ {
if strs[i] > strs[j] {
strs[i], strs[j] = strs[j], strs[i]
}
}
}
}
// addImports adds the necessary imports to the template data
func (w *Writer) addImports(templateData *TemplateData, db *models.Database) {
// Determine which Drizzle imports we need
needsPgTable := len(templateData.Tables) > 0
needsPgEnum := len(templateData.Enums) > 0
needsIndex := false
needsUniqueIndex := false
needsSQL := false
// Check what we need based on tables
for _, table := range templateData.Tables {
for _, index := range table.Indexes {
if index.IsUnique {
needsUniqueIndex = true
} else {
needsIndex = true
}
}
// Check if any column uses SQL default values
for _, col := range table.Columns {
if strings.Contains(col.DrizzleChain, "sql`") {
needsSQL = true
}
}
}
// Build the import statement
imports := make([]string, 0)
if needsPgTable {
imports = append(imports, "pgTable")
}
if needsPgEnum {
imports = append(imports, "pgEnum")
}
// Add column types - for now, add common ones
// TODO: Could be optimized to only include used types
columnTypes := []string{
"integer", "bigint", "smallint",
"serial", "bigserial", "smallserial",
"text", "varchar", "char",
"boolean", "numeric", "real", "doublePrecision",
"timestamp", "date", "time", "interval",
"json", "jsonb", "uuid", "bytea",
}
imports = append(imports, columnTypes...)
if needsIndex {
imports = append(imports, "index")
}
if needsUniqueIndex {
imports = append(imports, "uniqueIndex")
}
importLine := "import { " + strings.Join(imports, ", ") + " } from 'drizzle-orm/pg-core';"
templateData.AddImport(importLine)
// Add SQL import if needed
if needsSQL {
templateData.AddImport("import { sql } from 'drizzle-orm';")
}
}
// buildEnumMap builds a map of enum type names for quick lookup
func (w *Writer) buildEnumMap(db *models.Database) map[string]bool {
enumMap := make(map[string]bool)
for _, schema := range db.Schemas {
for _, enum := range schema.Enums {
enumMap[enum.Name] = true
// Also add lowercase version for case-insensitive lookup
enumMap[strings.ToLower(enum.Name)] = true
}
}
return enumMap
}
// tableUsesEnum checks if a table uses any enum types
func (w *Writer) tableUsesEnum(table *models.Table, enumMap map[string]bool) bool {
for _, col := range table.Columns {
if enumMap[col.Type] || enumMap[strings.ToLower(col.Type)] {
return true
}
}
return false
}
// getTableEnumNames returns the list of enum variable names used by a table
func (w *Writer) getTableEnumNames(table *models.Table, schema *models.Schema, enumMap map[string]bool) []string {
enumNames := make([]string, 0)
seen := make(map[string]bool)
for _, col := range table.Columns {
if enumMap[col.Type] || enumMap[strings.ToLower(col.Type)] {
// Find the enum in schema
for _, enum := range schema.Enums {
if strings.EqualFold(enum.Name, col.Type) {
varName := w.typeMapper.ToCamelCase(enum.Name)
if !seen[varName] {
enumNames = append(enumNames, varName)
seen[varName] = true
}
break
}
}
}
}
return enumNames
}
// getSortedColumnNames returns column names sorted by sequence or name
func (w *Writer) getSortedColumnNames(table *models.Table) []string {
// Convert map to slice
columns := make([]*models.Column, 0, len(table.Columns))
for _, col := range table.Columns {
columns = append(columns, col)
}
// Sort by sequence, then by primary key, then by name
// (Similar to GORM writer)
sortColumns := func(i, j int) bool {
// Sort by sequence if both have it
if columns[i].Sequence > 0 && columns[j].Sequence > 0 {
return columns[i].Sequence < columns[j].Sequence
}
// Put primary keys first
if columns[i].IsPrimaryKey != columns[j].IsPrimaryKey {
return columns[i].IsPrimaryKey
}
// Otherwise sort alphabetically
return columns[i].Name < columns[j].Name
}
// Create a custom sorter
for i := 0; i < len(columns); i++ {
for j := i + 1; j < len(columns); j++ {
if !sortColumns(i, j) {
columns[i], columns[j] = columns[j], columns[i]
}
}
}
// Extract names
names := make([]string, len(columns))
for i, col := range columns {
names[i] = col.Name
}
return names
}
// getForeignKeyForColumn returns the foreign key constraint for a column, if any
func (w *Writer) getForeignKeyForColumn(columnName string, table *models.Table) *models.Constraint {
for _, constraint := range table.Constraints {
if constraint.Type == models.ForeignKeyConstraint {
for _, col := range constraint.Columns {
if col == columnName {
return constraint
}
}
}
}
return nil
}
// writeOutput writes the content to file or stdout
func (w *Writer) writeOutput(content string) error {
if w.options.OutputPath != "" {
return os.WriteFile(w.options.OutputPath, []byte(content), 0644)
}
// Print to stdout
fmt.Print(content)
return nil
}
// shouldUseMultiFile determines whether to use multi-file mode based on metadata or output path
func (w *Writer) shouldUseMultiFile() bool {
// Check if multi_file is explicitly set in metadata
if w.options.Metadata != nil {
if mf, ok := w.options.Metadata["multi_file"].(bool); ok {
return mf
}
}
// Auto-detect based on output path
if w.options.OutputPath == "" {
// No output path means stdout (single file)
return false
}
// Check if path ends with .ts (explicit file)
if strings.HasSuffix(w.options.OutputPath, ".ts") {
return false
}
// Check if path ends with directory separator
if strings.HasSuffix(w.options.OutputPath, "/") || strings.HasSuffix(w.options.OutputPath, "\\") {
return true
}
// Check if path exists and is a directory
info, err := os.Stat(w.options.OutputPath)
if err == nil && info.IsDir() {
return true
}
// Default to single file for ambiguous cases
return false
}