package prisma import ( "fmt" "os" "sort" "strings" "git.warky.dev/wdevs/relspecgo/pkg/models" "git.warky.dev/wdevs/relspecgo/pkg/writers" ) // Writer implements the writers.Writer interface for Prisma schema format type Writer struct { options *writers.WriterOptions } // NewWriter creates a new Prisma writer with the given options func NewWriter(options *writers.WriterOptions) *Writer { return &Writer{ options: options, } } // WriteDatabase writes a Database model to Prisma schema format func (w *Writer) WriteDatabase(db *models.Database) error { content := w.databaseToPrisma(db) if w.options.OutputPath != "" { return os.WriteFile(w.options.OutputPath, []byte(content), 0644) } fmt.Print(content) return nil } // WriteSchema writes a Schema model to Prisma schema format func (w *Writer) WriteSchema(schema *models.Schema) error { // Create temporary database for schema db := models.InitDatabase("database") db.Schemas = []*models.Schema{schema} return w.WriteDatabase(db) } // WriteTable writes a Table model to Prisma schema format func (w *Writer) WriteTable(table *models.Table) error { // Create temporary schema and database for table schema := models.InitSchema(table.Schema) schema.Tables = []*models.Table{table} return w.WriteSchema(schema) } // databaseToPrisma converts a Database to Prisma schema format string func (w *Writer) databaseToPrisma(db *models.Database) string { var sb strings.Builder // Write datasource block sb.WriteString(w.generateDatasource(db)) sb.WriteString("\n") // Write generator block sb.WriteString(w.generateGenerator()) sb.WriteString("\n") // Process all schemas (typically just one in Prisma) for _, schema := range db.Schemas { // Write enums if len(schema.Enums) > 0 { for _, enum := range schema.Enums { sb.WriteString(w.enumToPrisma(enum)) sb.WriteString("\n") } } // Identify join tables for implicit M2M joinTables := w.identifyJoinTables(schema) // Write models (excluding join tables) for _, table := range schema.Tables { if joinTables[table.Name] { continue // Skip join tables } sb.WriteString(w.tableToPrisma(table, schema, joinTables)) sb.WriteString("\n") } } return sb.String() } // generateDatasource generates the datasource block func (w *Writer) generateDatasource(db *models.Database) string { provider := "postgresql" // Map database type to Prisma provider switch db.DatabaseType { case models.PostgresqlDatabaseType: provider = "postgresql" case models.MSSQLDatabaseType: provider = "sqlserver" case models.SqlLiteDatabaseType: provider = "sqlite" case "mysql": provider = "mysql" } return fmt.Sprintf(`datasource db { provider = "%s" url = env("DATABASE_URL") } `, provider) } // generateGenerator generates the generator block func (w *Writer) generateGenerator() string { return `generator client { provider = "prisma-client-js" } ` } // enumToPrisma converts an Enum to Prisma enum block func (w *Writer) enumToPrisma(enum *models.Enum) string { var sb strings.Builder sb.WriteString(fmt.Sprintf("enum %s {\n", enum.Name)) for _, value := range enum.Values { sb.WriteString(fmt.Sprintf(" %s\n", value)) } sb.WriteString("}\n") return sb.String() } // identifyJoinTables identifies tables that are join tables for M2M relations func (w *Writer) identifyJoinTables(schema *models.Schema) map[string]bool { joinTables := make(map[string]bool) for _, table := range schema.Tables { // Check if this is a join table: // 1. Starts with _ (Prisma convention) // 2. Has exactly 2 FK constraints // 3. Has composite PK with those 2 columns // 4. Has no other columns except the FK columns if !strings.HasPrefix(table.Name, "_") { continue } fks := table.GetForeignKeys() if len(fks) != 2 { continue } // Check if columns are only the FK columns if len(table.Columns) != 2 { continue } // Check if both FK columns are part of PK pkCols := 0 for _, col := range table.Columns { if col.IsPrimaryKey { pkCols++ } } if pkCols == 2 { joinTables[table.Name] = true } } return joinTables } // tableToPrisma converts a Table to Prisma model block func (w *Writer) tableToPrisma(table *models.Table, schema *models.Schema, joinTables map[string]bool) string { var sb strings.Builder sb.WriteString(fmt.Sprintf("model %s {\n", table.Name)) // Collect columns to write columns := make([]*models.Column, 0, len(table.Columns)) for _, col := range table.Columns { columns = append(columns, col) } // Sort columns for consistent output sort.Slice(columns, func(i, j int) bool { return columns[i].Name < columns[j].Name }) // Write scalar fields for _, col := range columns { // Skip if this column is part of a relation that will be output as array field if w.isRelationColumn(col, table) { // We'll output this with the relation field continue } sb.WriteString(w.columnToField(col, table, schema)) } // Write relation fields sb.WriteString(w.generateRelationFields(table, schema, joinTables)) // Write block attributes (@@id, @@unique, @@index) sb.WriteString(w.generateBlockAttributes(table)) sb.WriteString("}\n") return sb.String() } // columnToField converts a Column to a Prisma field definition func (w *Writer) columnToField(col *models.Column, table *models.Table, schema *models.Schema) string { var sb strings.Builder // Field name sb.WriteString(fmt.Sprintf(" %s", col.Name)) // Field type prismaType := w.sqlTypeToPrisma(col.Type, schema) sb.WriteString(fmt.Sprintf(" %s", prismaType)) // Optional modifier if !col.NotNull && !col.IsPrimaryKey { sb.WriteString("?") } // Field attributes attributes := w.generateFieldAttributes(col, table) if attributes != "" { sb.WriteString(" ") sb.WriteString(attributes) } sb.WriteString("\n") return sb.String() } // sqlTypeToPrisma converts SQL types to Prisma types func (w *Writer) sqlTypeToPrisma(sqlType string, schema *models.Schema) string { // Check if it's an enum for _, enum := range schema.Enums { if strings.EqualFold(sqlType, enum.Name) { return enum.Name } } // Standard type mapping typeMap := map[string]string{ "text": "String", "varchar": "String", "character varying": "String", "char": "String", "boolean": "Boolean", "bool": "Boolean", "integer": "Int", "int": "Int", "int4": "Int", "bigint": "BigInt", "int8": "BigInt", "double precision": "Float", "float": "Float", "float8": "Float", "decimal": "Decimal", "numeric": "Decimal", "timestamp": "DateTime", "timestamptz": "DateTime", "date": "DateTime", "jsonb": "Json", "json": "Json", "bytea": "Bytes", } for sqlPattern, prismaType := range typeMap { if strings.Contains(strings.ToLower(sqlType), sqlPattern) { return prismaType } } // Default to String for unknown types return "String" } // generateFieldAttributes generates field attributes like @id, @unique, @default func (w *Writer) generateFieldAttributes(col *models.Column, table *models.Table) string { attrs := make([]string, 0) // @id if col.IsPrimaryKey { // Check if this is part of a composite key pkCount := 0 for _, c := range table.Columns { if c.IsPrimaryKey { pkCount++ } } if pkCount == 1 { attrs = append(attrs, "@id") } } // @unique if w.hasUniqueConstraint(col.Name, table) { attrs = append(attrs, "@unique") } // @default if col.AutoIncrement { attrs = append(attrs, "@default(autoincrement())") } else if col.Default != nil { defaultAttr := w.formatDefaultValue(col.Default) if defaultAttr != "" { attrs = append(attrs, fmt.Sprintf("@default(%s)", defaultAttr)) } } // @updatedAt (check comment) if strings.Contains(col.Comment, "@updatedAt") { attrs = append(attrs, "@updatedAt") } return strings.Join(attrs, " ") } // formatDefaultValue formats a default value for Prisma func (w *Writer) formatDefaultValue(defaultValue any) string { switch v := defaultValue.(type) { case string: if v == "now()" { return "now()" } else if v == "gen_random_uuid()" { return "uuid()" } else if strings.Contains(strings.ToLower(v), "uuid") { return "uuid()" } else { // String literal return fmt.Sprintf(`"%s"`, v) } case bool: if v { return "true" } return "false" case int, int64, int32: return fmt.Sprintf("%v", v) default: return fmt.Sprintf("%v", v) } } // hasUniqueConstraint checks if a column has a unique constraint func (w *Writer) hasUniqueConstraint(colName string, table *models.Table) bool { for _, constraint := range table.Constraints { if constraint.Type == models.UniqueConstraint && len(constraint.Columns) == 1 && constraint.Columns[0] == colName { return true } } return false } // isRelationColumn checks if a column is a FK column func (w *Writer) isRelationColumn(col *models.Column, table *models.Table) bool { for _, constraint := range table.Constraints { if constraint.Type == models.ForeignKeyConstraint { for _, fkCol := range constraint.Columns { if fkCol == col.Name { return true } } } } return false } // generateRelationFields generates relation fields and their FK columns func (w *Writer) generateRelationFields(table *models.Table, schema *models.Schema, joinTables map[string]bool) string { var sb strings.Builder // Get all FK constraints fks := table.GetForeignKeys() for _, fk := range fks { // Generate the FK scalar field for _, fkCol := range fk.Columns { if col, exists := table.Columns[fkCol]; exists { sb.WriteString(w.columnToField(col, table, schema)) } } // Generate the relation field relationType := fk.ReferencedTable isOptional := false // Check if FK column is nullable for _, fkCol := range fk.Columns { if col, exists := table.Columns[fkCol]; exists { if !col.NotNull { isOptional = true } } } relationName := relationType if strings.HasSuffix(strings.ToLower(relationName), "s") { relationName = relationName[:len(relationName)-1] } sb.WriteString(fmt.Sprintf(" %s %s", strings.ToLower(relationName), relationType)) if isOptional { sb.WriteString("?") } // @relation attribute relationAttr := w.generateRelationAttribute(fk) if relationAttr != "" { sb.WriteString(" ") sb.WriteString(relationAttr) } sb.WriteString("\n") } // Generate inverse relations (arrays) for tables that reference this one sb.WriteString(w.generateInverseRelations(table, schema, joinTables)) return sb.String() } // generateRelationAttribute generates the @relation(...) attribute func (w *Writer) generateRelationAttribute(fk *models.Constraint) string { parts := make([]string, 0) // fields fieldsStr := strings.Join(fk.Columns, ", ") parts = append(parts, fmt.Sprintf("fields: [%s]", fieldsStr)) // references referencesStr := strings.Join(fk.ReferencedColumns, ", ") parts = append(parts, fmt.Sprintf("references: [%s]", referencesStr)) // onDelete if fk.OnDelete != "" { parts = append(parts, fmt.Sprintf("onDelete: %s", fk.OnDelete)) } // onUpdate if fk.OnUpdate != "" { parts = append(parts, fmt.Sprintf("onUpdate: %s", fk.OnUpdate)) } return fmt.Sprintf("@relation(%s)", strings.Join(parts, ", ")) } // generateInverseRelations generates array fields for reverse relationships func (w *Writer) generateInverseRelations(table *models.Table, schema *models.Schema, joinTables map[string]bool) string { var sb strings.Builder // Find all tables that have FKs pointing to this table for _, otherTable := range schema.Tables { if otherTable.Name == table.Name { continue } // Check if this is a join table if joinTables[otherTable.Name] { // Handle implicit M2M if w.isJoinTableFor(otherTable, table.Name) { // Find the other side of the M2M for _, fk := range otherTable.GetForeignKeys() { if fk.ReferencedTable != table.Name { // This is the other side otherSide := fk.ReferencedTable sb.WriteString(fmt.Sprintf(" %ss %s[]\n", strings.ToLower(otherSide), otherSide)) break } } } continue } // Regular one-to-many inverse relation for _, fk := range otherTable.GetForeignKeys() { if fk.ReferencedTable == table.Name { // This table is referenced by otherTable pluralName := otherTable.Name if !strings.HasSuffix(pluralName, "s") { pluralName += "s" } sb.WriteString(fmt.Sprintf(" %s %s[]\n", strings.ToLower(pluralName), otherTable.Name)) } } } return sb.String() } // isJoinTableFor checks if a table is a join table involving the specified model func (w *Writer) isJoinTableFor(joinTable *models.Table, modelName string) bool { for _, fk := range joinTable.GetForeignKeys() { if fk.ReferencedTable == modelName { return true } } return false } // generateBlockAttributes generates block-level attributes like @@id, @@unique, @@index func (w *Writer) generateBlockAttributes(table *models.Table) string { var sb strings.Builder // @@id for composite primary key pkCols := make([]string, 0) for _, col := range table.Columns { if col.IsPrimaryKey { pkCols = append(pkCols, col.Name) } } if len(pkCols) > 1 { sort.Strings(pkCols) sb.WriteString(fmt.Sprintf(" @@id([%s])\n", strings.Join(pkCols, ", "))) } // @@unique for multi-column unique constraints for _, constraint := range table.Constraints { if constraint.Type == models.UniqueConstraint && len(constraint.Columns) > 1 { sb.WriteString(fmt.Sprintf(" @@unique([%s])\n", strings.Join(constraint.Columns, ", "))) } } // @@index for indexes for _, index := range table.Indexes { if !index.Unique { // Unique indexes are handled by @@unique sb.WriteString(fmt.Sprintf(" @@index([%s])\n", strings.Join(index.Columns, ", "))) } } return sb.String() }