Files
relspecgo/pkg/readers/prisma/reader.go
Hein aad5db5175
Some checks failed
CI / Test (1.24) (push) Failing after -24m50s
CI / Test (1.25) (push) Failing after -24m42s
CI / Build (push) Successful in -25m49s
CI / Lint (push) Successful in -25m36s
fix: readers and linting issues
2025-12-19 22:28:24 +02:00

816 lines
23 KiB
Go

package prisma
import (
"bufio"
"fmt"
"os"
"regexp"
"strings"
"git.warky.dev/wdevs/relspecgo/pkg/models"
"git.warky.dev/wdevs/relspecgo/pkg/readers"
)
// Reader implements the readers.Reader interface for Prisma schema format
type Reader struct {
options *readers.ReaderOptions
}
// NewReader creates a new Prisma reader with the given options
func NewReader(options *readers.ReaderOptions) *Reader {
return &Reader{
options: options,
}
}
// ReadDatabase reads and parses Prisma schema input, returning a Database model
func (r *Reader) ReadDatabase() (*models.Database, error) {
if r.options.FilePath == "" {
return nil, fmt.Errorf("file path is required for Prisma reader")
}
content, err := os.ReadFile(r.options.FilePath)
if err != nil {
return nil, fmt.Errorf("failed to read file: %w", err)
}
return r.parsePrisma(string(content))
}
// ReadSchema reads and parses Prisma schema input, returning a Schema model
func (r *Reader) ReadSchema() (*models.Schema, error) {
db, err := r.ReadDatabase()
if err != nil {
return nil, err
}
if len(db.Schemas) == 0 {
return nil, fmt.Errorf("no schemas found in Prisma schema")
}
// Return the first schema
return db.Schemas[0], nil
}
// ReadTable reads and parses Prisma schema input, returning a Table model
func (r *Reader) ReadTable() (*models.Table, error) {
schema, err := r.ReadSchema()
if err != nil {
return nil, err
}
if len(schema.Tables) == 0 {
return nil, fmt.Errorf("no tables found in Prisma schema")
}
// Return the first table
return schema.Tables[0], nil
}
// parsePrisma parses Prisma schema content and returns a Database model
func (r *Reader) parsePrisma(content string) (*models.Database, error) {
db := models.InitDatabase("database")
if r.options.Metadata != nil {
if name, ok := r.options.Metadata["name"].(string); ok {
db.Name = name
}
}
// Default schema for Prisma (doesn't have explicit schema concept in most cases)
schema := models.InitSchema("public")
schema.Enums = make([]*models.Enum, 0)
scanner := bufio.NewScanner(strings.NewReader(content))
// State tracking
var currentBlock string // "datasource", "generator", "model", "enum"
var currentTable *models.Table
var currentEnum *models.Enum
var blockContent []string
// Regex patterns
datasourceRegex := regexp.MustCompile(`^datasource\s+\w+\s*{`)
generatorRegex := regexp.MustCompile(`^generator\s+\w+\s*{`)
modelRegex := regexp.MustCompile(`^model\s+(\w+)\s*{`)
enumRegex := regexp.MustCompile(`^enum\s+(\w+)\s*{`)
for scanner.Scan() {
line := scanner.Text()
trimmed := strings.TrimSpace(line)
// Skip empty lines and comments
if trimmed == "" || strings.HasPrefix(trimmed, "//") {
continue
}
// Check for block start
if matches := datasourceRegex.FindStringSubmatch(trimmed); matches != nil {
currentBlock = "datasource"
blockContent = []string{}
continue
}
if matches := generatorRegex.FindStringSubmatch(trimmed); matches != nil {
currentBlock = "generator"
blockContent = []string{}
continue
}
if matches := modelRegex.FindStringSubmatch(trimmed); matches != nil {
currentBlock = "model"
tableName := matches[1]
currentTable = models.InitTable(tableName, "public")
blockContent = []string{}
continue
}
if matches := enumRegex.FindStringSubmatch(trimmed); matches != nil {
currentBlock = "enum"
enumName := matches[1]
currentEnum = &models.Enum{
Name: enumName,
Schema: "public",
Values: make([]string, 0),
}
blockContent = []string{}
continue
}
// Check for block end
if trimmed == "}" {
switch currentBlock {
case "datasource":
r.parseDatasource(blockContent, db)
case "generator":
// We don't need to do anything with generator blocks
case "model":
if currentTable != nil {
r.parseModelFields(blockContent, currentTable)
schema.Tables = append(schema.Tables, currentTable)
currentTable = nil
}
case "enum":
if currentEnum != nil {
schema.Enums = append(schema.Enums, currentEnum)
currentEnum = nil
}
}
currentBlock = ""
blockContent = []string{}
continue
}
// Accumulate block content
if currentBlock != "" {
if currentBlock == "enum" && currentEnum != nil {
// For enums, just add the trimmed value
if trimmed != "" {
currentEnum.Values = append(currentEnum.Values, trimmed)
}
} else {
blockContent = append(blockContent, line)
}
}
}
// Second pass: resolve relationships
r.resolveRelationships(schema)
db.Schemas = append(db.Schemas, schema)
return db, nil
}
// parseDatasource extracts database type from datasource block
func (r *Reader) parseDatasource(lines []string, db *models.Database) {
providerRegex := regexp.MustCompile(`provider\s*=\s*"?(\w+)"?`)
for _, line := range lines {
if matches := providerRegex.FindStringSubmatch(line); matches != nil {
provider := matches[1]
switch provider {
case "postgresql", "postgres":
db.DatabaseType = models.PostgresqlDatabaseType
case "mysql":
db.DatabaseType = "mysql"
case "sqlite":
db.DatabaseType = models.SqlLiteDatabaseType
case "sqlserver":
db.DatabaseType = models.MSSQLDatabaseType
default:
db.DatabaseType = models.PostgresqlDatabaseType
}
break
}
}
}
// parseModelFields parses model field definitions
func (r *Reader) parseModelFields(lines []string, table *models.Table) {
fieldRegex := regexp.MustCompile(`^(\w+)\s+(\w+)(\?|\[\])?\s*(@.+)?`)
blockAttrRegex := regexp.MustCompile(`^@@(\w+)\((.*?)\)`)
for _, line := range lines {
trimmed := strings.TrimSpace(line)
// Skip empty lines and comments
if trimmed == "" || strings.HasPrefix(trimmed, "//") {
continue
}
// Check for block attributes (@@id, @@unique, @@index)
if matches := blockAttrRegex.FindStringSubmatch(trimmed); matches != nil {
attrName := matches[1]
attrContent := matches[2]
r.parseBlockAttribute(attrName, attrContent, table)
continue
}
// Parse field definition
if matches := fieldRegex.FindStringSubmatch(trimmed); matches != nil {
fieldName := matches[1]
fieldType := matches[2]
modifier := matches[3] // ? or []
attributes := matches[4] // @... part
column := r.parseField(fieldName, fieldType, modifier, attributes, table)
if column != nil {
table.Columns[column.Name] = column
}
}
}
}
// parseField parses a single field definition
func (r *Reader) parseField(name, fieldType, modifier, attributes string, table *models.Table) *models.Column {
// Check if this is a relation field (array or references another model)
if modifier == "[]" {
// Array field - this is a relation field, not a column
// We'll handle this in relationship resolution
return nil
}
// Check if this is a non-primitive type (relation field)
// Note: We need to allow enum types through as they're like primitives
if !r.isPrimitiveType(fieldType) && !r.isEnumType(fieldType, table) {
// This is a relation field (e.g., user User), not a scalar column
// Only process this if it has @relation attribute (which means it's the owning side with FK)
// Otherwise skip it as it's just the inverse relation field
if attributes == "" || !strings.Contains(attributes, "@relation") {
return nil
}
// If it has @relation, we still don't create a column for it
// The actual FK column will be in the fields: [...] part of @relation
return nil
}
column := models.InitColumn(name, table.Name, table.Schema)
// Map Prisma type to SQL type
column.Type = r.prismaTypeToSQL(fieldType)
// Handle modifiers
if modifier == "?" {
column.NotNull = false
} else {
// Default: required fields are NOT NULL
column.NotNull = true
}
// Parse field attributes
if attributes != "" {
r.parseFieldAttributes(attributes, column, table)
}
return column
}
// prismaTypeToSQL converts Prisma types to SQL types
func (r *Reader) prismaTypeToSQL(prismaType string) string {
typeMap := map[string]string{
"String": "text",
"Boolean": "boolean",
"Int": "integer",
"BigInt": "bigint",
"Float": "double precision",
"Decimal": "decimal",
"DateTime": "timestamp",
"Json": "jsonb",
"Bytes": "bytea",
}
if sqlType, ok := typeMap[prismaType]; ok {
return sqlType
}
// If not a built-in type, it might be an enum or model reference
// For enums, we'll use the enum name directly
return prismaType
}
// parseFieldAttributes parses field attributes like @id, @unique, @default
func (r *Reader) parseFieldAttributes(attributes string, column *models.Column, table *models.Table) {
// @id attribute
if strings.Contains(attributes, "@id") {
column.IsPrimaryKey = true
column.NotNull = true
}
// @unique attribute
if regexp.MustCompile(`@unique\b`).MatchString(attributes) {
uniqueConstraint := models.InitConstraint(
fmt.Sprintf("uq_%s", column.Name),
models.UniqueConstraint,
)
uniqueConstraint.Schema = table.Schema
uniqueConstraint.Table = table.Name
uniqueConstraint.Columns = []string{column.Name}
table.Constraints[uniqueConstraint.Name] = uniqueConstraint
}
// @default attribute - extract value with balanced parentheses
if strings.Contains(attributes, "@default(") {
defaultValue := r.extractDefaultValue(attributes)
if defaultValue != "" {
r.parseDefaultValue(defaultValue, column)
}
}
// @updatedAt attribute - store in comment for now
if strings.Contains(attributes, "@updatedAt") {
if column.Comment != "" {
column.Comment += "; @updatedAt"
} else {
column.Comment = "@updatedAt"
}
}
// @relation attribute - we'll handle this in relationship resolution
// For now, just note that this field is part of a relation
}
// extractDefaultValue extracts the default value from @default(...) handling nested parentheses
func (r *Reader) extractDefaultValue(attributes string) string {
idx := strings.Index(attributes, "@default(")
if idx == -1 {
return ""
}
start := idx + len("@default(")
depth := 1
i := start
for i < len(attributes) && depth > 0 {
switch attributes[i] {
case '(':
depth++
case ')':
depth--
}
i++
}
if depth == 0 {
return attributes[start : i-1]
}
return ""
}
// parseDefaultValue parses Prisma default value expressions
func (r *Reader) parseDefaultValue(defaultExpr string, column *models.Column) {
defaultExpr = strings.TrimSpace(defaultExpr)
switch defaultExpr {
case "autoincrement()":
column.AutoIncrement = true
case "now()":
column.Default = "now()"
case "uuid()":
column.Default = "gen_random_uuid()"
case "cuid()":
// CUID is Prisma-specific, store in comment
if column.Comment != "" {
column.Comment += "; default(cuid())"
} else {
column.Comment = "default(cuid())"
}
case "true":
column.Default = true
case "false":
column.Default = false
default:
// Check if it's a string literal
if strings.HasPrefix(defaultExpr, "\"") && strings.HasSuffix(defaultExpr, "\"") {
column.Default = defaultExpr[1 : len(defaultExpr)-1]
} else if strings.HasPrefix(defaultExpr, "'") && strings.HasSuffix(defaultExpr, "'") {
column.Default = defaultExpr[1 : len(defaultExpr)-1]
} else {
// Try to parse as number or enum value
column.Default = defaultExpr
}
}
}
// parseBlockAttribute parses block-level attributes like @@id, @@unique, @@index
func (r *Reader) parseBlockAttribute(attrName, content string, table *models.Table) {
// Extract column list from brackets [col1, col2]
colListRegex := regexp.MustCompile(`\[(.*?)\]`)
matches := colListRegex.FindStringSubmatch(content)
if matches == nil {
return
}
columnList := strings.Split(matches[1], ",")
columns := make([]string, 0)
for _, col := range columnList {
columns = append(columns, strings.TrimSpace(col))
}
switch attrName {
case "id":
// Composite primary key
for _, colName := range columns {
if col, exists := table.Columns[colName]; exists {
col.IsPrimaryKey = true
col.NotNull = true
}
}
// Also create a PK constraint
pkConstraint := models.InitConstraint(
fmt.Sprintf("pk_%s", table.Name),
models.PrimaryKeyConstraint,
)
pkConstraint.Schema = table.Schema
pkConstraint.Table = table.Name
pkConstraint.Columns = columns
table.Constraints[pkConstraint.Name] = pkConstraint
case "unique":
// Multi-column unique constraint
uniqueConstraint := models.InitConstraint(
fmt.Sprintf("uq_%s_%s", table.Name, strings.Join(columns, "_")),
models.UniqueConstraint,
)
uniqueConstraint.Schema = table.Schema
uniqueConstraint.Table = table.Name
uniqueConstraint.Columns = columns
table.Constraints[uniqueConstraint.Name] = uniqueConstraint
case "index":
// Index
index := models.InitIndex(
fmt.Sprintf("idx_%s_%s", table.Name, strings.Join(columns, "_")),
table.Name,
table.Schema,
)
index.Columns = columns
table.Indexes[index.Name] = index
}
}
// relationField stores information about a relation field for second-pass processing
type relationField struct {
tableName string
fieldName string
relatedModel string
isArray bool
relationAttr string
}
// resolveRelationships performs a second pass to resolve @relation attributes
func (r *Reader) resolveRelationships(schema *models.Schema) {
// Build a map of table names for quick lookup
tableMap := make(map[string]*models.Table)
for _, table := range schema.Tables {
tableMap[table.Name] = table
}
// First, we need to re-parse to find relation fields
// We'll re-read the file to extract relation information
if r.options.FilePath == "" {
return
}
content, err := os.ReadFile(r.options.FilePath)
if err != nil {
return
}
relations := r.extractRelationFields(string(content))
// Process explicit @relation attributes to create FK constraints
for _, rel := range relations {
if rel.relationAttr != "" {
r.createConstraintFromRelation(rel, tableMap, schema)
}
}
// Detect implicit many-to-many relationships
r.detectImplicitManyToMany(relations, tableMap, schema)
}
// extractRelationFields extracts relation field information from the schema
func (r *Reader) extractRelationFields(content string) []relationField {
relations := make([]relationField, 0)
scanner := bufio.NewScanner(strings.NewReader(content))
modelRegex := regexp.MustCompile(`^model\s+(\w+)\s*{`)
fieldRegex := regexp.MustCompile(`^(\w+)\s+(\w+)(\?|\[\])?\s*(@.+)?`)
var currentModel string
inModel := false
for scanner.Scan() {
line := scanner.Text()
trimmed := strings.TrimSpace(line)
if trimmed == "" || strings.HasPrefix(trimmed, "//") {
continue
}
if matches := modelRegex.FindStringSubmatch(trimmed); matches != nil {
currentModel = matches[1]
inModel = true
continue
}
if trimmed == "}" {
inModel = false
currentModel = ""
continue
}
if inModel && currentModel != "" {
if matches := fieldRegex.FindStringSubmatch(trimmed); matches != nil {
fieldName := matches[1]
fieldType := matches[2]
modifier := matches[3]
attributes := matches[4]
// Check if this is a relation field (references another model or is an array)
isPotentialRelation := modifier == "[]" || !r.isPrimitiveType(fieldType)
if isPotentialRelation {
rel := relationField{
tableName: currentModel,
fieldName: fieldName,
relatedModel: fieldType,
isArray: modifier == "[]",
relationAttr: attributes,
}
relations = append(relations, rel)
}
}
}
}
return relations
}
// isPrimitiveType checks if a type is a Prisma primitive type
func (r *Reader) isPrimitiveType(typeName string) bool {
primitives := []string{"String", "Boolean", "Int", "BigInt", "Float", "Decimal", "DateTime", "Json", "Bytes"}
for _, p := range primitives {
if typeName == p {
return true
}
}
return false
}
// isEnumType checks if a type name might be an enum
// Note: We can't definitively check against schema.Enums at parse time
// because enums might be defined after the model, so we just check
// if it starts with uppercase (Prisma convention for enums)
func (r *Reader) isEnumType(typeName string, table *models.Table) bool {
// Simple heuristic: enum types start with uppercase letter
// and are not known model names (though we can't check that yet)
if len(typeName) > 0 && typeName[0] >= 'A' && typeName[0] <= 'Z' {
// Additional check: primitive types are already handled above
// So if it's uppercase and not primitive, it's likely an enum or model
// We'll assume it's an enum if it's a single word
return !strings.Contains(typeName, "_")
}
return false
}
// createConstraintFromRelation creates a FK constraint from a @relation attribute
func (r *Reader) createConstraintFromRelation(rel relationField, tableMap map[string]*models.Table, schema *models.Schema) {
// Skip array fields (they are the inverse side of the relation)
if rel.isArray {
return
}
if rel.relationAttr == "" {
return
}
// Parse @relation attribute
relationRegex := regexp.MustCompile(`@relation\((.*?)\)`)
matches := relationRegex.FindStringSubmatch(rel.relationAttr)
if matches == nil {
return
}
relationContent := matches[1]
// Extract fields and references
fieldsRegex := regexp.MustCompile(`fields:\s*\[(.*?)\]`)
referencesRegex := regexp.MustCompile(`references:\s*\[(.*?)\]`)
nameRegex := regexp.MustCompile(`name:\s*"([^"]+)"`)
onDeleteRegex := regexp.MustCompile(`onDelete:\s*(\w+)`)
onUpdateRegex := regexp.MustCompile(`onUpdate:\s*(\w+)`)
fieldsMatch := fieldsRegex.FindStringSubmatch(relationContent)
referencesMatch := referencesRegex.FindStringSubmatch(relationContent)
if fieldsMatch == nil || referencesMatch == nil {
return
}
// Parse field and reference column lists
fieldCols := r.parseColumnList(fieldsMatch[1])
refCols := r.parseColumnList(referencesMatch[1])
if len(fieldCols) == 0 || len(refCols) == 0 {
return
}
// Create FK constraint
constraintName := fmt.Sprintf("fk_%s_%s", rel.tableName, fieldCols[0])
// Check for custom name
if nameMatch := nameRegex.FindStringSubmatch(relationContent); nameMatch != nil {
constraintName = nameMatch[1]
}
constraint := models.InitConstraint(constraintName, models.ForeignKeyConstraint)
constraint.Schema = "public"
constraint.Table = rel.tableName
constraint.Columns = fieldCols
constraint.ReferencedSchema = "public"
constraint.ReferencedTable = rel.relatedModel
constraint.ReferencedColumns = refCols
// Parse referential actions
if onDeleteMatch := onDeleteRegex.FindStringSubmatch(relationContent); onDeleteMatch != nil {
constraint.OnDelete = onDeleteMatch[1]
}
if onUpdateMatch := onUpdateRegex.FindStringSubmatch(relationContent); onUpdateMatch != nil {
constraint.OnUpdate = onUpdateMatch[1]
}
// Add constraint to table
if table, exists := tableMap[rel.tableName]; exists {
table.Constraints[constraint.Name] = constraint
}
}
// parseColumnList parses a comma-separated list of column names
func (r *Reader) parseColumnList(list string) []string {
parts := strings.Split(list, ",")
result := make([]string, 0)
for _, part := range parts {
trimmed := strings.TrimSpace(part)
if trimmed != "" {
result = append(result, trimmed)
}
}
return result
}
// detectImplicitManyToMany detects implicit M2M relationships and creates join tables
func (r *Reader) detectImplicitManyToMany(relations []relationField, tableMap map[string]*models.Table, schema *models.Schema) {
// Group relations by model pairs
type modelPair struct {
model1 string
model2 string
}
pairMap := make(map[modelPair][]relationField)
for _, rel := range relations {
if !rel.isArray || rel.relationAttr != "" {
// Skip non-array fields and explicit relations
continue
}
// Create a normalized pair (alphabetically sorted to avoid duplicates)
pair := modelPair{}
if rel.tableName < rel.relatedModel {
pair.model1 = rel.tableName
pair.model2 = rel.relatedModel
} else {
pair.model1 = rel.relatedModel
pair.model2 = rel.tableName
}
pairMap[pair] = append(pairMap[pair], rel)
}
// Check for pairs with arrays on both sides (implicit M2M)
for pair, rels := range pairMap {
if len(rels) >= 2 {
// This is an implicit many-to-many relationship
r.createImplicitJoinTable(pair.model1, pair.model2, tableMap, schema)
}
}
}
// createImplicitJoinTable creates a virtual join table for implicit M2M relations
func (r *Reader) createImplicitJoinTable(model1, model2 string, tableMap map[string]*models.Table, schema *models.Schema) {
// Prisma naming convention: _Model1ToModel2 (alphabetically sorted)
joinTableName := fmt.Sprintf("_%sTo%s", model1, model2)
// Check if join table already exists
if _, exists := tableMap[joinTableName]; exists {
return
}
// Create join table
joinTable := models.InitTable(joinTableName, "public")
// Get primary keys from both tables
pk1 := r.getPrimaryKeyColumn(tableMap[model1])
pk2 := r.getPrimaryKeyColumn(tableMap[model2])
if pk1 == nil || pk2 == nil {
return // Can't create join table without PKs
}
// Create FK columns in join table
fkCol1Name := fmt.Sprintf("%sId", model1)
fkCol1 := models.InitColumn(fkCol1Name, joinTableName, "public")
fkCol1.Type = pk1.Type
fkCol1.NotNull = true
joinTable.Columns[fkCol1Name] = fkCol1
fkCol2Name := fmt.Sprintf("%sId", model2)
fkCol2 := models.InitColumn(fkCol2Name, joinTableName, "public")
fkCol2.Type = pk2.Type
fkCol2.NotNull = true
joinTable.Columns[fkCol2Name] = fkCol2
// Create composite primary key
pkConstraint := models.InitConstraint(
fmt.Sprintf("pk_%s", joinTableName),
models.PrimaryKeyConstraint,
)
pkConstraint.Schema = "public"
pkConstraint.Table = joinTableName
pkConstraint.Columns = []string{fkCol1Name, fkCol2Name}
joinTable.Constraints[pkConstraint.Name] = pkConstraint
// Mark columns as PK
fkCol1.IsPrimaryKey = true
fkCol2.IsPrimaryKey = true
// Create FK constraints
fk1 := models.InitConstraint(
fmt.Sprintf("fk_%s_%s", joinTableName, model1),
models.ForeignKeyConstraint,
)
fk1.Schema = "public"
fk1.Table = joinTableName
fk1.Columns = []string{fkCol1Name}
fk1.ReferencedSchema = "public"
fk1.ReferencedTable = model1
fk1.ReferencedColumns = []string{pk1.Name}
fk1.OnDelete = "Cascade"
joinTable.Constraints[fk1.Name] = fk1
fk2 := models.InitConstraint(
fmt.Sprintf("fk_%s_%s", joinTableName, model2),
models.ForeignKeyConstraint,
)
fk2.Schema = "public"
fk2.Table = joinTableName
fk2.Columns = []string{fkCol2Name}
fk2.ReferencedSchema = "public"
fk2.ReferencedTable = model2
fk2.ReferencedColumns = []string{pk2.Name}
fk2.OnDelete = "Cascade"
joinTable.Constraints[fk2.Name] = fk2
// Add join table to schema
schema.Tables = append(schema.Tables, joinTable)
tableMap[joinTableName] = joinTable
}
// getPrimaryKeyColumn returns the primary key column of a table
func (r *Reader) getPrimaryKeyColumn(table *models.Table) *models.Column {
if table == nil {
return nil
}
for _, col := range table.Columns {
if col.IsPrimaryKey {
return col
}
}
return nil
}