Compare commits

...

13 Commits

Author SHA1 Message Date
53ff745d5d chore(release): update package version to 1.0.54
All checks were successful
Release / test (push) Successful in -31m47s
Release / release (push) Successful in -31m9s
Release / pkg-aur (push) Successful in -31m57s
Release / pkg-deb (push) Successful in -31m1s
Release / pkg-rpm (push) Successful in -29m27s
2026-05-05 11:12:49 +02:00
17bc8ed395 feat(migration): enhance primary key handling and add GIN index support in migration writer 2026-05-05 11:12:23 +02:00
a447b68b22 chore(release): update package version to 1.0.53
All checks were successful
Release / test (push) Successful in -31m55s
Release / release (push) Successful in -31m19s
Release / pkg-aur (push) Successful in -32m3s
Release / pkg-deb (push) Successful in -31m21s
Release / pkg-rpm (push) Successful in -28m4s
2026-05-05 10:48:27 +02:00
4303dcf59b Support typed primary key helpers in gorm and bun writers 2026-05-05 10:32:33 +02:00
e828d48798 chore(release): update package version to 1.0.52
All checks were successful
Release / test (push) Successful in -32m39s
Release / release (push) Successful in -32m1s
Release / pkg-deb (push) Successful in -32m9s
Release / pkg-aur (push) Successful in -31m37s
Release / pkg-rpm (push) Successful in -27m28s
2026-05-03 17:19:22 +02:00
6e470a9239 fix(type_mapper): adjust array tag handling in BuildBunTag 2026-05-03 17:18:58 +02:00
096815fe49 chore(release): update package version to 1.0.51
All checks were successful
Release / test (push) Successful in -32m30s
Release / release (push) Successful in -31m54s
Release / pkg-aur (push) Successful in -32m31s
Release / pkg-deb (push) Successful in -32m7s
Release / pkg-rpm (push) Successful in -30m36s
2026-05-03 16:11:13 +02:00
b8f60203cb fix(type_mapper): handle PostgreSQL array types in tags
* Update BuildBunTag to append "array" for array types
* Add tests for handling array types in TypeMapper
* Adjust regex in SanitizeStructTagValue to preserve array suffix
2026-05-03 16:11:01 +02:00
Hein
15763f60cc Fix GIN opclass handling for array columns 2026-04-30 20:35:06 +02:00
Hein
6d2884f5cf chore(release): update package version to 1.0.50
Some checks failed
Release / test (push) Successful in -32m41s
Release / release (push) Successful in -28m56s
Release / pkg-deb (push) Successful in -31m19s
Release / pkg-aur (push) Successful in -27m21s
Release / pkg-rpm (push) Failing after -26m24s
2026-04-30 20:23:29 +02:00
Hein
f192decff8 Add Prisma 7 flag support 2026-04-30 20:22:57 +02:00
Hein
8b906cf4a3 chore(release): update package version to 1.0.49
All checks were successful
Release / test (push) Successful in -32m39s
Release / release (push) Successful in -31m40s
Release / pkg-aur (push) Successful in -32m46s
Release / pkg-deb (push) Successful in -32m9s
Release / pkg-rpm (push) Successful in -29m53s
2026-04-30 18:16:28 +02:00
Hein
0a3966e6fc fix(pgsql): handle default values for array types in migrations
* update default value quoting logic for PostgreSQL
* add tests for array default value handling
2026-04-30 18:16:21 +02:00
30 changed files with 1134 additions and 167 deletions

View File

