552 lines
14 KiB
Go
552 lines
14 KiB
Go
package prisma
|
|
|
|
import (
|
|
"fmt"
|
|
"os"
|
|
"sort"
|
|
"strings"
|
|
|
|
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
|
"git.warky.dev/wdevs/relspecgo/pkg/writers"
|
|
)
|
|
|
|
// Writer implements the writers.Writer interface for Prisma schema format
|
|
type Writer struct {
|
|
options *writers.WriterOptions
|
|
}
|
|
|
|
// NewWriter creates a new Prisma writer with the given options
|
|
func NewWriter(options *writers.WriterOptions) *Writer {
|
|
return &Writer{
|
|
options: options,
|
|
}
|
|
}
|
|
|
|
// WriteDatabase writes a Database model to Prisma schema format
|
|
func (w *Writer) WriteDatabase(db *models.Database) error {
|
|
content := w.databaseToPrisma(db)
|
|
|
|
if w.options.OutputPath != "" {
|
|
return os.WriteFile(w.options.OutputPath, []byte(content), 0644)
|
|
}
|
|
|
|
fmt.Print(content)
|
|
return nil
|
|
}
|
|
|
|
// WriteSchema writes a Schema model to Prisma schema format
|
|
func (w *Writer) WriteSchema(schema *models.Schema) error {
|
|
// Create temporary database for schema
|
|
db := models.InitDatabase("database")
|
|
db.Schemas = []*models.Schema{schema}
|
|
|
|
return w.WriteDatabase(db)
|
|
}
|
|
|
|
// WriteTable writes a Table model to Prisma schema format
|
|
func (w *Writer) WriteTable(table *models.Table) error {
|
|
// Create temporary schema and database for table
|
|
schema := models.InitSchema(table.Schema)
|
|
schema.Tables = []*models.Table{table}
|
|
|
|
return w.WriteSchema(schema)
|
|
}
|
|
|
|
// databaseToPrisma converts a Database to Prisma schema format string
|
|
func (w *Writer) databaseToPrisma(db *models.Database) string {
|
|
var sb strings.Builder
|
|
|
|
// Write datasource block
|
|
sb.WriteString(w.generateDatasource(db))
|
|
sb.WriteString("\n")
|
|
|
|
// Write generator block
|
|
sb.WriteString(w.generateGenerator())
|
|
sb.WriteString("\n")
|
|
|
|
// Process all schemas (typically just one in Prisma)
|
|
for _, schema := range db.Schemas {
|
|
// Write enums
|
|
if len(schema.Enums) > 0 {
|
|
for _, enum := range schema.Enums {
|
|
sb.WriteString(w.enumToPrisma(enum))
|
|
sb.WriteString("\n")
|
|
}
|
|
}
|
|
|
|
// Identify join tables for implicit M2M
|
|
joinTables := w.identifyJoinTables(schema)
|
|
|
|
// Write models (excluding join tables)
|
|
for _, table := range schema.Tables {
|
|
if joinTables[table.Name] {
|
|
continue // Skip join tables
|
|
}
|
|
sb.WriteString(w.tableToPrisma(table, schema, joinTables))
|
|
sb.WriteString("\n")
|
|
}
|
|
}
|
|
|
|
return sb.String()
|
|
}
|
|
|
|
// generateDatasource generates the datasource block
|
|
func (w *Writer) generateDatasource(db *models.Database) string {
|
|
provider := "postgresql"
|
|
|
|
// Map database type to Prisma provider
|
|
switch db.DatabaseType {
|
|
case models.PostgresqlDatabaseType:
|
|
provider = "postgresql"
|
|
case models.MSSQLDatabaseType:
|
|
provider = "sqlserver"
|
|
case models.SqlLiteDatabaseType:
|
|
provider = "sqlite"
|
|
case "mysql":
|
|
provider = "mysql"
|
|
}
|
|
|
|
return fmt.Sprintf(`datasource db {
|
|
provider = "%s"
|
|
url = env("DATABASE_URL")
|
|
}
|
|
`, provider)
|
|
}
|
|
|
|
// generateGenerator generates the generator block
|
|
func (w *Writer) generateGenerator() string {
|
|
return `generator client {
|
|
provider = "prisma-client-js"
|
|
}
|
|
`
|
|
}
|
|
|
|
// enumToPrisma converts an Enum to Prisma enum block
|
|
func (w *Writer) enumToPrisma(enum *models.Enum) string {
|
|
var sb strings.Builder
|
|
|
|
sb.WriteString(fmt.Sprintf("enum %s {\n", enum.Name))
|
|
for _, value := range enum.Values {
|
|
sb.WriteString(fmt.Sprintf(" %s\n", value))
|
|
}
|
|
sb.WriteString("}\n")
|
|
|
|
return sb.String()
|
|
}
|
|
|
|
// identifyJoinTables identifies tables that are join tables for M2M relations
|
|
func (w *Writer) identifyJoinTables(schema *models.Schema) map[string]bool {
|
|
joinTables := make(map[string]bool)
|
|
|
|
for _, table := range schema.Tables {
|
|
// Check if this is a join table:
|
|
// 1. Starts with _ (Prisma convention)
|
|
// 2. Has exactly 2 FK constraints
|
|
// 3. Has composite PK with those 2 columns
|
|
// 4. Has no other columns except the FK columns
|
|
|
|
if !strings.HasPrefix(table.Name, "_") {
|
|
continue
|
|
}
|
|
|
|
fks := table.GetForeignKeys()
|
|
if len(fks) != 2 {
|
|
continue
|
|
}
|
|
|
|
// Check if columns are only the FK columns
|
|
if len(table.Columns) != 2 {
|
|
continue
|
|
}
|
|
|
|
// Check if both FK columns are part of PK
|
|
pkCols := 0
|
|
for _, col := range table.Columns {
|
|
if col.IsPrimaryKey {
|
|
pkCols++
|
|
}
|
|
}
|
|
|
|
if pkCols == 2 {
|
|
joinTables[table.Name] = true
|
|
}
|
|
}
|
|
|
|
return joinTables
|
|
}
|
|
|
|
// tableToPrisma converts a Table to Prisma model block
|
|
func (w *Writer) tableToPrisma(table *models.Table, schema *models.Schema, joinTables map[string]bool) string {
|
|
var sb strings.Builder
|
|
|
|
sb.WriteString(fmt.Sprintf("model %s {\n", table.Name))
|
|
|
|
// Collect columns to write
|
|
columns := make([]*models.Column, 0, len(table.Columns))
|
|
for _, col := range table.Columns {
|
|
columns = append(columns, col)
|
|
}
|
|
|
|
// Sort columns for consistent output
|
|
sort.Slice(columns, func(i, j int) bool {
|
|
return columns[i].Name < columns[j].Name
|
|
})
|
|
|
|
// Write scalar fields
|
|
for _, col := range columns {
|
|
// Skip if this column is part of a relation that will be output as array field
|
|
if w.isRelationColumn(col, table) {
|
|
// We'll output this with the relation field
|
|
continue
|
|
}
|
|
|
|
sb.WriteString(w.columnToField(col, table, schema))
|
|
}
|
|
|
|
// Write relation fields
|
|
sb.WriteString(w.generateRelationFields(table, schema, joinTables))
|
|
|
|
// Write block attributes (@@id, @@unique, @@index)
|
|
sb.WriteString(w.generateBlockAttributes(table))
|
|
|
|
sb.WriteString("}\n")
|
|
|
|
return sb.String()
|
|
}
|
|
|
|
// columnToField converts a Column to a Prisma field definition
|
|
func (w *Writer) columnToField(col *models.Column, table *models.Table, schema *models.Schema) string {
|
|
var sb strings.Builder
|
|
|
|
// Field name
|
|
sb.WriteString(fmt.Sprintf(" %s", col.Name))
|
|
|
|
// Field type
|
|
prismaType := w.sqlTypeToPrisma(col.Type, schema)
|
|
sb.WriteString(fmt.Sprintf(" %s", prismaType))
|
|
|
|
// Optional modifier
|
|
if !col.NotNull && !col.IsPrimaryKey {
|
|
sb.WriteString("?")
|
|
}
|
|
|
|
// Field attributes
|
|
attributes := w.generateFieldAttributes(col, table)
|
|
if attributes != "" {
|
|
sb.WriteString(" ")
|
|
sb.WriteString(attributes)
|
|
}
|
|
|
|
sb.WriteString("\n")
|
|
|
|
return sb.String()
|
|
}
|
|
|
|
// sqlTypeToPrisma converts SQL types to Prisma types
|
|
func (w *Writer) sqlTypeToPrisma(sqlType string, schema *models.Schema) string {
|
|
// Check if it's an enum
|
|
for _, enum := range schema.Enums {
|
|
if strings.EqualFold(sqlType, enum.Name) {
|
|
return enum.Name
|
|
}
|
|
}
|
|
|
|
// Standard type mapping
|
|
typeMap := map[string]string{
|
|
"text": "String",
|
|
"varchar": "String",
|
|
"character varying": "String",
|
|
"char": "String",
|
|
"boolean": "Boolean",
|
|
"bool": "Boolean",
|
|
"integer": "Int",
|
|
"int": "Int",
|
|
"int4": "Int",
|
|
"bigint": "BigInt",
|
|
"int8": "BigInt",
|
|
"double precision": "Float",
|
|
"float": "Float",
|
|
"float8": "Float",
|
|
"decimal": "Decimal",
|
|
"numeric": "Decimal",
|
|
"timestamp": "DateTime",
|
|
"timestamptz": "DateTime",
|
|
"date": "DateTime",
|
|
"jsonb": "Json",
|
|
"json": "Json",
|
|
"bytea": "Bytes",
|
|
}
|
|
|
|
for sqlPattern, prismaType := range typeMap {
|
|
if strings.Contains(strings.ToLower(sqlType), sqlPattern) {
|
|
return prismaType
|
|
}
|
|
}
|
|
|
|
// Default to String for unknown types
|
|
return "String"
|
|
}
|
|
|
|
// generateFieldAttributes generates field attributes like @id, @unique, @default
|
|
func (w *Writer) generateFieldAttributes(col *models.Column, table *models.Table) string {
|
|
attrs := make([]string, 0)
|
|
|
|
// @id
|
|
if col.IsPrimaryKey {
|
|
// Check if this is part of a composite key
|
|
pkCount := 0
|
|
for _, c := range table.Columns {
|
|
if c.IsPrimaryKey {
|
|
pkCount++
|
|
}
|
|
}
|
|
if pkCount == 1 {
|
|
attrs = append(attrs, "@id")
|
|
}
|
|
}
|
|
|
|
// @unique
|
|
if w.hasUniqueConstraint(col.Name, table) {
|
|
attrs = append(attrs, "@unique")
|
|
}
|
|
|
|
// @default
|
|
if col.AutoIncrement {
|
|
attrs = append(attrs, "@default(autoincrement())")
|
|
} else if col.Default != nil {
|
|
defaultAttr := w.formatDefaultValue(col.Default)
|
|
if defaultAttr != "" {
|
|
attrs = append(attrs, fmt.Sprintf("@default(%s)", defaultAttr))
|
|
}
|
|
}
|
|
|
|
// @updatedAt (check comment)
|
|
if strings.Contains(col.Comment, "@updatedAt") {
|
|
attrs = append(attrs, "@updatedAt")
|
|
}
|
|
|
|
return strings.Join(attrs, " ")
|
|
}
|
|
|
|
// formatDefaultValue formats a default value for Prisma
|
|
func (w *Writer) formatDefaultValue(defaultValue any) string {
|
|
switch v := defaultValue.(type) {
|
|
case string:
|
|
if v == "now()" {
|
|
return "now()"
|
|
} else if v == "gen_random_uuid()" {
|
|
return "uuid()"
|
|
} else if strings.Contains(strings.ToLower(v), "uuid") {
|
|
return "uuid()"
|
|
} else {
|
|
// String literal
|
|
return fmt.Sprintf(`"%s"`, v)
|
|
}
|
|
case bool:
|
|
if v {
|
|
return "true"
|
|
}
|
|
return "false"
|
|
case int, int64, int32:
|
|
return fmt.Sprintf("%v", v)
|
|
default:
|
|
return fmt.Sprintf("%v", v)
|
|
}
|
|
}
|
|
|
|
// hasUniqueConstraint checks if a column has a unique constraint
|
|
func (w *Writer) hasUniqueConstraint(colName string, table *models.Table) bool {
|
|
for _, constraint := range table.Constraints {
|
|
if constraint.Type == models.UniqueConstraint &&
|
|
len(constraint.Columns) == 1 &&
|
|
constraint.Columns[0] == colName {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// isRelationColumn checks if a column is a FK column
|
|
func (w *Writer) isRelationColumn(col *models.Column, table *models.Table) bool {
|
|
for _, constraint := range table.Constraints {
|
|
if constraint.Type == models.ForeignKeyConstraint {
|
|
for _, fkCol := range constraint.Columns {
|
|
if fkCol == col.Name {
|
|
return true
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// generateRelationFields generates relation fields and their FK columns
|
|
func (w *Writer) generateRelationFields(table *models.Table, schema *models.Schema, joinTables map[string]bool) string {
|
|
var sb strings.Builder
|
|
|
|
// Get all FK constraints
|
|
fks := table.GetForeignKeys()
|
|
|
|
for _, fk := range fks {
|
|
// Generate the FK scalar field
|
|
for _, fkCol := range fk.Columns {
|
|
if col, exists := table.Columns[fkCol]; exists {
|
|
sb.WriteString(w.columnToField(col, table, schema))
|
|
}
|
|
}
|
|
|
|
// Generate the relation field
|
|
relationType := fk.ReferencedTable
|
|
isOptional := false
|
|
|
|
// Check if FK column is nullable
|
|
for _, fkCol := range fk.Columns {
|
|
if col, exists := table.Columns[fkCol]; exists {
|
|
if !col.NotNull {
|
|
isOptional = true
|
|
}
|
|
}
|
|
}
|
|
|
|
relationName := relationType
|
|
if strings.HasSuffix(strings.ToLower(relationName), "s") {
|
|
relationName = relationName[:len(relationName)-1]
|
|
}
|
|
|
|
sb.WriteString(fmt.Sprintf(" %s %s", strings.ToLower(relationName), relationType))
|
|
|
|
if isOptional {
|
|
sb.WriteString("?")
|
|
}
|
|
|
|
// @relation attribute
|
|
relationAttr := w.generateRelationAttribute(fk)
|
|
if relationAttr != "" {
|
|
sb.WriteString(" ")
|
|
sb.WriteString(relationAttr)
|
|
}
|
|
|
|
sb.WriteString("\n")
|
|
}
|
|
|
|
// Generate inverse relations (arrays) for tables that reference this one
|
|
sb.WriteString(w.generateInverseRelations(table, schema, joinTables))
|
|
|
|
return sb.String()
|
|
}
|
|
|
|
// generateRelationAttribute generates the @relation(...) attribute
|
|
func (w *Writer) generateRelationAttribute(fk *models.Constraint) string {
|
|
parts := make([]string, 0)
|
|
|
|
// fields
|
|
fieldsStr := strings.Join(fk.Columns, ", ")
|
|
parts = append(parts, fmt.Sprintf("fields: [%s]", fieldsStr))
|
|
|
|
// references
|
|
referencesStr := strings.Join(fk.ReferencedColumns, ", ")
|
|
parts = append(parts, fmt.Sprintf("references: [%s]", referencesStr))
|
|
|
|
// onDelete
|
|
if fk.OnDelete != "" {
|
|
parts = append(parts, fmt.Sprintf("onDelete: %s", fk.OnDelete))
|
|
}
|
|
|
|
// onUpdate
|
|
if fk.OnUpdate != "" {
|
|
parts = append(parts, fmt.Sprintf("onUpdate: %s", fk.OnUpdate))
|
|
}
|
|
|
|
return fmt.Sprintf("@relation(%s)", strings.Join(parts, ", "))
|
|
}
|
|
|
|
// generateInverseRelations generates array fields for reverse relationships
|
|
func (w *Writer) generateInverseRelations(table *models.Table, schema *models.Schema, joinTables map[string]bool) string {
|
|
var sb strings.Builder
|
|
|
|
// Find all tables that have FKs pointing to this table
|
|
for _, otherTable := range schema.Tables {
|
|
if otherTable.Name == table.Name {
|
|
continue
|
|
}
|
|
|
|
// Check if this is a join table
|
|
if joinTables[otherTable.Name] {
|
|
// Handle implicit M2M
|
|
if w.isJoinTableFor(otherTable, table.Name) {
|
|
// Find the other side of the M2M
|
|
for _, fk := range otherTable.GetForeignKeys() {
|
|
if fk.ReferencedTable != table.Name {
|
|
// This is the other side
|
|
otherSide := fk.ReferencedTable
|
|
sb.WriteString(fmt.Sprintf(" %ss %s[]\n",
|
|
strings.ToLower(otherSide), otherSide))
|
|
break
|
|
}
|
|
}
|
|
}
|
|
continue
|
|
}
|
|
|
|
// Regular one-to-many inverse relation
|
|
for _, fk := range otherTable.GetForeignKeys() {
|
|
if fk.ReferencedTable == table.Name {
|
|
// This table is referenced by otherTable
|
|
pluralName := otherTable.Name
|
|
if !strings.HasSuffix(pluralName, "s") {
|
|
pluralName += "s"
|
|
}
|
|
|
|
sb.WriteString(fmt.Sprintf(" %s %s[]\n",
|
|
strings.ToLower(pluralName), otherTable.Name))
|
|
}
|
|
}
|
|
}
|
|
|
|
return sb.String()
|
|
}
|
|
|
|
// isJoinTableFor checks if a table is a join table involving the specified model
|
|
func (w *Writer) isJoinTableFor(joinTable *models.Table, modelName string) bool {
|
|
for _, fk := range joinTable.GetForeignKeys() {
|
|
if fk.ReferencedTable == modelName {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// generateBlockAttributes generates block-level attributes like @@id, @@unique, @@index
|
|
func (w *Writer) generateBlockAttributes(table *models.Table) string {
|
|
var sb strings.Builder
|
|
|
|
// @@id for composite primary key
|
|
pkCols := make([]string, 0)
|
|
for _, col := range table.Columns {
|
|
if col.IsPrimaryKey {
|
|
pkCols = append(pkCols, col.Name)
|
|
}
|
|
}
|
|
|
|
if len(pkCols) > 1 {
|
|
sort.Strings(pkCols)
|
|
sb.WriteString(fmt.Sprintf(" @@id([%s])\n", strings.Join(pkCols, ", ")))
|
|
}
|
|
|
|
// @@unique for multi-column unique constraints
|
|
for _, constraint := range table.Constraints {
|
|
if constraint.Type == models.UniqueConstraint && len(constraint.Columns) > 1 {
|
|
sb.WriteString(fmt.Sprintf(" @@unique([%s])\n", strings.Join(constraint.Columns, ", ")))
|
|
}
|
|
}
|
|
|
|
// @@index for indexes
|
|
for _, index := range table.Indexes {
|
|
if !index.Unique { // Unique indexes are handled by @@unique
|
|
sb.WriteString(fmt.Sprintf(" @@index([%s])\n", strings.Join(index.Columns, ", ")))
|
|
}
|
|
}
|
|
|
|
return sb.String()
|
|
}
|