Added Drizzle ORM support
Some checks failed
CI / Test (1.24) (push) Failing after -24m8s
CI / Test (1.25) (push) Failing after -23m54s
CI / Lint (push) Failing after -25m2s
CI / Build (push) Successful in -25m18s

This commit is contained in:
2025-12-28 10:15:30 +02:00
parent aad5db5175
commit 35bc9dfb5c
11 changed files with 2075 additions and 3 deletions

View File

@@ -0,0 +1,221 @@
package drizzle
import (
"sort"
"git.warky.dev/wdevs/relspecgo/pkg/models"
)
// TemplateData represents the data passed to the template for code generation
type TemplateData struct {
Imports []string
Enums []*EnumData
Tables []*TableData
}
// EnumData represents an enum in the schema
type EnumData struct {
Name string // Enum name (PascalCase)
VarName string // Variable name for the enum (camelCase)
Values []string // Enum values
ValuesStr string // Comma-separated quoted values for pgEnum()
TypeUnion string // TypeScript union type (e.g., "'admin' | 'user' | 'guest'")
SchemaName string // Schema name
}
// TableData represents a table in the template
type TableData struct {
Name string // Table variable name (camelCase, e.g., users)
TableName string // Actual database table name (e.g., users)
TypeName string // TypeScript type name (PascalCase, e.g., Users)
Columns []*ColumnData // Column definitions
Indexes []*IndexData // Index definitions
Comment string // Table comment
SchemaName string // Schema name
NeedsSQLTag bool // Whether we need to import 'sql' from drizzle-orm
IndexColumnFields []string // Column field names used in indexes (for destructuring)
}
// ColumnData represents a column in a table
type ColumnData struct {
Name string // Column name in database
FieldName string // Field name in TypeScript (camelCase)
DrizzleChain string // Complete Drizzle column chain (e.g., "integer('id').primaryKey()")
TypeScriptType string // TypeScript type for interface (e.g., "string", "number | null")
IsForeignKey bool // Whether this is a foreign key
ReferencesLine string // The .references() line if FK
Comment string // Column comment
}
// IndexData represents an index definition
type IndexData struct {
Name string // Index name
Columns []string // Column names
IsUnique bool // Whether it's a unique index
Definition string // Complete index definition line
}
// NewTemplateData creates a new TemplateData
func NewTemplateData() *TemplateData {
return &TemplateData{
Imports: make([]string, 0),
Enums: make([]*EnumData, 0),
Tables: make([]*TableData, 0),
}
}
// AddImport adds an import to the template data (deduplicates automatically)
func (td *TemplateData) AddImport(importLine string) {
// Check if already exists
for _, imp := range td.Imports {
if imp == importLine {
return
}
}
td.Imports = append(td.Imports, importLine)
}
// AddEnum adds an enum to the template data
func (td *TemplateData) AddEnum(enum *EnumData) {
td.Enums = append(td.Enums, enum)
}
// AddTable adds a table to the template data
func (td *TemplateData) AddTable(table *TableData) {
td.Tables = append(td.Tables, table)
}
// FinalizeImports sorts imports
func (td *TemplateData) FinalizeImports() {
sort.Strings(td.Imports)
}
// NewEnumData creates EnumData from a models.Enum
func NewEnumData(enum *models.Enum, tm *TypeMapper) *EnumData {
// Keep enum name as-is (it should already be PascalCase from the source)
enumName := enum.Name
// Variable name is camelCase version
varName := tm.ToCamelCase(enum.Name)
// Format values as comma-separated quoted strings for pgEnum()
quotedValues := make([]string, len(enum.Values))
for i, v := range enum.Values {
quotedValues[i] = "'" + v + "'"
}
valuesStr := ""
for i, qv := range quotedValues {
if i > 0 {
valuesStr += ", "
}
valuesStr += qv
}
// Build TypeScript union type (e.g., "'admin' | 'user' | 'guest'")
typeUnion := ""
for i, qv := range quotedValues {
if i > 0 {
typeUnion += " | "
}
typeUnion += qv
}
return &EnumData{
Name: enumName,
VarName: varName,
Values: enum.Values,
ValuesStr: valuesStr,
TypeUnion: typeUnion,
SchemaName: enum.Schema,
}
}
// NewTableData creates TableData from a models.Table
func NewTableData(table *models.Table, tm *TypeMapper) *TableData {
tableName := tm.ToCamelCase(table.Name)
typeName := tm.ToPascalCase(table.Name)
return &TableData{
Name: tableName,
TableName: table.Name,
TypeName: typeName,
Columns: make([]*ColumnData, 0),
Indexes: make([]*IndexData, 0),
Comment: formatComment(table.Description, table.Comment),
SchemaName: table.Schema,
}
}
// AddColumn adds a column to the table data
func (td *TableData) AddColumn(col *ColumnData) {
td.Columns = append(td.Columns, col)
}
// AddIndex adds an index to the table data
func (td *TableData) AddIndex(idx *IndexData) {
td.Indexes = append(td.Indexes, idx)
}
// NewColumnData creates ColumnData from a models.Column
func NewColumnData(col *models.Column, table *models.Table, tm *TypeMapper, isEnum bool) *ColumnData {
fieldName := tm.ToCamelCase(col.Name)
drizzleChain := tm.BuildColumnChain(col, table, isEnum)
return &ColumnData{
Name: col.Name,
FieldName: fieldName,
DrizzleChain: drizzleChain,
Comment: formatComment(col.Description, col.Comment),
}
}
// NewIndexData creates IndexData from a models.Index
func NewIndexData(index *models.Index, tableVar string, tm *TypeMapper) *IndexData {
indexName := tm.ToCamelCase(index.Name) + "Idx"
// Build column references as field names (will be used with destructuring)
colRefs := make([]string, len(index.Columns))
for i, colName := range index.Columns {
// Use just the field name for destructured parameters
colRefs[i] = tm.ToCamelCase(colName)
}
// Build the complete definition
// Example: index('email_idx').on(email)
// or: uniqueIndex('unique_email_idx').on(email)
definition := ""
if index.Unique {
definition = "uniqueIndex('" + index.Name + "').on(" + joinStrings(colRefs, ", ") + ")"
} else {
definition = "index('" + index.Name + "').on(" + joinStrings(colRefs, ", ") + ")"
}
return &IndexData{
Name: indexName,
Columns: index.Columns,
IsUnique: index.Unique,
Definition: definition,
}
}
// formatComment combines description and comment into a single comment string
func formatComment(description, comment string) string {
if description != "" && comment != "" {
return description + " - " + comment
}
if description != "" {
return description
}
return comment
}
// joinStrings joins a slice of strings with a separator
func joinStrings(strs []string, sep string) string {
result := ""
for i, s := range strs {
if i > 0 {
result += sep
}
result += s
}
return result
}