@@ -286,79 +286,79 @@ func readDatabaseForConvert(dbType, filePath, connString string) (*models.Databa
if filePath == "" { if filePath == "" {
return nil, fmt.Errorf("file path is required for DBML format") return nil, fmt.Errorf("file path is required for DBML format")
} }
reader = dbml.NewReader(&readers.ReaderOptions{FilePath: filePath}) reader = dbml.NewReader(newReaderOptions(filePath, ""))
case "dctx": case "dctx":
if filePath == "" { if filePath == "" {
return nil, fmt.Errorf("file path is required for DCTX format") return nil, fmt.Errorf("file path is required for DCTX format")
} }
reader = dctx.NewReader(&readers.ReaderOptions{FilePath: filePath}) reader = dctx.NewReader(newReaderOptions(filePath, ""))
case "drawdb": case "drawdb":
if filePath == "" { if filePath == "" {
return nil, fmt.Errorf("file path is required for DrawDB format") return nil, fmt.Errorf("file path is required for DrawDB format")
} }
reader = drawdb.NewReader(&readers.ReaderOptions{FilePath: filePath}) reader = drawdb.NewReader(newReaderOptions(filePath, ""))
case "json": case "json":
if filePath == "" { if filePath == "" {
return nil, fmt.Errorf("file path is required for JSON format") return nil, fmt.Errorf("file path is required for JSON format")
} }
reader = json.NewReader(&readers.ReaderOptions{FilePath: filePath}) reader = json.NewReader(newReaderOptions(filePath, ""))
case "yaml", "yml": case "yaml", "yml":
if filePath == "" { if filePath == "" {
return nil, fmt.Errorf("file path is required for YAML format") return nil, fmt.Errorf("file path is required for YAML format")
} }
reader = yaml.NewReader(&readers.ReaderOptions{FilePath: filePath}) reader = yaml.NewReader(newReaderOptions(filePath, ""))
case "pgsql", "postgres", "postgresql": case "pgsql", "postgres", "postgresql":
if connString == "" { if connString == "" {
return nil, fmt.Errorf("connection string is required for PostgreSQL format") return nil, fmt.Errorf("connection string is required for PostgreSQL format")
} }
reader = pgsql.NewReader(&readers.ReaderOptions{ConnectionString: connString}) reader = pgsql.NewReader(newReaderOptions("", connString))
case "gorm": case "gorm":
if filePath == "" { if filePath == "" {
return nil, fmt.Errorf("file path is required for GORM format") return nil, fmt.Errorf("file path is required for GORM format")
} }
reader = gorm.NewReader(&readers.ReaderOptions{FilePath: filePath}) reader = gorm.NewReader(newReaderOptions(filePath, ""))
case "bun": case "bun":
if filePath == "" { if filePath == "" {
return nil, fmt.Errorf("file path is required for Bun format") return nil, fmt.Errorf("file path is required for Bun format")
} }
reader = bun.NewReader(&readers.ReaderOptions{FilePath: filePath}) reader = bun.NewReader(newReaderOptions(filePath, ""))
case "drizzle": case "drizzle":
if filePath == "" { if filePath == "" {
return nil, fmt.Errorf("file path is required for Drizzle format") return nil, fmt.Errorf("file path is required for Drizzle format")
} }
reader = drizzle.NewReader(&readers.ReaderOptions{FilePath: filePath}) reader = drizzle.NewReader(newReaderOptions(filePath, ""))
case "prisma": case "prisma":
if filePath == "" { if filePath == "" {
return nil, fmt.Errorf("file path is required for Prisma format") return nil, fmt.Errorf("file path is required for Prisma format")
} }
reader = prisma.NewReader(&readers.ReaderOptions{FilePath: filePath}) reader = prisma.NewReader(newReaderOptions(filePath, ""))
case "typeorm": case "typeorm":
if filePath == "" { if filePath == "" {
return nil, fmt.Errorf("file path is required for TypeORM format") return nil, fmt.Errorf("file path is required for TypeORM format")
} }
reader = typeorm.NewReader(&readers.ReaderOptions{FilePath: filePath}) reader = typeorm.NewReader(newReaderOptions(filePath, ""))
case "graphql", "gql": case "graphql", "gql":
if filePath == "" { if filePath == "" {
return nil, fmt.Errorf("file path is required for GraphQL format") return nil, fmt.Errorf("file path is required for GraphQL format")
} }
reader = graphql.NewReader(&readers.ReaderOptions{FilePath: filePath}) reader = graphql.NewReader(newReaderOptions(filePath, ""))
case "mssql", "sqlserver", "mssql2016", "mssql2017", "mssql2019", "mssql2022": case "mssql", "sqlserver", "mssql2016", "mssql2017", "mssql2019", "mssql2022":
if connString == "" { if connString == "" {
return nil, fmt.Errorf("connection string is required for MSSQL format") return nil, fmt.Errorf("connection string is required for MSSQL format")
} }
reader = mssql.NewReader(&readers.ReaderOptions{ConnectionString: connString}) reader = mssql.NewReader(newReaderOptions("", connString))
case "sqlite", "sqlite3": case "sqlite", "sqlite3":
// SQLite can use either file path or connection string // SQLite can use either file path or connection string
@@ -369,7 +369,7 @@ func readDatabaseForConvert(dbType, filePath, connString string) (*models.Databa
if dbPath == "" { if dbPath == "" {
return nil, fmt.Errorf("file path or connection string is required for SQLite format") return nil, fmt.Errorf("file path or connection string is required for SQLite format")
} }
reader = sqlite.NewReader(&readers.ReaderOptions{FilePath: dbPath}) reader = sqlite.NewReader(newReaderOptions(dbPath, ""))
default: default:
return nil, fmt.Errorf("unsupported source format: %s", dbType) return nil, fmt.Errorf("unsupported source format: %s", dbType)
@@ -386,12 +386,7 @@ func readDatabaseForConvert(dbType, filePath, connString string) (*models.Databa
func writeDatabase(db *models.Database, dbType, outputPath, packageName, schemaFilter string, flattenSchema bool, nullableTypes string) error { func writeDatabase(db *models.Database, dbType, outputPath, packageName, schemaFilter string, flattenSchema bool, nullableTypes string) error {
var writer writers.Writer var writer writers.Writer
writerOpts := &writers.WriterOptions{ writerOpts := newWriterOptions(outputPath, packageName, flattenSchema, nullableTypes)
OutputPath: outputPath,
PackageName: packageName,
FlattenSchema: flattenSchema,
NullableTypes: nullableTypes,
}
switch strings.ToLower(dbType) { switch strings.ToLower(dbType) {
case "dbml": case "dbml":

View File

@@ -240,62 +240,62 @@ func readDatabaseForEdit(dbType, filePath, connString, label string) (*models.Da
if filePath == "" { if filePath == "" {
return nil, fmt.Errorf("%s: file path is required for DBML format", label) return nil, fmt.Errorf("%s: file path is required for DBML format", label)
} }
reader = dbml.NewReader(&readers.ReaderOptions{FilePath: filePath}) reader = dbml.NewReader(newReaderOptions(filePath, ""))
case "dctx": case "dctx":
if filePath == "" { if filePath == "" {
return nil, fmt.Errorf("%s: file path is required for DCTX format", label) return nil, fmt.Errorf("%s: file path is required for DCTX format", label)
} }
reader = dctx.NewReader(&readers.ReaderOptions{FilePath: filePath}) reader = dctx.NewReader(newReaderOptions(filePath, ""))
case "drawdb": case "drawdb":
if filePath == "" { if filePath == "" {
return nil, fmt.Errorf("%s: file path is required for DrawDB format", label) return nil, fmt.Errorf("%s: file path is required for DrawDB format", label)
} }
reader = drawdb.NewReader(&readers.ReaderOptions{FilePath: filePath}) reader = drawdb.NewReader(newReaderOptions(filePath, ""))
case "graphql": case "graphql":
if filePath == "" { if filePath == "" {
return nil, fmt.Errorf("%s: file path is required for GraphQL format", label) return nil, fmt.Errorf("%s: file path is required for GraphQL format", label)
} }
reader = graphql.NewReader(&readers.ReaderOptions{FilePath: filePath}) reader = graphql.NewReader(newReaderOptions(filePath, ""))
case "json": case "json":
if filePath == "" { if filePath == "" {
return nil, fmt.Errorf("%s: file path is required for JSON format", label) return nil, fmt.Errorf("%s: file path is required for JSON format", label)
} }
reader = json.NewReader(&readers.ReaderOptions{FilePath: filePath}) reader = json.NewReader(newReaderOptions(filePath, ""))
case "yaml": case "yaml":
if filePath == "" { if filePath == "" {
return nil, fmt.Errorf("%s: file path is required for YAML format", label) return nil, fmt.Errorf("%s: file path is required for YAML format", label)
} }
reader = yaml.NewReader(&readers.ReaderOptions{FilePath: filePath}) reader = yaml.NewReader(newReaderOptions(filePath, ""))
case "gorm": case "gorm":
if filePath == "" { if filePath == "" {
return nil, fmt.Errorf("%s: file path is required for GORM format", label) return nil, fmt.Errorf("%s: file path is required for GORM format", label)
} }
reader = gorm.NewReader(&readers.ReaderOptions{FilePath: filePath}) reader = gorm.NewReader(newReaderOptions(filePath, ""))
case "bun": case "bun":
if filePath == "" { if filePath == "" {
return nil, fmt.Errorf("%s: file path is required for Bun format", label) return nil, fmt.Errorf("%s: file path is required for Bun format", label)
} }
reader = bun.NewReader(&readers.ReaderOptions{FilePath: filePath}) reader = bun.NewReader(newReaderOptions(filePath, ""))
case "drizzle": case "drizzle":
if filePath == "" { if filePath == "" {
return nil, fmt.Errorf("%s: file path is required for Drizzle format", label) return nil, fmt.Errorf("%s: file path is required for Drizzle format", label)
} }
reader = drizzle.NewReader(&readers.ReaderOptions{FilePath: filePath}) reader = drizzle.NewReader(newReaderOptions(filePath, ""))
case "prisma": case "prisma":
if filePath == "" { if filePath == "" {
return nil, fmt.Errorf("%s: file path is required for Prisma format", label) return nil, fmt.Errorf("%s: file path is required for Prisma format", label)
} }
reader = prisma.NewReader(&readers.ReaderOptions{FilePath: filePath}) reader = prisma.NewReader(newReaderOptions(filePath, ""))
case "typeorm": case "typeorm":
if filePath == "" { if filePath == "" {
return nil, fmt.Errorf("%s: file path is required for TypeORM format", label) return nil, fmt.Errorf("%s: file path is required for TypeORM format", label)
} }
reader = typeorm.NewReader(&readers.ReaderOptions{FilePath: filePath}) reader = typeorm.NewReader(newReaderOptions(filePath, ""))
case "pgsql": case "pgsql":
if connString == "" { if connString == "" {
return nil, fmt.Errorf("%s: connection string is required for PostgreSQL format", label) return nil, fmt.Errorf("%s: connection string is required for PostgreSQL format", label)
} }
reader = pgsql.NewReader(&readers.ReaderOptions{ConnectionString: connString}) reader = pgsql.NewReader(newReaderOptions("", connString))
case "sqlite", "sqlite3": case "sqlite", "sqlite3":
// SQLite can use either file path or connection string // SQLite can use either file path or connection string
dbPath := filePath dbPath := filePath
@@ -305,7 +305,7 @@ func readDatabaseForEdit(dbType, filePath, connString, label string) (*models.Da
if dbPath == "" { if dbPath == "" {
return nil, fmt.Errorf("%s: file path or connection string is required for SQLite format", label) return nil, fmt.Errorf("%s: file path or connection string is required for SQLite format", label)
} }
reader = sqlite.NewReader(&readers.ReaderOptions{FilePath: dbPath}) reader = sqlite.NewReader(newReaderOptions(dbPath, ""))
default: default:
return nil, fmt.Errorf("%s: unsupported format: %s", label, dbType) return nil, fmt.Errorf("%s: unsupported format: %s", label, dbType)
} }
@@ -323,31 +323,31 @@ func writeDatabaseForEdit(dbType, filePath, connString string, db *models.Databa
switch strings.ToLower(dbType) { switch strings.ToLower(dbType) {
case "dbml": case "dbml":
writer = wdbml.NewWriter(&writers.WriterOptions{OutputPath: filePath}) writer = wdbml.NewWriter(newWriterOptions(filePath, "", false, ""))
case "dctx": case "dctx":
writer = wdctx.NewWriter(&writers.WriterOptions{OutputPath: filePath}) writer = wdctx.NewWriter(newWriterOptions(filePath, "", false, ""))
case "drawdb": case "drawdb":
writer = wdrawdb.NewWriter(&writers.WriterOptions{OutputPath: filePath}) writer = wdrawdb.NewWriter(newWriterOptions(filePath, "", false, ""))
case "graphql": case "graphql":
writer = wgraphql.NewWriter(&writers.WriterOptions{OutputPath: filePath}) writer = wgraphql.NewWriter(newWriterOptions(filePath, "", false, ""))
case "json": case "json":
writer = wjson.NewWriter(&writers.WriterOptions{OutputPath: filePath}) writer = wjson.NewWriter(newWriterOptions(filePath, "", false, ""))
case "yaml": case "yaml":
writer = wyaml.NewWriter(&writers.WriterOptions{OutputPath: filePath}) writer = wyaml.NewWriter(newWriterOptions(filePath, "", false, ""))
case "gorm": case "gorm":
writer = wgorm.NewWriter(&writers.WriterOptions{OutputPath: filePath}) writer = wgorm.NewWriter(newWriterOptions(filePath, "", false, ""))
case "bun": case "bun":
writer = wbun.NewWriter(&writers.WriterOptions{OutputPath: filePath}) writer = wbun.NewWriter(newWriterOptions(filePath, "", false, ""))
case "drizzle": case "drizzle":
writer = wdrizzle.NewWriter(&writers.WriterOptions{OutputPath: filePath}) writer = wdrizzle.NewWriter(newWriterOptions(filePath, "", false, ""))
case "prisma": case "prisma":
writer = wprisma.NewWriter(&writers.WriterOptions{OutputPath: filePath}) writer = wprisma.NewWriter(newWriterOptions(filePath, "", false, ""))
case "typeorm": case "typeorm":
writer = wtypeorm.NewWriter(&writers.WriterOptions{OutputPath: filePath}) writer = wtypeorm.NewWriter(newWriterOptions(filePath, "", false, ""))
case "sqlite", "sqlite3": case "sqlite", "sqlite3":
writer = wsqlite.NewWriter(&writers.WriterOptions{OutputPath: filePath}) writer = wsqlite.NewWriter(newWriterOptions(filePath, "", false, ""))
case "pgsql": case "pgsql":
writer = wpgsql.NewWriter(&writers.WriterOptions{OutputPath: filePath}) writer = wpgsql.NewWriter(newWriterOptions(filePath, "", false, ""))
default: default:
return fmt.Errorf("%s: unsupported format: %s", label, dbType) return fmt.Errorf("%s: unsupported format: %s", label, dbType)
} }

View File

@@ -221,73 +221,73 @@ func readDatabaseForInspect(dbType, filePath, connString string) (*models.Databa
if filePath == "" { if filePath == "" {
return nil, fmt.Errorf("file path is required for DBML format") return nil, fmt.Errorf("file path is required for DBML format")
} }
reader = dbml.NewReader(&readers.ReaderOptions{FilePath: filePath}) reader = dbml.NewReader(newReaderOptions(filePath, ""))
case "dctx": case "dctx":
if filePath == "" { if filePath == "" {
return nil, fmt.Errorf("file path is required for DCTX format") return nil, fmt.Errorf("file path is required for DCTX format")
} }
reader = dctx.NewReader(&readers.ReaderOptions{FilePath: filePath}) reader = dctx.NewReader(newReaderOptions(filePath, ""))
case "drawdb": case "drawdb":
if filePath == "" { if filePath == "" {
return nil, fmt.Errorf("file path is required for DrawDB format") return nil, fmt.Errorf("file path is required for DrawDB format")
} }
reader = drawdb.NewReader(&readers.ReaderOptions{FilePath: filePath}) reader = drawdb.NewReader(newReaderOptions(filePath, ""))
case "graphql": case "graphql":
if filePath == "" { if filePath == "" {
return nil, fmt.Errorf("file path is required for GraphQL format") return nil, fmt.Errorf("file path is required for GraphQL format")
} }
reader = graphql.NewReader(&readers.ReaderOptions{FilePath: filePath}) reader = graphql.NewReader(newReaderOptions(filePath, ""))
case "json": case "json":
if filePath == "" { if filePath == "" {
return nil, fmt.Errorf("file path is required for JSON format") return nil, fmt.Errorf("file path is required for JSON format")
} }
reader = json.NewReader(&readers.ReaderOptions{FilePath: filePath}) reader = json.NewReader(newReaderOptions(filePath, ""))
case "yaml", "yml": case "yaml", "yml":
if filePath == "" { if filePath == "" {
return nil, fmt.Errorf("file path is required for YAML format") return nil, fmt.Errorf("file path is required for YAML format")
} }
reader = yaml.NewReader(&readers.ReaderOptions{FilePath: filePath}) reader = yaml.NewReader(newReaderOptions(filePath, ""))
case "gorm": case "gorm":
if filePath == "" { if filePath == "" {
return nil, fmt.Errorf("file path is required for GORM format") return nil, fmt.Errorf("file path is required for GORM format")
} }
reader = gorm.NewReader(&readers.ReaderOptions{FilePath: filePath}) reader = gorm.NewReader(newReaderOptions(filePath, ""))
case "bun": case "bun":
if filePath == "" { if filePath == "" {
return nil, fmt.Errorf("file path is required for Bun format") return nil, fmt.Errorf("file path is required for Bun format")
} }
reader = bun.NewReader(&readers.ReaderOptions{FilePath: filePath}) reader = bun.NewReader(newReaderOptions(filePath, ""))
case "drizzle": case "drizzle":
if filePath == "" { if filePath == "" {
return nil, fmt.Errorf("file path is required for Drizzle format") return nil, fmt.Errorf("file path is required for Drizzle format")
} }
reader = drizzle.NewReader(&readers.ReaderOptions{FilePath: filePath}) reader = drizzle.NewReader(newReaderOptions(filePath, ""))
case "prisma": case "prisma":
if filePath == "" { if filePath == "" {
return nil, fmt.Errorf("file path is required for Prisma format") return nil, fmt.Errorf("file path is required for Prisma format")
} }
reader = prisma.NewReader(&readers.ReaderOptions{FilePath: filePath}) reader = prisma.NewReader(newReaderOptions(filePath, ""))
case "typeorm": case "typeorm":
if filePath == "" { if filePath == "" {
return nil, fmt.Errorf("file path is required for TypeORM format") return nil, fmt.Errorf("file path is required for TypeORM format")
} }
reader = typeorm.NewReader(&readers.ReaderOptions{FilePath: filePath}) reader = typeorm.NewReader(newReaderOptions(filePath, ""))
case "pgsql", "postgres", "postgresql": case "pgsql", "postgres", "postgresql":
if connString == "" { if connString == "" {
return nil, fmt.Errorf("connection string is required for PostgreSQL format") return nil, fmt.Errorf("connection string is required for PostgreSQL format")
} }
reader = pgsql.NewReader(&readers.ReaderOptions{ConnectionString: connString}) reader = pgsql.NewReader(newReaderOptions("", connString))
case "sqlite", "sqlite3": case "sqlite", "sqlite3":
// SQLite can use either file path or connection string // SQLite can use either file path or connection string
@@ -298,7 +298,7 @@ func readDatabaseForInspect(dbType, filePath, connString string) (*models.Databa
if dbPath == "" { if dbPath == "" {
return nil, fmt.Errorf("file path or connection string is required for SQLite format") return nil, fmt.Errorf("file path or connection string is required for SQLite format")
} }
reader = sqlite.NewReader(&readers.ReaderOptions{FilePath: dbPath}) reader = sqlite.NewReader(newReaderOptions(dbPath, ""))
default: default:
return nil, fmt.Errorf("unsupported database type: %s", dbType) return nil, fmt.Errorf("unsupported database type: %s", dbType)

View File

@@ -284,62 +284,62 @@ func readDatabaseForMerge(dbType, filePath, connString, label string) (*models.D
if filePath == "" { if filePath == "" {
return nil, fmt.Errorf("%s: file path is required for DBML format", label) return nil, fmt.Errorf("%s: file path is required for DBML format", label)
} }
reader = dbml.NewReader(&readers.ReaderOptions{FilePath: filePath}) reader = dbml.NewReader(newReaderOptions(filePath, ""))
case "dctx": case "dctx":
if filePath == "" { if filePath == "" {
return nil, fmt.Errorf("%s: file path is required for DCTX format", label) return nil, fmt.Errorf("%s: file path is required for DCTX format", label)
} }
reader = dctx.NewReader(&readers.ReaderOptions{FilePath: filePath}) reader = dctx.NewReader(newReaderOptions(filePath, ""))
case "drawdb": case "drawdb":
if filePath == "" { if filePath == "" {
return nil, fmt.Errorf("%s: file path is required for DrawDB format", label) return nil, fmt.Errorf("%s: file path is required for DrawDB format", label)
} }
reader = drawdb.NewReader(&readers.ReaderOptions{FilePath: filePath}) reader = drawdb.NewReader(newReaderOptions(filePath, ""))
case "graphql": case "graphql":
if filePath == "" { if filePath == "" {
return nil, fmt.Errorf("%s: file path is required for GraphQL format", label) return nil, fmt.Errorf("%s: file path is required for GraphQL format", label)
} }
reader = graphql.NewReader(&readers.ReaderOptions{FilePath: filePath}) reader = graphql.NewReader(newReaderOptions(filePath, ""))
case "json": case "json":
if filePath == "" { if filePath == "" {
return nil, fmt.Errorf("%s: file path is required for JSON format", label) return nil, fmt.Errorf("%s: file path is required for JSON format", label)
} }
reader = json.NewReader(&readers.ReaderOptions{FilePath: filePath}) reader = json.NewReader(newReaderOptions(filePath, ""))
case "yaml": case "yaml":
if filePath == "" { if filePath == "" {
return nil, fmt.Errorf("%s: file path is required for YAML format", label) return nil, fmt.Errorf("%s: file path is required for YAML format", label)
} }
reader = yaml.NewReader(&readers.ReaderOptions{FilePath: filePath}) reader = yaml.NewReader(newReaderOptions(filePath, ""))
case "gorm": case "gorm":
if filePath == "" { if filePath == "" {
return nil, fmt.Errorf("%s: file path is required for GORM format", label) return nil, fmt.Errorf("%s: file path is required for GORM format", label)
} }
reader = gorm.NewReader(&readers.ReaderOptions{FilePath: filePath}) reader = gorm.NewReader(newReaderOptions(filePath, ""))
case "bun": case "bun":
if filePath == "" { if filePath == "" {
return nil, fmt.Errorf("%s: file path is required for Bun format", label) return nil, fmt.Errorf("%s: file path is required for Bun format", label)
} }
reader = bun.NewReader(&readers.ReaderOptions{FilePath: filePath}) reader = bun.NewReader(newReaderOptions(filePath, ""))
case "drizzle": case "drizzle":
if filePath == "" { if filePath == "" {
return nil, fmt.Errorf("%s: file path is required for Drizzle format", label) return nil, fmt.Errorf("%s: file path is required for Drizzle format", label)
} }
reader = drizzle.NewReader(&readers.ReaderOptions{FilePath: filePath}) reader = drizzle.NewReader(newReaderOptions(filePath, ""))
case "prisma": case "prisma":
if filePath == "" { if filePath == "" {
return nil, fmt.Errorf("%s: file path is required for Prisma format", label) return nil, fmt.Errorf("%s: file path is required for Prisma format", label)
} }
reader = prisma.NewReader(&readers.ReaderOptions{FilePath: filePath}) reader = prisma.NewReader(newReaderOptions(filePath, ""))
case "typeorm": case "typeorm":
if filePath == "" { if filePath == "" {
return nil, fmt.Errorf("%s: file path is required for TypeORM format", label) return nil, fmt.Errorf("%s: file path is required for TypeORM format", label)
} }
reader = typeorm.NewReader(&readers.ReaderOptions{FilePath: filePath}) reader = typeorm.NewReader(newReaderOptions(filePath, ""))
case "pgsql": case "pgsql":
if connString == "" { if connString == "" {
return nil, fmt.Errorf("%s: connection string is required for PostgreSQL format", label) return nil, fmt.Errorf("%s: connection string is required for PostgreSQL format", label)
} }
reader = pgsql.NewReader(&readers.ReaderOptions{ConnectionString: connString}) reader = pgsql.NewReader(newReaderOptions("", connString))
case "sqlite", "sqlite3": case "sqlite", "sqlite3":
// SQLite can use either file path or connection string // SQLite can use either file path or connection string
dbPath := filePath dbPath := filePath
@@ -349,7 +349,7 @@ func readDatabaseForMerge(dbType, filePath, connString, label string) (*models.D
if dbPath == "" { if dbPath == "" {
return nil, fmt.Errorf("%s: file path or connection string is required for SQLite format", label) return nil, fmt.Errorf("%s: file path or connection string is required for SQLite format", label)
} }
reader = sqlite.NewReader(&readers.ReaderOptions{FilePath: dbPath}) reader = sqlite.NewReader(newReaderOptions(dbPath, ""))
default: default:
return nil, fmt.Errorf("%s: unsupported format '%s'", label, dbType) return nil, fmt.Errorf("%s: unsupported format '%s'", label, dbType)
} }
@@ -370,61 +370,61 @@ func writeDatabaseForMerge(dbType, filePath, connString string, db *models.Datab
if filePath == "" { if filePath == "" {
return fmt.Errorf("%s: file path is required for DBML format", label) return fmt.Errorf("%s: file path is required for DBML format", label)
} }
writer = wdbml.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema}) writer = wdbml.NewWriter(newWriterOptions(filePath, "", flattenSchema, ""))
case "dctx": case "dctx":
if filePath == "" { if filePath == "" {
return fmt.Errorf("%s: file path is required for DCTX format", label) return fmt.Errorf("%s: file path is required for DCTX format", label)
} }
writer = wdctx.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema}) writer = wdctx.NewWriter(newWriterOptions(filePath, "", flattenSchema, ""))
case "drawdb": case "drawdb":
if filePath == "" { if filePath == "" {
return fmt.Errorf("%s: file path is required for DrawDB format", label) return fmt.Errorf("%s: file path is required for DrawDB format", label)
} }
writer = wdrawdb.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema}) writer = wdrawdb.NewWriter(newWriterOptions(filePath, "", flattenSchema, ""))
case "graphql": case "graphql":
if filePath == "" { if filePath == "" {
return fmt.Errorf("%s: file path is required for GraphQL format", label) return fmt.Errorf("%s: file path is required for GraphQL format", label)
} }
writer = wgraphql.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema}) writer = wgraphql.NewWriter(newWriterOptions(filePath, "", flattenSchema, ""))
case "json": case "json":
if filePath == "" { if filePath == "" {
return fmt.Errorf("%s: file path is required for JSON format", label) return fmt.Errorf("%s: file path is required for JSON format", label)
} }
writer = wjson.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema}) writer = wjson.NewWriter(newWriterOptions(filePath, "", flattenSchema, ""))
case "yaml": case "yaml":
if filePath == "" { if filePath == "" {
return fmt.Errorf("%s: file path is required for YAML format", label) return fmt.Errorf("%s: file path is required for YAML format", label)
} }
writer = wyaml.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema}) writer = wyaml.NewWriter(newWriterOptions(filePath, "", flattenSchema, ""))
case "gorm": case "gorm":
if filePath == "" { if filePath == "" {
return fmt.Errorf("%s: file path is required for GORM format", label) return fmt.Errorf("%s: file path is required for GORM format", label)
} }
writer = wgorm.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema}) writer = wgorm.NewWriter(newWriterOptions(filePath, "", flattenSchema, ""))
case "bun": case "bun":
if filePath == "" { if filePath == "" {
return fmt.Errorf("%s: file path is required for Bun format", label) return fmt.Errorf("%s: file path is required for Bun format", label)
} }
writer = wbun.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema}) writer = wbun.NewWriter(newWriterOptions(filePath, "", flattenSchema, ""))
case "drizzle": case "drizzle":
if filePath == "" { if filePath == "" {
return fmt.Errorf("%s: file path is required for Drizzle format", label) return fmt.Errorf("%s: file path is required for Drizzle format", label)
} }
writer = wdrizzle.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema}) writer = wdrizzle.NewWriter(newWriterOptions(filePath, "", flattenSchema, ""))
case "prisma": case "prisma":
if filePath == "" { if filePath == "" {
return fmt.Errorf("%s: file path is required for Prisma format", label) return fmt.Errorf("%s: file path is required for Prisma format", label)
} }
writer = wprisma.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema}) writer = wprisma.NewWriter(newWriterOptions(filePath, "", flattenSchema, ""))
case "typeorm": case "typeorm":
if filePath == "" { if filePath == "" {
return fmt.Errorf("%s: file path is required for TypeORM format", label) return fmt.Errorf("%s: file path is required for TypeORM format", label)
} }
writer = wtypeorm.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema}) writer = wtypeorm.NewWriter(newWriterOptions(filePath, "", flattenSchema, ""))
case "sqlite", "sqlite3": case "sqlite", "sqlite3":
writer = wsqlite.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema}) writer = wsqlite.NewWriter(newWriterOptions(filePath, "", flattenSchema, ""))
case "pgsql": case "pgsql":
writerOpts := &writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema} writerOpts := newWriterOptions(filePath, "", flattenSchema, "")
if connString != "" { if connString != "" {
writerOpts.Metadata = map[string]interface{}{ writerOpts.Metadata = map[string]interface{}{
"connection_string": connString, "connection_string": connString,

View File

@@ -0,0 +1,24 @@
package main
import (
"git.warky.dev/wdevs/relspecgo/pkg/readers"
"git.warky.dev/wdevs/relspecgo/pkg/writers"
)
func newReaderOptions(filePath, connString string) *readers.ReaderOptions {
return &readers.ReaderOptions{
FilePath: filePath,
ConnectionString: connString,
Prisma7: prisma7,
}
}
func newWriterOptions(outputPath, packageName string, flattenSchema bool, nullableTypes string) *writers.WriterOptions {
return &writers.WriterOptions{
OutputPath: outputPath,
PackageName: packageName,
FlattenSchema: flattenSchema,
NullableTypes: nullableTypes,
Prisma7: prisma7,
}
}

View File

@@ -12,6 +12,7 @@ var (
// Version information, set via ldflags during build // Version information, set via ldflags during build
version = "dev" version = "dev"
buildDate = "unknown" buildDate = "unknown"
prisma7 bool
) )
func init() { func init() {
@@ -68,4 +69,5 @@ func init() {
rootCmd.AddCommand(mergeCmd) rootCmd.AddCommand(mergeCmd)
rootCmd.AddCommand(splitCmd) rootCmd.AddCommand(splitCmd)
rootCmd.AddCommand(versionCmd) rootCmd.AddCommand(versionCmd)
rootCmd.PersistentFlags().BoolVar(&prisma7, "prisma7", false, "Use Prisma 7 generator conventions when reading/writing Prisma schemas")
} }

View File

@@ -1,6 +1,6 @@
# Maintainer: Hein (Warky Devs) <hein@warky.dev> # Maintainer: Hein (Warky Devs) <hein@warky.dev>
pkgname=relspec pkgname=relspec
pkgver=1.0.48 pkgver=1.0.54
pkgrel=1 pkgrel=1
pkgdesc="RelSpec is a comprehensive database relations management tool that reads, transforms, and writes database table specifications across multiple formats and ORMs." pkgdesc="RelSpec is a comprehensive database relations management tool that reads, transforms, and writes database table specifications across multiple formats and ORMs."
arch=('x86_64' 'aarch64') arch=('x86_64' 'aarch64')

View File

@@ -1,5 +1,5 @@
Name: relspec Name: relspec
Version: 1.0.48 Version: 1.0.54
Release: 1%{?dist} Release: 1%{?dist}
Summary: RelSpec is a comprehensive database relations management tool that reads, transforms, and writes database table specifications across multiple formats and ORMs. Summary: RelSpec is a comprehensive database relations management tool that reads, transforms, and writes database table specifications across multiple formats and ORMs.

View File

@@ -70,6 +70,7 @@ func (r *Reader) ReadTable() (*models.Table, error) {
// parsePrisma parses Prisma schema content and returns a Database model // parsePrisma parses Prisma schema content and returns a Database model
func (r *Reader) parsePrisma(content string) (*models.Database, error) { func (r *Reader) parsePrisma(content string) (*models.Database, error) {
db := models.InitDatabase("database") db := models.InitDatabase("database")
db.SourceFormat = "prisma"
if r.options.Metadata != nil { if r.options.Metadata != nil {
if name, ok := r.options.Metadata["name"].(string); ok { if name, ok := r.options.Metadata["name"].(string); ok {
@@ -139,7 +140,7 @@ func (r *Reader) parsePrisma(content string) (*models.Database, error) {
case "datasource": case "datasource":
r.parseDatasource(blockContent, db) r.parseDatasource(blockContent, db)
case "generator": case "generator":
// We don't need to do anything with generator blocks r.parseGenerator(blockContent, db)
case "model": case "model":
if currentTable != nil { if currentTable != nil {
r.parseModelFields(blockContent, currentTable) r.parseModelFields(blockContent, currentTable)
@@ -173,10 +174,34 @@ func (r *Reader) parsePrisma(content string) (*models.Database, error) {
// Second pass: resolve relationships // Second pass: resolve relationships
r.resolveRelationships(schema) r.resolveRelationships(schema)
if db.SourceFormat == "prisma" && r.options != nil && r.options.Prisma7 {
db.SourceFormat = "prisma7"
}
db.Schemas = append(db.Schemas, schema) db.Schemas = append(db.Schemas, schema)
return db, nil return db, nil
} }
func (r *Reader) parseGenerator(lines []string, db *models.Database) {
providerRegex := regexp.MustCompile(`provider\s*=\s*"([^"]+)"`)
for _, line := range lines {
if matches := providerRegex.FindStringSubmatch(line); matches != nil {
switch matches[1] {
case "prisma-client":
db.SourceFormat = "prisma7"
default:
db.SourceFormat = "prisma"
}
return
}
}
if r.options != nil && r.options.Prisma7 {
db.SourceFormat = "prisma7"
}
}
// parseDatasource extracts database type from datasource block // parseDatasource extracts database type from datasource block
func (r *Reader) parseDatasource(lines []string, db *models.Database) { func (r *Reader) parseDatasource(lines []string, db *models.Database) {
providerRegex := regexp.MustCompile(`provider\s*=\s*"?(\w+)"?`) providerRegex := regexp.MustCompile(`provider\s*=\s*"?(\w+)"?`)

View File

@@ -0,0 +1,77 @@
package prisma
import (
"os"
"path/filepath"
"testing"
"git.warky.dev/wdevs/relspecgo/pkg/readers"
)
func TestReadDatabase_Prisma7GeneratorSetsSourceFormat(t *testing.T) {
t.Parallel()
tmpDir := t.TempDir()
schemaPath := filepath.Join(tmpDir, "schema.prisma")
content := `datasource db {
provider = "postgresql"
url = env("DATABASE_URL")
}
generator client {
provider = "prisma-client"
output = "./generated"
}
model User {
id Int @id @default(autoincrement())
}`
if err := os.WriteFile(schemaPath, []byte(content), 0644); err != nil {
t.Fatalf("failed to write schema: %v", err)
}
reader := NewReader(&readers.ReaderOptions{FilePath: schemaPath})
db, err := reader.ReadDatabase()
if err != nil {
t.Fatalf("ReadDatabase() failed: %v", err)
}
if db.SourceFormat != "prisma7" {
t.Fatalf("expected SourceFormat prisma7, got %q", db.SourceFormat)
}
}
func TestReadDatabase_Prisma7FlagSetsSourceFormatWithoutGenerator(t *testing.T) {
t.Parallel()
tmpDir := t.TempDir()
schemaPath := filepath.Join(tmpDir, "schema.prisma")
content := `datasource db {
provider = "postgresql"
url = env("DATABASE_URL")
}
model User {
id Int @id @default(autoincrement())
}`
if err := os.WriteFile(schemaPath, []byte(content), 0644); err != nil {
t.Fatalf("failed to write schema: %v", err)
}
reader := NewReader(&readers.ReaderOptions{
FilePath: schemaPath,
Prisma7: true,
})
db, err := reader.ReadDatabase()
if err != nil {
t.Fatalf("ReadDatabase() failed: %v", err)
}
if db.SourceFormat != "prisma7" {
t.Fatalf("expected SourceFormat prisma7 from flag, got %q", db.SourceFormat)
}
}

View File

@@ -25,6 +25,9 @@ type ReaderOptions struct {
// ConnectionString is the database connection string (for DB readers) // ConnectionString is the database connection string (for DB readers)
ConnectionString string ConnectionString string
// Prisma7 enables Prisma 7-specific handling for Prisma schemas.
Prisma7 bool
// Additional options can be added here as needed // Additional options can be added here as needed
Metadata map[string]interface{} Metadata map[string]interface{}
} }

View File

@@ -18,17 +18,20 @@ type TemplateData struct {
// ModelData represents a single model/struct in the template // ModelData represents a single model/struct in the template
type ModelData struct { type ModelData struct {
Name string Name string
TableName string // schema.table format TableName string // schema.table format
SchemaName string SchemaName string
TableNameOnly string // just table name without schema TableNameOnly string // just table name without schema
Comment string Comment string
Fields []*FieldData Fields []*FieldData
Config *MethodConfig Config *MethodConfig
PrimaryKeyField string // Name of the primary key field PrimaryKeyField string // Name of the primary key field
PrimaryKeyIsSQL bool // Whether PK uses SQL type (needs .Int64() call) PrimaryKeyType string // Go type of the primary key field
IDColumnName string // Name of the ID column in database PrimaryKeyIsSQL bool // Whether PK uses SQL type (needs .Int64() call)
Prefix string // 3-letter prefix PrimaryKeyIsStr bool // Whether helper methods should use string IDs
PrimaryKeyIDType string // Helper method GetID/SetID/UpdateID type
IDColumnName string // Name of the ID column in database
Prefix string // 3-letter prefix
} }
// FieldData represents a single field in a struct // FieldData represents a single field in a struct
@@ -140,7 +143,13 @@ func NewModelData(table *models.Table, schema string, typeMapper *TypeMapper, fl
model.IDColumnName = safeName model.IDColumnName = safeName
// Check if PK type is a SQL type (contains resolvespec_common or sql_types) // Check if PK type is a SQL type (contains resolvespec_common or sql_types)
goType := typeMapper.SQLTypeToGoType(col.Type, col.NotNull) goType := typeMapper.SQLTypeToGoType(col.Type, col.NotNull)
model.PrimaryKeyType = goType
model.PrimaryKeyIsSQL = strings.Contains(goType, "resolvespec_common") || strings.Contains(goType, "sql_types") model.PrimaryKeyIsSQL = strings.Contains(goType, "resolvespec_common") || strings.Contains(goType, "sql_types")
model.PrimaryKeyIsStr = isStringLikePrimaryKeyType(goType)
model.PrimaryKeyIDType = "int64"
if model.PrimaryKeyIsStr {
model.PrimaryKeyIDType = "string"
}
break break
} }
} }
@@ -192,6 +201,15 @@ func formatComment(description, comment string) string {
return comment return comment
} }
func isStringLikePrimaryKeyType(goType string) bool {
switch goType {
case "string", "sql.NullString", "resolvespec_common.SqlString", "resolvespec_common.SqlUUID":
return true
default:
return false
}
}
// resolveFieldNameCollision checks if a field name conflicts with generated method names // resolveFieldNameCollision checks if a field name conflicts with generated method names
// and adds an underscore suffix if there's a collision // and adds an underscore suffix if there's a collision
func resolveFieldNameCollision(fieldName string) string { func resolveFieldNameCollision(fieldName string) string {

View File

@@ -44,33 +44,55 @@ func (m {{.Name}}) SchemaName() string {
{{end}} {{end}}
{{if and .Config.GenerateGetID .PrimaryKeyField}} {{if and .Config.GenerateGetID .PrimaryKeyField}}
// GetID returns the primary key value // GetID returns the primary key value
func (m {{.Name}}) GetID() int64 { func (m {{.Name}}) GetID() {{.PrimaryKeyIDType}} {
{{if .PrimaryKeyIsSQL -}} {{if .PrimaryKeyIsSQL -}}
{{if .PrimaryKeyIsStr -}}
return m.{{.PrimaryKeyField}}.String()
{{- else -}}
return m.{{.PrimaryKeyField}}.Int64() return m.{{.PrimaryKeyField}}.Int64()
{{- end}}
{{- else -}}
{{if .PrimaryKeyIsStr -}}
return m.{{.PrimaryKeyField}}
{{- else -}} {{- else -}}
return int64(m.{{.PrimaryKeyField}}) return int64(m.{{.PrimaryKeyField}})
{{- end}} {{- end}}
{{- end}}
} }
{{end}} {{end}}
{{if and .Config.GenerateGetIDStr .PrimaryKeyField}} {{if and .Config.GenerateGetIDStr .PrimaryKeyField}}
// GetIDStr returns the primary key as a string // GetIDStr returns the primary key as a string
func (m {{.Name}}) GetIDStr() string { func (m {{.Name}}) GetIDStr() string {
{{if .PrimaryKeyIsSQL -}}
return m.{{.PrimaryKeyField}}.String()
{{- else if .PrimaryKeyIsStr -}}
return m.{{.PrimaryKeyField}}
{{- else -}}
return fmt.Sprintf("%d", m.{{.PrimaryKeyField}}) return fmt.Sprintf("%d", m.{{.PrimaryKeyField}})
{{- end}}
} }
{{end}} {{end}}
{{if and .Config.GenerateSetID .PrimaryKeyField}} {{if and .Config.GenerateSetID .PrimaryKeyField}}
// SetID sets the primary key value // SetID sets the primary key value
func (m {{.Name}}) SetID(newid int64) { func (m {{.Name}}) SetID(newid {{.PrimaryKeyIDType}}) {
m.UpdateID(newid) m.UpdateID(newid)
} }
{{end}} {{end}}
{{if and .Config.GenerateUpdateID .PrimaryKeyField}} {{if and .Config.GenerateUpdateID .PrimaryKeyField}}
// UpdateID updates the primary key value // UpdateID updates the primary key value
func (m *{{.Name}}) UpdateID(newid int64) { func (m *{{.Name}}) UpdateID(newid {{.PrimaryKeyIDType}}) {
{{if .PrimaryKeyIsSQL -}} {{if .PrimaryKeyIsSQL -}}
m.{{.PrimaryKeyField}}.FromString(fmt.Sprintf("%d", newid)) {{if .PrimaryKeyIsStr -}}
m.{{.PrimaryKeyField}}.FromString(newid)
{{- else -}} {{- else -}}
m.{{.PrimaryKeyField}} = int32(newid) m.{{.PrimaryKeyField}}.FromString(fmt.Sprintf("%d", newid))
{{- end}}
{{- else -}}
{{if .PrimaryKeyIsStr -}}
m.{{.PrimaryKeyField}} = newid
{{- else -}}
m.{{.PrimaryKeyField}} = {{.PrimaryKeyType}}(newid)
{{- end}}
{{- end}} {{- end}}
} }
{{end}} {{end}}

View File

@@ -311,10 +311,11 @@ func (tm *TypeMapper) BuildBunTag(column *models.Column, table *models.Table) st
if column.Type != "" { if column.Type != "" {
// Sanitize type to remove backticks // Sanitize type to remove backticks
typeStr := writers.SanitizeStructTagValue(column.Type) typeStr := writers.SanitizeStructTagValue(column.Type)
isArray := pgsql.IsArrayType(typeStr)
hasExplicitTypeModifier := pgsql.HasExplicitTypeModifier(typeStr) hasExplicitTypeModifier := pgsql.HasExplicitTypeModifier(typeStr)
if !hasExplicitTypeModifier && column.Length > 0 { if !hasExplicitTypeModifier && !isArray && column.Length > 0 {
typeStr = fmt.Sprintf("%s(%d)", typeStr, column.Length) typeStr = fmt.Sprintf("%s(%d)", typeStr, column.Length)
} else if !hasExplicitTypeModifier && column.Precision > 0 { } else if !hasExplicitTypeModifier && !isArray && column.Precision > 0 {
if column.Scale > 0 { if column.Scale > 0 {
typeStr = fmt.Sprintf("%s(%d,%d)", typeStr, column.Precision, column.Scale) typeStr = fmt.Sprintf("%s(%d,%d)", typeStr, column.Precision, column.Scale)
} else { } else {
@@ -322,6 +323,9 @@ func (tm *TypeMapper) BuildBunTag(column *models.Column, table *models.Table) st
} }
} }
parts = append(parts, fmt.Sprintf("type:%s", typeStr)) parts = append(parts, fmt.Sprintf("type:%s", typeStr))
if isArray && tm.typeStyle == writers.NullableTypeStdlib {
parts = append(parts, "array")
}
} }
// Primary key // Primary key

View File

@@ -102,8 +102,8 @@ func (w *Writer) writeSingleFile(db *models.Database) error {
} }
} }
// Add fmt import if GetIDStr is enabled // Add fmt import when generated helper methods need string formatting.
if w.config.GenerateGetIDStr { if w.needsFmtImport(templateData.Models) {
templateData.AddImport("\"fmt\"") templateData.AddImport("\"fmt\"")
} }
@@ -195,8 +195,8 @@ func (w *Writer) writeMultiFile(db *models.Database) error {
} }
} }
// Add fmt import if GetIDStr is enabled // Add fmt import when generated helper methods need string formatting.
if w.config.GenerateGetIDStr { if w.needsFmtImport(templateData.Models) {
templateData.AddImport("\"fmt\"") templateData.AddImport("\"fmt\"")
} }
@@ -301,6 +301,26 @@ func (w *Writer) addRelationshipFields(modelData *ModelData, table *models.Table
} }
} }
func (w *Writer) needsFmtImport(models []*ModelData) bool {
if w.config.GenerateGetIDStr {
for _, model := range models {
if model.PrimaryKeyField != "" && !model.PrimaryKeyIsSQL && !model.PrimaryKeyIsStr {
return true
}
}
}
if w.config.GenerateUpdateID {
for _, model := range models {
if model.PrimaryKeyField != "" && model.PrimaryKeyIsSQL && !model.PrimaryKeyIsStr {
return true
}
}
}
return false
}
// findTable finds a table by schema and name in the database // findTable finds a table by schema and name in the database
func (w *Writer) findTable(schemaName, tableName string, db *models.Database) *models.Table { func (w *Writer) findTable(schemaName, tableName string, db *models.Database) *models.Table {
for _, schema := range db.Schemas { for _, schema := range db.Schemas {

View File

@@ -574,6 +574,10 @@ func TestTypeMapper_SQLTypeToGoType_Bun(t *testing.T) {
{"boolean", false, "resolvespec_common.SqlBool"}, {"boolean", false, "resolvespec_common.SqlBool"},
{"uuid", false, "resolvespec_common.SqlUUID"}, {"uuid", false, "resolvespec_common.SqlUUID"},
{"jsonb", false, "resolvespec_common.SqlJSONB"}, {"jsonb", false, "resolvespec_common.SqlJSONB"},
{"text[]", true, "resolvespec_common.SqlStringArray"},
{"text[]", false, "resolvespec_common.SqlStringArray"},
{"integer[]", true, "resolvespec_common.SqlInt32Array"},
{"bigint[]", false, "resolvespec_common.SqlInt64Array"},
} }
for _, tt := range tests { for _, tt := range tests {
@@ -586,6 +590,116 @@ func TestTypeMapper_SQLTypeToGoType_Bun(t *testing.T) {
} }
} }
func TestWriter_UpdateIDTypeSafety_Bun(t *testing.T) {
tests := []struct {
name string
pkType string
expectedPK string
expectedLine string
forbidInt32 bool
}{
{"int32_pk", "int", "int32", "m.ID = int32(newid)", false},
{"sql_int16_pk", "smallint", "resolvespec_common.SqlInt16", "m.ID.FromString(fmt.Sprintf(\"%d\", newid))", true},
{"int64_pk", "bigint", "int64", "m.ID = int64(newid)", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
table := models.InitTable("test_table", "public")
table.Columns["id"] = &models.Column{
Name: "id",
Type: tt.pkType,
NotNull: true,
IsPrimaryKey: true,
}
tmpDir := t.TempDir()
opts := &writers.WriterOptions{
PackageName: "models",
OutputPath: filepath.Join(tmpDir, "test.go"),
}
writer := NewWriter(opts)
err := writer.WriteTable(table)
if err != nil {
t.Fatalf("WriteTable failed: %v", err)
}
content, err := os.ReadFile(opts.OutputPath)
if err != nil {
t.Fatalf("Failed to read generated file: %v", err)
}
generated := string(content)
if !strings.Contains(generated, tt.expectedLine) {
t.Errorf("Expected UpdateID to include %s\nGenerated:\n%s", tt.expectedLine, generated)
}
if !strings.Contains(generated, "ID "+tt.expectedPK) {
t.Errorf("Expected generated primary key field type %s\nGenerated:\n%s", tt.expectedPK, generated)
}
if tt.forbidInt32 && strings.Contains(generated, "int32(newid)") {
t.Errorf("UpdateID should not cast to int32 for %s type\nGenerated:\n%s", tt.pkType, generated)
}
if !strings.Contains(generated, "UpdateID(newid int64)") {
t.Errorf("UpdateID should accept int64 parameter\nGenerated:\n%s", generated)
}
})
}
}
func TestWriter_StringPrimaryKeyHelpers_Bun(t *testing.T) {
table := models.InitTable("accounts", "public")
table.Columns["id"] = &models.Column{
Name: "id",
Type: "uuid",
NotNull: true,
IsPrimaryKey: true,
}
tmpDir := t.TempDir()
opts := &writers.WriterOptions{
PackageName: "models",
OutputPath: filepath.Join(tmpDir, "test.go"),
}
writer := NewWriter(opts)
err := writer.WriteTable(table)
if err != nil {
t.Fatalf("WriteTable failed: %v", err)
}
content, err := os.ReadFile(opts.OutputPath)
if err != nil {
t.Fatalf("Failed to read generated file: %v", err)
}
generated := string(content)
expectations := []string{
"resolvespec_common.SqlUUID",
"func (m ModelPublicAccounts) GetID() string",
"return m.ID.String()",
"func (m ModelPublicAccounts) GetIDStr() string",
"func (m ModelPublicAccounts) SetID(newid string)",
"func (m *ModelPublicAccounts) UpdateID(newid string)",
"m.ID.FromString(newid)",
}
for _, expected := range expectations {
if !strings.Contains(generated, expected) {
t.Errorf("Generated code missing expected content: %q\nGenerated:\n%s", expected, generated)
}
}
if strings.Contains(generated, "GetID() int64") || strings.Contains(generated, "UpdateID(newid int64)") {
t.Errorf("String primary keys should not use int64 helper signatures\nGenerated:\n%s", generated)
}
}
func TestTypeMapper_BuildBunTag(t *testing.T) { func TestTypeMapper_BuildBunTag(t *testing.T) {
mapper := NewTypeMapper("") mapper := NewTypeMapper("")
@@ -685,6 +799,24 @@ func TestTypeMapper_BuildBunTag(t *testing.T) {
}, },
want: []string{"id,", "type:bigserial,", "pk,", "autoincrement,"}, want: []string{"id,", "type:bigserial,", "pk,", "autoincrement,"},
}, },
{
name: "text array type",
column: &models.Column{
Name: "tags",
Type: "text[]",
NotNull: false,
},
want: []string{"tags,", "type:text[],"},
},
{
name: "integer array type",
column: &models.Column{
Name: "scores",
Type: "integer[]",
NotNull: true,
},
want: []string{"scores,", "type:integer[],"},
},
} }
for _, tt := range tests { for _, tt := range tests {
@@ -695,6 +827,30 @@ func TestTypeMapper_BuildBunTag(t *testing.T) {
t.Errorf("BuildBunTag() = %q, missing %q", result, part) t.Errorf("BuildBunTag() = %q, missing %q", result, part)
} }
} }
// resolvespec mode must NOT add "array" — SqlXxxArray uses sql.Scanner
if strings.Contains(result, ",array,") || strings.HasSuffix(result, ",array,") {
t.Errorf("BuildBunTag() = %q, must not contain 'array' in resolvespec mode", result)
}
})
}
}
func TestTypeMapper_BuildBunTag_StdlibArrayHasArrayTag(t *testing.T) {
mapper := NewTypeMapper(writers.NullableTypeStdlib)
cases := []struct {
name string
column *models.Column
}{
{name: "text array", column: &models.Column{Name: "tags", Type: "text[]"}},
{name: "integer array", column: &models.Column{Name: "scores", Type: "integer[]", NotNull: true}},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
result := mapper.BuildBunTag(tt.column, nil)
if !strings.Contains(result, "array") {
t.Errorf("BuildBunTag() = %q, expected 'array' in stdlib mode", result)
}
}) })
} }
} }

