Compare commits

..

26 Commits

Author SHA1 Message Date
72200ea72e chore(release): update package version to 1.0.55
All checks were successful
Release / test (push) Successful in -32m1s
Release / release (push) Successful in -31m13s
Release / pkg-aur (push) Successful in -32m13s
Release / pkg-deb (push) Successful in -31m12s
Release / pkg-rpm (push) Successful in -29m45s
2026-05-05 11:36:29 +02:00
608893a3d6 feat(index): implement GIN index support for quoted text columns and enhance index column resolution 2026-05-05 11:32:15 +02:00
53ff745d5d chore(release): update package version to 1.0.54
All checks were successful
Release / test (push) Successful in -31m47s
Release / release (push) Successful in -31m9s
Release / pkg-aur (push) Successful in -31m57s
Release / pkg-deb (push) Successful in -31m1s
Release / pkg-rpm (push) Successful in -29m27s
2026-05-05 11:12:49 +02:00
17bc8ed395 feat(migration): enhance primary key handling and add GIN index support in migration writer 2026-05-05 11:12:23 +02:00
a447b68b22 chore(release): update package version to 1.0.53
All checks were successful
Release / test (push) Successful in -31m55s
Release / release (push) Successful in -31m19s
Release / pkg-aur (push) Successful in -32m3s
Release / pkg-deb (push) Successful in -31m21s
Release / pkg-rpm (push) Successful in -28m4s
2026-05-05 10:48:27 +02:00
4303dcf59b Support typed primary key helpers in gorm and bun writers 2026-05-05 10:32:33 +02:00
e828d48798 chore(release): update package version to 1.0.52
All checks were successful
Release / test (push) Successful in -32m39s
Release / release (push) Successful in -32m1s
Release / pkg-deb (push) Successful in -32m9s
Release / pkg-aur (push) Successful in -31m37s
Release / pkg-rpm (push) Successful in -27m28s
2026-05-03 17:19:22 +02:00
6e470a9239 fix(type_mapper): adjust array tag handling in BuildBunTag 2026-05-03 17:18:58 +02:00
096815fe49 chore(release): update package version to 1.0.51
All checks were successful
Release / test (push) Successful in -32m30s
Release / release (push) Successful in -31m54s
Release / pkg-aur (push) Successful in -32m31s
Release / pkg-deb (push) Successful in -32m7s
Release / pkg-rpm (push) Successful in -30m36s
2026-05-03 16:11:13 +02:00
b8f60203cb fix(type_mapper): handle PostgreSQL array types in tags
* Update BuildBunTag to append "array" for array types
* Add tests for handling array types in TypeMapper
* Adjust regex in SanitizeStructTagValue to preserve array suffix
2026-05-03 16:11:01 +02:00
Hein
15763f60cc Fix GIN opclass handling for array columns 2026-04-30 20:35:06 +02:00
Hein
6d2884f5cf chore(release): update package version to 1.0.50
Some checks failed
Release / test (push) Successful in -32m41s
Release / release (push) Successful in -28m56s
Release / pkg-deb (push) Successful in -31m19s
Release / pkg-aur (push) Successful in -27m21s
Release / pkg-rpm (push) Failing after -26m24s
2026-04-30 20:23:29 +02:00
Hein
f192decff8 Add Prisma 7 flag support 2026-04-30 20:22:57 +02:00
Hein
8b906cf4a3 chore(release): update package version to 1.0.49
All checks were successful
Release / test (push) Successful in -32m39s
Release / release (push) Successful in -31m40s
Release / pkg-aur (push) Successful in -32m46s
Release / pkg-deb (push) Successful in -32m9s
Release / pkg-rpm (push) Successful in -29m53s
2026-04-30 18:16:28 +02:00
Hein
0a3966e6fc fix(pgsql): handle default values for array types in migrations
* update default value quoting logic for PostgreSQL
* add tests for array default value handling
2026-04-30 18:16:21 +02:00
Hein
d30fc24f55 chore(release): update package version to 1.0.48
All checks were successful
Release / pkg-deb (push) Successful in -32m6s
Release / test (push) Successful in -32m44s
Release / release (push) Successful in -32m5s
Release / pkg-aur (push) Successful in -32m38s
Release / pkg-rpm (push) Successful in -30m46s
2026-04-30 16:07:33 +02:00
Hein
16a489d0b8 style(pkg): align json and numeric type mappings 2026-04-30 16:07:16 +02:00
Hein
3524e86282 feat: add --types flag and stdlib nullable type support for bun/gorm writers
* Fix pgsql reader double-quoting defaults: normalizePostgresDefault strips
  surrounding SQL string literal quotes from column_default before storing,
  matching the convention used by every other reader.

* Add NullableTypes field to WriterOptions with NullableTypeResolveSpec
  (default) and NullableTypeStdlib constants.

* Both bun and gorm TypeMappers now accept a typeStyle parameter. stdlib
  mode produces sql.NullString/NullInt32/NullTime etc. for nullable scalars,
  plain Go slices for arrays, and time.Time for NOT NULL timestamps. Default
  resolvespec behaviour is unchanged.

* Add --types flag to convert and split commands.

* Update bun/README.md and gorm/README.md with side-by-side generated code
  examples, updated type mapping tables, and Writer Options documentation.
2026-04-30 16:00:54 +02:00
Hein
1e54fdcd7f Merge branch 'master' of git.warky.dev:wdevs/relspecgo 2026-04-30 15:15:34 +02:00
fb104ea084 feat: PostgreSQL connections opened by relspec set application_name by default to relspecgo/<version>
All checks were successful
Release / test (push) Successful in -31m41s
Release / release (push) Successful in -28m47s
Release / pkg-aur (push) Successful in -32m40s
Release / pkg-deb (push) Successful in -32m25s
Release / pkg-rpm (push) Successful in -28m30s
2026-04-26 17:48:26 +02:00
837160b77a feat(pgsql): implement application_name handling in connection 2026-04-26 17:45:25 +02:00
ed7130bba8 refactor(pkg): canonicalize base types and adjust length handling
* Update base types to keep explicit modifier forms
* Modify length handling for vector types in tests
2026-04-26 17:35:15 +02:00
4ca1810d07 refactor(dctx): sort table columns and indexes for deterministic output
Some checks failed
Release / test (push) Failing after -31m18s
Release / release (push) Has been skipped
Release / pkg-aur (push) Has been skipped
Release / pkg-deb (push) Has been skipped
Release / pkg-rpm (push) Has been skipped
2026-04-26 12:50:39 +02:00
c0880cb076 feat(pkg): preserve PostgreSQL types in mapDataType function
Some checks failed
Release / test (push) Failing after -31m27s
Release / release (push) Has been skipped
Release / pkg-aur (push) Has been skipped
Release / pkg-deb (push) Has been skipped
Release / pkg-rpm (push) Has been skipped
* Add support for known PostgreSQL types and modifiers
* Implement canonicalization for PostgreSQL types
* Introduce unit tests for PostgreSQL type handling
2026-04-26 12:43:44 +02:00
988798998d test(drawdb): add test for converting column types with modifiers
* Implement tests to ensure explicit type modifiers are preserved during conversion.
* Validate behavior for varchar, numeric, and custom vector types.
2026-04-26 12:35:54 +02:00
Hein
3d9cc7ec58 .
All checks were successful
Release / Build and Release (push) Successful in -25m33s
2026-02-20 16:32:19 +02:00
60 changed files with 3126 additions and 579 deletions

0
.codex Normal file
View File

View File

@@ -42,6 +42,11 @@ relspec convert --from pgsql --from-conn "postgres://..." --to sqlite --to-path
relspec convert --from json --from-list "a.json,b.json" --to yaml --to-path merged.yaml relspec convert --from json --from-list "a.json,b.json" --to yaml --to-path merged.yaml
``` ```
PostgreSQL connections opened by relspec set `application_name` by default to
`relspecgo/<version>` (with component suffixes internally, e.g. readers/writers).
If you need a custom value, provide `application_name` explicitly in the connection
string query parameters.
### `merge` — Additive schema merge (never modifies existing items) ### `merge` — Additive schema merge (never modifies existing items)
```bash ```bash

View File