View File

@@ -0,0 +1,64 @@
package drizzle
import (
"bytes"
"text/template"
)
// schemaTemplate defines the template for generating Drizzle schemas
const schemaTemplate = `// Code generated by relspecgo. DO NOT EDIT.
{{range .Imports}}{{.}}
{{end}}
{{if .Enums}}
// Enums
{{range .Enums}}export const {{.VarName}} = pgEnum('{{.Name}}', [{{.ValuesStr}}]);
export type {{.Name}} = {{.TypeUnion}};
{{end}}
{{end}}
{{range .Tables}}// Table: {{.TableName}}{{if .Comment}} - {{.Comment}}{{end}}
export interface {{.TypeName}} {
{{- range $i, $col := .Columns}}
{{$col.FieldName}}: {{$col.TypeScriptType}};{{if $col.Comment}} // {{$col.Comment}}{{end}}
{{- end}}
}
export const {{.Name}} = pgTable('{{.TableName}}', {
{{- range $i, $col := .Columns}}
{{$col.FieldName}}: {{$col.DrizzleChain}},{{if $col.Comment}} // {{$col.Comment}}{{end}}
{{- end}}
}{{if .Indexes}}{{if .IndexColumnFields}}, ({ {{range $i, $field := .IndexColumnFields}}{{if $i}}, {{end}}{{$field}}{{end}} }) => [{{else}}, (table) => [{{end}}
{{- range $i, $idx := .Indexes}}
{{$idx.Definition}},
{{- end}}
]{{end}});
export type New{{.TypeName}} = typeof {{.Name}}.$inferInsert;
{{end}}`
// Templates holds the parsed templates
type Templates struct {
schemaTmpl *template.Template
}
// NewTemplates creates and parses the templates
func NewTemplates() (*Templates, error) {
schemaTmpl, err := template.New("schema").Parse(schemaTemplate)
if err != nil {
return nil, err
}
return &Templates{
schemaTmpl: schemaTmpl,
}, nil
}
// GenerateCode executes the template with the given data
func (t *Templates) GenerateCode(data *TemplateData) (string, error) {
var buf bytes.Buffer
err := t.schemaTmpl.Execute(&buf, data)
if err != nil {
return "", err
}
return buf.String(), nil
}