View File

@@ -2,6 +2,7 @@ package gorm
import ( import (
"sort" "sort"
"strings"
"git.warky.dev/wdevs/relspecgo/pkg/models" "git.warky.dev/wdevs/relspecgo/pkg/models"
"git.warky.dev/wdevs/relspecgo/pkg/writers" "git.warky.dev/wdevs/relspecgo/pkg/writers"
@@ -17,17 +18,20 @@ type TemplateData struct {
// ModelData represents a single model/struct in the template // ModelData represents a single model/struct in the template
type ModelData struct { type ModelData struct {
Name string Name string
TableName string // schema.table format TableName string // schema.table format
SchemaName string SchemaName string
TableNameOnly string // just table name without schema TableNameOnly string // just table name without schema
Comment string Comment string
Fields []*FieldData Fields []*FieldData
Config *MethodConfig Config *MethodConfig
PrimaryKeyField string // Name of the primary key field PrimaryKeyField string // Name of the primary key field
PrimaryKeyType string // Go type of the primary key field PrimaryKeyType string // Go type of the primary key field
IDColumnName string // Name of the ID column in database PrimaryKeyIsSQL bool // Whether PK uses a SQL wrapper type
Prefix string // 3-letter prefix PrimaryKeyIsStr bool // Whether helper methods should use string IDs
PrimaryKeyIDType string // Helper method GetID/SetID/UpdateID type
IDColumnName string // Name of the ID column in database
Prefix string // 3-letter prefix
} }
// FieldData represents a single field in a struct // FieldData represents a single field in a struct
@@ -136,7 +140,14 @@ func NewModelData(table *models.Table, schema string, typeMapper *TypeMapper, fl
// Sanitize column name to remove backticks // Sanitize column name to remove backticks
safeName := writers.SanitizeStructTagValue(col.Name) safeName := writers.SanitizeStructTagValue(col.Name)
model.PrimaryKeyField = SnakeCaseToPascalCase(safeName) model.PrimaryKeyField = SnakeCaseToPascalCase(safeName)
model.PrimaryKeyType = typeMapper.SQLTypeToGoType(col.Type, col.NotNull) goType := typeMapper.SQLTypeToGoType(col.Type, col.NotNull)
model.PrimaryKeyType = goType
model.PrimaryKeyIsSQL = strings.Contains(goType, "sql_types.") || strings.Contains(goType, "sql.")
model.PrimaryKeyIsStr = isStringLikePrimaryKeyType(goType)
model.PrimaryKeyIDType = "int64"
if model.PrimaryKeyIsStr {
model.PrimaryKeyIDType = "string"
}
model.IDColumnName = safeName model.IDColumnName = safeName
break break
} }
@@ -189,6 +200,15 @@ func formatComment(description, comment string) string {
return comment return comment
} }
func isStringLikePrimaryKeyType(goType string) bool {
switch goType {
case "string", "sql.NullString", "sql_types.SqlString", "sql_types.SqlUUID":
return true
default:
return false
}
}
// resolveFieldNameCollision checks if a field name conflicts with generated method names // resolveFieldNameCollision checks if a field name conflicts with generated method names
// and adds an underscore suffix if there's a collision // and adds an underscore suffix if there's a collision
func resolveFieldNameCollision(fieldName string) string { func resolveFieldNameCollision(fieldName string) string {

View File

@@ -43,26 +43,56 @@ func (m {{.Name}}) SchemaName() string {
{{end}} {{end}}
{{if and .Config.GenerateGetID .PrimaryKeyField}} {{if and .Config.GenerateGetID .PrimaryKeyField}}
// GetID returns the primary key value // GetID returns the primary key value
func (m {{.Name}}) GetID() int64 { func (m {{.Name}}) GetID() {{.PrimaryKeyIDType}} {
{{if .PrimaryKeyIsSQL -}}
{{if .PrimaryKeyIsStr -}}
return m.{{.PrimaryKeyField}}.String()
{{- else -}}
return m.{{.PrimaryKeyField}}.Int64()
{{- end}}
{{- else -}}
{{if .PrimaryKeyIsStr -}}
return m.{{.PrimaryKeyField}}
{{- else -}}
return int64(m.{{.PrimaryKeyField}}) return int64(m.{{.PrimaryKeyField}})
{{- end}}
{{- end}}
} }
{{end}} {{end}}
{{if and .Config.GenerateGetIDStr .PrimaryKeyField}} {{if and .Config.GenerateGetIDStr .PrimaryKeyField}}
// GetIDStr returns the primary key as a string // GetIDStr returns the primary key as a string
func (m {{.Name}}) GetIDStr() string { func (m {{.Name}}) GetIDStr() string {
{{if .PrimaryKeyIsSQL -}}
return m.{{.PrimaryKeyField}}.String()
{{- else if .PrimaryKeyIsStr -}}
return m.{{.PrimaryKeyField}}
{{- else -}}
return fmt.Sprintf("%d", m.{{.PrimaryKeyField}}) return fmt.Sprintf("%d", m.{{.PrimaryKeyField}})
{{- end}}
} }
{{end}} {{end}}
{{if and .Config.GenerateSetID .PrimaryKeyField}} {{if and .Config.GenerateSetID .PrimaryKeyField}}
// SetID sets the primary key value // SetID sets the primary key value
func (m {{.Name}}) SetID(newid int64) { func (m {{.Name}}) SetID(newid {{.PrimaryKeyIDType}}) {
m.UpdateID(newid) m.UpdateID(newid)
} }
{{end}} {{end}}
{{if and .Config.GenerateUpdateID .PrimaryKeyField}} {{if and .Config.GenerateUpdateID .PrimaryKeyField}}
// UpdateID updates the primary key value // UpdateID updates the primary key value
func (m *{{.Name}}) UpdateID(newid int64) { func (m *{{.Name}}) UpdateID(newid {{.PrimaryKeyIDType}}) {
{{if .PrimaryKeyIsSQL -}}
{{if .PrimaryKeyIsStr -}}
m.{{.PrimaryKeyField}}.FromString(newid)
{{- else -}}
m.{{.PrimaryKeyField}}.FromString(fmt.Sprintf("%d", newid))
{{- end}}
{{- else -}}
{{if .PrimaryKeyIsStr -}}
m.{{.PrimaryKeyField}} = newid
{{- else -}}
m.{{.PrimaryKeyField}} = {{.PrimaryKeyType}}(newid) m.{{.PrimaryKeyField}} = {{.PrimaryKeyType}}(newid)
{{- end}}
{{- end}}
} }
{{end}} {{end}}
{{if and .Config.GenerateGetIDName .IDColumnName}} {{if and .Config.GenerateGetIDName .IDColumnName}}

View File

@@ -99,8 +99,8 @@ func (w *Writer) writeSingleFile(db *models.Database) error {
} }
} }
// Add fmt import if GetIDStr is enabled // Add fmt import when generated helper methods need string formatting.
if w.config.GenerateGetIDStr { if w.needsFmtImport(templateData.Models) {
templateData.AddImport("\"fmt\"") templateData.AddImport("\"fmt\"")
} }
@@ -189,8 +189,8 @@ func (w *Writer) writeMultiFile(db *models.Database) error {
} }
} }
// Add fmt import if GetIDStr is enabled // Add fmt import when generated helper methods need string formatting.
if w.config.GenerateGetIDStr { if w.needsFmtImport(templateData.Models) {
templateData.AddImport("\"fmt\"") templateData.AddImport("\"fmt\"")
} }
@@ -295,6 +295,26 @@ func (w *Writer) addRelationshipFields(modelData *ModelData, table *models.Table
} }
} }
func (w *Writer) needsFmtImport(models []*ModelData) bool {
if w.config.GenerateGetIDStr {
for _, model := range models {
if model.PrimaryKeyField != "" && !model.PrimaryKeyIsSQL && !model.PrimaryKeyIsStr {
return true
}
}
}
if w.config.GenerateUpdateID {
for _, model := range models {
if model.PrimaryKeyField != "" && model.PrimaryKeyIsSQL && !model.PrimaryKeyIsStr {
return true
}
}
}
return false
}
// findTable finds a table by schema and name in the database // findTable finds a table by schema and name in the database
func (w *Writer) findTable(schemaName, tableName string, db *models.Database) *models.Table { func (w *Writer) findTable(schemaName, tableName string, db *models.Database) *models.Table {
for _, schema := range db.Schemas { for _, schema := range db.Schemas {

View File

@@ -598,6 +598,55 @@ func TestWriter_UpdateIDTypeSafety(t *testing.T) {
} }
} }
func TestWriter_StringPrimaryKeyHelpers_Gorm(t *testing.T) {
table := models.InitTable("accounts", "public")
table.Columns["id"] = &models.Column{
Name: "id",
Type: "uuid",
NotNull: true,
IsPrimaryKey: true,
}
tmpDir := t.TempDir()
opts := &writers.WriterOptions{
PackageName: "models",
OutputPath: filepath.Join(tmpDir, "test.go"),
}
writer := NewWriter(opts)
err := writer.WriteTable(table)
if err != nil {
t.Fatalf("WriteTable failed: %v", err)
}
content, err := os.ReadFile(opts.OutputPath)
if err != nil {
t.Fatalf("Failed to read generated file: %v", err)
}
generated := string(content)
expectations := []string{
"ID string",
"func (m ModelPublicAccounts) GetID() string",
"return m.ID",
"func (m ModelPublicAccounts) GetIDStr() string",
"func (m ModelPublicAccounts) SetID(newid string)",
"func (m *ModelPublicAccounts) UpdateID(newid string)",
"m.ID = newid",
}
for _, expected := range expectations {
if !strings.Contains(generated, expected) {
t.Errorf("Generated code missing expected content: %q\nGenerated:\n%s", expected, generated)
}
}
if strings.Contains(generated, "GetID() int64") || strings.Contains(generated, "UpdateID(newid int64)") {
t.Errorf("String primary keys should not use int64 helper signatures\nGenerated:\n%s", generated)
}
}
func TestNameConverter_SnakeCaseToPascalCase(t *testing.T) { func TestNameConverter_SnakeCaseToPascalCase(t *testing.T) {
tests := []struct { tests := []struct {
input string input string
@@ -658,6 +707,10 @@ func TestTypeMapper_SQLTypeToGoType(t *testing.T) {
{"timestamp", false, "sql_types.SqlTimeStamp"}, {"timestamp", false, "sql_types.SqlTimeStamp"},
{"boolean", true, "bool"}, {"boolean", true, "bool"},
{"boolean", false, "sql_types.SqlBool"}, {"boolean", false, "sql_types.SqlBool"},
{"text[]", true, "sql_types.SqlStringArray"},
{"text[]", false, "sql_types.SqlStringArray"},
{"integer[]", true, "sql_types.SqlInt32Array"},
{"bigint[]", false, "sql_types.SqlInt64Array"},
} }
for _, tt := range tests { for _, tt := range tests {
@@ -670,6 +723,21 @@ func TestTypeMapper_SQLTypeToGoType(t *testing.T) {
} }
} }
func TestTypeMapper_BuildGormTag_ArrayType(t *testing.T) {
mapper := NewTypeMapper("")
col := &models.Column{
Name: "tags",
Type: "text[]",
NotNull: false,
}
tag := mapper.BuildGormTag(col, nil)
if !strings.Contains(tag, "type:text[]") {
t.Fatalf("expected array type to be preserved, got %q", tag)
}
}
func TestTypeMapper_BuildGormTag_PreservesExplicitTypeModifiers(t *testing.T) { func TestTypeMapper_BuildGormTag_PreservesExplicitTypeModifiers(t *testing.T) {
mapper := NewTypeMapper("") mapper := NewTypeMapper("")

View File

@@ -31,6 +31,10 @@ type MigrationWriter struct {
// NewMigrationWriter creates a new templated migration writer // NewMigrationWriter creates a new templated migration writer
func NewMigrationWriter(options *writers.WriterOptions) (*MigrationWriter, error) { func NewMigrationWriter(options *writers.WriterOptions) (*MigrationWriter, error) {
if options == nil {
options = &writers.WriterOptions{}
}
executor, err := NewTemplateExecutor(options.FlattenSchema) executor, err := NewTemplateExecutor(options.FlattenSchema)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create template executor: %w", err) return nil, fmt.Errorf("failed to create template executor: %w", err)
@@ -44,6 +48,16 @@ func NewMigrationWriter(options *writers.WriterOptions) (*MigrationWriter, error
// WriteMigration generates migration scripts using templates // WriteMigration generates migration scripts using templates
func (w *MigrationWriter) WriteMigration(model *models.Database, current *models.Database) error { func (w *MigrationWriter) WriteMigration(model *models.Database, current *models.Database) error {
if model == nil {
return fmt.Errorf("model database is required")
}
if w.options == nil {
w.options = &writers.WriterOptions{}
}
if current == nil {
current = models.InitDatabase(model.Name)
}
var writer io.Writer var writer io.Writer
var file *os.File var file *os.File
var err error var err error
@@ -86,9 +100,16 @@ func (w *MigrationWriter) WriteMigration(model *models.Database, current *models
// Process each schema in the model // Process each schema in the model
for _, modelSchema := range model.Schemas { for _, modelSchema := range model.Schemas {
if modelSchema == nil {
continue
}
// Find corresponding schema in current database // Find corresponding schema in current database
var currentSchema *models.Schema var currentSchema *models.Schema
for _, cs := range current.Schemas { for _, cs := range current.Schemas {
if cs == nil {
continue
}
if strings.EqualFold(cs.Name, modelSchema.Name) { if strings.EqualFold(cs.Name, modelSchema.Name) {
currentSchema = cs currentSchema = cs
break break
@@ -329,7 +350,11 @@ func (w *MigrationWriter) generateAlterTableScripts(schema *models.Schema, model
// Column doesn't exist, add it // Column doesn't exist, add it
defaultVal := "" defaultVal := ""
if modelCol.Default != nil { if modelCol.Default != nil {
defaultVal = fmt.Sprintf("%v", modelCol.Default) if value, ok := modelCol.Default.(string); ok {
defaultVal = writers.QuoteDefaultValue(value, modelCol.Type)
} else {
defaultVal = fmt.Sprintf("%v", modelCol.Default)
}
} }
sql, err := w.executor.ExecuteAddColumn(AddColumnData{ sql, err := w.executor.ExecuteAddColumn(AddColumnData{
@@ -382,7 +407,11 @@ func (w *MigrationWriter) generateAlterTableScripts(schema *models.Schema, model
setDefault := modelCol.Default != nil setDefault := modelCol.Default != nil
defaultVal := "" defaultVal := ""
if setDefault { if setDefault {
defaultVal = fmt.Sprintf("%v", modelCol.Default) if value, ok := modelCol.Default.(string); ok {
defaultVal = writers.QuoteDefaultValue(value, modelCol.Type)
} else {
defaultVal = fmt.Sprintf("%v", modelCol.Default)
}
} }
sql, err := w.executor.ExecuteAlterColumnDefault(AlterColumnDefaultData{ sql, err := w.executor.ExecuteAlterColumnDefault(AlterColumnDefaultData{
@@ -537,12 +566,17 @@ func (w *MigrationWriter) generateIndexScripts(model *models.Schema, current *mo
indexType = modelIndex.Type indexType = modelIndex.Type
} }
columnExprs := buildIndexColumnExpressions(modelTable, modelIndex, indexType)
if len(columnExprs) == 0 {
continue
}
sql, err := w.executor.ExecuteCreateIndex(CreateIndexData{ sql, err := w.executor.ExecuteCreateIndex(CreateIndexData{
SchemaName: model.Name, SchemaName: model.Name,
TableName: modelTable.Name, TableName: modelTable.Name,
IndexName: indexName, IndexName: indexName,
IndexType: indexType, IndexType: indexType,
Columns: strings.Join(modelIndex.Columns, ", "), Columns: strings.Join(columnExprs, ", "),
Unique: modelIndex.Unique, Unique: modelIndex.Unique,
}) })
if err != nil { if err != nil {
@@ -565,6 +599,27 @@ func (w *MigrationWriter) generateIndexScripts(model *models.Schema, current *mo
return scripts, nil return scripts, nil
} }
func buildIndexColumnExpressions(table *models.Table, index *models.Index, indexType string) []string {
columnExprs := make([]string, 0, len(index.Columns))
for _, colName := range index.Columns {
colExpr := colName
if table != nil {
if col, ok := table.Columns[colName]; ok && col != nil {
colExpr = col.SQLName()
if strings.EqualFold(indexType, "gin") && isTextType(col.Type) {
opClass := extractOperatorClass(index.Comment)
if opClass == "" {
opClass = "gin_trgm_ops"
}
colExpr = fmt.Sprintf("%s %s", col.SQLName(), opClass)
}
}
}
columnExprs = append(columnExprs, colExpr)
}
return columnExprs
}
// generateForeignKeyScripts generates ADD CONSTRAINT FOREIGN KEY scripts using templates // generateForeignKeyScripts generates ADD CONSTRAINT FOREIGN KEY scripts using templates
func (w *MigrationWriter) generateForeignKeyScripts(model *models.Schema, current *models.Schema) ([]MigrationScript, error) { func (w *MigrationWriter) generateForeignKeyScripts(model *models.Schema, current *models.Schema) ([]MigrationScript, error) {
scripts := make([]MigrationScript, 0) scripts := make([]MigrationScript, 0)

View File

@@ -57,6 +57,129 @@ func TestWriteMigration_NewTable(t *testing.T) {
} }
} }
func TestWriteMigration_ArrayDefault(t *testing.T) {
current := models.InitDatabase("testdb")
currentSchema := models.InitSchema("public")
current.Schemas = append(current.Schemas, currentSchema)
model := models.InitDatabase("testdb")
modelSchema := models.InitSchema("public")
table := models.InitTable("plans", "public")
tagsCol := models.InitColumn("tags", "plans", "public")
tagsCol.Type = "text[]"
tagsCol.NotNull = true
tagsCol.Default = "''{}''"
table.Columns["tags"] = tagsCol
modelSchema.Tables = append(modelSchema.Tables, table)
model.Schemas = append(model.Schemas, modelSchema)
var buf bytes.Buffer
writer, err := NewMigrationWriter(&writers.WriterOptions{})
if err != nil {
t.Fatalf("Failed to create writer: %v", err)
}
writer.writer = &buf
err = writer.WriteMigration(model, current)
if err != nil {
t.Fatalf("WriteMigration failed: %v", err)
}
output := buf.String()
if !strings.Contains(output, "tags text[] DEFAULT '{}' NOT NULL") {
t.Fatalf("expected normalized array default in migration, got:\n%s", output)
}
if strings.Contains(output, "'''{}'''") {
t.Fatalf("migration still contains triple-quoted array default:\n%s", output)
}
}
func TestWriteMigration_GinIndexOnTextUsesTrigramOperatorClass(t *testing.T) {
current := models.InitDatabase("testdb")
currentSchema := models.InitSchema("public")
current.Schemas = append(current.Schemas, currentSchema)
model := models.InitDatabase("testdb")
modelSchema := models.InitSchema("public")
table := models.InitTable("articles", "public")
titleCol := models.InitColumn("title", "articles", "public")
titleCol.Type = "text"
table.Columns["title"] = titleCol
index := &models.Index{
Name: "idx_articles_title_gin",
Type: "gin",
Columns: []string{"title"},
}
table.Indexes[index.Name] = index
modelSchema.Tables = append(modelSchema.Tables, table)
model.Schemas = append(model.Schemas, modelSchema)
var buf bytes.Buffer
writer, err := NewMigrationWriter(&writers.WriterOptions{})
if err != nil {
t.Fatalf("Failed to create writer: %v", err)
}
writer.writer = &buf
if err := writer.WriteMigration(model, current); err != nil {
t.Fatalf("WriteMigration failed: %v", err)
}
output := buf.String()
if !strings.Contains(output, "USING gin (title gin_trgm_ops)") {
t.Fatalf("expected GIN text index to include gin_trgm_ops, got:\n%s", output)
}
}
func TestWriteMigration_GinIndexOnTextArrayDoesNotUseTrigramOperatorClass(t *testing.T) {
current := models.InitDatabase("testdb")
currentSchema := models.InitSchema("public")
current.Schemas = append(current.Schemas, currentSchema)
model := models.InitDatabase("testdb")
modelSchema := models.InitSchema("public")
table := models.InitTable("plans", "public")
tagsCol := models.InitColumn("tags", "plans", "public")
tagsCol.Type = "text[]"
table.Columns["tags"] = tagsCol
index := &models.Index{
Name: "idx_plans_tags",
Type: "gin",
Columns: []string{"tags"},
}
table.Indexes[index.Name] = index
modelSchema.Tables = append(modelSchema.Tables, table)
model.Schemas = append(model.Schemas, modelSchema)
var buf bytes.Buffer
writer, err := NewMigrationWriter(&writers.WriterOptions{})
if err != nil {
t.Fatalf("Failed to create writer: %v", err)
}
writer.writer = &buf
if err := writer.WriteMigration(model, current); err != nil {
t.Fatalf("WriteMigration failed: %v", err)
}
output := buf.String()
if !strings.Contains(output, "USING gin (tags)") {
t.Fatalf("expected GIN array index without explicit trigram opclass, got:\n%s", output)
}
if strings.Contains(output, "gin_trgm_ops") {
t.Fatalf("did not expect gin_trgm_ops for text[] migration index, got:\n%s", output)
}
}
func TestWriteMigration_WithAudit(t *testing.T) { func TestWriteMigration_WithAudit(t *testing.T) {
// Current database (empty) // Current database (empty)
current := models.InitDatabase("testdb") current := models.InitDatabase("testdb")
@@ -282,3 +405,46 @@ func TestWriteMigration_NumericConstraintNames(t *testing.T) {
t.Error("Migration missing FOREIGN KEY") t.Error("Migration missing FOREIGN KEY")
} }
} }
func TestNewMigrationWriter_NilOptions(t *testing.T) {
writer, err := NewMigrationWriter(nil)
if err != nil {
t.Fatalf("NewMigrationWriter(nil) returned error: %v", err)
}
if writer == nil {
t.Fatal("expected writer instance")
}
if writer.options == nil {
t.Fatal("expected default writer options to be initialized")
}
}
func TestWriteMigration_NilCurrentTreatsDatabaseAsEmpty(t *testing.T) {
model := models.InitDatabase("testdb")
modelSchema := models.InitSchema("public")
table := models.InitTable("users", "public")
idCol := models.InitColumn("id", "users", "public")
idCol.Type = "integer"
idCol.NotNull = true
table.Columns["id"] = idCol
modelSchema.Tables = append(modelSchema.Tables, table)
model.Schemas = append(model.Schemas, modelSchema)
var buf bytes.Buffer
writer, err := NewMigrationWriter(nil)
if err != nil {
t.Fatalf("Failed to create writer: %v", err)
}
writer.writer = &buf
if err := writer.WriteMigration(model, nil); err != nil {
t.Fatalf("WriteMigration with nil current failed: %v", err)
}
output := buf.String()
if !strings.Contains(output, "CREATE TABLE") {
t.Fatalf("expected CREATE TABLE in migration output, got:\n%s", output)
}
}

View File

@@ -8,6 +8,7 @@ import (
"text/template" "text/template"
"git.warky.dev/wdevs/relspecgo/pkg/models" "git.warky.dev/wdevs/relspecgo/pkg/models"
"git.warky.dev/wdevs/relspecgo/pkg/writers"
) )
//go:embed templates/*.tmpl //go:embed templates/*.tmpl
@@ -266,6 +267,7 @@ type CreatePrimaryKeyWithAutoGenCheckData struct {
ConstraintName string ConstraintName string
AutoGenNames string // Comma-separated list of names like "'name1', 'name2'" AutoGenNames string // Comma-separated list of names like "'name1', 'name2'"
Columns string Columns string
ColumnNames string // Comma-separated list of quoted column names like "'id', 'tenant_id'"
} }
// Execute methods for each template // Execute methods for each template
@@ -495,7 +497,11 @@ func BuildCreateTableData(schemaName string, table *models.Table) CreateTableDat
NotNull: col.NotNull, NotNull: col.NotNull,
} }
if col.Default != nil { if col.Default != nil {
colData.Default = fmt.Sprintf("%v", col.Default) if value, ok := col.Default.(string); ok {
colData.Default = writers.QuoteDefaultValue(value, col.Type)
} else {
colData.Default = fmt.Sprintf("%v", col.Default)
}
} }
columns = append(columns, colData) columns = append(columns, colData)
} }

View File

@@ -1,26 +1,42 @@
DO $$ DO $$
DECLARE DECLARE
auto_pk_name text; current_pk_name text;
current_pk_matches boolean := false;
BEGIN BEGIN
-- Drop auto-generated primary key if it exists SELECT tc.constraint_name,
SELECT constraint_name INTO auto_pk_name COALESCE(
FROM information_schema.table_constraints ARRAY(
WHERE table_schema = '{{.SchemaName}}' SELECT a.attname
AND table_name = '{{.TableName}}' FROM pg_constraint c
AND constraint_type = 'PRIMARY KEY' JOIN pg_class t ON t.oid = c.conrelid
AND constraint_name IN ({{.AutoGenNames}}); JOIN pg_namespace n ON n.oid = t.relnamespace
JOIN unnest(c.conkey) WITH ORDINALITY AS cols(attnum, ord)
ON TRUE
JOIN pg_attribute a
ON a.attrelid = t.oid
AND a.attnum = cols.attnum
WHERE c.contype = 'p'
AND n.nspname = '{{.SchemaName}}'
AND t.relname = '{{.TableName}}'
ORDER BY cols.ord
),
ARRAY[]::text[]
) = ARRAY[{{.ColumnNames}}]
INTO current_pk_name, current_pk_matches
FROM information_schema.table_constraints tc
WHERE tc.table_schema = '{{.SchemaName}}'
AND tc.table_name = '{{.TableName}}'
AND tc.constraint_type = 'PRIMARY KEY';
IF auto_pk_name IS NOT NULL THEN IF current_pk_name IS NOT NULL
EXECUTE 'ALTER TABLE {{qual_table .SchemaName .TableName}} DROP CONSTRAINT ' || quote_ident(auto_pk_name); AND NOT current_pk_matches
AND current_pk_name IN ({{.AutoGenNames}}) THEN
EXECUTE 'ALTER TABLE {{qual_table .SchemaName .TableName}} DROP CONSTRAINT ' || quote_ident(current_pk_name);
END IF; END IF;
-- Add named primary key if it doesn't exist -- Add the desired primary key only when no matching primary key already exists.
IF NOT EXISTS ( IF current_pk_name IS NULL
SELECT 1 FROM information_schema.table_constraints OR (NOT current_pk_matches AND current_pk_name IN ({{.AutoGenNames}})) THEN
WHERE table_schema = '{{.SchemaName}}'
AND table_name = '{{.TableName}}'
AND constraint_name = '{{.ConstraintName}}'
) THEN
ALTER TABLE {{qual_table .SchemaName .TableName}} ADD CONSTRAINT {{quote_ident .ConstraintName}} PRIMARY KEY ({{.Columns}}); ALTER TABLE {{qual_table .SchemaName .TableName}} ADD CONSTRAINT {{quote_ident .ConstraintName}} PRIMARY KEY ({{.Columns}});
END IF; END IF;
END; END;

View File

@@ -228,6 +228,7 @@ func (w *Writer) GenerateSchemaStatements(schema *models.Schema) ([]string, erro
ConstraintName: pkName, ConstraintName: pkName,
AutoGenNames: formatStringList(autoGenPKNames), AutoGenNames: formatStringList(autoGenPKNames),
Columns: strings.Join(pkColumns, ", "), Columns: strings.Join(pkColumns, ", "),
ColumnNames: formatStringList(pkColumns),
} }
stmt, err := w.executor.ExecuteCreatePrimaryKeyWithAutoGenCheck(data) stmt, err := w.executor.ExecuteCreatePrimaryKeyWithAutoGenCheck(data)
@@ -523,15 +524,7 @@ func (w *Writer) generateColumnDefinition(col *models.Column) string {
if col.Default != nil { if col.Default != nil {
switch v := col.Default.(type) { switch v := col.Default.(type) {
case string: case string:
// Strip backticks - DBML uses them for SQL expressions but PostgreSQL doesn't parts = append(parts, fmt.Sprintf("DEFAULT %s", writers.QuoteDefaultValue(stripBackticks(v), col.Type)))
cleanDefault := stripBackticks(v)
if strings.HasPrefix(cleanDefault, "nextval") || strings.HasPrefix(cleanDefault, "CURRENT_") || strings.Contains(cleanDefault, "()") {
parts = append(parts, fmt.Sprintf("DEFAULT %s", cleanDefault))
} else if cleanDefault == "true" || cleanDefault == "false" {
parts = append(parts, fmt.Sprintf("DEFAULT %s", cleanDefault))
} else {
parts = append(parts, fmt.Sprintf("DEFAULT '%s'", escapeQuote(cleanDefault)))
}
case bool: case bool:
parts = append(parts, fmt.Sprintf("DEFAULT %v", v)) parts = append(parts, fmt.Sprintf("DEFAULT %v", v))
default: default:
@@ -814,6 +807,7 @@ func (w *Writer) writePrimaryKeys(schema *models.Schema) error {
ConstraintName: pkName, ConstraintName: pkName,
AutoGenNames: formatStringList(autoGenPKNames), AutoGenNames: formatStringList(autoGenPKNames),
Columns: strings.Join(columnNames, ", "), Columns: strings.Join(columnNames, ", "),
ColumnNames: formatStringList(columnNames),
} }
sql, err := w.executor.ExecuteCreatePrimaryKeyWithAutoGenCheck(data) sql, err := w.executor.ExecuteCreatePrimaryKeyWithAutoGenCheck(data)
@@ -1259,6 +1253,9 @@ func isIntegerType(colType string) bool {
func isTextType(colType string) bool { func isTextType(colType string) bool {
textTypes := []string{"text", "varchar", "character varying", "char", "character", "string"} textTypes := []string{"text", "varchar", "character varying", "char", "character", "string"}
lowerType := strings.ToLower(colType) lowerType := strings.ToLower(colType)
if strings.HasSuffix(lowerType, "[]") {
return false
}
for _, t := range textTypes { for _, t := range textTypes {
if strings.HasPrefix(lowerType, t) { if strings.HasPrefix(lowerType, t) {
return true return true

View File

@@ -87,6 +87,43 @@ func TestWriteDatabase(t *testing.T) {
} }
} }
func TestWriteDatabase_GinIndexOnTextArrayDoesNotUseTrigramOperatorClass(t *testing.T) {
db := models.InitDatabase("testdb")
schema := models.InitSchema("public")
table := models.InitTable("plans", "public")
tagsCol := models.InitColumn("tags", "plans", "public")
tagsCol.Type = "text[]"
table.Columns["tags"] = tagsCol
index := &models.Index{
Name: "idx_plans_tags",
Type: "gin",
Columns: []string{"tags"},
}
table.Indexes[index.Name] = index
schema.Tables = append(schema.Tables, table)
db.Schemas = append(db.Schemas, schema)
var buf bytes.Buffer
writer := NewWriter(&writers.WriterOptions{})
writer.writer = &buf
if err := writer.WriteDatabase(db); err != nil {
t.Fatalf("WriteDatabase failed: %v", err)
}
output := buf.String()
if !strings.Contains(output, `USING gin (tags)`) {
t.Fatalf("expected GIN index on array column without explicit trigram opclass, got:\n%s", output)
}
if strings.Contains(output, "gin_trgm_ops") {
t.Fatalf("did not expect gin_trgm_ops for text[] column, got:\n%s", output)
}
}
func TestWriteForeignKeys(t *testing.T) { func TestWriteForeignKeys(t *testing.T) {
// Create a test database with two related tables // Create a test database with two related tables
db := models.InitDatabase("testdb") db := models.InitDatabase("testdb")
@@ -636,9 +673,14 @@ func TestPrimaryKeyExistenceCheck(t *testing.T) {
t.Errorf("Output missing logic to drop auto-generated primary key\nFull output:\n%s", output) t.Errorf("Output missing logic to drop auto-generated primary key\nFull output:\n%s", output)
} }
// Verify it checks for our specific named constraint before adding it // Verify it compares the current primary key columns before dropping/recreating
if !strings.Contains(output, "constraint_name = 'pk_public_products'") { if !strings.Contains(output, "current_pk_matches") || !strings.Contains(output, "ARRAY['id']") {
t.Errorf("Output missing check for our named primary key constraint\nFull output:\n%s", output) t.Errorf("Output missing safe primary key comparison logic\nFull output:\n%s", output)
}
// Verify it only adds the desired key when no PK exists or an auto-generated mismatch was dropped
if !strings.Contains(output, "current_pk_name IS NULL") || !strings.Contains(output, "current_pk_name IN ('products_pkey', 'public_products_pkey')") {
t.Errorf("Output missing guarded primary key creation logic\nFull output:\n%s", output)
} }
} }
@@ -729,6 +771,43 @@ func TestColumnSizeSpecifiers(t *testing.T) {
} }
} }
func TestWriteDatabase_PrimaryKeyTemplateDoesNotDropMatchingAutoPrimaryKey(t *testing.T) {
db := models.InitDatabase("testdb")
schema := models.InitSchema("public")
table := models.InitTable("learnings", "public")
idCol := models.InitColumn("id", "learnings", "public")
idCol.Type = "bigint"
idCol.IsPrimaryKey = true
table.Columns["id"] = idCol
parentCol := models.InitColumn("duplicate_of_learning_id", "learnings", "public")
parentCol.Type = "bigint"
table.Columns["duplicate_of_learning_id"] = parentCol
schema.Tables = append(schema.Tables, table)
db.Schemas = append(db.Schemas, schema)
var buf bytes.Buffer
writer := NewWriter(&writers.WriterOptions{})
writer.writer = &buf
if err := writer.WriteDatabase(db); err != nil {
t.Fatalf("WriteDatabase failed: %v", err)
}
output := buf.String()
if !strings.Contains(output, "current_pk_matches") {
t.Fatalf("expected generated SQL to compare current PK columns, got:\n%s", output)
}
if !strings.Contains(output, "ARRAY['id']") {
t.Fatalf("expected generated SQL to compare against desired PK columns, got:\n%s", output)
}
if !strings.Contains(output, "NOT current_pk_matches") {
t.Fatalf("expected generated SQL to avoid dropping matching PKs, got:\n%s", output)
}
}
func TestGenerateColumnDefinition_PreservesExplicitTypeModifiers(t *testing.T) { func TestGenerateColumnDefinition_PreservesExplicitTypeModifiers(t *testing.T) {
writer := NewWriter(&writers.WriterOptions{}) writer := NewWriter(&writers.WriterOptions{})

View File

@@ -61,7 +61,7 @@ func (w *Writer) databaseToPrisma(db *models.Database) string {
sb.WriteString("\n") sb.WriteString("\n")
// Write generator block // Write generator block
sb.WriteString(w.generateGenerator()) sb.WriteString(w.generateGenerator(db))
sb.WriteString("\n") sb.WriteString("\n")
// Process all schemas (typically just one in Prisma) // Process all schemas (typically just one in Prisma)
@@ -114,13 +114,28 @@ func (w *Writer) generateDatasource(db *models.Database) string {
} }
// generateGenerator generates the generator block // generateGenerator generates the generator block
func (w *Writer) generateGenerator() string { func (w *Writer) generateGenerator(db *models.Database) string {
if w.usePrisma7Generator(db) {
return `generator client {
provider = "prisma-client"
output = "./generated"
}
`
}
return `generator client { return `generator client {
provider = "prisma-client-js" provider = "prisma-client-js"
} }
` `
} }
func (w *Writer) usePrisma7Generator(db *models.Database) bool {
if w.options != nil && w.options.Prisma7 {
return true
}
return db != nil && db.SourceFormat == "prisma7"
}
// enumToPrisma converts an Enum to Prisma enum block // enumToPrisma converts an Enum to Prisma enum block
func (w *Writer) enumToPrisma(enum *models.Enum) string { func (w *Writer) enumToPrisma(enum *models.Enum) string {
var sb strings.Builder var sb strings.Builder

View File

@@ -0,0 +1,52 @@
package prisma
import (
"strings"
"testing"
"git.warky.dev/wdevs/relspecgo/pkg/models"
"git.warky.dev/wdevs/relspecgo/pkg/writers"
)
func TestGenerateGenerator_DefaultsToPrismaClientJS(t *testing.T) {
t.Parallel()
writer := NewWriter(&writers.WriterOptions{})
db := models.InitDatabase("testdb")
got := writer.generateGenerator(db)
if !strings.Contains(got, `provider = "prisma-client-js"`) {
t.Fatalf("expected prisma-client-js generator, got:\n%s", got)
}
if strings.Contains(got, `output = "./generated"`) {
t.Fatalf("did not expect prisma7 output path in default generator:\n%s", got)
}
}
func TestGenerateGenerator_Prisma7FlagUsesPrismaClient(t *testing.T) {
t.Parallel()
writer := NewWriter(&writers.WriterOptions{Prisma7: true})
db := models.InitDatabase("testdb")
got := writer.generateGenerator(db)
if !strings.Contains(got, `provider = "prisma-client"`) {
t.Fatalf("expected prisma-client generator, got:\n%s", got)
}
if !strings.Contains(got, `output = "./generated"`) {
t.Fatalf("expected prisma7 output path, got:\n%s", got)
}
}
func TestGenerateGenerator_Prisma7SourceFormatUsesPrismaClient(t *testing.T) {
t.Parallel()
writer := NewWriter(&writers.WriterOptions{})
db := models.InitDatabase("testdb")
db.SourceFormat = "prisma7"
got := writer.generateGenerator(db)
if !strings.Contains(got, `provider = "prisma-client"`) {
t.Fatalf("expected prisma-client generator from source format, got:\n%s", got)
}
}

View File

@@ -51,6 +51,9 @@ type WriterOptions struct {
// "stdlib" — database/sql (sql.NullString, sql.NullInt32, …) // "stdlib" — database/sql (sql.NullString, sql.NullInt32, …)
NullableTypes string NullableTypes string
// Prisma7 enables Prisma 7-specific output for Prisma writers.
Prisma7 bool
// Additional options can be added here as needed // Additional options can be added here as needed
Metadata map[string]interface{} Metadata map[string]interface{}
} }
@@ -110,8 +113,12 @@ func SanitizeFilename(name string) string {
// Examples (bigint): "0" → "0" // Examples (bigint): "0" → "0"
// Examples (timestamp): "now()" → "now()" (function call never quoted) // Examples (timestamp): "now()" → "now()" (function call never quoted)
func QuoteDefaultValue(value, sqlType string) string { func QuoteDefaultValue(value, sqlType string) string {
value = strings.TrimSpace(value)
// Function calls are never quoted regardless of column type. // Function calls are never quoted regardless of column type.
if strings.Contains(value, "(") || strings.Contains(value, ")") { if strings.Contains(value, "(") || strings.Contains(value, ")") ||
strings.Contains(value, "::") ||
strings.HasPrefix(strings.ToUpper(value), "ARRAY[") {
return value return value
} }
@@ -121,6 +128,16 @@ func QuoteDefaultValue(value, sqlType string) string {
baseType = baseType[:idx] baseType = baseType[:idx]
} }
if isArraySQLType(baseType) {
if arrayLiteral, ok := normalizeArrayDefaultLiteral(value); ok {
return quoteSQLLiteral(arrayLiteral)
}
}
if isQuotedSQLLiteral(value) {
return value
}
// Types whose default values must NOT be quoted. // Types whose default values must NOT be quoted.
unquotedTypes := map[string]bool{ unquotedTypes := map[string]bool{
// Integer types // Integer types
@@ -154,7 +171,32 @@ func QuoteDefaultValue(value, sqlType string) string {
// Everything else (text, varchar, char, uuid, date, time, timestamp, json, …) // Everything else (text, varchar, char, uuid, date, time, timestamp, json, …)
// is treated as a quoted literal. // is treated as a quoted literal.
return "'" + value + "'" return quoteSQLLiteral(value)
}
func isArraySQLType(sqlType string) bool {
return strings.HasSuffix(sqlType, "[]")
}
func normalizeArrayDefaultLiteral(value string) (string, bool) {
switch {
case strings.HasPrefix(value, "''{") && strings.HasSuffix(value, "}''"):
return value[2 : len(value)-2], true
case strings.HasPrefix(value, "'{") && strings.HasSuffix(value, "}'"):
return value[1 : len(value)-1], true
case strings.HasPrefix(value, "{") && strings.HasSuffix(value, "}"):
return value, true
default:
return "", false
}
}
func isQuotedSQLLiteral(value string) bool {
return len(value) >= 2 && strings.HasPrefix(value, "'") && strings.HasSuffix(value, "'")
}
func quoteSQLLiteral(value string) string {
return "'" + strings.ReplaceAll(value, "'", "''") + "'"
} }
// SanitizeStructTagValue sanitizes a value to be safely used inside Go struct tags. // SanitizeStructTagValue sanitizes a value to be safely used inside Go struct tags.
@@ -165,7 +207,8 @@ func QuoteDefaultValue(value, sqlType string) string {
// - Returns a clean identifier safe for use in struct tags and field names // - Returns a clean identifier safe for use in struct tags and field names
func SanitizeStructTagValue(value string) string { func SanitizeStructTagValue(value string) string {
// Remove DBML/DCTX style comments in brackets (e.g., [note: 'description']) // Remove DBML/DCTX style comments in brackets (e.g., [note: 'description'])
commentRegex := regexp.MustCompile(`\s*\[.*?\]\s*`) // Require at least one character inside brackets to avoid stripping PostgreSQL array suffix "[]"
commentRegex := regexp.MustCompile(`\s*\[[^\]]+\]\s*`)
value = commentRegex.ReplaceAllString(value, "") value = commentRegex.ReplaceAllString(value, "")
// Trim whitespace // Trim whitespace

View File

@@ -0,0 +1,54 @@
package writers
import "testing"
func TestQuoteDefaultValue(t *testing.T) {
t.Parallel()
tests := []struct {
name string
value string
sqlType string
want string
}{
{
name: "text default is quoted",
value: "active",
sqlType: "text",
want: "'active'",
},
{
name: "array default from bare literal is quoted once",
value: "{}",
sqlType: "text[]",
want: "'{}'",
},
{
name: "array default from quoted literal is preserved",
value: "'{}'",
sqlType: "text[]",
want: "'{}'",
},
{
name: "array default from double quoted literal is normalized",
value: "''{}''",
sqlType: "text[]",
want: "'{}'",
},
{
name: "function default is left alone",
value: "now()",
sqlType: "timestamptz",
want: "now()",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := QuoteDefaultValue(tt.value, tt.sqlType)
if got != tt.want {
t.Fatalf("QuoteDefaultValue(%q, %q) = %q, want %q", tt.value, tt.sqlType, got, tt.want)
}
})
}
}