package prisma import ( "bufio" "fmt" "os" "regexp" "strings" "git.warky.dev/wdevs/relspecgo/pkg/models" "git.warky.dev/wdevs/relspecgo/pkg/readers" ) // Reader implements the readers.Reader interface for Prisma schema format type Reader struct { options *readers.ReaderOptions } // NewReader creates a new Prisma reader with the given options func NewReader(options *readers.ReaderOptions) *Reader { return &Reader{ options: options, } } // ReadDatabase reads and parses Prisma schema input, returning a Database model func (r *Reader) ReadDatabase() (*models.Database, error) { if r.options.FilePath == "" { return nil, fmt.Errorf("file path is required for Prisma reader") } content, err := os.ReadFile(r.options.FilePath) if err != nil { return nil, fmt.Errorf("failed to read file: %w", err) } return r.parsePrisma(string(content)) } // ReadSchema reads and parses Prisma schema input, returning a Schema model func (r *Reader) ReadSchema() (*models.Schema, error) { db, err := r.ReadDatabase() if err != nil { return nil, err } if len(db.Schemas) == 0 { return nil, fmt.Errorf("no schemas found in Prisma schema") } // Return the first schema return db.Schemas[0], nil } // ReadTable reads and parses Prisma schema input, returning a Table model func (r *Reader) ReadTable() (*models.Table, error) { schema, err := r.ReadSchema() if err != nil { return nil, err } if len(schema.Tables) == 0 { return nil, fmt.Errorf("no tables found in Prisma schema") } // Return the first table return schema.Tables[0], nil } // parsePrisma parses Prisma schema content and returns a Database model func (r *Reader) parsePrisma(content string) (*models.Database, error) { db := models.InitDatabase("database") if r.options.Metadata != nil { if name, ok := r.options.Metadata["name"].(string); ok { db.Name = name } } // Default schema for Prisma (doesn't have explicit schema concept in most cases) schema := models.InitSchema("public") schema.Enums = make([]*models.Enum, 0) scanner := bufio.NewScanner(strings.NewReader(content)) // State tracking var currentBlock string // "datasource", "generator", "model", "enum" var currentTable *models.Table var currentEnum *models.Enum var blockContent []string // Regex patterns datasourceRegex := regexp.MustCompile(`^datasource\s+\w+\s*{`) generatorRegex := regexp.MustCompile(`^generator\s+\w+\s*{`) modelRegex := regexp.MustCompile(`^model\s+(\w+)\s*{`) enumRegex := regexp.MustCompile(`^enum\s+(\w+)\s*{`) for scanner.Scan() { line := scanner.Text() trimmed := strings.TrimSpace(line) // Skip empty lines and comments if trimmed == "" || strings.HasPrefix(trimmed, "//") { continue } // Check for block start if matches := datasourceRegex.FindStringSubmatch(trimmed); matches != nil { currentBlock = "datasource" blockContent = []string{} continue } if matches := generatorRegex.FindStringSubmatch(trimmed); matches != nil { currentBlock = "generator" blockContent = []string{} continue } if matches := modelRegex.FindStringSubmatch(trimmed); matches != nil { currentBlock = "model" tableName := matches[1] currentTable = models.InitTable(tableName, "public") blockContent = []string{} continue } if matches := enumRegex.FindStringSubmatch(trimmed); matches != nil { currentBlock = "enum" enumName := matches[1] currentEnum = models.InitEnum(enumName, "public") blockContent = []string{} continue } // Check for block end if trimmed == "}" { switch currentBlock { case "datasource": r.parseDatasource(blockContent, db) case "generator": // We don't need to do anything with generator blocks case "model": if currentTable != nil { r.parseModelFields(blockContent, currentTable) schema.Tables = append(schema.Tables, currentTable) currentTable = nil } case "enum": if currentEnum != nil { schema.Enums = append(schema.Enums, currentEnum) currentEnum = nil } } currentBlock = "" blockContent = []string{} continue } // Accumulate block content if currentBlock != "" { if currentBlock == "enum" && currentEnum != nil { // For enums, just add the trimmed value if trimmed != "" { currentEnum.Values = append(currentEnum.Values, trimmed) } } else { blockContent = append(blockContent, line) } } } // Second pass: resolve relationships r.resolveRelationships(schema) db.Schemas = append(db.Schemas, schema) return db, nil } // parseDatasource extracts database type from datasource block func (r *Reader) parseDatasource(lines []string, db *models.Database) { providerRegex := regexp.MustCompile(`provider\s*=\s*"?(\w+)"?`) for _, line := range lines { if matches := providerRegex.FindStringSubmatch(line); matches != nil { provider := matches[1] switch provider { case "postgresql", "postgres": db.DatabaseType = models.PostgresqlDatabaseType case "mysql": db.DatabaseType = "mysql" case "sqlite": db.DatabaseType = models.SqlLiteDatabaseType case "sqlserver": db.DatabaseType = models.MSSQLDatabaseType default: db.DatabaseType = models.PostgresqlDatabaseType } break } } } // parseModelFields parses model field definitions func (r *Reader) parseModelFields(lines []string, table *models.Table) { fieldRegex := regexp.MustCompile(`^(\w+)\s+(\w+)(\?|\[\])?\s*(@.+)?`) blockAttrRegex := regexp.MustCompile(`^@@(\w+)\((.*?)\)`) for _, line := range lines { trimmed := strings.TrimSpace(line) // Skip empty lines and comments if trimmed == "" || strings.HasPrefix(trimmed, "//") { continue } // Check for block attributes (@@id, @@unique, @@index) if matches := blockAttrRegex.FindStringSubmatch(trimmed); matches != nil { attrName := matches[1] attrContent := matches[2] r.parseBlockAttribute(attrName, attrContent, table) continue } // Parse field definition if matches := fieldRegex.FindStringSubmatch(trimmed); matches != nil { fieldName := matches[1] fieldType := matches[2] modifier := matches[3] // ? or [] attributes := matches[4] // @... part column := r.parseField(fieldName, fieldType, modifier, attributes, table) if column != nil { table.Columns[column.Name] = column } } } } // parseField parses a single field definition func (r *Reader) parseField(name, fieldType, modifier, attributes string, table *models.Table) *models.Column { // Check if this is a relation field (array or references another model) if modifier == "[]" { // Array field - this is a relation field, not a column // We'll handle this in relationship resolution return nil } // Check if this is a non-primitive type (relation field) // Note: We need to allow enum types through as they're like primitives if !r.isPrimitiveType(fieldType) && !r.isEnumType(fieldType, table) { // This is a relation field (e.g., user User), not a scalar column // Only process this if it has @relation attribute (which means it's the owning side with FK) // Otherwise skip it as it's just the inverse relation field if attributes == "" || !strings.Contains(attributes, "@relation") { return nil } // If it has @relation, we still don't create a column for it // The actual FK column will be in the fields: [...] part of @relation return nil } column := models.InitColumn(name, table.Name, table.Schema) // Map Prisma type to SQL type column.Type = r.prismaTypeToSQL(fieldType) // Handle modifiers if modifier == "?" { column.NotNull = false } else { // Default: required fields are NOT NULL column.NotNull = true } // Parse field attributes if attributes != "" { r.parseFieldAttributes(attributes, column, table) } return column } // prismaTypeToSQL converts Prisma types to SQL types func (r *Reader) prismaTypeToSQL(prismaType string) string { typeMap := map[string]string{ "String": "text", "Boolean": "boolean", "Int": "integer", "BigInt": "bigint", "Float": "double precision", "Decimal": "decimal", "DateTime": "timestamp", "Json": "jsonb", "Bytes": "bytea", } if sqlType, ok := typeMap[prismaType]; ok { return sqlType } // If not a built-in type, it might be an enum or model reference // For enums, we'll use the enum name directly return prismaType } // parseFieldAttributes parses field attributes like @id, @unique, @default func (r *Reader) parseFieldAttributes(attributes string, column *models.Column, table *models.Table) { // @id attribute if strings.Contains(attributes, "@id") { column.IsPrimaryKey = true column.NotNull = true } // @unique attribute if regexp.MustCompile(`@unique\b`).MatchString(attributes) { uniqueConstraint := models.InitConstraint( fmt.Sprintf("uq_%s", column.Name), models.UniqueConstraint, ) uniqueConstraint.Schema = table.Schema uniqueConstraint.Table = table.Name uniqueConstraint.Columns = []string{column.Name} table.Constraints[uniqueConstraint.Name] = uniqueConstraint } // @default attribute - extract value with balanced parentheses if strings.Contains(attributes, "@default(") { defaultValue := r.extractDefaultValue(attributes) if defaultValue != "" { r.parseDefaultValue(defaultValue, column) } } // @updatedAt attribute - store in comment for now if strings.Contains(attributes, "@updatedAt") { if column.Comment != "" { column.Comment += "; @updatedAt" } else { column.Comment = "@updatedAt" } } // @relation attribute - we'll handle this in relationship resolution // For now, just note that this field is part of a relation } // extractDefaultValue extracts the default value from @default(...) handling nested parentheses func (r *Reader) extractDefaultValue(attributes string) string { idx := strings.Index(attributes, "@default(") if idx == -1 { return "" } start := idx + len("@default(") depth := 1 i := start for i < len(attributes) && depth > 0 { switch attributes[i] { case '(': depth++ case ')': depth-- } i++ } if depth == 0 { return attributes[start : i-1] } return "" } // parseDefaultValue parses Prisma default value expressions func (r *Reader) parseDefaultValue(defaultExpr string, column *models.Column) { defaultExpr = strings.TrimSpace(defaultExpr) switch defaultExpr { case "autoincrement()": column.AutoIncrement = true case "now()": column.Default = "now()" case "uuid()": column.Default = "gen_random_uuid()" case "cuid()": // CUID is Prisma-specific, store in comment if column.Comment != "" { column.Comment += "; default(cuid())" } else { column.Comment = "default(cuid())" } case "true": column.Default = true case "false": column.Default = false default: // Check if it's a string literal if strings.HasPrefix(defaultExpr, "\"") && strings.HasSuffix(defaultExpr, "\"") { column.Default = defaultExpr[1 : len(defaultExpr)-1] } else if strings.HasPrefix(defaultExpr, "'") && strings.HasSuffix(defaultExpr, "'") { column.Default = defaultExpr[1 : len(defaultExpr)-1] } else { // Try to parse as number or enum value column.Default = defaultExpr } } } // parseBlockAttribute parses block-level attributes like @@id, @@unique, @@index func (r *Reader) parseBlockAttribute(attrName, content string, table *models.Table) { // Extract column list from brackets [col1, col2] colListRegex := regexp.MustCompile(`\[(.*?)\]`) matches := colListRegex.FindStringSubmatch(content) if matches == nil { return } columnList := strings.Split(matches[1], ",") columns := make([]string, 0) for _, col := range columnList { columns = append(columns, strings.TrimSpace(col)) } switch attrName { case "id": // Composite primary key for _, colName := range columns { if col, exists := table.Columns[colName]; exists { col.IsPrimaryKey = true col.NotNull = true } } // Also create a PK constraint pkConstraint := models.InitConstraint( fmt.Sprintf("pk_%s", table.Name), models.PrimaryKeyConstraint, ) pkConstraint.Schema = table.Schema pkConstraint.Table = table.Name pkConstraint.Columns = columns table.Constraints[pkConstraint.Name] = pkConstraint case "unique": // Multi-column unique constraint uniqueConstraint := models.InitConstraint( fmt.Sprintf("uq_%s_%s", table.Name, strings.Join(columns, "_")), models.UniqueConstraint, ) uniqueConstraint.Schema = table.Schema uniqueConstraint.Table = table.Name uniqueConstraint.Columns = columns table.Constraints[uniqueConstraint.Name] = uniqueConstraint case "index": // Index index := models.InitIndex( fmt.Sprintf("idx_%s_%s", table.Name, strings.Join(columns, "_")), table.Name, table.Schema, ) index.Columns = columns table.Indexes[index.Name] = index } } // relationField stores information about a relation field for second-pass processing type relationField struct { tableName string fieldName string relatedModel string isArray bool relationAttr string } // resolveRelationships performs a second pass to resolve @relation attributes func (r *Reader) resolveRelationships(schema *models.Schema) { // Build a map of table names for quick lookup tableMap := make(map[string]*models.Table) for _, table := range schema.Tables { tableMap[table.Name] = table } // First, we need to re-parse to find relation fields // We'll re-read the file to extract relation information if r.options.FilePath == "" { return } content, err := os.ReadFile(r.options.FilePath) if err != nil { return } relations := r.extractRelationFields(string(content)) // Process explicit @relation attributes to create FK constraints for _, rel := range relations { if rel.relationAttr != "" { r.createConstraintFromRelation(rel, tableMap, schema) } } // Detect implicit many-to-many relationships r.detectImplicitManyToMany(relations, tableMap, schema) } // extractRelationFields extracts relation field information from the schema func (r *Reader) extractRelationFields(content string) []relationField { relations := make([]relationField, 0) scanner := bufio.NewScanner(strings.NewReader(content)) modelRegex := regexp.MustCompile(`^model\s+(\w+)\s*{`) fieldRegex := regexp.MustCompile(`^(\w+)\s+(\w+)(\?|\[\])?\s*(@.+)?`) var currentModel string inModel := false for scanner.Scan() { line := scanner.Text() trimmed := strings.TrimSpace(line) if trimmed == "" || strings.HasPrefix(trimmed, "//") { continue } if matches := modelRegex.FindStringSubmatch(trimmed); matches != nil { currentModel = matches[1] inModel = true continue } if trimmed == "}" { inModel = false currentModel = "" continue } if inModel && currentModel != "" { if matches := fieldRegex.FindStringSubmatch(trimmed); matches != nil { fieldName := matches[1] fieldType := matches[2] modifier := matches[3] attributes := matches[4] // Check if this is a relation field (references another model or is an array) isPotentialRelation := modifier == "[]" || !r.isPrimitiveType(fieldType) if isPotentialRelation { rel := relationField{ tableName: currentModel, fieldName: fieldName, relatedModel: fieldType, isArray: modifier == "[]", relationAttr: attributes, } relations = append(relations, rel) } } } } return relations } // isPrimitiveType checks if a type is a Prisma primitive type func (r *Reader) isPrimitiveType(typeName string) bool { primitives := []string{"String", "Boolean", "Int", "BigInt", "Float", "Decimal", "DateTime", "Json", "Bytes"} for _, p := range primitives { if typeName == p { return true } } return false } // isEnumType checks if a type name might be an enum // Note: We can't definitively check against schema.Enums at parse time // because enums might be defined after the model, so we just check // if it starts with uppercase (Prisma convention for enums) func (r *Reader) isEnumType(typeName string, table *models.Table) bool { // Simple heuristic: enum types start with uppercase letter // and are not known model names (though we can't check that yet) if len(typeName) > 0 && typeName[0] >= 'A' && typeName[0] <= 'Z' { // Additional check: primitive types are already handled above // So if it's uppercase and not primitive, it's likely an enum or model // We'll assume it's an enum if it's a single word return !strings.Contains(typeName, "_") } return false } // createConstraintFromRelation creates a FK constraint from a @relation attribute func (r *Reader) createConstraintFromRelation(rel relationField, tableMap map[string]*models.Table, schema *models.Schema) { // Skip array fields (they are the inverse side of the relation) if rel.isArray { return } if rel.relationAttr == "" { return } // Parse @relation attribute relationRegex := regexp.MustCompile(`@relation\((.*?)\)`) matches := relationRegex.FindStringSubmatch(rel.relationAttr) if matches == nil { return } relationContent := matches[1] // Extract fields and references fieldsRegex := regexp.MustCompile(`fields:\s*\[(.*?)\]`) referencesRegex := regexp.MustCompile(`references:\s*\[(.*?)\]`) nameRegex := regexp.MustCompile(`name:\s*"([^"]+)"`) onDeleteRegex := regexp.MustCompile(`onDelete:\s*(\w+)`) onUpdateRegex := regexp.MustCompile(`onUpdate:\s*(\w+)`) fieldsMatch := fieldsRegex.FindStringSubmatch(relationContent) referencesMatch := referencesRegex.FindStringSubmatch(relationContent) if fieldsMatch == nil || referencesMatch == nil { return } // Parse field and reference column lists fieldCols := r.parseColumnList(fieldsMatch[1]) refCols := r.parseColumnList(referencesMatch[1]) if len(fieldCols) == 0 || len(refCols) == 0 { return } // Create FK constraint constraintName := fmt.Sprintf("fk_%s_%s", rel.tableName, fieldCols[0]) // Check for custom name if nameMatch := nameRegex.FindStringSubmatch(relationContent); nameMatch != nil { constraintName = nameMatch[1] } constraint := models.InitConstraint(constraintName, models.ForeignKeyConstraint) constraint.Schema = "public" constraint.Table = rel.tableName constraint.Columns = fieldCols constraint.ReferencedSchema = "public" constraint.ReferencedTable = rel.relatedModel constraint.ReferencedColumns = refCols // Parse referential actions if onDeleteMatch := onDeleteRegex.FindStringSubmatch(relationContent); onDeleteMatch != nil { constraint.OnDelete = onDeleteMatch[1] } if onUpdateMatch := onUpdateRegex.FindStringSubmatch(relationContent); onUpdateMatch != nil { constraint.OnUpdate = onUpdateMatch[1] } // Add constraint to table if table, exists := tableMap[rel.tableName]; exists { table.Constraints[constraint.Name] = constraint } } // parseColumnList parses a comma-separated list of column names func (r *Reader) parseColumnList(list string) []string { parts := strings.Split(list, ",") result := make([]string, 0) for _, part := range parts { trimmed := strings.TrimSpace(part) if trimmed != "" { result = append(result, trimmed) } } return result } // detectImplicitManyToMany detects implicit M2M relationships and creates join tables func (r *Reader) detectImplicitManyToMany(relations []relationField, tableMap map[string]*models.Table, schema *models.Schema) { // Group relations by model pairs type modelPair struct { model1 string model2 string } pairMap := make(map[modelPair][]relationField) for _, rel := range relations { if !rel.isArray || rel.relationAttr != "" { // Skip non-array fields and explicit relations continue } // Create a normalized pair (alphabetically sorted to avoid duplicates) pair := modelPair{} if rel.tableName < rel.relatedModel { pair.model1 = rel.tableName pair.model2 = rel.relatedModel } else { pair.model1 = rel.relatedModel pair.model2 = rel.tableName } pairMap[pair] = append(pairMap[pair], rel) } // Check for pairs with arrays on both sides (implicit M2M) for pair, rels := range pairMap { if len(rels) >= 2 { // This is an implicit many-to-many relationship r.createImplicitJoinTable(pair.model1, pair.model2, tableMap, schema) } } } // createImplicitJoinTable creates a virtual join table for implicit M2M relations func (r *Reader) createImplicitJoinTable(model1, model2 string, tableMap map[string]*models.Table, schema *models.Schema) { // Prisma naming convention: _Model1ToModel2 (alphabetically sorted) joinTableName := fmt.Sprintf("_%sTo%s", model1, model2) // Check if join table already exists if _, exists := tableMap[joinTableName]; exists { return } // Create join table joinTable := models.InitTable(joinTableName, "public") // Get primary keys from both tables pk1 := r.getPrimaryKeyColumn(tableMap[model1]) pk2 := r.getPrimaryKeyColumn(tableMap[model2]) if pk1 == nil || pk2 == nil { return // Can't create join table without PKs } // Create FK columns in join table fkCol1Name := fmt.Sprintf("%sId", model1) fkCol1 := models.InitColumn(fkCol1Name, joinTableName, "public") fkCol1.Type = pk1.Type fkCol1.NotNull = true joinTable.Columns[fkCol1Name] = fkCol1 fkCol2Name := fmt.Sprintf("%sId", model2) fkCol2 := models.InitColumn(fkCol2Name, joinTableName, "public") fkCol2.Type = pk2.Type fkCol2.NotNull = true joinTable.Columns[fkCol2Name] = fkCol2 // Create composite primary key pkConstraint := models.InitConstraint( fmt.Sprintf("pk_%s", joinTableName), models.PrimaryKeyConstraint, ) pkConstraint.Schema = "public" pkConstraint.Table = joinTableName pkConstraint.Columns = []string{fkCol1Name, fkCol2Name} joinTable.Constraints[pkConstraint.Name] = pkConstraint // Mark columns as PK fkCol1.IsPrimaryKey = true fkCol2.IsPrimaryKey = true // Create FK constraints fk1 := models.InitConstraint( fmt.Sprintf("fk_%s_%s", joinTableName, model1), models.ForeignKeyConstraint, ) fk1.Schema = "public" fk1.Table = joinTableName fk1.Columns = []string{fkCol1Name} fk1.ReferencedSchema = "public" fk1.ReferencedTable = model1 fk1.ReferencedColumns = []string{pk1.Name} fk1.OnDelete = "Cascade" joinTable.Constraints[fk1.Name] = fk1 fk2 := models.InitConstraint( fmt.Sprintf("fk_%s_%s", joinTableName, model2), models.ForeignKeyConstraint, ) fk2.Schema = "public" fk2.Table = joinTableName fk2.Columns = []string{fkCol2Name} fk2.ReferencedSchema = "public" fk2.ReferencedTable = model2 fk2.ReferencedColumns = []string{pk2.Name} fk2.OnDelete = "Cascade" joinTable.Constraints[fk2.Name] = fk2 // Add join table to schema schema.Tables = append(schema.Tables, joinTable) tableMap[joinTableName] = joinTable } // getPrimaryKeyColumn returns the primary key column of a table func (r *Reader) getPrimaryKeyColumn(table *models.Table) *models.Column { if table == nil { return nil } for _, col := range table.Columns { if col.IsPrimaryKey { return col } } return nil }