diff --git a/cmd/relspec/convert.go b/cmd/relspec/convert.go index 9be6658..c75d1ff 100644 --- a/cmd/relspec/convert.go +++ b/cmd/relspec/convert.go @@ -286,79 +286,79 @@ func readDatabaseForConvert(dbType, filePath, connString string) (*models.Databa if filePath == "" { return nil, fmt.Errorf("file path is required for DBML format") } - reader = dbml.NewReader(&readers.ReaderOptions{FilePath: filePath}) + reader = dbml.NewReader(newReaderOptions(filePath, "")) case "dctx": if filePath == "" { return nil, fmt.Errorf("file path is required for DCTX format") } - reader = dctx.NewReader(&readers.ReaderOptions{FilePath: filePath}) + reader = dctx.NewReader(newReaderOptions(filePath, "")) case "drawdb": if filePath == "" { return nil, fmt.Errorf("file path is required for DrawDB format") } - reader = drawdb.NewReader(&readers.ReaderOptions{FilePath: filePath}) + reader = drawdb.NewReader(newReaderOptions(filePath, "")) case "json": if filePath == "" { return nil, fmt.Errorf("file path is required for JSON format") } - reader = json.NewReader(&readers.ReaderOptions{FilePath: filePath}) + reader = json.NewReader(newReaderOptions(filePath, "")) case "yaml", "yml": if filePath == "" { return nil, fmt.Errorf("file path is required for YAML format") } - reader = yaml.NewReader(&readers.ReaderOptions{FilePath: filePath}) + reader = yaml.NewReader(newReaderOptions(filePath, "")) case "pgsql", "postgres", "postgresql": if connString == "" { return nil, fmt.Errorf("connection string is required for PostgreSQL format") } - reader = pgsql.NewReader(&readers.ReaderOptions{ConnectionString: connString}) + reader = pgsql.NewReader(newReaderOptions("", connString)) case "gorm": if filePath == "" { return nil, fmt.Errorf("file path is required for GORM format") } - reader = gorm.NewReader(&readers.ReaderOptions{FilePath: filePath}) + reader = gorm.NewReader(newReaderOptions(filePath, "")) case "bun": if filePath == "" { return nil, fmt.Errorf("file path is required for Bun format") } - reader = bun.NewReader(&readers.ReaderOptions{FilePath: filePath}) + reader = bun.NewReader(newReaderOptions(filePath, "")) case "drizzle": if filePath == "" { return nil, fmt.Errorf("file path is required for Drizzle format") } - reader = drizzle.NewReader(&readers.ReaderOptions{FilePath: filePath}) + reader = drizzle.NewReader(newReaderOptions(filePath, "")) case "prisma": if filePath == "" { return nil, fmt.Errorf("file path is required for Prisma format") } - reader = prisma.NewReader(&readers.ReaderOptions{FilePath: filePath}) + reader = prisma.NewReader(newReaderOptions(filePath, "")) case "typeorm": if filePath == "" { return nil, fmt.Errorf("file path is required for TypeORM format") } - reader = typeorm.NewReader(&readers.ReaderOptions{FilePath: filePath}) + reader = typeorm.NewReader(newReaderOptions(filePath, "")) case "graphql", "gql": if filePath == "" { return nil, fmt.Errorf("file path is required for GraphQL format") } - reader = graphql.NewReader(&readers.ReaderOptions{FilePath: filePath}) + reader = graphql.NewReader(newReaderOptions(filePath, "")) case "mssql", "sqlserver", "mssql2016", "mssql2017", "mssql2019", "mssql2022": if connString == "" { return nil, fmt.Errorf("connection string is required for MSSQL format") } - reader = mssql.NewReader(&readers.ReaderOptions{ConnectionString: connString}) + reader = mssql.NewReader(newReaderOptions("", connString)) case "sqlite", "sqlite3": // SQLite can use either file path or connection string @@ -369,7 +369,7 @@ func readDatabaseForConvert(dbType, filePath, connString string) (*models.Databa if dbPath == "" { return nil, fmt.Errorf("file path or connection string is required for SQLite format") } - reader = sqlite.NewReader(&readers.ReaderOptions{FilePath: dbPath}) + reader = sqlite.NewReader(newReaderOptions(dbPath, "")) default: return nil, fmt.Errorf("unsupported source format: %s", dbType) @@ -386,12 +386,7 @@ func readDatabaseForConvert(dbType, filePath, connString string) (*models.Databa func writeDatabase(db *models.Database, dbType, outputPath, packageName, schemaFilter string, flattenSchema bool, nullableTypes string) error { var writer writers.Writer - writerOpts := &writers.WriterOptions{ - OutputPath: outputPath, - PackageName: packageName, - FlattenSchema: flattenSchema, - NullableTypes: nullableTypes, - } + writerOpts := newWriterOptions(outputPath, packageName, flattenSchema, nullableTypes) switch strings.ToLower(dbType) { case "dbml": diff --git a/cmd/relspec/edit.go b/cmd/relspec/edit.go index 8e5c7e4..fb27325 100644 --- a/cmd/relspec/edit.go +++ b/cmd/relspec/edit.go @@ -240,62 +240,62 @@ func readDatabaseForEdit(dbType, filePath, connString, label string) (*models.Da if filePath == "" { return nil, fmt.Errorf("%s: file path is required for DBML format", label) } - reader = dbml.NewReader(&readers.ReaderOptions{FilePath: filePath}) + reader = dbml.NewReader(newReaderOptions(filePath, "")) case "dctx": if filePath == "" { return nil, fmt.Errorf("%s: file path is required for DCTX format", label) } - reader = dctx.NewReader(&readers.ReaderOptions{FilePath: filePath}) + reader = dctx.NewReader(newReaderOptions(filePath, "")) case "drawdb": if filePath == "" { return nil, fmt.Errorf("%s: file path is required for DrawDB format", label) } - reader = drawdb.NewReader(&readers.ReaderOptions{FilePath: filePath}) + reader = drawdb.NewReader(newReaderOptions(filePath, "")) case "graphql": if filePath == "" { return nil, fmt.Errorf("%s: file path is required for GraphQL format", label) } - reader = graphql.NewReader(&readers.ReaderOptions{FilePath: filePath}) + reader = graphql.NewReader(newReaderOptions(filePath, "")) case "json": if filePath == "" { return nil, fmt.Errorf("%s: file path is required for JSON format", label) } - reader = json.NewReader(&readers.ReaderOptions{FilePath: filePath}) + reader = json.NewReader(newReaderOptions(filePath, "")) case "yaml": if filePath == "" { return nil, fmt.Errorf("%s: file path is required for YAML format", label) } - reader = yaml.NewReader(&readers.ReaderOptions{FilePath: filePath}) + reader = yaml.NewReader(newReaderOptions(filePath, "")) case "gorm": if filePath == "" { return nil, fmt.Errorf("%s: file path is required for GORM format", label) } - reader = gorm.NewReader(&readers.ReaderOptions{FilePath: filePath}) + reader = gorm.NewReader(newReaderOptions(filePath, "")) case "bun": if filePath == "" { return nil, fmt.Errorf("%s: file path is required for Bun format", label) } - reader = bun.NewReader(&readers.ReaderOptions{FilePath: filePath}) + reader = bun.NewReader(newReaderOptions(filePath, "")) case "drizzle": if filePath == "" { return nil, fmt.Errorf("%s: file path is required for Drizzle format", label) } - reader = drizzle.NewReader(&readers.ReaderOptions{FilePath: filePath}) + reader = drizzle.NewReader(newReaderOptions(filePath, "")) case "prisma": if filePath == "" { return nil, fmt.Errorf("%s: file path is required for Prisma format", label) } - reader = prisma.NewReader(&readers.ReaderOptions{FilePath: filePath}) + reader = prisma.NewReader(newReaderOptions(filePath, "")) case "typeorm": if filePath == "" { return nil, fmt.Errorf("%s: file path is required for TypeORM format", label) } - reader = typeorm.NewReader(&readers.ReaderOptions{FilePath: filePath}) + reader = typeorm.NewReader(newReaderOptions(filePath, "")) case "pgsql": if connString == "" { return nil, fmt.Errorf("%s: connection string is required for PostgreSQL format", label) } - reader = pgsql.NewReader(&readers.ReaderOptions{ConnectionString: connString}) + reader = pgsql.NewReader(newReaderOptions("", connString)) case "sqlite", "sqlite3": // SQLite can use either file path or connection string dbPath := filePath @@ -305,7 +305,7 @@ func readDatabaseForEdit(dbType, filePath, connString, label string) (*models.Da if dbPath == "" { return nil, fmt.Errorf("%s: file path or connection string is required for SQLite format", label) } - reader = sqlite.NewReader(&readers.ReaderOptions{FilePath: dbPath}) + reader = sqlite.NewReader(newReaderOptions(dbPath, "")) default: return nil, fmt.Errorf("%s: unsupported format: %s", label, dbType) } @@ -323,31 +323,31 @@ func writeDatabaseForEdit(dbType, filePath, connString string, db *models.Databa switch strings.ToLower(dbType) { case "dbml": - writer = wdbml.NewWriter(&writers.WriterOptions{OutputPath: filePath}) + writer = wdbml.NewWriter(newWriterOptions(filePath, "", false, "")) case "dctx": - writer = wdctx.NewWriter(&writers.WriterOptions{OutputPath: filePath}) + writer = wdctx.NewWriter(newWriterOptions(filePath, "", false, "")) case "drawdb": - writer = wdrawdb.NewWriter(&writers.WriterOptions{OutputPath: filePath}) + writer = wdrawdb.NewWriter(newWriterOptions(filePath, "", false, "")) case "graphql": - writer = wgraphql.NewWriter(&writers.WriterOptions{OutputPath: filePath}) + writer = wgraphql.NewWriter(newWriterOptions(filePath, "", false, "")) case "json": - writer = wjson.NewWriter(&writers.WriterOptions{OutputPath: filePath}) + writer = wjson.NewWriter(newWriterOptions(filePath, "", false, "")) case "yaml": - writer = wyaml.NewWriter(&writers.WriterOptions{OutputPath: filePath}) + writer = wyaml.NewWriter(newWriterOptions(filePath, "", false, "")) case "gorm": - writer = wgorm.NewWriter(&writers.WriterOptions{OutputPath: filePath}) + writer = wgorm.NewWriter(newWriterOptions(filePath, "", false, "")) case "bun": - writer = wbun.NewWriter(&writers.WriterOptions{OutputPath: filePath}) + writer = wbun.NewWriter(newWriterOptions(filePath, "", false, "")) case "drizzle": - writer = wdrizzle.NewWriter(&writers.WriterOptions{OutputPath: filePath}) + writer = wdrizzle.NewWriter(newWriterOptions(filePath, "", false, "")) case "prisma": - writer = wprisma.NewWriter(&writers.WriterOptions{OutputPath: filePath}) + writer = wprisma.NewWriter(newWriterOptions(filePath, "", false, "")) case "typeorm": - writer = wtypeorm.NewWriter(&writers.WriterOptions{OutputPath: filePath}) + writer = wtypeorm.NewWriter(newWriterOptions(filePath, "", false, "")) case "sqlite", "sqlite3": - writer = wsqlite.NewWriter(&writers.WriterOptions{OutputPath: filePath}) + writer = wsqlite.NewWriter(newWriterOptions(filePath, "", false, "")) case "pgsql": - writer = wpgsql.NewWriter(&writers.WriterOptions{OutputPath: filePath}) + writer = wpgsql.NewWriter(newWriterOptions(filePath, "", false, "")) default: return fmt.Errorf("%s: unsupported format: %s", label, dbType) } diff --git a/cmd/relspec/inspect.go b/cmd/relspec/inspect.go index c9d95cf..bfccb78 100644 --- a/cmd/relspec/inspect.go +++ b/cmd/relspec/inspect.go @@ -221,73 +221,73 @@ func readDatabaseForInspect(dbType, filePath, connString string) (*models.Databa if filePath == "" { return nil, fmt.Errorf("file path is required for DBML format") } - reader = dbml.NewReader(&readers.ReaderOptions{FilePath: filePath}) + reader = dbml.NewReader(newReaderOptions(filePath, "")) case "dctx": if filePath == "" { return nil, fmt.Errorf("file path is required for DCTX format") } - reader = dctx.NewReader(&readers.ReaderOptions{FilePath: filePath}) + reader = dctx.NewReader(newReaderOptions(filePath, "")) case "drawdb": if filePath == "" { return nil, fmt.Errorf("file path is required for DrawDB format") } - reader = drawdb.NewReader(&readers.ReaderOptions{FilePath: filePath}) + reader = drawdb.NewReader(newReaderOptions(filePath, "")) case "graphql": if filePath == "" { return nil, fmt.Errorf("file path is required for GraphQL format") } - reader = graphql.NewReader(&readers.ReaderOptions{FilePath: filePath}) + reader = graphql.NewReader(newReaderOptions(filePath, "")) case "json": if filePath == "" { return nil, fmt.Errorf("file path is required for JSON format") } - reader = json.NewReader(&readers.ReaderOptions{FilePath: filePath}) + reader = json.NewReader(newReaderOptions(filePath, "")) case "yaml", "yml": if filePath == "" { return nil, fmt.Errorf("file path is required for YAML format") } - reader = yaml.NewReader(&readers.ReaderOptions{FilePath: filePath}) + reader = yaml.NewReader(newReaderOptions(filePath, "")) case "gorm": if filePath == "" { return nil, fmt.Errorf("file path is required for GORM format") } - reader = gorm.NewReader(&readers.ReaderOptions{FilePath: filePath}) + reader = gorm.NewReader(newReaderOptions(filePath, "")) case "bun": if filePath == "" { return nil, fmt.Errorf("file path is required for Bun format") } - reader = bun.NewReader(&readers.ReaderOptions{FilePath: filePath}) + reader = bun.NewReader(newReaderOptions(filePath, "")) case "drizzle": if filePath == "" { return nil, fmt.Errorf("file path is required for Drizzle format") } - reader = drizzle.NewReader(&readers.ReaderOptions{FilePath: filePath}) + reader = drizzle.NewReader(newReaderOptions(filePath, "")) case "prisma": if filePath == "" { return nil, fmt.Errorf("file path is required for Prisma format") } - reader = prisma.NewReader(&readers.ReaderOptions{FilePath: filePath}) + reader = prisma.NewReader(newReaderOptions(filePath, "")) case "typeorm": if filePath == "" { return nil, fmt.Errorf("file path is required for TypeORM format") } - reader = typeorm.NewReader(&readers.ReaderOptions{FilePath: filePath}) + reader = typeorm.NewReader(newReaderOptions(filePath, "")) case "pgsql", "postgres", "postgresql": if connString == "" { return nil, fmt.Errorf("connection string is required for PostgreSQL format") } - reader = pgsql.NewReader(&readers.ReaderOptions{ConnectionString: connString}) + reader = pgsql.NewReader(newReaderOptions("", connString)) case "sqlite", "sqlite3": // SQLite can use either file path or connection string @@ -298,7 +298,7 @@ func readDatabaseForInspect(dbType, filePath, connString string) (*models.Databa if dbPath == "" { return nil, fmt.Errorf("file path or connection string is required for SQLite format") } - reader = sqlite.NewReader(&readers.ReaderOptions{FilePath: dbPath}) + reader = sqlite.NewReader(newReaderOptions(dbPath, "")) default: return nil, fmt.Errorf("unsupported database type: %s", dbType) diff --git a/cmd/relspec/merge.go b/cmd/relspec/merge.go index 071b0a3..55281e3 100644 --- a/cmd/relspec/merge.go +++ b/cmd/relspec/merge.go @@ -284,62 +284,62 @@ func readDatabaseForMerge(dbType, filePath, connString, label string) (*models.D if filePath == "" { return nil, fmt.Errorf("%s: file path is required for DBML format", label) } - reader = dbml.NewReader(&readers.ReaderOptions{FilePath: filePath}) + reader = dbml.NewReader(newReaderOptions(filePath, "")) case "dctx": if filePath == "" { return nil, fmt.Errorf("%s: file path is required for DCTX format", label) } - reader = dctx.NewReader(&readers.ReaderOptions{FilePath: filePath}) + reader = dctx.NewReader(newReaderOptions(filePath, "")) case "drawdb": if filePath == "" { return nil, fmt.Errorf("%s: file path is required for DrawDB format", label) } - reader = drawdb.NewReader(&readers.ReaderOptions{FilePath: filePath}) + reader = drawdb.NewReader(newReaderOptions(filePath, "")) case "graphql": if filePath == "" { return nil, fmt.Errorf("%s: file path is required for GraphQL format", label) } - reader = graphql.NewReader(&readers.ReaderOptions{FilePath: filePath}) + reader = graphql.NewReader(newReaderOptions(filePath, "")) case "json": if filePath == "" { return nil, fmt.Errorf("%s: file path is required for JSON format", label) } - reader = json.NewReader(&readers.ReaderOptions{FilePath: filePath}) + reader = json.NewReader(newReaderOptions(filePath, "")) case "yaml": if filePath == "" { return nil, fmt.Errorf("%s: file path is required for YAML format", label) } - reader = yaml.NewReader(&readers.ReaderOptions{FilePath: filePath}) + reader = yaml.NewReader(newReaderOptions(filePath, "")) case "gorm": if filePath == "" { return nil, fmt.Errorf("%s: file path is required for GORM format", label) } - reader = gorm.NewReader(&readers.ReaderOptions{FilePath: filePath}) + reader = gorm.NewReader(newReaderOptions(filePath, "")) case "bun": if filePath == "" { return nil, fmt.Errorf("%s: file path is required for Bun format", label) } - reader = bun.NewReader(&readers.ReaderOptions{FilePath: filePath}) + reader = bun.NewReader(newReaderOptions(filePath, "")) case "drizzle": if filePath == "" { return nil, fmt.Errorf("%s: file path is required for Drizzle format", label) } - reader = drizzle.NewReader(&readers.ReaderOptions{FilePath: filePath}) + reader = drizzle.NewReader(newReaderOptions(filePath, "")) case "prisma": if filePath == "" { return nil, fmt.Errorf("%s: file path is required for Prisma format", label) } - reader = prisma.NewReader(&readers.ReaderOptions{FilePath: filePath}) + reader = prisma.NewReader(newReaderOptions(filePath, "")) case "typeorm": if filePath == "" { return nil, fmt.Errorf("%s: file path is required for TypeORM format", label) } - reader = typeorm.NewReader(&readers.ReaderOptions{FilePath: filePath}) + reader = typeorm.NewReader(newReaderOptions(filePath, "")) case "pgsql": if connString == "" { return nil, fmt.Errorf("%s: connection string is required for PostgreSQL format", label) } - reader = pgsql.NewReader(&readers.ReaderOptions{ConnectionString: connString}) + reader = pgsql.NewReader(newReaderOptions("", connString)) case "sqlite", "sqlite3": // SQLite can use either file path or connection string dbPath := filePath @@ -349,7 +349,7 @@ func readDatabaseForMerge(dbType, filePath, connString, label string) (*models.D if dbPath == "" { return nil, fmt.Errorf("%s: file path or connection string is required for SQLite format", label) } - reader = sqlite.NewReader(&readers.ReaderOptions{FilePath: dbPath}) + reader = sqlite.NewReader(newReaderOptions(dbPath, "")) default: return nil, fmt.Errorf("%s: unsupported format '%s'", label, dbType) } @@ -370,61 +370,61 @@ func writeDatabaseForMerge(dbType, filePath, connString string, db *models.Datab if filePath == "" { return fmt.Errorf("%s: file path is required for DBML format", label) } - writer = wdbml.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema}) + writer = wdbml.NewWriter(newWriterOptions(filePath, "", flattenSchema, "")) case "dctx": if filePath == "" { return fmt.Errorf("%s: file path is required for DCTX format", label) } - writer = wdctx.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema}) + writer = wdctx.NewWriter(newWriterOptions(filePath, "", flattenSchema, "")) case "drawdb": if filePath == "" { return fmt.Errorf("%s: file path is required for DrawDB format", label) } - writer = wdrawdb.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema}) + writer = wdrawdb.NewWriter(newWriterOptions(filePath, "", flattenSchema, "")) case "graphql": if filePath == "" { return fmt.Errorf("%s: file path is required for GraphQL format", label) } - writer = wgraphql.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema}) + writer = wgraphql.NewWriter(newWriterOptions(filePath, "", flattenSchema, "")) case "json": if filePath == "" { return fmt.Errorf("%s: file path is required for JSON format", label) } - writer = wjson.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema}) + writer = wjson.NewWriter(newWriterOptions(filePath, "", flattenSchema, "")) case "yaml": if filePath == "" { return fmt.Errorf("%s: file path is required for YAML format", label) } - writer = wyaml.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema}) + writer = wyaml.NewWriter(newWriterOptions(filePath, "", flattenSchema, "")) case "gorm": if filePath == "" { return fmt.Errorf("%s: file path is required for GORM format", label) } - writer = wgorm.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema}) + writer = wgorm.NewWriter(newWriterOptions(filePath, "", flattenSchema, "")) case "bun": if filePath == "" { return fmt.Errorf("%s: file path is required for Bun format", label) } - writer = wbun.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema}) + writer = wbun.NewWriter(newWriterOptions(filePath, "", flattenSchema, "")) case "drizzle": if filePath == "" { return fmt.Errorf("%s: file path is required for Drizzle format", label) } - writer = wdrizzle.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema}) + writer = wdrizzle.NewWriter(newWriterOptions(filePath, "", flattenSchema, "")) case "prisma": if filePath == "" { return fmt.Errorf("%s: file path is required for Prisma format", label) } - writer = wprisma.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema}) + writer = wprisma.NewWriter(newWriterOptions(filePath, "", flattenSchema, "")) case "typeorm": if filePath == "" { return fmt.Errorf("%s: file path is required for TypeORM format", label) } - writer = wtypeorm.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema}) + writer = wtypeorm.NewWriter(newWriterOptions(filePath, "", flattenSchema, "")) case "sqlite", "sqlite3": - writer = wsqlite.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema}) + writer = wsqlite.NewWriter(newWriterOptions(filePath, "", flattenSchema, "")) case "pgsql": - writerOpts := &writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema} + writerOpts := newWriterOptions(filePath, "", flattenSchema, "") if connString != "" { writerOpts.Metadata = map[string]interface{}{ "connection_string": connString, diff --git a/cmd/relspec/prisma_options.go b/cmd/relspec/prisma_options.go new file mode 100644 index 0000000..2675640 --- /dev/null +++ b/cmd/relspec/prisma_options.go @@ -0,0 +1,24 @@ +package main + +import ( + "git.warky.dev/wdevs/relspecgo/pkg/readers" + "git.warky.dev/wdevs/relspecgo/pkg/writers" +) + +func newReaderOptions(filePath, connString string) *readers.ReaderOptions { + return &readers.ReaderOptions{ + FilePath: filePath, + ConnectionString: connString, + Prisma7: prisma7, + } +} + +func newWriterOptions(outputPath, packageName string, flattenSchema bool, nullableTypes string) *writers.WriterOptions { + return &writers.WriterOptions{ + OutputPath: outputPath, + PackageName: packageName, + FlattenSchema: flattenSchema, + NullableTypes: nullableTypes, + Prisma7: prisma7, + } +} diff --git a/cmd/relspec/root.go b/cmd/relspec/root.go index ad2e672..32ed5b4 100644 --- a/cmd/relspec/root.go +++ b/cmd/relspec/root.go @@ -12,6 +12,7 @@ var ( // Version information, set via ldflags during build version = "dev" buildDate = "unknown" + prisma7 bool ) func init() { @@ -68,4 +69,5 @@ func init() { rootCmd.AddCommand(mergeCmd) rootCmd.AddCommand(splitCmd) rootCmd.AddCommand(versionCmd) + rootCmd.PersistentFlags().BoolVar(&prisma7, "prisma7", false, "Use Prisma 7 generator conventions when reading/writing Prisma schemas") } diff --git a/pkg/readers/prisma/reader.go b/pkg/readers/prisma/reader.go index 788f8e2..815329e 100644 --- a/pkg/readers/prisma/reader.go +++ b/pkg/readers/prisma/reader.go @@ -70,6 +70,7 @@ func (r *Reader) ReadTable() (*models.Table, error) { // parsePrisma parses Prisma schema content and returns a Database model func (r *Reader) parsePrisma(content string) (*models.Database, error) { db := models.InitDatabase("database") + db.SourceFormat = "prisma" if r.options.Metadata != nil { if name, ok := r.options.Metadata["name"].(string); ok { @@ -139,7 +140,7 @@ func (r *Reader) parsePrisma(content string) (*models.Database, error) { case "datasource": r.parseDatasource(blockContent, db) case "generator": - // We don't need to do anything with generator blocks + r.parseGenerator(blockContent, db) case "model": if currentTable != nil { r.parseModelFields(blockContent, currentTable) @@ -173,10 +174,34 @@ func (r *Reader) parsePrisma(content string) (*models.Database, error) { // Second pass: resolve relationships r.resolveRelationships(schema) + if db.SourceFormat == "prisma" && r.options != nil && r.options.Prisma7 { + db.SourceFormat = "prisma7" + } + db.Schemas = append(db.Schemas, schema) return db, nil } +func (r *Reader) parseGenerator(lines []string, db *models.Database) { + providerRegex := regexp.MustCompile(`provider\s*=\s*"([^"]+)"`) + + for _, line := range lines { + if matches := providerRegex.FindStringSubmatch(line); matches != nil { + switch matches[1] { + case "prisma-client": + db.SourceFormat = "prisma7" + default: + db.SourceFormat = "prisma" + } + return + } + } + + if r.options != nil && r.options.Prisma7 { + db.SourceFormat = "prisma7" + } +} + // parseDatasource extracts database type from datasource block func (r *Reader) parseDatasource(lines []string, db *models.Database) { providerRegex := regexp.MustCompile(`provider\s*=\s*"?(\w+)"?`) diff --git a/pkg/readers/prisma/reader_test.go b/pkg/readers/prisma/reader_test.go new file mode 100644 index 0000000..d381d1c --- /dev/null +++ b/pkg/readers/prisma/reader_test.go @@ -0,0 +1,77 @@ +package prisma + +import ( + "os" + "path/filepath" + "testing" + + "git.warky.dev/wdevs/relspecgo/pkg/readers" +) + +func TestReadDatabase_Prisma7GeneratorSetsSourceFormat(t *testing.T) { + t.Parallel() + + tmpDir := t.TempDir() + schemaPath := filepath.Join(tmpDir, "schema.prisma") + + content := `datasource db { + provider = "postgresql" + url = env("DATABASE_URL") +} + +generator client { + provider = "prisma-client" + output = "./generated" +} + +model User { + id Int @id @default(autoincrement()) +}` + + if err := os.WriteFile(schemaPath, []byte(content), 0644); err != nil { + t.Fatalf("failed to write schema: %v", err) + } + + reader := NewReader(&readers.ReaderOptions{FilePath: schemaPath}) + db, err := reader.ReadDatabase() + if err != nil { + t.Fatalf("ReadDatabase() failed: %v", err) + } + + if db.SourceFormat != "prisma7" { + t.Fatalf("expected SourceFormat prisma7, got %q", db.SourceFormat) + } +} + +func TestReadDatabase_Prisma7FlagSetsSourceFormatWithoutGenerator(t *testing.T) { + t.Parallel() + + tmpDir := t.TempDir() + schemaPath := filepath.Join(tmpDir, "schema.prisma") + + content := `datasource db { + provider = "postgresql" + url = env("DATABASE_URL") +} + +model User { + id Int @id @default(autoincrement()) +}` + + if err := os.WriteFile(schemaPath, []byte(content), 0644); err != nil { + t.Fatalf("failed to write schema: %v", err) + } + + reader := NewReader(&readers.ReaderOptions{ + FilePath: schemaPath, + Prisma7: true, + }) + db, err := reader.ReadDatabase() + if err != nil { + t.Fatalf("ReadDatabase() failed: %v", err) + } + + if db.SourceFormat != "prisma7" { + t.Fatalf("expected SourceFormat prisma7 from flag, got %q", db.SourceFormat) + } +} diff --git a/pkg/readers/reader.go b/pkg/readers/reader.go index f8e5d16..ceee81f 100644 --- a/pkg/readers/reader.go +++ b/pkg/readers/reader.go @@ -25,6 +25,9 @@ type ReaderOptions struct { // ConnectionString is the database connection string (for DB readers) ConnectionString string + // Prisma7 enables Prisma 7-specific handling for Prisma schemas. + Prisma7 bool + // Additional options can be added here as needed Metadata map[string]interface{} } diff --git a/pkg/writers/prisma/writer.go b/pkg/writers/prisma/writer.go index ce253a4..1123e7d 100644 --- a/pkg/writers/prisma/writer.go +++ b/pkg/writers/prisma/writer.go @@ -61,7 +61,7 @@ func (w *Writer) databaseToPrisma(db *models.Database) string { sb.WriteString("\n") // Write generator block - sb.WriteString(w.generateGenerator()) + sb.WriteString(w.generateGenerator(db)) sb.WriteString("\n") // Process all schemas (typically just one in Prisma) @@ -114,13 +114,28 @@ func (w *Writer) generateDatasource(db *models.Database) string { } // generateGenerator generates the generator block -func (w *Writer) generateGenerator() string { +func (w *Writer) generateGenerator(db *models.Database) string { + if w.usePrisma7Generator(db) { + return `generator client { + provider = "prisma-client" + output = "./generated" +} +` + } + return `generator client { provider = "prisma-client-js" } ` } +func (w *Writer) usePrisma7Generator(db *models.Database) bool { + if w.options != nil && w.options.Prisma7 { + return true + } + return db != nil && db.SourceFormat == "prisma7" +} + // enumToPrisma converts an Enum to Prisma enum block func (w *Writer) enumToPrisma(enum *models.Enum) string { var sb strings.Builder diff --git a/pkg/writers/prisma/writer_test.go b/pkg/writers/prisma/writer_test.go new file mode 100644 index 0000000..a43b6b4 --- /dev/null +++ b/pkg/writers/prisma/writer_test.go @@ -0,0 +1,52 @@ +package prisma + +import ( + "strings" + "testing" + + "git.warky.dev/wdevs/relspecgo/pkg/models" + "git.warky.dev/wdevs/relspecgo/pkg/writers" +) + +func TestGenerateGenerator_DefaultsToPrismaClientJS(t *testing.T) { + t.Parallel() + + writer := NewWriter(&writers.WriterOptions{}) + db := models.InitDatabase("testdb") + + got := writer.generateGenerator(db) + if !strings.Contains(got, `provider = "prisma-client-js"`) { + t.Fatalf("expected prisma-client-js generator, got:\n%s", got) + } + if strings.Contains(got, `output = "./generated"`) { + t.Fatalf("did not expect prisma7 output path in default generator:\n%s", got) + } +} + +func TestGenerateGenerator_Prisma7FlagUsesPrismaClient(t *testing.T) { + t.Parallel() + + writer := NewWriter(&writers.WriterOptions{Prisma7: true}) + db := models.InitDatabase("testdb") + + got := writer.generateGenerator(db) + if !strings.Contains(got, `provider = "prisma-client"`) { + t.Fatalf("expected prisma-client generator, got:\n%s", got) + } + if !strings.Contains(got, `output = "./generated"`) { + t.Fatalf("expected prisma7 output path, got:\n%s", got) + } +} + +func TestGenerateGenerator_Prisma7SourceFormatUsesPrismaClient(t *testing.T) { + t.Parallel() + + writer := NewWriter(&writers.WriterOptions{}) + db := models.InitDatabase("testdb") + db.SourceFormat = "prisma7" + + got := writer.generateGenerator(db) + if !strings.Contains(got, `provider = "prisma-client"`) { + t.Fatalf("expected prisma-client generator from source format, got:\n%s", got) + } +} diff --git a/pkg/writers/writer.go b/pkg/writers/writer.go index 0637ebf..137c4c6 100644 --- a/pkg/writers/writer.go +++ b/pkg/writers/writer.go @@ -51,6 +51,9 @@ type WriterOptions struct { // "stdlib" — database/sql (sql.NullString, sql.NullInt32, …) NullableTypes string + // Prisma7 enables Prisma 7-specific output for Prisma writers. + Prisma7 bool + // Additional options can be added here as needed Metadata map[string]interface{} }