Implemented TypeORM, Prisma and Enums on a schema

This commit is contained in:
2025-12-19 21:40:46 +02:00
parent 8ca2b50f9c
commit 289715ba44
9 changed files with 3001 additions and 13 deletions

22
TODO.md
View File

@@ -2,21 +2,21 @@
## Input Readers / Writers ## Input Readers / Writers
- [x] **Database Inspector** - [✔️] **Database Inspector**
- [x] PostgreSQL driver - [✔️] PostgreSQL driver
- [ ] MySQL driver - [ ] MySQL driver
- [ ] SQLite driver - [ ] SQLite driver
- [ ] MSSQL driver - [ ] MSSQL driver
- [x] Foreign key detection - [✔️] Foreign key detection
- [x] Index extraction - [✔️] Index extraction
- [ ] .sql file generation with sequence and priority - [ ] .sql file generation with sequence and priority
- [*] .dbml: Database Markup Language (DBML) for textual schema representation. - [✔️] .dbml: Database Markup Language (DBML) for textual schema representation.
- [ ] Prisma schema support (PSL format) .prisma - [✔️] Prisma schema support (PSL format) .prisma
- [ ] Entity Framework (.NET) model .edmx - [☠️] Entity Framework (.NET) model .edmx (Fuck no, EDMX files were bloated, verbose XML nightmares—hard to merge, error-prone, and a pain in teams. Microsoft wisely ditched them in EF Core for code-first. Classic overkill from old MS era.)
- [ ] TypeORM support - [✔️] TypeORM support
- [ ] .hbm.xml / schema.xml: Hibernate/Propel mappings (Java/PHP) - [] .hbm.xml / schema.xml: Hibernate/Propel mappings (Java/PHP) (💲 Someone can do this, not me)
- [ ] Django models.py (Python classes), Sequelize migrations (JS) - [ ] Django models.py (Python classes), Sequelize migrations (JS) (💲 Someone can do this, not me)
- [ ] .avsc: Avro schema (JSON format for data serialization) - [] .avsc: Avro schema (JSON format for data serialization) (💲 Someone can do this, not me)

View File

@@ -15,6 +15,8 @@ import (
"git.warky.dev/wdevs/relspecgo/pkg/readers/gorm" "git.warky.dev/wdevs/relspecgo/pkg/readers/gorm"
"git.warky.dev/wdevs/relspecgo/pkg/readers/json" "git.warky.dev/wdevs/relspecgo/pkg/readers/json"
"git.warky.dev/wdevs/relspecgo/pkg/readers/pgsql" "git.warky.dev/wdevs/relspecgo/pkg/readers/pgsql"
"git.warky.dev/wdevs/relspecgo/pkg/readers/prisma"
"git.warky.dev/wdevs/relspecgo/pkg/readers/typeorm"
"git.warky.dev/wdevs/relspecgo/pkg/readers/yaml" "git.warky.dev/wdevs/relspecgo/pkg/readers/yaml"
"git.warky.dev/wdevs/relspecgo/pkg/writers" "git.warky.dev/wdevs/relspecgo/pkg/writers"
wbun "git.warky.dev/wdevs/relspecgo/pkg/writers/bun" wbun "git.warky.dev/wdevs/relspecgo/pkg/writers/bun"
@@ -24,6 +26,8 @@ import (
wgorm "git.warky.dev/wdevs/relspecgo/pkg/writers/gorm" wgorm "git.warky.dev/wdevs/relspecgo/pkg/writers/gorm"
wjson "git.warky.dev/wdevs/relspecgo/pkg/writers/json" wjson "git.warky.dev/wdevs/relspecgo/pkg/writers/json"
wpgsql "git.warky.dev/wdevs/relspecgo/pkg/writers/pgsql" wpgsql "git.warky.dev/wdevs/relspecgo/pkg/writers/pgsql"
wprisma "git.warky.dev/wdevs/relspecgo/pkg/writers/prisma"
wtypeorm "git.warky.dev/wdevs/relspecgo/pkg/writers/typeorm"
wyaml "git.warky.dev/wdevs/relspecgo/pkg/writers/yaml" wyaml "git.warky.dev/wdevs/relspecgo/pkg/writers/yaml"
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
@@ -55,6 +59,8 @@ Input formats:
- yaml: YAML database schema - yaml: YAML database schema
- gorm: GORM model files (Go, file or directory) - gorm: GORM model files (Go, file or directory)
- bun: Bun model files (Go, file or directory) - bun: Bun model files (Go, file or directory)
- prisma: Prisma schema files (.prisma)
- typeorm: TypeORM entity files (TypeScript)
- pgsql: PostgreSQL database (live connection) - pgsql: PostgreSQL database (live connection)
Output formats: Output formats:
@@ -65,6 +71,8 @@ Output formats:
- yaml: YAML database schema - yaml: YAML database schema
- gorm: GORM model files (Go) - gorm: GORM model files (Go)
- bun: Bun model files (Go) - bun: Bun model files (Go)
- prisma: Prisma schema files (.prisma)
- typeorm: TypeORM entity files (TypeScript)
- pgsql: PostgreSQL SQL schema - pgsql: PostgreSQL SQL schema
PostgreSQL Connection String Examples: PostgreSQL Connection String Examples:
@@ -123,11 +131,11 @@ Examples:
} }
func init() { func init() {
convertCmd.Flags().StringVar(&convertSourceType, "from", "", "Source format (dbml, dctx, drawdb, json, yaml, gorm, bun, pgsql)") convertCmd.Flags().StringVar(&convertSourceType, "from", "", "Source format (dbml, dctx, drawdb, json, yaml, gorm, bun, prisma, typeorm, pgsql)")
convertCmd.Flags().StringVar(&convertSourcePath, "from-path", "", "Source file path (for file-based formats)") convertCmd.Flags().StringVar(&convertSourcePath, "from-path", "", "Source file path (for file-based formats)")
convertCmd.Flags().StringVar(&convertSourceConn, "from-conn", "", "Source connection string (for database formats)") convertCmd.Flags().StringVar(&convertSourceConn, "from-conn", "", "Source connection string (for database formats)")
convertCmd.Flags().StringVar(&convertTargetType, "to", "", "Target format (dbml, dctx, drawdb, json, yaml, gorm, bun, pgsql)") convertCmd.Flags().StringVar(&convertTargetType, "to", "", "Target format (dbml, dctx, drawdb, json, yaml, gorm, bun, prisma, typeorm, pgsql)")
convertCmd.Flags().StringVar(&convertTargetPath, "to-path", "", "Target output path (file or directory)") convertCmd.Flags().StringVar(&convertTargetPath, "to-path", "", "Target output path (file or directory)")
convertCmd.Flags().StringVar(&convertPackageName, "package", "", "Package name (for code generation formats like gorm/bun)") convertCmd.Flags().StringVar(&convertPackageName, "package", "", "Package name (for code generation formats like gorm/bun)")
convertCmd.Flags().StringVar(&convertSchemaFilter, "schema", "", "Filter to a specific schema by name (required for formats like dctx that only support single schemas)") convertCmd.Flags().StringVar(&convertSchemaFilter, "schema", "", "Filter to a specific schema by name (required for formats like dctx that only support single schemas)")
@@ -239,6 +247,18 @@ func readDatabaseForConvert(dbType, filePath, connString string) (*models.Databa
} }
reader = bun.NewReader(&readers.ReaderOptions{FilePath: filePath}) reader = bun.NewReader(&readers.ReaderOptions{FilePath: filePath})
case "prisma":
if filePath == "" {
return nil, fmt.Errorf("file path is required for Prisma format")
}
reader = prisma.NewReader(&readers.ReaderOptions{FilePath: filePath})
case "typeorm":
if filePath == "" {
return nil, fmt.Errorf("file path is required for TypeORM format")
}
reader = typeorm.NewReader(&readers.ReaderOptions{FilePath: filePath})
default: default:
return nil, fmt.Errorf("unsupported source format: %s", dbType) return nil, fmt.Errorf("unsupported source format: %s", dbType)
} }
@@ -290,6 +310,12 @@ func writeDatabase(db *models.Database, dbType, outputPath, packageName, schemaF
case "pgsql", "postgres", "postgresql", "sql": case "pgsql", "postgres", "postgresql", "sql":
writer = wpgsql.NewWriter(writerOpts) writer = wpgsql.NewWriter(writerOpts)
case "prisma":
writer = wprisma.NewWriter(writerOpts)
case "typeorm":
writer = wtypeorm.NewWriter(writerOpts)
default: default:
return fmt.Errorf("unsupported target format: %s", dbType) return fmt.Errorf("unsupported target format: %s", dbType)
} }

View File