@@ -52,6 +52,7 @@ var (
convertPackageName string convertPackageName string
convertSchemaFilter string convertSchemaFilter string
convertFlattenSchema bool convertFlattenSchema bool
convertNullableTypes string
) )
var convertCmd = &cobra.Command{ var convertCmd = &cobra.Command{
@@ -175,6 +176,7 @@ func init() {
convertCmd.Flags().StringVar(&convertPackageName, "package", "", "Package name (for code generation formats like gorm/bun)") convertCmd.Flags().StringVar(&convertPackageName, "package", "", "Package name (for code generation formats like gorm/bun)")
convertCmd.Flags().StringVar(&convertSchemaFilter, "schema", "", "Filter to a specific schema by name (required for formats like dctx that only support single schemas)") convertCmd.Flags().StringVar(&convertSchemaFilter, "schema", "", "Filter to a specific schema by name (required for formats like dctx that only support single schemas)")
convertCmd.Flags().BoolVar(&convertFlattenSchema, "flatten-schema", false, "Flatten schema.table names to schema_table (useful for databases like SQLite that do not support schemas)") convertCmd.Flags().BoolVar(&convertFlattenSchema, "flatten-schema", false, "Flatten schema.table names to schema_table (useful for databases like SQLite that do not support schemas)")
convertCmd.Flags().StringVar(&convertNullableTypes, "types", "", "Nullable type package for code-gen writers (bun/gorm): 'resolvespec' (default) or 'stdlib' (database/sql)")
err := convertCmd.MarkFlagRequired("from") err := convertCmd.MarkFlagRequired("from")
if err != nil { if err != nil {
@@ -241,7 +243,7 @@ func runConvert(cmd *cobra.Command, args []string) error {
fmt.Fprintf(os.Stderr, " Schema: %s\n", convertSchemaFilter) fmt.Fprintf(os.Stderr, " Schema: %s\n", convertSchemaFilter)
} }
if err := writeDatabase(db, convertTargetType, convertTargetPath, convertPackageName, convertSchemaFilter, convertFlattenSchema); err != nil { if err := writeDatabase(db, convertTargetType, convertTargetPath, convertPackageName, convertSchemaFilter, convertFlattenSchema, convertNullableTypes); err != nil {
return fmt.Errorf("failed to write target: %w", err) return fmt.Errorf("failed to write target: %w", err)
} }
@@ -284,79 +286,79 @@ func readDatabaseForConvert(dbType, filePath, connString string) (*models.Databa
if filePath == "" { if filePath == "" {
return nil, fmt.Errorf("file path is required for DBML format") return nil, fmt.Errorf("file path is required for DBML format")
} }
reader = dbml.NewReader(&readers.ReaderOptions{FilePath: filePath}) reader = dbml.NewReader(newReaderOptions(filePath, ""))
case "dctx": case "dctx":
if filePath == "" { if filePath == "" {
return nil, fmt.Errorf("file path is required for DCTX format") return nil, fmt.Errorf("file path is required for DCTX format")
} }
reader = dctx.NewReader(&readers.ReaderOptions{FilePath: filePath}) reader = dctx.NewReader(newReaderOptions(filePath, ""))
case "drawdb": case "drawdb":
if filePath == "" { if filePath == "" {
return nil, fmt.Errorf("file path is required for DrawDB format") return nil, fmt.Errorf("file path is required for DrawDB format")
} }
reader = drawdb.NewReader(&readers.ReaderOptions{FilePath: filePath}) reader = drawdb.NewReader(newReaderOptions(filePath, ""))
case "json": case "json":
if filePath == "" { if filePath == "" {
return nil, fmt.Errorf("file path is required for JSON format") return nil, fmt.Errorf("file path is required for JSON format")
} }
reader = json.NewReader(&readers.ReaderOptions{FilePath: filePath}) reader = json.NewReader(newReaderOptions(filePath, ""))
case "yaml", "yml": case "yaml", "yml":
if filePath == "" { if filePath == "" {
return nil, fmt.Errorf("file path is required for YAML format") return nil, fmt.Errorf("file path is required for YAML format")
} }
reader = yaml.NewReader(&readers.ReaderOptions{FilePath: filePath}) reader = yaml.NewReader(newReaderOptions(filePath, ""))
case "pgsql", "postgres", "postgresql": case "pgsql", "postgres", "postgresql":
if connString == "" { if connString == "" {
return nil, fmt.Errorf("connection string is required for PostgreSQL format") return nil, fmt.Errorf("connection string is required for PostgreSQL format")
} }
reader = pgsql.NewReader(&readers.ReaderOptions{ConnectionString: connString}) reader = pgsql.NewReader(newReaderOptions("", connString))
case "gorm": case "gorm":
if filePath == "" { if filePath == "" {
return nil, fmt.Errorf("file path is required for GORM format") return nil, fmt.Errorf("file path is required for GORM format")
} }
reader = gorm.NewReader(&readers.ReaderOptions{FilePath: filePath}) reader = gorm.NewReader(newReaderOptions(filePath, ""))
case "bun": case "bun":
if filePath == "" { if filePath == "" {
return nil, fmt.Errorf("file path is required for Bun format") return nil, fmt.Errorf("file path is required for Bun format")
} }
reader = bun.NewReader(&readers.ReaderOptions{FilePath: filePath}) reader = bun.NewReader(newReaderOptions(filePath, ""))
case "drizzle": case "drizzle":
if filePath == "" { if filePath == "" {
return nil, fmt.Errorf("file path is required for Drizzle format") return nil, fmt.Errorf("file path is required for Drizzle format")
} }
reader = drizzle.NewReader(&readers.ReaderOptions{FilePath: filePath}) reader = drizzle.NewReader(newReaderOptions(filePath, ""))
case "prisma": case "prisma":
if filePath == "" { if filePath == "" {
return nil, fmt.Errorf("file path is required for Prisma format") return nil, fmt.Errorf("file path is required for Prisma format")
} }
reader = prisma.NewReader(&readers.ReaderOptions{FilePath: filePath}) reader = prisma.NewReader(newReaderOptions(filePath, ""))
case "typeorm": case "typeorm":
if filePath == "" { if filePath == "" {
return nil, fmt.Errorf("file path is required for TypeORM format") return nil, fmt.Errorf("file path is required for TypeORM format")
} }
reader = typeorm.NewReader(&readers.ReaderOptions{FilePath: filePath}) reader = typeorm.NewReader(newReaderOptions(filePath, ""))
case "graphql", "gql": case "graphql", "gql":
if filePath == "" { if filePath == "" {
return nil, fmt.Errorf("file path is required for GraphQL format") return nil, fmt.Errorf("file path is required for GraphQL format")
} }
reader = graphql.NewReader(&readers.ReaderOptions{FilePath: filePath}) reader = graphql.NewReader(newReaderOptions(filePath, ""))
case "mssql", "sqlserver", "mssql2016", "mssql2017", "mssql2019", "mssql2022": case "mssql", "sqlserver", "mssql2016", "mssql2017", "mssql2019", "mssql2022":
if connString == "" { if connString == "" {
return nil, fmt.Errorf("connection string is required for MSSQL format") return nil, fmt.Errorf("connection string is required for MSSQL format")
} }
reader = mssql.NewReader(&readers.ReaderOptions{ConnectionString: connString}) reader = mssql.NewReader(newReaderOptions("", connString))
case "sqlite", "sqlite3": case "sqlite", "sqlite3":
// SQLite can use either file path or connection string // SQLite can use either file path or connection string
@@ -367,7 +369,7 @@ func readDatabaseForConvert(dbType, filePath, connString string) (*models.Databa
if dbPath == "" { if dbPath == "" {
return nil, fmt.Errorf("file path or connection string is required for SQLite format") return nil, fmt.Errorf("file path or connection string is required for SQLite format")
} }
reader = sqlite.NewReader(&readers.ReaderOptions{FilePath: dbPath}) reader = sqlite.NewReader(newReaderOptions(dbPath, ""))
default: default:
return nil, fmt.Errorf("unsupported source format: %s", dbType) return nil, fmt.Errorf("unsupported source format: %s", dbType)
@@ -381,14 +383,10 @@ func readDatabaseForConvert(dbType, filePath, connString string) (*models.Databa
return db, nil return db, nil
} }
func writeDatabase(db *models.Database, dbType, outputPath, packageName, schemaFilter string, flattenSchema bool) error { func writeDatabase(db *models.Database, dbType, outputPath, packageName, schemaFilter string, flattenSchema bool, nullableTypes string) error {
var writer writers.Writer var writer writers.Writer
writerOpts := &writers.WriterOptions{ writerOpts := newWriterOptions(outputPath, packageName, flattenSchema, nullableTypes)
OutputPath: outputPath,
PackageName: packageName,
FlattenSchema: flattenSchema,
}
switch strings.ToLower(dbType) { switch strings.ToLower(dbType) {
case "dbml": case "dbml":

View File

@@ -240,62 +240,62 @@ func readDatabaseForEdit(dbType, filePath, connString, label string) (*models.Da
if filePath == "" { if filePath == "" {
return nil, fmt.Errorf("%s: file path is required for DBML format", label) return nil, fmt.Errorf("%s: file path is required for DBML format", label)
} }
reader = dbml.NewReader(&readers.ReaderOptions{FilePath: filePath}) reader = dbml.NewReader(newReaderOptions(filePath, ""))
case "dctx": case "dctx":
if filePath == "" { if filePath == "" {
return nil, fmt.Errorf("%s: file path is required for DCTX format", label) return nil, fmt.Errorf("%s: file path is required for DCTX format", label)
} }
reader = dctx.NewReader(&readers.ReaderOptions{FilePath: filePath}) reader = dctx.NewReader(newReaderOptions(filePath, ""))
case "drawdb": case "drawdb":
if filePath == "" { if filePath == "" {
return nil, fmt.Errorf("%s: file path is required for DrawDB format", label) return nil, fmt.Errorf("%s: file path is required for DrawDB format", label)
} }
reader = drawdb.NewReader(&readers.ReaderOptions{FilePath: filePath}) reader = drawdb.NewReader(newReaderOptions(filePath, ""))
case "graphql": case "graphql":
if filePath == "" { if filePath == "" {
return nil, fmt.Errorf("%s: file path is required for GraphQL format", label) return nil, fmt.Errorf("%s: file path is required for GraphQL format", label)
} }
reader = graphql.NewReader(&readers.ReaderOptions{FilePath: filePath}) reader = graphql.NewReader(newReaderOptions(filePath, ""))
case "json": case "json":
if filePath == "" { if filePath == "" {
return nil, fmt.Errorf("%s: file path is required for JSON format", label) return nil, fmt.Errorf("%s: file path is required for JSON format", label)
} }
reader = json.NewReader(&readers.ReaderOptions{FilePath: filePath}) reader = json.NewReader(newReaderOptions(filePath, ""))
case "yaml": case "yaml":
if filePath == "" { if filePath == "" {
return nil, fmt.Errorf("%s: file path is required for YAML format", label) return nil, fmt.Errorf("%s: file path is required for YAML format", label)
} }
reader = yaml.NewReader(&readers.ReaderOptions{FilePath: filePath}) reader = yaml.NewReader(newReaderOptions(filePath, ""))
case "gorm": case "gorm":
if filePath == "" { if filePath == "" {
return nil, fmt.Errorf("%s: file path is required for GORM format", label) return nil, fmt.Errorf("%s: file path is required for GORM format", label)
} }
reader = gorm.NewReader(&readers.ReaderOptions{FilePath: filePath}) reader = gorm.NewReader(newReaderOptions(filePath, ""))
case "bun": case "bun":
if filePath == "" { if filePath == "" {
return nil, fmt.Errorf("%s: file path is required for Bun format", label) return nil, fmt.Errorf("%s: file path is required for Bun format", label)
} }
reader = bun.NewReader(&readers.ReaderOptions{FilePath: filePath}) reader = bun.NewReader(newReaderOptions(filePath, ""))
case "drizzle": case "drizzle":
if filePath == "" { if filePath == "" {
return nil, fmt.Errorf("%s: file path is required for Drizzle format", label) return nil, fmt.Errorf("%s: file path is required for Drizzle format", label)
} }
reader = drizzle.NewReader(&readers.ReaderOptions{FilePath: filePath}) reader = drizzle.NewReader(newReaderOptions(filePath, ""))
case "prisma": case "prisma":
if filePath == "" { if filePath == "" {
return nil, fmt.Errorf("%s: file path is required for Prisma format", label) return nil, fmt.Errorf("%s: file path is required for Prisma format", label)
} }
reader = prisma.NewReader(&readers.ReaderOptions{FilePath: filePath}) reader = prisma.NewReader(newReaderOptions(filePath, ""))
case "typeorm": case "typeorm":
if filePath == "" { if filePath == "" {
return nil, fmt.Errorf("%s: file path is required for TypeORM format", label) return nil, fmt.Errorf("%s: file path is required for TypeORM format", label)
} }
reader = typeorm.NewReader(&readers.ReaderOptions{FilePath: filePath}) reader = typeorm.NewReader(newReaderOptions(filePath, ""))
case "pgsql": case "pgsql":
if connString == "" { if connString == "" {
return nil, fmt.Errorf("%s: connection string is required for PostgreSQL format", label) return nil, fmt.Errorf("%s: connection string is required for PostgreSQL format", label)
} }
reader = pgsql.NewReader(&readers.ReaderOptions{ConnectionString: connString}) reader = pgsql.NewReader(newReaderOptions("", connString))
case "sqlite", "sqlite3": case "sqlite", "sqlite3":
// SQLite can use either file path or connection string // SQLite can use either file path or connection string
dbPath := filePath dbPath := filePath
@@ -305,7 +305,7 @@ func readDatabaseForEdit(dbType, filePath, connString, label string) (*models.Da
if dbPath == "" { if dbPath == "" {
return nil, fmt.Errorf("%s: file path or connection string is required for SQLite format", label) return nil, fmt.Errorf("%s: file path or connection string is required for SQLite format", label)
} }
reader = sqlite.NewReader(&readers.ReaderOptions{FilePath: dbPath}) reader = sqlite.NewReader(newReaderOptions(dbPath, ""))
default: default:
return nil, fmt.Errorf("%s: unsupported format: %s", label, dbType) return nil, fmt.Errorf("%s: unsupported format: %s", label, dbType)
} }
@@ -323,31 +323,31 @@ func writeDatabaseForEdit(dbType, filePath, connString string, db *models.Databa
switch strings.ToLower(dbType) { switch strings.ToLower(dbType) {
case "dbml": case "dbml":
writer = wdbml.NewWriter(&writers.WriterOptions{OutputPath: filePath}) writer = wdbml.NewWriter(newWriterOptions(filePath, "", false, ""))
case "dctx": case "dctx":
writer = wdctx.NewWriter(&writers.WriterOptions{OutputPath: filePath}) writer = wdctx.NewWriter(newWriterOptions(filePath, "", false, ""))
case "drawdb": case "drawdb":
writer = wdrawdb.NewWriter(&writers.WriterOptions{OutputPath: filePath}) writer = wdrawdb.NewWriter(newWriterOptions(filePath, "", false, ""))
case "graphql": case "graphql":
writer = wgraphql.NewWriter(&writers.WriterOptions{OutputPath: filePath}) writer = wgraphql.NewWriter(newWriterOptions(filePath, "", false, ""))
case "json": case "json":
writer = wjson.NewWriter(&writers.WriterOptions{OutputPath: filePath}) writer = wjson.NewWriter(newWriterOptions(filePath, "", false, ""))
case "yaml": case "yaml":
writer = wyaml.NewWriter(&writers.WriterOptions{OutputPath: filePath}) writer = wyaml.NewWriter(newWriterOptions(filePath, "", false, ""))
case "gorm": case "gorm":
writer = wgorm.NewWriter(&writers.WriterOptions{OutputPath: filePath}) writer = wgorm.NewWriter(newWriterOptions(filePath, "", false, ""))
case "bun": case "bun":
writer = wbun.NewWriter(&writers.WriterOptions{OutputPath: filePath}) writer = wbun.NewWriter(newWriterOptions(filePath, "", false, ""))
case "drizzle": case "drizzle":
writer = wdrizzle.NewWriter(&writers.WriterOptions{OutputPath: filePath}) writer = wdrizzle.NewWriter(newWriterOptions(filePath, "", false, ""))
case "prisma": case "prisma":
writer = wprisma.NewWriter(&writers.WriterOptions{OutputPath: filePath}) writer = wprisma.NewWriter(newWriterOptions(filePath, "", false, ""))
case "typeorm": case "typeorm":
writer = wtypeorm.NewWriter(&writers.WriterOptions{OutputPath: filePath}) writer = wtypeorm.NewWriter(newWriterOptions(filePath, "", false, ""))
case "sqlite", "sqlite3": case "sqlite", "sqlite3":
writer = wsqlite.NewWriter(&writers.WriterOptions{OutputPath: filePath}) writer = wsqlite.NewWriter(newWriterOptions(filePath, "", false, ""))
case "pgsql": case "pgsql":
writer = wpgsql.NewWriter(&writers.WriterOptions{OutputPath: filePath}) writer = wpgsql.NewWriter(newWriterOptions(filePath, "", false, ""))
default: default:
return fmt.Errorf("%s: unsupported format: %s", label, dbType) return fmt.Errorf("%s: unsupported format: %s", label, dbType)
} }

View File

@@ -221,73 +221,73 @@ func readDatabaseForInspect(dbType, filePath, connString string) (*models.Databa
if filePath == "" { if filePath == "" {
return nil, fmt.Errorf("file path is required for DBML format") return nil, fmt.Errorf("file path is required for DBML format")
} }
reader = dbml.NewReader(&readers.ReaderOptions{FilePath: filePath}) reader = dbml.NewReader(newReaderOptions(filePath, ""))
case "dctx": case "dctx":
if filePath == "" { if filePath == "" {
return nil, fmt.Errorf("file path is required for DCTX format") return nil, fmt.Errorf("file path is required for DCTX format")
} }
reader = dctx.NewReader(&readers.ReaderOptions{FilePath: filePath}) reader = dctx.NewReader(newReaderOptions(filePath, ""))
case "drawdb": case "drawdb":
if filePath == "" { if filePath == "" {
return nil, fmt.Errorf("file path is required for DrawDB format") return nil, fmt.Errorf("file path is required for DrawDB format")
} }
reader = drawdb.NewReader(&readers.ReaderOptions{FilePath: filePath}) reader = drawdb.NewReader(newReaderOptions(filePath, ""))
case "graphql": case "graphql":
if filePath == "" { if filePath == "" {
return nil, fmt.Errorf("file path is required for GraphQL format") return nil, fmt.Errorf("file path is required for GraphQL format")
} }
reader = graphql.NewReader(&readers.ReaderOptions{FilePath: filePath}) reader = graphql.NewReader(newReaderOptions(filePath, ""))
case "json": case "json":
if filePath == "" { if filePath == "" {
return nil, fmt.Errorf("file path is required for JSON format") return nil, fmt.Errorf("file path is required for JSON format")
} }
reader = json.NewReader(&readers.ReaderOptions{FilePath: filePath}) reader = json.NewReader(newReaderOptions(filePath, ""))
case "yaml", "yml": case "yaml", "yml":
if filePath == "" { if filePath == "" {
return nil, fmt.Errorf("file path is required for YAML format") return nil, fmt.Errorf("file path is required for YAML format")
} }
reader = yaml.NewReader(&readers.ReaderOptions{FilePath: filePath}) reader = yaml.NewReader(newReaderOptions(filePath, ""))
case "gorm": case "gorm":
if filePath == "" { if filePath == "" {
return nil, fmt.Errorf("file path is required for GORM format") return nil, fmt.Errorf("file path is required for GORM format")
} }
reader = gorm.NewReader(&readers.ReaderOptions{FilePath: filePath}) reader = gorm.NewReader(newReaderOptions(filePath, ""))
case "bun": case "bun":
if filePath == "" { if filePath == "" {
return nil, fmt.Errorf("file path is required for Bun format") return nil, fmt.Errorf("file path is required for Bun format")
} }
reader = bun.NewReader(&readers.ReaderOptions{FilePath: filePath}) reader = bun.NewReader(newReaderOptions(filePath, ""))
case "drizzle": case "drizzle":
if filePath == "" { if filePath == "" {
return nil, fmt.Errorf("file path is required for Drizzle format") return nil, fmt.Errorf("file path is required for Drizzle format")
} }
reader = drizzle.NewReader(&readers.ReaderOptions{FilePath: filePath}) reader = drizzle.NewReader(newReaderOptions(filePath, ""))
case "prisma": case "prisma":
if filePath == "" { if filePath == "" {
return nil, fmt.Errorf("file path is required for Prisma format") return nil, fmt.Errorf("file path is required for Prisma format")
} }
reader = prisma.NewReader(&readers.ReaderOptions{FilePath: filePath}) reader = prisma.NewReader(newReaderOptions(filePath, ""))
case "typeorm": case "typeorm":
if filePath == "" { if filePath == "" {
return nil, fmt.Errorf("file path is required for TypeORM format") return nil, fmt.Errorf("file path is required for TypeORM format")
} }
reader = typeorm.NewReader(&readers.ReaderOptions{FilePath: filePath}) reader = typeorm.NewReader(newReaderOptions(filePath, ""))
case "pgsql", "postgres", "postgresql": case "pgsql", "postgres", "postgresql":
if connString == "" { if connString == "" {
return nil, fmt.Errorf("connection string is required for PostgreSQL format") return nil, fmt.Errorf("connection string is required for PostgreSQL format")
} }
reader = pgsql.NewReader(&readers.ReaderOptions{ConnectionString: connString}) reader = pgsql.NewReader(newReaderOptions("", connString))
case "sqlite", "sqlite3": case "sqlite", "sqlite3":
// SQLite can use either file path or connection string // SQLite can use either file path or connection string
@@ -298,7 +298,7 @@ func readDatabaseForInspect(dbType, filePath, connString string) (*models.Databa
if dbPath == "" { if dbPath == "" {
return nil, fmt.Errorf("file path or connection string is required for SQLite format") return nil, fmt.Errorf("file path or connection string is required for SQLite format")
} }
reader = sqlite.NewReader(&readers.ReaderOptions{FilePath: dbPath}) reader = sqlite.NewReader(newReaderOptions(dbPath, ""))
default: default:
return nil, fmt.Errorf("unsupported database type: %s", dbType) return nil, fmt.Errorf("unsupported database type: %s", dbType)

View File

@@ -284,62 +284,62 @@ func readDatabaseForMerge(dbType, filePath, connString, label string) (*models.D
if filePath == "" { if filePath == "" {
return nil, fmt.Errorf("%s: file path is required for DBML format", label) return nil, fmt.Errorf("%s: file path is required for DBML format", label)
} }
reader = dbml.NewReader(&readers.ReaderOptions{FilePath: filePath}) reader = dbml.NewReader(newReaderOptions(filePath, ""))
case "dctx": case "dctx":
if filePath == "" { if filePath == "" {
return nil, fmt.Errorf("%s: file path is required for DCTX format", label) return nil, fmt.Errorf("%s: file path is required for DCTX format", label)
} }
reader = dctx.NewReader(&readers.ReaderOptions{FilePath: filePath}) reader = dctx.NewReader(newReaderOptions(filePath, ""))
case "drawdb": case "drawdb":
if filePath == "" { if filePath == "" {
return nil, fmt.Errorf("%s: file path is required for DrawDB format", label) return nil, fmt.Errorf("%s: file path is required for DrawDB format", label)
} }
reader = drawdb.NewReader(&readers.ReaderOptions{FilePath: filePath}) reader = drawdb.NewReader(newReaderOptions(filePath, ""))
case "graphql": case "graphql":
if filePath == "" { if filePath == "" {
return nil, fmt.Errorf("%s: file path is required for GraphQL format", label) return nil, fmt.Errorf("%s: file path is required for GraphQL format", label)
} }
reader = graphql.NewReader(&readers.ReaderOptions{FilePath: filePath}) reader = graphql.NewReader(newReaderOptions(filePath, ""))
case "json": case "json":
if filePath == "" { if filePath == "" {
return nil, fmt.Errorf("%s: file path is required for JSON format", label) return nil, fmt.Errorf("%s: file path is required for JSON format", label)
} }
reader = json.NewReader(&readers.ReaderOptions{FilePath: filePath}) reader = json.NewReader(newReaderOptions(filePath, ""))
case "yaml": case "yaml":
if filePath == "" { if filePath == "" {
return nil, fmt.Errorf("%s: file path is required for YAML format", label) return nil, fmt.Errorf("%s: file path is required for YAML format", label)
} }
reader = yaml.NewReader(&readers.ReaderOptions{FilePath: filePath}) reader = yaml.NewReader(newReaderOptions(filePath, ""))
case "gorm": case "gorm":
if filePath == "" { if filePath == "" {
return nil, fmt.Errorf("%s: file path is required for GORM format", label) return nil, fmt.Errorf("%s: file path is required for GORM format", label)
} }
reader = gorm.NewReader(&readers.ReaderOptions{FilePath: filePath}) reader = gorm.NewReader(newReaderOptions(filePath, ""))
case "bun": case "bun":
if filePath == "" { if filePath == "" {
return nil, fmt.Errorf("%s: file path is required for Bun format", label) return nil, fmt.Errorf("%s: file path is required for Bun format", label)
} }
reader = bun.NewReader(&readers.ReaderOptions{FilePath: filePath}) reader = bun.NewReader(newReaderOptions(filePath, ""))
case "drizzle": case "drizzle":
if filePath == "" { if filePath == "" {
return nil, fmt.Errorf("%s: file path is required for Drizzle format", label) return nil, fmt.Errorf("%s: file path is required for Drizzle format", label)
} }
reader = drizzle.NewReader(&readers.ReaderOptions{FilePath: filePath}) reader = drizzle.NewReader(newReaderOptions(filePath, ""))
case "prisma": case "prisma":
if filePath == "" { if filePath == "" {
return nil, fmt.Errorf("%s: file path is required for Prisma format", label) return nil, fmt.Errorf("%s: file path is required for Prisma format", label)
} }
reader = prisma.NewReader(&readers.ReaderOptions{FilePath: filePath}) reader = prisma.NewReader(newReaderOptions(filePath, ""))
case "typeorm": case "typeorm":
if filePath == "" { if filePath == "" {
return nil, fmt.Errorf("%s: file path is required for TypeORM format", label) return nil, fmt.Errorf("%s: file path is required for TypeORM format", label)
} }
reader = typeorm.NewReader(&readers.ReaderOptions{FilePath: filePath}) reader = typeorm.NewReader(newReaderOptions(filePath, ""))
case "pgsql": case "pgsql":
if connString == "" { if connString == "" {
return nil, fmt.Errorf("%s: connection string is required for PostgreSQL format", label) return nil, fmt.Errorf("%s: connection string is required for PostgreSQL format", label)
} }
reader = pgsql.NewReader(&readers.ReaderOptions{ConnectionString: connString}) reader = pgsql.NewReader(newReaderOptions("", connString))
case "sqlite", "sqlite3": case "sqlite", "sqlite3":
// SQLite can use either file path or connection string // SQLite can use either file path or connection string
dbPath := filePath dbPath := filePath
@@ -349,7 +349,7 @@ func readDatabaseForMerge(dbType, filePath, connString, label string) (*models.D
if dbPath == "" { if dbPath == "" {
return nil, fmt.Errorf("%s: file path or connection string is required for SQLite format", label) return nil, fmt.Errorf("%s: file path or connection string is required for SQLite format", label)
} }
reader = sqlite.NewReader(&readers.ReaderOptions{FilePath: dbPath}) reader = sqlite.NewReader(newReaderOptions(dbPath, ""))
default: default:
return nil, fmt.Errorf("%s: unsupported format '%s'", label, dbType) return nil, fmt.Errorf("%s: unsupported format '%s'", label, dbType)
} }
@@ -370,61 +370,61 @@ func writeDatabaseForMerge(dbType, filePath, connString string, db *models.Datab
if filePath == "" { if filePath == "" {
return fmt.Errorf("%s: file path is required for DBML format", label) return fmt.Errorf("%s: file path is required for DBML format", label)
} }
writer = wdbml.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema}) writer = wdbml.NewWriter(newWriterOptions(filePath, "", flattenSchema, ""))
case "dctx": case "dctx":
if filePath == "" { if filePath == "" {
return fmt.Errorf("%s: file path is required for DCTX format", label) return fmt.Errorf("%s: file path is required for DCTX format", label)
} }
writer = wdctx.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema}) writer = wdctx.NewWriter(newWriterOptions(filePath, "", flattenSchema, ""))
case "drawdb": case "drawdb":
if filePath == "" { if filePath == "" {
return fmt.Errorf("%s: file path is required for DrawDB format", label) return fmt.Errorf("%s: file path is required for DrawDB format", label)
} }
writer = wdrawdb.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema}) writer = wdrawdb.NewWriter(newWriterOptions(filePath, "", flattenSchema, ""))
case "graphql": case "graphql":
if filePath == "" { if filePath == "" {
return fmt.Errorf("%s: file path is required for GraphQL format", label) return fmt.Errorf("%s: file path is required for GraphQL format", label)
} }
writer = wgraphql.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema}) writer = wgraphql.NewWriter(newWriterOptions(filePath, "", flattenSchema, ""))
case "json": case "json":
if filePath == "" { if filePath == "" {
return fmt.Errorf("%s: file path is required for JSON format", label) return fmt.Errorf("%s: file path is required for JSON format", label)
} }
writer = wjson.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema}) writer = wjson.NewWriter(newWriterOptions(filePath, "", flattenSchema, ""))
case "yaml": case "yaml":
if filePath == "" { if filePath == "" {
return fmt.Errorf("%s: file path is required for YAML format", label) return fmt.Errorf("%s: file path is required for YAML format", label)
} }
writer = wyaml.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema}) writer = wyaml.NewWriter(newWriterOptions(filePath, "", flattenSchema, ""))
case "gorm": case "gorm":
if filePath == "" { if filePath == "" {
return fmt.Errorf("%s: file path is required for GORM format", label) return fmt.Errorf("%s: file path is required for GORM format", label)
} }
writer = wgorm.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema}) writer = wgorm.NewWriter(newWriterOptions(filePath, "", flattenSchema, ""))
case "bun": case "bun":
if filePath == "" { if filePath == "" {
return fmt.Errorf("%s: file path is required for Bun format", label) return fmt.Errorf("%s: file path is required for Bun format", label)
} }
writer = wbun.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema}) writer = wbun.NewWriter(newWriterOptions(filePath, "", flattenSchema, ""))
case "drizzle": case "drizzle":
if filePath == "" { if filePath == "" {
return fmt.Errorf("%s: file path is required for Drizzle format", label) return fmt.Errorf("%s: file path is required for Drizzle format", label)
} }
writer = wdrizzle.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema}) writer = wdrizzle.NewWriter(newWriterOptions(filePath, "", flattenSchema, ""))
case "prisma": case "prisma":
if filePath == "" { if filePath == "" {
return fmt.Errorf("%s: file path is required for Prisma format", label) return fmt.Errorf("%s: file path is required for Prisma format", label)
} }
writer = wprisma.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema}) writer = wprisma.NewWriter(newWriterOptions(filePath, "", flattenSchema, ""))
case "typeorm": case "typeorm":
if filePath == "" { if filePath == "" {
return fmt.Errorf("%s: file path is required for TypeORM format", label) return fmt.Errorf("%s: file path is required for TypeORM format", label)
} }
writer = wtypeorm.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema}) writer = wtypeorm.NewWriter(newWriterOptions(filePath, "", flattenSchema, ""))
case "sqlite", "sqlite3": case "sqlite", "sqlite3":
writer = wsqlite.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema}) writer = wsqlite.NewWriter(newWriterOptions(filePath, "", flattenSchema, ""))
case "pgsql": case "pgsql":
writerOpts := &writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema} writerOpts := newWriterOptions(filePath, "", flattenSchema, "")
if connString != "" { if connString != "" {
writerOpts.Metadata = map[string]interface{}{ writerOpts.Metadata = map[string]interface{}{
"connection_string": connString, "connection_string": connString,

View File

@@ -0,0 +1,24 @@
package main
import (
"git.warky.dev/wdevs/relspecgo/pkg/readers"
"git.warky.dev/wdevs/relspecgo/pkg/writers"
)
func newReaderOptions(filePath, connString string) *readers.ReaderOptions {
return &readers.ReaderOptions{
FilePath: filePath,
ConnectionString: connString,
Prisma7: prisma7,
}
}
func newWriterOptions(outputPath, packageName string, flattenSchema bool, nullableTypes string) *writers.WriterOptions {
return &writers.WriterOptions{
OutputPath: outputPath,
PackageName: packageName,
FlattenSchema: flattenSchema,
NullableTypes: nullableTypes,
Prisma7: prisma7,
}
}

View File

@@ -12,6 +12,7 @@ var (
// Version information, set via ldflags during build // Version information, set via ldflags during build
version = "dev" version = "dev"
buildDate = "unknown" buildDate = "unknown"
prisma7 bool
) )
func init() { func init() {
@@ -68,4 +69,5 @@ func init() {
rootCmd.AddCommand(mergeCmd) rootCmd.AddCommand(mergeCmd)
rootCmd.AddCommand(splitCmd) rootCmd.AddCommand(splitCmd)
rootCmd.AddCommand(versionCmd) rootCmd.AddCommand(versionCmd)
rootCmd.PersistentFlags().BoolVar(&prisma7, "prisma7", false, "Use Prisma 7 generator conventions when reading/writing Prisma schemas")
} }

View File

@@ -22,6 +22,7 @@ var (
splitDatabaseName string splitDatabaseName string
splitExcludeSchema string splitExcludeSchema string
splitExcludeTables string splitExcludeTables string
splitNullableTypes string
) )
var splitCmd = &cobra.Command{ var splitCmd = &cobra.Command{
@@ -110,6 +111,7 @@ func init() {
splitCmd.Flags().StringVar(&splitTables, "tables", "", "Comma-separated list of table names to include (case-insensitive)") splitCmd.Flags().StringVar(&splitTables, "tables", "", "Comma-separated list of table names to include (case-insensitive)")
splitCmd.Flags().StringVar(&splitExcludeSchema, "exclude-schema", "", "Comma-separated list of schema names to exclude") splitCmd.Flags().StringVar(&splitExcludeSchema, "exclude-schema", "", "Comma-separated list of schema names to exclude")
splitCmd.Flags().StringVar(&splitExcludeTables, "exclude-tables", "", "Comma-separated list of table names to exclude (case-insensitive)") splitCmd.Flags().StringVar(&splitExcludeTables, "exclude-tables", "", "Comma-separated list of table names to exclude (case-insensitive)")
splitCmd.Flags().StringVar(&splitNullableTypes, "types", "", "Nullable type package for code-gen writers (bun/gorm): 'resolvespec' (default) or 'stdlib' (database/sql)")
err := splitCmd.MarkFlagRequired("from") err := splitCmd.MarkFlagRequired("from")
if err != nil { if err != nil {
@@ -185,6 +187,7 @@ func runSplit(cmd *cobra.Command, args []string) error {
splitPackageName, splitPackageName,
"", // no schema filter for split "", // no schema filter for split
false, // no flatten-schema for split false, // no flatten-schema for split
splitNullableTypes,
) )
if err != nil { if err != nil {
return fmt.Errorf("failed to write output: %w", err) return fmt.Errorf("failed to write output: %w", err)

View File

@@ -1,6 +1,6 @@
# Maintainer: Hein (Warky Devs) <hein@warky.dev> # Maintainer: Hein (Warky Devs) <hein@warky.dev>
pkgname=relspec pkgname=relspec
pkgver=1.0.44 pkgver=1.0.55
pkgrel=1 pkgrel=1
pkgdesc="RelSpec is a comprehensive database relations management tool that reads, transforms, and writes database table specifications across multiple formats and ORMs." pkgdesc="RelSpec is a comprehensive database relations management tool that reads, transforms, and writes database table specifications across multiple formats and ORMs."
arch=('x86_64' 'aarch64') arch=('x86_64' 'aarch64')

View File

@@ -1,5 +1,5 @@
Name: relspec Name: relspec
Version: 1.0.44 Version: 1.0.55
Release: 1%{?dist} Release: 1%{?dist}
Summary: RelSpec is a comprehensive database relations management tool that reads, transforms, and writes database table specifications across multiple formats and ORMs. Summary: RelSpec is a comprehensive database relations management tool that reads, transforms, and writes database table specifications across multiple formats and ORMs.

85
pkg/pgsql/connection.go Normal file
View File

@@ -0,0 +1,85 @@
package pgsql
import (
"context"
"fmt"
"runtime/debug"
"strings"
"github.com/jackc/pgx/v5"
)
const (
defaultApplicationPrefix = "relspecgo"
postgresIdentifierMaxLen = 63
)
// BuildApplicationName returns a PostgreSQL application_name in the form:
// relspecgo/<version>[:<component>]
func BuildApplicationName(component string) string {
appName := fmt.Sprintf("%s/%s", defaultApplicationPrefix, relspecVersion())
component = strings.TrimSpace(component)
if component != "" {
appName = appName + ":" + component
}
if len(appName) > postgresIdentifierMaxLen {
appName = appName[:postgresIdentifierMaxLen]
}
return appName
}
// ParseConfigWithApplicationName parses a connection string and applies a default
// application_name when one is not explicitly provided by the caller.
func ParseConfigWithApplicationName(connString, component string) (*pgx.ConnConfig, error) {
cfg, err := pgx.ParseConfig(connString)
if err != nil {
return nil, err
}
if cfg.RuntimeParams == nil {
cfg.RuntimeParams = map[string]string{}
}
if strings.TrimSpace(cfg.RuntimeParams["application_name"]) == "" {
cfg.RuntimeParams["application_name"] = BuildApplicationName(component)
}
return cfg, nil
}
// Connect establishes a PostgreSQL connection with a default relspec
// application_name when the caller does not provide one in the DSN.
func Connect(ctx context.Context, connString, component string) (*pgx.Conn, error) {
cfg, err := ParseConfigWithApplicationName(connString, component)
if err != nil {
return nil, err
}
return pgx.ConnectConfig(ctx, cfg)
}
func relspecVersion() string {
info, ok := debug.ReadBuildInfo()
if !ok {
return "dev"
}
version := strings.TrimSpace(info.Main.Version)
if version != "" && version != "(devel)" {
return version
}
for _, setting := range info.Settings {
if setting.Key == "vcs.revision" {
revision := strings.TrimSpace(setting.Value)
if len(revision) >= 7 {
return revision[:7]
}
if revision != "" {
return revision
}
}
}
return "dev"
}

View File

@@ -0,0 +1,53 @@
package pgsql
import (
"strings"
"testing"
)
func TestBuildApplicationName_IncludesVersion(t *testing.T) {
got := BuildApplicationName("")
if !strings.HasPrefix(got, "relspecgo/") {
t.Fatalf("BuildApplicationName() = %q, expected prefix relspecgo/", got)
}
}
func TestBuildApplicationName_IncludesComponent(t *testing.T) {
got := BuildApplicationName("reader-pgsql")
if !strings.Contains(got, ":reader-pgsql") {
t.Fatalf("BuildApplicationName(component) = %q, expected component suffix", got)
}
}
func TestBuildApplicationName_RespectsPostgresLengthLimit(t *testing.T) {
got := BuildApplicationName(strings.Repeat("x", 200))
if len(got) > 63 {
t.Fatalf("BuildApplicationName() length = %d, expected <= 63", len(got))
}
}
func TestParseConfigWithApplicationName_AddsWhenMissing(t *testing.T) {
cfg, err := ParseConfigWithApplicationName("postgres://user:pass@localhost:5432/db", "reader-pgsql")
if err != nil {
t.Fatalf("ParseConfigWithApplicationName() error = %v", err)
}
appName := cfg.RuntimeParams["application_name"]
if appName == "" {
t.Fatal("expected application_name to be set")
}
if !strings.HasPrefix(appName, "relspecgo/") {
t.Fatalf("application_name = %q, expected relspecgo/<version> prefix", appName)
}
}
func TestParseConfigWithApplicationName_PreservesExplicitValue(t *testing.T) {
cfg, err := ParseConfigWithApplicationName("postgres://user:pass@localhost:5432/db?application_name=custom-app", "reader-pgsql")
if err != nil {
t.Fatalf("ParseConfigWithApplicationName() error = %v", err)
}
if got := cfg.RuntimeParams["application_name"]; got != "custom-app" {
t.Fatalf("application_name = %q, expected %q", got, "custom-app")
}
}

250
pkg/pgsql/types_registry.go Normal file
View File

@@ -0,0 +1,250 @@
package pgsql
import (
"sort"
"strings"
)
// TypeSpec describes PostgreSQL type capabilities used by parsers/writers.
type TypeSpec struct {
SupportsLength bool
SupportsPrecision bool
}
var postgresBaseTypes = map[string]TypeSpec{
// Numeric types
"smallint": {},
"integer": {},
"bigint": {},
"decimal": {SupportsPrecision: true},
"numeric": {SupportsPrecision: true},
"real": {},
"double precision": {},
"smallserial": {},
"serial": {},
"bigserial": {},
"money": {},
// Character types
"char": {SupportsLength: true},
"character": {SupportsLength: true},
"varchar": {SupportsLength: true},
"character varying": {SupportsLength: true},
"text": {},
"name": {},
// Binary
"bytea": {},
// Date/time
"timestamp": {SupportsPrecision: true},
"timestamp without time zone": {SupportsPrecision: true},
"timestamp with time zone": {SupportsPrecision: true},
"time": {SupportsPrecision: true},
"time without time zone": {SupportsPrecision: true},
"time with time zone": {SupportsPrecision: true},
"date": {},
"interval": {SupportsPrecision: true},
// Boolean
"boolean": {},
// Geometric
"point": {},
"line": {},
"lseg": {},
"box": {},
"path": {},
"polygon": {},
"circle": {},
// Network
"cidr": {},
"inet": {},
"macaddr": {},
"macaddr8": {},
// Bit string
"bit": {SupportsLength: true},
"bit varying": {SupportsLength: true},
"varbit": {SupportsLength: true},
// Text search
"tsvector": {},
"tsquery": {},
// UUID/XML/JSON
"uuid": {},
"xml": {},
"json": {},
"jsonb": {},
// Range
"int4range": {},
"int8range": {},
"numrange": {},
"tsrange": {},
"tstzrange": {},
"daterange": {},
"int4multirange": {},
"int8multirange": {},
"nummultirange": {},
"tsmultirange": {},
"tstzmultirange": {},
"datemultirange": {},
// Object identifier
"oid": {},
"regclass": {},
"regproc": {},
"regtype": {},
// Pseudo-ish/common built-ins seen in schemas
"record": {},
"void": {},
// Common extensions
"citext": {},
"hstore": {},
"ltree": {},
"lquery": {},
"ltxtquery": {},
"vector": {}, // pgvector: keep explicit modifier form (vector(dim))
"halfvec": {}, // pgvector: keep explicit modifier form (halfvec(dim))
"sparsevec": {}, // pgvector: keep explicit modifier form (sparsevec(dim))
}
var postgresTypeAliases = map[string]string{
// Integer aliases
"int2": "smallint",
"int4": "integer",
"int8": "bigint",
"int": "integer",
// Serial aliases
"serial2": "smallserial",
"serial4": "serial",
"serial8": "bigserial",
// Character aliases
"bpchar": "char",
// Float aliases
"float4": "real",
"float8": "double precision",
"float": "double precision",
// Time aliases
"timestamptz": "timestamp with time zone",
"timetz": "time with time zone",
// Bit alias
"varbit": "bit varying",
// Boolean alias
"bool": "boolean",
}
// GetPostgresBaseTypes returns a sorted-ish stable list of registered base type names.
func GetPostgresBaseTypes() []string {
result := make([]string, 0, len(postgresBaseTypes))
for t := range postgresBaseTypes {
result = append(result, t)
}
sort.Strings(result)
return result
}
// GetPostgresTypes returns the registered PostgreSQL types.
// When includeArrays is true, each base type also includes an array variant ("type[]").
func GetPostgresTypes(includeArrays bool) []string {
base := GetPostgresBaseTypes()
if !includeArrays {
return base
}
result := make([]string, 0, len(base)*2)
result = append(result, base...)
for _, t := range base {
result = append(result, t+"[]")
}
return result
}
// ExtractBaseType returns the type without outer array suffixes and modifiers.
// Examples:
// - varchar(255) -> varchar
// - text[] -> text
// - numeric(10,2)[] -> numeric
func ExtractBaseType(sqlType string) string {
t := normalizeTypeToken(sqlType)
t = strings.TrimSpace(stripArraySuffixes(t))
if idx := strings.Index(t, "("); idx > 0 {
t = strings.TrimSpace(t[:idx])
}
return t
}
// ExtractBaseTypeLower is ExtractBaseType with lowercase normalization.
func ExtractBaseTypeLower(sqlType string) string {
return strings.ToLower(ExtractBaseType(sqlType))
}
// IsArrayType reports whether the SQL type has one or more [] suffixes.
func IsArrayType(sqlType string) bool {
t := normalizeTypeToken(sqlType)
return strings.HasSuffix(t, "[]")
}
// ElementType returns the underlying element type for array types.
// For non-array types, it returns the input unchanged.
func ElementType(sqlType string) string {
t := normalizeTypeToken(sqlType)
return stripArraySuffixes(t)
}
// CanonicalizeBaseType resolves aliases to canonical PostgreSQL type names.
func CanonicalizeBaseType(baseType string) string {
base := strings.ToLower(normalizeTypeToken(baseType))
if canonical, ok := postgresTypeAliases[base]; ok {
return canonical
}
return base
}
// IsKnownPostgresType reports whether a type (including array forms) exists in the registry.
func IsKnownPostgresType(sqlType string) bool {
base := CanonicalizeBaseType(ExtractBaseTypeLower(sqlType))
_, ok := postgresBaseTypes[base]
return ok
}
// SupportsLength reports if this SQL type accepts a single length/dimension modifier.
func SupportsLength(sqlType string) bool {
base := CanonicalizeBaseType(ExtractBaseTypeLower(sqlType))
spec, ok := postgresBaseTypes[base]
return ok && spec.SupportsLength
}
// SupportsPrecision reports if this SQL type accepts precision (and possibly scale).
func SupportsPrecision(sqlType string) bool {
base := CanonicalizeBaseType(ExtractBaseTypeLower(sqlType))
spec, ok := postgresBaseTypes[base]
return ok && spec.SupportsPrecision
}
// HasExplicitTypeModifier reports if the type already includes "(...)".
func HasExplicitTypeModifier(sqlType string) bool {
return strings.Contains(sqlType, "(")
}
func stripArraySuffixes(t string) string {
for strings.HasSuffix(t, "[]") {
t = strings.TrimSpace(strings.TrimSuffix(t, "[]"))
}
return t
}
func normalizeTypeToken(t string) string {
return strings.Join(strings.Fields(strings.TrimSpace(t)), " ")
}

View File

@@ -0,0 +1,99 @@
package pgsql
import "testing"
func TestPostgresTypeRegistry_MasterListIncludesRequestedTypes(t *testing.T) {
required := []string{
"vector",
"integer",
"citext",
}
types := make(map[string]bool)
for _, typ := range GetPostgresTypes(true) {
types[typ] = true
}
for _, typ := range required {
if !types[typ] {
t.Fatalf("master type list missing %q", typ)
}
if !types[typ+"[]"] {
t.Fatalf("master type list missing array variant %q", typ+"[]")
}
}
}
func TestPostgresTypeRegistry_TypeParsingAndCapabilities(t *testing.T) {
tests := []struct {
input string
wantBase string
wantCanonicalBase string
wantArray bool
wantKnown bool
wantLength bool
wantPrecision bool
}{
{
input: "integer[]",
wantBase: "integer",
wantCanonicalBase: "integer",
wantArray: true,
wantKnown: true,
},
{
input: "citext[]",
wantBase: "citext",
wantCanonicalBase: "citext",
wantArray: true,
wantKnown: true,
},
{
input: "vector(1536)",
wantBase: "vector",
wantCanonicalBase: "vector",
wantKnown: true,
wantLength: false,
},
{
input: "numeric(10,2)",
wantBase: "numeric",
wantCanonicalBase: "numeric",
wantKnown: true,
wantPrecision: true,
},
{
input: "int4",
wantBase: "int4",
wantCanonicalBase: "integer",
wantKnown: true,
},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
base := ExtractBaseTypeLower(tt.input)
if base != tt.wantBase {
t.Fatalf("ExtractBaseTypeLower(%q) = %q, want %q", tt.input, base, tt.wantBase)
}
canonical := CanonicalizeBaseType(base)
if canonical != tt.wantCanonicalBase {
t.Fatalf("CanonicalizeBaseType(%q) = %q, want %q", base, canonical, tt.wantCanonicalBase)
}
if IsArrayType(tt.input) != tt.wantArray {
t.Fatalf("IsArrayType(%q) = %v, want %v", tt.input, IsArrayType(tt.input), tt.wantArray)
}
if IsKnownPostgresType(tt.input) != tt.wantKnown {
t.Fatalf("IsKnownPostgresType(%q) = %v, want %v", tt.input, IsKnownPostgresType(tt.input), tt.wantKnown)
}
if SupportsLength(tt.input) != tt.wantLength {
t.Fatalf("SupportsLength(%q) = %v, want %v", tt.input, SupportsLength(tt.input), tt.wantLength)
}
if SupportsPrecision(tt.input) != tt.wantPrecision {
t.Fatalf("SupportsPrecision(%q) = %v, want %v", tt.input, SupportsPrecision(tt.input), tt.wantPrecision)
}
})
}
}

View File

@@ -12,6 +12,7 @@ import (
"strings" "strings"
"git.warky.dev/wdevs/relspecgo/pkg/models" "git.warky.dev/wdevs/relspecgo/pkg/models"
"git.warky.dev/wdevs/relspecgo/pkg/pgsql"
"git.warky.dev/wdevs/relspecgo/pkg/readers" "git.warky.dev/wdevs/relspecgo/pkg/readers"
) )
@@ -700,16 +701,22 @@ func (r *Reader) extractBunTag(tag string) string {
// parseTypeWithLength parses a type string and extracts length if present // parseTypeWithLength parses a type string and extracts length if present
// e.g., "varchar(255)" returns ("varchar", 255) // e.g., "varchar(255)" returns ("varchar", 255)
func (r *Reader) parseTypeWithLength(typeStr string) (baseType string, length int) { func (r *Reader) parseTypeWithLength(typeStr string) (baseType string, length int) {
typeStr = strings.TrimSpace(typeStr)
baseType = typeStr
// Check for type with length: varchar(255), char(10), etc. // Check for type with length: varchar(255), char(10), etc.
re := regexp.MustCompile(`^([a-zA-Z\s]+)\((\d+)\)$`) re := regexp.MustCompile(`^([a-zA-Z\s]+)\((\d+)\)$`)
matches := re.FindStringSubmatch(typeStr) matches := re.FindStringSubmatch(typeStr)
if len(matches) == 3 { if len(matches) == 3 {
rawBaseType := strings.TrimSpace(matches[1])
if pgsql.SupportsLength(rawBaseType) {
if _, err := fmt.Sscanf(matches[2], "%d", &length); err == nil { if _, err := fmt.Sscanf(matches[2], "%d", &length); err == nil {
baseType = strings.TrimSpace(matches[1]) baseType = pgsql.CanonicalizeBaseType(rawBaseType)
return return
} }
} }
baseType = typeStr }
return return
} }

View File

@@ -71,8 +71,11 @@ func TestReader_ReadDatabase_Simple(t *testing.T) {
if !emailCol.NotNull { if !emailCol.NotNull {
t.Error("Column 'email' should be NOT NULL (explicit 'notnull' tag)") t.Error("Column 'email' should be NOT NULL (explicit 'notnull' tag)")
} }
if emailCol.Type != "varchar" || emailCol.Length != 255 { if emailCol.Type != "varchar" && emailCol.Type != "varchar(255)" {
t.Errorf("Expected email type 'varchar(255)', got '%s' with length %d", emailCol.Type, emailCol.Length) t.Errorf("Expected email type 'varchar' or 'varchar(255)', got '%s' with length %d", emailCol.Type, emailCol.Length)
}
if emailCol.Length != 255 {
t.Errorf("Expected email length 255, got %d", emailCol.Length)
} }
// Verify name column - primitive string type should be NOT NULL by default in Bun // Verify name column - primitive string type should be NOT NULL by default in Bun
@@ -356,6 +359,33 @@ func TestReader_ReadDatabase_Complex(t *testing.T) {
} }
} }
func TestParseTypeWithLength_PreservesExplicitTypeModifiers(t *testing.T) {
reader := &Reader{}
tests := []struct {
input string
wantType string
wantLength int
}{
{"varchar(255)", "varchar", 255},
{"character varying(120)", "character varying", 120},
{"vector(1536)", "vector(1536)", 0},
{"numeric(10,2)", "numeric(10,2)", 0},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
gotType, gotLength := reader.parseTypeWithLength(tt.input)
if gotType != tt.wantType {
t.Fatalf("parseTypeWithLength(%q) type = %q, want %q", tt.input, gotType, tt.wantType)
}
if gotLength != tt.wantLength {
t.Fatalf("parseTypeWithLength(%q) length = %d, want %d", tt.input, gotLength, tt.wantLength)
}
})
}
}
func TestReader_ReadSchema(t *testing.T) { func TestReader_ReadSchema(t *testing.T) {
opts := &readers.ReaderOptions{ opts := &readers.ReaderOptions{
FilePath: filepath.Join("..", "..", "..", "tests", "assets", "bun", "simple.go"), FilePath: filepath.Join("..", "..", "..", "tests", "assets", "bun", "simple.go"),

View File

@@ -567,25 +567,20 @@ func (r *Reader) parseDBML(content string) (*models.Database, error) {
// parseColumn parses a DBML column definition // parseColumn parses a DBML column definition
func (r *Reader) parseColumn(line, tableName, schemaName string) (*models.Column, *models.Constraint) { func (r *Reader) parseColumn(line, tableName, schemaName string) (*models.Column, *models.Constraint) {
// Format: column_name type [attributes] // comment // Format: column_name type [attributes] // comment
parts := strings.Fields(line) lineNoComment, inlineComment := splitInlineComment(line)
if len(parts) < 2 { signature, attrs := splitColumnSignatureAndAttrs(lineNoComment)
columnName, columnType, ok := parseColumnSignature(signature)
if !ok {
return nil, nil return nil, nil
} }
columnName := stripQuotes(parts[0])
columnType := stripQuotes(parts[1])
column := models.InitColumn(columnName, tableName, schemaName) column := models.InitColumn(columnName, tableName, schemaName)
column.Type = columnType column.Type = columnType
var constraint *models.Constraint var constraint *models.Constraint
// Parse attributes in brackets // Parse attributes in brackets
if strings.Contains(line, "[") && strings.Contains(line, "]") { if attrs != "" {
attrStart := strings.Index(line, "[")
attrEnd := strings.Index(line, "]")
if attrStart < attrEnd {
attrs := line[attrStart+1 : attrEnd]
attrList := strings.Split(attrs, ",") attrList := strings.Split(attrs, ",")
for _, attr := range attrList { for _, attr := range attrList {
@@ -660,17 +655,94 @@ func (r *Reader) parseColumn(line, tableName, schemaName string) (*models.Column
} }
} }
} }
}
// Parse inline comment // Parse inline comment
if strings.Contains(line, "//") { if inlineComment != "" {
commentStart := strings.Index(line, "//") column.Comment = inlineComment
column.Comment = strings.TrimSpace(line[commentStart+2:])
} }
return column, constraint return column, constraint
} }
func splitInlineComment(line string) (content string, inlineComment string) {
commentStart := strings.Index(line, "//")
if commentStart == -1 {
return line, ""
}
return strings.TrimSpace(line[:commentStart]), strings.TrimSpace(line[commentStart+2:])
}
func splitColumnSignatureAndAttrs(line string) (signature string, attrs string) {
trimmed := strings.TrimSpace(line)
if trimmed == "" || !strings.HasSuffix(trimmed, "]") {
return trimmed, ""
}
bracketDepth := 0
for i := len(trimmed) - 1; i >= 0; i-- {
switch trimmed[i] {
case ']':
bracketDepth++
case '[':
bracketDepth--
if bracketDepth == 0 {
// DBML attributes are a trailing [ ... ] block preceded by whitespace.
// This avoids confusing array types like text[] with attribute blocks.
if i > 0 && (trimmed[i-1] == ' ' || trimmed[i-1] == '\t') {
return strings.TrimSpace(trimmed[:i]), strings.TrimSpace(trimmed[i+1 : len(trimmed)-1])
}
}
}
}
return trimmed, ""
}
func parseColumnSignature(signature string) (columnName string, columnType string, ok bool) {
signature = strings.TrimSpace(signature)
if signature == "" {
return "", "", false
}
var splitAt int
if signature[0] == '"' || signature[0] == '\'' {
quote := signature[0]
splitAt = 1
for splitAt < len(signature) {
if signature[splitAt] == quote {
splitAt++
break
}
splitAt++
}
} else {
for splitAt < len(signature) && signature[splitAt] != ' ' && signature[splitAt] != '\t' {
splitAt++
}
}
if splitAt <= 0 || splitAt >= len(signature) {
return "", "", false
}
columnName = stripQuotes(strings.TrimSpace(signature[:splitAt]))
columnType = stripWrappingQuotes(strings.TrimSpace(signature[splitAt:]))
if columnName == "" || columnType == "" {
return "", "", false
}
return columnName, columnType, true
}
func stripWrappingQuotes(s string) string {
s = strings.TrimSpace(s)
if len(s) >= 2 && ((s[0] == '"' && s[len(s)-1] == '"') || (s[0] == '\'' && s[len(s)-1] == '\'')) {
return s[1 : len(s)-1]
}
return s
}
// parseIndex parses a DBML index definition // parseIndex parses a DBML index definition
func (r *Reader) parseIndex(line, tableName, schemaName string) *models.Index { func (r *Reader) parseIndex(line, tableName, schemaName string) *models.Index {
// Format: (columns) [attributes] OR columnname [attributes] // Format: (columns) [attributes] OR columnname [attributes]

View File

@@ -839,6 +839,67 @@ func TestConstraintNaming(t *testing.T) {
} }
} }
func TestParseColumn_PostgresTypes(t *testing.T) {
reader := &Reader{}
tests := []struct {
name string
line string
wantName string
wantType string
wantNotNull bool
wantComment string
}{
{
name: "array type with attrs",
line: "tags text[] [not null]",
wantName: "tags",
wantType: "text[]",
wantNotNull: true,
},
{
name: "vector with dimension",
line: "embedding vector(1536)",
wantName: "embedding",
wantType: "vector(1536)",
},
{
name: "multi word timestamp type",
line: "published_at timestamp with time zone",
wantName: "published_at",
wantType: "timestamp with time zone",
},
{
name: "array type with inline comment",
line: "labels varchar(20)[] // column labels",
wantName: "labels",
wantType: "varchar(20)[]",
wantComment: "column labels",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
col, _ := reader.parseColumn(tt.line, "events", "public")
if col == nil {
t.Fatalf("parseColumn() returned nil column")
}
if col.Name != tt.wantName {
t.Errorf("column name = %q, want %q", col.Name, tt.wantName)
}
if col.Type != tt.wantType {
t.Errorf("column type = %q, want %q", col.Type, tt.wantType)
}
if col.NotNull != tt.wantNotNull {
t.Errorf("column not null = %v, want %v", col.NotNull, tt.wantNotNull)
}
if col.Comment != tt.wantComment {
t.Errorf("column comment = %q, want %q", col.Comment, tt.wantComment)
}
})
}
}
func getKeys[V any](m map[string]V) []string { func getKeys[V any](m map[string]V) []string {
keys := make([]string, 0, len(m)) keys := make([]string, 0, len(m))
for k := range m { for k := range m {

View File

@@ -7,6 +7,7 @@ import (
"strings" "strings"
"git.warky.dev/wdevs/relspecgo/pkg/models" "git.warky.dev/wdevs/relspecgo/pkg/models"
"git.warky.dev/wdevs/relspecgo/pkg/pgsql"
"git.warky.dev/wdevs/relspecgo/pkg/readers" "git.warky.dev/wdevs/relspecgo/pkg/readers"
) )
@@ -232,7 +233,19 @@ func (r *Reader) convertField(dctxField *models.DCTXField, tableName string) ([]
// mapDataType maps Clarion data types to SQL types // mapDataType maps Clarion data types to SQL types
func (r *Reader) mapDataType(clarionType string, size int) (sqlType string, precision int) { func (r *Reader) mapDataType(clarionType string, size int) (sqlType string, precision int) {
switch strings.ToUpper(clarionType) { trimmedType := strings.TrimSpace(clarionType)
// Preserve known PostgreSQL types (including arrays and extension types)
// from DCTX input instead of coercing them to generic text.
if pgsql.IsKnownPostgresType(trimmedType) {
pgType := canonicalizePostgresType(trimmedType)
if !pgsql.HasExplicitTypeModifier(pgType) && size > 0 && pgsql.SupportsLength(pgType) {
return pgType, size
}
return pgType, 0
}
switch strings.ToUpper(trimmedType) {
case "LONG": case "LONG":
if size == 8 { if size == 8 {
return "bigint", 0 return "bigint", 0
@@ -306,6 +319,32 @@ func (r *Reader) mapDataType(clarionType string, size int) (sqlType string, prec
} }
} }
func canonicalizePostgresType(typeStr string) string {
t := strings.ToLower(strings.Join(strings.Fields(strings.TrimSpace(typeStr)), " "))
if t == "" {
return ""
}
// Handle array suffixes
arrayCount := 0
for strings.HasSuffix(t, "[]") {
arrayCount++
t = strings.TrimSpace(strings.TrimSuffix(t, "[]"))
}
// Handle optional type modifier
modifier := ""
if idx := strings.Index(t, "("); idx > 0 {
if end := strings.LastIndex(t, ")"); end > idx {
modifier = t[idx : end+1]
t = strings.TrimSpace(t[:idx])
}
}
base := pgsql.CanonicalizeBaseType(t)
return base + modifier + strings.Repeat("[]", arrayCount)
}
// processKeys processes DCTX keys and converts them to indexes and primary keys // processKeys processes DCTX keys and converts them to indexes and primary keys
func (r *Reader) processKeys(dctxTable *models.DCTXTable, table *models.Table, fieldGuidMap map[string]string) error { func (r *Reader) processKeys(dctxTable *models.DCTXTable, table *models.Table, fieldGuidMap map[string]string) error {
for _, dctxKey := range dctxTable.Keys { for _, dctxKey := range dctxTable.Keys {

View File

@@ -493,3 +493,55 @@ func TestRelationships(t *testing.T) {
} }
} }
} }
func TestMapDataType_PostgresTypes(t *testing.T) {
reader := &Reader{}
tests := []struct {
name string
inputType string
size int
wantType string
wantLength int
}{
{
name: "integer array preserved",
inputType: "integer[]",
wantType: "integer[]",
},
{
name: "citext array preserved",
inputType: "citext[]",
wantType: "citext[]",
},
{
name: "vector modifier preserved",
inputType: "vector(1536)",
wantType: "vector(1536)",
},
{
name: "alias canonicalized in array",
inputType: "int4[]",
wantType: "integer[]",
},
{
name: "varchar length from size",
inputType: "varchar",
size: 120,
wantType: "varchar",
wantLength: 120,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotType, gotLength := reader.mapDataType(tt.inputType, tt.size)
if gotType != tt.wantType {
t.Fatalf("mapDataType(%q, %d) type = %q, want %q", tt.inputType, tt.size, gotType, tt.wantType)
}
if gotLength != tt.wantLength {
t.Fatalf("mapDataType(%q, %d) length = %d, want %d", tt.inputType, tt.size, gotLength, tt.wantLength)
}
})
}
}

View File

@@ -8,6 +8,7 @@ import (
"strings" "strings"
"git.warky.dev/wdevs/relspecgo/pkg/models" "git.warky.dev/wdevs/relspecgo/pkg/models"
"git.warky.dev/wdevs/relspecgo/pkg/pgsql"
"git.warky.dev/wdevs/relspecgo/pkg/readers" "git.warky.dev/wdevs/relspecgo/pkg/readers"
"git.warky.dev/wdevs/relspecgo/pkg/writers/drawdb" "git.warky.dev/wdevs/relspecgo/pkg/writers/drawdb"
) )
@@ -231,17 +232,19 @@ func (r *Reader) convertToColumn(field *drawdb.DrawDBField, tableName, schemaNam
// Parse type and dimensions // Parse type and dimensions
typeStr := field.Type typeStr := field.Type
typeStr = strings.TrimSpace(typeStr)
column.Type = typeStr column.Type = typeStr
// Try to extract length/precision from type string like "varchar(255)" or "decimal(10,2)" // Try to extract length/precision from type string like "varchar(255)" or "decimal(10,2)"
if strings.Contains(typeStr, "(") { if strings.Contains(typeStr, "(") {
parts := strings.Split(typeStr, "(") parts := strings.Split(typeStr, "(")
column.Type = parts[0] baseType := strings.TrimSpace(parts[0])
if len(parts) > 1 { if len(parts) > 1 {
dimensions := strings.TrimSuffix(parts[1], ")") dimensions := strings.TrimSuffix(parts[1], ")")
if strings.Contains(dimensions, ",") { if strings.Contains(dimensions, ",") {
// Precision and scale (e.g., decimal(10,2)) // Precision and scale (e.g., decimal(10,2), numeric(10,2))
if pgsql.SupportsPrecision(baseType) {
dims := strings.Split(dimensions, ",") dims := strings.Split(dimensions, ",")
if precision, err := strconv.Atoi(strings.TrimSpace(dims[0])); err == nil { if precision, err := strconv.Atoi(strings.TrimSpace(dims[0])); err == nil {
column.Precision = precision column.Precision = precision
@@ -251,14 +254,17 @@ func (r *Reader) convertToColumn(field *drawdb.DrawDBField, tableName, schemaNam
column.Scale = scale column.Scale = scale
} }
} }
}
} else { } else {
// Just length (e.g., varchar(255)) // Just length (e.g., varchar(255))
if pgsql.SupportsLength(baseType) {
if length, err := strconv.Atoi(dimensions); err == nil { if length, err := strconv.Atoi(dimensions); err == nil {
column.Length = length column.Length = length
} }
} }
} }
} }
}
column.IsPrimaryKey = field.Primary column.IsPrimaryKey = field.Primary
column.NotNull = field.NotNull || field.Primary column.NotNull = field.NotNull || field.Primary

View File

@@ -6,6 +6,7 @@ import (
"git.warky.dev/wdevs/relspecgo/pkg/models" "git.warky.dev/wdevs/relspecgo/pkg/models"
"git.warky.dev/wdevs/relspecgo/pkg/readers" "git.warky.dev/wdevs/relspecgo/pkg/readers"
"git.warky.dev/wdevs/relspecgo/pkg/writers/drawdb"
) )
func TestReader_ReadDatabase_Simple(t *testing.T) { func TestReader_ReadDatabase_Simple(t *testing.T) {
@@ -288,6 +289,61 @@ func TestReader_ReadDatabase_Complex(t *testing.T) {
} }
} }
func TestConvertToColumn_PreservesExplicitTypeModifiers(t *testing.T) {
reader := &Reader{}
tests := []struct {
name string
fieldType string
wantType string
wantLength int
wantPrecision int
wantScale int
}{
{
name: "varchar with length",
fieldType: "varchar(255)",
wantType: "varchar(255)",
wantLength: 255,
},
{
name: "numeric precision/scale",
fieldType: "numeric(10,2)",
wantType: "numeric(10,2)",
wantPrecision: 10,
wantScale: 2,
},
{
name: "custom vector modifier",
fieldType: "vector(1536)",
wantType: "vector(1536)",
wantLength: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
field := &drawdb.DrawDBField{
Name: tt.name,
Type: tt.fieldType,
}
col := reader.convertToColumn(field, "events", "public")
if col.Type != tt.wantType {
t.Fatalf("column type = %q, want %q", col.Type, tt.wantType)
}
if col.Length != tt.wantLength {
t.Fatalf("column length = %d, want %d", col.Length, tt.wantLength)
}
if col.Precision != tt.wantPrecision {
t.Fatalf("column precision = %d, want %d", col.Precision, tt.wantPrecision)
}
if col.Scale != tt.wantScale {
t.Fatalf("column scale = %d, want %d", col.Scale, tt.wantScale)
}
})
}
}
func TestReader_ReadSchema(t *testing.T) { func TestReader_ReadSchema(t *testing.T) {
opts := &readers.ReaderOptions{ opts := &readers.ReaderOptions{
FilePath: filepath.Join("..", "..", "..", "tests", "assets", "drawdb", "simple.json"), FilePath: filepath.Join("..", "..", "..", "tests", "assets", "drawdb", "simple.json"),

View File

@@ -12,6 +12,7 @@ import (
"strings" "strings"
"git.warky.dev/wdevs/relspecgo/pkg/models" "git.warky.dev/wdevs/relspecgo/pkg/models"
"git.warky.dev/wdevs/relspecgo/pkg/pgsql"
"git.warky.dev/wdevs/relspecgo/pkg/readers" "git.warky.dev/wdevs/relspecgo/pkg/readers"
) )
@@ -676,19 +677,8 @@ func (r *Reader) extractTableFromGormTag(tag string) (tablename string, schemaNa
// deriveTableName derives a table name from struct name // deriveTableName derives a table name from struct name
func (r *Reader) deriveTableName(structName string) string { func (r *Reader) deriveTableName(structName string) string {
// Remove "Model" prefix if present // Remove "Model" prefix if present, use the name as-is without transformation
name := strings.TrimPrefix(structName, "Model") return strings.TrimPrefix(structName, "Model")
// Convert PascalCase to snake_case
var result strings.Builder
for i, r := range name {
if i > 0 && r >= 'A' && r <= 'Z' {
result.WriteRune('_')
}
result.WriteRune(r)
}
return strings.ToLower(result.String())
} }
// parseColumn parses a struct field into a Column model // parseColumn parses a struct field into a Column model
@@ -784,11 +774,14 @@ func (r *Reader) extractGormTag(tag string) string {
// parseTypeWithLength parses a type string and extracts length if present // parseTypeWithLength parses a type string and extracts length if present
// e.g., "varchar(255)" returns ("varchar", 255) // e.g., "varchar(255)" returns ("varchar", 255)
func (r *Reader) parseTypeWithLength(typeStr string) (baseType string, length int) { func (r *Reader) parseTypeWithLength(typeStr string) (baseType string, length int) {
typeStr = strings.TrimSpace(typeStr)
baseType = typeStr
// Check for type with length: varchar(255), char(10), etc. // Check for type with length: varchar(255), char(10), etc.
// Also handle precision/scale: numeric(10,2) // Also handle precision/scale: numeric(10,2)
if strings.Contains(typeStr, "(") { if strings.Contains(typeStr, "(") {
idx := strings.Index(typeStr, "(") idx := strings.Index(typeStr, "(")
baseType = strings.TrimSpace(typeStr[:idx]) rawBaseType := strings.TrimSpace(typeStr[:idx])
// Extract numbers from parentheses // Extract numbers from parentheses
parens := typeStr[idx+1:] parens := typeStr[idx+1:]
@@ -796,14 +789,16 @@ func (r *Reader) parseTypeWithLength(typeStr string) (baseType string, length in
parens = parens[:endIdx] parens = parens[:endIdx]
} }
// For now, just handle single number (length) // Only treat as "length" for text-ish SQL types.
if !strings.Contains(parens, ",") { // This avoids converting custom modifiers like vector(1536) into Length.
if pgsql.SupportsLength(rawBaseType) && !strings.Contains(parens, ",") {
if _, err := fmt.Sscanf(parens, "%d", &length); err == nil { if _, err := fmt.Sscanf(parens, "%d", &length); err == nil {
baseType = pgsql.CanonicalizeBaseType(rawBaseType)
return return
} }
} }
} }
baseType = typeStr
return return
} }

View File

@@ -71,8 +71,11 @@ func TestReader_ReadDatabase_Simple(t *testing.T) {
if !emailCol.NotNull { if !emailCol.NotNull {
t.Error("Column 'email' should be NOT NULL (explicit 'not null' tag)") t.Error("Column 'email' should be NOT NULL (explicit 'not null' tag)")
} }
if emailCol.Type != "varchar" || emailCol.Length != 255 { if emailCol.Type != "varchar" && emailCol.Type != "varchar(255)" {
t.Errorf("Expected email type 'varchar(255)', got '%s' with length %d", emailCol.Type, emailCol.Length) t.Errorf("Expected email type 'varchar' or 'varchar(255)', got '%s' with length %d", emailCol.Type, emailCol.Length)
}
if emailCol.Length != 255 {
t.Errorf("Expected email length 255, got %d", emailCol.Length)
} }
// Verify name column - primitive string type should be NOT NULL by default // Verify name column - primitive string type should be NOT NULL by default
@@ -363,6 +366,33 @@ func TestReader_ReadDatabase_Complex(t *testing.T) {
} }
} }
func TestParseTypeWithLength_PreservesExplicitTypeModifiers(t *testing.T) {
reader := &Reader{}
tests := []struct {
input string
wantType string
wantLength int
}{
{"varchar(255)", "varchar", 255},
{"character varying(120)", "character varying", 120},
{"vector(1536)", "vector(1536)", 0},
{"numeric(10,2)", "numeric(10,2)", 0},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
gotType, gotLength := reader.parseTypeWithLength(tt.input)
if gotType != tt.wantType {
t.Fatalf("parseTypeWithLength(%q) type = %q, want %q", tt.input, gotType, tt.wantType)
}
if gotLength != tt.wantLength {
t.Fatalf("parseTypeWithLength(%q) length = %d, want %d", tt.input, gotLength, tt.wantLength)
}
})
}
}
func TestReader_ReadSchema(t *testing.T) { func TestReader_ReadSchema(t *testing.T) {
opts := &readers.ReaderOptions{ opts := &readers.ReaderOptions{
FilePath: filepath.Join("..", "..", "..", "tests", "assets", "gorm", "simple.go"), FilePath: filepath.Join("..", "..", "..", "tests", "assets", "gorm", "simple.go"),

View File

@@ -89,6 +89,10 @@ postgres://user@localhost/mydb?sslmode=disable
postgres://user:pass@db.example.com:5432/production?sslmode=require postgres://user:pass@db.example.com:5432/production?sslmode=require
``` ```
By default, relspec sets `application_name` to `relspecgo/<version>` for PostgreSQL
sessions so they are identifiable in `pg_stat_activity`. If you provide
`application_name` in the connection string, your explicit value is preserved.
## Extracted Information ## Extracted Information
### Tables ### Tables

View File

@@ -206,8 +206,19 @@ func (r *Reader) queryColumns(schemaName string) (map[string]map[string]*models.
c.numeric_precision, c.numeric_precision,
c.numeric_scale, c.numeric_scale,
c.udt_name, c.udt_name,
pg_catalog.format_type(a.atttypid, a.atttypmod) as formatted_data_type,
col_description((c.table_schema||'.'||c.table_name)::regclass, c.ordinal_position) as description col_description((c.table_schema||'.'||c.table_name)::regclass, c.ordinal_position) as description
FROM information_schema.columns c FROM information_schema.columns c
JOIN pg_catalog.pg_namespace n
ON n.nspname = c.table_schema
JOIN pg_catalog.pg_class cls
ON cls.relname = c.table_name
AND cls.relnamespace = n.oid
JOIN pg_catalog.pg_attribute a
ON a.attrelid = cls.oid
AND a.attname = c.column_name
AND a.attnum > 0
AND NOT a.attisdropped
WHERE c.table_schema = $1 WHERE c.table_schema = $1
ORDER BY c.table_schema, c.table_name, c.ordinal_position ORDER BY c.table_schema, c.table_name, c.ordinal_position
` `
@@ -221,12 +232,12 @@ func (r *Reader) queryColumns(schemaName string) (map[string]map[string]*models.
columnsMap := make(map[string]map[string]*models.Column) columnsMap := make(map[string]map[string]*models.Column)
for rows.Next() { for rows.Next() {
var schema, tableName, columnName, isNullable, dataType, udtName string var schema, tableName, columnName, isNullable, dataType, udtName, formattedDataType string
var ordinalPosition int var ordinalPosition int
var columnDefault, description *string var columnDefault, description *string
var charMaxLength, numPrecision, numScale *int var charMaxLength, numPrecision, numScale *int
if err := rows.Scan(&schema, &tableName, &columnName, &ordinalPosition, &columnDefault, &isNullable, &dataType, &charMaxLength, &numPrecision, &numScale, &udtName, &description); err != nil { if err := rows.Scan(&schema, &tableName, &columnName, &ordinalPosition, &columnDefault, &isNullable, &dataType, &charMaxLength, &numPrecision, &numScale, &udtName, &formattedDataType, &description); err != nil {
return nil, err return nil, err
} }
@@ -241,12 +252,12 @@ func (r *Reader) queryColumns(schemaName string) (map[string]map[string]*models.
column.AutoIncrement = true column.AutoIncrement = true
column.Default = defaultVal column.Default = defaultVal
} else { } else {
column.Default = defaultVal column.Default = normalizePostgresDefault(defaultVal)
} }
} }
// Map data type, preserving serial types when detected // Map data type, preserving serial types when detected
column.Type = r.mapDataType(dataType, udtName, hasNextval) column.Type = r.mapDataType(dataType, udtName, formattedDataType, hasNextval)
column.NotNull = (isNullable == "NO") column.NotNull = (isNullable == "NO")
column.Sequence = uint(ordinalPosition) column.Sequence = uint(ordinalPosition)
@@ -602,3 +613,30 @@ func (r *Reader) parseIndexDefinition(indexName, tableName, schema, indexDef str
return index, nil return index, nil
} }
// normalizePostgresDefault converts a raw PostgreSQL column_default expression into the
// unquoted string value that the model convention expects. PostgreSQL stores string
// literal defaults as 'value' or 'value'::type (e.g. '{}'::text[]), while every other
// reader stores the bare value so the writer can re-quote it correctly.
func normalizePostgresDefault(defaultVal string) string {
if !strings.HasPrefix(defaultVal, "'") {
return defaultVal
}
// Decode the SQL string literal: skip the leading quote, unescape '' → ', stop at
// the first unescaped closing quote (any trailing ::cast is ignored).
rest := defaultVal[1:]
var buf strings.Builder
for i := 0; i < len(rest); i++ {
if rest[i] == '\'' {
if i+1 < len(rest) && rest[i+1] == '\'' {
buf.WriteByte('\'')
i++
} else {
break
}
} else {
buf.WriteByte(rest[i])
}
}
return buf.String()
}

View File

@@ -244,7 +244,7 @@ func (r *Reader) ReadTable() (*models.Table, error) {
// connect establishes a connection to the PostgreSQL database // connect establishes a connection to the PostgreSQL database
func (r *Reader) connect() error { func (r *Reader) connect() error {
conn, err := pgx.Connect(r.ctx, r.options.ConnectionString) conn, err := pgsql.Connect(r.ctx, r.options.ConnectionString, "reader-pgsql")
if err != nil { if err != nil {
return err return err
} }
@@ -259,12 +259,14 @@ func (r *Reader) close() {
} }
} }
// mapDataType maps PostgreSQL data types to canonical types // mapDataType maps PostgreSQL data types while preserving exact type text when available.
func (r *Reader) mapDataType(pgType, udtName string, hasNextval bool) string { func (r *Reader) mapDataType(pgType, udtName, formattedType string, hasNextval bool) string {
normalizedPGType := strings.ToLower(strings.TrimSpace(pgType))
// If the column has a nextval default, it's likely a serial type // If the column has a nextval default, it's likely a serial type
// Map to the appropriate serial type instead of the base integer type // Map to the appropriate serial type instead of the base integer type
if hasNextval { if hasNextval {
switch strings.ToLower(pgType) { switch normalizedPGType {
case "integer", "int", "int4": case "integer", "int", "int4":
return "serial" return "serial"
case "bigint", "int8": case "bigint", "int8":
@@ -274,6 +276,17 @@ func (r *Reader) mapDataType(pgType, udtName string, hasNextval bool) string {
} }
} }
// Prefer the database-provided formatted type; this preserves arrays/custom
// types/modifiers like text[], vector(1536), numeric(10,2), etc.
if strings.TrimSpace(formattedType) != "" {
return formattedType
}
// information_schema reports arrays generically as "ARRAY" with udt_name like "_text".
if strings.EqualFold(pgType, "ARRAY") && strings.HasPrefix(udtName, "_") && len(udtName) > 1 {
return udtName[1:] + "[]"
}
// Map common PostgreSQL types // Map common PostgreSQL types
typeMap := map[string]string{ typeMap := map[string]string{
"integer": "integer", "integer": "integer",
@@ -320,7 +333,7 @@ func (r *Reader) mapDataType(pgType, udtName string, hasNextval bool) string {
} }
// Try mapped type first // Try mapped type first
if mapped, exists := typeMap[pgType]; exists { if mapped, exists := typeMap[normalizedPGType]; exists {
return mapped return mapped
} }
@@ -329,8 +342,11 @@ func (r *Reader) mapDataType(pgType, udtName string, hasNextval bool) string {
return pgsql.GetSQLType(pgType) return pgsql.GetSQLType(pgType)
} }
// Return UDT name for custom types // Return UDT name for custom types (including array fallback when needed)
if udtName != "" { if udtName != "" {
if strings.HasPrefix(udtName, "_") && len(udtName) > 1 {
return udtName[1:] + "[]"
}
return udtName return udtName
} }

View File

@@ -175,33 +175,37 @@ func TestMapDataType(t *testing.T) {
tests := []struct { tests := []struct {
pgType string pgType string
udtName string udtName string
formattedType string
expected string expected string
}{ }{
{"integer", "int4", "integer"}, {"integer", "int4", "", "integer"},
{"bigint", "int8", "bigint"}, {"bigint", "int8", "", "bigint"},
{"smallint", "int2", "smallint"}, {"smallint", "int2", "", "smallint"},
{"character varying", "varchar", "varchar"}, {"character varying", "varchar", "", "varchar"},
{"text", "text", "text"}, {"text", "text", "", "text"},
{"boolean", "bool", "boolean"}, {"boolean", "bool", "", "boolean"},
{"timestamp without time zone", "timestamp", "timestamp"}, {"timestamp without time zone", "timestamp", "", "timestamp"},
{"timestamp with time zone", "timestamptz", "timestamptz"}, {"timestamp with time zone", "timestamptz", "", "timestamptz"},
{"json", "json", "json"}, {"json", "json", "", "json"},
{"jsonb", "jsonb", "jsonb"}, {"jsonb", "jsonb", "", "jsonb"},
{"uuid", "uuid", "uuid"}, {"uuid", "uuid", "", "uuid"},
{"numeric", "numeric", "numeric"}, {"numeric", "numeric", "", "numeric"},
{"real", "float4", "real"}, {"real", "float4", "", "real"},
{"double precision", "float8", "double precision"}, {"double precision", "float8", "", "double precision"},
{"date", "date", "date"}, {"date", "date", "", "date"},
{"time without time zone", "time", "time"}, {"time without time zone", "time", "", "time"},
{"bytea", "bytea", "bytea"}, {"bytea", "bytea", "", "bytea"},
{"unknown_type", "custom", "custom"}, // Should return UDT name {"unknown_type", "custom", "", "custom"}, // Should return UDT name
{"ARRAY", "_text", "", "text[]"},
{"USER-DEFINED", "vector", "vector(1536)", "vector(1536)"},
{"character varying", "varchar", "character varying(255)", "character varying(255)"},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.pgType, func(t *testing.T) { t.Run(tt.pgType, func(t *testing.T) {
result := reader.mapDataType(tt.pgType, tt.udtName, false) result := reader.mapDataType(tt.pgType, tt.udtName, tt.formattedType, false)
if result != tt.expected { if result != tt.expected {
t.Errorf("mapDataType(%s, %s) = %s, expected %s", tt.pgType, tt.udtName, result, tt.expected) t.Errorf("mapDataType(%s, %s, %s) = %s, expected %s", tt.pgType, tt.udtName, tt.formattedType, result, tt.expected)
} }
}) })
} }
@@ -218,9 +222,9 @@ func TestMapDataType(t *testing.T) {
for _, tt := range serialTests { for _, tt := range serialTests {
t.Run(tt.pgType+"_with_nextval", func(t *testing.T) { t.Run(tt.pgType+"_with_nextval", func(t *testing.T) {
result := reader.mapDataType(tt.pgType, "", true) result := reader.mapDataType(tt.pgType, "", "", true)
if result != tt.expected { if result != tt.expected {
t.Errorf("mapDataType(%s, '', true) = %s, expected %s", tt.pgType, result, tt.expected) t.Errorf("mapDataType(%s, '', '', true) = %s, expected %s", tt.pgType, result, tt.expected)
} }
}) })
} }

View File

@@ -70,6 +70,7 @@ func (r *Reader) ReadTable() (*models.Table, error) {
// parsePrisma parses Prisma schema content and returns a Database model // parsePrisma parses Prisma schema content and returns a Database model
func (r *Reader) parsePrisma(content string) (*models.Database, error) { func (r *Reader) parsePrisma(content string) (*models.Database, error) {
db := models.InitDatabase("database") db := models.InitDatabase("database")
db.SourceFormat = "prisma"
if r.options.Metadata != nil { if r.options.Metadata != nil {
if name, ok := r.options.Metadata["name"].(string); ok { if name, ok := r.options.Metadata["name"].(string); ok {
@@ -139,7 +140,7 @@ func (r *Reader) parsePrisma(content string) (*models.Database, error) {
case "datasource": case "datasource":
r.parseDatasource(blockContent, db) r.parseDatasource(blockContent, db)
case "generator": case "generator":
// We don't need to do anything with generator blocks r.parseGenerator(blockContent, db)
case "model": case "model":
if currentTable != nil { if currentTable != nil {
r.parseModelFields(blockContent, currentTable) r.parseModelFields(blockContent, currentTable)
@@ -173,10 +174,34 @@ func (r *Reader) parsePrisma(content string) (*models.Database, error) {
// Second pass: resolve relationships // Second pass: resolve relationships
r.resolveRelationships(schema) r.resolveRelationships(schema)
if db.SourceFormat == "prisma" && r.options != nil && r.options.Prisma7 {
db.SourceFormat = "prisma7"
}
db.Schemas = append(db.Schemas, schema) db.Schemas = append(db.Schemas, schema)
return db, nil return db, nil
} }
func (r *Reader) parseGenerator(lines []string, db *models.Database) {
providerRegex := regexp.MustCompile(`provider\s*=\s*"([^"]+)"`)
for _, line := range lines {
if matches := providerRegex.FindStringSubmatch(line); matches != nil {
switch matches[1] {
case "prisma-client":
db.SourceFormat = "prisma7"
default:
db.SourceFormat = "prisma"
}
return
}
}
if r.options != nil && r.options.Prisma7 {
db.SourceFormat = "prisma7"
}
}
// parseDatasource extracts database type from datasource block // parseDatasource extracts database type from datasource block
func (r *Reader) parseDatasource(lines []string, db *models.Database) { func (r *Reader) parseDatasource(lines []string, db *models.Database) {
providerRegex := regexp.MustCompile(`provider\s*=\s*"?(\w+)"?`) providerRegex := regexp.MustCompile(`provider\s*=\s*"?(\w+)"?`)

View File

@@ -0,0 +1,77 @@
package prisma
import (
"os"
"path/filepath"
"testing"
"git.warky.dev/wdevs/relspecgo/pkg/readers"
)
func TestReadDatabase_Prisma7GeneratorSetsSourceFormat(t *testing.T) {
t.Parallel()
tmpDir := t.TempDir()
schemaPath := filepath.Join(tmpDir, "schema.prisma")
content := `datasource db {
provider = "postgresql"
url = env("DATABASE_URL")
}
generator client {
provider = "prisma-client"
output = "./generated"
}
model User {
id Int @id @default(autoincrement())
}`
if err := os.WriteFile(schemaPath, []byte(content), 0644); err != nil {
t.Fatalf("failed to write schema: %v", err)
}
reader := NewReader(&readers.ReaderOptions{FilePath: schemaPath})
db, err := reader.ReadDatabase()
if err != nil {
t.Fatalf("ReadDatabase() failed: %v", err)
}
if db.SourceFormat != "prisma7" {
t.Fatalf("expected SourceFormat prisma7, got %q", db.SourceFormat)
}
}
func TestReadDatabase_Prisma7FlagSetsSourceFormatWithoutGenerator(t *testing.T) {
t.Parallel()
tmpDir := t.TempDir()
schemaPath := filepath.Join(tmpDir, "schema.prisma")
content := `datasource db {
provider = "postgresql"
url = env("DATABASE_URL")
}
model User {
id Int @id @default(autoincrement())
}`
if err := os.WriteFile(schemaPath, []byte(content), 0644); err != nil {
t.Fatalf("failed to write schema: %v", err)
}
reader := NewReader(&readers.ReaderOptions{
FilePath: schemaPath,
Prisma7: true,
})
db, err := reader.ReadDatabase()
if err != nil {
t.Fatalf("ReadDatabase() failed: %v", err)
}
if db.SourceFormat != "prisma7" {
t.Fatalf("expected SourceFormat prisma7 from flag, got %q", db.SourceFormat)
}
}

View File

@@ -25,6 +25,9 @@ type ReaderOptions struct {
// ConnectionString is the database connection string (for DB readers) // ConnectionString is the database connection string (for DB readers)
ConnectionString string ConnectionString string
// Prisma7 enables Prisma 7-specific handling for Prisma schemas.
Prisma7 bool
// Additional options can be added here as needed // Additional options can be added here as needed
Metadata map[string]interface{} Metadata map[string]interface{}
} }

View File

@@ -5,9 +5,11 @@ import (
"fmt" "fmt"
"os" "os"
"regexp" "regexp"
"strconv"
"strings" "strings"
"git.warky.dev/wdevs/relspecgo/pkg/models" "git.warky.dev/wdevs/relspecgo/pkg/models"
"git.warky.dev/wdevs/relspecgo/pkg/pgsql"
"git.warky.dev/wdevs/relspecgo/pkg/readers" "git.warky.dev/wdevs/relspecgo/pkg/readers"
) )
@@ -549,6 +551,41 @@ func (r *Reader) parseColumnOptions(decorator string, column *models.Column, tab
} }
} }
// Preserve explicit type modifiers from options where present.
// Example: @Column({ type: 'varchar', length: 255 }) -> varchar(255)
if column.Type != "" && !strings.Contains(column.Type, "(") {
lengthRegex := regexp.MustCompile(`length:\s*(\d+)`)
precisionRegex := regexp.MustCompile(`precision:\s*(\d+)`)
scaleRegex := regexp.MustCompile(`scale:\s*(\d+)`)
baseType := strings.ToLower(strings.TrimSpace(column.Type))
if pgsql.SupportsLength(baseType) {
if matches := lengthRegex.FindStringSubmatch(content); len(matches) == 2 {
if n, err := strconv.Atoi(matches[1]); err == nil && n > 0 {
column.Length = n
column.Type = fmt.Sprintf("%s(%d)", column.Type, n)
}
}
}
if pgsql.SupportsPrecision(baseType) {
if matches := precisionRegex.FindStringSubmatch(content); len(matches) == 2 {
if p, err := strconv.Atoi(matches[1]); err == nil && p > 0 {
column.Precision = p
if sm := scaleRegex.FindStringSubmatch(content); len(sm) == 2 {
if s, err := strconv.Atoi(sm[1]); err == nil && s >= 0 {
column.Scale = s
column.Type = fmt.Sprintf("%s(%d,%d)", column.Type, p, s)
}
} else {
column.Type = fmt.Sprintf("%s(%d)", column.Type, p)
}
}
}
}
}
if strings.Contains(content, "nullable: true") || strings.Contains(content, "nullable:true") { if strings.Contains(content, "nullable: true") || strings.Contains(content, "nullable:true") {
column.NotNull = false column.NotNull = false
} }

View File

@@ -0,0 +1,60 @@
package typeorm
import (
"testing"
"git.warky.dev/wdevs/relspecgo/pkg/models"
)
func TestParseColumnOptions_PreservesTypeModifiers(t *testing.T) {
reader := &Reader{}
table := models.InitTable("users", "public")
tests := []struct {
name string
decorator string
wantType string
wantLength int
wantPrecision int
wantScale int
}{
{
name: "varchar with length",
decorator: `@Column({ type: 'varchar', length: 255 })`,
wantType: "varchar(255)",
wantLength: 255,
},
{
name: "numeric with precision and scale",
decorator: `@Column({ type: 'numeric', precision: 10, scale: 2 })`,
wantType: "numeric(10,2)",
wantPrecision: 10,
wantScale: 2,
},
{
name: "custom type with explicit modifier is preserved",
decorator: `@Column({ type: 'vector(1536)' })`,
wantType: "vector(1536)",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
col := models.InitColumn("sample", table.Name, table.Schema)
reader.parseColumnOptions(tt.decorator, col, table)
if col.Type != tt.wantType {
t.Fatalf("column type = %q, want %q", col.Type, tt.wantType)
}
if col.Length != tt.wantLength {
t.Fatalf("column length = %d, want %d", col.Length, tt.wantLength)
}
if col.Precision != tt.wantPrecision {
t.Fatalf("column precision = %d, want %d", col.Precision, tt.wantPrecision)
}
if col.Scale != tt.wantScale {
t.Fatalf("column scale = %d, want %d", col.Scale, tt.wantScale)
}
})
}
}

View File

@@ -46,54 +46,67 @@ func main() {
### CLI Examples ### CLI Examples
```bash ```bash
# Generate Bun models from PostgreSQL database # Generate Bun models from a DBML schema (default: resolvespec types)
relspec --input pgsql \ relspec convert --from dbml --from-path schema.dbml \
--conn "postgres://localhost/mydb" \ --to bun --to-path models.go --package models
--output bun \
--out-file models.go \
--package models
# Convert GORM models to Bun # Use standard library database/sql nullable types instead of resolvespec
relspec --input gorm --in-file gorm_models.go --output bun --out-file bun_models.go relspec convert --from dbml --from-path schema.dbml \
--to bun --to-path models.go --package models \
--types stdlib
# Multi-file output # Explicitly select resolvespec types (same as omitting --types)
relspec --input json --in-file schema.json --output bun --out-file models/ relspec convert --from pgsql --from-conn "postgres://localhost/mydb" \
--to bun --to-path models.go --package models \
--types resolvespec
# Multi-file output (one file per table)
relspec convert --from json --from-path schema.json \
--to bun --to-path models/ --package models
``` ```
## Generated Code Example ## Generated Code Examples
### Default — resolvespec types (`--types resolvespec`)
```go ```go
package models package models
import ( import (
"time" resolvespec_common "github.com/bitechdev/ResolveSpec/pkg/spectypes"
"database/sql"
"github.com/uptrace/bun" "github.com/uptrace/bun"
) )
type User struct { type User struct {
bun.BaseModel `bun:"table:users,alias:u"` bun.BaseModel `bun:"table:users,alias:u"`
ID int64 `bun:"id,pk,autoincrement" json:"id"` ID int64 `bun:"id,type:uuid,pk," json:"id"`
Username string `bun:"username,notnull,unique" json:"username"` Username string `bun:"username,type:text,notnull," json:"username"`
Email string `bun:"email,notnull" json:"email"` Email resolvespec_common.SqlString `bun:"email,type:text,nullzero," json:"email"`
Bio sql.NullString `bun:"bio" json:"bio,omitempty"` Tags resolvespec_common.SqlStringArray `bun:"tags,type:text[],default:'{}',notnull," json:"tags"`
CreatedAt time.Time `bun:"created_at,notnull,default:now()" json:"created_at"` CreatedAt resolvespec_common.SqlTimeStamp `bun:"created_at,type:timestamptz,default:now(),notnull," json:"created_at"`
// Relationships
Posts []*Post `bun:"rel:has-many,join:id=user_id" json:"posts,omitempty"`
} }
```
type Post struct { ### Standard library — `--types stdlib`
bun.BaseModel `bun:"table:posts,alias:p"`
ID int64 `bun:"id,pk" json:"id"` ```go
UserID int64 `bun:"user_id,notnull" json:"user_id"` package models
Title string `bun:"title,notnull" json:"title"`
Content sql.NullString `bun:"content" json:"content,omitempty"`
// Belongs to import (
User *User `bun:"rel:belongs-to,join:user_id=id" json:"user,omitempty"` "database/sql"
"time"
"github.com/uptrace/bun"
)
type User struct {
bun.BaseModel `bun:"table:users,alias:u"`
ID string `bun:"id,type:uuid,pk," json:"id"`
Username string `bun:"username,type:text,notnull," json:"username"`
Email sql.NullString `bun:"email,type:text,nullzero," json:"email"`
Tags []string `bun:"tags,type:text[],default:'{}',notnull," json:"tags"`
CreatedAt time.Time `bun:"created_at,type:timestamptz,default:now(),notnull," json:"created_at"`
} }
``` ```
@@ -111,19 +124,68 @@ type Post struct {
## Type Mapping ## Type Mapping
| SQL Type | Go Type | Nullable Type | The nullable type package is selected with `--types` (or `WriterOptions.NullableTypes`).
|----------|---------|---------------|
| bigint | int64 | sql.NullInt64 | | SQL Type | NOT NULL (both) | Nullable — resolvespec | Nullable — stdlib |
| integer | int | sql.NullInt32 | |---|---|---|---|
| varchar, text | string | sql.NullString | | `bigint` | `int64` | `SqlInt64` | `sql.NullInt64` |
| boolean | bool | sql.NullBool | | `integer` | `int32` | `SqlInt32` | `sql.NullInt32` |
| timestamp | time.Time | sql.NullTime | | `smallint` | `int16` | `SqlInt16` | `sql.NullInt16` |
| numeric | float64 | sql.NullFloat64 | | `text`, `varchar` | `string` | `SqlString` | `sql.NullString` |
| `boolean` | `bool` | `SqlBool` | `sql.NullBool` |
| `timestamp`, `timestamptz` | `time.Time`* | `SqlTimeStamp` | `sql.NullTime` |
| `numeric`, `decimal` | `float64` | `SqlFloat64` | `sql.NullFloat64` |
| `uuid` | `string` | `SqlUUID` | `sql.NullString` |
| `jsonb` | `string` | `SqlJSONB` | `sql.NullString` |
| `text[]` | `SqlStringArray` | `SqlStringArray` | `[]string` |
| `integer[]` | `SqlInt32Array` | `SqlInt32Array` | `[]int32` |
| `uuid[]` | `SqlUUIDArray` | `SqlUUIDArray` | `[]string` |
| `vector` | `SqlVector` | `SqlVector` | `[]float32` |
\* In resolvespec mode, NOT NULL timestamps use `SqlTimeStamp` (not `time.Time`) unless the base type is a simple integer or boolean. In stdlib mode, NOT NULL timestamps use `time.Time`.
## Writer Options
### NullableTypes
Controls which Go package is used for nullable column types. Set via the `--types` CLI flag or `WriterOptions.NullableTypes`:
```go
// Use resolvespec types (default — omit NullableTypes or set to "resolvespec")
options := &writers.WriterOptions{
OutputPath: "models.go",
PackageName: "models",
NullableTypes: writers.NullableTypeResolveSpec,
}
// Use standard library database/sql types
options := &writers.WriterOptions{
OutputPath: "models.go",
PackageName: "models",
NullableTypes: writers.NullableTypeStdlib,
}
```
### Metadata Options
```go
options := &writers.WriterOptions{
OutputPath: "models.go",
PackageName: "models",
Metadata: map[string]any{
"multi_file": true, // Enable multi-file mode
"populate_refs": true, // Populate RefDatabase/RefSchema
"generate_get_id_str": true, // Generate GetIDStr() methods
},
}
```
## Notes ## Notes
- Model names are derived from table names (singularized, PascalCase) - Model names are derived from table names (singularized, PascalCase)
- Table aliases are auto-generated from table names - Table aliases are auto-generated from table names
- Nullable columns use `resolvespec_common.SqlString`, `resolvespec_common.SqlTimeStamp`, etc. by default; pass `--types stdlib` to use `sql.NullString`, `sql.NullTime`, etc. instead
- Array columns use `resolvespec_common.SqlStringArray`, `resolvespec_common.SqlInt32Array`, etc. by default; `--types stdlib` produces plain Go slices (`[]string`, `[]int32`, …)
- Multi-file mode: one file per table named `sql_{schema}_{table}.go` - Multi-file mode: one file per table named `sql_{schema}_{table}.go`
- Generated code is auto-formatted - Generated code is auto-formatted
- JSON tags are automatically added - JSON tags are automatically added

View File

@@ -26,7 +26,10 @@ type ModelData struct {
Fields []*FieldData Fields []*FieldData
Config *MethodConfig Config *MethodConfig
PrimaryKeyField string // Name of the primary key field PrimaryKeyField string // Name of the primary key field
PrimaryKeyType string // Go type of the primary key field
PrimaryKeyIsSQL bool // Whether PK uses SQL type (needs .Int64() call) PrimaryKeyIsSQL bool // Whether PK uses SQL type (needs .Int64() call)
PrimaryKeyIsStr bool // Whether helper methods should use string IDs
PrimaryKeyIDType string // Helper method GetID/SetID/UpdateID type
IDColumnName string // Name of the ID column in database IDColumnName string // Name of the ID column in database
Prefix string // 3-letter prefix Prefix string // 3-letter prefix
} }
@@ -140,7 +143,13 @@ func NewModelData(table *models.Table, schema string, typeMapper *TypeMapper, fl
model.IDColumnName = safeName model.IDColumnName = safeName
// Check if PK type is a SQL type (contains resolvespec_common or sql_types) // Check if PK type is a SQL type (contains resolvespec_common or sql_types)
goType := typeMapper.SQLTypeToGoType(col.Type, col.NotNull) goType := typeMapper.SQLTypeToGoType(col.Type, col.NotNull)
model.PrimaryKeyType = goType
model.PrimaryKeyIsSQL = strings.Contains(goType, "resolvespec_common") || strings.Contains(goType, "sql_types") model.PrimaryKeyIsSQL = strings.Contains(goType, "resolvespec_common") || strings.Contains(goType, "sql_types")
model.PrimaryKeyIsStr = isStringLikePrimaryKeyType(goType)
model.PrimaryKeyIDType = "int64"
if model.PrimaryKeyIsStr {
model.PrimaryKeyIDType = "string"
}
break break
} }
} }
@@ -192,6 +201,15 @@ func formatComment(description, comment string) string {
return comment return comment
} }
func isStringLikePrimaryKeyType(goType string) bool {
switch goType {
case "string", "sql.NullString", "resolvespec_common.SqlString", "resolvespec_common.SqlUUID":
return true
default:
return false
}
}
// resolveFieldNameCollision checks if a field name conflicts with generated method names // resolveFieldNameCollision checks if a field name conflicts with generated method names
// and adds an underscore suffix if there's a collision // and adds an underscore suffix if there's a collision
func resolveFieldNameCollision(fieldName string) string { func resolveFieldNameCollision(fieldName string) string {

View File

@@ -44,33 +44,55 @@ func (m {{.Name}}) SchemaName() string {
{{end}} {{end}}
{{if and .Config.GenerateGetID .PrimaryKeyField}} {{if and .Config.GenerateGetID .PrimaryKeyField}}
// GetID returns the primary key value // GetID returns the primary key value
func (m {{.Name}}) GetID() int64 { func (m {{.Name}}) GetID() {{.PrimaryKeyIDType}} {
{{if .PrimaryKeyIsSQL -}} {{if .PrimaryKeyIsSQL -}}
{{if .PrimaryKeyIsStr -}}
return m.{{.PrimaryKeyField}}.String()
{{- else -}}
return m.{{.PrimaryKeyField}}.Int64() return m.{{.PrimaryKeyField}}.Int64()
{{- end}}
{{- else -}}
{{if .PrimaryKeyIsStr -}}
return m.{{.PrimaryKeyField}}
{{- else -}} {{- else -}}
return int64(m.{{.PrimaryKeyField}}) return int64(m.{{.PrimaryKeyField}})
{{- end}} {{- end}}
{{- end}}
} }
{{end}} {{end}}
{{if and .Config.GenerateGetIDStr .PrimaryKeyField}} {{if and .Config.GenerateGetIDStr .PrimaryKeyField}}
// GetIDStr returns the primary key as a string // GetIDStr returns the primary key as a string
func (m {{.Name}}) GetIDStr() string { func (m {{.Name}}) GetIDStr() string {
{{if .PrimaryKeyIsSQL -}}
return m.{{.PrimaryKeyField}}.String()
{{- else if .PrimaryKeyIsStr -}}
return m.{{.PrimaryKeyField}}
{{- else -}}
return fmt.Sprintf("%d", m.{{.PrimaryKeyField}}) return fmt.Sprintf("%d", m.{{.PrimaryKeyField}})
{{- end}}
} }
{{end}} {{end}}
{{if and .Config.GenerateSetID .PrimaryKeyField}} {{if and .Config.GenerateSetID .PrimaryKeyField}}
// SetID sets the primary key value // SetID sets the primary key value
func (m {{.Name}}) SetID(newid int64) { func (m {{.Name}}) SetID(newid {{.PrimaryKeyIDType}}) {
m.UpdateID(newid) m.UpdateID(newid)
} }
{{end}} {{end}}
{{if and .Config.GenerateUpdateID .PrimaryKeyField}} {{if and .Config.GenerateUpdateID .PrimaryKeyField}}
// UpdateID updates the primary key value // UpdateID updates the primary key value
func (m *{{.Name}}) UpdateID(newid int64) { func (m *{{.Name}}) UpdateID(newid {{.PrimaryKeyIDType}}) {
{{if .PrimaryKeyIsSQL -}} {{if .PrimaryKeyIsSQL -}}
m.{{.PrimaryKeyField}}.FromString(fmt.Sprintf("%d", newid)) {{if .PrimaryKeyIsStr -}}
m.{{.PrimaryKeyField}}.FromString(newid)
{{- else -}} {{- else -}}
m.{{.PrimaryKeyField}} = int32(newid) m.{{.PrimaryKeyField}}.FromString(fmt.Sprintf("%d", newid))
{{- end}}
{{- else -}}
{{if .PrimaryKeyIsStr -}}
m.{{.PrimaryKeyField}} = newid
{{- else -}}
m.{{.PrimaryKeyField}} = {{.PrimaryKeyType}}(newid)
{{- end}}
{{- end}} {{- end}}
} }
{{end}} {{end}}

View File

@@ -5,48 +5,55 @@ import (
"strings" "strings"
"git.warky.dev/wdevs/relspecgo/pkg/models" "git.warky.dev/wdevs/relspecgo/pkg/models"
"git.warky.dev/wdevs/relspecgo/pkg/pgsql"
"git.warky.dev/wdevs/relspecgo/pkg/writers" "git.warky.dev/wdevs/relspecgo/pkg/writers"
) )
// TypeMapper handles type conversions between SQL and Go types for Bun // TypeMapper handles type conversions between SQL and Go types for Bun
type TypeMapper struct { type TypeMapper struct {
// Package alias for sql_types import
sqlTypesAlias string sqlTypesAlias string
typeStyle string // writers.NullableTypeResolveSpec | writers.NullableTypeStdlib
} }
// NewTypeMapper creates a new TypeMapper with default settings // NewTypeMapper creates a new TypeMapper.
func NewTypeMapper() *TypeMapper { // typeStyle should be writers.NullableTypeResolveSpec or writers.NullableTypeStdlib;
// an empty string defaults to resolvespec.
func NewTypeMapper(typeStyle string) *TypeMapper {
if typeStyle == "" {
typeStyle = writers.NullableTypeResolveSpec
}
return &TypeMapper{ return &TypeMapper{
sqlTypesAlias: "resolvespec_common", sqlTypesAlias: "resolvespec_common",
typeStyle: typeStyle,
} }
} }
// SQLTypeToGoType converts a SQL type to its Go equivalent // SQLTypeToGoType converts a SQL type to its Go equivalent.
// Uses ResolveSpec common package types (all are nullable by default in Bun)
func (tm *TypeMapper) SQLTypeToGoType(sqlType string, notNull bool) string { func (tm *TypeMapper) SQLTypeToGoType(sqlType string, notNull bool) string {
// Normalize SQL type (lowercase, remove length/precision) // Array types are handled separately for both styles.
if pgsql.IsArrayType(sqlType) {
return tm.arrayGoType(tm.extractBaseType(sqlType))
}
baseType := tm.extractBaseType(sqlType) baseType := tm.extractBaseType(sqlType)
// For Bun, we typically use resolvespec_common types for most fields if tm.typeStyle == writers.NullableTypeStdlib {
// unless they're explicitly NOT NULL and we want to avoid null handling if notNull {
return tm.rawGoType(baseType)
}
return tm.stdlibNullableGoType(baseType)
}
// resolvespec (default): use base Go types only for simple NOT NULL fields.
if notNull && tm.isSimpleType(baseType) { if notNull && tm.isSimpleType(baseType) {
return tm.baseGoType(baseType) return tm.baseGoType(baseType)
} }
// Use resolvespec_common types for nullable fields
return tm.bunGoType(baseType) return tm.bunGoType(baseType)
} }
// extractBaseType extracts the base type from a SQL type string // extractBaseType extracts the base type from a SQL type string
func (tm *TypeMapper) extractBaseType(sqlType string) string { func (tm *TypeMapper) extractBaseType(sqlType string) string {
sqlType = strings.ToLower(strings.TrimSpace(sqlType)) return pgsql.CanonicalizeBaseType(pgsql.ExtractBaseTypeLower(sqlType))
// Remove everything after '('
if idx := strings.Index(sqlType, "("); idx > 0 {
sqlType = sqlType[:idx]
}
return sqlType
} }
// isSimpleType checks if a type should use base Go type when NOT NULL // isSimpleType checks if a type should use base Go type when NOT NULL
@@ -160,6 +167,9 @@ func (tm *TypeMapper) bunGoType(sqlType string) string {
// Other // Other
"money": tm.sqlTypesAlias + ".SqlFloat64", "money": tm.sqlTypesAlias + ".SqlFloat64",
// pgvector
"vector": tm.sqlTypesAlias + ".SqlVector",
} }
if goType, ok := typeMap[sqlType]; ok { if goType, ok := typeMap[sqlType]; ok {
@@ -170,6 +180,123 @@ func (tm *TypeMapper) bunGoType(sqlType string) string {
return tm.sqlTypesAlias + ".SqlString" return tm.sqlTypesAlias + ".SqlString"
} }
// arrayGoType returns the Go type for a PostgreSQL array column.
// The baseElemType is the canonical base type (e.g. "text", "integer").
func (tm *TypeMapper) arrayGoType(baseElemType string) string {
if tm.typeStyle == writers.NullableTypeStdlib {
return tm.stdlibArrayGoType(baseElemType)
}
typeMap := map[string]string{
"text": tm.sqlTypesAlias + ".SqlStringArray", "varchar": tm.sqlTypesAlias + ".SqlStringArray",
"char": tm.sqlTypesAlias + ".SqlStringArray", "character": tm.sqlTypesAlias + ".SqlStringArray",
"citext": tm.sqlTypesAlias + ".SqlStringArray", "bpchar": tm.sqlTypesAlias + ".SqlStringArray",
"inet": tm.sqlTypesAlias + ".SqlStringArray", "cidr": tm.sqlTypesAlias + ".SqlStringArray",
"macaddr": tm.sqlTypesAlias + ".SqlStringArray",
"json": tm.sqlTypesAlias + ".SqlStringArray", "jsonb": tm.sqlTypesAlias + ".SqlStringArray",
"integer": tm.sqlTypesAlias + ".SqlInt32Array", "int": tm.sqlTypesAlias + ".SqlInt32Array",
"int4": tm.sqlTypesAlias + ".SqlInt32Array", "serial": tm.sqlTypesAlias + ".SqlInt32Array",
"smallint": tm.sqlTypesAlias + ".SqlInt16Array", "int2": tm.sqlTypesAlias + ".SqlInt16Array",
"smallserial": tm.sqlTypesAlias + ".SqlInt16Array",
"bigint": tm.sqlTypesAlias + ".SqlInt64Array", "int8": tm.sqlTypesAlias + ".SqlInt64Array",
"bigserial": tm.sqlTypesAlias + ".SqlInt64Array",
"real": tm.sqlTypesAlias + ".SqlFloat32Array", "float4": tm.sqlTypesAlias + ".SqlFloat32Array",
"double precision": tm.sqlTypesAlias + ".SqlFloat64Array", "float8": tm.sqlTypesAlias + ".SqlFloat64Array",
"numeric": tm.sqlTypesAlias + ".SqlFloat64Array", "decimal": tm.sqlTypesAlias + ".SqlFloat64Array",
"money": tm.sqlTypesAlias + ".SqlFloat64Array",
"boolean": tm.sqlTypesAlias + ".SqlBoolArray", "bool": tm.sqlTypesAlias + ".SqlBoolArray",
"uuid": tm.sqlTypesAlias + ".SqlUUIDArray",
}
if goType, ok := typeMap[baseElemType]; ok {
return goType
}
return tm.sqlTypesAlias + ".SqlStringArray"
}
// rawGoType returns the plain Go type for a NOT NULL column in stdlib mode.
func (tm *TypeMapper) rawGoType(sqlType string) string {
typeMap := map[string]string{
"integer": "int32", "int": "int32", "int4": "int32", "serial": "int32",
"smallint": "int16", "int2": "int16", "smallserial": "int16",
"bigint": "int64", "int8": "int64", "bigserial": "int64",
"boolean": "bool", "bool": "bool",
"real": "float32", "float4": "float32",
"double precision": "float64", "float8": "float64",
"numeric": "float64", "decimal": "float64", "money": "float64",
"text": "string", "varchar": "string", "char": "string",
"character": "string", "citext": "string", "bpchar": "string",
"inet": "string", "cidr": "string", "macaddr": "string",
"uuid": "string", "json": "string", "jsonb": "string",
"timestamp": "time.Time",
"timestamp without time zone": "time.Time",
"timestamp with time zone": "time.Time",
"timestamptz": "time.Time",
"date": "time.Time",
"time": "time.Time",
"time without time zone": "time.Time",
"time with time zone": "time.Time",
"timetz": "time.Time",
"bytea": "[]byte",
"vector": "[]float32",
}
if goType, ok := typeMap[sqlType]; ok {
return goType
}
return "string"
}
// stdlibNullableGoType returns the database/sql nullable type for a column.
func (tm *TypeMapper) stdlibNullableGoType(sqlType string) string {
typeMap := map[string]string{
"integer": "sql.NullInt32", "int": "sql.NullInt32", "int4": "sql.NullInt32", "serial": "sql.NullInt32",
"smallint": "sql.NullInt16", "int2": "sql.NullInt16", "smallserial": "sql.NullInt16",
"bigint": "sql.NullInt64", "int8": "sql.NullInt64", "bigserial": "sql.NullInt64",
"boolean": "sql.NullBool", "bool": "sql.NullBool",
"real": "sql.NullFloat64", "float4": "sql.NullFloat64",
"double precision": "sql.NullFloat64", "float8": "sql.NullFloat64",
"numeric": "sql.NullFloat64", "decimal": "sql.NullFloat64", "money": "sql.NullFloat64",
"text": "sql.NullString", "varchar": "sql.NullString", "char": "sql.NullString",
"character": "sql.NullString", "citext": "sql.NullString", "bpchar": "sql.NullString",
"inet": "sql.NullString", "cidr": "sql.NullString", "macaddr": "sql.NullString",
"uuid": "sql.NullString", "json": "sql.NullString", "jsonb": "sql.NullString",
"timestamp": "sql.NullTime",
"timestamp without time zone": "sql.NullTime",
"timestamp with time zone": "sql.NullTime",
"timestamptz": "sql.NullTime",
"date": "sql.NullTime",
"time": "sql.NullTime",
"time without time zone": "sql.NullTime",
"time with time zone": "sql.NullTime",
"timetz": "sql.NullTime",
"bytea": "[]byte",
"vector": "[]float32",
}
if goType, ok := typeMap[sqlType]; ok {
return goType
}
return "sql.NullString"
}
// stdlibArrayGoType returns a plain Go slice type for array columns in stdlib mode.
func (tm *TypeMapper) stdlibArrayGoType(baseElemType string) string {
typeMap := map[string]string{
"text": "[]string", "varchar": "[]string", "char": "[]string",
"character": "[]string", "citext": "[]string", "bpchar": "[]string",
"inet": "[]string", "cidr": "[]string", "macaddr": "[]string",
"uuid": "[]string", "json": "[]string", "jsonb": "[]string",
"integer": "[]int32", "int": "[]int32", "int4": "[]int32", "serial": "[]int32",
"smallint": "[]int16", "int2": "[]int16", "smallserial": "[]int16",
"bigint": "[]int64", "int8": "[]int64", "bigserial": "[]int64",
"real": "[]float32", "float4": "[]float32",
"double precision": "[]float64", "float8": "[]float64",
"numeric": "[]float64", "decimal": "[]float64", "money": "[]float64",
"boolean": "[]bool", "bool": "[]bool",
}
if goType, ok := typeMap[baseElemType]; ok {
return goType
}
return "[]string"
}
// BuildBunTag generates a complete Bun tag string for a column // BuildBunTag generates a complete Bun tag string for a column
// Bun format: bun:"column_name,type:type_name,pk,default:value" // Bun format: bun:"column_name,type:type_name,pk,default:value"
func (tm *TypeMapper) BuildBunTag(column *models.Column, table *models.Table) string { func (tm *TypeMapper) BuildBunTag(column *models.Column, table *models.Table) string {
@@ -184,9 +311,11 @@ func (tm *TypeMapper) BuildBunTag(column *models.Column, table *models.Table) st
if column.Type != "" { if column.Type != "" {
// Sanitize type to remove backticks // Sanitize type to remove backticks
typeStr := writers.SanitizeStructTagValue(column.Type) typeStr := writers.SanitizeStructTagValue(column.Type)
if column.Length > 0 { isArray := pgsql.IsArrayType(typeStr)
hasExplicitTypeModifier := pgsql.HasExplicitTypeModifier(typeStr)
if !hasExplicitTypeModifier && !isArray && column.Length > 0 {
typeStr = fmt.Sprintf("%s(%d)", typeStr, column.Length) typeStr = fmt.Sprintf("%s(%d)", typeStr, column.Length)
} else if column.Precision > 0 { } else if !hasExplicitTypeModifier && !isArray && column.Precision > 0 {
if column.Scale > 0 { if column.Scale > 0 {
typeStr = fmt.Sprintf("%s(%d,%d)", typeStr, column.Precision, column.Scale) typeStr = fmt.Sprintf("%s(%d,%d)", typeStr, column.Precision, column.Scale)
} else { } else {
@@ -194,6 +323,9 @@ func (tm *TypeMapper) BuildBunTag(column *models.Column, table *models.Table) st
} }
} }
parts = append(parts, fmt.Sprintf("type:%s", typeStr)) parts = append(parts, fmt.Sprintf("type:%s", typeStr))
if isArray && tm.typeStyle == writers.NullableTypeStdlib {
parts = append(parts, "array")
}
} }
// Primary key // Primary key
@@ -291,11 +423,20 @@ func (tm *TypeMapper) NeedsFmtImport(generateGetIDStr bool) bool {
return generateGetIDStr return generateGetIDStr
} }
// GetSQLTypesImport returns the import path for sql_types (ResolveSpec common) // GetSQLTypesImport returns the import path for the ResolveSpec spectypes package.
func (tm *TypeMapper) GetSQLTypesImport() string { func (tm *TypeMapper) GetSQLTypesImport() string {
return "github.com/bitechdev/ResolveSpec/pkg/spectypes" return "github.com/bitechdev/ResolveSpec/pkg/spectypes"
} }
// GetNullableTypeImportLine returns the full Go import line for the nullable type
// package (ready to pass to AddImport). Returns empty string when no import is needed.
func (tm *TypeMapper) GetNullableTypeImportLine() string {
if tm.typeStyle == writers.NullableTypeStdlib {
return "\"database/sql\""
}
return fmt.Sprintf("%s \"%s\"", tm.sqlTypesAlias, tm.GetSQLTypesImport())
}
// GetBunImport returns the import path for Bun // GetBunImport returns the import path for Bun
func (tm *TypeMapper) GetBunImport() string { func (tm *TypeMapper) GetBunImport() string {
return "github.com/uptrace/bun" return "github.com/uptrace/bun"

View File

@@ -24,7 +24,7 @@ type Writer struct {
func NewWriter(options *writers.WriterOptions) *Writer { func NewWriter(options *writers.WriterOptions) *Writer {
w := &Writer{ w := &Writer{
options: options, options: options,
typeMapper: NewTypeMapper(), typeMapper: NewTypeMapper(options.NullableTypes),
config: LoadMethodConfigFromMetadata(options.Metadata), config: LoadMethodConfigFromMetadata(options.Metadata),
} }
@@ -80,8 +80,8 @@ func (w *Writer) writeSingleFile(db *models.Database) error {
// Add bun import (always needed) // Add bun import (always needed)
templateData.AddImport(fmt.Sprintf("\"%s\"", w.typeMapper.GetBunImport())) templateData.AddImport(fmt.Sprintf("\"%s\"", w.typeMapper.GetBunImport()))
// Add resolvespec_common import (always needed for nullable types) // Add nullable types import (resolvespec or stdlib depending on options)
templateData.AddImport(fmt.Sprintf("resolvespec_common \"%s\"", w.typeMapper.GetSQLTypesImport())) templateData.AddImport(w.typeMapper.GetNullableTypeImportLine())
// Collect all models // Collect all models
for _, schema := range db.Schemas { for _, schema := range db.Schemas {
@@ -102,8 +102,8 @@ func (w *Writer) writeSingleFile(db *models.Database) error {
} }
} }
// Add fmt import if GetIDStr is enabled // Add fmt import when generated helper methods need string formatting.
if w.config.GenerateGetIDStr { if w.needsFmtImport(templateData.Models) {
templateData.AddImport("\"fmt\"") templateData.AddImport("\"fmt\"")
} }
@@ -177,8 +177,8 @@ func (w *Writer) writeMultiFile(db *models.Database) error {
// Add bun import // Add bun import
templateData.AddImport(fmt.Sprintf("\"%s\"", w.typeMapper.GetBunImport())) templateData.AddImport(fmt.Sprintf("\"%s\"", w.typeMapper.GetBunImport()))
// Add resolvespec_common import // Add nullable types import (resolvespec or stdlib depending on options)
templateData.AddImport(fmt.Sprintf("resolvespec_common \"%s\"", w.typeMapper.GetSQLTypesImport())) templateData.AddImport(w.typeMapper.GetNullableTypeImportLine())
// Create model data // Create model data
modelData := NewModelData(table, schema.Name, w.typeMapper, w.options.FlattenSchema) modelData := NewModelData(table, schema.Name, w.typeMapper, w.options.FlattenSchema)
@@ -195,8 +195,8 @@ func (w *Writer) writeMultiFile(db *models.Database) error {
} }
} }
// Add fmt import if GetIDStr is enabled // Add fmt import when generated helper methods need string formatting.
if w.config.GenerateGetIDStr { if w.needsFmtImport(templateData.Models) {
templateData.AddImport("\"fmt\"") templateData.AddImport("\"fmt\"")
} }
@@ -301,6 +301,26 @@ func (w *Writer) addRelationshipFields(modelData *ModelData, table *models.Table
} }
} }
func (w *Writer) needsFmtImport(models []*ModelData) bool {
if w.config.GenerateGetIDStr {
for _, model := range models {
if model.PrimaryKeyField != "" && !model.PrimaryKeyIsSQL && !model.PrimaryKeyIsStr {
return true
}
}
}
if w.config.GenerateUpdateID {
for _, model := range models {
if model.PrimaryKeyField != "" && model.PrimaryKeyIsSQL && !model.PrimaryKeyIsStr {
return true
}
}
}
return false
}
// findTable finds a table by schema and name in the database // findTable finds a table by schema and name in the database
func (w *Writer) findTable(schemaName, tableName string, db *models.Database) *models.Table { func (w *Writer) findTable(schemaName, tableName string, db *models.Database) *models.Table {
for _, schema := range db.Schemas { for _, schema := range db.Schemas {

View File

@@ -556,7 +556,7 @@ func TestWriter_FieldNameCollision(t *testing.T) {
} }
func TestTypeMapper_SQLTypeToGoType_Bun(t *testing.T) { func TestTypeMapper_SQLTypeToGoType_Bun(t *testing.T) {
mapper := NewTypeMapper() mapper := NewTypeMapper("")
tests := []struct { tests := []struct {
sqlType string sqlType string
@@ -574,6 +574,10 @@ func TestTypeMapper_SQLTypeToGoType_Bun(t *testing.T) {
{"boolean", false, "resolvespec_common.SqlBool"}, {"boolean", false, "resolvespec_common.SqlBool"},
{"uuid", false, "resolvespec_common.SqlUUID"}, {"uuid", false, "resolvespec_common.SqlUUID"},
{"jsonb", false, "resolvespec_common.SqlJSONB"}, {"jsonb", false, "resolvespec_common.SqlJSONB"},
{"text[]", true, "resolvespec_common.SqlStringArray"},
{"text[]", false, "resolvespec_common.SqlStringArray"},
{"integer[]", true, "resolvespec_common.SqlInt32Array"},
{"bigint[]", false, "resolvespec_common.SqlInt64Array"},
} }
for _, tt := range tests { for _, tt := range tests {
@@ -586,8 +590,118 @@ func TestTypeMapper_SQLTypeToGoType_Bun(t *testing.T) {
} }
} }
func TestWriter_UpdateIDTypeSafety_Bun(t *testing.T) {
tests := []struct {
name string
pkType string
expectedPK string
expectedLine string
forbidInt32 bool
}{
{"int32_pk", "int", "int32", "m.ID = int32(newid)", false},
{"sql_int16_pk", "smallint", "resolvespec_common.SqlInt16", "m.ID.FromString(fmt.Sprintf(\"%d\", newid))", true},
{"int64_pk", "bigint", "int64", "m.ID = int64(newid)", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
table := models.InitTable("test_table", "public")
table.Columns["id"] = &models.Column{
Name: "id",
Type: tt.pkType,
NotNull: true,
IsPrimaryKey: true,
}
tmpDir := t.TempDir()
opts := &writers.WriterOptions{
PackageName: "models",
OutputPath: filepath.Join(tmpDir, "test.go"),
}
writer := NewWriter(opts)
err := writer.WriteTable(table)
if err != nil {
t.Fatalf("WriteTable failed: %v", err)
}
content, err := os.ReadFile(opts.OutputPath)
if err != nil {
t.Fatalf("Failed to read generated file: %v", err)
}
generated := string(content)
if !strings.Contains(generated, tt.expectedLine) {
t.Errorf("Expected UpdateID to include %s\nGenerated:\n%s", tt.expectedLine, generated)
}
if !strings.Contains(generated, "ID "+tt.expectedPK) {
t.Errorf("Expected generated primary key field type %s\nGenerated:\n%s", tt.expectedPK, generated)
}
if tt.forbidInt32 && strings.Contains(generated, "int32(newid)") {
t.Errorf("UpdateID should not cast to int32 for %s type\nGenerated:\n%s", tt.pkType, generated)
}
if !strings.Contains(generated, "UpdateID(newid int64)") {
t.Errorf("UpdateID should accept int64 parameter\nGenerated:\n%s", generated)
}
})
}
}
func TestWriter_StringPrimaryKeyHelpers_Bun(t *testing.T) {
table := models.InitTable("accounts", "public")
table.Columns["id"] = &models.Column{
Name: "id",
Type: "uuid",
NotNull: true,
IsPrimaryKey: true,
}
tmpDir := t.TempDir()
opts := &writers.WriterOptions{
PackageName: "models",
OutputPath: filepath.Join(tmpDir, "test.go"),
}
writer := NewWriter(opts)
err := writer.WriteTable(table)
if err != nil {
t.Fatalf("WriteTable failed: %v", err)
}
content, err := os.ReadFile(opts.OutputPath)
if err != nil {
t.Fatalf("Failed to read generated file: %v", err)
}
generated := string(content)
expectations := []string{
"resolvespec_common.SqlUUID",
"func (m ModelPublicAccounts) GetID() string",
"return m.ID.String()",
"func (m ModelPublicAccounts) GetIDStr() string",
"func (m ModelPublicAccounts) SetID(newid string)",
"func (m *ModelPublicAccounts) UpdateID(newid string)",
"m.ID.FromString(newid)",
}
for _, expected := range expectations {
if !strings.Contains(generated, expected) {
t.Errorf("Generated code missing expected content: %q\nGenerated:\n%s", expected, generated)
}
}
if strings.Contains(generated, "GetID() int64") || strings.Contains(generated, "UpdateID(newid int64)") {
t.Errorf("String primary keys should not use int64 helper signatures\nGenerated:\n%s", generated)
}
}
func TestTypeMapper_BuildBunTag(t *testing.T) { func TestTypeMapper_BuildBunTag(t *testing.T) {
mapper := NewTypeMapper() mapper := NewTypeMapper("")
tests := []struct { tests := []struct {
name string name string
@@ -685,6 +799,24 @@ func TestTypeMapper_BuildBunTag(t *testing.T) {
}, },
want: []string{"id,", "type:bigserial,", "pk,", "autoincrement,"}, want: []string{"id,", "type:bigserial,", "pk,", "autoincrement,"},
}, },
{
name: "text array type",
column: &models.Column{
Name: "tags",
Type: "text[]",
NotNull: false,
},
want: []string{"tags,", "type:text[],"},
},
{
name: "integer array type",
column: &models.Column{
Name: "scores",
Type: "integer[]",
NotNull: true,
},
want: []string{"scores,", "type:integer[],"},
},
} }
for _, tt := range tests { for _, tt := range tests {
@@ -695,6 +827,50 @@ func TestTypeMapper_BuildBunTag(t *testing.T) {
t.Errorf("BuildBunTag() = %q, missing %q", result, part) t.Errorf("BuildBunTag() = %q, missing %q", result, part)
} }
} }
// resolvespec mode must NOT add "array" — SqlXxxArray uses sql.Scanner
if strings.Contains(result, ",array,") || strings.HasSuffix(result, ",array,") {
t.Errorf("BuildBunTag() = %q, must not contain 'array' in resolvespec mode", result)
}
}) })
} }
} }
func TestTypeMapper_BuildBunTag_StdlibArrayHasArrayTag(t *testing.T) {
mapper := NewTypeMapper(writers.NullableTypeStdlib)
cases := []struct {
name string
column *models.Column
}{
{name: "text array", column: &models.Column{Name: "tags", Type: "text[]"}},
{name: "integer array", column: &models.Column{Name: "scores", Type: "integer[]", NotNull: true}},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
result := mapper.BuildBunTag(tt.column, nil)
if !strings.Contains(result, "array") {
t.Errorf("BuildBunTag() = %q, expected 'array' in stdlib mode", result)
}
})
}
}
func TestTypeMapper_BuildBunTag_PreservesExplicitTypeModifiers(t *testing.T) {
mapper := NewTypeMapper("")
col := &models.Column{
Name: "embedding",
Type: "vector(1536)",
Length: 1536,
Precision: 0,
Scale: 0,
}
tag := mapper.BuildBunTag(col, nil)
if !strings.Contains(tag, "type:vector(1536),") {
t.Fatalf("expected explicit modifier to be preserved, got %q", tag)
}
if strings.Contains(tag, ")(") {
t.Fatalf("type modifier appears duplicated in %q", tag)
}
}

View File

@@ -4,6 +4,7 @@ import (
"encoding/xml" "encoding/xml"
"fmt" "fmt"
"os" "os"
"sort"
"strings" "strings"
"github.com/google/uuid" "github.com/google/uuid"
@@ -155,8 +156,15 @@ func (w *Writer) mapTableFields(table *models.Table) models.DCTXTable {
}, },
} }
columnNames := make([]string, 0, len(table.Columns))
for name := range table.Columns {
columnNames = append(columnNames, name)
}
sort.Strings(columnNames)
i := 0 i := 0
for _, column := range table.Columns { for _, colName := range columnNames {
column := table.Columns[colName]
dctxTable.Fields[i] = w.mapField(column) dctxTable.Fields[i] = w.mapField(column)
i++ i++
} }
@@ -165,12 +173,27 @@ func (w *Writer) mapTableFields(table *models.Table) models.DCTXTable {
} }
func (w *Writer) mapTableKeys(table *models.Table) []models.DCTXKey { func (w *Writer) mapTableKeys(table *models.Table) []models.DCTXKey {
keys := make([]models.DCTXKey, len(table.Indexes)) indexes := make([]*models.Index, 0, len(table.Indexes))
i := 0
for _, index := range table.Indexes { for _, index := range table.Indexes {
keys[i] = w.mapKey(index, table) indexes = append(indexes, index)
i++
} }
// Stable ordering for deterministic output and test reproducibility:
// primary keys first, then lexicographic by index name.
sort.Slice(indexes, func(i, j int) bool {
iPrimary := strings.HasSuffix(indexes[i].Name, "_pkey")
jPrimary := strings.HasSuffix(indexes[j].Name, "_pkey")
if iPrimary != jPrimary {
return iPrimary
}
return indexes[i].Name < indexes[j].Name
})
keys := make([]models.DCTXKey, len(indexes))
for i, index := range indexes {
keys[i] = w.mapKey(index, table)
}
return keys return keys
} }

View File

@@ -5,6 +5,7 @@ import (
"strings" "strings"
"git.warky.dev/wdevs/relspecgo/pkg/models" "git.warky.dev/wdevs/relspecgo/pkg/models"
"git.warky.dev/wdevs/relspecgo/pkg/pgsql"
) )
// TypeMapper handles SQL to Drizzle type conversions // TypeMapper handles SQL to Drizzle type conversions
@@ -18,7 +19,7 @@ func NewTypeMapper() *TypeMapper {
// SQLTypeToDrizzle converts SQL types to Drizzle column type functions // SQLTypeToDrizzle converts SQL types to Drizzle column type functions
// Returns the Drizzle column constructor (e.g., "integer", "varchar", "text") // Returns the Drizzle column constructor (e.g., "integer", "varchar", "text")
func (tm *TypeMapper) SQLTypeToDrizzle(sqlType string) string { func (tm *TypeMapper) SQLTypeToDrizzle(sqlType string) string {
sqlTypeLower := strings.ToLower(sqlType) sqlTypeLower := pgsql.CanonicalizeBaseType(pgsql.ExtractBaseTypeLower(sqlType))
// PostgreSQL type mapping to Drizzle // PostgreSQL type mapping to Drizzle
typeMap := map[string]string{ typeMap := map[string]string{
@@ -87,13 +88,6 @@ func (tm *TypeMapper) SQLTypeToDrizzle(sqlType string) string {
return drizzleType return drizzleType
} }
// Check for partial matches (e.g., "varchar(255)" -> "varchar")
for sqlPattern, drizzleType := range typeMap {
if strings.HasPrefix(sqlTypeLower, sqlPattern) {
return drizzleType
}
}
// Default to text for unknown types // Default to text for unknown types
return "text" return "text"
} }

View File

@@ -48,22 +48,23 @@ func main() {
### CLI Examples ### CLI Examples
```bash ```bash
# Generate GORM models from PostgreSQL database (single file) # Generate GORM models from a DBML schema (default: resolvespec types)
relspec --input pgsql \ relspec convert --from dbml --from-path schema.dbml \
--conn "postgres://localhost/mydb" \ --to gorm --to-path models.go --package models
--output gorm \
--out-file models.go \
--package models
# Generate GORM models with multi-file output (one file per table) # Use standard library database/sql nullable types instead of resolvespec
relspec --input json \ relspec convert --from dbml --from-path schema.dbml \
--in-file schema.json \ --to gorm --to-path models.go --package models \
--output gorm \ --types stdlib
--out-file models/ \
--package models
# Convert DBML to GORM models # Explicitly select resolvespec types (same as omitting --types)
relspec --input dbml --in-file schema.dbml --output gorm --out-file models.go relspec convert --from pgsql --from-conn "postgres://localhost/mydb" \
--to gorm --to-path models.go --package models \
--types resolvespec
# Multi-file output (one file per table)
relspec convert --from json --from-path schema.json \
--to gorm --to-path models/ --package models
``` ```
## Output Modes ## Output Modes
@@ -86,56 +87,84 @@ relspec --input pgsql --conn "..." --output gorm --out-file models/
Files are named: `sql_{schema}_{table}.go` Files are named: `sql_{schema}_{table}.go`
## Generated Code Example ## Generated Code Examples
### Default — resolvespec types (`--types resolvespec`)
```go ```go
package models package models
import ( import (
"time" sql_types "github.com/bitechdev/ResolveSpec/pkg/spectypes"
sql_types "git.warky.dev/wdevs/sql_types"
) )
type ModelUser struct { type ModelUser struct {
ID int64 `gorm:"column:id;type:bigint;primaryKey;autoIncrement" json:"id"` ID string `gorm:"column:id;type:uuid;primaryKey" json:"id"`
Username string `gorm:"column:username;type:varchar(50);not null;uniqueIndex" json:"username"` Username string `gorm:"column:username;type:text;not null" json:"username"`
Email string `gorm:"column:email;type:varchar(100);not null" json:"email"` Email sql_types.SqlString `gorm:"column:email;type:text" json:"email,omitempty"`
CreatedAt time.Time `gorm:"column:created_at;type:timestamp;not null;default:now()" json:"created_at"` Tags sql_types.SqlStringArray `gorm:"column:tags;type:text[];not null;default:'{}'" json:"tags"`
CreatedAt sql_types.SqlTimeStamp `gorm:"column:created_at;type:timestamptz;not null;default:now()" json:"created_at"`
// Relationships
Pos []*ModelPost `gorm:"foreignKey:UserID;references:ID;constraint:OnDelete:CASCADE" json:"pos,omitempty"`
} }
func (ModelUser) TableName() string { func (ModelUser) TableName() string {
return "public.users" return "public.users"
} }
```
type ModelPost struct { ### Standard library — `--types stdlib`
ID int64 `gorm:"column:id;type:bigint;primaryKey" json:"id"`
UserID int64 `gorm:"column:user_id;type:bigint;not null" json:"user_id"`
Title string `gorm:"column:title;type:varchar(200);not null" json:"title"`
Content sql_types.SqlString `gorm:"column:content;type:text" json:"content,omitempty"`
// Belongs to ```go
Use *ModelUser `gorm:"foreignKey:UserID;references:ID" json:"use,omitempty"` package models
import (
"database/sql"
"time"
)
type ModelUser struct {
ID string `gorm:"column:id;type:uuid;primaryKey" json:"id"`
Username string `gorm:"column:username;type:text;not null" json:"username"`
Email sql.NullString `gorm:"column:email;type:text" json:"email,omitempty"`
Tags []string `gorm:"column:tags;type:text[];not null;default:'{}'" json:"tags"`
CreatedAt time.Time `gorm:"column:created_at;type:timestamptz;not null;default:now()" json:"created_at"`
} }
func (ModelPost) TableName() string { func (ModelUser) TableName() string {
return "public.posts" return "public.users"
} }
``` ```
## Writer Options ## Writer Options
### NullableTypes
Controls which Go package is used for nullable column types. Set via the `--types` CLI flag or `WriterOptions.NullableTypes`:
```go
// Use resolvespec types (default — omit NullableTypes or set to "resolvespec")
options := &writers.WriterOptions{
OutputPath: "models.go",
PackageName: "models",
NullableTypes: writers.NullableTypeResolveSpec,
}
// Use standard library database/sql types
options := &writers.WriterOptions{
OutputPath: "models.go",
PackageName: "models",
NullableTypes: writers.NullableTypeStdlib,
}
```
### Metadata Options ### Metadata Options
Configure the writer behavior using metadata in `WriterOptions`: Configure additional writer behavior using metadata in `WriterOptions`:
```go ```go
options := &writers.WriterOptions{ options := &writers.WriterOptions{
OutputPath: "models.go", OutputPath: "models.go",
PackageName: "models", PackageName: "models",
Metadata: map[string]interface{}{ Metadata: map[string]any{
"multi_file": true, // Enable multi-file mode "multi_file": true, // Enable multi-file mode
"populate_refs": true, // Populate RefDatabase/RefSchema "populate_refs": true, // Populate RefDatabase/RefSchema
"generate_get_id_str": true, // Generate GetIDStr() methods "generate_get_id_str": true, // Generate GetIDStr() methods
@@ -145,18 +174,23 @@ options := &writers.WriterOptions{
## Type Mapping ## Type Mapping
| SQL Type | Go Type | Notes | The nullable type package is selected with `--types` (or `WriterOptions.NullableTypes`).
|----------|---------|-------|
| bigint, int8 | int64 | - | | SQL Type | NOT NULL — both | Nullable — resolvespec | Nullable — stdlib |
| integer, int, int4 | int | - | |---|---|---|---|
| smallint, int2 | int16 | - | | `bigint` | `int64` | `SqlInt64` | `sql.NullInt64` |
| varchar, text | string | Not nullable | | `integer` | `int32` | `SqlInt32` | `sql.NullInt32` |
| varchar, text (nullable) | sql_types.SqlString | Nullable | | `smallint` | `int16` | `SqlInt16` | `sql.NullInt16` |
| boolean, bool | bool | - | | `text`, `varchar` | `string` | `SqlString` | `sql.NullString` |
| timestamp, timestamptz | time.Time | - | | `boolean` | `bool` | `SqlBool` | `sql.NullBool` |
| numeric, decimal | float64 | - | | `timestamp`, `timestamptz` | `time.Time` | `SqlTimeStamp` | `sql.NullTime` |
| uuid | string | - | | `numeric`, `decimal` | `float64` | `SqlFloat64` | `sql.NullFloat64` |
| json, jsonb | string | - | | `uuid` | `string` | `SqlUUID` | `sql.NullString` |
| `jsonb` | `string` | `SqlString` | `sql.NullString` |
| `text[]` | `SqlStringArray` | `SqlStringArray` | `[]string` |
| `integer[]` | `SqlInt32Array` | `SqlInt32Array` | `[]int32` |
| `uuid[]` | `SqlUUIDArray` | `SqlUUIDArray` | `[]string` |
| `vector` | `SqlVector` | `SqlVector` | `[]float32` |
## Relationship Generation ## Relationship Generation
@@ -170,7 +204,8 @@ The writer automatically generates relationship fields:
## Notes ## Notes
- Model names are prefixed with "Model" (e.g., `ModelUser`) - Model names are prefixed with "Model" (e.g., `ModelUser`)
- Nullable columns use `sql_types.SqlString`, `sql_types.SqlInt64`, etc. - Nullable columns use `sql_types.SqlString`, `sql_types.SqlInt64`, etc. by default; pass `--types stdlib` to use `sql.NullString`, `sql.NullInt64`, etc. instead
- Array columns use `sql_types.SqlStringArray`, `sql_types.SqlInt32Array`, etc. by default; `--types stdlib` produces plain Go slices (`[]string`, `[]int32`, …)
- Generated code is auto-formatted with `go fmt` - Generated code is auto-formatted with `go fmt`
- JSON tags are automatically added - JSON tags are automatically added
- Supports schema-qualified table names in `TableName()` method - Supports schema-qualified table names in `TableName()` method

View File

@@ -2,6 +2,7 @@ package gorm
import ( import (
"sort" "sort"
"strings"
"git.warky.dev/wdevs/relspecgo/pkg/models" "git.warky.dev/wdevs/relspecgo/pkg/models"
"git.warky.dev/wdevs/relspecgo/pkg/writers" "git.warky.dev/wdevs/relspecgo/pkg/writers"
@@ -26,6 +27,9 @@ type ModelData struct {
Config *MethodConfig Config *MethodConfig
PrimaryKeyField string // Name of the primary key field PrimaryKeyField string // Name of the primary key field
PrimaryKeyType string // Go type of the primary key field PrimaryKeyType string // Go type of the primary key field
PrimaryKeyIsSQL bool // Whether PK uses a SQL wrapper type
PrimaryKeyIsStr bool // Whether helper methods should use string IDs
PrimaryKeyIDType string // Helper method GetID/SetID/UpdateID type
IDColumnName string // Name of the ID column in database IDColumnName string // Name of the ID column in database
Prefix string // 3-letter prefix Prefix string // 3-letter prefix
} }
@@ -136,7 +140,14 @@ func NewModelData(table *models.Table, schema string, typeMapper *TypeMapper, fl
// Sanitize column name to remove backticks // Sanitize column name to remove backticks
safeName := writers.SanitizeStructTagValue(col.Name) safeName := writers.SanitizeStructTagValue(col.Name)
model.PrimaryKeyField = SnakeCaseToPascalCase(safeName) model.PrimaryKeyField = SnakeCaseToPascalCase(safeName)
model.PrimaryKeyType = typeMapper.SQLTypeToGoType(col.Type, col.NotNull) goType := typeMapper.SQLTypeToGoType(col.Type, col.NotNull)
model.PrimaryKeyType = goType
model.PrimaryKeyIsSQL = strings.Contains(goType, "sql_types.") || strings.Contains(goType, "sql.")
model.PrimaryKeyIsStr = isStringLikePrimaryKeyType(goType)
model.PrimaryKeyIDType = "int64"
if model.PrimaryKeyIsStr {
model.PrimaryKeyIDType = "string"
}
model.IDColumnName = safeName model.IDColumnName = safeName
break break
} }
@@ -189,6 +200,15 @@ func formatComment(description, comment string) string {
return comment return comment
} }
func isStringLikePrimaryKeyType(goType string) bool {
switch goType {
case "string", "sql.NullString", "sql_types.SqlString", "sql_types.SqlUUID":
return true
default:
return false
}
}
// resolveFieldNameCollision checks if a field name conflicts with generated method names // resolveFieldNameCollision checks if a field name conflicts with generated method names
// and adds an underscore suffix if there's a collision // and adds an underscore suffix if there's a collision
func resolveFieldNameCollision(fieldName string) string { func resolveFieldNameCollision(fieldName string) string {

View File

@@ -43,26 +43,56 @@ func (m {{.Name}}) SchemaName() string {
{{end}} {{end}}
{{if and .Config.GenerateGetID .PrimaryKeyField}} {{if and .Config.GenerateGetID .PrimaryKeyField}}
// GetID returns the primary key value // GetID returns the primary key value
func (m {{.Name}}) GetID() int64 { func (m {{.Name}}) GetID() {{.PrimaryKeyIDType}} {
{{if .PrimaryKeyIsSQL -}}
{{if .PrimaryKeyIsStr -}}
return m.{{.PrimaryKeyField}}.String()
{{- else -}}
return m.{{.PrimaryKeyField}}.Int64()
{{- end}}
{{- else -}}
{{if .PrimaryKeyIsStr -}}
return m.{{.PrimaryKeyField}}
{{- else -}}
return int64(m.{{.PrimaryKeyField}}) return int64(m.{{.PrimaryKeyField}})
{{- end}}
{{- end}}
} }
{{end}} {{end}}
{{if and .Config.GenerateGetIDStr .PrimaryKeyField}} {{if and .Config.GenerateGetIDStr .PrimaryKeyField}}
// GetIDStr returns the primary key as a string // GetIDStr returns the primary key as a string
func (m {{.Name}}) GetIDStr() string { func (m {{.Name}}) GetIDStr() string {
{{if .PrimaryKeyIsSQL -}}
return m.{{.PrimaryKeyField}}.String()
{{- else if .PrimaryKeyIsStr -}}
return m.{{.PrimaryKeyField}}
{{- else -}}
return fmt.Sprintf("%d", m.{{.PrimaryKeyField}}) return fmt.Sprintf("%d", m.{{.PrimaryKeyField}})
{{- end}}
} }
{{end}} {{end}}
{{if and .Config.GenerateSetID .PrimaryKeyField}} {{if and .Config.GenerateSetID .PrimaryKeyField}}
// SetID sets the primary key value // SetID sets the primary key value
func (m {{.Name}}) SetID(newid int64) { func (m {{.Name}}) SetID(newid {{.PrimaryKeyIDType}}) {
m.UpdateID(newid) m.UpdateID(newid)
} }
{{end}} {{end}}
{{if and .Config.GenerateUpdateID .PrimaryKeyField}} {{if and .Config.GenerateUpdateID .PrimaryKeyField}}
// UpdateID updates the primary key value // UpdateID updates the primary key value
func (m *{{.Name}}) UpdateID(newid int64) { func (m *{{.Name}}) UpdateID(newid {{.PrimaryKeyIDType}}) {
{{if .PrimaryKeyIsSQL -}}
{{if .PrimaryKeyIsStr -}}
m.{{.PrimaryKeyField}}.FromString(newid)
{{- else -}}
m.{{.PrimaryKeyField}}.FromString(fmt.Sprintf("%d", newid))
{{- end}}
{{- else -}}
{{if .PrimaryKeyIsStr -}}
m.{{.PrimaryKeyField}} = newid
{{- else -}}
m.{{.PrimaryKeyField}} = {{.PrimaryKeyType}}(newid) m.{{.PrimaryKeyField}} = {{.PrimaryKeyType}}(newid)
{{- end}}
{{- end}}
} }
{{end}} {{end}}
{{if and .Config.GenerateGetIDName .IDColumnName}} {{if and .Config.GenerateGetIDName .IDColumnName}}

View File

@@ -5,48 +5,56 @@ import (
"strings" "strings"
"git.warky.dev/wdevs/relspecgo/pkg/models" "git.warky.dev/wdevs/relspecgo/pkg/models"
"git.warky.dev/wdevs/relspecgo/pkg/pgsql"
"git.warky.dev/wdevs/relspecgo/pkg/writers" "git.warky.dev/wdevs/relspecgo/pkg/writers"
) )
// TypeMapper handles type conversions between SQL and Go types // TypeMapper handles type conversions between SQL and Go types
type TypeMapper struct { type TypeMapper struct {
// Package alias for sql_types import
sqlTypesAlias string sqlTypesAlias string
typeStyle string // writers.NullableTypeResolveSpec | writers.NullableTypeStdlib
} }
// NewTypeMapper creates a new TypeMapper with default settings // NewTypeMapper creates a new TypeMapper.
func NewTypeMapper() *TypeMapper { // typeStyle should be writers.NullableTypeResolveSpec or writers.NullableTypeStdlib;
// an empty string defaults to resolvespec.
func NewTypeMapper(typeStyle string) *TypeMapper {
if typeStyle == "" {
typeStyle = writers.NullableTypeResolveSpec
}
return &TypeMapper{ return &TypeMapper{
sqlTypesAlias: "sql_types", sqlTypesAlias: "sql_types",
typeStyle: typeStyle,
} }
} }
// SQLTypeToGoType converts a SQL type to its Go equivalent // SQLTypeToGoType converts a SQL type to its Go equivalent.
// Handles nullable types using ResolveSpec sql_types package
func (tm *TypeMapper) SQLTypeToGoType(sqlType string, notNull bool) string { func (tm *TypeMapper) SQLTypeToGoType(sqlType string, notNull bool) string {
// Normalize SQL type (lowercase, remove length/precision) // Array types are handled separately for both styles.
if pgsql.IsArrayType(sqlType) {
return tm.arrayGoType(tm.extractBaseType(sqlType))
}
baseType := tm.extractBaseType(sqlType) baseType := tm.extractBaseType(sqlType)
// If not null, use base Go types if tm.typeStyle == writers.NullableTypeStdlib {
if notNull {
return tm.rawGoType(baseType)
}
return tm.stdlibNullableGoType(baseType)
}
// resolvespec (default)
if notNull { if notNull {
return tm.baseGoType(baseType) return tm.baseGoType(baseType)
} }
// For nullable fields, use sql_types
return tm.nullableGoType(baseType) return tm.nullableGoType(baseType)
} }
// extractBaseType extracts the base type from a SQL type string // extractBaseType extracts the base type from a SQL type string
// Examples: varchar(100) → varchar, numeric(10,2) → numeric // Examples: varchar(100) → varchar, numeric(10,2) → numeric
func (tm *TypeMapper) extractBaseType(sqlType string) string { func (tm *TypeMapper) extractBaseType(sqlType string) string {
sqlType = strings.ToLower(strings.TrimSpace(sqlType)) return pgsql.CanonicalizeBaseType(pgsql.ExtractBaseTypeLower(sqlType))
// Remove everything after '('
if idx := strings.Index(sqlType, "("); idx > 0 {
sqlType = sqlType[:idx]
}
return sqlType
} }
// baseGoType returns the base Go type for a SQL type (not null) // baseGoType returns the base Go type for a SQL type (not null)
@@ -112,6 +120,9 @@ func (tm *TypeMapper) baseGoType(sqlType string) string {
// Other // Other
"money": "float64", "money": "float64",
// pgvector — always uses SqlVector even when NOT NULL
"vector": tm.sqlTypesAlias + ".SqlVector",
} }
if goType, ok := typeMap[sqlType]; ok { if goType, ok := typeMap[sqlType]; ok {
@@ -185,6 +196,9 @@ func (tm *TypeMapper) nullableGoType(sqlType string) string {
// Other // Other
"money": tm.sqlTypesAlias + ".SqlFloat64", "money": tm.sqlTypesAlias + ".SqlFloat64",
// pgvector
"vector": tm.sqlTypesAlias + ".SqlVector",
} }
if goType, ok := typeMap[sqlType]; ok { if goType, ok := typeMap[sqlType]; ok {
@@ -195,6 +209,123 @@ func (tm *TypeMapper) nullableGoType(sqlType string) string {
return tm.sqlTypesAlias + ".SqlString" return tm.sqlTypesAlias + ".SqlString"
} }
// arrayGoType returns the Go type for a PostgreSQL array column.
// The baseElemType is the canonical base type (e.g. "text", "integer").
func (tm *TypeMapper) arrayGoType(baseElemType string) string {
if tm.typeStyle == writers.NullableTypeStdlib {
return tm.stdlibArrayGoType(baseElemType)
}
typeMap := map[string]string{
"text": tm.sqlTypesAlias + ".SqlStringArray", "varchar": tm.sqlTypesAlias + ".SqlStringArray",
"char": tm.sqlTypesAlias + ".SqlStringArray", "character": tm.sqlTypesAlias + ".SqlStringArray",
"citext": tm.sqlTypesAlias + ".SqlStringArray", "bpchar": tm.sqlTypesAlias + ".SqlStringArray",
"inet": tm.sqlTypesAlias + ".SqlStringArray", "cidr": tm.sqlTypesAlias + ".SqlStringArray",
"macaddr": tm.sqlTypesAlias + ".SqlStringArray",
"json": tm.sqlTypesAlias + ".SqlStringArray", "jsonb": tm.sqlTypesAlias + ".SqlStringArray",
"integer": tm.sqlTypesAlias + ".SqlInt32Array", "int": tm.sqlTypesAlias + ".SqlInt32Array",
"int4": tm.sqlTypesAlias + ".SqlInt32Array", "serial": tm.sqlTypesAlias + ".SqlInt32Array",
"smallint": tm.sqlTypesAlias + ".SqlInt16Array", "int2": tm.sqlTypesAlias + ".SqlInt16Array",
"smallserial": tm.sqlTypesAlias + ".SqlInt16Array",
"bigint": tm.sqlTypesAlias + ".SqlInt64Array", "int8": tm.sqlTypesAlias + ".SqlInt64Array",
"bigserial": tm.sqlTypesAlias + ".SqlInt64Array",
"real": tm.sqlTypesAlias + ".SqlFloat32Array", "float4": tm.sqlTypesAlias + ".SqlFloat32Array",
"double precision": tm.sqlTypesAlias + ".SqlFloat64Array", "float8": tm.sqlTypesAlias + ".SqlFloat64Array",
"numeric": tm.sqlTypesAlias + ".SqlFloat64Array", "decimal": tm.sqlTypesAlias + ".SqlFloat64Array",
"money": tm.sqlTypesAlias + ".SqlFloat64Array",
"boolean": tm.sqlTypesAlias + ".SqlBoolArray", "bool": tm.sqlTypesAlias + ".SqlBoolArray",
"uuid": tm.sqlTypesAlias + ".SqlUUIDArray",
}
if goType, ok := typeMap[baseElemType]; ok {
return goType
}
return tm.sqlTypesAlias + ".SqlStringArray"
}
// rawGoType returns the plain Go type for a NOT NULL column in stdlib mode.
func (tm *TypeMapper) rawGoType(sqlType string) string {
typeMap := map[string]string{
"integer": "int32", "int": "int32", "int4": "int32", "serial": "int32",
"smallint": "int16", "int2": "int16", "smallserial": "int16",
"bigint": "int64", "int8": "int64", "bigserial": "int64",
"boolean": "bool", "bool": "bool",
"real": "float32", "float4": "float32",
"double precision": "float64", "float8": "float64",
"numeric": "float64", "decimal": "float64", "money": "float64",
"text": "string", "varchar": "string", "char": "string",
"character": "string", "citext": "string", "bpchar": "string",
"inet": "string", "cidr": "string", "macaddr": "string",
"uuid": "string", "json": "string", "jsonb": "string",
"timestamp": "time.Time",
"timestamp without time zone": "time.Time",
"timestamp with time zone": "time.Time",
"timestamptz": "time.Time",
"date": "time.Time",
"time": "time.Time",
"time without time zone": "time.Time",
"time with time zone": "time.Time",
"timetz": "time.Time",
"bytea": "[]byte",
"vector": "[]float32",
}
if goType, ok := typeMap[sqlType]; ok {
return goType
}
return "string"
}
// stdlibNullableGoType returns the database/sql nullable type for a column.
func (tm *TypeMapper) stdlibNullableGoType(sqlType string) string {
typeMap := map[string]string{
"integer": "sql.NullInt32", "int": "sql.NullInt32", "int4": "sql.NullInt32", "serial": "sql.NullInt32",
"smallint": "sql.NullInt16", "int2": "sql.NullInt16", "smallserial": "sql.NullInt16",
"bigint": "sql.NullInt64", "int8": "sql.NullInt64", "bigserial": "sql.NullInt64",
"boolean": "sql.NullBool", "bool": "sql.NullBool",
"real": "sql.NullFloat64", "float4": "sql.NullFloat64",
"double precision": "sql.NullFloat64", "float8": "sql.NullFloat64",
"numeric": "sql.NullFloat64", "decimal": "sql.NullFloat64", "money": "sql.NullFloat64",
"text": "sql.NullString", "varchar": "sql.NullString", "char": "sql.NullString",
"character": "sql.NullString", "citext": "sql.NullString", "bpchar": "sql.NullString",
"inet": "sql.NullString", "cidr": "sql.NullString", "macaddr": "sql.NullString",
"uuid": "sql.NullString", "json": "sql.NullString", "jsonb": "sql.NullString",
"timestamp": "sql.NullTime",
"timestamp without time zone": "sql.NullTime",
"timestamp with time zone": "sql.NullTime",
"timestamptz": "sql.NullTime",
"date": "sql.NullTime",
"time": "sql.NullTime",
"time without time zone": "sql.NullTime",
"time with time zone": "sql.NullTime",
"timetz": "sql.NullTime",
"bytea": "[]byte",
"vector": "[]float32",
}
if goType, ok := typeMap[sqlType]; ok {
return goType
}
return "sql.NullString"
}
// stdlibArrayGoType returns a plain Go slice type for array columns in stdlib mode.
func (tm *TypeMapper) stdlibArrayGoType(baseElemType string) string {
typeMap := map[string]string{
"text": "[]string", "varchar": "[]string", "char": "[]string",
"character": "[]string", "citext": "[]string", "bpchar": "[]string",
"inet": "[]string", "cidr": "[]string", "macaddr": "[]string",
"uuid": "[]string", "json": "[]string", "jsonb": "[]string",
"integer": "[]int32", "int": "[]int32", "int4": "[]int32", "serial": "[]int32",
"smallint": "[]int16", "int2": "[]int16", "smallserial": "[]int16",
"bigint": "[]int64", "int8": "[]int64", "bigserial": "[]int64",
"real": "[]float32", "float4": "[]float32",
"double precision": "[]float64", "float8": "[]float64",
"numeric": "[]float64", "decimal": "[]float64", "money": "[]float64",
"boolean": "[]bool", "bool": "[]bool",
}
if goType, ok := typeMap[baseElemType]; ok {
return goType
}
return "[]string"
}
// BuildGormTag generates a complete GORM tag string for a column // BuildGormTag generates a complete GORM tag string for a column
func (tm *TypeMapper) BuildGormTag(column *models.Column, table *models.Table) string { func (tm *TypeMapper) BuildGormTag(column *models.Column, table *models.Table) string {
var parts []string var parts []string
@@ -209,9 +340,10 @@ func (tm *TypeMapper) BuildGormTag(column *models.Column, table *models.Table) s
// Include length, precision, scale if present // Include length, precision, scale if present
// Sanitize type to remove backticks // Sanitize type to remove backticks
typeStr := writers.SanitizeStructTagValue(column.Type) typeStr := writers.SanitizeStructTagValue(column.Type)
if column.Length > 0 { hasExplicitTypeModifier := pgsql.HasExplicitTypeModifier(typeStr)
if !hasExplicitTypeModifier && column.Length > 0 {
typeStr = fmt.Sprintf("%s(%d)", typeStr, column.Length) typeStr = fmt.Sprintf("%s(%d)", typeStr, column.Length)
} else if column.Precision > 0 { } else if !hasExplicitTypeModifier && column.Precision > 0 {
if column.Scale > 0 { if column.Scale > 0 {
typeStr = fmt.Sprintf("%s(%d,%d)", typeStr, column.Precision, column.Scale) typeStr = fmt.Sprintf("%s(%d,%d)", typeStr, column.Precision, column.Scale)
} else { } else {
@@ -335,7 +467,16 @@ func (tm *TypeMapper) NeedsFmtImport(generateGetIDStr bool) bool {
return generateGetIDStr return generateGetIDStr
} }
// GetSQLTypesImport returns the import path for sql_types // GetSQLTypesImport returns the import path for the ResolveSpec spectypes package.
func (tm *TypeMapper) GetSQLTypesImport() string { func (tm *TypeMapper) GetSQLTypesImport() string {
return "github.com/bitechdev/ResolveSpec/pkg/spectypes" return "github.com/bitechdev/ResolveSpec/pkg/spectypes"
} }
// GetNullableTypeImportLine returns the full Go import line for the nullable type
// package (ready to pass to AddImport). Returns empty string when no import is needed.
func (tm *TypeMapper) GetNullableTypeImportLine() string {
if tm.typeStyle == writers.NullableTypeStdlib {
return "\"database/sql\""
}
return fmt.Sprintf("%s \"%s\"", tm.sqlTypesAlias, tm.GetSQLTypesImport())
}

View File

@@ -24,7 +24,7 @@ type Writer struct {
func NewWriter(options *writers.WriterOptions) *Writer { func NewWriter(options *writers.WriterOptions) *Writer {
w := &Writer{ w := &Writer{
options: options, options: options,
typeMapper: NewTypeMapper(), typeMapper: NewTypeMapper(options.NullableTypes),
config: LoadMethodConfigFromMetadata(options.Metadata), config: LoadMethodConfigFromMetadata(options.Metadata),
} }
@@ -77,8 +77,8 @@ func (w *Writer) writeSingleFile(db *models.Database) error {
packageName := w.getPackageName() packageName := w.getPackageName()
templateData := NewTemplateData(packageName, w.config) templateData := NewTemplateData(packageName, w.config)
// Add sql_types import (always needed for nullable types) // Add nullable types import (resolvespec or stdlib depending on options)
templateData.AddImport(fmt.Sprintf("sql_types \"%s\"", w.typeMapper.GetSQLTypesImport())) templateData.AddImport(w.typeMapper.GetNullableTypeImportLine())
// Collect all models // Collect all models
for _, schema := range db.Schemas { for _, schema := range db.Schemas {
@@ -99,8 +99,8 @@ func (w *Writer) writeSingleFile(db *models.Database) error {
} }
} }
// Add fmt import if GetIDStr is enabled // Add fmt import when generated helper methods need string formatting.
if w.config.GenerateGetIDStr { if w.needsFmtImport(templateData.Models) {
templateData.AddImport("\"fmt\"") templateData.AddImport("\"fmt\"")
} }
@@ -171,8 +171,8 @@ func (w *Writer) writeMultiFile(db *models.Database) error {
// Create template data for this single table // Create template data for this single table
templateData := NewTemplateData(packageName, w.config) templateData := NewTemplateData(packageName, w.config)
// Add sql_types import // Add nullable types import (resolvespec or stdlib depending on options)
templateData.AddImport(fmt.Sprintf("sql_types \"%s\"", w.typeMapper.GetSQLTypesImport())) templateData.AddImport(w.typeMapper.GetNullableTypeImportLine())
// Create model data // Create model data
modelData := NewModelData(table, schema.Name, w.typeMapper, w.options.FlattenSchema) modelData := NewModelData(table, schema.Name, w.typeMapper, w.options.FlattenSchema)
@@ -189,8 +189,8 @@ func (w *Writer) writeMultiFile(db *models.Database) error {
} }
} }
// Add fmt import if GetIDStr is enabled // Add fmt import when generated helper methods need string formatting.
if w.config.GenerateGetIDStr { if w.needsFmtImport(templateData.Models) {
templateData.AddImport("\"fmt\"") templateData.AddImport("\"fmt\"")
} }
@@ -295,6 +295,26 @@ func (w *Writer) addRelationshipFields(modelData *ModelData, table *models.Table
} }
} }
func (w *Writer) needsFmtImport(models []*ModelData) bool {
if w.config.GenerateGetIDStr {
for _, model := range models {
if model.PrimaryKeyField != "" && !model.PrimaryKeyIsSQL && !model.PrimaryKeyIsStr {
return true
}
}
}
if w.config.GenerateUpdateID {
for _, model := range models {
if model.PrimaryKeyField != "" && model.PrimaryKeyIsSQL && !model.PrimaryKeyIsStr {
return true
}
}
}
return false
}
// findTable finds a table by schema and name in the database // findTable finds a table by schema and name in the database
func (w *Writer) findTable(schemaName, tableName string, db *models.Database) *models.Table { func (w *Writer) findTable(schemaName, tableName string, db *models.Database) *models.Table {
for _, schema := range db.Schemas { for _, schema := range db.Schemas {

View File

@@ -598,6 +598,55 @@ func TestWriter_UpdateIDTypeSafety(t *testing.T) {
} }
} }
func TestWriter_StringPrimaryKeyHelpers_Gorm(t *testing.T) {
table := models.InitTable("accounts", "public")
table.Columns["id"] = &models.Column{
Name: "id",
Type: "uuid",
NotNull: true,
IsPrimaryKey: true,
}
tmpDir := t.TempDir()
opts := &writers.WriterOptions{
PackageName: "models",
OutputPath: filepath.Join(tmpDir, "test.go"),
}
writer := NewWriter(opts)
err := writer.WriteTable(table)
if err != nil {
t.Fatalf("WriteTable failed: %v", err)
}
content, err := os.ReadFile(opts.OutputPath)
if err != nil {
t.Fatalf("Failed to read generated file: %v", err)
}
generated := string(content)
expectations := []string{
"ID string",
"func (m ModelPublicAccounts) GetID() string",
"return m.ID",
"func (m ModelPublicAccounts) GetIDStr() string",
"func (m ModelPublicAccounts) SetID(newid string)",
"func (m *ModelPublicAccounts) UpdateID(newid string)",
"m.ID = newid",
}
for _, expected := range expectations {
if !strings.Contains(generated, expected) {
t.Errorf("Generated code missing expected content: %q\nGenerated:\n%s", expected, generated)
}
}
if strings.Contains(generated, "GetID() int64") || strings.Contains(generated, "UpdateID(newid int64)") {
t.Errorf("String primary keys should not use int64 helper signatures\nGenerated:\n%s", generated)
}
}
func TestNameConverter_SnakeCaseToPascalCase(t *testing.T) { func TestNameConverter_SnakeCaseToPascalCase(t *testing.T) {
tests := []struct { tests := []struct {
input string input string
@@ -643,7 +692,7 @@ func TestNameConverter_Pluralize(t *testing.T) {
} }
func TestTypeMapper_SQLTypeToGoType(t *testing.T) { func TestTypeMapper_SQLTypeToGoType(t *testing.T) {
mapper := NewTypeMapper() mapper := NewTypeMapper("")
tests := []struct { tests := []struct {
sqlType string sqlType string
@@ -658,6 +707,10 @@ func TestTypeMapper_SQLTypeToGoType(t *testing.T) {
{"timestamp", false, "sql_types.SqlTimeStamp"}, {"timestamp", false, "sql_types.SqlTimeStamp"},
{"boolean", true, "bool"}, {"boolean", true, "bool"},
{"boolean", false, "sql_types.SqlBool"}, {"boolean", false, "sql_types.SqlBool"},
{"text[]", true, "sql_types.SqlStringArray"},
{"text[]", false, "sql_types.SqlStringArray"},
{"integer[]", true, "sql_types.SqlInt32Array"},
{"bigint[]", false, "sql_types.SqlInt64Array"},
} }
for _, tt := range tests { for _, tt := range tests {
@@ -669,3 +722,38 @@ func TestTypeMapper_SQLTypeToGoType(t *testing.T) {
}) })
} }
} }
func TestTypeMapper_BuildGormTag_ArrayType(t *testing.T) {
mapper := NewTypeMapper("")
col := &models.Column{
Name: "tags",
Type: "text[]",
NotNull: false,
}
tag := mapper.BuildGormTag(col, nil)
if !strings.Contains(tag, "type:text[]") {
t.Fatalf("expected array type to be preserved, got %q", tag)
}
}
func TestTypeMapper_BuildGormTag_PreservesExplicitTypeModifiers(t *testing.T) {
mapper := NewTypeMapper("")
col := &models.Column{
Name: "embedding",
Type: "vector(1536)",
Length: 1536,
Precision: 0,
Scale: 0,
}
tag := mapper.BuildGormTag(col, nil)
if !strings.Contains(tag, "type:vector(1536)") {
t.Fatalf("expected explicit modifier to be preserved, got %q", tag)
}
if strings.Contains(tag, ")(") {
t.Fatalf("type modifier appears duplicated in %q", tag)
}
}

View File

@@ -4,6 +4,7 @@ import (
"strings" "strings"
"git.warky.dev/wdevs/relspecgo/pkg/models" "git.warky.dev/wdevs/relspecgo/pkg/models"
"git.warky.dev/wdevs/relspecgo/pkg/pgsql"
) )
func (w *Writer) sqlTypeToGraphQL(sqlType string, column *models.Column, table *models.Table, schema *models.Schema) string { func (w *Writer) sqlTypeToGraphQL(sqlType string, column *models.Column, table *models.Table, schema *models.Schema) string {
@@ -33,12 +34,11 @@ func (w *Writer) sqlTypeToGraphQL(sqlType string, column *models.Column, table *
} }
// Standard type mappings // Standard type mappings
baseType := strings.Split(sqlType, "(")[0] // Remove length/precision baseType := pgsql.CanonicalizeBaseType(pgsql.ExtractBaseTypeLower(sqlType))
baseType = strings.TrimSpace(baseType)
// Handle array types // Handle array types
if strings.HasSuffix(baseType, "[]") { if pgsql.IsArrayType(sqlType) {
elemType := strings.TrimSuffix(baseType, "[]") elemType := pgsql.CanonicalizeBaseType(pgsql.ExtractBaseTypeLower(pgsql.ElementType(sqlType)))
gqlType := w.mapBaseTypeToGraphQL(elemType) gqlType := w.mapBaseTypeToGraphQL(elemType)
return "[" + gqlType + "]" return "[" + gqlType + "]"
} }
@@ -108,8 +108,7 @@ func (w *Writer) sqlTypeToCustomScalar(sqlType string) string {
"date": "Date", "date": "Date",
} }
baseType := strings.Split(sqlType, "(")[0] baseType := pgsql.CanonicalizeBaseType(pgsql.ExtractBaseTypeLower(sqlType))
baseType = strings.TrimSpace(baseType)
if scalar, ok := scalarMap[baseType]; ok { if scalar, ok := scalarMap[baseType]; ok {
return scalar return scalar
@@ -132,8 +131,7 @@ func (w *Writer) isIntegerType(sqlType string) bool {
"smallserial": true, "smallserial": true,
} }
baseType := strings.Split(sqlType, "(")[0] baseType := pgsql.CanonicalizeBaseType(pgsql.ExtractBaseTypeLower(sqlType))
baseType = strings.TrimSpace(baseType)
return intTypes[baseType] return intTypes[baseType]
} }

View File

@@ -31,6 +31,10 @@ type MigrationWriter struct {
// NewMigrationWriter creates a new templated migration writer // NewMigrationWriter creates a new templated migration writer
func NewMigrationWriter(options *writers.WriterOptions) (*MigrationWriter, error) { func NewMigrationWriter(options *writers.WriterOptions) (*MigrationWriter, error) {
if options == nil {
options = &writers.WriterOptions{}
}
executor, err := NewTemplateExecutor(options.FlattenSchema) executor, err := NewTemplateExecutor(options.FlattenSchema)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create template executor: %w", err) return nil, fmt.Errorf("failed to create template executor: %w", err)
@@ -44,6 +48,16 @@ func NewMigrationWriter(options *writers.WriterOptions) (*MigrationWriter, error
// WriteMigration generates migration scripts using templates // WriteMigration generates migration scripts using templates
func (w *MigrationWriter) WriteMigration(model *models.Database, current *models.Database) error { func (w *MigrationWriter) WriteMigration(model *models.Database, current *models.Database) error {
if model == nil {
return fmt.Errorf("model database is required")
}
if w.options == nil {
w.options = &writers.WriterOptions{}
}
if current == nil {
current = models.InitDatabase(model.Name)
}
var writer io.Writer var writer io.Writer
var file *os.File var file *os.File
var err error var err error
@@ -86,9 +100,16 @@ func (w *MigrationWriter) WriteMigration(model *models.Database, current *models
// Process each schema in the model // Process each schema in the model
for _, modelSchema := range model.Schemas { for _, modelSchema := range model.Schemas {
if modelSchema == nil {
continue
}
// Find corresponding schema in current database // Find corresponding schema in current database
var currentSchema *models.Schema var currentSchema *models.Schema
for _, cs := range current.Schemas { for _, cs := range current.Schemas {
if cs == nil {
continue
}
if strings.EqualFold(cs.Name, modelSchema.Name) { if strings.EqualFold(cs.Name, modelSchema.Name) {
currentSchema = cs currentSchema = cs
break break
@@ -329,8 +350,12 @@ func (w *MigrationWriter) generateAlterTableScripts(schema *models.Schema, model
// Column doesn't exist, add it // Column doesn't exist, add it
defaultVal := "" defaultVal := ""
if modelCol.Default != nil { if modelCol.Default != nil {
if value, ok := modelCol.Default.(string); ok {
defaultVal = writers.QuoteDefaultValue(value, modelCol.Type)
} else {
defaultVal = fmt.Sprintf("%v", modelCol.Default) defaultVal = fmt.Sprintf("%v", modelCol.Default)
} }
}
sql, err := w.executor.ExecuteAddColumn(AddColumnData{ sql, err := w.executor.ExecuteAddColumn(AddColumnData{
SchemaName: schema.Name, SchemaName: schema.Name,
@@ -382,8 +407,12 @@ func (w *MigrationWriter) generateAlterTableScripts(schema *models.Schema, model
setDefault := modelCol.Default != nil setDefault := modelCol.Default != nil
defaultVal := "" defaultVal := ""
if setDefault { if setDefault {
if value, ok := modelCol.Default.(string); ok {
defaultVal = writers.QuoteDefaultValue(value, modelCol.Type)
} else {
defaultVal = fmt.Sprintf("%v", modelCol.Default) defaultVal = fmt.Sprintf("%v", modelCol.Default)
} }
}
sql, err := w.executor.ExecuteAlterColumnDefault(AlterColumnDefaultData{ sql, err := w.executor.ExecuteAlterColumnDefault(AlterColumnDefaultData{
SchemaName: schema.Name, SchemaName: schema.Name,
@@ -537,12 +566,17 @@ func (w *MigrationWriter) generateIndexScripts(model *models.Schema, current *mo
indexType = modelIndex.Type indexType = modelIndex.Type
} }
columnExprs := buildIndexColumnExpressions(modelTable, modelIndex, indexType)
if len(columnExprs) == 0 {
continue
}
sql, err := w.executor.ExecuteCreateIndex(CreateIndexData{ sql, err := w.executor.ExecuteCreateIndex(CreateIndexData{
SchemaName: model.Name, SchemaName: model.Name,
TableName: modelTable.Name, TableName: modelTable.Name,
IndexName: indexName, IndexName: indexName,
IndexType: indexType, IndexType: indexType,
Columns: strings.Join(modelIndex.Columns, ", "), Columns: strings.Join(columnExprs, ", "),
Unique: modelIndex.Unique, Unique: modelIndex.Unique,
}) })
if err != nil { if err != nil {
@@ -565,6 +599,27 @@ func (w *MigrationWriter) generateIndexScripts(model *models.Schema, current *mo
return scripts, nil return scripts, nil
} }
func buildIndexColumnExpressions(table *models.Table, index *models.Index, indexType string) []string {
columnExprs := make([]string, 0, len(index.Columns))
for _, colName := range index.Columns {
colExpr := colName
if table != nil {
if col, ok := resolveIndexColumn(table, colName); ok && col != nil {
colExpr = col.SQLName()
if strings.EqualFold(indexType, "gin") && isTextType(col.Type) {
opClass := extractOperatorClass(index.Comment)
if opClass == "" {
opClass = "gin_trgm_ops"
}
colExpr = fmt.Sprintf("%s %s", col.SQLName(), opClass)
}
}
}
columnExprs = append(columnExprs, colExpr)
}
return columnExprs
}
// generateForeignKeyScripts generates ADD CONSTRAINT FOREIGN KEY scripts using templates // generateForeignKeyScripts generates ADD CONSTRAINT FOREIGN KEY scripts using templates
func (w *MigrationWriter) generateForeignKeyScripts(model *models.Schema, current *models.Schema) ([]MigrationScript, error) { func (w *MigrationWriter) generateForeignKeyScripts(model *models.Schema, current *models.Schema) ([]MigrationScript, error) {
scripts := make([]MigrationScript, 0) scripts := make([]MigrationScript, 0)

View File

@@ -57,6 +57,169 @@ func TestWriteMigration_NewTable(t *testing.T) {
} }
} }
func TestWriteMigration_ArrayDefault(t *testing.T) {
current := models.InitDatabase("testdb")
currentSchema := models.InitSchema("public")
current.Schemas = append(current.Schemas, currentSchema)
model := models.InitDatabase("testdb")
modelSchema := models.InitSchema("public")
table := models.InitTable("plans", "public")
tagsCol := models.InitColumn("tags", "plans", "public")
tagsCol.Type = "text[]"
tagsCol.NotNull = true
tagsCol.Default = "''{}''"
table.Columns["tags"] = tagsCol
modelSchema.Tables = append(modelSchema.Tables, table)
model.Schemas = append(model.Schemas, modelSchema)
var buf bytes.Buffer
writer, err := NewMigrationWriter(&writers.WriterOptions{})
if err != nil {
t.Fatalf("Failed to create writer: %v", err)
}
writer.writer = &buf
err = writer.WriteMigration(model, current)
if err != nil {
t.Fatalf("WriteMigration failed: %v", err)
}
output := buf.String()
if !strings.Contains(output, "tags text[] DEFAULT '{}' NOT NULL") {
t.Fatalf("expected normalized array default in migration, got:\n%s", output)
}
if strings.Contains(output, "'''{}'''") {
t.Fatalf("migration still contains triple-quoted array default:\n%s", output)
}
}
func TestWriteMigration_GinIndexOnTextUsesTrigramOperatorClass(t *testing.T) {
current := models.InitDatabase("testdb")
currentSchema := models.InitSchema("public")
current.Schemas = append(current.Schemas, currentSchema)
model := models.InitDatabase("testdb")
modelSchema := models.InitSchema("public")
table := models.InitTable("articles", "public")
titleCol := models.InitColumn("title", "articles", "public")
titleCol.Type = "text"
table.Columns["title"] = titleCol
index := &models.Index{
Name: "idx_articles_title_gin",
Type: "gin",
Columns: []string{"title"},
}
table.Indexes[index.Name] = index
modelSchema.Tables = append(modelSchema.Tables, table)
model.Schemas = append(model.Schemas, modelSchema)
var buf bytes.Buffer
writer, err := NewMigrationWriter(&writers.WriterOptions{})
if err != nil {
t.Fatalf("Failed to create writer: %v", err)
}
writer.writer = &buf
if err := writer.WriteMigration(model, current); err != nil {
t.Fatalf("WriteMigration failed: %v", err)
}
output := buf.String()
if !strings.Contains(output, "USING gin (title gin_trgm_ops)") {
t.Fatalf("expected GIN text index to include gin_trgm_ops, got:\n%s", output)
}
}
func TestWriteMigration_GinIndexOnQuotedTextColumnUsesTrigramOperatorClass(t *testing.T) {
current := models.InitDatabase("testdb")
currentSchema := models.InitSchema("public")
current.Schemas = append(current.Schemas, currentSchema)
model := models.InitDatabase("testdb")
modelSchema := models.InitSchema("public")
table := models.InitTable("agent_personas", "public")
nameCol := models.InitColumn("name", "agent_personas", "public")
nameCol.Type = "text"
table.Columns["name"] = nameCol
index := &models.Index{
Name: "idx_agent_personas_name_gin",
Type: "gin",
Columns: []string{`"name"`},
}
table.Indexes[index.Name] = index
modelSchema.Tables = append(modelSchema.Tables, table)
model.Schemas = append(model.Schemas, modelSchema)
var buf bytes.Buffer
writer, err := NewMigrationWriter(&writers.WriterOptions{})
if err != nil {
t.Fatalf("Failed to create writer: %v", err)
}
writer.writer = &buf
if err := writer.WriteMigration(model, current); err != nil {
t.Fatalf("WriteMigration failed: %v", err)
}
output := buf.String()
if !strings.Contains(output, "USING gin (name gin_trgm_ops)") {
t.Fatalf("expected quoted text column GIN index to include gin_trgm_ops, got:\n%s", output)
}
}
func TestWriteMigration_GinIndexOnTextArrayDoesNotUseTrigramOperatorClass(t *testing.T) {
current := models.InitDatabase("testdb")
currentSchema := models.InitSchema("public")
current.Schemas = append(current.Schemas, currentSchema)
model := models.InitDatabase("testdb")
modelSchema := models.InitSchema("public")
table := models.InitTable("plans", "public")
tagsCol := models.InitColumn("tags", "plans", "public")
tagsCol.Type = "text[]"
table.Columns["tags"] = tagsCol
index := &models.Index{
Name: "idx_plans_tags",
Type: "gin",
Columns: []string{"tags"},
}
table.Indexes[index.Name] = index
modelSchema.Tables = append(modelSchema.Tables, table)
model.Schemas = append(model.Schemas, modelSchema)
var buf bytes.Buffer
writer, err := NewMigrationWriter(&writers.WriterOptions{})
if err != nil {
t.Fatalf("Failed to create writer: %v", err)
}
writer.writer = &buf
if err := writer.WriteMigration(model, current); err != nil {
t.Fatalf("WriteMigration failed: %v", err)
}
output := buf.String()
if !strings.Contains(output, "USING gin (tags)") {
t.Fatalf("expected GIN array index without explicit trigram opclass, got:\n%s", output)
}
if strings.Contains(output, "gin_trgm_ops") {
t.Fatalf("did not expect gin_trgm_ops for text[] migration index, got:\n%s", output)
}
}
func TestWriteMigration_WithAudit(t *testing.T) { func TestWriteMigration_WithAudit(t *testing.T) {
// Current database (empty) // Current database (empty)
current := models.InitDatabase("testdb") current := models.InitDatabase("testdb")
@@ -282,3 +445,46 @@ func TestWriteMigration_NumericConstraintNames(t *testing.T) {
t.Error("Migration missing FOREIGN KEY") t.Error("Migration missing FOREIGN KEY")
} }
} }
func TestNewMigrationWriter_NilOptions(t *testing.T) {
writer, err := NewMigrationWriter(nil)
if err != nil {
t.Fatalf("NewMigrationWriter(nil) returned error: %v", err)
}
if writer == nil {
t.Fatal("expected writer instance")
}
if writer.options == nil {
t.Fatal("expected default writer options to be initialized")
}
}
func TestWriteMigration_NilCurrentTreatsDatabaseAsEmpty(t *testing.T) {
model := models.InitDatabase("testdb")
modelSchema := models.InitSchema("public")
table := models.InitTable("users", "public")
idCol := models.InitColumn("id", "users", "public")
idCol.Type = "integer"
idCol.NotNull = true
table.Columns["id"] = idCol
modelSchema.Tables = append(modelSchema.Tables, table)
model.Schemas = append(model.Schemas, modelSchema)
var buf bytes.Buffer
writer, err := NewMigrationWriter(nil)
if err != nil {
t.Fatalf("Failed to create writer: %v", err)
}
writer.writer = &buf
if err := writer.WriteMigration(model, nil); err != nil {
t.Fatalf("WriteMigration with nil current failed: %v", err)
}
output := buf.String()
if !strings.Contains(output, "CREATE TABLE") {
t.Fatalf("expected CREATE TABLE in migration output, got:\n%s", output)
}
}

View File

@@ -8,6 +8,7 @@ import (
"text/template" "text/template"
"git.warky.dev/wdevs/relspecgo/pkg/models" "git.warky.dev/wdevs/relspecgo/pkg/models"
"git.warky.dev/wdevs/relspecgo/pkg/writers"
) )
//go:embed templates/*.tmpl //go:embed templates/*.tmpl
@@ -266,6 +267,7 @@ type CreatePrimaryKeyWithAutoGenCheckData struct {
ConstraintName string ConstraintName string
AutoGenNames string // Comma-separated list of names like "'name1', 'name2'" AutoGenNames string // Comma-separated list of names like "'name1', 'name2'"
Columns string Columns string
ColumnNames string // Comma-separated list of quoted column names like "'id', 'tenant_id'"
} }
// Execute methods for each template // Execute methods for each template
@@ -495,8 +497,12 @@ func BuildCreateTableData(schemaName string, table *models.Table) CreateTableDat
NotNull: col.NotNull, NotNull: col.NotNull,
} }
if col.Default != nil { if col.Default != nil {
if value, ok := col.Default.(string); ok {
colData.Default = writers.QuoteDefaultValue(value, col.Type)
} else {
colData.Default = fmt.Sprintf("%v", col.Default) colData.Default = fmt.Sprintf("%v", col.Default)
} }
}
columns = append(columns, colData) columns = append(columns, colData)
} }

View File

@@ -1,26 +1,42 @@
DO $$ DO $$
DECLARE DECLARE
auto_pk_name text; current_pk_name text;
current_pk_matches boolean := false;
BEGIN BEGIN
-- Drop auto-generated primary key if it exists SELECT tc.constraint_name,
SELECT constraint_name INTO auto_pk_name COALESCE(
FROM information_schema.table_constraints ARRAY(
WHERE table_schema = '{{.SchemaName}}' SELECT a.attname::text
AND table_name = '{{.TableName}}' FROM pg_constraint c
AND constraint_type = 'PRIMARY KEY' JOIN pg_class t ON t.oid = c.conrelid
AND constraint_name IN ({{.AutoGenNames}}); JOIN pg_namespace n ON n.oid = t.relnamespace
JOIN unnest(c.conkey) WITH ORDINALITY AS cols(attnum, ord)
ON TRUE
JOIN pg_attribute a
ON a.attrelid = t.oid
AND a.attnum = cols.attnum
WHERE c.contype = 'p'
AND n.nspname = '{{.SchemaName}}'
AND t.relname = '{{.TableName}}'
ORDER BY cols.ord
),
ARRAY[]::text[]
) = ARRAY[{{.ColumnNames}}]
INTO current_pk_name, current_pk_matches
FROM information_schema.table_constraints tc
WHERE tc.table_schema = '{{.SchemaName}}'
AND tc.table_name = '{{.TableName}}'
AND tc.constraint_type = 'PRIMARY KEY';
IF auto_pk_name IS NOT NULL THEN IF current_pk_name IS NOT NULL
EXECUTE 'ALTER TABLE {{qual_table .SchemaName .TableName}} DROP CONSTRAINT ' || quote_ident(auto_pk_name); AND NOT current_pk_matches
AND current_pk_name IN ({{.AutoGenNames}}) THEN
EXECUTE 'ALTER TABLE {{qual_table .SchemaName .TableName}} DROP CONSTRAINT ' || quote_ident(current_pk_name);
END IF; END IF;
-- Add named primary key if it doesn't exist -- Add the desired primary key only when no matching primary key already exists.
IF NOT EXISTS ( IF current_pk_name IS NULL
SELECT 1 FROM information_schema.table_constraints OR (NOT current_pk_matches AND current_pk_name IN ({{.AutoGenNames}})) THEN
WHERE table_schema = '{{.SchemaName}}'
AND table_name = '{{.TableName}}'
AND constraint_name = '{{.ConstraintName}}'
) THEN
ALTER TABLE {{qual_table .SchemaName .TableName}} ADD CONSTRAINT {{quote_ident .ConstraintName}} PRIMARY KEY ({{.Columns}}); ALTER TABLE {{qual_table .SchemaName .TableName}} ADD CONSTRAINT {{quote_ident .ConstraintName}} PRIMARY KEY ({{.Columns}});
END IF; END IF;
END; END;

View File

@@ -10,8 +10,6 @@ import (
"strings" "strings"
"time" "time"
"github.com/jackc/pgx/v5"
"git.warky.dev/wdevs/relspecgo/pkg/models" "git.warky.dev/wdevs/relspecgo/pkg/models"
"git.warky.dev/wdevs/relspecgo/pkg/pgsql" "git.warky.dev/wdevs/relspecgo/pkg/pgsql"
"git.warky.dev/wdevs/relspecgo/pkg/writers" "git.warky.dev/wdevs/relspecgo/pkg/writers"
@@ -230,6 +228,7 @@ func (w *Writer) GenerateSchemaStatements(schema *models.Schema) ([]string, erro
ConstraintName: pkName, ConstraintName: pkName,
AutoGenNames: formatStringList(autoGenPKNames), AutoGenNames: formatStringList(autoGenPKNames),
Columns: strings.Join(pkColumns, ", "), Columns: strings.Join(pkColumns, ", "),
ColumnNames: formatStringList(pkColumns),
} }
stmt, err := w.executor.ExecuteCreatePrimaryKeyWithAutoGenCheck(data) stmt, err := w.executor.ExecuteCreatePrimaryKeyWithAutoGenCheck(data)
@@ -262,7 +261,7 @@ func (w *Writer) GenerateSchemaStatements(schema *models.Schema) ([]string, erro
columnExprs := make([]string, 0, len(index.Columns)) columnExprs := make([]string, 0, len(index.Columns))
for _, colName := range index.Columns { for _, colName := range index.Columns {
colExpr := colName colExpr := colName
if col, ok := table.Columns[colName]; ok { if col, ok := resolveIndexColumn(table, colName); ok {
// For GIN indexes on text columns, add operator class // For GIN indexes on text columns, add operator class
if strings.EqualFold(indexType, "gin") && isTextType(col.Type) { if strings.EqualFold(indexType, "gin") && isTextType(col.Type) {
opClass := extractOperatorClass(index.Comment) opClass := extractOperatorClass(index.Comment)
@@ -493,18 +492,19 @@ func (w *Writer) generateColumnDefinition(col *models.Column) string {
// Type with length/precision - convert to valid PostgreSQL type // Type with length/precision - convert to valid PostgreSQL type
baseType := pgsql.ConvertSQLType(col.Type) baseType := pgsql.ConvertSQLType(col.Type)
typeStr := baseType typeStr := baseType
hasExplicitTypeModifier := pgsql.HasExplicitTypeModifier(baseType)
// Only add size specifiers for types that support them // Only add size specifiers for types that support them
if col.Length > 0 && col.Precision == 0 { if !hasExplicitTypeModifier && col.Length > 0 && col.Precision == 0 {
if supportsLength(baseType) { if pgsql.SupportsLength(baseType) {
typeStr = fmt.Sprintf("%s(%d)", baseType, col.Length) typeStr = fmt.Sprintf("%s(%d)", baseType, col.Length)
} else if isTextTypeWithoutLength(baseType) { } else if isTextTypeWithoutLength(baseType) {
// Convert text with length to varchar // Convert text with length to varchar
typeStr = fmt.Sprintf("varchar(%d)", col.Length) typeStr = fmt.Sprintf("varchar(%d)", col.Length)
} }
// For types that don't support length (integer, bigint, etc.), ignore the length // For types that don't support length (integer, bigint, etc.), ignore the length
} else if col.Precision > 0 { } else if !hasExplicitTypeModifier && col.Precision > 0 {
if supportsPrecision(baseType) { if pgsql.SupportsPrecision(baseType) {
if col.Scale > 0 { if col.Scale > 0 {
typeStr = fmt.Sprintf("%s(%d,%d)", baseType, col.Precision, col.Scale) typeStr = fmt.Sprintf("%s(%d,%d)", baseType, col.Precision, col.Scale)
} else { } else {
@@ -524,15 +524,7 @@ func (w *Writer) generateColumnDefinition(col *models.Column) string {
if col.Default != nil { if col.Default != nil {
switch v := col.Default.(type) { switch v := col.Default.(type) {
case string: case string:
// Strip backticks - DBML uses them for SQL expressions but PostgreSQL doesn't parts = append(parts, fmt.Sprintf("DEFAULT %s", writers.QuoteDefaultValue(stripBackticks(v), col.Type)))
cleanDefault := stripBackticks(v)
if strings.HasPrefix(cleanDefault, "nextval") || strings.HasPrefix(cleanDefault, "CURRENT_") || strings.Contains(cleanDefault, "()") {
parts = append(parts, fmt.Sprintf("DEFAULT %s", cleanDefault))
} else if cleanDefault == "true" || cleanDefault == "false" {
parts = append(parts, fmt.Sprintf("DEFAULT %s", cleanDefault))
} else {
parts = append(parts, fmt.Sprintf("DEFAULT '%s'", escapeQuote(cleanDefault)))
}
case bool: case bool:
parts = append(parts, fmt.Sprintf("DEFAULT %v", v)) parts = append(parts, fmt.Sprintf("DEFAULT %v", v))
default: default:
@@ -815,6 +807,7 @@ func (w *Writer) writePrimaryKeys(schema *models.Schema) error {
ConstraintName: pkName, ConstraintName: pkName,
AutoGenNames: formatStringList(autoGenPKNames), AutoGenNames: formatStringList(autoGenPKNames),
Columns: strings.Join(columnNames, ", "), Columns: strings.Join(columnNames, ", "),
ColumnNames: formatStringList(columnNames),
} }
sql, err := w.executor.ExecuteCreatePrimaryKeyWithAutoGenCheck(data) sql, err := w.executor.ExecuteCreatePrimaryKeyWithAutoGenCheck(data)
@@ -862,7 +855,7 @@ func (w *Writer) writeIndexes(schema *models.Schema) error {
// Build column list with operator class support for GIN indexes // Build column list with operator class support for GIN indexes
columnExprs := make([]string, 0, len(index.Columns)) columnExprs := make([]string, 0, len(index.Columns))
for _, colName := range index.Columns { for _, colName := range index.Columns {
if col, ok := table.Columns[colName]; ok { if col, ok := resolveIndexColumn(table, colName); ok {
colExpr := col.SQLName() colExpr := col.SQLName()
// For GIN indexes on text columns, add operator class // For GIN indexes on text columns, add operator class
if strings.EqualFold(index.Type, "gin") && isTextType(col.Type) { if strings.EqualFold(index.Type, "gin") && isTextType(col.Type) {
@@ -1260,6 +1253,9 @@ func isIntegerType(colType string) bool {
func isTextType(colType string) bool { func isTextType(colType string) bool {
textTypes := []string{"text", "varchar", "character varying", "char", "character", "string"} textTypes := []string{"text", "varchar", "character varying", "char", "character", "string"}
lowerType := strings.ToLower(colType) lowerType := strings.ToLower(colType)
if strings.HasSuffix(lowerType, "[]") {
return false
}
for _, t := range textTypes { for _, t := range textTypes {
if strings.HasPrefix(lowerType, t) { if strings.HasPrefix(lowerType, t) {
return true return true
@@ -1268,35 +1264,38 @@ func isTextType(colType string) bool {
return false return false
} }
// supportsLength checks if a PostgreSQL type supports length specification
func supportsLength(colType string) bool {
lengthTypes := []string{"varchar", "character varying", "char", "character", "bit", "bit varying", "varbit"}
lowerType := strings.ToLower(colType)
for _, t := range lengthTypes {
if lowerType == t || strings.HasPrefix(lowerType, t+"(") {
return true
}
}
return false
}
// supportsPrecision checks if a PostgreSQL type supports precision/scale specification
func supportsPrecision(colType string) bool {
precisionTypes := []string{"numeric", "decimal", "time", "timestamp", "timestamptz", "timestamp with time zone", "timestamp without time zone", "time with time zone", "time without time zone", "interval"}
lowerType := strings.ToLower(colType)
for _, t := range precisionTypes {
if lowerType == t || strings.HasPrefix(lowerType, t+"(") {
return true
}
}
return false
}
// isTextTypeWithoutLength checks if type is text (which should convert to varchar when length is specified) // isTextTypeWithoutLength checks if type is text (which should convert to varchar when length is specified)
func isTextTypeWithoutLength(colType string) bool { func isTextTypeWithoutLength(colType string) bool {
return strings.EqualFold(colType, "text") return strings.EqualFold(colType, "text")
} }
func resolveIndexColumn(table *models.Table, colName string) (*models.Column, bool) {
if table == nil {
return nil, false
}
if col, ok := table.Columns[colName]; ok && col != nil {
return col, true
}
normalized := strings.ToLower(strings.Trim(colName, `"`))
for key, col := range table.Columns {
if col == nil {
continue
}
if strings.ToLower(strings.Trim(key, `"`)) == normalized {
return col, true
}
if strings.ToLower(strings.Trim(col.Name, `"`)) == normalized {
return col, true
}
if strings.ToLower(strings.Trim(col.SQLName(), `"`)) == normalized {
return col, true
}
}
return nil, false
}
// formatStringList formats a list of strings as a SQL-safe comma-separated quoted list // formatStringList formats a list of strings as a SQL-safe comma-separated quoted list
func formatStringList(items []string) string { func formatStringList(items []string) string {
quoted := make([]string, len(items)) quoted := make([]string, len(items))
@@ -1376,7 +1375,7 @@ func (w *Writer) executeDatabaseSQL(db *models.Database, connString string) erro
// Connect to database // Connect to database
ctx := context.Background() ctx := context.Background()
conn, err := pgx.Connect(ctx, connString) conn, err := pgsql.Connect(ctx, connString, "writer-pgsql")
if err != nil { if err != nil {
return fmt.Errorf("failed to connect to database: %w", err) return fmt.Errorf("failed to connect to database: %w", err)
} }

View File

@@ -87,6 +87,77 @@ func TestWriteDatabase(t *testing.T) {
} }
} }
func TestWriteDatabase_GinIndexOnTextArrayDoesNotUseTrigramOperatorClass(t *testing.T) {
db := models.InitDatabase("testdb")
schema := models.InitSchema("public")
table := models.InitTable("plans", "public")
tagsCol := models.InitColumn("tags", "plans", "public")
tagsCol.Type = "text[]"
table.Columns["tags"] = tagsCol
index := &models.Index{
Name: "idx_plans_tags",
Type: "gin",
Columns: []string{"tags"},
}
table.Indexes[index.Name] = index
schema.Tables = append(schema.Tables, table)
db.Schemas = append(db.Schemas, schema)
var buf bytes.Buffer
writer := NewWriter(&writers.WriterOptions{})
writer.writer = &buf
if err := writer.WriteDatabase(db); err != nil {
t.Fatalf("WriteDatabase failed: %v", err)
}
output := buf.String()
if !strings.Contains(output, `USING gin (tags)`) {
t.Fatalf("expected GIN index on array column without explicit trigram opclass, got:\n%s", output)
}
if strings.Contains(output, "gin_trgm_ops") {
t.Fatalf("did not expect gin_trgm_ops for text[] column, got:\n%s", output)
}
}
func TestWriteDatabase_GinIndexOnQuotedTextColumnUsesTrigramOperatorClass(t *testing.T) {
db := models.InitDatabase("testdb")
schema := models.InitSchema("public")
table := models.InitTable("agent_personas", "public")
nameCol := models.InitColumn("name", "agent_personas", "public")
nameCol.Type = "text"
table.Columns["name"] = nameCol
index := &models.Index{
Name: "idx_agent_personas_name_gin",
Type: "gin",
Columns: []string{`"name"`},
}
table.Indexes[index.Name] = index
schema.Tables = append(schema.Tables, table)
db.Schemas = append(db.Schemas, schema)
var buf bytes.Buffer
writer := NewWriter(&writers.WriterOptions{})
writer.writer = &buf
if err := writer.WriteDatabase(db); err != nil {
t.Fatalf("WriteDatabase failed: %v", err)
}
output := buf.String()
if !strings.Contains(output, `USING gin (name gin_trgm_ops)`) {
t.Fatalf("expected quoted text GIN index to include gin_trgm_ops, got:\n%s", output)
}
}
func TestWriteForeignKeys(t *testing.T) { func TestWriteForeignKeys(t *testing.T) {
// Create a test database with two related tables // Create a test database with two related tables
db := models.InitDatabase("testdb") db := models.InitDatabase("testdb")
@@ -636,9 +707,14 @@ func TestPrimaryKeyExistenceCheck(t *testing.T) {
t.Errorf("Output missing logic to drop auto-generated primary key\nFull output:\n%s", output) t.Errorf("Output missing logic to drop auto-generated primary key\nFull output:\n%s", output)
} }
// Verify it checks for our specific named constraint before adding it // Verify it compares the current primary key columns before dropping/recreating
if !strings.Contains(output, "constraint_name = 'pk_public_products'") { if !strings.Contains(output, "current_pk_matches") || !strings.Contains(output, "ARRAY['id']") {
t.Errorf("Output missing check for our named primary key constraint\nFull output:\n%s", output) t.Errorf("Output missing safe primary key comparison logic\nFull output:\n%s", output)
}
// Verify it only adds the desired key when no PK exists or an auto-generated mismatch was dropped
if !strings.Contains(output, "current_pk_name IS NULL") || !strings.Contains(output, "current_pk_name IN ('products_pkey', 'public_products_pkey')") {
t.Errorf("Output missing guarded primary key creation logic\nFull output:\n%s", output)
} }
} }
@@ -729,6 +805,93 @@ func TestColumnSizeSpecifiers(t *testing.T) {
} }
} }
func TestWriteDatabase_PrimaryKeyTemplateDoesNotDropMatchingAutoPrimaryKey(t *testing.T) {
db := models.InitDatabase("testdb")
schema := models.InitSchema("public")
table := models.InitTable("learnings", "public")
idCol := models.InitColumn("id", "learnings", "public")
idCol.Type = "bigint"
idCol.IsPrimaryKey = true
table.Columns["id"] = idCol
parentCol := models.InitColumn("duplicate_of_learning_id", "learnings", "public")
parentCol.Type = "bigint"
table.Columns["duplicate_of_learning_id"] = parentCol
schema.Tables = append(schema.Tables, table)
db.Schemas = append(db.Schemas, schema)
var buf bytes.Buffer
writer := NewWriter(&writers.WriterOptions{})
writer.writer = &buf
if err := writer.WriteDatabase(db); err != nil {
t.Fatalf("WriteDatabase failed: %v", err)
}
output := buf.String()
if !strings.Contains(output, "current_pk_matches") {
t.Fatalf("expected generated SQL to compare current PK columns, got:\n%s", output)
}
if !strings.Contains(output, "ARRAY['id']") {
t.Fatalf("expected generated SQL to compare against desired PK columns, got:\n%s", output)
}
if !strings.Contains(output, "NOT current_pk_matches") {
t.Fatalf("expected generated SQL to avoid dropping matching PKs, got:\n%s", output)
}
}
func TestGenerateColumnDefinition_PreservesExplicitTypeModifiers(t *testing.T) {
writer := NewWriter(&writers.WriterOptions{})
cases := []struct {
name string
colType string
length int
precision int
scale int
wantType string
}{
{
name: "character varying already includes length",
colType: "character varying(50)",
length: 50,
wantType: "character varying(50)",
},
{
name: "numeric already includes precision",
colType: "numeric(10,2)",
precision: 10,
scale: 2,
wantType: "numeric(10,2)",
},
{
name: "custom vector modifier preserved",
colType: "vector(1536)",
wantType: "vector(1536)",
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
col := models.InitColumn("sample", "events", "public")
col.Type = tc.colType
col.Length = tc.length
col.Precision = tc.precision
col.Scale = tc.scale
def := writer.generateColumnDefinition(col)
if !strings.Contains(def, " "+tc.wantType+" ") && !strings.HasSuffix(def, " "+tc.wantType) {
t.Fatalf("generated definition %q does not contain expected type %q", def, tc.wantType)
}
if strings.Contains(def, ")(") {
t.Fatalf("generated definition %q appears to duplicate modifiers", def)
}
})
}
}
func TestGenerateAddColumnStatements(t *testing.T) { func TestGenerateAddColumnStatements(t *testing.T) {
// Create a test database with tables that have new columns // Create a test database with tables that have new columns
db := models.InitDatabase("testdb") db := models.InitDatabase("testdb")

View File

@@ -61,7 +61,7 @@ func (w *Writer) databaseToPrisma(db *models.Database) string {
sb.WriteString("\n") sb.WriteString("\n")
// Write generator block // Write generator block
sb.WriteString(w.generateGenerator()) sb.WriteString(w.generateGenerator(db))
sb.WriteString("\n") sb.WriteString("\n")
// Process all schemas (typically just one in Prisma) // Process all schemas (typically just one in Prisma)
@@ -114,13 +114,28 @@ func (w *Writer) generateDatasource(db *models.Database) string {
} }
// generateGenerator generates the generator block // generateGenerator generates the generator block
func (w *Writer) generateGenerator() string { func (w *Writer) generateGenerator(db *models.Database) string {
if w.usePrisma7Generator(db) {
return `generator client {
provider = "prisma-client"
output = "./generated"
}
`
}
return `generator client { return `generator client {
provider = "prisma-client-js" provider = "prisma-client-js"
} }
` `
} }
func (w *Writer) usePrisma7Generator(db *models.Database) bool {
if w.options != nil && w.options.Prisma7 {
return true
}
return db != nil && db.SourceFormat == "prisma7"
}
// enumToPrisma converts an Enum to Prisma enum block // enumToPrisma converts an Enum to Prisma enum block
func (w *Writer) enumToPrisma(enum *models.Enum) string { func (w *Writer) enumToPrisma(enum *models.Enum) string {
var sb strings.Builder var sb strings.Builder

View File

@@ -0,0 +1,52 @@
package prisma
import (
"strings"
"testing"
"git.warky.dev/wdevs/relspecgo/pkg/models"
"git.warky.dev/wdevs/relspecgo/pkg/writers"
)
func TestGenerateGenerator_DefaultsToPrismaClientJS(t *testing.T) {
t.Parallel()
writer := NewWriter(&writers.WriterOptions{})
db := models.InitDatabase("testdb")
got := writer.generateGenerator(db)
if !strings.Contains(got, `provider = "prisma-client-js"`) {
t.Fatalf("expected prisma-client-js generator, got:\n%s", got)
}
if strings.Contains(got, `output = "./generated"`) {
t.Fatalf("did not expect prisma7 output path in default generator:\n%s", got)
}
}
func TestGenerateGenerator_Prisma7FlagUsesPrismaClient(t *testing.T) {
t.Parallel()
writer := NewWriter(&writers.WriterOptions{Prisma7: true})
db := models.InitDatabase("testdb")
got := writer.generateGenerator(db)
if !strings.Contains(got, `provider = "prisma-client"`) {
t.Fatalf("expected prisma-client generator, got:\n%s", got)
}
if !strings.Contains(got, `output = "./generated"`) {
t.Fatalf("expected prisma7 output path, got:\n%s", got)
}
}
func TestGenerateGenerator_Prisma7SourceFormatUsesPrismaClient(t *testing.T) {
t.Parallel()
writer := NewWriter(&writers.WriterOptions{})
db := models.InitDatabase("testdb")
db.SourceFormat = "prisma7"
got := writer.generateGenerator(db)
if !strings.Contains(got, `provider = "prisma-client"`) {
t.Fatalf("expected prisma-client generator from source format, got:\n%s", got)
}
}

View File

@@ -8,6 +8,7 @@ import (
"github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5"
"git.warky.dev/wdevs/relspecgo/pkg/models" "git.warky.dev/wdevs/relspecgo/pkg/models"
"git.warky.dev/wdevs/relspecgo/pkg/pgsql"
"git.warky.dev/wdevs/relspecgo/pkg/writers" "git.warky.dev/wdevs/relspecgo/pkg/writers"
) )
@@ -42,7 +43,7 @@ func (w *Writer) WriteDatabase(db *models.Database) error {
// Connect to database // Connect to database
ctx := context.Background() ctx := context.Background()
conn, err := pgx.Connect(ctx, connString) conn, err := pgsql.Connect(ctx, connString, "writer-sqlexec")
if err != nil { if err != nil {
return fmt.Errorf("failed to connect to database: %w", err) return fmt.Errorf("failed to connect to database: %w", err)
} }
@@ -72,7 +73,7 @@ func (w *Writer) WriteSchema(schema *models.Schema) error {
// Connect to database // Connect to database
ctx := context.Background() ctx := context.Background()
conn, err := pgx.Connect(ctx, connString) conn, err := pgsql.Connect(ctx, connString, "writer-sqlexec")
if err != nil { if err != nil {
return fmt.Errorf("failed to connect to database: %w", err) return fmt.Errorf("failed to connect to database: %w", err)
} }

View File

@@ -20,6 +20,18 @@ type Writer interface {
WriteTable(table *models.Table) error WriteTable(table *models.Table) error
} }
// NullableType constants control which Go package is used for nullable column types
// in code-generation writers (Bun, GORM).
const (
// NullableTypeResolveSpec uses github.com/bitechdev/ResolveSpec/pkg/spectypes
// (SqlString, SqlInt32, SqlVector, SqlStringArray, …). This is the default.
NullableTypeResolveSpec = "resolvespec"
// NullableTypeStdlib uses the standard library database/sql nullable types
// (sql.NullString, sql.NullInt32, …) and plain Go slices for arrays.
NullableTypeStdlib = "stdlib"
)
// WriterOptions contains common options for writers // WriterOptions contains common options for writers
type WriterOptions struct { type WriterOptions struct {
// OutputPath is the path where the output should be written // OutputPath is the path where the output should be written
@@ -33,6 +45,15 @@ type WriterOptions struct {
// Useful for databases like SQLite that do not support schemas. // Useful for databases like SQLite that do not support schemas.
FlattenSchema bool FlattenSchema bool
// NullableTypes selects the Go type package used for nullable columns in
// code-generation writers (bun, gorm). Accepted values:
// "resolvespec" (default) — github.com/bitechdev/ResolveSpec/pkg/spectypes
// "stdlib" — database/sql (sql.NullString, sql.NullInt32, …)
NullableTypes string
// Prisma7 enables Prisma 7-specific output for Prisma writers.
Prisma7 bool
// Additional options can be added here as needed // Additional options can be added here as needed
Metadata map[string]interface{} Metadata map[string]interface{}
} }
@@ -92,8 +113,12 @@ func SanitizeFilename(name string) string {
// Examples (bigint): "0" → "0" // Examples (bigint): "0" → "0"
// Examples (timestamp): "now()" → "now()" (function call never quoted) // Examples (timestamp): "now()" → "now()" (function call never quoted)
func QuoteDefaultValue(value, sqlType string) string { func QuoteDefaultValue(value, sqlType string) string {
value = strings.TrimSpace(value)
// Function calls are never quoted regardless of column type. // Function calls are never quoted regardless of column type.
if strings.Contains(value, "(") || strings.Contains(value, ")") { if strings.Contains(value, "(") || strings.Contains(value, ")") ||
strings.Contains(value, "::") ||
strings.HasPrefix(strings.ToUpper(value), "ARRAY[") {
return value return value
} }
@@ -103,6 +128,16 @@ func QuoteDefaultValue(value, sqlType string) string {
baseType = baseType[:idx] baseType = baseType[:idx]
} }
if isArraySQLType(baseType) {
if arrayLiteral, ok := normalizeArrayDefaultLiteral(value); ok {
return quoteSQLLiteral(arrayLiteral)
}
}
if isQuotedSQLLiteral(value) {
return value
}
// Types whose default values must NOT be quoted. // Types whose default values must NOT be quoted.
unquotedTypes := map[string]bool{ unquotedTypes := map[string]bool{
// Integer types // Integer types
@@ -136,7 +171,32 @@ func QuoteDefaultValue(value, sqlType string) string {
// Everything else (text, varchar, char, uuid, date, time, timestamp, json, …) // Everything else (text, varchar, char, uuid, date, time, timestamp, json, …)
// is treated as a quoted literal. // is treated as a quoted literal.
return "'" + value + "'" return quoteSQLLiteral(value)
}
func isArraySQLType(sqlType string) bool {
return strings.HasSuffix(sqlType, "[]")
}
func normalizeArrayDefaultLiteral(value string) (string, bool) {
switch {
case strings.HasPrefix(value, "''{") && strings.HasSuffix(value, "}''"):
return value[2 : len(value)-2], true
case strings.HasPrefix(value, "'{") && strings.HasSuffix(value, "}'"):
return value[1 : len(value)-1], true
case strings.HasPrefix(value, "{") && strings.HasSuffix(value, "}"):
return value, true
default:
return "", false
}
}
func isQuotedSQLLiteral(value string) bool {
return len(value) >= 2 && strings.HasPrefix(value, "'") && strings.HasSuffix(value, "'")
}
func quoteSQLLiteral(value string) string {
return "'" + strings.ReplaceAll(value, "'", "''") + "'"
} }
// SanitizeStructTagValue sanitizes a value to be safely used inside Go struct tags. // SanitizeStructTagValue sanitizes a value to be safely used inside Go struct tags.
@@ -147,7 +207,8 @@ func QuoteDefaultValue(value, sqlType string) string {
// - Returns a clean identifier safe for use in struct tags and field names // - Returns a clean identifier safe for use in struct tags and field names
func SanitizeStructTagValue(value string) string { func SanitizeStructTagValue(value string) string {
// Remove DBML/DCTX style comments in brackets (e.g., [note: 'description']) // Remove DBML/DCTX style comments in brackets (e.g., [note: 'description'])
commentRegex := regexp.MustCompile(`\s*\[.*?\]\s*`) // Require at least one character inside brackets to avoid stripping PostgreSQL array suffix "[]"
commentRegex := regexp.MustCompile(`\s*\[[^\]]+\]\s*`)
value = commentRegex.ReplaceAllString(value, "") value = commentRegex.ReplaceAllString(value, "")
// Trim whitespace // Trim whitespace

View File

@@ -0,0 +1,54 @@
package writers
import "testing"
func TestQuoteDefaultValue(t *testing.T) {
t.Parallel()
tests := []struct {
name string
value string
sqlType string
want string
}{
{
name: "text default is quoted",
value: "active",
sqlType: "text",
want: "'active'",
},
{
name: "array default from bare literal is quoted once",
value: "{}",
sqlType: "text[]",
want: "'{}'",
},
{
name: "array default from quoted literal is preserved",
value: "'{}'",
sqlType: "text[]",
want: "'{}'",
},
{
name: "array default from double quoted literal is normalized",
value: "''{}''",
sqlType: "text[]",
want: "'{}'",
},
{
name: "function default is left alone",
value: "now()",
sqlType: "timestamptz",
want: "now()",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := QuoteDefaultValue(tt.value, tt.sqlType)
if got != tt.want {
t.Fatalf("QuoteDefaultValue(%q, %q) = %q, want %q", tt.value, tt.sqlType, got, tt.want)
}
})
}
}