View File

@@ -0,0 +1,318 @@
package drizzle
import (
"fmt"
"strings"
"git.warky.dev/wdevs/relspecgo/pkg/models"
)
// TypeMapper handles SQL to Drizzle type conversions
type TypeMapper struct{}
// NewTypeMapper creates a new TypeMapper instance
func NewTypeMapper() *TypeMapper {
return &TypeMapper{}
}
// SQLTypeToDrizzle converts SQL types to Drizzle column type functions
// Returns the Drizzle column constructor (e.g., "integer", "varchar", "text")
func (tm *TypeMapper) SQLTypeToDrizzle(sqlType string) string {
sqlTypeLower := strings.ToLower(sqlType)
// PostgreSQL type mapping to Drizzle
typeMap := map[string]string{
// Integer types
"integer": "integer",
"int": "integer",
"int4": "integer",
"smallint": "smallint",
"int2": "smallint",
"bigint": "bigint",
"int8": "bigint",
// Serial types
"serial": "serial",
"serial4": "serial",
"smallserial": "smallserial",
"serial2": "smallserial",
"bigserial": "bigserial",
"serial8": "bigserial",
// Numeric types
"numeric": "numeric",
"decimal": "numeric",
"real": "real",
"float4": "real",
"double precision": "doublePrecision",
"float": "doublePrecision",
"float8": "doublePrecision",
// Character types
"text": "text",
"varchar": "varchar",
"character varying": "varchar",
"char": "char",
"character": "char",
// Boolean
"boolean": "boolean",
"bool": "boolean",
// Binary
"bytea": "bytea",
// JSON types
"json": "json",
"jsonb": "jsonb",
// Date/Time types
"time": "time",
"timetz": "time",
"timestamp": "timestamp",
"timestamptz": "timestamp",
"date": "date",
"interval": "interval",
// UUID
"uuid": "uuid",
// Geometric types
"point": "point",
"line": "line",
}
// Check for exact match first
if drizzleType, ok := typeMap[sqlTypeLower]; ok {
return drizzleType
}
// Check for partial matches (e.g., "varchar(255)" -> "varchar")
for sqlPattern, drizzleType := range typeMap {
if strings.HasPrefix(sqlTypeLower, sqlPattern) {
return drizzleType
}
}
// Default to text for unknown types
return "text"
}
// BuildColumnChain builds the complete column definition chain for Drizzle
// Example: integer('id').primaryKey().notNull()
func (tm *TypeMapper) BuildColumnChain(col *models.Column, table *models.Table, isEnum bool) string {
var parts []string
// Determine Drizzle column type
var drizzleType string
if isEnum {
// For enum types, use the type name directly
drizzleType = fmt.Sprintf("pgEnum('%s')", col.Type)
} else {
drizzleType = tm.SQLTypeToDrizzle(col.Type)
}
// Start with column type and name
// Note: column name is passed as first argument to the column constructor
base := fmt.Sprintf("%s('%s')", drizzleType, col.Name)
parts = append(parts, base)
// Add column modifiers in order
modifiers := tm.buildColumnModifiers(col, table)
if len(modifiers) > 0 {
parts = append(parts, modifiers...)
}
return strings.Join(parts, ".")
}
// buildColumnModifiers builds an array of method calls for column modifiers
func (tm *TypeMapper) buildColumnModifiers(col *models.Column, table *models.Table) []string {
var modifiers []string
// Primary key
if col.IsPrimaryKey {
modifiers = append(modifiers, "primaryKey()")
}
// Not null constraint
if col.NotNull && !col.IsPrimaryKey {
modifiers = append(modifiers, "notNull()")
}
// Unique constraint (check if there's a single-column unique constraint)
if tm.hasUniqueConstraint(col.Name, table) {
modifiers = append(modifiers, "unique()")
}
// Default value
if col.AutoIncrement {
// For auto-increment, use generatedAlwaysAsIdentity()
modifiers = append(modifiers, "generatedAlwaysAsIdentity()")
} else if col.Default != nil {
defaultValue := tm.formatDefaultValue(col.Default)
if defaultValue != "" {
modifiers = append(modifiers, fmt.Sprintf("default(%s)", defaultValue))
}
}
return modifiers
}
// formatDefaultValue formats a default value for Drizzle
func (tm *TypeMapper) formatDefaultValue(defaultValue any) string {
switch v := defaultValue.(type) {
case string:
if v == "now()" || v == "CURRENT_TIMESTAMP" {
return "sql`now()`"
} else if v == "gen_random_uuid()" || strings.Contains(strings.ToLower(v), "uuid") {
return "sql`gen_random_uuid()`"
} else {
// Try to parse as number first
// Check if it's a numeric string that should be a number
if isNumericString(v) {
return v
}
// String literal
return fmt.Sprintf("'%s'", strings.ReplaceAll(v, "'", "\\'"))
}
case bool:
if v {
return "true"
}
return "false"
case int, int64, int32, int16, int8:
return fmt.Sprintf("%v", v)
case float32, float64:
return fmt.Sprintf("%v", v)
default:
return fmt.Sprintf("%v", v)
}
}
// isNumericString checks if a string represents a number
func isNumericString(s string) bool {
if s == "" {
return false
}
// Simple check for numeric strings
for i, c := range s {
if i == 0 && c == '-' {
continue // Allow negative sign at start
}
if c < '0' || c > '9' {
if c != '.' {
return false
}
}
}
return true
}
// hasUniqueConstraint checks if a column has a unique constraint
func (tm *TypeMapper) hasUniqueConstraint(colName string, table *models.Table) bool {
for _, constraint := range table.Constraints {
if constraint.Type == models.UniqueConstraint &&
len(constraint.Columns) == 1 &&
constraint.Columns[0] == colName {
return true
}
}
return false
}
// BuildReferencesChain builds the .references() chain for foreign key columns
func (tm *TypeMapper) BuildReferencesChain(fk *models.Constraint, referencedTable string) string {
// Example: .references(() => users.id)
if len(fk.ReferencedColumns) > 0 {
// Use the referenced table variable name (camelCase)
refTableVar := tm.ToCamelCase(referencedTable)
refColumn := fk.ReferencedColumns[0]
return fmt.Sprintf("references(() => %s.%s)", refTableVar, refColumn)
}
return ""
}
// ToCamelCase converts snake_case or PascalCase to camelCase
func (tm *TypeMapper) ToCamelCase(s string) string {
if s == "" {
return s
}
// Check if it's snake_case
if strings.Contains(s, "_") {
parts := strings.Split(s, "_")
if len(parts) == 0 {
return s
}
// First part stays lowercase
result := strings.ToLower(parts[0])
// Capitalize first letter of remaining parts
for i := 1; i < len(parts); i++ {
if len(parts[i]) > 0 {
result += strings.ToUpper(parts[i][:1]) + strings.ToLower(parts[i][1:])
}
}
return result
}
// Otherwise, assume it's PascalCase - just lowercase the first letter
return strings.ToLower(s[:1]) + s[1:]
}
// ToPascalCase converts snake_case to PascalCase
func (tm *TypeMapper) ToPascalCase(s string) string {
parts := strings.Split(s, "_")
var result string
for _, part := range parts {
if len(part) > 0 {
result += strings.ToUpper(part[:1]) + strings.ToLower(part[1:])
}
}
return result
}
// DrizzleTypeToTypeScript converts Drizzle column types to TypeScript types
func (tm *TypeMapper) DrizzleTypeToTypeScript(drizzleType string, isEnum bool, enumName string) string {
if isEnum {
return enumName
}
typeMap := map[string]string{
"integer": "number",
"bigint": "number",
"smallint": "number",
"serial": "number",
"bigserial": "number",
"smallserial": "number",
"numeric": "number",
"real": "number",
"doublePrecision": "number",
"text": "string",
"varchar": "string",
"char": "string",
"boolean": "boolean",
"bytea": "Buffer",
"json": "any",
"jsonb": "any",
"timestamp": "Date",
"date": "Date",
"time": "Date",
"interval": "string",
"uuid": "string",
"point": "{ x: number; y: number }",
"line": "{ a: number; b: number; c: number }",
}
if tsType, ok := typeMap[drizzleType]; ok {
return tsType
}
// Default to any for unknown types
return "any"
}