@@ -40,6 +40,7 @@ type Schema struct {
Sequence uint `json:"sequence,omitempty" yaml:"sequence,omitempty" xml:"sequence,omitempty"` Sequence uint `json:"sequence,omitempty" yaml:"sequence,omitempty" xml:"sequence,omitempty"`
RefDatabase *Database `json:"-" yaml:"-" xml:"-"` // Excluded to prevent circular references RefDatabase *Database `json:"-" yaml:"-" xml:"-"` // Excluded to prevent circular references
Relations []*Relationship `json:"relations,omitempty" yaml:"relations,omitempty" xml:"-"` Relations []*Relationship `json:"relations,omitempty" yaml:"relations,omitempty" xml:"-"`
Enums []*Enum `json:"enums,omitempty" yaml:"enums,omitempty" xml:"enums"`
} }
// SQLName returns the schema name in lowercase // SQLName returns the schema name in lowercase
@@ -225,6 +226,16 @@ func (d *Constraint) SQLName() string {
type ConstraintType string type ConstraintType string
type Enum struct {
Name string `json:"name" yaml:"name" xml:"name"`
Values []string `json:"values" yaml:"values" xml:"values"`
Schema string `json:"schema,omitempty" yaml:"schema,omitempty" xml:"schema,omitempty"`
}
func (d *Enum) SQLName() string {
return strings.ToLower(d.Name)
}
const ( const (
PrimaryKeyConstraint ConstraintType = "primary_key" PrimaryKeyConstraint ConstraintType = "primary_key"
ForeignKeyConstraint ConstraintType = "foreign_key" ForeignKeyConstraint ConstraintType = "foreign_key"

View File

@@ -0,0 +1,823 @@
package prisma
import (
"bufio"
"fmt"
"os"
"regexp"
"strings"
"git.warky.dev/wdevs/relspecgo/pkg/models"
"git.warky.dev/wdevs/relspecgo/pkg/readers"
)
// Reader implements the readers.Reader interface for Prisma schema format
type Reader struct {
options *readers.ReaderOptions
}
// NewReader creates a new Prisma reader with the given options
func NewReader(options *readers.ReaderOptions) *Reader {
return &Reader{
options: options,
}
}
// ReadDatabase reads and parses Prisma schema input, returning a Database model
func (r *Reader) ReadDatabase() (*models.Database, error) {
if r.options.FilePath == "" {
return nil, fmt.Errorf("file path is required for Prisma reader")
}
content, err := os.ReadFile(r.options.FilePath)
if err != nil {
return nil, fmt.Errorf("failed to read file: %w", err)
}
return r.parsePrisma(string(content))
}
// ReadSchema reads and parses Prisma schema input, returning a Schema model
func (r *Reader) ReadSchema() (*models.Schema, error) {
db, err := r.ReadDatabase()
if err != nil {
return nil, err
}
if len(db.Schemas) == 0 {
return nil, fmt.Errorf("no schemas found in Prisma schema")
}
// Return the first schema
return db.Schemas[0], nil
}
// ReadTable reads and parses Prisma schema input, returning a Table model
func (r *Reader) ReadTable() (*models.Table, error) {
schema, err := r.ReadSchema()
if err != nil {
return nil, err
}
if len(schema.Tables) == 0 {
return nil, fmt.Errorf("no tables found in Prisma schema")
}
// Return the first table
return schema.Tables[0], nil
}
// stripQuotes removes surrounding quotes from an identifier
func stripQuotes(s string) string {
s = strings.TrimSpace(s)
if len(s) >= 2 && ((s[0] == '"' && s[len(s)-1] == '"') || (s[0] == '\'' && s[len(s)-1] == '\'')) {
return s[1 : len(s)-1]
}
return s
}
// parsePrisma parses Prisma schema content and returns a Database model
func (r *Reader) parsePrisma(content string) (*models.Database, error) {
db := models.InitDatabase("database")
if r.options.Metadata != nil {
if name, ok := r.options.Metadata["name"].(string); ok {
db.Name = name
}
}
// Default schema for Prisma (doesn't have explicit schema concept in most cases)
schema := models.InitSchema("public")
schema.Enums = make([]*models.Enum, 0)
scanner := bufio.NewScanner(strings.NewReader(content))
// State tracking
var currentBlock string // "datasource", "generator", "model", "enum"
var currentTable *models.Table
var currentEnum *models.Enum
var blockContent []string
// Regex patterns
datasourceRegex := regexp.MustCompile(`^datasource\s+\w+\s*{`)
generatorRegex := regexp.MustCompile(`^generator\s+\w+\s*{`)
modelRegex := regexp.MustCompile(`^model\s+(\w+)\s*{`)
enumRegex := regexp.MustCompile(`^enum\s+(\w+)\s*{`)
for scanner.Scan() {
line := scanner.Text()
trimmed := strings.TrimSpace(line)
// Skip empty lines and comments
if trimmed == "" || strings.HasPrefix(trimmed, "//") {
continue
}
// Check for block start
if matches := datasourceRegex.FindStringSubmatch(trimmed); matches != nil {
currentBlock = "datasource"
blockContent = []string{}
continue
}
if matches := generatorRegex.FindStringSubmatch(trimmed); matches != nil {
currentBlock = "generator"
blockContent = []string{}
continue
}
if matches := modelRegex.FindStringSubmatch(trimmed); matches != nil {
currentBlock = "model"
tableName := matches[1]
currentTable = models.InitTable(tableName, "public")
blockContent = []string{}
continue
}
if matches := enumRegex.FindStringSubmatch(trimmed); matches != nil {
currentBlock = "enum"
enumName := matches[1]
currentEnum = &models.Enum{
Name: enumName,
Schema: "public",
Values: make([]string, 0),
}
blockContent = []string{}
continue
}
// Check for block end
if trimmed == "}" {
switch currentBlock {
case "datasource":
r.parseDatasource(blockContent, db)
case "generator":
// We don't need to do anything with generator blocks
case "model":
if currentTable != nil {
r.parseModelFields(blockContent, currentTable)
schema.Tables = append(schema.Tables, currentTable)
currentTable = nil
}
case "enum":
if currentEnum != nil {
schema.Enums = append(schema.Enums, currentEnum)
currentEnum = nil
}
}
currentBlock = ""
blockContent = []string{}
continue
}
// Accumulate block content
if currentBlock != "" {
if currentBlock == "enum" && currentEnum != nil {
// For enums, just add the trimmed value
if trimmed != "" {
currentEnum.Values = append(currentEnum.Values, trimmed)
}
} else {
blockContent = append(blockContent, line)
}
}
}
// Second pass: resolve relationships
r.resolveRelationships(schema)
db.Schemas = append(db.Schemas, schema)
return db, nil
}
// parseDatasource extracts database type from datasource block
func (r *Reader) parseDatasource(lines []string, db *models.Database) {
providerRegex := regexp.MustCompile(`provider\s*=\s*"?(\w+)"?`)
for _, line := range lines {
if matches := providerRegex.FindStringSubmatch(line); matches != nil {
provider := matches[1]
switch provider {
case "postgresql", "postgres":
db.DatabaseType = models.PostgresqlDatabaseType
case "mysql":
db.DatabaseType = "mysql"
case "sqlite":
db.DatabaseType = models.SqlLiteDatabaseType
case "sqlserver":
db.DatabaseType = models.MSSQLDatabaseType
default:
db.DatabaseType = models.PostgresqlDatabaseType
}
break
}
}
}
// parseModelFields parses model field definitions
func (r *Reader) parseModelFields(lines []string, table *models.Table) {
fieldRegex := regexp.MustCompile(`^(\w+)\s+(\w+)(\?|\[\])?\s*(@.+)?`)
blockAttrRegex := regexp.MustCompile(`^@@(\w+)\((.*?)\)`)
for _, line := range lines {
trimmed := strings.TrimSpace(line)
// Skip empty lines and comments
if trimmed == "" || strings.HasPrefix(trimmed, "//") {
continue
}
// Check for block attributes (@@id, @@unique, @@index)
if matches := blockAttrRegex.FindStringSubmatch(trimmed); matches != nil {
attrName := matches[1]
attrContent := matches[2]
r.parseBlockAttribute(attrName, attrContent, table)
continue
}
// Parse field definition
if matches := fieldRegex.FindStringSubmatch(trimmed); matches != nil {
fieldName := matches[1]
fieldType := matches[2]
modifier := matches[3] // ? or []
attributes := matches[4] // @... part
column := r.parseField(fieldName, fieldType, modifier, attributes, table)
if column != nil {
table.Columns[column.Name] = column
}
}
}
}
// parseField parses a single field definition
func (r *Reader) parseField(name, fieldType, modifier, attributes string, table *models.Table) *models.Column {
// Check if this is a relation field (array or references another model)
if modifier == "[]" {
// Array field - this is a relation field, not a column
// We'll handle this in relationship resolution
return nil
}
// Check if this is a non-primitive type (relation field)
// Note: We need to allow enum types through as they're like primitives
if !r.isPrimitiveType(fieldType) && !r.isEnumType(fieldType, table) {
// This is a relation field (e.g., user User), not a scalar column
// Only process this if it has @relation attribute (which means it's the owning side with FK)
// Otherwise skip it as it's just the inverse relation field
if attributes == "" || !strings.Contains(attributes, "@relation") {
return nil
}
// If it has @relation, we still don't create a column for it
// The actual FK column will be in the fields: [...] part of @relation
return nil
}
column := models.InitColumn(name, table.Name, table.Schema)
// Map Prisma type to SQL type
column.Type = r.prismaTypeToSQL(fieldType)
// Handle modifiers
if modifier == "?" {
column.NotNull = false
} else {
// Default: required fields are NOT NULL
column.NotNull = true
}
// Parse field attributes
if attributes != "" {
r.parseFieldAttributes(attributes, column, table)
}
return column
}
// prismaTypeToSQL converts Prisma types to SQL types
func (r *Reader) prismaTypeToSQL(prismaType string) string {
typeMap := map[string]string{
"String": "text",
"Boolean": "boolean",
"Int": "integer",
"BigInt": "bigint",
"Float": "double precision",
"Decimal": "decimal",
"DateTime": "timestamp",
"Json": "jsonb",
"Bytes": "bytea",
}
if sqlType, ok := typeMap[prismaType]; ok {
return sqlType
}
// If not a built-in type, it might be an enum or model reference
// For enums, we'll use the enum name directly
return prismaType
}
// parseFieldAttributes parses field attributes like @id, @unique, @default
func (r *Reader) parseFieldAttributes(attributes string, column *models.Column, table *models.Table) {
// @id attribute
if strings.Contains(attributes, "@id") {
column.IsPrimaryKey = true
column.NotNull = true
}
// @unique attribute
if regexp.MustCompile(`@unique\b`).MatchString(attributes) {
uniqueConstraint := models.InitConstraint(
fmt.Sprintf("uq_%s", column.Name),
models.UniqueConstraint,
)
uniqueConstraint.Schema = table.Schema
uniqueConstraint.Table = table.Name
uniqueConstraint.Columns = []string{column.Name}
table.Constraints[uniqueConstraint.Name] = uniqueConstraint
}
// @default attribute - extract value with balanced parentheses
if strings.Contains(attributes, "@default(") {
defaultValue := r.extractDefaultValue(attributes)
if defaultValue != "" {
r.parseDefaultValue(defaultValue, column)
}
}
// @updatedAt attribute - store in comment for now
if strings.Contains(attributes, "@updatedAt") {
if column.Comment != "" {
column.Comment += "; @updatedAt"
} else {
column.Comment = "@updatedAt"
}
}
// @relation attribute - we'll handle this in relationship resolution
// For now, just note that this field is part of a relation
}
// extractDefaultValue extracts the default value from @default(...) handling nested parentheses
func (r *Reader) extractDefaultValue(attributes string) string {
idx := strings.Index(attributes, "@default(")
if idx == -1 {
return ""
}
start := idx + len("@default(")
depth := 1
i := start
for i < len(attributes) && depth > 0 {
if attributes[i] == '(' {
depth++
} else if attributes[i] == ')' {
depth--
}
i++
}
if depth == 0 {
return attributes[start : i-1]
}
return ""
}
// parseDefaultValue parses Prisma default value expressions
func (r *Reader) parseDefaultValue(defaultExpr string, column *models.Column) {
defaultExpr = strings.TrimSpace(defaultExpr)
switch defaultExpr {
case "autoincrement()":
column.AutoIncrement = true
case "now()":
column.Default = "now()"
case "uuid()":
column.Default = "gen_random_uuid()"
case "cuid()":
// CUID is Prisma-specific, store in comment
if column.Comment != "" {
column.Comment += "; default(cuid())"
} else {
column.Comment = "default(cuid())"
}
case "true":
column.Default = true
case "false":
column.Default = false
default:
// Check if it's a string literal
if strings.HasPrefix(defaultExpr, "\"") && strings.HasSuffix(defaultExpr, "\"") {
column.Default = defaultExpr[1 : len(defaultExpr)-1]
} else if strings.HasPrefix(defaultExpr, "'") && strings.HasSuffix(defaultExpr, "'") {
column.Default = defaultExpr[1 : len(defaultExpr)-1]
} else {
// Try to parse as number or enum value
column.Default = defaultExpr
}
}
}
// parseBlockAttribute parses block-level attributes like @@id, @@unique, @@index
func (r *Reader) parseBlockAttribute(attrName, content string, table *models.Table) {
// Extract column list from brackets [col1, col2]
colListRegex := regexp.MustCompile(`\[(.*?)\]`)
matches := colListRegex.FindStringSubmatch(content)
if matches == nil {
return
}
columnList := strings.Split(matches[1], ",")
columns := make([]string, 0)
for _, col := range columnList {
columns = append(columns, strings.TrimSpace(col))
}
switch attrName {
case "id":
// Composite primary key
for _, colName := range columns {
if col, exists := table.Columns[colName]; exists {
col.IsPrimaryKey = true
col.NotNull = true
}
}
// Also create a PK constraint
pkConstraint := models.InitConstraint(
fmt.Sprintf("pk_%s", table.Name),
models.PrimaryKeyConstraint,
)
pkConstraint.Schema = table.Schema
pkConstraint.Table = table.Name
pkConstraint.Columns = columns
table.Constraints[pkConstraint.Name] = pkConstraint
case "unique":
// Multi-column unique constraint
uniqueConstraint := models.InitConstraint(
fmt.Sprintf("uq_%s_%s", table.Name, strings.Join(columns, "_")),
models.UniqueConstraint,
)
uniqueConstraint.Schema = table.Schema
uniqueConstraint.Table = table.Name
uniqueConstraint.Columns = columns
table.Constraints[uniqueConstraint.Name] = uniqueConstraint
case "index":
// Index
index := models.InitIndex(
fmt.Sprintf("idx_%s_%s", table.Name, strings.Join(columns, "_")),
table.Name,
table.Schema,
)
index.Columns = columns
table.Indexes[index.Name] = index
}
}
// relationField stores information about a relation field for second-pass processing
type relationField struct {
tableName string
fieldName string
relatedModel string
isArray bool
relationAttr string
}
// resolveRelationships performs a second pass to resolve @relation attributes
func (r *Reader) resolveRelationships(schema *models.Schema) {
// Build a map of table names for quick lookup
tableMap := make(map[string]*models.Table)
for _, table := range schema.Tables {
tableMap[table.Name] = table
}
// First, we need to re-parse to find relation fields
// We'll re-read the file to extract relation information
if r.options.FilePath == "" {
return
}
content, err := os.ReadFile(r.options.FilePath)
if err != nil {
return
}
relations := r.extractRelationFields(string(content))
// Process explicit @relation attributes to create FK constraints
for _, rel := range relations {
if rel.relationAttr != "" {
r.createConstraintFromRelation(rel, tableMap, schema)
}
}
// Detect implicit many-to-many relationships
r.detectImplicitManyToMany(relations, tableMap, schema)
}
// extractRelationFields extracts relation field information from the schema
func (r *Reader) extractRelationFields(content string) []relationField {
relations := make([]relationField, 0)
scanner := bufio.NewScanner(strings.NewReader(content))
modelRegex := regexp.MustCompile(`^model\s+(\w+)\s*{`)
fieldRegex := regexp.MustCompile(`^(\w+)\s+(\w+)(\?|\[\])?\s*(@.+)?`)
var currentModel string
inModel := false
for scanner.Scan() {
line := scanner.Text()
trimmed := strings.TrimSpace(line)
if trimmed == "" || strings.HasPrefix(trimmed, "//") {
continue
}
if matches := modelRegex.FindStringSubmatch(trimmed); matches != nil {
currentModel = matches[1]
inModel = true
continue
}
if trimmed == "}" {
inModel = false
currentModel = ""
continue
}
if inModel && currentModel != "" {
if matches := fieldRegex.FindStringSubmatch(trimmed); matches != nil {
fieldName := matches[1]
fieldType := matches[2]
modifier := matches[3]
attributes := matches[4]
// Check if this is a relation field (references another model or is an array)
isPotentialRelation := modifier == "[]" || !r.isPrimitiveType(fieldType)
if isPotentialRelation {
rel := relationField{
tableName: currentModel,
fieldName: fieldName,
relatedModel: fieldType,
isArray: modifier == "[]",
relationAttr: attributes,
}
relations = append(relations, rel)
}
}
}
}
return relations
}
// isPrimitiveType checks if a type is a Prisma primitive type
func (r *Reader) isPrimitiveType(typeName string) bool {
primitives := []string{"String", "Boolean", "Int", "BigInt", "Float", "Decimal", "DateTime", "Json", "Bytes"}
for _, p := range primitives {
if typeName == p {
return true
}
}
return false
}
// isEnumType checks if a type name might be an enum
// Note: We can't definitively check against schema.Enums at parse time
// because enums might be defined after the model, so we just check
// if it starts with uppercase (Prisma convention for enums)
func (r *Reader) isEnumType(typeName string, table *models.Table) bool {
// Simple heuristic: enum types start with uppercase letter
// and are not known model names (though we can't check that yet)
if len(typeName) > 0 && typeName[0] >= 'A' && typeName[0] <= 'Z' {
// Additional check: primitive types are already handled above
// So if it's uppercase and not primitive, it's likely an enum or model
// We'll assume it's an enum if it's a single word
return !strings.Contains(typeName, "_")
}
return false
}
// createConstraintFromRelation creates a FK constraint from a @relation attribute
func (r *Reader) createConstraintFromRelation(rel relationField, tableMap map[string]*models.Table, schema *models.Schema) {
// Skip array fields (they are the inverse side of the relation)
if rel.isArray {
return
}
if rel.relationAttr == "" {
return
}
// Parse @relation attribute
relationRegex := regexp.MustCompile(`@relation\((.*?)\)`)
matches := relationRegex.FindStringSubmatch(rel.relationAttr)
if matches == nil {
return
}
relationContent := matches[1]
// Extract fields and references
fieldsRegex := regexp.MustCompile(`fields:\s*\[(.*?)\]`)
referencesRegex := regexp.MustCompile(`references:\s*\[(.*?)\]`)
nameRegex := regexp.MustCompile(`name:\s*"([^"]+)"`)
onDeleteRegex := regexp.MustCompile(`onDelete:\s*(\w+)`)
onUpdateRegex := regexp.MustCompile(`onUpdate:\s*(\w+)`)
fieldsMatch := fieldsRegex.FindStringSubmatch(relationContent)
referencesMatch := referencesRegex.FindStringSubmatch(relationContent)
if fieldsMatch == nil || referencesMatch == nil {
return
}
// Parse field and reference column lists
fieldCols := r.parseColumnList(fieldsMatch[1])
refCols := r.parseColumnList(referencesMatch[1])
if len(fieldCols) == 0 || len(refCols) == 0 {
return
}
// Create FK constraint
constraintName := fmt.Sprintf("fk_%s_%s", rel.tableName, fieldCols[0])
// Check for custom name
if nameMatch := nameRegex.FindStringSubmatch(relationContent); nameMatch != nil {
constraintName = nameMatch[1]
}
constraint := models.InitConstraint(constraintName, models.ForeignKeyConstraint)
constraint.Schema = "public"
constraint.Table = rel.tableName
constraint.Columns = fieldCols
constraint.ReferencedSchema = "public"
constraint.ReferencedTable = rel.relatedModel
constraint.ReferencedColumns = refCols
// Parse referential actions
if onDeleteMatch := onDeleteRegex.FindStringSubmatch(relationContent); onDeleteMatch != nil {
constraint.OnDelete = onDeleteMatch[1]
}
if onUpdateMatch := onUpdateRegex.FindStringSubmatch(relationContent); onUpdateMatch != nil {
constraint.OnUpdate = onUpdateMatch[1]
}
// Add constraint to table
if table, exists := tableMap[rel.tableName]; exists {
table.Constraints[constraint.Name] = constraint
}
}
// parseColumnList parses a comma-separated list of column names
func (r *Reader) parseColumnList(list string) []string {
parts := strings.Split(list, ",")
result := make([]string, 0)
for _, part := range parts {
trimmed := strings.TrimSpace(part)
if trimmed != "" {
result = append(result, trimmed)
}
}
return result
}
// detectImplicitManyToMany detects implicit M2M relationships and creates join tables
func (r *Reader) detectImplicitManyToMany(relations []relationField, tableMap map[string]*models.Table, schema *models.Schema) {
// Group relations by model pairs
type modelPair struct {
model1 string
model2 string
}
pairMap := make(map[modelPair][]relationField)
for _, rel := range relations {
if !rel.isArray || rel.relationAttr != "" {
// Skip non-array fields and explicit relations
continue
}
// Create a normalized pair (alphabetically sorted to avoid duplicates)
pair := modelPair{}
if rel.tableName < rel.relatedModel {
pair.model1 = rel.tableName
pair.model2 = rel.relatedModel
} else {
pair.model1 = rel.relatedModel
pair.model2 = rel.tableName
}
pairMap[pair] = append(pairMap[pair], rel)
}
// Check for pairs with arrays on both sides (implicit M2M)
for pair, rels := range pairMap {
if len(rels) >= 2 {
// This is an implicit many-to-many relationship
r.createImplicitJoinTable(pair.model1, pair.model2, tableMap, schema)
}
}
}
// createImplicitJoinTable creates a virtual join table for implicit M2M relations
func (r *Reader) createImplicitJoinTable(model1, model2 string, tableMap map[string]*models.Table, schema *models.Schema) {
// Prisma naming convention: _Model1ToModel2 (alphabetically sorted)
joinTableName := fmt.Sprintf("_%sTo%s", model1, model2)
// Check if join table already exists
if _, exists := tableMap[joinTableName]; exists {
return
}
// Create join table
joinTable := models.InitTable(joinTableName, "public")
// Get primary keys from both tables
pk1 := r.getPrimaryKeyColumn(tableMap[model1])
pk2 := r.getPrimaryKeyColumn(tableMap[model2])
if pk1 == nil || pk2 == nil {
return // Can't create join table without PKs
}
// Create FK columns in join table
fkCol1Name := fmt.Sprintf("%sId", model1)
fkCol1 := models.InitColumn(fkCol1Name, joinTableName, "public")
fkCol1.Type = pk1.Type
fkCol1.NotNull = true
joinTable.Columns[fkCol1Name] = fkCol1
fkCol2Name := fmt.Sprintf("%sId", model2)
fkCol2 := models.InitColumn(fkCol2Name, joinTableName, "public")
fkCol2.Type = pk2.Type
fkCol2.NotNull = true
joinTable.Columns[fkCol2Name] = fkCol2
// Create composite primary key
pkConstraint := models.InitConstraint(
fmt.Sprintf("pk_%s", joinTableName),
models.PrimaryKeyConstraint,
)
pkConstraint.Schema = "public"
pkConstraint.Table = joinTableName
pkConstraint.Columns = []string{fkCol1Name, fkCol2Name}
joinTable.Constraints[pkConstraint.Name] = pkConstraint
// Mark columns as PK
fkCol1.IsPrimaryKey = true
fkCol2.IsPrimaryKey = true
// Create FK constraints
fk1 := models.InitConstraint(
fmt.Sprintf("fk_%s_%s", joinTableName, model1),
models.ForeignKeyConstraint,
)
fk1.Schema = "public"
fk1.Table = joinTableName
fk1.Columns = []string{fkCol1Name}
fk1.ReferencedSchema = "public"
fk1.ReferencedTable = model1
fk1.ReferencedColumns = []string{pk1.Name}
fk1.OnDelete = "Cascade"
joinTable.Constraints[fk1.Name] = fk1
fk2 := models.InitConstraint(
fmt.Sprintf("fk_%s_%s", joinTableName, model2),
models.ForeignKeyConstraint,
)
fk2.Schema = "public"
fk2.Table = joinTableName
fk2.Columns = []string{fkCol2Name}
fk2.ReferencedSchema = "public"
fk2.ReferencedTable = model2
fk2.ReferencedColumns = []string{pk2.Name}
fk2.OnDelete = "Cascade"
joinTable.Constraints[fk2.Name] = fk2
// Add join table to schema
schema.Tables = append(schema.Tables, joinTable)
tableMap[joinTableName] = joinTable
}
// getPrimaryKeyColumn returns the primary key column of a table
func (r *Reader) getPrimaryKeyColumn(table *models.Table) *models.Column {
if table == nil {
return nil
}
for _, col := range table.Columns {
if col.IsPrimaryKey {
return col
}
}
return nil
}

View File

@@ -0,0 +1,785 @@
package typeorm
import (
"bufio"
"fmt"
"os"
"regexp"
"strings"
"git.warky.dev/wdevs/relspecgo/pkg/models"
"git.warky.dev/wdevs/relspecgo/pkg/readers"
)
// Reader implements the readers.Reader interface for TypeORM entity files
type Reader struct {
options *readers.ReaderOptions
}
// NewReader creates a new TypeORM reader with the given options
func NewReader(options *readers.ReaderOptions) *Reader {
return &Reader{
options: options,
}
}
// ReadDatabase reads and parses TypeORM entity files, returning a Database model
func (r *Reader) ReadDatabase() (*models.Database, error) {
if r.options.FilePath == "" {
return nil, fmt.Errorf("file path is required for TypeORM reader")
}
content, err := os.ReadFile(r.options.FilePath)
if err != nil {
return nil, fmt.Errorf("failed to read file: %w", err)
}
return r.parseTypeORM(string(content))
}
// ReadSchema reads and parses TypeORM entity files, returning a Schema model
func (r *Reader) ReadSchema() (*models.Schema, error) {
db, err := r.ReadDatabase()
if err != nil {
return nil, err
}
if len(db.Schemas) == 0 {
return nil, fmt.Errorf("no schemas found in TypeORM entities")
}
return db.Schemas[0], nil
}
// ReadTable reads and parses TypeORM entity files, returning a Table model
func (r *Reader) ReadTable() (*models.Table, error) {
schema, err := r.ReadSchema()
if err != nil {
return nil, err
}
if len(schema.Tables) == 0 {
return nil, fmt.Errorf("no tables found in TypeORM entities")
}
return schema.Tables[0], nil
}
// entityInfo stores information about an entity during parsing
type entityInfo struct {
name string
fields []fieldInfo
decorators []string
}
// fieldInfo stores information about a field during parsing
type fieldInfo struct {
name string
typeName string
decorators []string
}
// parseTypeORM parses TypeORM entity content and returns a Database model
func (r *Reader) parseTypeORM(content string) (*models.Database, error) {
db := models.InitDatabase("database")
schema := models.InitSchema("public")
// Parse entities
entities := r.extractEntities(content)
// Convert entities to tables and views
tableMap := make(map[string]*models.Table)
for _, entity := range entities {
// Check if this is a view
isView := false
for _, decorator := range entity.decorators {
if strings.HasPrefix(decorator, "@ViewEntity") {
isView = true
break
}
}
if isView {
view := r.entityToView(entity)
schema.Views = append(schema.Views, view)
} else {
table := r.entityToTable(entity)
schema.Tables = append(schema.Tables, table)
tableMap[table.Name] = table
}
}
// Second pass: resolve relationships
r.resolveRelationships(entities, tableMap, schema)
db.Schemas = append(db.Schemas, schema)
return db, nil
}
// extractEntities extracts entity and view definitions from TypeORM content
func (r *Reader) extractEntities(content string) []entityInfo {
entities := make([]entityInfo, 0)
// First, extract decorators properly (handling multi-line)
content = r.normalizeDecorators(content)
scanner := bufio.NewScanner(strings.NewReader(content))
entityRegex := regexp.MustCompile(`^export\s+class\s+(\w+)`)
decoratorRegex := regexp.MustCompile(`^\s*@(\w+)(\([^)]*\))?`)
fieldRegex := regexp.MustCompile(`^\s*(\w+):\s*([^;]+);`)
var currentEntity *entityInfo
var pendingDecorators []string
inClass := false
for scanner.Scan() {
line := scanner.Text()
trimmed := strings.TrimSpace(line)
// Skip empty lines and comments
if trimmed == "" || strings.HasPrefix(trimmed, "//") || strings.HasPrefix(trimmed, "import ") {
continue
}
// Check for decorator
if matches := decoratorRegex.FindStringSubmatch(trimmed); matches != nil {
decorator := matches[0]
pendingDecorators = append(pendingDecorators, decorator)
continue
}
// Check for entity/view class
if matches := entityRegex.FindStringSubmatch(trimmed); matches != nil {
// Save previous entity if exists
if currentEntity != nil {
entities = append(entities, *currentEntity)
}
currentEntity = &entityInfo{
name: matches[1],
fields: make([]fieldInfo, 0),
decorators: pendingDecorators,
}
pendingDecorators = []string{}
inClass = true
continue
}
// Check for class end
if inClass && trimmed == "}" {
if currentEntity != nil {
entities = append(entities, *currentEntity)
currentEntity = nil
}
inClass = false
pendingDecorators = []string{}
continue
}
// Check for field definition
if inClass && currentEntity != nil {
if matches := fieldRegex.FindStringSubmatch(trimmed); matches != nil {
fieldName := matches[1]
fieldType := strings.TrimSpace(matches[2])
field := fieldInfo{
name: fieldName,
typeName: fieldType,
decorators: pendingDecorators,
}
currentEntity.fields = append(currentEntity.fields, field)
pendingDecorators = []string{}
}
}
}
// Save last entity
if currentEntity != nil {
entities = append(entities, *currentEntity)
}
return entities
}
// normalizeDecorators combines multi-line decorators into single lines
func (r *Reader) normalizeDecorators(content string) string {
// Replace multi-line decorators with single-line versions
// Match @Decorator({ ... }) across multiple lines
decoratorRegex := regexp.MustCompile(`@(\w+)\s*\(\s*\{([^}]*)\}\s*\)`)
return decoratorRegex.ReplaceAllStringFunc(content, func(match string) string {
// Remove newlines and extra spaces from decorator
match = strings.ReplaceAll(match, "\n", " ")
match = strings.ReplaceAll(match, "\r", " ")
// Normalize multiple spaces
spaceRegex := regexp.MustCompile(`\s+`)
match = spaceRegex.ReplaceAllString(match, " ")
return match
})
}
// entityToView converts a view entity to a view
func (r *Reader) entityToView(entity entityInfo) *models.View {
// Parse @ViewEntity decorator options
viewName := entity.name
schemaName := "public"
var expression string
for _, decorator := range entity.decorators {
if strings.HasPrefix(decorator, "@ViewEntity") {
// Extract options from @ViewEntity({ ... })
options := r.parseViewEntityOptions(decorator)
// Check for custom view name
if name, ok := options["name"]; ok {
viewName = name
}
// Check for schema
if schema, ok := options["schema"]; ok {
schemaName = schema
}
// Check for expression (SQL definition)
if expr, ok := options["expression"]; ok {
expression = expr
}
break
}
}
view := models.InitView(viewName, schemaName)
view.Definition = expression
// Add columns from fields (if any are defined in the view class)
for _, field := range entity.fields {
column := models.InitColumn(field.name, viewName, schemaName)
column.Type = r.typeScriptTypeToSQL(field.typeName)
view.Columns[column.Name] = column
}
return view
}
// parseViewEntityOptions parses @ViewEntity decorator options
func (r *Reader) parseViewEntityOptions(decorator string) map[string]string {
options := make(map[string]string)
// Extract content between parentheses
start := strings.Index(decorator, "(")
end := strings.LastIndex(decorator, ")")
if start == -1 || end == -1 || start >= end {
return options
}
content := decorator[start+1 : end]
// Skip if empty @ViewEntity()
if strings.TrimSpace(content) == "" {
return options
}
// Parse name: "value"
nameRegex := regexp.MustCompile(`name:\s*["']([^"']+)["']`)
if matches := nameRegex.FindStringSubmatch(content); matches != nil {
options["name"] = matches[1]
}
// Parse schema: "value"
schemaRegex := regexp.MustCompile(`schema:\s*["']([^"']+)["']`)
if matches := schemaRegex.FindStringSubmatch(content); matches != nil {
options["schema"] = matches[1]
}
// Parse expression: ` ... ` (can be multi-line, captured as single line after normalization)
// Look for expression followed by backtick or quote
expressionRegex := regexp.MustCompile(`expression:\s*` + "`" + `([^` + "`" + `]+)` + "`")
if matches := expressionRegex.FindStringSubmatch(content); matches != nil {
options["expression"] = strings.TrimSpace(matches[1])
} else {
// Try with regular quotes
expressionRegex = regexp.MustCompile(`expression:\s*["']([^"']+)["']`)
if matches := expressionRegex.FindStringSubmatch(content); matches != nil {
options["expression"] = strings.TrimSpace(matches[1])
}
}
return options
}
// entityToTable converts an entity to a table
func (r *Reader) entityToTable(entity entityInfo) *models.Table {
// Parse @Entity decorator options
tableName := entity.name
schemaName := "public"
var entityOptions map[string]string
for _, decorator := range entity.decorators {
if strings.HasPrefix(decorator, "@Entity") {
// Extract options from @Entity({ ... })
entityOptions = r.parseEntityOptions(decorator)
// Check for custom table name
if name, ok := entityOptions["name"]; ok {
tableName = name
}
// Check for schema
if schema, ok := entityOptions["schema"]; ok {
schemaName = schema
}
break
}
}
table := models.InitTable(tableName, schemaName)
// Store additional metadata from @Entity options
if entityOptions != nil {
// Store database name in metadata
if database, ok := entityOptions["database"]; ok {
if table.Metadata == nil {
table.Metadata = make(map[string]any)
}
table.Metadata["database"] = database
}
// Store engine in metadata
if engine, ok := entityOptions["engine"]; ok {
if table.Metadata == nil {
table.Metadata = make(map[string]any)
}
table.Metadata["engine"] = engine
}
// Store original class name if different from table name
if entity.name != tableName {
if table.Metadata == nil {
table.Metadata = make(map[string]any)
}
table.Metadata["class_name"] = entity.name
}
}
for _, field := range entity.fields {
// Skip relation fields (they'll be handled in relationship resolution)
if r.isRelationField(field) {
continue
}
column := r.fieldToColumn(field, table)
if column != nil {
table.Columns[column.Name] = column
}
}
return table
}
// parseEntityOptions parses @Entity decorator options
func (r *Reader) parseEntityOptions(decorator string) map[string]string {
options := make(map[string]string)
// Extract content between parentheses
start := strings.Index(decorator, "(")
end := strings.LastIndex(decorator, ")")
if start == -1 || end == -1 || start >= end {
return options
}
content := decorator[start+1 : end]
// Skip if empty @Entity()
if strings.TrimSpace(content) == "" {
return options
}
// Parse name: "value" or name: 'value'
nameRegex := regexp.MustCompile(`name:\s*["']([^"']+)["']`)
if matches := nameRegex.FindStringSubmatch(content); matches != nil {
options["name"] = matches[1]
}
// Parse schema: "value"
schemaRegex := regexp.MustCompile(`schema:\s*["']([^"']+)["']`)
if matches := schemaRegex.FindStringSubmatch(content); matches != nil {
options["schema"] = matches[1]
}
// Parse database: "value"
databaseRegex := regexp.MustCompile(`database:\s*["']([^"']+)["']`)
if matches := databaseRegex.FindStringSubmatch(content); matches != nil {
options["database"] = matches[1]
}
// Parse engine: "value"
engineRegex := regexp.MustCompile(`engine:\s*["']([^"']+)["']`)
if matches := engineRegex.FindStringSubmatch(content); matches != nil {
options["engine"] = matches[1]
}
return options
}
// isRelationField checks if a field is a relation field
func (r *Reader) isRelationField(field fieldInfo) bool {
for _, decorator := range field.decorators {
if strings.Contains(decorator, "@ManyToOne") ||
strings.Contains(decorator, "@OneToMany") ||
strings.Contains(decorator, "@ManyToMany") ||
strings.Contains(decorator, "@OneToOne") {
return true
}
}
return false
}
// fieldToColumn converts a field to a column
func (r *Reader) fieldToColumn(field fieldInfo, table *models.Table) *models.Column {
column := models.InitColumn(field.name, table.Name, table.Schema)
// Map TypeScript type to SQL type
column.Type = r.typeScriptTypeToSQL(field.typeName)
// Default to NOT NULL
column.NotNull = true
// Parse decorators
for _, decorator := range field.decorators {
r.parseColumnDecorator(decorator, column, table)
}
return column
}
// typeScriptTypeToSQL converts TypeScript types to SQL types
func (r *Reader) typeScriptTypeToSQL(tsType string) string {
// Remove array brackets and optional markers
tsType = strings.TrimSuffix(tsType, "[]")
tsType = strings.TrimSuffix(tsType, " | null")
typeMap := map[string]string{
"string": "text",
"number": "integer",
"boolean": "boolean",
"Date": "timestamp",
"any": "jsonb",
}
for tsPattern, sqlType := range typeMap {
if strings.Contains(tsType, tsPattern) {
return sqlType
}
}
// Default to text
return "text"
}
// parseColumnDecorator parses a column decorator
func (r *Reader) parseColumnDecorator(decorator string, column *models.Column, table *models.Table) {
// @PrimaryGeneratedColumn
if strings.HasPrefix(decorator, "@PrimaryGeneratedColumn") {
column.IsPrimaryKey = true
column.NotNull = true
if strings.Contains(decorator, "'uuid'") {
column.Type = "uuid"
column.Default = "gen_random_uuid()"
} else if strings.Contains(decorator, "'increment'") || strings.Contains(decorator, "()") {
column.AutoIncrement = true
}
return
}
// @Column
if strings.HasPrefix(decorator, "@Column") {
r.parseColumnOptions(decorator, column, table)
return
}
// @CreateDateColumn
if strings.HasPrefix(decorator, "@CreateDateColumn") {
column.Type = "timestamp"
column.Default = "now()"
column.NotNull = true
return
}
// @UpdateDateColumn
if strings.HasPrefix(decorator, "@UpdateDateColumn") {
column.Type = "timestamp"
column.NotNull = true
if column.Comment != "" {
column.Comment += "; auto-update"
} else {
column.Comment = "auto-update"
}
return
}
}
// parseColumnOptions parses @Column decorator options
func (r *Reader) parseColumnOptions(decorator string, column *models.Column, table *models.Table) {
// Extract content between parentheses
start := strings.Index(decorator, "(")
end := strings.LastIndex(decorator, ")")
if start == -1 || end == -1 || start >= end {
return
}
content := decorator[start+1 : end]
// Check for shorthand type: @Column('text')
if strings.HasPrefix(content, "'") || strings.HasPrefix(content, "\"") {
typeStr := strings.Trim(content, "'\"`")
column.Type = typeStr
return
}
// Parse options object
if strings.Contains(content, "type:") {
typeRegex := regexp.MustCompile(`type:\s*['"]([^'"]+)['"]`)
if matches := typeRegex.FindStringSubmatch(content); matches != nil {
column.Type = matches[1]
}
}
if strings.Contains(content, "nullable: true") || strings.Contains(content, "nullable:true") {
column.NotNull = false
}
if strings.Contains(content, "unique: true") || strings.Contains(content, "unique:true") {
uniqueConstraint := models.InitConstraint(
fmt.Sprintf("uq_%s", column.Name),
models.UniqueConstraint,
)
uniqueConstraint.Schema = table.Schema
uniqueConstraint.Table = table.Name
uniqueConstraint.Columns = []string{column.Name}
table.Constraints[uniqueConstraint.Name] = uniqueConstraint
}
if strings.Contains(content, "default:") {
defaultRegex := regexp.MustCompile(`default:\s*['"]?([^,}'"]+)['"]?`)
if matches := defaultRegex.FindStringSubmatch(content); matches != nil {
defaultValue := strings.TrimSpace(matches[1])
defaultValue = strings.Trim(defaultValue, "'\"")
column.Default = defaultValue
}
}
}
// resolveRelationships resolves TypeORM relationships
func (r *Reader) resolveRelationships(entities []entityInfo, tableMap map[string]*models.Table, schema *models.Schema) {
// Track M2M relations that need join tables
type m2mRelation struct {
ownerEntity string
targetEntity string
ownerField string
}
m2mRelations := make([]m2mRelation, 0)
for _, entity := range entities {
table := tableMap[entity.name]
if table == nil {
continue
}
for _, field := range entity.fields {
// Handle @ManyToOne relations
if r.hasDecorator(field, "@ManyToOne") {
r.createManyToOneConstraint(field, entity.name, table, tableMap)
}
// Track @ManyToMany relations with @JoinTable
if r.hasDecorator(field, "@ManyToMany") && r.hasDecorator(field, "@JoinTable") {
targetEntity := r.extractRelationTarget(field)
if targetEntity != "" {
m2mRelations = append(m2mRelations, m2mRelation{
ownerEntity: entity.name,
targetEntity: targetEntity,
ownerField: field.name,
})
}
}
}
}
// Create join tables for M2M relations
for _, rel := range m2mRelations {
r.createManyToManyJoinTable(rel.ownerEntity, rel.targetEntity, tableMap, schema)
}
}
// hasDecorator checks if a field has a specific decorator
func (r *Reader) hasDecorator(field fieldInfo, decoratorName string) bool {
for _, decorator := range field.decorators {
if strings.HasPrefix(decorator, decoratorName) {
return true
}
}
return false
}
// extractRelationTarget extracts the target entity from a relation decorator
func (r *Reader) extractRelationTarget(field fieldInfo) string {
// Remove array brackets from type
targetType := strings.TrimSuffix(field.typeName, "[]")
targetType = strings.TrimSpace(targetType)
return targetType
}
// createManyToOneConstraint creates a foreign key constraint for @ManyToOne
func (r *Reader) createManyToOneConstraint(field fieldInfo, entityName string, table *models.Table, tableMap map[string]*models.Table) {
targetEntity := r.extractRelationTarget(field)
if targetEntity == "" {
return
}
// Get target table to find its PK
targetTable := tableMap[targetEntity]
if targetTable == nil {
return
}
targetPK := r.getPrimaryKeyColumn(targetTable)
if targetPK == nil {
return
}
// Create FK column
fkColumnName := fmt.Sprintf("%sId", field.name)
fkColumn := models.InitColumn(fkColumnName, table.Name, table.Schema)
fkColumn.Type = targetPK.Type
// Check if nullable option is set in @ManyToOne decorator
isNullable := false
for _, decorator := range field.decorators {
if strings.Contains(decorator, "nullable: true") || strings.Contains(decorator, "nullable:true") {
isNullable = true
break
}
}
fkColumn.NotNull = !isNullable
table.Columns[fkColumnName] = fkColumn
// Create FK constraint
constraint := models.InitConstraint(
fmt.Sprintf("fk_%s_%s", entityName, field.name),
models.ForeignKeyConstraint,
)
constraint.Schema = table.Schema
constraint.Table = table.Name
constraint.Columns = []string{fkColumnName}
constraint.ReferencedSchema = "public"
constraint.ReferencedTable = targetEntity
constraint.ReferencedColumns = []string{targetPK.Name}
constraint.OnDelete = "CASCADE"
table.Constraints[constraint.Name] = constraint
}
// createManyToManyJoinTable creates a join table for M2M relations
func (r *Reader) createManyToManyJoinTable(entity1, entity2 string, tableMap map[string]*models.Table, schema *models.Schema) {
// TypeORM naming convention: entity1_entity2_entity1field
// We'll simplify to entity1_entity2
joinTableName := fmt.Sprintf("%s_%s", strings.ToLower(entity1), strings.ToLower(entity2))
// Check if join table already exists
if _, exists := tableMap[joinTableName]; exists {
return
}
// Get PKs from both tables
table1 := tableMap[entity1]
table2 := tableMap[entity2]
if table1 == nil || table2 == nil {
return
}
pk1 := r.getPrimaryKeyColumn(table1)
pk2 := r.getPrimaryKeyColumn(table2)
if pk1 == nil || pk2 == nil {
return
}
// Create join table
joinTable := models.InitTable(joinTableName, "public")
// Create FK columns
fkCol1Name := fmt.Sprintf("%sId", strings.ToLower(entity1))
fkCol1 := models.InitColumn(fkCol1Name, joinTableName, "public")
fkCol1.Type = pk1.Type
fkCol1.NotNull = true
fkCol1.IsPrimaryKey = true
joinTable.Columns[fkCol1Name] = fkCol1
fkCol2Name := fmt.Sprintf("%sId", strings.ToLower(entity2))
fkCol2 := models.InitColumn(fkCol2Name, joinTableName, "public")
fkCol2.Type = pk2.Type
fkCol2.NotNull = true
fkCol2.IsPrimaryKey = true
joinTable.Columns[fkCol2Name] = fkCol2
// Create composite PK constraint
pkConstraint := models.InitConstraint(
fmt.Sprintf("pk_%s", joinTableName),
models.PrimaryKeyConstraint,
)
pkConstraint.Schema = "public"
pkConstraint.Table = joinTableName
pkConstraint.Columns = []string{fkCol1Name, fkCol2Name}
joinTable.Constraints[pkConstraint.Name] = pkConstraint
// Create FK constraints
fk1 := models.InitConstraint(
fmt.Sprintf("fk_%s_%s", joinTableName, entity1),
models.ForeignKeyConstraint,
)
fk1.Schema = "public"
fk1.Table = joinTableName
fk1.Columns = []string{fkCol1Name}
fk1.ReferencedSchema = "public"
fk1.ReferencedTable = entity1
fk1.ReferencedColumns = []string{pk1.Name}
fk1.OnDelete = "CASCADE"
joinTable.Constraints[fk1.Name] = fk1
fk2 := models.InitConstraint(
fmt.Sprintf("fk_%s_%s", joinTableName, entity2),
models.ForeignKeyConstraint,
)
fk2.Schema = "public"
fk2.Table = joinTableName
fk2.Columns = []string{fkCol2Name}
fk2.ReferencedSchema = "public"
fk2.ReferencedTable = entity2
fk2.ReferencedColumns = []string{pk2.Name}
fk2.OnDelete = "CASCADE"
joinTable.Constraints[fk2.Name] = fk2
// Add join table to schema
schema.Tables = append(schema.Tables, joinTable)
tableMap[joinTableName] = joinTable
}
// getPrimaryKeyColumn returns the primary key column of a table
func (r *Reader) getPrimaryKeyColumn(table *models.Table) *models.Column {
if table == nil {
return nil
}
for _, col := range table.Columns {
if col.IsPrimaryKey {
return col
}
}
return nil
}

View File

@@ -0,0 +1,551 @@
package prisma
import (
"fmt"
"os"
"sort"
"strings"
"git.warky.dev/wdevs/relspecgo/pkg/models"
"git.warky.dev/wdevs/relspecgo/pkg/writers"
)
// Writer implements the writers.Writer interface for Prisma schema format
type Writer struct {
options *writers.WriterOptions
}
// NewWriter creates a new Prisma writer with the given options
func NewWriter(options *writers.WriterOptions) *Writer {
return &Writer{
options: options,
}
}
// WriteDatabase writes a Database model to Prisma schema format
func (w *Writer) WriteDatabase(db *models.Database) error {
content := w.databaseToPrisma(db)
if w.options.OutputPath != "" {
return os.WriteFile(w.options.OutputPath, []byte(content), 0644)
}
fmt.Print(content)
return nil
}
// WriteSchema writes a Schema model to Prisma schema format
func (w *Writer) WriteSchema(schema *models.Schema) error {
// Create temporary database for schema
db := models.InitDatabase("database")
db.Schemas = []*models.Schema{schema}
return w.WriteDatabase(db)
}
// WriteTable writes a Table model to Prisma schema format
func (w *Writer) WriteTable(table *models.Table) error {
// Create temporary schema and database for table
schema := models.InitSchema(table.Schema)
schema.Tables = []*models.Table{table}
return w.WriteSchema(schema)
}
// databaseToPrisma converts a Database to Prisma schema format string
func (w *Writer) databaseToPrisma(db *models.Database) string {
var sb strings.Builder
// Write datasource block
sb.WriteString(w.generateDatasource(db))
sb.WriteString("\n")
// Write generator block
sb.WriteString(w.generateGenerator())
sb.WriteString("\n")
// Process all schemas (typically just one in Prisma)
for _, schema := range db.Schemas {
// Write enums
if len(schema.Enums) > 0 {
for _, enum := range schema.Enums {
sb.WriteString(w.enumToPrisma(enum))
sb.WriteString("\n")
}
}
// Identify join tables for implicit M2M
joinTables := w.identifyJoinTables(schema)
// Write models (excluding join tables)
for _, table := range schema.Tables {
if joinTables[table.Name] {
continue // Skip join tables
}
sb.WriteString(w.tableToPrisma(table, schema, joinTables))
sb.WriteString("\n")
}
}
return sb.String()
}
// generateDatasource generates the datasource block
func (w *Writer) generateDatasource(db *models.Database) string {
provider := "postgresql"
// Map database type to Prisma provider
switch db.DatabaseType {
case models.PostgresqlDatabaseType:
provider = "postgresql"
case models.MSSQLDatabaseType:
provider = "sqlserver"
case models.SqlLiteDatabaseType:
provider = "sqlite"
case "mysql":
provider = "mysql"
}
return fmt.Sprintf(`datasource db {
provider = "%s"
url = env("DATABASE_URL")
}
`, provider)
}
// generateGenerator generates the generator block
func (w *Writer) generateGenerator() string {
return `generator client {
provider = "prisma-client-js"
}
`
}
// enumToPrisma converts an Enum to Prisma enum block
func (w *Writer) enumToPrisma(enum *models.Enum) string {
var sb strings.Builder
sb.WriteString(fmt.Sprintf("enum %s {\n", enum.Name))
for _, value := range enum.Values {
sb.WriteString(fmt.Sprintf(" %s\n", value))
}
sb.WriteString("}\n")
return sb.String()
}
// identifyJoinTables identifies tables that are join tables for M2M relations
func (w *Writer) identifyJoinTables(schema *models.Schema) map[string]bool {
joinTables := make(map[string]bool)
for _, table := range schema.Tables {
// Check if this is a join table:
// 1. Starts with _ (Prisma convention)
// 2. Has exactly 2 FK constraints
// 3. Has composite PK with those 2 columns
// 4. Has no other columns except the FK columns
if !strings.HasPrefix(table.Name, "_") {
continue
}
fks := table.GetForeignKeys()
if len(fks) != 2 {
continue
}
// Check if columns are only the FK columns
if len(table.Columns) != 2 {
continue
}
// Check if both FK columns are part of PK
pkCols := 0
for _, col := range table.Columns {
if col.IsPrimaryKey {
pkCols++
}
}
if pkCols == 2 {
joinTables[table.Name] = true
}
}
return joinTables
}
// tableToPrisma converts a Table to Prisma model block
func (w *Writer) tableToPrisma(table *models.Table, schema *models.Schema, joinTables map[string]bool) string {
var sb strings.Builder
sb.WriteString(fmt.Sprintf("model %s {\n", table.Name))
// Collect columns to write
columns := make([]*models.Column, 0, len(table.Columns))
for _, col := range table.Columns {
columns = append(columns, col)
}
// Sort columns for consistent output
sort.Slice(columns, func(i, j int) bool {
return columns[i].Name < columns[j].Name
})
// Write scalar fields
for _, col := range columns {
// Skip if this column is part of a relation that will be output as array field
if w.isRelationColumn(col, table) {
// We'll output this with the relation field
continue
}
sb.WriteString(w.columnToField(col, table, schema))
}
// Write relation fields
sb.WriteString(w.generateRelationFields(table, schema, joinTables))
// Write block attributes (@@id, @@unique, @@index)
sb.WriteString(w.generateBlockAttributes(table))
sb.WriteString("}\n")
return sb.String()
}
// columnToField converts a Column to a Prisma field definition
func (w *Writer) columnToField(col *models.Column, table *models.Table, schema *models.Schema) string {
var sb strings.Builder
// Field name
sb.WriteString(fmt.Sprintf(" %s", col.Name))
// Field type
prismaType := w.sqlTypeToPrisma(col.Type, schema)
sb.WriteString(fmt.Sprintf(" %s", prismaType))
// Optional modifier
if !col.NotNull && !col.IsPrimaryKey {
sb.WriteString("?")
}
// Field attributes
attributes := w.generateFieldAttributes(col, table)
if attributes != "" {
sb.WriteString(" ")
sb.WriteString(attributes)
}
sb.WriteString("\n")
return sb.String()
}
// sqlTypeToPrisma converts SQL types to Prisma types
func (w *Writer) sqlTypeToPrisma(sqlType string, schema *models.Schema) string {
// Check if it's an enum
for _, enum := range schema.Enums {
if strings.EqualFold(sqlType, enum.Name) {
return enum.Name
}
}
// Standard type mapping
typeMap := map[string]string{
"text": "String",
"varchar": "String",
"character varying": "String",
"char": "String",
"boolean": "Boolean",
"bool": "Boolean",
"integer": "Int",
"int": "Int",
"int4": "Int",
"bigint": "BigInt",
"int8": "BigInt",
"double precision": "Float",
"float": "Float",
"float8": "Float",
"decimal": "Decimal",
"numeric": "Decimal",
"timestamp": "DateTime",
"timestamptz": "DateTime",
"date": "DateTime",
"jsonb": "Json",
"json": "Json",
"bytea": "Bytes",
}
for sqlPattern, prismaType := range typeMap {
if strings.Contains(strings.ToLower(sqlType), sqlPattern) {
return prismaType
}
}
// Default to String for unknown types
return "String"
}
// generateFieldAttributes generates field attributes like @id, @unique, @default
func (w *Writer) generateFieldAttributes(col *models.Column, table *models.Table) string {
attrs := make([]string, 0)
// @id
if col.IsPrimaryKey {
// Check if this is part of a composite key
pkCount := 0
for _, c := range table.Columns {
if c.IsPrimaryKey {
pkCount++
}
}
if pkCount == 1 {
attrs = append(attrs, "@id")
}
}
// @unique
if w.hasUniqueConstraint(col.Name, table) {
attrs = append(attrs, "@unique")
}
// @default
if col.AutoIncrement {
attrs = append(attrs, "@default(autoincrement())")
} else if col.Default != nil {
defaultAttr := w.formatDefaultValue(col.Default)
if defaultAttr != "" {
attrs = append(attrs, fmt.Sprintf("@default(%s)", defaultAttr))
}
}
// @updatedAt (check comment)
if strings.Contains(col.Comment, "@updatedAt") {
attrs = append(attrs, "@updatedAt")
}
return strings.Join(attrs, " ")
}
// formatDefaultValue formats a default value for Prisma
func (w *Writer) formatDefaultValue(defaultValue any) string {
switch v := defaultValue.(type) {
case string:
if v == "now()" {
return "now()"
} else if v == "gen_random_uuid()" {
return "uuid()"
} else if strings.Contains(strings.ToLower(v), "uuid") {
return "uuid()"
} else {
// String literal
return fmt.Sprintf(`"%s"`, v)
}
case bool:
if v {
return "true"
}
return "false"
case int, int64, int32:
return fmt.Sprintf("%v", v)
default:
return fmt.Sprintf("%v", v)
}
}
// hasUniqueConstraint checks if a column has a unique constraint
func (w *Writer) hasUniqueConstraint(colName string, table *models.Table) bool {
for _, constraint := range table.Constraints {
if constraint.Type == models.UniqueConstraint &&
len(constraint.Columns) == 1 &&
constraint.Columns[0] == colName {
return true
}
}
return false
}
// isRelationColumn checks if a column is a FK column
func (w *Writer) isRelationColumn(col *models.Column, table *models.Table) bool {
for _, constraint := range table.Constraints {
if constraint.Type == models.ForeignKeyConstraint {
for _, fkCol := range constraint.Columns {
if fkCol == col.Name {
return true
}
}
}
}
return false
}
// generateRelationFields generates relation fields and their FK columns
func (w *Writer) generateRelationFields(table *models.Table, schema *models.Schema, joinTables map[string]bool) string {
var sb strings.Builder
// Get all FK constraints
fks := table.GetForeignKeys()
for _, fk := range fks {
// Generate the FK scalar field
for _, fkCol := range fk.Columns {
if col, exists := table.Columns[fkCol]; exists {
sb.WriteString(w.columnToField(col, table, schema))
}
}
// Generate the relation field
relationType := fk.ReferencedTable
isOptional := false
// Check if FK column is nullable
for _, fkCol := range fk.Columns {
if col, exists := table.Columns[fkCol]; exists {
if !col.NotNull {
isOptional = true
}
}
}
relationName := relationType
if strings.HasSuffix(strings.ToLower(relationName), "s") {
relationName = relationName[:len(relationName)-1]
}
sb.WriteString(fmt.Sprintf(" %s %s", strings.ToLower(relationName), relationType))
if isOptional {
sb.WriteString("?")
}
// @relation attribute
relationAttr := w.generateRelationAttribute(fk)
if relationAttr != "" {
sb.WriteString(" ")
sb.WriteString(relationAttr)
}
sb.WriteString("\n")
}
// Generate inverse relations (arrays) for tables that reference this one
sb.WriteString(w.generateInverseRelations(table, schema, joinTables))
return sb.String()
}
// generateRelationAttribute generates the @relation(...) attribute
func (w *Writer) generateRelationAttribute(fk *models.Constraint) string {
parts := make([]string, 0)
// fields
fieldsStr := strings.Join(fk.Columns, ", ")
parts = append(parts, fmt.Sprintf("fields: [%s]", fieldsStr))
// references
referencesStr := strings.Join(fk.ReferencedColumns, ", ")
parts = append(parts, fmt.Sprintf("references: [%s]", referencesStr))
// onDelete
if fk.OnDelete != "" {
parts = append(parts, fmt.Sprintf("onDelete: %s", fk.OnDelete))
}
// onUpdate
if fk.OnUpdate != "" {
parts = append(parts, fmt.Sprintf("onUpdate: %s", fk.OnUpdate))
}
return fmt.Sprintf("@relation(%s)", strings.Join(parts, ", "))
}
// generateInverseRelations generates array fields for reverse relationships
func (w *Writer) generateInverseRelations(table *models.Table, schema *models.Schema, joinTables map[string]bool) string {
var sb strings.Builder
// Find all tables that have FKs pointing to this table
for _, otherTable := range schema.Tables {
if otherTable.Name == table.Name {
continue
}
// Check if this is a join table
if joinTables[otherTable.Name] {
// Handle implicit M2M
if w.isJoinTableFor(otherTable, table.Name) {
// Find the other side of the M2M
for _, fk := range otherTable.GetForeignKeys() {
if fk.ReferencedTable != table.Name {
// This is the other side
otherSide := fk.ReferencedTable
sb.WriteString(fmt.Sprintf(" %ss %s[]\n",
strings.ToLower(otherSide), otherSide))
break
}
}
}
continue
}
// Regular one-to-many inverse relation
for _, fk := range otherTable.GetForeignKeys() {
if fk.ReferencedTable == table.Name {
// This table is referenced by otherTable
pluralName := otherTable.Name
if !strings.HasSuffix(pluralName, "s") {
pluralName += "s"
}
sb.WriteString(fmt.Sprintf(" %s %s[]\n",
strings.ToLower(pluralName), otherTable.Name))
}
}
}
return sb.String()
}
// isJoinTableFor checks if a table is a join table involving the specified model
func (w *Writer) isJoinTableFor(joinTable *models.Table, modelName string) bool {
for _, fk := range joinTable.GetForeignKeys() {
if fk.ReferencedTable == modelName {
return true
}
}
return false
}
// generateBlockAttributes generates block-level attributes like @@id, @@unique, @@index
func (w *Writer) generateBlockAttributes(table *models.Table) string {
var sb strings.Builder
// @@id for composite primary key
pkCols := make([]string, 0)
for _, col := range table.Columns {
if col.IsPrimaryKey {
pkCols = append(pkCols, col.Name)
}
}
if len(pkCols) > 1 {
sort.Strings(pkCols)
sb.WriteString(fmt.Sprintf(" @@id([%s])\n", strings.Join(pkCols, ", ")))
}
// @@unique for multi-column unique constraints
for _, constraint := range table.Constraints {
if constraint.Type == models.UniqueConstraint && len(constraint.Columns) > 1 {
sb.WriteString(fmt.Sprintf(" @@unique([%s])\n", strings.Join(constraint.Columns, ", ")))
}
}
// @@index for indexes
for _, index := range table.Indexes {
if !index.Unique { // Unique indexes are handled by @@unique
sb.WriteString(fmt.Sprintf(" @@index([%s])\n", strings.Join(index.Columns, ", ")))
}
}
return sb.String()
}

View File

@@ -0,0 +1,631 @@
package typeorm
import (
"fmt"
"os"
"sort"
"strings"
"git.warky.dev/wdevs/relspecgo/pkg/models"
"git.warky.dev/wdevs/relspecgo/pkg/writers"
)
// Writer implements the writers.Writer interface for TypeORM entity format
type Writer struct {
options *writers.WriterOptions
}
// NewWriter creates a new TypeORM writer with the given options
func NewWriter(options *writers.WriterOptions) *Writer {
return &Writer{
options: options,
}
}
// WriteDatabase writes a Database model to TypeORM entity format
func (w *Writer) WriteDatabase(db *models.Database) error {
content := w.databaseToTypeORM(db)
if w.options.OutputPath != "" {
return os.WriteFile(w.options.OutputPath, []byte(content), 0644)
}
fmt.Print(content)
return nil
}
// WriteSchema writes a Schema model to TypeORM entity format
func (w *Writer) WriteSchema(schema *models.Schema) error {
db := models.InitDatabase("database")
db.Schemas = []*models.Schema{schema}
return w.WriteDatabase(db)
}
// WriteTable writes a Table model to TypeORM entity format
func (w *Writer) WriteTable(table *models.Table) error {
schema := models.InitSchema(table.Schema)
schema.Tables = []*models.Table{table}
return w.WriteSchema(schema)
}
// databaseToTypeORM converts a Database to TypeORM entity format string
func (w *Writer) databaseToTypeORM(db *models.Database) string {
var sb strings.Builder
// Generate imports
sb.WriteString(w.generateImports(db))
sb.WriteString("\n")
// Process all schemas
for _, schema := range db.Schemas {
// Identify join tables
joinTables := w.identifyJoinTables(schema)
// Write entities (excluding join tables)
for _, table := range schema.Tables {
if joinTables[table.Name] {
continue
}
sb.WriteString(w.tableToEntity(table, schema, joinTables))
sb.WriteString("\n")
}
// Write view entities
for _, view := range schema.Views {
sb.WriteString(w.viewToEntity(view))
sb.WriteString("\n")
}
}
return sb.String()
}
// generateImports generates the TypeORM import statement
func (w *Writer) generateImports(db *models.Database) string {
imports := make([]string, 0)
// Always include basic decorators
imports = append(imports, "Entity", "PrimaryGeneratedColumn", "Column")
// Check if we need relation decorators
needsManyToOne := false
needsOneToMany := false
needsManyToMany := false
needsJoinTable := false
needsCreateDate := false
needsUpdateDate := false
needsViewEntity := false
for _, schema := range db.Schemas {
// Check for views
if len(schema.Views) > 0 {
needsViewEntity = true
}
for _, table := range schema.Tables {
// Check for timestamp columns
for _, col := range table.Columns {
if col.Default == "now()" {
needsCreateDate = true
}
if strings.Contains(col.Comment, "auto-update") {
needsUpdateDate = true
}
}
// Check for relations
for _, constraint := range table.Constraints {
if constraint.Type == models.ForeignKeyConstraint {
needsManyToOne = true
}
}
}
}
// OneToMany is the inverse of ManyToOne
if needsManyToOne {
needsOneToMany = true
}
// Check for M2M (join tables indicate M2M relations)
joinTables := make(map[string]bool)
for _, schema := range db.Schemas {
jt := w.identifyJoinTables(schema)
for name := range jt {
joinTables[name] = true
needsManyToMany = true
needsJoinTable = true
}
}
if needsManyToOne {
imports = append(imports, "ManyToOne")
}
if needsOneToMany {
imports = append(imports, "OneToMany")
}
if needsManyToMany {
imports = append(imports, "ManyToMany")
}
if needsJoinTable {
imports = append(imports, "JoinTable")
}
if needsCreateDate {
imports = append(imports, "CreateDateColumn")
}
if needsUpdateDate {
imports = append(imports, "UpdateDateColumn")
}
if needsViewEntity {
imports = append(imports, "ViewEntity")
}
return fmt.Sprintf("import { %s } from 'typeorm';\n", strings.Join(imports, ", "))
}
// identifyJoinTables identifies tables that are join tables for M2M relations
func (w *Writer) identifyJoinTables(schema *models.Schema) map[string]bool {
joinTables := make(map[string]bool)
for _, table := range schema.Tables {
// Check if this is a join table:
// 1. Has exactly 2 FK constraints
// 2. Has composite PK with those 2 columns
// 3. Has no other columns except the FK columns
fks := table.GetForeignKeys()
if len(fks) != 2 {
continue
}
// Check if columns are only the FK columns
if len(table.Columns) != 2 {
continue
}
// Check if both FK columns are part of PK
pkCols := 0
for _, col := range table.Columns {
if col.IsPrimaryKey {
pkCols++
}
}
if pkCols == 2 {
joinTables[table.Name] = true
}
}
return joinTables
}
// tableToEntity converts a Table to a TypeORM entity class
func (w *Writer) tableToEntity(table *models.Table, schema *models.Schema, joinTables map[string]bool) string {
var sb strings.Builder
// Generate @Entity decorator with options
entityOptions := w.buildEntityOptions(table)
sb.WriteString(fmt.Sprintf("@Entity({\n%s\n})\n", entityOptions))
// Get class name (from metadata if different from table name)
className := table.Name
if table.Metadata != nil {
if classNameVal, ok := table.Metadata["class_name"]; ok {
if classNameStr, ok := classNameVal.(string); ok {
className = classNameStr
}
}
}
sb.WriteString(fmt.Sprintf("export class %s {\n", className))
// Collect and sort columns
columns := make([]*models.Column, 0, len(table.Columns))
for _, col := range table.Columns {
// Skip FK columns (they'll be represented as relations)
if w.isForeignKeyColumn(col, table) {
continue
}
columns = append(columns, col)
}
sort.Slice(columns, func(i, j int) bool {
// Put PK first, then alphabetical
if columns[i].IsPrimaryKey && !columns[j].IsPrimaryKey {
return true
}
if !columns[i].IsPrimaryKey && columns[j].IsPrimaryKey {
return false
}
return columns[i].Name < columns[j].Name
})
// Write scalar fields
for _, col := range columns {
sb.WriteString(w.columnToField(col, table))
sb.WriteString("\n")
}
// Write relation fields
sb.WriteString(w.generateRelationFields(table, schema, joinTables))
sb.WriteString("}\n")
return sb.String()
}
// viewToEntity converts a View to a TypeORM @ViewEntity class
func (w *Writer) viewToEntity(view *models.View) string {
var sb strings.Builder
// Generate @ViewEntity decorator with expression
sb.WriteString("@ViewEntity({\n")
if view.Definition != "" {
// Format the SQL expression with proper indentation
sb.WriteString(" expression: `\n")
sb.WriteString(" ")
sb.WriteString(view.Definition)
sb.WriteString("\n `,\n")
}
sb.WriteString("})\n")
// Generate class
sb.WriteString(fmt.Sprintf("export class %s {\n", view.Name))
// Generate field definitions (without decorators for view fields)
columns := make([]*models.Column, 0, len(view.Columns))
for _, col := range view.Columns {
columns = append(columns, col)
}
sort.Slice(columns, func(i, j int) bool {
return columns[i].Name < columns[j].Name
})
for _, col := range columns {
tsType := w.sqlTypeToTypeScript(col.Type)
sb.WriteString(fmt.Sprintf(" %s: %s;\n", col.Name, tsType))
}
sb.WriteString("}\n")
return sb.String()
}
// columnToField converts a Column to a TypeORM field
func (w *Writer) columnToField(col *models.Column, table *models.Table) string {
var sb strings.Builder
// Generate decorator
if col.IsPrimaryKey {
if col.AutoIncrement {
sb.WriteString(" @PrimaryGeneratedColumn('increment')\n")
} else if col.Type == "uuid" || strings.Contains(fmt.Sprint(col.Default), "uuid") {
sb.WriteString(" @PrimaryGeneratedColumn('uuid')\n")
} else {
sb.WriteString(" @PrimaryGeneratedColumn()\n")
}
} else if col.Default == "now()" {
sb.WriteString(" @CreateDateColumn()\n")
} else if strings.Contains(col.Comment, "auto-update") {
sb.WriteString(" @UpdateDateColumn()\n")
} else {
// Regular @Column decorator
options := w.buildColumnOptions(col, table)
if options != "" {
sb.WriteString(fmt.Sprintf(" @Column({ %s })\n", options))
} else {
sb.WriteString(" @Column()\n")
}
}
// Generate field declaration
tsType := w.sqlTypeToTypeScript(col.Type)
nullable := ""
if !col.NotNull {
nullable = " | null"
}
sb.WriteString(fmt.Sprintf(" %s: %s%s;", col.Name, tsType, nullable))
return sb.String()
}
// buildColumnOptions builds the options object for @Column decorator
func (w *Writer) buildColumnOptions(col *models.Column, table *models.Table) string {
options := make([]string, 0)
// Type (if not default)
if w.needsExplicitType(col.Type) {
options = append(options, fmt.Sprintf("type: '%s'", col.Type))
}
// Nullable
if !col.NotNull {
options = append(options, "nullable: true")
}
// Unique
if w.hasUniqueConstraint(col.Name, table) {
options = append(options, "unique: true")
}
// Default
if col.Default != nil && col.Default != "now()" {
defaultStr := fmt.Sprint(col.Default)
if defaultStr != "" {
options = append(options, fmt.Sprintf("default: '%s'", defaultStr))
}
}
return strings.Join(options, ", ")
}
// needsExplicitType checks if a SQL type needs explicit type declaration
func (w *Writer) needsExplicitType(sqlType string) bool {
// Types that don't map cleanly to TypeScript types need explicit declaration
explicitTypes := []string{"text", "uuid", "jsonb", "bigint"}
for _, t := range explicitTypes {
if strings.Contains(sqlType, t) {
return true
}
}
return false
}
// hasUniqueConstraint checks if a column has a unique constraint
func (w *Writer) hasUniqueConstraint(colName string, table *models.Table) bool {
for _, constraint := range table.Constraints {
if constraint.Type == models.UniqueConstraint &&
len(constraint.Columns) == 1 &&
constraint.Columns[0] == colName {
return true
}
}
return false
}
// sqlTypeToTypeScript converts SQL types to TypeScript types
func (w *Writer) sqlTypeToTypeScript(sqlType string) string {
typeMap := map[string]string{
"text": "string",
"varchar": "string",
"character varying": "string",
"char": "string",
"uuid": "string",
"boolean": "boolean",
"bool": "boolean",
"integer": "number",
"int": "number",
"bigint": "number",
"double precision": "number",
"float": "number",
"decimal": "number",
"numeric": "number",
"timestamp": "Date",
"timestamptz": "Date",
"date": "Date",
"jsonb": "any",
"json": "any",
}
for sqlPattern, tsType := range typeMap {
if strings.Contains(strings.ToLower(sqlType), sqlPattern) {
return tsType
}
}
return "any"
}
// isForeignKeyColumn checks if a column is a FK column
func (w *Writer) isForeignKeyColumn(col *models.Column, table *models.Table) bool {
for _, constraint := range table.Constraints {
if constraint.Type == models.ForeignKeyConstraint {
for _, fkCol := range constraint.Columns {
if fkCol == col.Name {
return true
}
}
}
}
return false
}
// generateRelationFields generates relation fields for a table
func (w *Writer) generateRelationFields(table *models.Table, schema *models.Schema, joinTables map[string]bool) string {
var sb strings.Builder
// Get all FK constraints
fks := table.GetForeignKeys()
// Generate @ManyToOne fields
for _, fk := range fks {
relatedTable := fk.ReferencedTable
fieldName := strings.ToLower(relatedTable)
// Determine if nullable
isNullable := false
for _, fkCol := range fk.Columns {
if col, exists := table.Columns[fkCol]; exists {
if !col.NotNull {
isNullable = true
}
}
}
nullable := ""
if isNullable {
nullable = " | null"
}
// Find inverse field name if possible
inverseField := w.findInverseFieldName(table.Name, relatedTable, schema)
if inverseField != "" {
sb.WriteString(fmt.Sprintf(" @ManyToOne(() => %s, %s => %s.%s)\n",
relatedTable, strings.ToLower(relatedTable), strings.ToLower(relatedTable), inverseField))
} else {
if isNullable {
sb.WriteString(fmt.Sprintf(" @ManyToOne(() => %s, { nullable: true })\n", relatedTable))
} else {
sb.WriteString(fmt.Sprintf(" @ManyToOne(() => %s)\n", relatedTable))
}
}
sb.WriteString(fmt.Sprintf(" %s: %s%s;\n", fieldName, relatedTable, nullable))
sb.WriteString("\n")
}
// Generate @OneToMany fields (inverse of FKs pointing to this table)
w.generateInverseRelations(table, schema, joinTables, &sb)
// Generate @ManyToMany fields
w.generateManyToManyRelations(table, schema, joinTables, &sb)
return sb.String()
}
// findInverseFieldName finds the inverse field name for a relation
func (w *Writer) findInverseFieldName(fromTable, toTable string, schema *models.Schema) string {
// Look for tables that have FKs pointing back to fromTable
for _, table := range schema.Tables {
if table.Name != toTable {
continue
}
for _, constraint := range table.Constraints {
if constraint.Type == models.ForeignKeyConstraint && constraint.ReferencedTable == fromTable {
// Found an inverse relation
// Use pluralized form of fromTable
return w.pluralize(strings.ToLower(fromTable))
}
}
}
return ""
}
// generateInverseRelations generates @OneToMany fields
func (w *Writer) generateInverseRelations(table *models.Table, schema *models.Schema, joinTables map[string]bool, sb *strings.Builder) {
for _, otherTable := range schema.Tables {
if otherTable.Name == table.Name || joinTables[otherTable.Name] {
continue
}
for _, fk := range otherTable.GetForeignKeys() {
if fk.ReferencedTable == table.Name {
// This table is referenced by otherTable
fieldName := w.pluralize(strings.ToLower(otherTable.Name))
inverseName := strings.ToLower(table.Name)
sb.WriteString(fmt.Sprintf(" @OneToMany(() => %s, %s => %s.%s)\n",
otherTable.Name, strings.ToLower(otherTable.Name), strings.ToLower(otherTable.Name), inverseName))
sb.WriteString(fmt.Sprintf(" %s: %s[];\n", fieldName, otherTable.Name))
sb.WriteString("\n")
}
}
}
}
// generateManyToManyRelations generates @ManyToMany fields
func (w *Writer) generateManyToManyRelations(table *models.Table, schema *models.Schema, joinTables map[string]bool, sb *strings.Builder) {
for joinTableName := range joinTables {
joinTable := w.findTable(joinTableName, schema)
if joinTable == nil {
continue
}
fks := joinTable.GetForeignKeys()
if len(fks) != 2 {
continue
}
// Check if this table is part of the M2M relation
var thisTableFK *models.Constraint
var otherTableFK *models.Constraint
for i, fk := range fks {
if fk.ReferencedTable == table.Name {
thisTableFK = fk
if i == 0 {
otherTableFK = fks[1]
} else {
otherTableFK = fks[0]
}
}
}
if thisTableFK == nil {
continue
}
// Determine which side owns the relation (has @JoinTable)
// We'll make the first entity alphabetically the owner
isOwner := table.Name < otherTableFK.ReferencedTable
otherTable := otherTableFK.ReferencedTable
fieldName := w.pluralize(strings.ToLower(otherTable))
inverseName := w.pluralize(strings.ToLower(table.Name))
if isOwner {
sb.WriteString(fmt.Sprintf(" @ManyToMany(() => %s, %s => %s.%s)\n",
otherTable, strings.ToLower(otherTable), strings.ToLower(otherTable), inverseName))
sb.WriteString(" @JoinTable()\n")
} else {
sb.WriteString(fmt.Sprintf(" @ManyToMany(() => %s, %s => %s.%s)\n",
otherTable, strings.ToLower(otherTable), strings.ToLower(otherTable), inverseName))
}
sb.WriteString(fmt.Sprintf(" %s: %s[];\n", fieldName, otherTable))
sb.WriteString("\n")
}
}
// findTable finds a table by name in a schema
func (w *Writer) findTable(name string, schema *models.Schema) *models.Table {
for _, table := range schema.Tables {
if table.Name == name {
return table
}
}
return nil
}
// buildEntityOptions builds the options object for @Entity decorator
func (w *Writer) buildEntityOptions(table *models.Table) string {
options := make([]string, 0)
// Always include table name
options = append(options, fmt.Sprintf(" name: \"%s\"", table.Name))
// Always include schema
options = append(options, fmt.Sprintf(" schema: \"%s\"", table.Schema))
// Database name from metadata
if table.Metadata != nil {
if database, ok := table.Metadata["database"]; ok {
if databaseStr, ok := database.(string); ok {
options = append(options, fmt.Sprintf(" database: \"%s\"", databaseStr))
}
}
// Engine from metadata
if engine, ok := table.Metadata["engine"]; ok {
if engineStr, ok := engine.(string); ok {
options = append(options, fmt.Sprintf(" engine: \"%s\"", engineStr))
}
}
}
return strings.Join(options, ",\n")
}
// pluralize adds 's' to make a word plural (simple version)
func (w *Writer) pluralize(word string) string {
if strings.HasSuffix(word, "s") {
return word
}
return word + "s"
}

View File

@@ -0,0 +1,46 @@
datasource db {
provider = "postgresql"
}
generator client {
provider = "prisma-client"
output = "./generated"
}
model User {
id Int @id @default(autoincrement())
email String @unique
name String?
role Role @default(USER)
posts Post[]
profile Profile?
}
model Profile {
id Int @id @default(autoincrement())
bio String
user User @relation(fields: [userId], references: [id])
userId Int @unique
}
model Post {
id Int @id @default(autoincrement())
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
title String
published Boolean @default(false)
author User @relation(fields: [authorId], references: [id])
authorId Int
categories Category[]
}
model Category {
id Int @id @default(autoincrement())
name String
posts Post[]
}
enum Role {
USER
ADMIN
}

View File

@@ -0,0 +1,115 @@
//@ts-nocheck
import { Entity, PrimaryGeneratedColumn, Column, ManyToOne, OneToMany, ManyToMany, JoinTable, CreateDateColumn, UpdateDateColumn } from 'typeorm';
@Entity()
export class User {
@PrimaryGeneratedColumn('uuid')
id: string;
@Column({ unique: true })
email: string;
@Column()
name: string;
@CreateDateColumn()
createdAt: Date;
@UpdateDateColumn()
updatedAt: Date;
@OneToMany(() => Project, project => project.owner)
ownedProjects: Project[];
@ManyToMany(() => Project, project => project.members)
@JoinTable()
projects: Project[];
}
@Entity()
export class Project {
@PrimaryGeneratedColumn('uuid')
id: string;
@Column()
title: string;
@Column({ nullable: true })
description: string;
@Column({ default: 'active' })
status: string;
@ManyToOne(() => User, user => user.ownedProjects)
owner: User;
@ManyToMany(() => User, user => user.projects)
members: User[];
@OneToMany(() => Task, task => task.project)
tasks: Task[];
@CreateDateColumn()
createdAt: Date;
}
@Entity()
export class Task {
@PrimaryGeneratedColumn('uuid')
id: string;
@Column()
title: string;
@Column({ type: 'text', nullable: true })
description: string;
@Column({ default: 'todo' })
status: string;
@Column({ nullable: true })
dueDate: Date;
@ManyToOne(() => Project, project => project.tasks)
project: Project;
@ManyToOne(() => User, { nullable: true })
assignee: User;
@OneToMany(() => Comment, comment => comment.task)
comments: Comment[];
}
@Entity()
export class Comment {
@PrimaryGeneratedColumn('uuid')
id: string;
@Column('text')
content: string;
@ManyToOne(() => Task, task => task.comments)
task: Task;
@ManyToOne(() => User)
author: User;
@CreateDateColumn()
createdAt: Date;
}
@Entity()
export class Tag {
@PrimaryGeneratedColumn('uuid')
id: string;
@Column({ unique: true })
name: string;
@Column()
color: string;
@ManyToMany(() => Task)
@JoinTable()
tasks: Task[];
}