816 lines
23 KiB
Go
816 lines
23 KiB
Go
package prisma
|
|
|
|
import (
|
|
"bufio"
|
|
"fmt"
|
|
"os"
|
|
"regexp"
|
|
"strings"
|
|
|
|
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
|
"git.warky.dev/wdevs/relspecgo/pkg/readers"
|
|
)
|
|
|
|
// Reader implements the readers.Reader interface for Prisma schema format
|
|
type Reader struct {
|
|
options *readers.ReaderOptions
|
|
}
|
|
|
|
// NewReader creates a new Prisma reader with the given options
|
|
func NewReader(options *readers.ReaderOptions) *Reader {
|
|
return &Reader{
|
|
options: options,
|
|
}
|
|
}
|
|
|
|
// ReadDatabase reads and parses Prisma schema input, returning a Database model
|
|
func (r *Reader) ReadDatabase() (*models.Database, error) {
|
|
if r.options.FilePath == "" {
|
|
return nil, fmt.Errorf("file path is required for Prisma reader")
|
|
}
|
|
|
|
content, err := os.ReadFile(r.options.FilePath)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to read file: %w", err)
|
|
}
|
|
|
|
return r.parsePrisma(string(content))
|
|
}
|
|
|
|
// ReadSchema reads and parses Prisma schema input, returning a Schema model
|
|
func (r *Reader) ReadSchema() (*models.Schema, error) {
|
|
db, err := r.ReadDatabase()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if len(db.Schemas) == 0 {
|
|
return nil, fmt.Errorf("no schemas found in Prisma schema")
|
|
}
|
|
|
|
// Return the first schema
|
|
return db.Schemas[0], nil
|
|
}
|
|
|
|
// ReadTable reads and parses Prisma schema input, returning a Table model
|
|
func (r *Reader) ReadTable() (*models.Table, error) {
|
|
schema, err := r.ReadSchema()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if len(schema.Tables) == 0 {
|
|
return nil, fmt.Errorf("no tables found in Prisma schema")
|
|
}
|
|
|
|
// Return the first table
|
|
return schema.Tables[0], nil
|
|
}
|
|
|
|
// parsePrisma parses Prisma schema content and returns a Database model
|
|
func (r *Reader) parsePrisma(content string) (*models.Database, error) {
|
|
db := models.InitDatabase("database")
|
|
|
|
if r.options.Metadata != nil {
|
|
if name, ok := r.options.Metadata["name"].(string); ok {
|
|
db.Name = name
|
|
}
|
|
}
|
|
|
|
// Default schema for Prisma (doesn't have explicit schema concept in most cases)
|
|
schema := models.InitSchema("public")
|
|
schema.Enums = make([]*models.Enum, 0)
|
|
|
|
scanner := bufio.NewScanner(strings.NewReader(content))
|
|
|
|
// State tracking
|
|
var currentBlock string // "datasource", "generator", "model", "enum"
|
|
var currentTable *models.Table
|
|
var currentEnum *models.Enum
|
|
var blockContent []string
|
|
|
|
// Regex patterns
|
|
datasourceRegex := regexp.MustCompile(`^datasource\s+\w+\s*{`)
|
|
generatorRegex := regexp.MustCompile(`^generator\s+\w+\s*{`)
|
|
modelRegex := regexp.MustCompile(`^model\s+(\w+)\s*{`)
|
|
enumRegex := regexp.MustCompile(`^enum\s+(\w+)\s*{`)
|
|
|
|
for scanner.Scan() {
|
|
line := scanner.Text()
|
|
trimmed := strings.TrimSpace(line)
|
|
|
|
// Skip empty lines and comments
|
|
if trimmed == "" || strings.HasPrefix(trimmed, "//") {
|
|
continue
|
|
}
|
|
|
|
// Check for block start
|
|
if matches := datasourceRegex.FindStringSubmatch(trimmed); matches != nil {
|
|
currentBlock = "datasource"
|
|
blockContent = []string{}
|
|
continue
|
|
}
|
|
|
|
if matches := generatorRegex.FindStringSubmatch(trimmed); matches != nil {
|
|
currentBlock = "generator"
|
|
blockContent = []string{}
|
|
continue
|
|
}
|
|
|
|
if matches := modelRegex.FindStringSubmatch(trimmed); matches != nil {
|
|
currentBlock = "model"
|
|
tableName := matches[1]
|
|
currentTable = models.InitTable(tableName, "public")
|
|
blockContent = []string{}
|
|
continue
|
|
}
|
|
|
|
if matches := enumRegex.FindStringSubmatch(trimmed); matches != nil {
|
|
currentBlock = "enum"
|
|
enumName := matches[1]
|
|
currentEnum = &models.Enum{
|
|
Name: enumName,
|
|
Schema: "public",
|
|
Values: make([]string, 0),
|
|
}
|
|
blockContent = []string{}
|
|
continue
|
|
}
|
|
|
|
// Check for block end
|
|
if trimmed == "}" {
|
|
switch currentBlock {
|
|
case "datasource":
|
|
r.parseDatasource(blockContent, db)
|
|
case "generator":
|
|
// We don't need to do anything with generator blocks
|
|
case "model":
|
|
if currentTable != nil {
|
|
r.parseModelFields(blockContent, currentTable)
|
|
schema.Tables = append(schema.Tables, currentTable)
|
|
currentTable = nil
|
|
}
|
|
case "enum":
|
|
if currentEnum != nil {
|
|
schema.Enums = append(schema.Enums, currentEnum)
|
|
currentEnum = nil
|
|
}
|
|
}
|
|
currentBlock = ""
|
|
blockContent = []string{}
|
|
continue
|
|
}
|
|
|
|
// Accumulate block content
|
|
if currentBlock != "" {
|
|
if currentBlock == "enum" && currentEnum != nil {
|
|
// For enums, just add the trimmed value
|
|
if trimmed != "" {
|
|
currentEnum.Values = append(currentEnum.Values, trimmed)
|
|
}
|
|
} else {
|
|
blockContent = append(blockContent, line)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Second pass: resolve relationships
|
|
r.resolveRelationships(schema)
|
|
|
|
db.Schemas = append(db.Schemas, schema)
|
|
return db, nil
|
|
}
|
|
|
|
// parseDatasource extracts database type from datasource block
|
|
func (r *Reader) parseDatasource(lines []string, db *models.Database) {
|
|
providerRegex := regexp.MustCompile(`provider\s*=\s*"?(\w+)"?`)
|
|
|
|
for _, line := range lines {
|
|
if matches := providerRegex.FindStringSubmatch(line); matches != nil {
|
|
provider := matches[1]
|
|
switch provider {
|
|
case "postgresql", "postgres":
|
|
db.DatabaseType = models.PostgresqlDatabaseType
|
|
case "mysql":
|
|
db.DatabaseType = "mysql"
|
|
case "sqlite":
|
|
db.DatabaseType = models.SqlLiteDatabaseType
|
|
case "sqlserver":
|
|
db.DatabaseType = models.MSSQLDatabaseType
|
|
default:
|
|
db.DatabaseType = models.PostgresqlDatabaseType
|
|
}
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
// parseModelFields parses model field definitions
|
|
func (r *Reader) parseModelFields(lines []string, table *models.Table) {
|
|
fieldRegex := regexp.MustCompile(`^(\w+)\s+(\w+)(\?|\[\])?\s*(@.+)?`)
|
|
blockAttrRegex := regexp.MustCompile(`^@@(\w+)\((.*?)\)`)
|
|
|
|
for _, line := range lines {
|
|
trimmed := strings.TrimSpace(line)
|
|
|
|
// Skip empty lines and comments
|
|
if trimmed == "" || strings.HasPrefix(trimmed, "//") {
|
|
continue
|
|
}
|
|
|
|
// Check for block attributes (@@id, @@unique, @@index)
|
|
if matches := blockAttrRegex.FindStringSubmatch(trimmed); matches != nil {
|
|
attrName := matches[1]
|
|
attrContent := matches[2]
|
|
r.parseBlockAttribute(attrName, attrContent, table)
|
|
continue
|
|
}
|
|
|
|
// Parse field definition
|
|
if matches := fieldRegex.FindStringSubmatch(trimmed); matches != nil {
|
|
fieldName := matches[1]
|
|
fieldType := matches[2]
|
|
modifier := matches[3] // ? or []
|
|
attributes := matches[4] // @... part
|
|
|
|
column := r.parseField(fieldName, fieldType, modifier, attributes, table)
|
|
if column != nil {
|
|
table.Columns[column.Name] = column
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// parseField parses a single field definition
|
|
func (r *Reader) parseField(name, fieldType, modifier, attributes string, table *models.Table) *models.Column {
|
|
// Check if this is a relation field (array or references another model)
|
|
if modifier == "[]" {
|
|
// Array field - this is a relation field, not a column
|
|
// We'll handle this in relationship resolution
|
|
return nil
|
|
}
|
|
|
|
// Check if this is a non-primitive type (relation field)
|
|
// Note: We need to allow enum types through as they're like primitives
|
|
if !r.isPrimitiveType(fieldType) && !r.isEnumType(fieldType, table) {
|
|
// This is a relation field (e.g., user User), not a scalar column
|
|
// Only process this if it has @relation attribute (which means it's the owning side with FK)
|
|
// Otherwise skip it as it's just the inverse relation field
|
|
if attributes == "" || !strings.Contains(attributes, "@relation") {
|
|
return nil
|
|
}
|
|
// If it has @relation, we still don't create a column for it
|
|
// The actual FK column will be in the fields: [...] part of @relation
|
|
return nil
|
|
}
|
|
|
|
column := models.InitColumn(name, table.Name, table.Schema)
|
|
|
|
// Map Prisma type to SQL type
|
|
column.Type = r.prismaTypeToSQL(fieldType)
|
|
|
|
// Handle modifiers
|
|
if modifier == "?" {
|
|
column.NotNull = false
|
|
} else {
|
|
// Default: required fields are NOT NULL
|
|
column.NotNull = true
|
|
}
|
|
|
|
// Parse field attributes
|
|
if attributes != "" {
|
|
r.parseFieldAttributes(attributes, column, table)
|
|
}
|
|
|
|
return column
|
|
}
|
|
|
|
// prismaTypeToSQL converts Prisma types to SQL types
|
|
func (r *Reader) prismaTypeToSQL(prismaType string) string {
|
|
typeMap := map[string]string{
|
|
"String": "text",
|
|
"Boolean": "boolean",
|
|
"Int": "integer",
|
|
"BigInt": "bigint",
|
|
"Float": "double precision",
|
|
"Decimal": "decimal",
|
|
"DateTime": "timestamp",
|
|
"Json": "jsonb",
|
|
"Bytes": "bytea",
|
|
}
|
|
|
|
if sqlType, ok := typeMap[prismaType]; ok {
|
|
return sqlType
|
|
}
|
|
|
|
// If not a built-in type, it might be an enum or model reference
|
|
// For enums, we'll use the enum name directly
|
|
return prismaType
|
|
}
|
|
|
|
// parseFieldAttributes parses field attributes like @id, @unique, @default
|
|
func (r *Reader) parseFieldAttributes(attributes string, column *models.Column, table *models.Table) {
|
|
// @id attribute
|
|
if strings.Contains(attributes, "@id") {
|
|
column.IsPrimaryKey = true
|
|
column.NotNull = true
|
|
}
|
|
|
|
// @unique attribute
|
|
if regexp.MustCompile(`@unique\b`).MatchString(attributes) {
|
|
uniqueConstraint := models.InitConstraint(
|
|
fmt.Sprintf("uq_%s", column.Name),
|
|
models.UniqueConstraint,
|
|
)
|
|
uniqueConstraint.Schema = table.Schema
|
|
uniqueConstraint.Table = table.Name
|
|
uniqueConstraint.Columns = []string{column.Name}
|
|
table.Constraints[uniqueConstraint.Name] = uniqueConstraint
|
|
}
|
|
|
|
// @default attribute - extract value with balanced parentheses
|
|
if strings.Contains(attributes, "@default(") {
|
|
defaultValue := r.extractDefaultValue(attributes)
|
|
if defaultValue != "" {
|
|
r.parseDefaultValue(defaultValue, column)
|
|
}
|
|
}
|
|
|
|
// @updatedAt attribute - store in comment for now
|
|
if strings.Contains(attributes, "@updatedAt") {
|
|
if column.Comment != "" {
|
|
column.Comment += "; @updatedAt"
|
|
} else {
|
|
column.Comment = "@updatedAt"
|
|
}
|
|
}
|
|
|
|
// @relation attribute - we'll handle this in relationship resolution
|
|
// For now, just note that this field is part of a relation
|
|
}
|
|
|
|
// extractDefaultValue extracts the default value from @default(...) handling nested parentheses
|
|
func (r *Reader) extractDefaultValue(attributes string) string {
|
|
idx := strings.Index(attributes, "@default(")
|
|
if idx == -1 {
|
|
return ""
|
|
}
|
|
|
|
start := idx + len("@default(")
|
|
depth := 1
|
|
i := start
|
|
|
|
for i < len(attributes) && depth > 0 {
|
|
switch attributes[i] {
|
|
case '(':
|
|
depth++
|
|
case ')':
|
|
depth--
|
|
}
|
|
i++
|
|
}
|
|
|
|
if depth == 0 {
|
|
return attributes[start : i-1]
|
|
}
|
|
|
|
return ""
|
|
}
|
|
|
|
// parseDefaultValue parses Prisma default value expressions
|
|
func (r *Reader) parseDefaultValue(defaultExpr string, column *models.Column) {
|
|
defaultExpr = strings.TrimSpace(defaultExpr)
|
|
|
|
switch defaultExpr {
|
|
case "autoincrement()":
|
|
column.AutoIncrement = true
|
|
case "now()":
|
|
column.Default = "now()"
|
|
case "uuid()":
|
|
column.Default = "gen_random_uuid()"
|
|
case "cuid()":
|
|
// CUID is Prisma-specific, store in comment
|
|
if column.Comment != "" {
|
|
column.Comment += "; default(cuid())"
|
|
} else {
|
|
column.Comment = "default(cuid())"
|
|
}
|
|
case "true":
|
|
column.Default = true
|
|
case "false":
|
|
column.Default = false
|
|
default:
|
|
// Check if it's a string literal
|
|
if strings.HasPrefix(defaultExpr, "\"") && strings.HasSuffix(defaultExpr, "\"") {
|
|
column.Default = defaultExpr[1 : len(defaultExpr)-1]
|
|
} else if strings.HasPrefix(defaultExpr, "'") && strings.HasSuffix(defaultExpr, "'") {
|
|
column.Default = defaultExpr[1 : len(defaultExpr)-1]
|
|
} else {
|
|
// Try to parse as number or enum value
|
|
column.Default = defaultExpr
|
|
}
|
|
}
|
|
}
|
|
|
|
// parseBlockAttribute parses block-level attributes like @@id, @@unique, @@index
|
|
func (r *Reader) parseBlockAttribute(attrName, content string, table *models.Table) {
|
|
// Extract column list from brackets [col1, col2]
|
|
colListRegex := regexp.MustCompile(`\[(.*?)\]`)
|
|
matches := colListRegex.FindStringSubmatch(content)
|
|
if matches == nil {
|
|
return
|
|
}
|
|
|
|
columnList := strings.Split(matches[1], ",")
|
|
columns := make([]string, 0)
|
|
for _, col := range columnList {
|
|
columns = append(columns, strings.TrimSpace(col))
|
|
}
|
|
|
|
switch attrName {
|
|
case "id":
|
|
// Composite primary key
|
|
for _, colName := range columns {
|
|
if col, exists := table.Columns[colName]; exists {
|
|
col.IsPrimaryKey = true
|
|
col.NotNull = true
|
|
}
|
|
}
|
|
// Also create a PK constraint
|
|
pkConstraint := models.InitConstraint(
|
|
fmt.Sprintf("pk_%s", table.Name),
|
|
models.PrimaryKeyConstraint,
|
|
)
|
|
pkConstraint.Schema = table.Schema
|
|
pkConstraint.Table = table.Name
|
|
pkConstraint.Columns = columns
|
|
table.Constraints[pkConstraint.Name] = pkConstraint
|
|
|
|
case "unique":
|
|
// Multi-column unique constraint
|
|
uniqueConstraint := models.InitConstraint(
|
|
fmt.Sprintf("uq_%s_%s", table.Name, strings.Join(columns, "_")),
|
|
models.UniqueConstraint,
|
|
)
|
|
uniqueConstraint.Schema = table.Schema
|
|
uniqueConstraint.Table = table.Name
|
|
uniqueConstraint.Columns = columns
|
|
table.Constraints[uniqueConstraint.Name] = uniqueConstraint
|
|
|
|
case "index":
|
|
// Index
|
|
index := models.InitIndex(
|
|
fmt.Sprintf("idx_%s_%s", table.Name, strings.Join(columns, "_")),
|
|
table.Name,
|
|
table.Schema,
|
|
)
|
|
index.Columns = columns
|
|
table.Indexes[index.Name] = index
|
|
}
|
|
}
|
|
|
|
// relationField stores information about a relation field for second-pass processing
|
|
type relationField struct {
|
|
tableName string
|
|
fieldName string
|
|
relatedModel string
|
|
isArray bool
|
|
relationAttr string
|
|
}
|
|
|
|
// resolveRelationships performs a second pass to resolve @relation attributes
|
|
func (r *Reader) resolveRelationships(schema *models.Schema) {
|
|
// Build a map of table names for quick lookup
|
|
tableMap := make(map[string]*models.Table)
|
|
for _, table := range schema.Tables {
|
|
tableMap[table.Name] = table
|
|
}
|
|
|
|
// First, we need to re-parse to find relation fields
|
|
// We'll re-read the file to extract relation information
|
|
if r.options.FilePath == "" {
|
|
return
|
|
}
|
|
|
|
content, err := os.ReadFile(r.options.FilePath)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
relations := r.extractRelationFields(string(content))
|
|
|
|
// Process explicit @relation attributes to create FK constraints
|
|
for _, rel := range relations {
|
|
if rel.relationAttr != "" {
|
|
r.createConstraintFromRelation(rel, tableMap, schema)
|
|
}
|
|
}
|
|
|
|
// Detect implicit many-to-many relationships
|
|
r.detectImplicitManyToMany(relations, tableMap, schema)
|
|
}
|
|
|
|
// extractRelationFields extracts relation field information from the schema
|
|
func (r *Reader) extractRelationFields(content string) []relationField {
|
|
relations := make([]relationField, 0)
|
|
scanner := bufio.NewScanner(strings.NewReader(content))
|
|
|
|
modelRegex := regexp.MustCompile(`^model\s+(\w+)\s*{`)
|
|
fieldRegex := regexp.MustCompile(`^(\w+)\s+(\w+)(\?|\[\])?\s*(@.+)?`)
|
|
|
|
var currentModel string
|
|
inModel := false
|
|
|
|
for scanner.Scan() {
|
|
line := scanner.Text()
|
|
trimmed := strings.TrimSpace(line)
|
|
|
|
if trimmed == "" || strings.HasPrefix(trimmed, "//") {
|
|
continue
|
|
}
|
|
|
|
if matches := modelRegex.FindStringSubmatch(trimmed); matches != nil {
|
|
currentModel = matches[1]
|
|
inModel = true
|
|
continue
|
|
}
|
|
|
|
if trimmed == "}" {
|
|
inModel = false
|
|
currentModel = ""
|
|
continue
|
|
}
|
|
|
|
if inModel && currentModel != "" {
|
|
if matches := fieldRegex.FindStringSubmatch(trimmed); matches != nil {
|
|
fieldName := matches[1]
|
|
fieldType := matches[2]
|
|
modifier := matches[3]
|
|
attributes := matches[4]
|
|
|
|
// Check if this is a relation field (references another model or is an array)
|
|
isPotentialRelation := modifier == "[]" || !r.isPrimitiveType(fieldType)
|
|
|
|
if isPotentialRelation {
|
|
rel := relationField{
|
|
tableName: currentModel,
|
|
fieldName: fieldName,
|
|
relatedModel: fieldType,
|
|
isArray: modifier == "[]",
|
|
relationAttr: attributes,
|
|
}
|
|
relations = append(relations, rel)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return relations
|
|
}
|
|
|
|
// isPrimitiveType checks if a type is a Prisma primitive type
|
|
func (r *Reader) isPrimitiveType(typeName string) bool {
|
|
primitives := []string{"String", "Boolean", "Int", "BigInt", "Float", "Decimal", "DateTime", "Json", "Bytes"}
|
|
for _, p := range primitives {
|
|
if typeName == p {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// isEnumType checks if a type name might be an enum
|
|
// Note: We can't definitively check against schema.Enums at parse time
|
|
// because enums might be defined after the model, so we just check
|
|
// if it starts with uppercase (Prisma convention for enums)
|
|
func (r *Reader) isEnumType(typeName string, table *models.Table) bool {
|
|
// Simple heuristic: enum types start with uppercase letter
|
|
// and are not known model names (though we can't check that yet)
|
|
if len(typeName) > 0 && typeName[0] >= 'A' && typeName[0] <= 'Z' {
|
|
// Additional check: primitive types are already handled above
|
|
// So if it's uppercase and not primitive, it's likely an enum or model
|
|
// We'll assume it's an enum if it's a single word
|
|
return !strings.Contains(typeName, "_")
|
|
}
|
|
return false
|
|
}
|
|
|
|
// createConstraintFromRelation creates a FK constraint from a @relation attribute
|
|
func (r *Reader) createConstraintFromRelation(rel relationField, tableMap map[string]*models.Table, schema *models.Schema) {
|
|
// Skip array fields (they are the inverse side of the relation)
|
|
if rel.isArray {
|
|
return
|
|
}
|
|
|
|
if rel.relationAttr == "" {
|
|
return
|
|
}
|
|
|
|
// Parse @relation attribute
|
|
relationRegex := regexp.MustCompile(`@relation\((.*?)\)`)
|
|
matches := relationRegex.FindStringSubmatch(rel.relationAttr)
|
|
if matches == nil {
|
|
return
|
|
}
|
|
|
|
relationContent := matches[1]
|
|
|
|
// Extract fields and references
|
|
fieldsRegex := regexp.MustCompile(`fields:\s*\[(.*?)\]`)
|
|
referencesRegex := regexp.MustCompile(`references:\s*\[(.*?)\]`)
|
|
nameRegex := regexp.MustCompile(`name:\s*"([^"]+)"`)
|
|
onDeleteRegex := regexp.MustCompile(`onDelete:\s*(\w+)`)
|
|
onUpdateRegex := regexp.MustCompile(`onUpdate:\s*(\w+)`)
|
|
|
|
fieldsMatch := fieldsRegex.FindStringSubmatch(relationContent)
|
|
referencesMatch := referencesRegex.FindStringSubmatch(relationContent)
|
|
|
|
if fieldsMatch == nil || referencesMatch == nil {
|
|
return
|
|
}
|
|
|
|
// Parse field and reference column lists
|
|
fieldCols := r.parseColumnList(fieldsMatch[1])
|
|
refCols := r.parseColumnList(referencesMatch[1])
|
|
|
|
if len(fieldCols) == 0 || len(refCols) == 0 {
|
|
return
|
|
}
|
|
|
|
// Create FK constraint
|
|
constraintName := fmt.Sprintf("fk_%s_%s", rel.tableName, fieldCols[0])
|
|
|
|
// Check for custom name
|
|
if nameMatch := nameRegex.FindStringSubmatch(relationContent); nameMatch != nil {
|
|
constraintName = nameMatch[1]
|
|
}
|
|
|
|
constraint := models.InitConstraint(constraintName, models.ForeignKeyConstraint)
|
|
constraint.Schema = "public"
|
|
constraint.Table = rel.tableName
|
|
constraint.Columns = fieldCols
|
|
constraint.ReferencedSchema = "public"
|
|
constraint.ReferencedTable = rel.relatedModel
|
|
constraint.ReferencedColumns = refCols
|
|
|
|
// Parse referential actions
|
|
if onDeleteMatch := onDeleteRegex.FindStringSubmatch(relationContent); onDeleteMatch != nil {
|
|
constraint.OnDelete = onDeleteMatch[1]
|
|
}
|
|
|
|
if onUpdateMatch := onUpdateRegex.FindStringSubmatch(relationContent); onUpdateMatch != nil {
|
|
constraint.OnUpdate = onUpdateMatch[1]
|
|
}
|
|
|
|
// Add constraint to table
|
|
if table, exists := tableMap[rel.tableName]; exists {
|
|
table.Constraints[constraint.Name] = constraint
|
|
}
|
|
}
|
|
|
|
// parseColumnList parses a comma-separated list of column names
|
|
func (r *Reader) parseColumnList(list string) []string {
|
|
parts := strings.Split(list, ",")
|
|
result := make([]string, 0)
|
|
for _, part := range parts {
|
|
trimmed := strings.TrimSpace(part)
|
|
if trimmed != "" {
|
|
result = append(result, trimmed)
|
|
}
|
|
}
|
|
return result
|
|
}
|
|
|
|
// detectImplicitManyToMany detects implicit M2M relationships and creates join tables
|
|
func (r *Reader) detectImplicitManyToMany(relations []relationField, tableMap map[string]*models.Table, schema *models.Schema) {
|
|
// Group relations by model pairs
|
|
type modelPair struct {
|
|
model1 string
|
|
model2 string
|
|
}
|
|
|
|
pairMap := make(map[modelPair][]relationField)
|
|
|
|
for _, rel := range relations {
|
|
if !rel.isArray || rel.relationAttr != "" {
|
|
// Skip non-array fields and explicit relations
|
|
continue
|
|
}
|
|
|
|
// Create a normalized pair (alphabetically sorted to avoid duplicates)
|
|
pair := modelPair{}
|
|
if rel.tableName < rel.relatedModel {
|
|
pair.model1 = rel.tableName
|
|
pair.model2 = rel.relatedModel
|
|
} else {
|
|
pair.model1 = rel.relatedModel
|
|
pair.model2 = rel.tableName
|
|
}
|
|
|
|
pairMap[pair] = append(pairMap[pair], rel)
|
|
}
|
|
|
|
// Check for pairs with arrays on both sides (implicit M2M)
|
|
for pair, rels := range pairMap {
|
|
if len(rels) >= 2 {
|
|
// This is an implicit many-to-many relationship
|
|
r.createImplicitJoinTable(pair.model1, pair.model2, tableMap, schema)
|
|
}
|
|
}
|
|
}
|
|
|
|
// createImplicitJoinTable creates a virtual join table for implicit M2M relations
|
|
func (r *Reader) createImplicitJoinTable(model1, model2 string, tableMap map[string]*models.Table, schema *models.Schema) {
|
|
// Prisma naming convention: _Model1ToModel2 (alphabetically sorted)
|
|
joinTableName := fmt.Sprintf("_%sTo%s", model1, model2)
|
|
|
|
// Check if join table already exists
|
|
if _, exists := tableMap[joinTableName]; exists {
|
|
return
|
|
}
|
|
|
|
// Create join table
|
|
joinTable := models.InitTable(joinTableName, "public")
|
|
|
|
// Get primary keys from both tables
|
|
pk1 := r.getPrimaryKeyColumn(tableMap[model1])
|
|
pk2 := r.getPrimaryKeyColumn(tableMap[model2])
|
|
|
|
if pk1 == nil || pk2 == nil {
|
|
return // Can't create join table without PKs
|
|
}
|
|
|
|
// Create FK columns in join table
|
|
fkCol1Name := fmt.Sprintf("%sId", model1)
|
|
fkCol1 := models.InitColumn(fkCol1Name, joinTableName, "public")
|
|
fkCol1.Type = pk1.Type
|
|
fkCol1.NotNull = true
|
|
joinTable.Columns[fkCol1Name] = fkCol1
|
|
|
|
fkCol2Name := fmt.Sprintf("%sId", model2)
|
|
fkCol2 := models.InitColumn(fkCol2Name, joinTableName, "public")
|
|
fkCol2.Type = pk2.Type
|
|
fkCol2.NotNull = true
|
|
joinTable.Columns[fkCol2Name] = fkCol2
|
|
|
|
// Create composite primary key
|
|
pkConstraint := models.InitConstraint(
|
|
fmt.Sprintf("pk_%s", joinTableName),
|
|
models.PrimaryKeyConstraint,
|
|
)
|
|
pkConstraint.Schema = "public"
|
|
pkConstraint.Table = joinTableName
|
|
pkConstraint.Columns = []string{fkCol1Name, fkCol2Name}
|
|
joinTable.Constraints[pkConstraint.Name] = pkConstraint
|
|
|
|
// Mark columns as PK
|
|
fkCol1.IsPrimaryKey = true
|
|
fkCol2.IsPrimaryKey = true
|
|
|
|
// Create FK constraints
|
|
fk1 := models.InitConstraint(
|
|
fmt.Sprintf("fk_%s_%s", joinTableName, model1),
|
|
models.ForeignKeyConstraint,
|
|
)
|
|
fk1.Schema = "public"
|
|
fk1.Table = joinTableName
|
|
fk1.Columns = []string{fkCol1Name}
|
|
fk1.ReferencedSchema = "public"
|
|
fk1.ReferencedTable = model1
|
|
fk1.ReferencedColumns = []string{pk1.Name}
|
|
fk1.OnDelete = "Cascade"
|
|
joinTable.Constraints[fk1.Name] = fk1
|
|
|
|
fk2 := models.InitConstraint(
|
|
fmt.Sprintf("fk_%s_%s", joinTableName, model2),
|
|
models.ForeignKeyConstraint,
|
|
)
|
|
fk2.Schema = "public"
|
|
fk2.Table = joinTableName
|
|
fk2.Columns = []string{fkCol2Name}
|
|
fk2.ReferencedSchema = "public"
|
|
fk2.ReferencedTable = model2
|
|
fk2.ReferencedColumns = []string{pk2.Name}
|
|
fk2.OnDelete = "Cascade"
|
|
joinTable.Constraints[fk2.Name] = fk2
|
|
|
|
// Add join table to schema
|
|
schema.Tables = append(schema.Tables, joinTable)
|
|
tableMap[joinTableName] = joinTable
|
|
}
|
|
|
|
// getPrimaryKeyColumn returns the primary key column of a table
|
|
func (r *Reader) getPrimaryKeyColumn(table *models.Table) *models.Column {
|
|
if table == nil {
|
|
return nil
|
|
}
|
|
|
|
for _, col := range table.Columns {
|
|
if col.IsPrimaryKey {
|
|
return col
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|