View File

@@ -0,0 +1,543 @@
package drizzle
import (
"fmt"
"os"
"path/filepath"
"strings"
"git.warky.dev/wdevs/relspecgo/pkg/models"
"git.warky.dev/wdevs/relspecgo/pkg/writers"
)
// Writer implements the writers.Writer interface for Drizzle ORM
type Writer struct {
options *writers.WriterOptions
typeMapper *TypeMapper
templates *Templates
}
// NewWriter creates a new Drizzle writer with the given options
func NewWriter(options *writers.WriterOptions) *Writer {
w := &Writer{
options: options,
typeMapper: NewTypeMapper(),
}
// Initialize templates
tmpl, err := NewTemplates()
if err != nil {
// Should not happen with embedded templates
panic(fmt.Sprintf("failed to initialize templates: %v", err))
}
w.templates = tmpl
return w
}
// WriteDatabase writes a complete database as Drizzle schema
func (w *Writer) WriteDatabase(db *models.Database) error {
// Check if multi-file mode is enabled
multiFile := w.shouldUseMultiFile()
if multiFile {
return w.writeMultiFile(db)
}
return w.writeSingleFile(db)
}
// WriteSchema writes a schema as Drizzle schema
func (w *Writer) WriteSchema(schema *models.Schema) error {
// Create a temporary database with just this schema
db := models.InitDatabase(schema.Name)
db.Schemas = []*models.Schema{schema}
return w.WriteDatabase(db)
}
// WriteTable writes a single table as a Drizzle schema
func (w *Writer) WriteTable(table *models.Table) error {
// Create a temporary schema and database
schema := models.InitSchema(table.Schema)
schema.Tables = []*models.Table{table}
db := models.InitDatabase(schema.Name)
db.Schemas = []*models.Schema{schema}
return w.WriteDatabase(db)
}
// writeSingleFile writes all tables to a single file
func (w *Writer) writeSingleFile(db *models.Database) error {
templateData := NewTemplateData()
// Build enum map for quick lookup
enumMap := w.buildEnumMap(db)
// Process all schemas
for _, schema := range db.Schemas {
// Add enums
for _, enum := range schema.Enums {
enumData := NewEnumData(enum, w.typeMapper)
templateData.AddEnum(enumData)
}
// Add tables
for _, table := range schema.Tables {
tableData := w.buildTableData(table, schema, db, enumMap)
templateData.AddTable(tableData)
}
}
// Add imports
w.addImports(templateData, db)
// Finalize imports
templateData.FinalizeImports()
// Generate code
code, err := w.templates.GenerateCode(templateData)
if err != nil {
return fmt.Errorf("failed to generate code: %w", err)
}
// Write output
return w.writeOutput(code)
}
// writeMultiFile writes each table to a separate file
func (w *Writer) writeMultiFile(db *models.Database) error {
// Ensure output path is a directory
if w.options.OutputPath == "" {
return fmt.Errorf("output path is required for multi-file mode")
}
// Create output directory if it doesn't exist
if err := os.MkdirAll(w.options.OutputPath, 0755); err != nil {
return fmt.Errorf("failed to create output directory: %w", err)
}
// Build enum map for quick lookup
enumMap := w.buildEnumMap(db)
// Process all schemas
for _, schema := range db.Schemas {
// Write enums file if there are any
if len(schema.Enums) > 0 {
if err := w.writeEnumsFile(schema); err != nil {
return err
}
}
// Write each table to a separate file
for _, table := range schema.Tables {
if err := w.writeTableFile(table, schema, db, enumMap); err != nil {
return err
}
}
}
return nil
}
// writeEnumsFile writes all enums to a separate file
func (w *Writer) writeEnumsFile(schema *models.Schema) error {
templateData := NewTemplateData()
// Add enums
for _, enum := range schema.Enums {
enumData := NewEnumData(enum, w.typeMapper)
templateData.AddEnum(enumData)
}
// Add imports for enums
templateData.AddImport("import { pgEnum } from 'drizzle-orm/pg-core';")
// Generate code
code, err := w.templates.GenerateCode(templateData)
if err != nil {
return fmt.Errorf("failed to generate enums code: %w", err)
}
// Write to enums.ts file
filename := filepath.Join(w.options.OutputPath, "enums.ts")
return os.WriteFile(filename, []byte(code), 0644)
}
// writeTableFile writes a single table to its own file
func (w *Writer) writeTableFile(table *models.Table, schema *models.Schema, db *models.Database, enumMap map[string]bool) error {
templateData := NewTemplateData()
// Build table data
tableData := w.buildTableData(table, schema, db, enumMap)
templateData.AddTable(tableData)
// Add imports
w.addImports(templateData, db)
// If there are enums, add import from enums file
if len(schema.Enums) > 0 && w.tableUsesEnum(table, enumMap) {
// Import enum definitions from enums.ts
enumNames := w.getTableEnumNames(table, schema, enumMap)
if len(enumNames) > 0 {
importLine := fmt.Sprintf("import { %s } from './enums';", strings.Join(enumNames, ", "))
templateData.AddImport(importLine)
}
}
// Finalize imports
templateData.FinalizeImports()
// Generate code
code, err := w.templates.GenerateCode(templateData)
if err != nil {
return fmt.Errorf("failed to generate code for table %s: %w", table.Name, err)
}
// Generate filename: {tableName}.ts
filename := filepath.Join(w.options.OutputPath, table.Name+".ts")
return os.WriteFile(filename, []byte(code), 0644)
}
// buildTableData builds TableData from a models.Table
func (w *Writer) buildTableData(table *models.Table, schema *models.Schema, db *models.Database, enumMap map[string]bool) *TableData {
tableData := NewTableData(table, w.typeMapper)
// Add columns
for _, colName := range w.getSortedColumnNames(table) {
col := table.Columns[colName]
// Check if this column uses an enum
isEnum := enumMap[col.Type]
columnData := NewColumnData(col, table, w.typeMapper, isEnum)
// Set TypeScript type
drizzleType := w.typeMapper.SQLTypeToDrizzle(col.Type)
enumName := ""
if isEnum {
// For enums, use the enum type name
enumName = col.Type
}
baseType := w.typeMapper.DrizzleTypeToTypeScript(drizzleType, isEnum, enumName)
// Add null union if column is nullable
if !col.NotNull && !col.IsPrimaryKey {
columnData.TypeScriptType = baseType + " | null"
} else {
columnData.TypeScriptType = baseType
}
// Check if this column is a foreign key
if fk := w.getForeignKeyForColumn(col.Name, table); fk != nil {
columnData.IsForeignKey = true
refTableName := fk.ReferencedTable
refChain := w.typeMapper.BuildReferencesChain(fk, refTableName)
if refChain != "" {
columnData.ReferencesLine = "." + refChain
// Append to the drizzle chain
columnData.DrizzleChain += columnData.ReferencesLine
}
}
tableData.AddColumn(columnData)
}
// Collect all column field names that are used in indexes
indexColumnFields := make(map[string]bool)
// Add indexes (excluding single-column unique indexes, which are handled inline)
for _, index := range table.Indexes {
// Skip single-column unique indexes (handled by .unique() modifier)
if index.Unique && len(index.Columns) == 1 {
continue
}
// Track which columns are used in indexes
for _, colName := range index.Columns {
// Find the field name for this column
if col, exists := table.Columns[colName]; exists {
fieldName := w.typeMapper.ToCamelCase(col.Name)
indexColumnFields[fieldName] = true
}
}
indexData := NewIndexData(index, tableData.Name, w.typeMapper)
tableData.AddIndex(indexData)
}
// Add multi-column unique constraints as unique indexes
for _, constraint := range table.Constraints {
if constraint.Type == models.UniqueConstraint && len(constraint.Columns) > 1 {
// Create a unique index for this constraint
indexData := &IndexData{
Name: w.typeMapper.ToCamelCase(constraint.Name) + "Idx",
Columns: constraint.Columns,
IsUnique: true,
}
// Track which columns are used in indexes
for _, colName := range constraint.Columns {
if col, exists := table.Columns[colName]; exists {
fieldName := w.typeMapper.ToCamelCase(col.Name)
indexColumnFields[fieldName] = true
}
}
// Build column references as field names (for destructuring)
colRefs := make([]string, len(constraint.Columns))
for i, colName := range constraint.Columns {
if col, exists := table.Columns[colName]; exists {
colRefs[i] = w.typeMapper.ToCamelCase(col.Name)
} else {
colRefs[i] = w.typeMapper.ToCamelCase(colName)
}
}
indexData.Definition = "uniqueIndex('" + constraint.Name + "').on(" + joinStrings(colRefs, ", ") + ")"
tableData.AddIndex(indexData)
}
}
// Convert index column fields map to sorted slice
if len(indexColumnFields) > 0 {
fields := make([]string, 0, len(indexColumnFields))
for field := range indexColumnFields {
fields = append(fields, field)
}
// Sort for consistent output
sortStrings(fields)
tableData.IndexColumnFields = fields
}
return tableData
}
// sortStrings sorts a slice of strings in place
func sortStrings(strs []string) {
for i := 0; i < len(strs); i++ {
for j := i + 1; j < len(strs); j++ {
if strs[i] > strs[j] {
strs[i], strs[j] = strs[j], strs[i]
}
}
}
}
// addImports adds the necessary imports to the template data
func (w *Writer) addImports(templateData *TemplateData, db *models.Database) {
// Determine which Drizzle imports we need
needsPgTable := len(templateData.Tables) > 0
needsPgEnum := len(templateData.Enums) > 0
needsIndex := false
needsUniqueIndex := false
needsSQL := false
// Check what we need based on tables
for _, table := range templateData.Tables {
for _, index := range table.Indexes {
if index.IsUnique {
needsUniqueIndex = true
} else {
needsIndex = true
}
}
// Check if any column uses SQL default values
for _, col := range table.Columns {
if strings.Contains(col.DrizzleChain, "sql`") {
needsSQL = true
}
}
}
// Build the import statement
imports := make([]string, 0)
if needsPgTable {
imports = append(imports, "pgTable")
}
if needsPgEnum {
imports = append(imports, "pgEnum")
}
// Add column types - for now, add common ones
// TODO: Could be optimized to only include used types
columnTypes := []string{
"integer", "bigint", "smallint",
"serial", "bigserial", "smallserial",
"text", "varchar", "char",
"boolean", "numeric", "real", "doublePrecision",
"timestamp", "date", "time", "interval",
"json", "jsonb", "uuid", "bytea",
}
imports = append(imports, columnTypes...)
if needsIndex {
imports = append(imports, "index")
}
if needsUniqueIndex {
imports = append(imports, "uniqueIndex")
}
importLine := "import { " + strings.Join(imports, ", ") + " } from 'drizzle-orm/pg-core';"
templateData.AddImport(importLine)
// Add SQL import if needed
if needsSQL {
templateData.AddImport("import { sql } from 'drizzle-orm';")
}
}
// buildEnumMap builds a map of enum type names for quick lookup
func (w *Writer) buildEnumMap(db *models.Database) map[string]bool {
enumMap := make(map[string]bool)
for _, schema := range db.Schemas {
for _, enum := range schema.Enums {
enumMap[enum.Name] = true
// Also add lowercase version for case-insensitive lookup
enumMap[strings.ToLower(enum.Name)] = true
}
}
return enumMap
}
// tableUsesEnum checks if a table uses any enum types
func (w *Writer) tableUsesEnum(table *models.Table, enumMap map[string]bool) bool {
for _, col := range table.Columns {
if enumMap[col.Type] || enumMap[strings.ToLower(col.Type)] {
return true
}
}
return false
}
// getTableEnumNames returns the list of enum variable names used by a table
func (w *Writer) getTableEnumNames(table *models.Table, schema *models.Schema, enumMap map[string]bool) []string {
enumNames := make([]string, 0)
seen := make(map[string]bool)
for _, col := range table.Columns {
if enumMap[col.Type] || enumMap[strings.ToLower(col.Type)] {
// Find the enum in schema
for _, enum := range schema.Enums {
if strings.EqualFold(enum.Name, col.Type) {
varName := w.typeMapper.ToCamelCase(enum.Name)
if !seen[varName] {
enumNames = append(enumNames, varName)
seen[varName] = true
}
break
}
}
}
}
return enumNames
}
// getSortedColumnNames returns column names sorted by sequence or name
func (w *Writer) getSortedColumnNames(table *models.Table) []string {
// Convert map to slice
columns := make([]*models.Column, 0, len(table.Columns))
for _, col := range table.Columns {
columns = append(columns, col)
}
// Sort by sequence, then by primary key, then by name
// (Similar to GORM writer)
sortColumns := func(i, j int) bool {
// Sort by sequence if both have it
if columns[i].Sequence > 0 && columns[j].Sequence > 0 {
return columns[i].Sequence < columns[j].Sequence
}
// Put primary keys first
if columns[i].IsPrimaryKey != columns[j].IsPrimaryKey {
return columns[i].IsPrimaryKey
}
// Otherwise sort alphabetically
return columns[i].Name < columns[j].Name
}
// Create a custom sorter
for i := 0; i < len(columns); i++ {
for j := i + 1; j < len(columns); j++ {
if !sortColumns(i, j) {
columns[i], columns[j] = columns[j], columns[i]
}
}
}
// Extract names
names := make([]string, len(columns))
for i, col := range columns {
names[i] = col.Name
}
return names
}
// getForeignKeyForColumn returns the foreign key constraint for a column, if any
func (w *Writer) getForeignKeyForColumn(columnName string, table *models.Table) *models.Constraint {
for _, constraint := range table.Constraints {
if constraint.Type == models.ForeignKeyConstraint {
for _, col := range constraint.Columns {
if col == columnName {
return constraint
}
}
}
}
return nil
}
// writeOutput writes the content to file or stdout
func (w *Writer) writeOutput(content string) error {
if w.options.OutputPath != "" {
return os.WriteFile(w.options.OutputPath, []byte(content), 0644)
}
// Print to stdout
fmt.Print(content)
return nil
}
// shouldUseMultiFile determines whether to use multi-file mode based on metadata or output path
func (w *Writer) shouldUseMultiFile() bool {
// Check if multi_file is explicitly set in metadata
if w.options.Metadata != nil {
if mf, ok := w.options.Metadata["multi_file"].(bool); ok {
return mf
}
}
// Auto-detect based on output path
if w.options.OutputPath == "" {
// No output path means stdout (single file)
return false
}
// Check if path ends with .ts (explicit file)
if strings.HasSuffix(w.options.OutputPath, ".ts") {
return false
}
// Check if path ends with directory separator
if strings.HasSuffix(w.options.OutputPath, "/") || strings.HasSuffix(w.options.OutputPath, "\\") {
return true
}
// Check if path exists and is a directory
info, err := os.Stat(w.options.OutputPath)
if err == nil && info.IsDir() {
return true
}
// Default to single file for ambiguous cases
return false
}