diff --git a/TODO.md b/TODO.md index d8dc24f..9eff2d8 100644 --- a/TODO.md +++ b/TODO.md @@ -2,21 +2,21 @@ ## Input Readers / Writers -- [x] **Database Inspector** - - [x] PostgreSQL driver +- [✔️] **Database Inspector** + - [✔️] PostgreSQL driver - [ ] MySQL driver - [ ] SQLite driver - [ ] MSSQL driver - - [x] Foreign key detection - - [x] Index extraction + - [✔️] Foreign key detection + - [✔️] Index extraction - [ ] .sql file generation with sequence and priority -- [*] .dbml: Database Markup Language (DBML) for textual schema representation. -- [ ] Prisma schema support (PSL format) .prisma -- [ ] Entity Framework (.NET) model .edmx -- [ ] TypeORM support -- [ ] .hbm.xml / schema.xml: Hibernate/Propel mappings (Java/PHP) -- [ ] Django models.py (Python classes), Sequelize migrations (JS) -- [ ] .avsc: Avro schema (JSON format for data serialization) +- [✔️] .dbml: Database Markup Language (DBML) for textual schema representation. +- [✔️] Prisma schema support (PSL format) .prisma +- [☠️] Entity Framework (.NET) model .edmx (Fuck no, EDMX files were bloated, verbose XML nightmares—hard to merge, error-prone, and a pain in teams. Microsoft wisely ditched them in EF Core for code-first. Classic overkill from old MS era.) +- [✔️] TypeORM support +- [] .hbm.xml / schema.xml: Hibernate/Propel mappings (Java/PHP) (💲 Someone can do this, not me) +- [ ] Django models.py (Python classes), Sequelize migrations (JS) (💲 Someone can do this, not me) +- [] .avsc: Avro schema (JSON format for data serialization) (💲 Someone can do this, not me) diff --git a/cmd/relspec/convert.go b/cmd/relspec/convert.go index a755aaa..ce59b1a 100644 --- a/cmd/relspec/convert.go +++ b/cmd/relspec/convert.go @@ -15,6 +15,8 @@ import ( "git.warky.dev/wdevs/relspecgo/pkg/readers/gorm" "git.warky.dev/wdevs/relspecgo/pkg/readers/json" "git.warky.dev/wdevs/relspecgo/pkg/readers/pgsql" + "git.warky.dev/wdevs/relspecgo/pkg/readers/prisma" + "git.warky.dev/wdevs/relspecgo/pkg/readers/typeorm" "git.warky.dev/wdevs/relspecgo/pkg/readers/yaml" "git.warky.dev/wdevs/relspecgo/pkg/writers" wbun "git.warky.dev/wdevs/relspecgo/pkg/writers/bun" @@ -24,6 +26,8 @@ import ( wgorm "git.warky.dev/wdevs/relspecgo/pkg/writers/gorm" wjson "git.warky.dev/wdevs/relspecgo/pkg/writers/json" wpgsql "git.warky.dev/wdevs/relspecgo/pkg/writers/pgsql" + wprisma "git.warky.dev/wdevs/relspecgo/pkg/writers/prisma" + wtypeorm "git.warky.dev/wdevs/relspecgo/pkg/writers/typeorm" wyaml "git.warky.dev/wdevs/relspecgo/pkg/writers/yaml" "github.com/spf13/cobra" ) @@ -55,6 +59,8 @@ Input formats: - yaml: YAML database schema - gorm: GORM model files (Go, file or directory) - bun: Bun model files (Go, file or directory) + - prisma: Prisma schema files (.prisma) + - typeorm: TypeORM entity files (TypeScript) - pgsql: PostgreSQL database (live connection) Output formats: @@ -65,6 +71,8 @@ Output formats: - yaml: YAML database schema - gorm: GORM model files (Go) - bun: Bun model files (Go) + - prisma: Prisma schema files (.prisma) + - typeorm: TypeORM entity files (TypeScript) - pgsql: PostgreSQL SQL schema PostgreSQL Connection String Examples: @@ -123,11 +131,11 @@ Examples: } func init() { - convertCmd.Flags().StringVar(&convertSourceType, "from", "", "Source format (dbml, dctx, drawdb, json, yaml, gorm, bun, pgsql)") + convertCmd.Flags().StringVar(&convertSourceType, "from", "", "Source format (dbml, dctx, drawdb, json, yaml, gorm, bun, prisma, typeorm, pgsql)") convertCmd.Flags().StringVar(&convertSourcePath, "from-path", "", "Source file path (for file-based formats)") convertCmd.Flags().StringVar(&convertSourceConn, "from-conn", "", "Source connection string (for database formats)") - convertCmd.Flags().StringVar(&convertTargetType, "to", "", "Target format (dbml, dctx, drawdb, json, yaml, gorm, bun, pgsql)") + convertCmd.Flags().StringVar(&convertTargetType, "to", "", "Target format (dbml, dctx, drawdb, json, yaml, gorm, bun, prisma, typeorm, pgsql)") convertCmd.Flags().StringVar(&convertTargetPath, "to-path", "", "Target output path (file or directory)") convertCmd.Flags().StringVar(&convertPackageName, "package", "", "Package name (for code generation formats like gorm/bun)") convertCmd.Flags().StringVar(&convertSchemaFilter, "schema", "", "Filter to a specific schema by name (required for formats like dctx that only support single schemas)") @@ -239,6 +247,18 @@ func readDatabaseForConvert(dbType, filePath, connString string) (*models.Databa } reader = bun.NewReader(&readers.ReaderOptions{FilePath: filePath}) + case "prisma": + if filePath == "" { + return nil, fmt.Errorf("file path is required for Prisma format") + } + reader = prisma.NewReader(&readers.ReaderOptions{FilePath: filePath}) + + case "typeorm": + if filePath == "" { + return nil, fmt.Errorf("file path is required for TypeORM format") + } + reader = typeorm.NewReader(&readers.ReaderOptions{FilePath: filePath}) + default: return nil, fmt.Errorf("unsupported source format: %s", dbType) } @@ -290,6 +310,12 @@ func writeDatabase(db *models.Database, dbType, outputPath, packageName, schemaF case "pgsql", "postgres", "postgresql", "sql": writer = wpgsql.NewWriter(writerOpts) + case "prisma": + writer = wprisma.NewWriter(writerOpts) + + case "typeorm": + writer = wtypeorm.NewWriter(writerOpts) + default: return fmt.Errorf("unsupported target format: %s", dbType) } diff --git a/pkg/models/models.go b/pkg/models/models.go index 1cb8d33..8868c0b 100644 --- a/pkg/models/models.go +++ b/pkg/models/models.go @@ -40,6 +40,7 @@ type Schema struct { Sequence uint `json:"sequence,omitempty" yaml:"sequence,omitempty" xml:"sequence,omitempty"` RefDatabase *Database `json:"-" yaml:"-" xml:"-"` // Excluded to prevent circular references Relations []*Relationship `json:"relations,omitempty" yaml:"relations,omitempty" xml:"-"` + Enums []*Enum `json:"enums,omitempty" yaml:"enums,omitempty" xml:"enums"` } // SQLName returns the schema name in lowercase @@ -225,6 +226,16 @@ func (d *Constraint) SQLName() string { type ConstraintType string +type Enum struct { + Name string `json:"name" yaml:"name" xml:"name"` + Values []string `json:"values" yaml:"values" xml:"values"` + Schema string `json:"schema,omitempty" yaml:"schema,omitempty" xml:"schema,omitempty"` +} + +func (d *Enum) SQLName() string { + return strings.ToLower(d.Name) +} + const ( PrimaryKeyConstraint ConstraintType = "primary_key" ForeignKeyConstraint ConstraintType = "foreign_key" diff --git a/pkg/readers/prisma/reader.go b/pkg/readers/prisma/reader.go new file mode 100644 index 0000000..6cf1495 --- /dev/null +++ b/pkg/readers/prisma/reader.go @@ -0,0 +1,823 @@ +package prisma + +import ( + "bufio" + "fmt" + "os" + "regexp" + "strings" + + "git.warky.dev/wdevs/relspecgo/pkg/models" + "git.warky.dev/wdevs/relspecgo/pkg/readers" +) + +// Reader implements the readers.Reader interface for Prisma schema format +type Reader struct { + options *readers.ReaderOptions +} + +// NewReader creates a new Prisma reader with the given options +func NewReader(options *readers.ReaderOptions) *Reader { + return &Reader{ + options: options, + } +} + +// ReadDatabase reads and parses Prisma schema input, returning a Database model +func (r *Reader) ReadDatabase() (*models.Database, error) { + if r.options.FilePath == "" { + return nil, fmt.Errorf("file path is required for Prisma reader") + } + + content, err := os.ReadFile(r.options.FilePath) + if err != nil { + return nil, fmt.Errorf("failed to read file: %w", err) + } + + return r.parsePrisma(string(content)) +} + +// ReadSchema reads and parses Prisma schema input, returning a Schema model +func (r *Reader) ReadSchema() (*models.Schema, error) { + db, err := r.ReadDatabase() + if err != nil { + return nil, err + } + + if len(db.Schemas) == 0 { + return nil, fmt.Errorf("no schemas found in Prisma schema") + } + + // Return the first schema + return db.Schemas[0], nil +} + +// ReadTable reads and parses Prisma schema input, returning a Table model +func (r *Reader) ReadTable() (*models.Table, error) { + schema, err := r.ReadSchema() + if err != nil { + return nil, err + } + + if len(schema.Tables) == 0 { + return nil, fmt.Errorf("no tables found in Prisma schema") + } + + // Return the first table + return schema.Tables[0], nil +} + +// stripQuotes removes surrounding quotes from an identifier +func stripQuotes(s string) string { + s = strings.TrimSpace(s) + if len(s) >= 2 && ((s[0] == '"' && s[len(s)-1] == '"') || (s[0] == '\'' && s[len(s)-1] == '\'')) { + return s[1 : len(s)-1] + } + return s +} + +// parsePrisma parses Prisma schema content and returns a Database model +func (r *Reader) parsePrisma(content string) (*models.Database, error) { + db := models.InitDatabase("database") + + if r.options.Metadata != nil { + if name, ok := r.options.Metadata["name"].(string); ok { + db.Name = name + } + } + + // Default schema for Prisma (doesn't have explicit schema concept in most cases) + schema := models.InitSchema("public") + schema.Enums = make([]*models.Enum, 0) + + scanner := bufio.NewScanner(strings.NewReader(content)) + + // State tracking + var currentBlock string // "datasource", "generator", "model", "enum" + var currentTable *models.Table + var currentEnum *models.Enum + var blockContent []string + + // Regex patterns + datasourceRegex := regexp.MustCompile(`^datasource\s+\w+\s*{`) + generatorRegex := regexp.MustCompile(`^generator\s+\w+\s*{`) + modelRegex := regexp.MustCompile(`^model\s+(\w+)\s*{`) + enumRegex := regexp.MustCompile(`^enum\s+(\w+)\s*{`) + + for scanner.Scan() { + line := scanner.Text() + trimmed := strings.TrimSpace(line) + + // Skip empty lines and comments + if trimmed == "" || strings.HasPrefix(trimmed, "//") { + continue + } + + // Check for block start + if matches := datasourceRegex.FindStringSubmatch(trimmed); matches != nil { + currentBlock = "datasource" + blockContent = []string{} + continue + } + + if matches := generatorRegex.FindStringSubmatch(trimmed); matches != nil { + currentBlock = "generator" + blockContent = []string{} + continue + } + + if matches := modelRegex.FindStringSubmatch(trimmed); matches != nil { + currentBlock = "model" + tableName := matches[1] + currentTable = models.InitTable(tableName, "public") + blockContent = []string{} + continue + } + + if matches := enumRegex.FindStringSubmatch(trimmed); matches != nil { + currentBlock = "enum" + enumName := matches[1] + currentEnum = &models.Enum{ + Name: enumName, + Schema: "public", + Values: make([]string, 0), + } + blockContent = []string{} + continue + } + + // Check for block end + if trimmed == "}" { + switch currentBlock { + case "datasource": + r.parseDatasource(blockContent, db) + case "generator": + // We don't need to do anything with generator blocks + case "model": + if currentTable != nil { + r.parseModelFields(blockContent, currentTable) + schema.Tables = append(schema.Tables, currentTable) + currentTable = nil + } + case "enum": + if currentEnum != nil { + schema.Enums = append(schema.Enums, currentEnum) + currentEnum = nil + } + } + currentBlock = "" + blockContent = []string{} + continue + } + + // Accumulate block content + if currentBlock != "" { + if currentBlock == "enum" && currentEnum != nil { + // For enums, just add the trimmed value + if trimmed != "" { + currentEnum.Values = append(currentEnum.Values, trimmed) + } + } else { + blockContent = append(blockContent, line) + } + } + } + + // Second pass: resolve relationships + r.resolveRelationships(schema) + + db.Schemas = append(db.Schemas, schema) + return db, nil +} + +// parseDatasource extracts database type from datasource block +func (r *Reader) parseDatasource(lines []string, db *models.Database) { + providerRegex := regexp.MustCompile(`provider\s*=\s*"?(\w+)"?`) + + for _, line := range lines { + if matches := providerRegex.FindStringSubmatch(line); matches != nil { + provider := matches[1] + switch provider { + case "postgresql", "postgres": + db.DatabaseType = models.PostgresqlDatabaseType + case "mysql": + db.DatabaseType = "mysql" + case "sqlite": + db.DatabaseType = models.SqlLiteDatabaseType + case "sqlserver": + db.DatabaseType = models.MSSQLDatabaseType + default: + db.DatabaseType = models.PostgresqlDatabaseType + } + break + } + } +} + +// parseModelFields parses model field definitions +func (r *Reader) parseModelFields(lines []string, table *models.Table) { + fieldRegex := regexp.MustCompile(`^(\w+)\s+(\w+)(\?|\[\])?\s*(@.+)?`) + blockAttrRegex := regexp.MustCompile(`^@@(\w+)\((.*?)\)`) + + for _, line := range lines { + trimmed := strings.TrimSpace(line) + + // Skip empty lines and comments + if trimmed == "" || strings.HasPrefix(trimmed, "//") { + continue + } + + // Check for block attributes (@@id, @@unique, @@index) + if matches := blockAttrRegex.FindStringSubmatch(trimmed); matches != nil { + attrName := matches[1] + attrContent := matches[2] + r.parseBlockAttribute(attrName, attrContent, table) + continue + } + + // Parse field definition + if matches := fieldRegex.FindStringSubmatch(trimmed); matches != nil { + fieldName := matches[1] + fieldType := matches[2] + modifier := matches[3] // ? or [] + attributes := matches[4] // @... part + + column := r.parseField(fieldName, fieldType, modifier, attributes, table) + if column != nil { + table.Columns[column.Name] = column + } + } + } +} + +// parseField parses a single field definition +func (r *Reader) parseField(name, fieldType, modifier, attributes string, table *models.Table) *models.Column { + // Check if this is a relation field (array or references another model) + if modifier == "[]" { + // Array field - this is a relation field, not a column + // We'll handle this in relationship resolution + return nil + } + + // Check if this is a non-primitive type (relation field) + // Note: We need to allow enum types through as they're like primitives + if !r.isPrimitiveType(fieldType) && !r.isEnumType(fieldType, table) { + // This is a relation field (e.g., user User), not a scalar column + // Only process this if it has @relation attribute (which means it's the owning side with FK) + // Otherwise skip it as it's just the inverse relation field + if attributes == "" || !strings.Contains(attributes, "@relation") { + return nil + } + // If it has @relation, we still don't create a column for it + // The actual FK column will be in the fields: [...] part of @relation + return nil + } + + column := models.InitColumn(name, table.Name, table.Schema) + + // Map Prisma type to SQL type + column.Type = r.prismaTypeToSQL(fieldType) + + // Handle modifiers + if modifier == "?" { + column.NotNull = false + } else { + // Default: required fields are NOT NULL + column.NotNull = true + } + + // Parse field attributes + if attributes != "" { + r.parseFieldAttributes(attributes, column, table) + } + + return column +} + +// prismaTypeToSQL converts Prisma types to SQL types +func (r *Reader) prismaTypeToSQL(prismaType string) string { + typeMap := map[string]string{ + "String": "text", + "Boolean": "boolean", + "Int": "integer", + "BigInt": "bigint", + "Float": "double precision", + "Decimal": "decimal", + "DateTime": "timestamp", + "Json": "jsonb", + "Bytes": "bytea", + } + + if sqlType, ok := typeMap[prismaType]; ok { + return sqlType + } + + // If not a built-in type, it might be an enum or model reference + // For enums, we'll use the enum name directly + return prismaType +} + +// parseFieldAttributes parses field attributes like @id, @unique, @default +func (r *Reader) parseFieldAttributes(attributes string, column *models.Column, table *models.Table) { + // @id attribute + if strings.Contains(attributes, "@id") { + column.IsPrimaryKey = true + column.NotNull = true + } + + // @unique attribute + if regexp.MustCompile(`@unique\b`).MatchString(attributes) { + uniqueConstraint := models.InitConstraint( + fmt.Sprintf("uq_%s", column.Name), + models.UniqueConstraint, + ) + uniqueConstraint.Schema = table.Schema + uniqueConstraint.Table = table.Name + uniqueConstraint.Columns = []string{column.Name} + table.Constraints[uniqueConstraint.Name] = uniqueConstraint + } + + // @default attribute - extract value with balanced parentheses + if strings.Contains(attributes, "@default(") { + defaultValue := r.extractDefaultValue(attributes) + if defaultValue != "" { + r.parseDefaultValue(defaultValue, column) + } + } + + // @updatedAt attribute - store in comment for now + if strings.Contains(attributes, "@updatedAt") { + if column.Comment != "" { + column.Comment += "; @updatedAt" + } else { + column.Comment = "@updatedAt" + } + } + + // @relation attribute - we'll handle this in relationship resolution + // For now, just note that this field is part of a relation +} + +// extractDefaultValue extracts the default value from @default(...) handling nested parentheses +func (r *Reader) extractDefaultValue(attributes string) string { + idx := strings.Index(attributes, "@default(") + if idx == -1 { + return "" + } + + start := idx + len("@default(") + depth := 1 + i := start + + for i < len(attributes) && depth > 0 { + if attributes[i] == '(' { + depth++ + } else if attributes[i] == ')' { + depth-- + } + i++ + } + + if depth == 0 { + return attributes[start : i-1] + } + + return "" +} + +// parseDefaultValue parses Prisma default value expressions +func (r *Reader) parseDefaultValue(defaultExpr string, column *models.Column) { + defaultExpr = strings.TrimSpace(defaultExpr) + + switch defaultExpr { + case "autoincrement()": + column.AutoIncrement = true + case "now()": + column.Default = "now()" + case "uuid()": + column.Default = "gen_random_uuid()" + case "cuid()": + // CUID is Prisma-specific, store in comment + if column.Comment != "" { + column.Comment += "; default(cuid())" + } else { + column.Comment = "default(cuid())" + } + case "true": + column.Default = true + case "false": + column.Default = false + default: + // Check if it's a string literal + if strings.HasPrefix(defaultExpr, "\"") && strings.HasSuffix(defaultExpr, "\"") { + column.Default = defaultExpr[1 : len(defaultExpr)-1] + } else if strings.HasPrefix(defaultExpr, "'") && strings.HasSuffix(defaultExpr, "'") { + column.Default = defaultExpr[1 : len(defaultExpr)-1] + } else { + // Try to parse as number or enum value + column.Default = defaultExpr + } + } +} + +// parseBlockAttribute parses block-level attributes like @@id, @@unique, @@index +func (r *Reader) parseBlockAttribute(attrName, content string, table *models.Table) { + // Extract column list from brackets [col1, col2] + colListRegex := regexp.MustCompile(`\[(.*?)\]`) + matches := colListRegex.FindStringSubmatch(content) + if matches == nil { + return + } + + columnList := strings.Split(matches[1], ",") + columns := make([]string, 0) + for _, col := range columnList { + columns = append(columns, strings.TrimSpace(col)) + } + + switch attrName { + case "id": + // Composite primary key + for _, colName := range columns { + if col, exists := table.Columns[colName]; exists { + col.IsPrimaryKey = true + col.NotNull = true + } + } + // Also create a PK constraint + pkConstraint := models.InitConstraint( + fmt.Sprintf("pk_%s", table.Name), + models.PrimaryKeyConstraint, + ) + pkConstraint.Schema = table.Schema + pkConstraint.Table = table.Name + pkConstraint.Columns = columns + table.Constraints[pkConstraint.Name] = pkConstraint + + case "unique": + // Multi-column unique constraint + uniqueConstraint := models.InitConstraint( + fmt.Sprintf("uq_%s_%s", table.Name, strings.Join(columns, "_")), + models.UniqueConstraint, + ) + uniqueConstraint.Schema = table.Schema + uniqueConstraint.Table = table.Name + uniqueConstraint.Columns = columns + table.Constraints[uniqueConstraint.Name] = uniqueConstraint + + case "index": + // Index + index := models.InitIndex( + fmt.Sprintf("idx_%s_%s", table.Name, strings.Join(columns, "_")), + table.Name, + table.Schema, + ) + index.Columns = columns + table.Indexes[index.Name] = index + } +} + +// relationField stores information about a relation field for second-pass processing +type relationField struct { + tableName string + fieldName string + relatedModel string + isArray bool + relationAttr string +} + +// resolveRelationships performs a second pass to resolve @relation attributes +func (r *Reader) resolveRelationships(schema *models.Schema) { + // Build a map of table names for quick lookup + tableMap := make(map[string]*models.Table) + for _, table := range schema.Tables { + tableMap[table.Name] = table + } + + // First, we need to re-parse to find relation fields + // We'll re-read the file to extract relation information + if r.options.FilePath == "" { + return + } + + content, err := os.ReadFile(r.options.FilePath) + if err != nil { + return + } + + relations := r.extractRelationFields(string(content)) + + // Process explicit @relation attributes to create FK constraints + for _, rel := range relations { + if rel.relationAttr != "" { + r.createConstraintFromRelation(rel, tableMap, schema) + } + } + + // Detect implicit many-to-many relationships + r.detectImplicitManyToMany(relations, tableMap, schema) +} + +// extractRelationFields extracts relation field information from the schema +func (r *Reader) extractRelationFields(content string) []relationField { + relations := make([]relationField, 0) + scanner := bufio.NewScanner(strings.NewReader(content)) + + modelRegex := regexp.MustCompile(`^model\s+(\w+)\s*{`) + fieldRegex := regexp.MustCompile(`^(\w+)\s+(\w+)(\?|\[\])?\s*(@.+)?`) + + var currentModel string + inModel := false + + for scanner.Scan() { + line := scanner.Text() + trimmed := strings.TrimSpace(line) + + if trimmed == "" || strings.HasPrefix(trimmed, "//") { + continue + } + + if matches := modelRegex.FindStringSubmatch(trimmed); matches != nil { + currentModel = matches[1] + inModel = true + continue + } + + if trimmed == "}" { + inModel = false + currentModel = "" + continue + } + + if inModel && currentModel != "" { + if matches := fieldRegex.FindStringSubmatch(trimmed); matches != nil { + fieldName := matches[1] + fieldType := matches[2] + modifier := matches[3] + attributes := matches[4] + + // Check if this is a relation field (references another model or is an array) + isPotentialRelation := modifier == "[]" || !r.isPrimitiveType(fieldType) + + if isPotentialRelation { + rel := relationField{ + tableName: currentModel, + fieldName: fieldName, + relatedModel: fieldType, + isArray: modifier == "[]", + relationAttr: attributes, + } + relations = append(relations, rel) + } + } + } + } + + return relations +} + +// isPrimitiveType checks if a type is a Prisma primitive type +func (r *Reader) isPrimitiveType(typeName string) bool { + primitives := []string{"String", "Boolean", "Int", "BigInt", "Float", "Decimal", "DateTime", "Json", "Bytes"} + for _, p := range primitives { + if typeName == p { + return true + } + } + return false +} + +// isEnumType checks if a type name might be an enum +// Note: We can't definitively check against schema.Enums at parse time +// because enums might be defined after the model, so we just check +// if it starts with uppercase (Prisma convention for enums) +func (r *Reader) isEnumType(typeName string, table *models.Table) bool { + // Simple heuristic: enum types start with uppercase letter + // and are not known model names (though we can't check that yet) + if len(typeName) > 0 && typeName[0] >= 'A' && typeName[0] <= 'Z' { + // Additional check: primitive types are already handled above + // So if it's uppercase and not primitive, it's likely an enum or model + // We'll assume it's an enum if it's a single word + return !strings.Contains(typeName, "_") + } + return false +} + +// createConstraintFromRelation creates a FK constraint from a @relation attribute +func (r *Reader) createConstraintFromRelation(rel relationField, tableMap map[string]*models.Table, schema *models.Schema) { + // Skip array fields (they are the inverse side of the relation) + if rel.isArray { + return + } + + if rel.relationAttr == "" { + return + } + + // Parse @relation attribute + relationRegex := regexp.MustCompile(`@relation\((.*?)\)`) + matches := relationRegex.FindStringSubmatch(rel.relationAttr) + if matches == nil { + return + } + + relationContent := matches[1] + + // Extract fields and references + fieldsRegex := regexp.MustCompile(`fields:\s*\[(.*?)\]`) + referencesRegex := regexp.MustCompile(`references:\s*\[(.*?)\]`) + nameRegex := regexp.MustCompile(`name:\s*"([^"]+)"`) + onDeleteRegex := regexp.MustCompile(`onDelete:\s*(\w+)`) + onUpdateRegex := regexp.MustCompile(`onUpdate:\s*(\w+)`) + + fieldsMatch := fieldsRegex.FindStringSubmatch(relationContent) + referencesMatch := referencesRegex.FindStringSubmatch(relationContent) + + if fieldsMatch == nil || referencesMatch == nil { + return + } + + // Parse field and reference column lists + fieldCols := r.parseColumnList(fieldsMatch[1]) + refCols := r.parseColumnList(referencesMatch[1]) + + if len(fieldCols) == 0 || len(refCols) == 0 { + return + } + + // Create FK constraint + constraintName := fmt.Sprintf("fk_%s_%s", rel.tableName, fieldCols[0]) + + // Check for custom name + if nameMatch := nameRegex.FindStringSubmatch(relationContent); nameMatch != nil { + constraintName = nameMatch[1] + } + + constraint := models.InitConstraint(constraintName, models.ForeignKeyConstraint) + constraint.Schema = "public" + constraint.Table = rel.tableName + constraint.Columns = fieldCols + constraint.ReferencedSchema = "public" + constraint.ReferencedTable = rel.relatedModel + constraint.ReferencedColumns = refCols + + // Parse referential actions + if onDeleteMatch := onDeleteRegex.FindStringSubmatch(relationContent); onDeleteMatch != nil { + constraint.OnDelete = onDeleteMatch[1] + } + + if onUpdateMatch := onUpdateRegex.FindStringSubmatch(relationContent); onUpdateMatch != nil { + constraint.OnUpdate = onUpdateMatch[1] + } + + // Add constraint to table + if table, exists := tableMap[rel.tableName]; exists { + table.Constraints[constraint.Name] = constraint + } +} + +// parseColumnList parses a comma-separated list of column names +func (r *Reader) parseColumnList(list string) []string { + parts := strings.Split(list, ",") + result := make([]string, 0) + for _, part := range parts { + trimmed := strings.TrimSpace(part) + if trimmed != "" { + result = append(result, trimmed) + } + } + return result +} + +// detectImplicitManyToMany detects implicit M2M relationships and creates join tables +func (r *Reader) detectImplicitManyToMany(relations []relationField, tableMap map[string]*models.Table, schema *models.Schema) { + // Group relations by model pairs + type modelPair struct { + model1 string + model2 string + } + + pairMap := make(map[modelPair][]relationField) + + for _, rel := range relations { + if !rel.isArray || rel.relationAttr != "" { + // Skip non-array fields and explicit relations + continue + } + + // Create a normalized pair (alphabetically sorted to avoid duplicates) + pair := modelPair{} + if rel.tableName < rel.relatedModel { + pair.model1 = rel.tableName + pair.model2 = rel.relatedModel + } else { + pair.model1 = rel.relatedModel + pair.model2 = rel.tableName + } + + pairMap[pair] = append(pairMap[pair], rel) + } + + // Check for pairs with arrays on both sides (implicit M2M) + for pair, rels := range pairMap { + if len(rels) >= 2 { + // This is an implicit many-to-many relationship + r.createImplicitJoinTable(pair.model1, pair.model2, tableMap, schema) + } + } +} + +// createImplicitJoinTable creates a virtual join table for implicit M2M relations +func (r *Reader) createImplicitJoinTable(model1, model2 string, tableMap map[string]*models.Table, schema *models.Schema) { + // Prisma naming convention: _Model1ToModel2 (alphabetically sorted) + joinTableName := fmt.Sprintf("_%sTo%s", model1, model2) + + // Check if join table already exists + if _, exists := tableMap[joinTableName]; exists { + return + } + + // Create join table + joinTable := models.InitTable(joinTableName, "public") + + // Get primary keys from both tables + pk1 := r.getPrimaryKeyColumn(tableMap[model1]) + pk2 := r.getPrimaryKeyColumn(tableMap[model2]) + + if pk1 == nil || pk2 == nil { + return // Can't create join table without PKs + } + + // Create FK columns in join table + fkCol1Name := fmt.Sprintf("%sId", model1) + fkCol1 := models.InitColumn(fkCol1Name, joinTableName, "public") + fkCol1.Type = pk1.Type + fkCol1.NotNull = true + joinTable.Columns[fkCol1Name] = fkCol1 + + fkCol2Name := fmt.Sprintf("%sId", model2) + fkCol2 := models.InitColumn(fkCol2Name, joinTableName, "public") + fkCol2.Type = pk2.Type + fkCol2.NotNull = true + joinTable.Columns[fkCol2Name] = fkCol2 + + // Create composite primary key + pkConstraint := models.InitConstraint( + fmt.Sprintf("pk_%s", joinTableName), + models.PrimaryKeyConstraint, + ) + pkConstraint.Schema = "public" + pkConstraint.Table = joinTableName + pkConstraint.Columns = []string{fkCol1Name, fkCol2Name} + joinTable.Constraints[pkConstraint.Name] = pkConstraint + + // Mark columns as PK + fkCol1.IsPrimaryKey = true + fkCol2.IsPrimaryKey = true + + // Create FK constraints + fk1 := models.InitConstraint( + fmt.Sprintf("fk_%s_%s", joinTableName, model1), + models.ForeignKeyConstraint, + ) + fk1.Schema = "public" + fk1.Table = joinTableName + fk1.Columns = []string{fkCol1Name} + fk1.ReferencedSchema = "public" + fk1.ReferencedTable = model1 + fk1.ReferencedColumns = []string{pk1.Name} + fk1.OnDelete = "Cascade" + joinTable.Constraints[fk1.Name] = fk1 + + fk2 := models.InitConstraint( + fmt.Sprintf("fk_%s_%s", joinTableName, model2), + models.ForeignKeyConstraint, + ) + fk2.Schema = "public" + fk2.Table = joinTableName + fk2.Columns = []string{fkCol2Name} + fk2.ReferencedSchema = "public" + fk2.ReferencedTable = model2 + fk2.ReferencedColumns = []string{pk2.Name} + fk2.OnDelete = "Cascade" + joinTable.Constraints[fk2.Name] = fk2 + + // Add join table to schema + schema.Tables = append(schema.Tables, joinTable) + tableMap[joinTableName] = joinTable +} + +// getPrimaryKeyColumn returns the primary key column of a table +func (r *Reader) getPrimaryKeyColumn(table *models.Table) *models.Column { + if table == nil { + return nil + } + + for _, col := range table.Columns { + if col.IsPrimaryKey { + return col + } + } + + return nil +} diff --git a/pkg/readers/typeorm/reader.go b/pkg/readers/typeorm/reader.go new file mode 100644 index 0000000..0b6d300 --- /dev/null +++ b/pkg/readers/typeorm/reader.go @@ -0,0 +1,785 @@ +package typeorm + +import ( + "bufio" + "fmt" + "os" + "regexp" + "strings" + + "git.warky.dev/wdevs/relspecgo/pkg/models" + "git.warky.dev/wdevs/relspecgo/pkg/readers" +) + +// Reader implements the readers.Reader interface for TypeORM entity files +type Reader struct { + options *readers.ReaderOptions +} + +// NewReader creates a new TypeORM reader with the given options +func NewReader(options *readers.ReaderOptions) *Reader { + return &Reader{ + options: options, + } +} + +// ReadDatabase reads and parses TypeORM entity files, returning a Database model +func (r *Reader) ReadDatabase() (*models.Database, error) { + if r.options.FilePath == "" { + return nil, fmt.Errorf("file path is required for TypeORM reader") + } + + content, err := os.ReadFile(r.options.FilePath) + if err != nil { + return nil, fmt.Errorf("failed to read file: %w", err) + } + + return r.parseTypeORM(string(content)) +} + +// ReadSchema reads and parses TypeORM entity files, returning a Schema model +func (r *Reader) ReadSchema() (*models.Schema, error) { + db, err := r.ReadDatabase() + if err != nil { + return nil, err + } + + if len(db.Schemas) == 0 { + return nil, fmt.Errorf("no schemas found in TypeORM entities") + } + + return db.Schemas[0], nil +} + +// ReadTable reads and parses TypeORM entity files, returning a Table model +func (r *Reader) ReadTable() (*models.Table, error) { + schema, err := r.ReadSchema() + if err != nil { + return nil, err + } + + if len(schema.Tables) == 0 { + return nil, fmt.Errorf("no tables found in TypeORM entities") + } + + return schema.Tables[0], nil +} + +// entityInfo stores information about an entity during parsing +type entityInfo struct { + name string + fields []fieldInfo + decorators []string +} + +// fieldInfo stores information about a field during parsing +type fieldInfo struct { + name string + typeName string + decorators []string +} + +// parseTypeORM parses TypeORM entity content and returns a Database model +func (r *Reader) parseTypeORM(content string) (*models.Database, error) { + db := models.InitDatabase("database") + schema := models.InitSchema("public") + + // Parse entities + entities := r.extractEntities(content) + + // Convert entities to tables and views + tableMap := make(map[string]*models.Table) + for _, entity := range entities { + // Check if this is a view + isView := false + for _, decorator := range entity.decorators { + if strings.HasPrefix(decorator, "@ViewEntity") { + isView = true + break + } + } + + if isView { + view := r.entityToView(entity) + schema.Views = append(schema.Views, view) + } else { + table := r.entityToTable(entity) + schema.Tables = append(schema.Tables, table) + tableMap[table.Name] = table + } + } + + // Second pass: resolve relationships + r.resolveRelationships(entities, tableMap, schema) + + db.Schemas = append(db.Schemas, schema) + return db, nil +} + +// extractEntities extracts entity and view definitions from TypeORM content +func (r *Reader) extractEntities(content string) []entityInfo { + entities := make([]entityInfo, 0) + + // First, extract decorators properly (handling multi-line) + content = r.normalizeDecorators(content) + + scanner := bufio.NewScanner(strings.NewReader(content)) + + entityRegex := regexp.MustCompile(`^export\s+class\s+(\w+)`) + decoratorRegex := regexp.MustCompile(`^\s*@(\w+)(\([^)]*\))?`) + fieldRegex := regexp.MustCompile(`^\s*(\w+):\s*([^;]+);`) + + var currentEntity *entityInfo + var pendingDecorators []string + inClass := false + + for scanner.Scan() { + line := scanner.Text() + trimmed := strings.TrimSpace(line) + + // Skip empty lines and comments + if trimmed == "" || strings.HasPrefix(trimmed, "//") || strings.HasPrefix(trimmed, "import ") { + continue + } + + // Check for decorator + if matches := decoratorRegex.FindStringSubmatch(trimmed); matches != nil { + decorator := matches[0] + pendingDecorators = append(pendingDecorators, decorator) + continue + } + + // Check for entity/view class + if matches := entityRegex.FindStringSubmatch(trimmed); matches != nil { + // Save previous entity if exists + if currentEntity != nil { + entities = append(entities, *currentEntity) + } + + currentEntity = &entityInfo{ + name: matches[1], + fields: make([]fieldInfo, 0), + decorators: pendingDecorators, + } + pendingDecorators = []string{} + inClass = true + continue + } + + // Check for class end + if inClass && trimmed == "}" { + if currentEntity != nil { + entities = append(entities, *currentEntity) + currentEntity = nil + } + inClass = false + pendingDecorators = []string{} + continue + } + + // Check for field definition + if inClass && currentEntity != nil { + if matches := fieldRegex.FindStringSubmatch(trimmed); matches != nil { + fieldName := matches[1] + fieldType := strings.TrimSpace(matches[2]) + + field := fieldInfo{ + name: fieldName, + typeName: fieldType, + decorators: pendingDecorators, + } + currentEntity.fields = append(currentEntity.fields, field) + pendingDecorators = []string{} + } + } + } + + // Save last entity + if currentEntity != nil { + entities = append(entities, *currentEntity) + } + + return entities +} + +// normalizeDecorators combines multi-line decorators into single lines +func (r *Reader) normalizeDecorators(content string) string { + // Replace multi-line decorators with single-line versions + // Match @Decorator({ ... }) across multiple lines + decoratorRegex := regexp.MustCompile(`@(\w+)\s*\(\s*\{([^}]*)\}\s*\)`) + + return decoratorRegex.ReplaceAllStringFunc(content, func(match string) string { + // Remove newlines and extra spaces from decorator + match = strings.ReplaceAll(match, "\n", " ") + match = strings.ReplaceAll(match, "\r", " ") + // Normalize multiple spaces + spaceRegex := regexp.MustCompile(`\s+`) + match = spaceRegex.ReplaceAllString(match, " ") + return match + }) +} + +// entityToView converts a view entity to a view +func (r *Reader) entityToView(entity entityInfo) *models.View { + // Parse @ViewEntity decorator options + viewName := entity.name + schemaName := "public" + var expression string + + for _, decorator := range entity.decorators { + if strings.HasPrefix(decorator, "@ViewEntity") { + // Extract options from @ViewEntity({ ... }) + options := r.parseViewEntityOptions(decorator) + + // Check for custom view name + if name, ok := options["name"]; ok { + viewName = name + } + + // Check for schema + if schema, ok := options["schema"]; ok { + schemaName = schema + } + + // Check for expression (SQL definition) + if expr, ok := options["expression"]; ok { + expression = expr + } + break + } + } + + view := models.InitView(viewName, schemaName) + view.Definition = expression + + // Add columns from fields (if any are defined in the view class) + for _, field := range entity.fields { + column := models.InitColumn(field.name, viewName, schemaName) + column.Type = r.typeScriptTypeToSQL(field.typeName) + view.Columns[column.Name] = column + } + + return view +} + +// parseViewEntityOptions parses @ViewEntity decorator options +func (r *Reader) parseViewEntityOptions(decorator string) map[string]string { + options := make(map[string]string) + + // Extract content between parentheses + start := strings.Index(decorator, "(") + end := strings.LastIndex(decorator, ")") + + if start == -1 || end == -1 || start >= end { + return options + } + + content := decorator[start+1 : end] + + // Skip if empty @ViewEntity() + if strings.TrimSpace(content) == "" { + return options + } + + // Parse name: "value" + nameRegex := regexp.MustCompile(`name:\s*["']([^"']+)["']`) + if matches := nameRegex.FindStringSubmatch(content); matches != nil { + options["name"] = matches[1] + } + + // Parse schema: "value" + schemaRegex := regexp.MustCompile(`schema:\s*["']([^"']+)["']`) + if matches := schemaRegex.FindStringSubmatch(content); matches != nil { + options["schema"] = matches[1] + } + + // Parse expression: ` ... ` (can be multi-line, captured as single line after normalization) + // Look for expression followed by backtick or quote + expressionRegex := regexp.MustCompile(`expression:\s*` + "`" + `([^` + "`" + `]+)` + "`") + if matches := expressionRegex.FindStringSubmatch(content); matches != nil { + options["expression"] = strings.TrimSpace(matches[1]) + } else { + // Try with regular quotes + expressionRegex = regexp.MustCompile(`expression:\s*["']([^"']+)["']`) + if matches := expressionRegex.FindStringSubmatch(content); matches != nil { + options["expression"] = strings.TrimSpace(matches[1]) + } + } + + return options +} + +// entityToTable converts an entity to a table +func (r *Reader) entityToTable(entity entityInfo) *models.Table { + // Parse @Entity decorator options + tableName := entity.name + schemaName := "public" + var entityOptions map[string]string + + for _, decorator := range entity.decorators { + if strings.HasPrefix(decorator, "@Entity") { + // Extract options from @Entity({ ... }) + entityOptions = r.parseEntityOptions(decorator) + + // Check for custom table name + if name, ok := entityOptions["name"]; ok { + tableName = name + } + + // Check for schema + if schema, ok := entityOptions["schema"]; ok { + schemaName = schema + } + break + } + } + + table := models.InitTable(tableName, schemaName) + + // Store additional metadata from @Entity options + if entityOptions != nil { + // Store database name in metadata + if database, ok := entityOptions["database"]; ok { + if table.Metadata == nil { + table.Metadata = make(map[string]any) + } + table.Metadata["database"] = database + } + + // Store engine in metadata + if engine, ok := entityOptions["engine"]; ok { + if table.Metadata == nil { + table.Metadata = make(map[string]any) + } + table.Metadata["engine"] = engine + } + + // Store original class name if different from table name + if entity.name != tableName { + if table.Metadata == nil { + table.Metadata = make(map[string]any) + } + table.Metadata["class_name"] = entity.name + } + } + + for _, field := range entity.fields { + // Skip relation fields (they'll be handled in relationship resolution) + if r.isRelationField(field) { + continue + } + + column := r.fieldToColumn(field, table) + if column != nil { + table.Columns[column.Name] = column + } + } + + return table +} + +// parseEntityOptions parses @Entity decorator options +func (r *Reader) parseEntityOptions(decorator string) map[string]string { + options := make(map[string]string) + + // Extract content between parentheses + start := strings.Index(decorator, "(") + end := strings.LastIndex(decorator, ")") + + if start == -1 || end == -1 || start >= end { + return options + } + + content := decorator[start+1 : end] + + // Skip if empty @Entity() + if strings.TrimSpace(content) == "" { + return options + } + + // Parse name: "value" or name: 'value' + nameRegex := regexp.MustCompile(`name:\s*["']([^"']+)["']`) + if matches := nameRegex.FindStringSubmatch(content); matches != nil { + options["name"] = matches[1] + } + + // Parse schema: "value" + schemaRegex := regexp.MustCompile(`schema:\s*["']([^"']+)["']`) + if matches := schemaRegex.FindStringSubmatch(content); matches != nil { + options["schema"] = matches[1] + } + + // Parse database: "value" + databaseRegex := regexp.MustCompile(`database:\s*["']([^"']+)["']`) + if matches := databaseRegex.FindStringSubmatch(content); matches != nil { + options["database"] = matches[1] + } + + // Parse engine: "value" + engineRegex := regexp.MustCompile(`engine:\s*["']([^"']+)["']`) + if matches := engineRegex.FindStringSubmatch(content); matches != nil { + options["engine"] = matches[1] + } + + return options +} + +// isRelationField checks if a field is a relation field +func (r *Reader) isRelationField(field fieldInfo) bool { + for _, decorator := range field.decorators { + if strings.Contains(decorator, "@ManyToOne") || + strings.Contains(decorator, "@OneToMany") || + strings.Contains(decorator, "@ManyToMany") || + strings.Contains(decorator, "@OneToOne") { + return true + } + } + return false +} + +// fieldToColumn converts a field to a column +func (r *Reader) fieldToColumn(field fieldInfo, table *models.Table) *models.Column { + column := models.InitColumn(field.name, table.Name, table.Schema) + + // Map TypeScript type to SQL type + column.Type = r.typeScriptTypeToSQL(field.typeName) + + // Default to NOT NULL + column.NotNull = true + + // Parse decorators + for _, decorator := range field.decorators { + r.parseColumnDecorator(decorator, column, table) + } + + return column +} + +// typeScriptTypeToSQL converts TypeScript types to SQL types +func (r *Reader) typeScriptTypeToSQL(tsType string) string { + // Remove array brackets and optional markers + tsType = strings.TrimSuffix(tsType, "[]") + tsType = strings.TrimSuffix(tsType, " | null") + + typeMap := map[string]string{ + "string": "text", + "number": "integer", + "boolean": "boolean", + "Date": "timestamp", + "any": "jsonb", + } + + for tsPattern, sqlType := range typeMap { + if strings.Contains(tsType, tsPattern) { + return sqlType + } + } + + // Default to text + return "text" +} + +// parseColumnDecorator parses a column decorator +func (r *Reader) parseColumnDecorator(decorator string, column *models.Column, table *models.Table) { + // @PrimaryGeneratedColumn + if strings.HasPrefix(decorator, "@PrimaryGeneratedColumn") { + column.IsPrimaryKey = true + column.NotNull = true + + if strings.Contains(decorator, "'uuid'") { + column.Type = "uuid" + column.Default = "gen_random_uuid()" + } else if strings.Contains(decorator, "'increment'") || strings.Contains(decorator, "()") { + column.AutoIncrement = true + } + return + } + + // @Column + if strings.HasPrefix(decorator, "@Column") { + r.parseColumnOptions(decorator, column, table) + return + } + + // @CreateDateColumn + if strings.HasPrefix(decorator, "@CreateDateColumn") { + column.Type = "timestamp" + column.Default = "now()" + column.NotNull = true + return + } + + // @UpdateDateColumn + if strings.HasPrefix(decorator, "@UpdateDateColumn") { + column.Type = "timestamp" + column.NotNull = true + if column.Comment != "" { + column.Comment += "; auto-update" + } else { + column.Comment = "auto-update" + } + return + } +} + +// parseColumnOptions parses @Column decorator options +func (r *Reader) parseColumnOptions(decorator string, column *models.Column, table *models.Table) { + // Extract content between parentheses + start := strings.Index(decorator, "(") + end := strings.LastIndex(decorator, ")") + + if start == -1 || end == -1 || start >= end { + return + } + + content := decorator[start+1 : end] + + // Check for shorthand type: @Column('text') + if strings.HasPrefix(content, "'") || strings.HasPrefix(content, "\"") { + typeStr := strings.Trim(content, "'\"`") + column.Type = typeStr + return + } + + // Parse options object + if strings.Contains(content, "type:") { + typeRegex := regexp.MustCompile(`type:\s*['"]([^'"]+)['"]`) + if matches := typeRegex.FindStringSubmatch(content); matches != nil { + column.Type = matches[1] + } + } + + if strings.Contains(content, "nullable: true") || strings.Contains(content, "nullable:true") { + column.NotNull = false + } + + if strings.Contains(content, "unique: true") || strings.Contains(content, "unique:true") { + uniqueConstraint := models.InitConstraint( + fmt.Sprintf("uq_%s", column.Name), + models.UniqueConstraint, + ) + uniqueConstraint.Schema = table.Schema + uniqueConstraint.Table = table.Name + uniqueConstraint.Columns = []string{column.Name} + table.Constraints[uniqueConstraint.Name] = uniqueConstraint + } + + if strings.Contains(content, "default:") { + defaultRegex := regexp.MustCompile(`default:\s*['"]?([^,}'"]+)['"]?`) + if matches := defaultRegex.FindStringSubmatch(content); matches != nil { + defaultValue := strings.TrimSpace(matches[1]) + defaultValue = strings.Trim(defaultValue, "'\"") + column.Default = defaultValue + } + } +} + +// resolveRelationships resolves TypeORM relationships +func (r *Reader) resolveRelationships(entities []entityInfo, tableMap map[string]*models.Table, schema *models.Schema) { + // Track M2M relations that need join tables + type m2mRelation struct { + ownerEntity string + targetEntity string + ownerField string + } + m2mRelations := make([]m2mRelation, 0) + + for _, entity := range entities { + table := tableMap[entity.name] + if table == nil { + continue + } + + for _, field := range entity.fields { + // Handle @ManyToOne relations + if r.hasDecorator(field, "@ManyToOne") { + r.createManyToOneConstraint(field, entity.name, table, tableMap) + } + + // Track @ManyToMany relations with @JoinTable + if r.hasDecorator(field, "@ManyToMany") && r.hasDecorator(field, "@JoinTable") { + targetEntity := r.extractRelationTarget(field) + if targetEntity != "" { + m2mRelations = append(m2mRelations, m2mRelation{ + ownerEntity: entity.name, + targetEntity: targetEntity, + ownerField: field.name, + }) + } + } + } + } + + // Create join tables for M2M relations + for _, rel := range m2mRelations { + r.createManyToManyJoinTable(rel.ownerEntity, rel.targetEntity, tableMap, schema) + } +} + +// hasDecorator checks if a field has a specific decorator +func (r *Reader) hasDecorator(field fieldInfo, decoratorName string) bool { + for _, decorator := range field.decorators { + if strings.HasPrefix(decorator, decoratorName) { + return true + } + } + return false +} + +// extractRelationTarget extracts the target entity from a relation decorator +func (r *Reader) extractRelationTarget(field fieldInfo) string { + // Remove array brackets from type + targetType := strings.TrimSuffix(field.typeName, "[]") + targetType = strings.TrimSpace(targetType) + return targetType +} + +// createManyToOneConstraint creates a foreign key constraint for @ManyToOne +func (r *Reader) createManyToOneConstraint(field fieldInfo, entityName string, table *models.Table, tableMap map[string]*models.Table) { + targetEntity := r.extractRelationTarget(field) + if targetEntity == "" { + return + } + + // Get target table to find its PK + targetTable := tableMap[targetEntity] + if targetTable == nil { + return + } + + targetPK := r.getPrimaryKeyColumn(targetTable) + if targetPK == nil { + return + } + + // Create FK column + fkColumnName := fmt.Sprintf("%sId", field.name) + fkColumn := models.InitColumn(fkColumnName, table.Name, table.Schema) + fkColumn.Type = targetPK.Type + + // Check if nullable option is set in @ManyToOne decorator + isNullable := false + for _, decorator := range field.decorators { + if strings.Contains(decorator, "nullable: true") || strings.Contains(decorator, "nullable:true") { + isNullable = true + break + } + } + fkColumn.NotNull = !isNullable + + table.Columns[fkColumnName] = fkColumn + + // Create FK constraint + constraint := models.InitConstraint( + fmt.Sprintf("fk_%s_%s", entityName, field.name), + models.ForeignKeyConstraint, + ) + constraint.Schema = table.Schema + constraint.Table = table.Name + constraint.Columns = []string{fkColumnName} + constraint.ReferencedSchema = "public" + constraint.ReferencedTable = targetEntity + constraint.ReferencedColumns = []string{targetPK.Name} + constraint.OnDelete = "CASCADE" + + table.Constraints[constraint.Name] = constraint +} + +// createManyToManyJoinTable creates a join table for M2M relations +func (r *Reader) createManyToManyJoinTable(entity1, entity2 string, tableMap map[string]*models.Table, schema *models.Schema) { + // TypeORM naming convention: entity1_entity2_entity1field + // We'll simplify to entity1_entity2 + joinTableName := fmt.Sprintf("%s_%s", strings.ToLower(entity1), strings.ToLower(entity2)) + + // Check if join table already exists + if _, exists := tableMap[joinTableName]; exists { + return + } + + // Get PKs from both tables + table1 := tableMap[entity1] + table2 := tableMap[entity2] + if table1 == nil || table2 == nil { + return + } + + pk1 := r.getPrimaryKeyColumn(table1) + pk2 := r.getPrimaryKeyColumn(table2) + if pk1 == nil || pk2 == nil { + return + } + + // Create join table + joinTable := models.InitTable(joinTableName, "public") + + // Create FK columns + fkCol1Name := fmt.Sprintf("%sId", strings.ToLower(entity1)) + fkCol1 := models.InitColumn(fkCol1Name, joinTableName, "public") + fkCol1.Type = pk1.Type + fkCol1.NotNull = true + fkCol1.IsPrimaryKey = true + joinTable.Columns[fkCol1Name] = fkCol1 + + fkCol2Name := fmt.Sprintf("%sId", strings.ToLower(entity2)) + fkCol2 := models.InitColumn(fkCol2Name, joinTableName, "public") + fkCol2.Type = pk2.Type + fkCol2.NotNull = true + fkCol2.IsPrimaryKey = true + joinTable.Columns[fkCol2Name] = fkCol2 + + // Create composite PK constraint + pkConstraint := models.InitConstraint( + fmt.Sprintf("pk_%s", joinTableName), + models.PrimaryKeyConstraint, + ) + pkConstraint.Schema = "public" + pkConstraint.Table = joinTableName + pkConstraint.Columns = []string{fkCol1Name, fkCol2Name} + joinTable.Constraints[pkConstraint.Name] = pkConstraint + + // Create FK constraints + fk1 := models.InitConstraint( + fmt.Sprintf("fk_%s_%s", joinTableName, entity1), + models.ForeignKeyConstraint, + ) + fk1.Schema = "public" + fk1.Table = joinTableName + fk1.Columns = []string{fkCol1Name} + fk1.ReferencedSchema = "public" + fk1.ReferencedTable = entity1 + fk1.ReferencedColumns = []string{pk1.Name} + fk1.OnDelete = "CASCADE" + joinTable.Constraints[fk1.Name] = fk1 + + fk2 := models.InitConstraint( + fmt.Sprintf("fk_%s_%s", joinTableName, entity2), + models.ForeignKeyConstraint, + ) + fk2.Schema = "public" + fk2.Table = joinTableName + fk2.Columns = []string{fkCol2Name} + fk2.ReferencedSchema = "public" + fk2.ReferencedTable = entity2 + fk2.ReferencedColumns = []string{pk2.Name} + fk2.OnDelete = "CASCADE" + joinTable.Constraints[fk2.Name] = fk2 + + // Add join table to schema + schema.Tables = append(schema.Tables, joinTable) + tableMap[joinTableName] = joinTable +} + +// getPrimaryKeyColumn returns the primary key column of a table +func (r *Reader) getPrimaryKeyColumn(table *models.Table) *models.Column { + if table == nil { + return nil + } + + for _, col := range table.Columns { + if col.IsPrimaryKey { + return col + } + } + + return nil +} diff --git a/pkg/writers/prisma/writer.go b/pkg/writers/prisma/writer.go new file mode 100644 index 0000000..c1471c9 --- /dev/null +++ b/pkg/writers/prisma/writer.go @@ -0,0 +1,551 @@ +package prisma + +import ( + "fmt" + "os" + "sort" + "strings" + + "git.warky.dev/wdevs/relspecgo/pkg/models" + "git.warky.dev/wdevs/relspecgo/pkg/writers" +) + +// Writer implements the writers.Writer interface for Prisma schema format +type Writer struct { + options *writers.WriterOptions +} + +// NewWriter creates a new Prisma writer with the given options +func NewWriter(options *writers.WriterOptions) *Writer { + return &Writer{ + options: options, + } +} + +// WriteDatabase writes a Database model to Prisma schema format +func (w *Writer) WriteDatabase(db *models.Database) error { + content := w.databaseToPrisma(db) + + if w.options.OutputPath != "" { + return os.WriteFile(w.options.OutputPath, []byte(content), 0644) + } + + fmt.Print(content) + return nil +} + +// WriteSchema writes a Schema model to Prisma schema format +func (w *Writer) WriteSchema(schema *models.Schema) error { + // Create temporary database for schema + db := models.InitDatabase("database") + db.Schemas = []*models.Schema{schema} + + return w.WriteDatabase(db) +} + +// WriteTable writes a Table model to Prisma schema format +func (w *Writer) WriteTable(table *models.Table) error { + // Create temporary schema and database for table + schema := models.InitSchema(table.Schema) + schema.Tables = []*models.Table{table} + + return w.WriteSchema(schema) +} + +// databaseToPrisma converts a Database to Prisma schema format string +func (w *Writer) databaseToPrisma(db *models.Database) string { + var sb strings.Builder + + // Write datasource block + sb.WriteString(w.generateDatasource(db)) + sb.WriteString("\n") + + // Write generator block + sb.WriteString(w.generateGenerator()) + sb.WriteString("\n") + + // Process all schemas (typically just one in Prisma) + for _, schema := range db.Schemas { + // Write enums + if len(schema.Enums) > 0 { + for _, enum := range schema.Enums { + sb.WriteString(w.enumToPrisma(enum)) + sb.WriteString("\n") + } + } + + // Identify join tables for implicit M2M + joinTables := w.identifyJoinTables(schema) + + // Write models (excluding join tables) + for _, table := range schema.Tables { + if joinTables[table.Name] { + continue // Skip join tables + } + sb.WriteString(w.tableToPrisma(table, schema, joinTables)) + sb.WriteString("\n") + } + } + + return sb.String() +} + +// generateDatasource generates the datasource block +func (w *Writer) generateDatasource(db *models.Database) string { + provider := "postgresql" + + // Map database type to Prisma provider + switch db.DatabaseType { + case models.PostgresqlDatabaseType: + provider = "postgresql" + case models.MSSQLDatabaseType: + provider = "sqlserver" + case models.SqlLiteDatabaseType: + provider = "sqlite" + case "mysql": + provider = "mysql" + } + + return fmt.Sprintf(`datasource db { + provider = "%s" + url = env("DATABASE_URL") +} +`, provider) +} + +// generateGenerator generates the generator block +func (w *Writer) generateGenerator() string { + return `generator client { + provider = "prisma-client-js" +} +` +} + +// enumToPrisma converts an Enum to Prisma enum block +func (w *Writer) enumToPrisma(enum *models.Enum) string { + var sb strings.Builder + + sb.WriteString(fmt.Sprintf("enum %s {\n", enum.Name)) + for _, value := range enum.Values { + sb.WriteString(fmt.Sprintf(" %s\n", value)) + } + sb.WriteString("}\n") + + return sb.String() +} + +// identifyJoinTables identifies tables that are join tables for M2M relations +func (w *Writer) identifyJoinTables(schema *models.Schema) map[string]bool { + joinTables := make(map[string]bool) + + for _, table := range schema.Tables { + // Check if this is a join table: + // 1. Starts with _ (Prisma convention) + // 2. Has exactly 2 FK constraints + // 3. Has composite PK with those 2 columns + // 4. Has no other columns except the FK columns + + if !strings.HasPrefix(table.Name, "_") { + continue + } + + fks := table.GetForeignKeys() + if len(fks) != 2 { + continue + } + + // Check if columns are only the FK columns + if len(table.Columns) != 2 { + continue + } + + // Check if both FK columns are part of PK + pkCols := 0 + for _, col := range table.Columns { + if col.IsPrimaryKey { + pkCols++ + } + } + + if pkCols == 2 { + joinTables[table.Name] = true + } + } + + return joinTables +} + +// tableToPrisma converts a Table to Prisma model block +func (w *Writer) tableToPrisma(table *models.Table, schema *models.Schema, joinTables map[string]bool) string { + var sb strings.Builder + + sb.WriteString(fmt.Sprintf("model %s {\n", table.Name)) + + // Collect columns to write + columns := make([]*models.Column, 0, len(table.Columns)) + for _, col := range table.Columns { + columns = append(columns, col) + } + + // Sort columns for consistent output + sort.Slice(columns, func(i, j int) bool { + return columns[i].Name < columns[j].Name + }) + + // Write scalar fields + for _, col := range columns { + // Skip if this column is part of a relation that will be output as array field + if w.isRelationColumn(col, table) { + // We'll output this with the relation field + continue + } + + sb.WriteString(w.columnToField(col, table, schema)) + } + + // Write relation fields + sb.WriteString(w.generateRelationFields(table, schema, joinTables)) + + // Write block attributes (@@id, @@unique, @@index) + sb.WriteString(w.generateBlockAttributes(table)) + + sb.WriteString("}\n") + + return sb.String() +} + +// columnToField converts a Column to a Prisma field definition +func (w *Writer) columnToField(col *models.Column, table *models.Table, schema *models.Schema) string { + var sb strings.Builder + + // Field name + sb.WriteString(fmt.Sprintf(" %s", col.Name)) + + // Field type + prismaType := w.sqlTypeToPrisma(col.Type, schema) + sb.WriteString(fmt.Sprintf(" %s", prismaType)) + + // Optional modifier + if !col.NotNull && !col.IsPrimaryKey { + sb.WriteString("?") + } + + // Field attributes + attributes := w.generateFieldAttributes(col, table) + if attributes != "" { + sb.WriteString(" ") + sb.WriteString(attributes) + } + + sb.WriteString("\n") + + return sb.String() +} + +// sqlTypeToPrisma converts SQL types to Prisma types +func (w *Writer) sqlTypeToPrisma(sqlType string, schema *models.Schema) string { + // Check if it's an enum + for _, enum := range schema.Enums { + if strings.EqualFold(sqlType, enum.Name) { + return enum.Name + } + } + + // Standard type mapping + typeMap := map[string]string{ + "text": "String", + "varchar": "String", + "character varying": "String", + "char": "String", + "boolean": "Boolean", + "bool": "Boolean", + "integer": "Int", + "int": "Int", + "int4": "Int", + "bigint": "BigInt", + "int8": "BigInt", + "double precision": "Float", + "float": "Float", + "float8": "Float", + "decimal": "Decimal", + "numeric": "Decimal", + "timestamp": "DateTime", + "timestamptz": "DateTime", + "date": "DateTime", + "jsonb": "Json", + "json": "Json", + "bytea": "Bytes", + } + + for sqlPattern, prismaType := range typeMap { + if strings.Contains(strings.ToLower(sqlType), sqlPattern) { + return prismaType + } + } + + // Default to String for unknown types + return "String" +} + +// generateFieldAttributes generates field attributes like @id, @unique, @default +func (w *Writer) generateFieldAttributes(col *models.Column, table *models.Table) string { + attrs := make([]string, 0) + + // @id + if col.IsPrimaryKey { + // Check if this is part of a composite key + pkCount := 0 + for _, c := range table.Columns { + if c.IsPrimaryKey { + pkCount++ + } + } + if pkCount == 1 { + attrs = append(attrs, "@id") + } + } + + // @unique + if w.hasUniqueConstraint(col.Name, table) { + attrs = append(attrs, "@unique") + } + + // @default + if col.AutoIncrement { + attrs = append(attrs, "@default(autoincrement())") + } else if col.Default != nil { + defaultAttr := w.formatDefaultValue(col.Default) + if defaultAttr != "" { + attrs = append(attrs, fmt.Sprintf("@default(%s)", defaultAttr)) + } + } + + // @updatedAt (check comment) + if strings.Contains(col.Comment, "@updatedAt") { + attrs = append(attrs, "@updatedAt") + } + + return strings.Join(attrs, " ") +} + +// formatDefaultValue formats a default value for Prisma +func (w *Writer) formatDefaultValue(defaultValue any) string { + switch v := defaultValue.(type) { + case string: + if v == "now()" { + return "now()" + } else if v == "gen_random_uuid()" { + return "uuid()" + } else if strings.Contains(strings.ToLower(v), "uuid") { + return "uuid()" + } else { + // String literal + return fmt.Sprintf(`"%s"`, v) + } + case bool: + if v { + return "true" + } + return "false" + case int, int64, int32: + return fmt.Sprintf("%v", v) + default: + return fmt.Sprintf("%v", v) + } +} + +// hasUniqueConstraint checks if a column has a unique constraint +func (w *Writer) hasUniqueConstraint(colName string, table *models.Table) bool { + for _, constraint := range table.Constraints { + if constraint.Type == models.UniqueConstraint && + len(constraint.Columns) == 1 && + constraint.Columns[0] == colName { + return true + } + } + return false +} + +// isRelationColumn checks if a column is a FK column +func (w *Writer) isRelationColumn(col *models.Column, table *models.Table) bool { + for _, constraint := range table.Constraints { + if constraint.Type == models.ForeignKeyConstraint { + for _, fkCol := range constraint.Columns { + if fkCol == col.Name { + return true + } + } + } + } + return false +} + +// generateRelationFields generates relation fields and their FK columns +func (w *Writer) generateRelationFields(table *models.Table, schema *models.Schema, joinTables map[string]bool) string { + var sb strings.Builder + + // Get all FK constraints + fks := table.GetForeignKeys() + + for _, fk := range fks { + // Generate the FK scalar field + for _, fkCol := range fk.Columns { + if col, exists := table.Columns[fkCol]; exists { + sb.WriteString(w.columnToField(col, table, schema)) + } + } + + // Generate the relation field + relationType := fk.ReferencedTable + isOptional := false + + // Check if FK column is nullable + for _, fkCol := range fk.Columns { + if col, exists := table.Columns[fkCol]; exists { + if !col.NotNull { + isOptional = true + } + } + } + + relationName := relationType + if strings.HasSuffix(strings.ToLower(relationName), "s") { + relationName = relationName[:len(relationName)-1] + } + + sb.WriteString(fmt.Sprintf(" %s %s", strings.ToLower(relationName), relationType)) + + if isOptional { + sb.WriteString("?") + } + + // @relation attribute + relationAttr := w.generateRelationAttribute(fk) + if relationAttr != "" { + sb.WriteString(" ") + sb.WriteString(relationAttr) + } + + sb.WriteString("\n") + } + + // Generate inverse relations (arrays) for tables that reference this one + sb.WriteString(w.generateInverseRelations(table, schema, joinTables)) + + return sb.String() +} + +// generateRelationAttribute generates the @relation(...) attribute +func (w *Writer) generateRelationAttribute(fk *models.Constraint) string { + parts := make([]string, 0) + + // fields + fieldsStr := strings.Join(fk.Columns, ", ") + parts = append(parts, fmt.Sprintf("fields: [%s]", fieldsStr)) + + // references + referencesStr := strings.Join(fk.ReferencedColumns, ", ") + parts = append(parts, fmt.Sprintf("references: [%s]", referencesStr)) + + // onDelete + if fk.OnDelete != "" { + parts = append(parts, fmt.Sprintf("onDelete: %s", fk.OnDelete)) + } + + // onUpdate + if fk.OnUpdate != "" { + parts = append(parts, fmt.Sprintf("onUpdate: %s", fk.OnUpdate)) + } + + return fmt.Sprintf("@relation(%s)", strings.Join(parts, ", ")) +} + +// generateInverseRelations generates array fields for reverse relationships +func (w *Writer) generateInverseRelations(table *models.Table, schema *models.Schema, joinTables map[string]bool) string { + var sb strings.Builder + + // Find all tables that have FKs pointing to this table + for _, otherTable := range schema.Tables { + if otherTable.Name == table.Name { + continue + } + + // Check if this is a join table + if joinTables[otherTable.Name] { + // Handle implicit M2M + if w.isJoinTableFor(otherTable, table.Name) { + // Find the other side of the M2M + for _, fk := range otherTable.GetForeignKeys() { + if fk.ReferencedTable != table.Name { + // This is the other side + otherSide := fk.ReferencedTable + sb.WriteString(fmt.Sprintf(" %ss %s[]\n", + strings.ToLower(otherSide), otherSide)) + break + } + } + } + continue + } + + // Regular one-to-many inverse relation + for _, fk := range otherTable.GetForeignKeys() { + if fk.ReferencedTable == table.Name { + // This table is referenced by otherTable + pluralName := otherTable.Name + if !strings.HasSuffix(pluralName, "s") { + pluralName += "s" + } + + sb.WriteString(fmt.Sprintf(" %s %s[]\n", + strings.ToLower(pluralName), otherTable.Name)) + } + } + } + + return sb.String() +} + +// isJoinTableFor checks if a table is a join table involving the specified model +func (w *Writer) isJoinTableFor(joinTable *models.Table, modelName string) bool { + for _, fk := range joinTable.GetForeignKeys() { + if fk.ReferencedTable == modelName { + return true + } + } + return false +} + +// generateBlockAttributes generates block-level attributes like @@id, @@unique, @@index +func (w *Writer) generateBlockAttributes(table *models.Table) string { + var sb strings.Builder + + // @@id for composite primary key + pkCols := make([]string, 0) + for _, col := range table.Columns { + if col.IsPrimaryKey { + pkCols = append(pkCols, col.Name) + } + } + + if len(pkCols) > 1 { + sort.Strings(pkCols) + sb.WriteString(fmt.Sprintf(" @@id([%s])\n", strings.Join(pkCols, ", "))) + } + + // @@unique for multi-column unique constraints + for _, constraint := range table.Constraints { + if constraint.Type == models.UniqueConstraint && len(constraint.Columns) > 1 { + sb.WriteString(fmt.Sprintf(" @@unique([%s])\n", strings.Join(constraint.Columns, ", "))) + } + } + + // @@index for indexes + for _, index := range table.Indexes { + if !index.Unique { // Unique indexes are handled by @@unique + sb.WriteString(fmt.Sprintf(" @@index([%s])\n", strings.Join(index.Columns, ", "))) + } + } + + return sb.String() +} diff --git a/pkg/writers/typeorm/writer.go b/pkg/writers/typeorm/writer.go new file mode 100644 index 0000000..759f796 --- /dev/null +++ b/pkg/writers/typeorm/writer.go @@ -0,0 +1,631 @@ +package typeorm + +import ( + "fmt" + "os" + "sort" + "strings" + + "git.warky.dev/wdevs/relspecgo/pkg/models" + "git.warky.dev/wdevs/relspecgo/pkg/writers" +) + +// Writer implements the writers.Writer interface for TypeORM entity format +type Writer struct { + options *writers.WriterOptions +} + +// NewWriter creates a new TypeORM writer with the given options +func NewWriter(options *writers.WriterOptions) *Writer { + return &Writer{ + options: options, + } +} + +// WriteDatabase writes a Database model to TypeORM entity format +func (w *Writer) WriteDatabase(db *models.Database) error { + content := w.databaseToTypeORM(db) + + if w.options.OutputPath != "" { + return os.WriteFile(w.options.OutputPath, []byte(content), 0644) + } + + fmt.Print(content) + return nil +} + +// WriteSchema writes a Schema model to TypeORM entity format +func (w *Writer) WriteSchema(schema *models.Schema) error { + db := models.InitDatabase("database") + db.Schemas = []*models.Schema{schema} + + return w.WriteDatabase(db) +} + +// WriteTable writes a Table model to TypeORM entity format +func (w *Writer) WriteTable(table *models.Table) error { + schema := models.InitSchema(table.Schema) + schema.Tables = []*models.Table{table} + + return w.WriteSchema(schema) +} + +// databaseToTypeORM converts a Database to TypeORM entity format string +func (w *Writer) databaseToTypeORM(db *models.Database) string { + var sb strings.Builder + + // Generate imports + sb.WriteString(w.generateImports(db)) + sb.WriteString("\n") + + // Process all schemas + for _, schema := range db.Schemas { + // Identify join tables + joinTables := w.identifyJoinTables(schema) + + // Write entities (excluding join tables) + for _, table := range schema.Tables { + if joinTables[table.Name] { + continue + } + sb.WriteString(w.tableToEntity(table, schema, joinTables)) + sb.WriteString("\n") + } + + // Write view entities + for _, view := range schema.Views { + sb.WriteString(w.viewToEntity(view)) + sb.WriteString("\n") + } + } + + return sb.String() +} + +// generateImports generates the TypeORM import statement +func (w *Writer) generateImports(db *models.Database) string { + imports := make([]string, 0) + + // Always include basic decorators + imports = append(imports, "Entity", "PrimaryGeneratedColumn", "Column") + + // Check if we need relation decorators + needsManyToOne := false + needsOneToMany := false + needsManyToMany := false + needsJoinTable := false + needsCreateDate := false + needsUpdateDate := false + needsViewEntity := false + + for _, schema := range db.Schemas { + // Check for views + if len(schema.Views) > 0 { + needsViewEntity = true + } + + for _, table := range schema.Tables { + // Check for timestamp columns + for _, col := range table.Columns { + if col.Default == "now()" { + needsCreateDate = true + } + if strings.Contains(col.Comment, "auto-update") { + needsUpdateDate = true + } + } + + // Check for relations + for _, constraint := range table.Constraints { + if constraint.Type == models.ForeignKeyConstraint { + needsManyToOne = true + } + } + } + } + + // OneToMany is the inverse of ManyToOne + if needsManyToOne { + needsOneToMany = true + } + + // Check for M2M (join tables indicate M2M relations) + joinTables := make(map[string]bool) + for _, schema := range db.Schemas { + jt := w.identifyJoinTables(schema) + for name := range jt { + joinTables[name] = true + needsManyToMany = true + needsJoinTable = true + } + } + + if needsManyToOne { + imports = append(imports, "ManyToOne") + } + if needsOneToMany { + imports = append(imports, "OneToMany") + } + if needsManyToMany { + imports = append(imports, "ManyToMany") + } + if needsJoinTable { + imports = append(imports, "JoinTable") + } + if needsCreateDate { + imports = append(imports, "CreateDateColumn") + } + if needsUpdateDate { + imports = append(imports, "UpdateDateColumn") + } + if needsViewEntity { + imports = append(imports, "ViewEntity") + } + + return fmt.Sprintf("import { %s } from 'typeorm';\n", strings.Join(imports, ", ")) +} + +// identifyJoinTables identifies tables that are join tables for M2M relations +func (w *Writer) identifyJoinTables(schema *models.Schema) map[string]bool { + joinTables := make(map[string]bool) + + for _, table := range schema.Tables { + // Check if this is a join table: + // 1. Has exactly 2 FK constraints + // 2. Has composite PK with those 2 columns + // 3. Has no other columns except the FK columns + + fks := table.GetForeignKeys() + if len(fks) != 2 { + continue + } + + // Check if columns are only the FK columns + if len(table.Columns) != 2 { + continue + } + + // Check if both FK columns are part of PK + pkCols := 0 + for _, col := range table.Columns { + if col.IsPrimaryKey { + pkCols++ + } + } + + if pkCols == 2 { + joinTables[table.Name] = true + } + } + + return joinTables +} + +// tableToEntity converts a Table to a TypeORM entity class +func (w *Writer) tableToEntity(table *models.Table, schema *models.Schema, joinTables map[string]bool) string { + var sb strings.Builder + + // Generate @Entity decorator with options + entityOptions := w.buildEntityOptions(table) + sb.WriteString(fmt.Sprintf("@Entity({\n%s\n})\n", entityOptions)) + + // Get class name (from metadata if different from table name) + className := table.Name + if table.Metadata != nil { + if classNameVal, ok := table.Metadata["class_name"]; ok { + if classNameStr, ok := classNameVal.(string); ok { + className = classNameStr + } + } + } + + sb.WriteString(fmt.Sprintf("export class %s {\n", className)) + + // Collect and sort columns + columns := make([]*models.Column, 0, len(table.Columns)) + for _, col := range table.Columns { + // Skip FK columns (they'll be represented as relations) + if w.isForeignKeyColumn(col, table) { + continue + } + columns = append(columns, col) + } + + sort.Slice(columns, func(i, j int) bool { + // Put PK first, then alphabetical + if columns[i].IsPrimaryKey && !columns[j].IsPrimaryKey { + return true + } + if !columns[i].IsPrimaryKey && columns[j].IsPrimaryKey { + return false + } + return columns[i].Name < columns[j].Name + }) + + // Write scalar fields + for _, col := range columns { + sb.WriteString(w.columnToField(col, table)) + sb.WriteString("\n") + } + + // Write relation fields + sb.WriteString(w.generateRelationFields(table, schema, joinTables)) + + sb.WriteString("}\n") + + return sb.String() +} + +// viewToEntity converts a View to a TypeORM @ViewEntity class +func (w *Writer) viewToEntity(view *models.View) string { + var sb strings.Builder + + // Generate @ViewEntity decorator with expression + sb.WriteString("@ViewEntity({\n") + if view.Definition != "" { + // Format the SQL expression with proper indentation + sb.WriteString(" expression: `\n") + sb.WriteString(" ") + sb.WriteString(view.Definition) + sb.WriteString("\n `,\n") + } + sb.WriteString("})\n") + + // Generate class + sb.WriteString(fmt.Sprintf("export class %s {\n", view.Name)) + + // Generate field definitions (without decorators for view fields) + columns := make([]*models.Column, 0, len(view.Columns)) + for _, col := range view.Columns { + columns = append(columns, col) + } + sort.Slice(columns, func(i, j int) bool { + return columns[i].Name < columns[j].Name + }) + + for _, col := range columns { + tsType := w.sqlTypeToTypeScript(col.Type) + sb.WriteString(fmt.Sprintf(" %s: %s;\n", col.Name, tsType)) + } + + sb.WriteString("}\n") + + return sb.String() +} + +// columnToField converts a Column to a TypeORM field +func (w *Writer) columnToField(col *models.Column, table *models.Table) string { + var sb strings.Builder + + // Generate decorator + if col.IsPrimaryKey { + if col.AutoIncrement { + sb.WriteString(" @PrimaryGeneratedColumn('increment')\n") + } else if col.Type == "uuid" || strings.Contains(fmt.Sprint(col.Default), "uuid") { + sb.WriteString(" @PrimaryGeneratedColumn('uuid')\n") + } else { + sb.WriteString(" @PrimaryGeneratedColumn()\n") + } + } else if col.Default == "now()" { + sb.WriteString(" @CreateDateColumn()\n") + } else if strings.Contains(col.Comment, "auto-update") { + sb.WriteString(" @UpdateDateColumn()\n") + } else { + // Regular @Column decorator + options := w.buildColumnOptions(col, table) + if options != "" { + sb.WriteString(fmt.Sprintf(" @Column({ %s })\n", options)) + } else { + sb.WriteString(" @Column()\n") + } + } + + // Generate field declaration + tsType := w.sqlTypeToTypeScript(col.Type) + nullable := "" + if !col.NotNull { + nullable = " | null" + } + + sb.WriteString(fmt.Sprintf(" %s: %s%s;", col.Name, tsType, nullable)) + + return sb.String() +} + +// buildColumnOptions builds the options object for @Column decorator +func (w *Writer) buildColumnOptions(col *models.Column, table *models.Table) string { + options := make([]string, 0) + + // Type (if not default) + if w.needsExplicitType(col.Type) { + options = append(options, fmt.Sprintf("type: '%s'", col.Type)) + } + + // Nullable + if !col.NotNull { + options = append(options, "nullable: true") + } + + // Unique + if w.hasUniqueConstraint(col.Name, table) { + options = append(options, "unique: true") + } + + // Default + if col.Default != nil && col.Default != "now()" { + defaultStr := fmt.Sprint(col.Default) + if defaultStr != "" { + options = append(options, fmt.Sprintf("default: '%s'", defaultStr)) + } + } + + return strings.Join(options, ", ") +} + +// needsExplicitType checks if a SQL type needs explicit type declaration +func (w *Writer) needsExplicitType(sqlType string) bool { + // Types that don't map cleanly to TypeScript types need explicit declaration + explicitTypes := []string{"text", "uuid", "jsonb", "bigint"} + for _, t := range explicitTypes { + if strings.Contains(sqlType, t) { + return true + } + } + return false +} + +// hasUniqueConstraint checks if a column has a unique constraint +func (w *Writer) hasUniqueConstraint(colName string, table *models.Table) bool { + for _, constraint := range table.Constraints { + if constraint.Type == models.UniqueConstraint && + len(constraint.Columns) == 1 && + constraint.Columns[0] == colName { + return true + } + } + return false +} + +// sqlTypeToTypeScript converts SQL types to TypeScript types +func (w *Writer) sqlTypeToTypeScript(sqlType string) string { + typeMap := map[string]string{ + "text": "string", + "varchar": "string", + "character varying": "string", + "char": "string", + "uuid": "string", + "boolean": "boolean", + "bool": "boolean", + "integer": "number", + "int": "number", + "bigint": "number", + "double precision": "number", + "float": "number", + "decimal": "number", + "numeric": "number", + "timestamp": "Date", + "timestamptz": "Date", + "date": "Date", + "jsonb": "any", + "json": "any", + } + + for sqlPattern, tsType := range typeMap { + if strings.Contains(strings.ToLower(sqlType), sqlPattern) { + return tsType + } + } + + return "any" +} + +// isForeignKeyColumn checks if a column is a FK column +func (w *Writer) isForeignKeyColumn(col *models.Column, table *models.Table) bool { + for _, constraint := range table.Constraints { + if constraint.Type == models.ForeignKeyConstraint { + for _, fkCol := range constraint.Columns { + if fkCol == col.Name { + return true + } + } + } + } + return false +} + +// generateRelationFields generates relation fields for a table +func (w *Writer) generateRelationFields(table *models.Table, schema *models.Schema, joinTables map[string]bool) string { + var sb strings.Builder + + // Get all FK constraints + fks := table.GetForeignKeys() + + // Generate @ManyToOne fields + for _, fk := range fks { + relatedTable := fk.ReferencedTable + fieldName := strings.ToLower(relatedTable) + + // Determine if nullable + isNullable := false + for _, fkCol := range fk.Columns { + if col, exists := table.Columns[fkCol]; exists { + if !col.NotNull { + isNullable = true + } + } + } + + nullable := "" + if isNullable { + nullable = " | null" + } + + // Find inverse field name if possible + inverseField := w.findInverseFieldName(table.Name, relatedTable, schema) + + if inverseField != "" { + sb.WriteString(fmt.Sprintf(" @ManyToOne(() => %s, %s => %s.%s)\n", + relatedTable, strings.ToLower(relatedTable), strings.ToLower(relatedTable), inverseField)) + } else { + if isNullable { + sb.WriteString(fmt.Sprintf(" @ManyToOne(() => %s, { nullable: true })\n", relatedTable)) + } else { + sb.WriteString(fmt.Sprintf(" @ManyToOne(() => %s)\n", relatedTable)) + } + } + + sb.WriteString(fmt.Sprintf(" %s: %s%s;\n", fieldName, relatedTable, nullable)) + sb.WriteString("\n") + } + + // Generate @OneToMany fields (inverse of FKs pointing to this table) + w.generateInverseRelations(table, schema, joinTables, &sb) + + // Generate @ManyToMany fields + w.generateManyToManyRelations(table, schema, joinTables, &sb) + + return sb.String() +} + +// findInverseFieldName finds the inverse field name for a relation +func (w *Writer) findInverseFieldName(fromTable, toTable string, schema *models.Schema) string { + // Look for tables that have FKs pointing back to fromTable + for _, table := range schema.Tables { + if table.Name != toTable { + continue + } + + for _, constraint := range table.Constraints { + if constraint.Type == models.ForeignKeyConstraint && constraint.ReferencedTable == fromTable { + // Found an inverse relation + // Use pluralized form of fromTable + return w.pluralize(strings.ToLower(fromTable)) + } + } + } + + return "" +} + +// generateInverseRelations generates @OneToMany fields +func (w *Writer) generateInverseRelations(table *models.Table, schema *models.Schema, joinTables map[string]bool, sb *strings.Builder) { + for _, otherTable := range schema.Tables { + if otherTable.Name == table.Name || joinTables[otherTable.Name] { + continue + } + + for _, fk := range otherTable.GetForeignKeys() { + if fk.ReferencedTable == table.Name { + // This table is referenced by otherTable + fieldName := w.pluralize(strings.ToLower(otherTable.Name)) + inverseName := strings.ToLower(table.Name) + + sb.WriteString(fmt.Sprintf(" @OneToMany(() => %s, %s => %s.%s)\n", + otherTable.Name, strings.ToLower(otherTable.Name), strings.ToLower(otherTable.Name), inverseName)) + sb.WriteString(fmt.Sprintf(" %s: %s[];\n", fieldName, otherTable.Name)) + sb.WriteString("\n") + } + } + } +} + +// generateManyToManyRelations generates @ManyToMany fields +func (w *Writer) generateManyToManyRelations(table *models.Table, schema *models.Schema, joinTables map[string]bool, sb *strings.Builder) { + for joinTableName := range joinTables { + joinTable := w.findTable(joinTableName, schema) + if joinTable == nil { + continue + } + + fks := joinTable.GetForeignKeys() + if len(fks) != 2 { + continue + } + + // Check if this table is part of the M2M relation + var thisTableFK *models.Constraint + var otherTableFK *models.Constraint + + for i, fk := range fks { + if fk.ReferencedTable == table.Name { + thisTableFK = fk + if i == 0 { + otherTableFK = fks[1] + } else { + otherTableFK = fks[0] + } + } + } + + if thisTableFK == nil { + continue + } + + // Determine which side owns the relation (has @JoinTable) + // We'll make the first entity alphabetically the owner + isOwner := table.Name < otherTableFK.ReferencedTable + + otherTable := otherTableFK.ReferencedTable + fieldName := w.pluralize(strings.ToLower(otherTable)) + inverseName := w.pluralize(strings.ToLower(table.Name)) + + if isOwner { + sb.WriteString(fmt.Sprintf(" @ManyToMany(() => %s, %s => %s.%s)\n", + otherTable, strings.ToLower(otherTable), strings.ToLower(otherTable), inverseName)) + sb.WriteString(" @JoinTable()\n") + } else { + sb.WriteString(fmt.Sprintf(" @ManyToMany(() => %s, %s => %s.%s)\n", + otherTable, strings.ToLower(otherTable), strings.ToLower(otherTable), inverseName)) + } + + sb.WriteString(fmt.Sprintf(" %s: %s[];\n", fieldName, otherTable)) + sb.WriteString("\n") + } +} + +// findTable finds a table by name in a schema +func (w *Writer) findTable(name string, schema *models.Schema) *models.Table { + for _, table := range schema.Tables { + if table.Name == name { + return table + } + } + return nil +} + +// buildEntityOptions builds the options object for @Entity decorator +func (w *Writer) buildEntityOptions(table *models.Table) string { + options := make([]string, 0) + + // Always include table name + options = append(options, fmt.Sprintf(" name: \"%s\"", table.Name)) + + // Always include schema + options = append(options, fmt.Sprintf(" schema: \"%s\"", table.Schema)) + + // Database name from metadata + if table.Metadata != nil { + if database, ok := table.Metadata["database"]; ok { + if databaseStr, ok := database.(string); ok { + options = append(options, fmt.Sprintf(" database: \"%s\"", databaseStr)) + } + } + + // Engine from metadata + if engine, ok := table.Metadata["engine"]; ok { + if engineStr, ok := engine.(string); ok { + options = append(options, fmt.Sprintf(" engine: \"%s\"", engineStr)) + } + } + } + + return strings.Join(options, ",\n") +} + +// pluralize adds 's' to make a word plural (simple version) +func (w *Writer) pluralize(word string) string { + if strings.HasSuffix(word, "s") { + return word + } + return word + "s" +} diff --git a/tests/assets/prisma/example.prisma b/tests/assets/prisma/example.prisma new file mode 100644 index 0000000..134649e --- /dev/null +++ b/tests/assets/prisma/example.prisma @@ -0,0 +1,46 @@ +datasource db { + provider = "postgresql" +} + +generator client { + provider = "prisma-client" + output = "./generated" +} + +model User { + id Int @id @default(autoincrement()) + email String @unique + name String? + role Role @default(USER) + posts Post[] + profile Profile? +} + +model Profile { + id Int @id @default(autoincrement()) + bio String + user User @relation(fields: [userId], references: [id]) + userId Int @unique +} + +model Post { + id Int @id @default(autoincrement()) + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt + title String + published Boolean @default(false) + author User @relation(fields: [authorId], references: [id]) + authorId Int + categories Category[] +} + +model Category { + id Int @id @default(autoincrement()) + name String + posts Post[] +} + +enum Role { + USER + ADMIN +} \ No newline at end of file diff --git a/tests/assets/typeorm/example.ts b/tests/assets/typeorm/example.ts new file mode 100644 index 0000000..8bfd59a --- /dev/null +++ b/tests/assets/typeorm/example.ts @@ -0,0 +1,115 @@ +//@ts-nocheck +import { Entity, PrimaryGeneratedColumn, Column, ManyToOne, OneToMany, ManyToMany, JoinTable, CreateDateColumn, UpdateDateColumn } from 'typeorm'; + +@Entity() +export class User { + @PrimaryGeneratedColumn('uuid') + id: string; + + @Column({ unique: true }) + email: string; + + @Column() + name: string; + + @CreateDateColumn() + createdAt: Date; + + @UpdateDateColumn() + updatedAt: Date; + + @OneToMany(() => Project, project => project.owner) + ownedProjects: Project[]; + + @ManyToMany(() => Project, project => project.members) + @JoinTable() + projects: Project[]; +} + +@Entity() +export class Project { + @PrimaryGeneratedColumn('uuid') + id: string; + + @Column() + title: string; + + @Column({ nullable: true }) + description: string; + + @Column({ default: 'active' }) + status: string; + + @ManyToOne(() => User, user => user.ownedProjects) + owner: User; + + @ManyToMany(() => User, user => user.projects) + members: User[]; + + @OneToMany(() => Task, task => task.project) + tasks: Task[]; + + @CreateDateColumn() + createdAt: Date; +} + +@Entity() +export class Task { + @PrimaryGeneratedColumn('uuid') + id: string; + + @Column() + title: string; + + @Column({ type: 'text', nullable: true }) + description: string; + + @Column({ default: 'todo' }) + status: string; + + @Column({ nullable: true }) + dueDate: Date; + + @ManyToOne(() => Project, project => project.tasks) + project: Project; + + @ManyToOne(() => User, { nullable: true }) + assignee: User; + + @OneToMany(() => Comment, comment => comment.task) + comments: Comment[]; +} + +@Entity() +export class Comment { + @PrimaryGeneratedColumn('uuid') + id: string; + + @Column('text') + content: string; + + @ManyToOne(() => Task, task => task.comments) + task: Task; + + @ManyToOne(() => User) + author: User; + + @CreateDateColumn() + createdAt: Date; +} + +@Entity() +export class Tag { + @PrimaryGeneratedColumn('uuid') + id: string; + + @Column({ unique: true }) + name: string; + + @Column() + color: string; + + @ManyToMany(() => Task) + @JoinTable() + tasks: Task[]; +} \ No newline at end of file