Implemented TypeORM, Prisma and Enums on a schema
This commit is contained in:
551
pkg/writers/prisma/writer.go
Normal file
551
pkg/writers/prisma/writer.go
Normal file
@@ -0,0 +1,551 @@
|
||||
package prisma
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/writers"
|
||||
)
|
||||
|
||||
// Writer implements the writers.Writer interface for Prisma schema format
|
||||
type Writer struct {
|
||||
options *writers.WriterOptions
|
||||
}
|
||||
|
||||
// NewWriter creates a new Prisma writer with the given options
|
||||
func NewWriter(options *writers.WriterOptions) *Writer {
|
||||
return &Writer{
|
||||
options: options,
|
||||
}
|
||||
}
|
||||
|
||||
// WriteDatabase writes a Database model to Prisma schema format
|
||||
func (w *Writer) WriteDatabase(db *models.Database) error {
|
||||
content := w.databaseToPrisma(db)
|
||||
|
||||
if w.options.OutputPath != "" {
|
||||
return os.WriteFile(w.options.OutputPath, []byte(content), 0644)
|
||||
}
|
||||
|
||||
fmt.Print(content)
|
||||
return nil
|
||||
}
|
||||
|
||||
// WriteSchema writes a Schema model to Prisma schema format
|
||||
func (w *Writer) WriteSchema(schema *models.Schema) error {
|
||||
// Create temporary database for schema
|
||||
db := models.InitDatabase("database")
|
||||
db.Schemas = []*models.Schema{schema}
|
||||
|
||||
return w.WriteDatabase(db)
|
||||
}
|
||||
|
||||
// WriteTable writes a Table model to Prisma schema format
|
||||
func (w *Writer) WriteTable(table *models.Table) error {
|
||||
// Create temporary schema and database for table
|
||||
schema := models.InitSchema(table.Schema)
|
||||
schema.Tables = []*models.Table{table}
|
||||
|
||||
return w.WriteSchema(schema)
|
||||
}
|
||||
|
||||
// databaseToPrisma converts a Database to Prisma schema format string
|
||||
func (w *Writer) databaseToPrisma(db *models.Database) string {
|
||||
var sb strings.Builder
|
||||
|
||||
// Write datasource block
|
||||
sb.WriteString(w.generateDatasource(db))
|
||||
sb.WriteString("\n")
|
||||
|
||||
// Write generator block
|
||||
sb.WriteString(w.generateGenerator())
|
||||
sb.WriteString("\n")
|
||||
|
||||
// Process all schemas (typically just one in Prisma)
|
||||
for _, schema := range db.Schemas {
|
||||
// Write enums
|
||||
if len(schema.Enums) > 0 {
|
||||
for _, enum := range schema.Enums {
|
||||
sb.WriteString(w.enumToPrisma(enum))
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
}
|
||||
|
||||
// Identify join tables for implicit M2M
|
||||
joinTables := w.identifyJoinTables(schema)
|
||||
|
||||
// Write models (excluding join tables)
|
||||
for _, table := range schema.Tables {
|
||||
if joinTables[table.Name] {
|
||||
continue // Skip join tables
|
||||
}
|
||||
sb.WriteString(w.tableToPrisma(table, schema, joinTables))
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// generateDatasource generates the datasource block
|
||||
func (w *Writer) generateDatasource(db *models.Database) string {
|
||||
provider := "postgresql"
|
||||
|
||||
// Map database type to Prisma provider
|
||||
switch db.DatabaseType {
|
||||
case models.PostgresqlDatabaseType:
|
||||
provider = "postgresql"
|
||||
case models.MSSQLDatabaseType:
|
||||
provider = "sqlserver"
|
||||
case models.SqlLiteDatabaseType:
|
||||
provider = "sqlite"
|
||||
case "mysql":
|
||||
provider = "mysql"
|
||||
}
|
||||
|
||||
return fmt.Sprintf(`datasource db {
|
||||
provider = "%s"
|
||||
url = env("DATABASE_URL")
|
||||
}
|
||||
`, provider)
|
||||
}
|
||||
|
||||
// generateGenerator generates the generator block
|
||||
func (w *Writer) generateGenerator() string {
|
||||
return `generator client {
|
||||
provider = "prisma-client-js"
|
||||
}
|
||||
`
|
||||
}
|
||||
|
||||
// enumToPrisma converts an Enum to Prisma enum block
|
||||
func (w *Writer) enumToPrisma(enum *models.Enum) string {
|
||||
var sb strings.Builder
|
||||
|
||||
sb.WriteString(fmt.Sprintf("enum %s {\n", enum.Name))
|
||||
for _, value := range enum.Values {
|
||||
sb.WriteString(fmt.Sprintf(" %s\n", value))
|
||||
}
|
||||
sb.WriteString("}\n")
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// identifyJoinTables identifies tables that are join tables for M2M relations
|
||||
func (w *Writer) identifyJoinTables(schema *models.Schema) map[string]bool {
|
||||
joinTables := make(map[string]bool)
|
||||
|
||||
for _, table := range schema.Tables {
|
||||
// Check if this is a join table:
|
||||
// 1. Starts with _ (Prisma convention)
|
||||
// 2. Has exactly 2 FK constraints
|
||||
// 3. Has composite PK with those 2 columns
|
||||
// 4. Has no other columns except the FK columns
|
||||
|
||||
if !strings.HasPrefix(table.Name, "_") {
|
||||
continue
|
||||
}
|
||||
|
||||
fks := table.GetForeignKeys()
|
||||
if len(fks) != 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if columns are only the FK columns
|
||||
if len(table.Columns) != 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if both FK columns are part of PK
|
||||
pkCols := 0
|
||||
for _, col := range table.Columns {
|
||||
if col.IsPrimaryKey {
|
||||
pkCols++
|
||||
}
|
||||
}
|
||||
|
||||
if pkCols == 2 {
|
||||
joinTables[table.Name] = true
|
||||
}
|
||||
}
|
||||
|
||||
return joinTables
|
||||
}
|
||||
|
||||
// tableToPrisma converts a Table to Prisma model block
|
||||
func (w *Writer) tableToPrisma(table *models.Table, schema *models.Schema, joinTables map[string]bool) string {
|
||||
var sb strings.Builder
|
||||
|
||||
sb.WriteString(fmt.Sprintf("model %s {\n", table.Name))
|
||||
|
||||
// Collect columns to write
|
||||
columns := make([]*models.Column, 0, len(table.Columns))
|
||||
for _, col := range table.Columns {
|
||||
columns = append(columns, col)
|
||||
}
|
||||
|
||||
// Sort columns for consistent output
|
||||
sort.Slice(columns, func(i, j int) bool {
|
||||
return columns[i].Name < columns[j].Name
|
||||
})
|
||||
|
||||
// Write scalar fields
|
||||
for _, col := range columns {
|
||||
// Skip if this column is part of a relation that will be output as array field
|
||||
if w.isRelationColumn(col, table) {
|
||||
// We'll output this with the relation field
|
||||
continue
|
||||
}
|
||||
|
||||
sb.WriteString(w.columnToField(col, table, schema))
|
||||
}
|
||||
|
||||
// Write relation fields
|
||||
sb.WriteString(w.generateRelationFields(table, schema, joinTables))
|
||||
|
||||
// Write block attributes (@@id, @@unique, @@index)
|
||||
sb.WriteString(w.generateBlockAttributes(table))
|
||||
|
||||
sb.WriteString("}\n")
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// columnToField converts a Column to a Prisma field definition
|
||||
func (w *Writer) columnToField(col *models.Column, table *models.Table, schema *models.Schema) string {
|
||||
var sb strings.Builder
|
||||
|
||||
// Field name
|
||||
sb.WriteString(fmt.Sprintf(" %s", col.Name))
|
||||
|
||||
// Field type
|
||||
prismaType := w.sqlTypeToPrisma(col.Type, schema)
|
||||
sb.WriteString(fmt.Sprintf(" %s", prismaType))
|
||||
|
||||
// Optional modifier
|
||||
if !col.NotNull && !col.IsPrimaryKey {
|
||||
sb.WriteString("?")
|
||||
}
|
||||
|
||||
// Field attributes
|
||||
attributes := w.generateFieldAttributes(col, table)
|
||||
if attributes != "" {
|
||||
sb.WriteString(" ")
|
||||
sb.WriteString(attributes)
|
||||
}
|
||||
|
||||
sb.WriteString("\n")
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// sqlTypeToPrisma converts SQL types to Prisma types
|
||||
func (w *Writer) sqlTypeToPrisma(sqlType string, schema *models.Schema) string {
|
||||
// Check if it's an enum
|
||||
for _, enum := range schema.Enums {
|
||||
if strings.EqualFold(sqlType, enum.Name) {
|
||||
return enum.Name
|
||||
}
|
||||
}
|
||||
|
||||
// Standard type mapping
|
||||
typeMap := map[string]string{
|
||||
"text": "String",
|
||||
"varchar": "String",
|
||||
"character varying": "String",
|
||||
"char": "String",
|
||||
"boolean": "Boolean",
|
||||
"bool": "Boolean",
|
||||
"integer": "Int",
|
||||
"int": "Int",
|
||||
"int4": "Int",
|
||||
"bigint": "BigInt",
|
||||
"int8": "BigInt",
|
||||
"double precision": "Float",
|
||||
"float": "Float",
|
||||
"float8": "Float",
|
||||
"decimal": "Decimal",
|
||||
"numeric": "Decimal",
|
||||
"timestamp": "DateTime",
|
||||
"timestamptz": "DateTime",
|
||||
"date": "DateTime",
|
||||
"jsonb": "Json",
|
||||
"json": "Json",
|
||||
"bytea": "Bytes",
|
||||
}
|
||||
|
||||
for sqlPattern, prismaType := range typeMap {
|
||||
if strings.Contains(strings.ToLower(sqlType), sqlPattern) {
|
||||
return prismaType
|
||||
}
|
||||
}
|
||||
|
||||
// Default to String for unknown types
|
||||
return "String"
|
||||
}
|
||||
|
||||
// generateFieldAttributes generates field attributes like @id, @unique, @default
|
||||
func (w *Writer) generateFieldAttributes(col *models.Column, table *models.Table) string {
|
||||
attrs := make([]string, 0)
|
||||
|
||||
// @id
|
||||
if col.IsPrimaryKey {
|
||||
// Check if this is part of a composite key
|
||||
pkCount := 0
|
||||
for _, c := range table.Columns {
|
||||
if c.IsPrimaryKey {
|
||||
pkCount++
|
||||
}
|
||||
}
|
||||
if pkCount == 1 {
|
||||
attrs = append(attrs, "@id")
|
||||
}
|
||||
}
|
||||
|
||||
// @unique
|
||||
if w.hasUniqueConstraint(col.Name, table) {
|
||||
attrs = append(attrs, "@unique")
|
||||
}
|
||||
|
||||
// @default
|
||||
if col.AutoIncrement {
|
||||
attrs = append(attrs, "@default(autoincrement())")
|
||||
} else if col.Default != nil {
|
||||
defaultAttr := w.formatDefaultValue(col.Default)
|
||||
if defaultAttr != "" {
|
||||
attrs = append(attrs, fmt.Sprintf("@default(%s)", defaultAttr))
|
||||
}
|
||||
}
|
||||
|
||||
// @updatedAt (check comment)
|
||||
if strings.Contains(col.Comment, "@updatedAt") {
|
||||
attrs = append(attrs, "@updatedAt")
|
||||
}
|
||||
|
||||
return strings.Join(attrs, " ")
|
||||
}
|
||||
|
||||
// formatDefaultValue formats a default value for Prisma
|
||||
func (w *Writer) formatDefaultValue(defaultValue any) string {
|
||||
switch v := defaultValue.(type) {
|
||||
case string:
|
||||
if v == "now()" {
|
||||
return "now()"
|
||||
} else if v == "gen_random_uuid()" {
|
||||
return "uuid()"
|
||||
} else if strings.Contains(strings.ToLower(v), "uuid") {
|
||||
return "uuid()"
|
||||
} else {
|
||||
// String literal
|
||||
return fmt.Sprintf(`"%s"`, v)
|
||||
}
|
||||
case bool:
|
||||
if v {
|
||||
return "true"
|
||||
}
|
||||
return "false"
|
||||
case int, int64, int32:
|
||||
return fmt.Sprintf("%v", v)
|
||||
default:
|
||||
return fmt.Sprintf("%v", v)
|
||||
}
|
||||
}
|
||||
|
||||
// hasUniqueConstraint checks if a column has a unique constraint
|
||||
func (w *Writer) hasUniqueConstraint(colName string, table *models.Table) bool {
|
||||
for _, constraint := range table.Constraints {
|
||||
if constraint.Type == models.UniqueConstraint &&
|
||||
len(constraint.Columns) == 1 &&
|
||||
constraint.Columns[0] == colName {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// isRelationColumn checks if a column is a FK column
|
||||
func (w *Writer) isRelationColumn(col *models.Column, table *models.Table) bool {
|
||||
for _, constraint := range table.Constraints {
|
||||
if constraint.Type == models.ForeignKeyConstraint {
|
||||
for _, fkCol := range constraint.Columns {
|
||||
if fkCol == col.Name {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// generateRelationFields generates relation fields and their FK columns
|
||||
func (w *Writer) generateRelationFields(table *models.Table, schema *models.Schema, joinTables map[string]bool) string {
|
||||
var sb strings.Builder
|
||||
|
||||
// Get all FK constraints
|
||||
fks := table.GetForeignKeys()
|
||||
|
||||
for _, fk := range fks {
|
||||
// Generate the FK scalar field
|
||||
for _, fkCol := range fk.Columns {
|
||||
if col, exists := table.Columns[fkCol]; exists {
|
||||
sb.WriteString(w.columnToField(col, table, schema))
|
||||
}
|
||||
}
|
||||
|
||||
// Generate the relation field
|
||||
relationType := fk.ReferencedTable
|
||||
isOptional := false
|
||||
|
||||
// Check if FK column is nullable
|
||||
for _, fkCol := range fk.Columns {
|
||||
if col, exists := table.Columns[fkCol]; exists {
|
||||
if !col.NotNull {
|
||||
isOptional = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
relationName := relationType
|
||||
if strings.HasSuffix(strings.ToLower(relationName), "s") {
|
||||
relationName = relationName[:len(relationName)-1]
|
||||
}
|
||||
|
||||
sb.WriteString(fmt.Sprintf(" %s %s", strings.ToLower(relationName), relationType))
|
||||
|
||||
if isOptional {
|
||||
sb.WriteString("?")
|
||||
}
|
||||
|
||||
// @relation attribute
|
||||
relationAttr := w.generateRelationAttribute(fk)
|
||||
if relationAttr != "" {
|
||||
sb.WriteString(" ")
|
||||
sb.WriteString(relationAttr)
|
||||
}
|
||||
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
|
||||
// Generate inverse relations (arrays) for tables that reference this one
|
||||
sb.WriteString(w.generateInverseRelations(table, schema, joinTables))
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// generateRelationAttribute generates the @relation(...) attribute
|
||||
func (w *Writer) generateRelationAttribute(fk *models.Constraint) string {
|
||||
parts := make([]string, 0)
|
||||
|
||||
// fields
|
||||
fieldsStr := strings.Join(fk.Columns, ", ")
|
||||
parts = append(parts, fmt.Sprintf("fields: [%s]", fieldsStr))
|
||||
|
||||
// references
|
||||
referencesStr := strings.Join(fk.ReferencedColumns, ", ")
|
||||
parts = append(parts, fmt.Sprintf("references: [%s]", referencesStr))
|
||||
|
||||
// onDelete
|
||||
if fk.OnDelete != "" {
|
||||
parts = append(parts, fmt.Sprintf("onDelete: %s", fk.OnDelete))
|
||||
}
|
||||
|
||||
// onUpdate
|
||||
if fk.OnUpdate != "" {
|
||||
parts = append(parts, fmt.Sprintf("onUpdate: %s", fk.OnUpdate))
|
||||
}
|
||||
|
||||
return fmt.Sprintf("@relation(%s)", strings.Join(parts, ", "))
|
||||
}
|
||||
|
||||
// generateInverseRelations generates array fields for reverse relationships
|
||||
func (w *Writer) generateInverseRelations(table *models.Table, schema *models.Schema, joinTables map[string]bool) string {
|
||||
var sb strings.Builder
|
||||
|
||||
// Find all tables that have FKs pointing to this table
|
||||
for _, otherTable := range schema.Tables {
|
||||
if otherTable.Name == table.Name {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if this is a join table
|
||||
if joinTables[otherTable.Name] {
|
||||
// Handle implicit M2M
|
||||
if w.isJoinTableFor(otherTable, table.Name) {
|
||||
// Find the other side of the M2M
|
||||
for _, fk := range otherTable.GetForeignKeys() {
|
||||
if fk.ReferencedTable != table.Name {
|
||||
// This is the other side
|
||||
otherSide := fk.ReferencedTable
|
||||
sb.WriteString(fmt.Sprintf(" %ss %s[]\n",
|
||||
strings.ToLower(otherSide), otherSide))
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Regular one-to-many inverse relation
|
||||
for _, fk := range otherTable.GetForeignKeys() {
|
||||
if fk.ReferencedTable == table.Name {
|
||||
// This table is referenced by otherTable
|
||||
pluralName := otherTable.Name
|
||||
if !strings.HasSuffix(pluralName, "s") {
|
||||
pluralName += "s"
|
||||
}
|
||||
|
||||
sb.WriteString(fmt.Sprintf(" %s %s[]\n",
|
||||
strings.ToLower(pluralName), otherTable.Name))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// isJoinTableFor checks if a table is a join table involving the specified model
|
||||
func (w *Writer) isJoinTableFor(joinTable *models.Table, modelName string) bool {
|
||||
for _, fk := range joinTable.GetForeignKeys() {
|
||||
if fk.ReferencedTable == modelName {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// generateBlockAttributes generates block-level attributes like @@id, @@unique, @@index
|
||||
func (w *Writer) generateBlockAttributes(table *models.Table) string {
|
||||
var sb strings.Builder
|
||||
|
||||
// @@id for composite primary key
|
||||
pkCols := make([]string, 0)
|
||||
for _, col := range table.Columns {
|
||||
if col.IsPrimaryKey {
|
||||
pkCols = append(pkCols, col.Name)
|
||||
}
|
||||
}
|
||||
|
||||
if len(pkCols) > 1 {
|
||||
sort.Strings(pkCols)
|
||||
sb.WriteString(fmt.Sprintf(" @@id([%s])\n", strings.Join(pkCols, ", ")))
|
||||
}
|
||||
|
||||
// @@unique for multi-column unique constraints
|
||||
for _, constraint := range table.Constraints {
|
||||
if constraint.Type == models.UniqueConstraint && len(constraint.Columns) > 1 {
|
||||
sb.WriteString(fmt.Sprintf(" @@unique([%s])\n", strings.Join(constraint.Columns, ", ")))
|
||||
}
|
||||
}
|
||||
|
||||
// @@index for indexes
|
||||
for _, index := range table.Indexes {
|
||||
if !index.Unique { // Unique indexes are handled by @@unique
|
||||
sb.WriteString(fmt.Sprintf(" @@index([%s])\n", strings.Join(index.Columns, ", ")))
|
||||
}
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
631
pkg/writers/typeorm/writer.go
Normal file
631
pkg/writers/typeorm/writer.go
Normal file
@@ -0,0 +1,631 @@
|
||||
package typeorm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/writers"
|
||||
)
|
||||
|
||||
// Writer implements the writers.Writer interface for TypeORM entity format
|
||||
type Writer struct {
|
||||
options *writers.WriterOptions
|
||||
}
|
||||
|
||||
// NewWriter creates a new TypeORM writer with the given options
|
||||
func NewWriter(options *writers.WriterOptions) *Writer {
|
||||
return &Writer{
|
||||
options: options,
|
||||
}
|
||||
}
|
||||
|
||||
// WriteDatabase writes a Database model to TypeORM entity format
|
||||
func (w *Writer) WriteDatabase(db *models.Database) error {
|
||||
content := w.databaseToTypeORM(db)
|
||||
|
||||
if w.options.OutputPath != "" {
|
||||
return os.WriteFile(w.options.OutputPath, []byte(content), 0644)
|
||||
}
|
||||
|
||||
fmt.Print(content)
|
||||
return nil
|
||||
}
|
||||
|
||||
// WriteSchema writes a Schema model to TypeORM entity format
|
||||
func (w *Writer) WriteSchema(schema *models.Schema) error {
|
||||
db := models.InitDatabase("database")
|
||||
db.Schemas = []*models.Schema{schema}
|
||||
|
||||
return w.WriteDatabase(db)
|
||||
}
|
||||
|
||||
// WriteTable writes a Table model to TypeORM entity format
|
||||
func (w *Writer) WriteTable(table *models.Table) error {
|
||||
schema := models.InitSchema(table.Schema)
|
||||
schema.Tables = []*models.Table{table}
|
||||
|
||||
return w.WriteSchema(schema)
|
||||
}
|
||||
|
||||
// databaseToTypeORM converts a Database to TypeORM entity format string
|
||||
func (w *Writer) databaseToTypeORM(db *models.Database) string {
|
||||
var sb strings.Builder
|
||||
|
||||
// Generate imports
|
||||
sb.WriteString(w.generateImports(db))
|
||||
sb.WriteString("\n")
|
||||
|
||||
// Process all schemas
|
||||
for _, schema := range db.Schemas {
|
||||
// Identify join tables
|
||||
joinTables := w.identifyJoinTables(schema)
|
||||
|
||||
// Write entities (excluding join tables)
|
||||
for _, table := range schema.Tables {
|
||||
if joinTables[table.Name] {
|
||||
continue
|
||||
}
|
||||
sb.WriteString(w.tableToEntity(table, schema, joinTables))
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
|
||||
// Write view entities
|
||||
for _, view := range schema.Views {
|
||||
sb.WriteString(w.viewToEntity(view))
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// generateImports generates the TypeORM import statement
|
||||
func (w *Writer) generateImports(db *models.Database) string {
|
||||
imports := make([]string, 0)
|
||||
|
||||
// Always include basic decorators
|
||||
imports = append(imports, "Entity", "PrimaryGeneratedColumn", "Column")
|
||||
|
||||
// Check if we need relation decorators
|
||||
needsManyToOne := false
|
||||
needsOneToMany := false
|
||||
needsManyToMany := false
|
||||
needsJoinTable := false
|
||||
needsCreateDate := false
|
||||
needsUpdateDate := false
|
||||
needsViewEntity := false
|
||||
|
||||
for _, schema := range db.Schemas {
|
||||
// Check for views
|
||||
if len(schema.Views) > 0 {
|
||||
needsViewEntity = true
|
||||
}
|
||||
|
||||
for _, table := range schema.Tables {
|
||||
// Check for timestamp columns
|
||||
for _, col := range table.Columns {
|
||||
if col.Default == "now()" {
|
||||
needsCreateDate = true
|
||||
}
|
||||
if strings.Contains(col.Comment, "auto-update") {
|
||||
needsUpdateDate = true
|
||||
}
|
||||
}
|
||||
|
||||
// Check for relations
|
||||
for _, constraint := range table.Constraints {
|
||||
if constraint.Type == models.ForeignKeyConstraint {
|
||||
needsManyToOne = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OneToMany is the inverse of ManyToOne
|
||||
if needsManyToOne {
|
||||
needsOneToMany = true
|
||||
}
|
||||
|
||||
// Check for M2M (join tables indicate M2M relations)
|
||||
joinTables := make(map[string]bool)
|
||||
for _, schema := range db.Schemas {
|
||||
jt := w.identifyJoinTables(schema)
|
||||
for name := range jt {
|
||||
joinTables[name] = true
|
||||
needsManyToMany = true
|
||||
needsJoinTable = true
|
||||
}
|
||||
}
|
||||
|
||||
if needsManyToOne {
|
||||
imports = append(imports, "ManyToOne")
|
||||
}
|
||||
if needsOneToMany {
|
||||
imports = append(imports, "OneToMany")
|
||||
}
|
||||
if needsManyToMany {
|
||||
imports = append(imports, "ManyToMany")
|
||||
}
|
||||
if needsJoinTable {
|
||||
imports = append(imports, "JoinTable")
|
||||
}
|
||||
if needsCreateDate {
|
||||
imports = append(imports, "CreateDateColumn")
|
||||
}
|
||||
if needsUpdateDate {
|
||||
imports = append(imports, "UpdateDateColumn")
|
||||
}
|
||||
if needsViewEntity {
|
||||
imports = append(imports, "ViewEntity")
|
||||
}
|
||||
|
||||
return fmt.Sprintf("import { %s } from 'typeorm';\n", strings.Join(imports, ", "))
|
||||
}
|
||||
|
||||
// identifyJoinTables identifies tables that are join tables for M2M relations
|
||||
func (w *Writer) identifyJoinTables(schema *models.Schema) map[string]bool {
|
||||
joinTables := make(map[string]bool)
|
||||
|
||||
for _, table := range schema.Tables {
|
||||
// Check if this is a join table:
|
||||
// 1. Has exactly 2 FK constraints
|
||||
// 2. Has composite PK with those 2 columns
|
||||
// 3. Has no other columns except the FK columns
|
||||
|
||||
fks := table.GetForeignKeys()
|
||||
if len(fks) != 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if columns are only the FK columns
|
||||
if len(table.Columns) != 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if both FK columns are part of PK
|
||||
pkCols := 0
|
||||
for _, col := range table.Columns {
|
||||
if col.IsPrimaryKey {
|
||||
pkCols++
|
||||
}
|
||||
}
|
||||
|
||||
if pkCols == 2 {
|
||||
joinTables[table.Name] = true
|
||||
}
|
||||
}
|
||||
|
||||
return joinTables
|
||||
}
|
||||
|
||||
// tableToEntity converts a Table to a TypeORM entity class
|
||||
func (w *Writer) tableToEntity(table *models.Table, schema *models.Schema, joinTables map[string]bool) string {
|
||||
var sb strings.Builder
|
||||
|
||||
// Generate @Entity decorator with options
|
||||
entityOptions := w.buildEntityOptions(table)
|
||||
sb.WriteString(fmt.Sprintf("@Entity({\n%s\n})\n", entityOptions))
|
||||
|
||||
// Get class name (from metadata if different from table name)
|
||||
className := table.Name
|
||||
if table.Metadata != nil {
|
||||
if classNameVal, ok := table.Metadata["class_name"]; ok {
|
||||
if classNameStr, ok := classNameVal.(string); ok {
|
||||
className = classNameStr
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sb.WriteString(fmt.Sprintf("export class %s {\n", className))
|
||||
|
||||
// Collect and sort columns
|
||||
columns := make([]*models.Column, 0, len(table.Columns))
|
||||
for _, col := range table.Columns {
|
||||
// Skip FK columns (they'll be represented as relations)
|
||||
if w.isForeignKeyColumn(col, table) {
|
||||
continue
|
||||
}
|
||||
columns = append(columns, col)
|
||||
}
|
||||
|
||||
sort.Slice(columns, func(i, j int) bool {
|
||||
// Put PK first, then alphabetical
|
||||
if columns[i].IsPrimaryKey && !columns[j].IsPrimaryKey {
|
||||
return true
|
||||
}
|
||||
if !columns[i].IsPrimaryKey && columns[j].IsPrimaryKey {
|
||||
return false
|
||||
}
|
||||
return columns[i].Name < columns[j].Name
|
||||
})
|
||||
|
||||
// Write scalar fields
|
||||
for _, col := range columns {
|
||||
sb.WriteString(w.columnToField(col, table))
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
|
||||
// Write relation fields
|
||||
sb.WriteString(w.generateRelationFields(table, schema, joinTables))
|
||||
|
||||
sb.WriteString("}\n")
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// viewToEntity converts a View to a TypeORM @ViewEntity class
|
||||
func (w *Writer) viewToEntity(view *models.View) string {
|
||||
var sb strings.Builder
|
||||
|
||||
// Generate @ViewEntity decorator with expression
|
||||
sb.WriteString("@ViewEntity({\n")
|
||||
if view.Definition != "" {
|
||||
// Format the SQL expression with proper indentation
|
||||
sb.WriteString(" expression: `\n")
|
||||
sb.WriteString(" ")
|
||||
sb.WriteString(view.Definition)
|
||||
sb.WriteString("\n `,\n")
|
||||
}
|
||||
sb.WriteString("})\n")
|
||||
|
||||
// Generate class
|
||||
sb.WriteString(fmt.Sprintf("export class %s {\n", view.Name))
|
||||
|
||||
// Generate field definitions (without decorators for view fields)
|
||||
columns := make([]*models.Column, 0, len(view.Columns))
|
||||
for _, col := range view.Columns {
|
||||
columns = append(columns, col)
|
||||
}
|
||||
sort.Slice(columns, func(i, j int) bool {
|
||||
return columns[i].Name < columns[j].Name
|
||||
})
|
||||
|
||||
for _, col := range columns {
|
||||
tsType := w.sqlTypeToTypeScript(col.Type)
|
||||
sb.WriteString(fmt.Sprintf(" %s: %s;\n", col.Name, tsType))
|
||||
}
|
||||
|
||||
sb.WriteString("}\n")
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// columnToField converts a Column to a TypeORM field
|
||||
func (w *Writer) columnToField(col *models.Column, table *models.Table) string {
|
||||
var sb strings.Builder
|
||||
|
||||
// Generate decorator
|
||||
if col.IsPrimaryKey {
|
||||
if col.AutoIncrement {
|
||||
sb.WriteString(" @PrimaryGeneratedColumn('increment')\n")
|
||||
} else if col.Type == "uuid" || strings.Contains(fmt.Sprint(col.Default), "uuid") {
|
||||
sb.WriteString(" @PrimaryGeneratedColumn('uuid')\n")
|
||||
} else {
|
||||
sb.WriteString(" @PrimaryGeneratedColumn()\n")
|
||||
}
|
||||
} else if col.Default == "now()" {
|
||||
sb.WriteString(" @CreateDateColumn()\n")
|
||||
} else if strings.Contains(col.Comment, "auto-update") {
|
||||
sb.WriteString(" @UpdateDateColumn()\n")
|
||||
} else {
|
||||
// Regular @Column decorator
|
||||
options := w.buildColumnOptions(col, table)
|
||||
if options != "" {
|
||||
sb.WriteString(fmt.Sprintf(" @Column({ %s })\n", options))
|
||||
} else {
|
||||
sb.WriteString(" @Column()\n")
|
||||
}
|
||||
}
|
||||
|
||||
// Generate field declaration
|
||||
tsType := w.sqlTypeToTypeScript(col.Type)
|
||||
nullable := ""
|
||||
if !col.NotNull {
|
||||
nullable = " | null"
|
||||
}
|
||||
|
||||
sb.WriteString(fmt.Sprintf(" %s: %s%s;", col.Name, tsType, nullable))
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// buildColumnOptions builds the options object for @Column decorator
|
||||
func (w *Writer) buildColumnOptions(col *models.Column, table *models.Table) string {
|
||||
options := make([]string, 0)
|
||||
|
||||
// Type (if not default)
|
||||
if w.needsExplicitType(col.Type) {
|
||||
options = append(options, fmt.Sprintf("type: '%s'", col.Type))
|
||||
}
|
||||
|
||||
// Nullable
|
||||
if !col.NotNull {
|
||||
options = append(options, "nullable: true")
|
||||
}
|
||||
|
||||
// Unique
|
||||
if w.hasUniqueConstraint(col.Name, table) {
|
||||
options = append(options, "unique: true")
|
||||
}
|
||||
|
||||
// Default
|
||||
if col.Default != nil && col.Default != "now()" {
|
||||
defaultStr := fmt.Sprint(col.Default)
|
||||
if defaultStr != "" {
|
||||
options = append(options, fmt.Sprintf("default: '%s'", defaultStr))
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(options, ", ")
|
||||
}
|
||||
|
||||
// needsExplicitType checks if a SQL type needs explicit type declaration
|
||||
func (w *Writer) needsExplicitType(sqlType string) bool {
|
||||
// Types that don't map cleanly to TypeScript types need explicit declaration
|
||||
explicitTypes := []string{"text", "uuid", "jsonb", "bigint"}
|
||||
for _, t := range explicitTypes {
|
||||
if strings.Contains(sqlType, t) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// hasUniqueConstraint checks if a column has a unique constraint
|
||||
func (w *Writer) hasUniqueConstraint(colName string, table *models.Table) bool {
|
||||
for _, constraint := range table.Constraints {
|
||||
if constraint.Type == models.UniqueConstraint &&
|
||||
len(constraint.Columns) == 1 &&
|
||||
constraint.Columns[0] == colName {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// sqlTypeToTypeScript converts SQL types to TypeScript types
|
||||
func (w *Writer) sqlTypeToTypeScript(sqlType string) string {
|
||||
typeMap := map[string]string{
|
||||
"text": "string",
|
||||
"varchar": "string",
|
||||
"character varying": "string",
|
||||
"char": "string",
|
||||
"uuid": "string",
|
||||
"boolean": "boolean",
|
||||
"bool": "boolean",
|
||||
"integer": "number",
|
||||
"int": "number",
|
||||
"bigint": "number",
|
||||
"double precision": "number",
|
||||
"float": "number",
|
||||
"decimal": "number",
|
||||
"numeric": "number",
|
||||
"timestamp": "Date",
|
||||
"timestamptz": "Date",
|
||||
"date": "Date",
|
||||
"jsonb": "any",
|
||||
"json": "any",
|
||||
}
|
||||
|
||||
for sqlPattern, tsType := range typeMap {
|
||||
if strings.Contains(strings.ToLower(sqlType), sqlPattern) {
|
||||
return tsType
|
||||
}
|
||||
}
|
||||
|
||||
return "any"
|
||||
}
|
||||
|
||||
// isForeignKeyColumn checks if a column is a FK column
|
||||
func (w *Writer) isForeignKeyColumn(col *models.Column, table *models.Table) bool {
|
||||
for _, constraint := range table.Constraints {
|
||||
if constraint.Type == models.ForeignKeyConstraint {
|
||||
for _, fkCol := range constraint.Columns {
|
||||
if fkCol == col.Name {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// generateRelationFields generates relation fields for a table
|
||||
func (w *Writer) generateRelationFields(table *models.Table, schema *models.Schema, joinTables map[string]bool) string {
|
||||
var sb strings.Builder
|
||||
|
||||
// Get all FK constraints
|
||||
fks := table.GetForeignKeys()
|
||||
|
||||
// Generate @ManyToOne fields
|
||||
for _, fk := range fks {
|
||||
relatedTable := fk.ReferencedTable
|
||||
fieldName := strings.ToLower(relatedTable)
|
||||
|
||||
// Determine if nullable
|
||||
isNullable := false
|
||||
for _, fkCol := range fk.Columns {
|
||||
if col, exists := table.Columns[fkCol]; exists {
|
||||
if !col.NotNull {
|
||||
isNullable = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
nullable := ""
|
||||
if isNullable {
|
||||
nullable = " | null"
|
||||
}
|
||||
|
||||
// Find inverse field name if possible
|
||||
inverseField := w.findInverseFieldName(table.Name, relatedTable, schema)
|
||||
|
||||
if inverseField != "" {
|
||||
sb.WriteString(fmt.Sprintf(" @ManyToOne(() => %s, %s => %s.%s)\n",
|
||||
relatedTable, strings.ToLower(relatedTable), strings.ToLower(relatedTable), inverseField))
|
||||
} else {
|
||||
if isNullable {
|
||||
sb.WriteString(fmt.Sprintf(" @ManyToOne(() => %s, { nullable: true })\n", relatedTable))
|
||||
} else {
|
||||
sb.WriteString(fmt.Sprintf(" @ManyToOne(() => %s)\n", relatedTable))
|
||||
}
|
||||
}
|
||||
|
||||
sb.WriteString(fmt.Sprintf(" %s: %s%s;\n", fieldName, relatedTable, nullable))
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
|
||||
// Generate @OneToMany fields (inverse of FKs pointing to this table)
|
||||
w.generateInverseRelations(table, schema, joinTables, &sb)
|
||||
|
||||
// Generate @ManyToMany fields
|
||||
w.generateManyToManyRelations(table, schema, joinTables, &sb)
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// findInverseFieldName finds the inverse field name for a relation
|
||||
func (w *Writer) findInverseFieldName(fromTable, toTable string, schema *models.Schema) string {
|
||||
// Look for tables that have FKs pointing back to fromTable
|
||||
for _, table := range schema.Tables {
|
||||
if table.Name != toTable {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, constraint := range table.Constraints {
|
||||
if constraint.Type == models.ForeignKeyConstraint && constraint.ReferencedTable == fromTable {
|
||||
// Found an inverse relation
|
||||
// Use pluralized form of fromTable
|
||||
return w.pluralize(strings.ToLower(fromTable))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// generateInverseRelations generates @OneToMany fields
|
||||
func (w *Writer) generateInverseRelations(table *models.Table, schema *models.Schema, joinTables map[string]bool, sb *strings.Builder) {
|
||||
for _, otherTable := range schema.Tables {
|
||||
if otherTable.Name == table.Name || joinTables[otherTable.Name] {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, fk := range otherTable.GetForeignKeys() {
|
||||
if fk.ReferencedTable == table.Name {
|
||||
// This table is referenced by otherTable
|
||||
fieldName := w.pluralize(strings.ToLower(otherTable.Name))
|
||||
inverseName := strings.ToLower(table.Name)
|
||||
|
||||
sb.WriteString(fmt.Sprintf(" @OneToMany(() => %s, %s => %s.%s)\n",
|
||||
otherTable.Name, strings.ToLower(otherTable.Name), strings.ToLower(otherTable.Name), inverseName))
|
||||
sb.WriteString(fmt.Sprintf(" %s: %s[];\n", fieldName, otherTable.Name))
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// generateManyToManyRelations generates @ManyToMany fields
|
||||
func (w *Writer) generateManyToManyRelations(table *models.Table, schema *models.Schema, joinTables map[string]bool, sb *strings.Builder) {
|
||||
for joinTableName := range joinTables {
|
||||
joinTable := w.findTable(joinTableName, schema)
|
||||
if joinTable == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
fks := joinTable.GetForeignKeys()
|
||||
if len(fks) != 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if this table is part of the M2M relation
|
||||
var thisTableFK *models.Constraint
|
||||
var otherTableFK *models.Constraint
|
||||
|
||||
for i, fk := range fks {
|
||||
if fk.ReferencedTable == table.Name {
|
||||
thisTableFK = fk
|
||||
if i == 0 {
|
||||
otherTableFK = fks[1]
|
||||
} else {
|
||||
otherTableFK = fks[0]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if thisTableFK == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Determine which side owns the relation (has @JoinTable)
|
||||
// We'll make the first entity alphabetically the owner
|
||||
isOwner := table.Name < otherTableFK.ReferencedTable
|
||||
|
||||
otherTable := otherTableFK.ReferencedTable
|
||||
fieldName := w.pluralize(strings.ToLower(otherTable))
|
||||
inverseName := w.pluralize(strings.ToLower(table.Name))
|
||||
|
||||
if isOwner {
|
||||
sb.WriteString(fmt.Sprintf(" @ManyToMany(() => %s, %s => %s.%s)\n",
|
||||
otherTable, strings.ToLower(otherTable), strings.ToLower(otherTable), inverseName))
|
||||
sb.WriteString(" @JoinTable()\n")
|
||||
} else {
|
||||
sb.WriteString(fmt.Sprintf(" @ManyToMany(() => %s, %s => %s.%s)\n",
|
||||
otherTable, strings.ToLower(otherTable), strings.ToLower(otherTable), inverseName))
|
||||
}
|
||||
|
||||
sb.WriteString(fmt.Sprintf(" %s: %s[];\n", fieldName, otherTable))
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
}
|
||||
|
||||
// findTable finds a table by name in a schema
|
||||
func (w *Writer) findTable(name string, schema *models.Schema) *models.Table {
|
||||
for _, table := range schema.Tables {
|
||||
if table.Name == name {
|
||||
return table
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// buildEntityOptions builds the options object for @Entity decorator
|
||||
func (w *Writer) buildEntityOptions(table *models.Table) string {
|
||||
options := make([]string, 0)
|
||||
|
||||
// Always include table name
|
||||
options = append(options, fmt.Sprintf(" name: \"%s\"", table.Name))
|
||||
|
||||
// Always include schema
|
||||
options = append(options, fmt.Sprintf(" schema: \"%s\"", table.Schema))
|
||||
|
||||
// Database name from metadata
|
||||
if table.Metadata != nil {
|
||||
if database, ok := table.Metadata["database"]; ok {
|
||||
if databaseStr, ok := database.(string); ok {
|
||||
options = append(options, fmt.Sprintf(" database: \"%s\"", databaseStr))
|
||||
}
|
||||
}
|
||||
|
||||
// Engine from metadata
|
||||
if engine, ok := table.Metadata["engine"]; ok {
|
||||
if engineStr, ok := engine.(string); ok {
|
||||
options = append(options, fmt.Sprintf(" engine: \"%s\"", engineStr))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(options, ",\n")
|
||||
}
|
||||
|
||||
// pluralize adds 's' to make a word plural (simple version)
|
||||
func (w *Writer) pluralize(word string) string {
|
||||
if strings.HasSuffix(word, "s") {
|
||||
return word
|
||||
}
|
||||
return word + "s"
|
||||
}
|
||||
Reference in New Issue
Block a user