Compare commits
20 Commits
1e54fdcd7f
...
v1.0.56
| Author | SHA1 | Date | |
|---|---|---|---|
| 30ef1db010 | |||
| 2d97a47ee1 | |||
| 72200ea72e | |||
| 608893a3d6 | |||
| 53ff745d5d | |||
| 17bc8ed395 | |||
| a447b68b22 | |||
| 4303dcf59b | |||
| e828d48798 | |||
| 6e470a9239 | |||
| 096815fe49 | |||
| b8f60203cb | |||
|
|
15763f60cc | ||
|
|
6d2884f5cf | ||
|
|
f192decff8 | ||
|
|
8b906cf4a3 | ||
|
|
0a3966e6fc | ||
|
|
d30fc24f55 | ||
|
|
16a489d0b8 | ||
|
|
3524e86282 |
@@ -52,6 +52,7 @@ var (
|
|||||||
convertPackageName string
|
convertPackageName string
|
||||||
convertSchemaFilter string
|
convertSchemaFilter string
|
||||||
convertFlattenSchema bool
|
convertFlattenSchema bool
|
||||||
|
convertNullableTypes string
|
||||||
)
|
)
|
||||||
|
|
||||||
var convertCmd = &cobra.Command{
|
var convertCmd = &cobra.Command{
|
||||||
@@ -175,6 +176,7 @@ func init() {
|
|||||||
convertCmd.Flags().StringVar(&convertPackageName, "package", "", "Package name (for code generation formats like gorm/bun)")
|
convertCmd.Flags().StringVar(&convertPackageName, "package", "", "Package name (for code generation formats like gorm/bun)")
|
||||||
convertCmd.Flags().StringVar(&convertSchemaFilter, "schema", "", "Filter to a specific schema by name (required for formats like dctx that only support single schemas)")
|
convertCmd.Flags().StringVar(&convertSchemaFilter, "schema", "", "Filter to a specific schema by name (required for formats like dctx that only support single schemas)")
|
||||||
convertCmd.Flags().BoolVar(&convertFlattenSchema, "flatten-schema", false, "Flatten schema.table names to schema_table (useful for databases like SQLite that do not support schemas)")
|
convertCmd.Flags().BoolVar(&convertFlattenSchema, "flatten-schema", false, "Flatten schema.table names to schema_table (useful for databases like SQLite that do not support schemas)")
|
||||||
|
convertCmd.Flags().StringVar(&convertNullableTypes, "types", "", "Nullable type package for code-gen writers (bun/gorm): 'resolvespec' (default) or 'stdlib' (database/sql)")
|
||||||
|
|
||||||
err := convertCmd.MarkFlagRequired("from")
|
err := convertCmd.MarkFlagRequired("from")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -241,7 +243,7 @@ func runConvert(cmd *cobra.Command, args []string) error {
|
|||||||
fmt.Fprintf(os.Stderr, " Schema: %s\n", convertSchemaFilter)
|
fmt.Fprintf(os.Stderr, " Schema: %s\n", convertSchemaFilter)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := writeDatabase(db, convertTargetType, convertTargetPath, convertPackageName, convertSchemaFilter, convertFlattenSchema); err != nil {
|
if err := writeDatabase(db, convertTargetType, convertTargetPath, convertPackageName, convertSchemaFilter, convertFlattenSchema, convertNullableTypes); err != nil {
|
||||||
return fmt.Errorf("failed to write target: %w", err)
|
return fmt.Errorf("failed to write target: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -284,79 +286,79 @@ func readDatabaseForConvert(dbType, filePath, connString string) (*models.Databa
|
|||||||
if filePath == "" {
|
if filePath == "" {
|
||||||
return nil, fmt.Errorf("file path is required for DBML format")
|
return nil, fmt.Errorf("file path is required for DBML format")
|
||||||
}
|
}
|
||||||
reader = dbml.NewReader(&readers.ReaderOptions{FilePath: filePath})
|
reader = dbml.NewReader(newReaderOptions(filePath, ""))
|
||||||
|
|
||||||
case "dctx":
|
case "dctx":
|
||||||
if filePath == "" {
|
if filePath == "" {
|
||||||
return nil, fmt.Errorf("file path is required for DCTX format")
|
return nil, fmt.Errorf("file path is required for DCTX format")
|
||||||
}
|
}
|
||||||
reader = dctx.NewReader(&readers.ReaderOptions{FilePath: filePath})
|
reader = dctx.NewReader(newReaderOptions(filePath, ""))
|
||||||
|
|
||||||
case "drawdb":
|
case "drawdb":
|
||||||
if filePath == "" {
|
if filePath == "" {
|
||||||
return nil, fmt.Errorf("file path is required for DrawDB format")
|
return nil, fmt.Errorf("file path is required for DrawDB format")
|
||||||
}
|
}
|
||||||
reader = drawdb.NewReader(&readers.ReaderOptions{FilePath: filePath})
|
reader = drawdb.NewReader(newReaderOptions(filePath, ""))
|
||||||
|
|
||||||
case "json":
|
case "json":
|
||||||
if filePath == "" {
|
if filePath == "" {
|
||||||
return nil, fmt.Errorf("file path is required for JSON format")
|
return nil, fmt.Errorf("file path is required for JSON format")
|
||||||
}
|
}
|
||||||
reader = json.NewReader(&readers.ReaderOptions{FilePath: filePath})
|
reader = json.NewReader(newReaderOptions(filePath, ""))
|
||||||
|
|
||||||
case "yaml", "yml":
|
case "yaml", "yml":
|
||||||
if filePath == "" {
|
if filePath == "" {
|
||||||
return nil, fmt.Errorf("file path is required for YAML format")
|
return nil, fmt.Errorf("file path is required for YAML format")
|
||||||
}
|
}
|
||||||
reader = yaml.NewReader(&readers.ReaderOptions{FilePath: filePath})
|
reader = yaml.NewReader(newReaderOptions(filePath, ""))
|
||||||
|
|
||||||
case "pgsql", "postgres", "postgresql":
|
case "pgsql", "postgres", "postgresql":
|
||||||
if connString == "" {
|
if connString == "" {
|
||||||
return nil, fmt.Errorf("connection string is required for PostgreSQL format")
|
return nil, fmt.Errorf("connection string is required for PostgreSQL format")
|
||||||
}
|
}
|
||||||
reader = pgsql.NewReader(&readers.ReaderOptions{ConnectionString: connString})
|
reader = pgsql.NewReader(newReaderOptions("", connString))
|
||||||
|
|
||||||
case "gorm":
|
case "gorm":
|
||||||
if filePath == "" {
|
if filePath == "" {
|
||||||
return nil, fmt.Errorf("file path is required for GORM format")
|
return nil, fmt.Errorf("file path is required for GORM format")
|
||||||
}
|
}
|
||||||
reader = gorm.NewReader(&readers.ReaderOptions{FilePath: filePath})
|
reader = gorm.NewReader(newReaderOptions(filePath, ""))
|
||||||
|
|
||||||
case "bun":
|
case "bun":
|
||||||
if filePath == "" {
|
if filePath == "" {
|
||||||
return nil, fmt.Errorf("file path is required for Bun format")
|
return nil, fmt.Errorf("file path is required for Bun format")
|
||||||
}
|
}
|
||||||
reader = bun.NewReader(&readers.ReaderOptions{FilePath: filePath})
|
reader = bun.NewReader(newReaderOptions(filePath, ""))
|
||||||
|
|
||||||
case "drizzle":
|
case "drizzle":
|
||||||
if filePath == "" {
|
if filePath == "" {
|
||||||
return nil, fmt.Errorf("file path is required for Drizzle format")
|
return nil, fmt.Errorf("file path is required for Drizzle format")
|
||||||
}
|
}
|
||||||
reader = drizzle.NewReader(&readers.ReaderOptions{FilePath: filePath})
|
reader = drizzle.NewReader(newReaderOptions(filePath, ""))
|
||||||
|
|
||||||
case "prisma":
|
case "prisma":
|
||||||
if filePath == "" {
|
if filePath == "" {
|
||||||
return nil, fmt.Errorf("file path is required for Prisma format")
|
return nil, fmt.Errorf("file path is required for Prisma format")
|
||||||
}
|
}
|
||||||
reader = prisma.NewReader(&readers.ReaderOptions{FilePath: filePath})
|
reader = prisma.NewReader(newReaderOptions(filePath, ""))
|
||||||
|
|
||||||
case "typeorm":
|
case "typeorm":
|
||||||
if filePath == "" {
|
if filePath == "" {
|
||||||
return nil, fmt.Errorf("file path is required for TypeORM format")
|
return nil, fmt.Errorf("file path is required for TypeORM format")
|
||||||
}
|
}
|
||||||
reader = typeorm.NewReader(&readers.ReaderOptions{FilePath: filePath})
|
reader = typeorm.NewReader(newReaderOptions(filePath, ""))
|
||||||
|
|
||||||
case "graphql", "gql":
|
case "graphql", "gql":
|
||||||
if filePath == "" {
|
if filePath == "" {
|
||||||
return nil, fmt.Errorf("file path is required for GraphQL format")
|
return nil, fmt.Errorf("file path is required for GraphQL format")
|
||||||
}
|
}
|
||||||
reader = graphql.NewReader(&readers.ReaderOptions{FilePath: filePath})
|
reader = graphql.NewReader(newReaderOptions(filePath, ""))
|
||||||
|
|
||||||
case "mssql", "sqlserver", "mssql2016", "mssql2017", "mssql2019", "mssql2022":
|
case "mssql", "sqlserver", "mssql2016", "mssql2017", "mssql2019", "mssql2022":
|
||||||
if connString == "" {
|
if connString == "" {
|
||||||
return nil, fmt.Errorf("connection string is required for MSSQL format")
|
return nil, fmt.Errorf("connection string is required for MSSQL format")
|
||||||
}
|
}
|
||||||
reader = mssql.NewReader(&readers.ReaderOptions{ConnectionString: connString})
|
reader = mssql.NewReader(newReaderOptions("", connString))
|
||||||
|
|
||||||
case "sqlite", "sqlite3":
|
case "sqlite", "sqlite3":
|
||||||
// SQLite can use either file path or connection string
|
// SQLite can use either file path or connection string
|
||||||
@@ -367,7 +369,7 @@ func readDatabaseForConvert(dbType, filePath, connString string) (*models.Databa
|
|||||||
if dbPath == "" {
|
if dbPath == "" {
|
||||||
return nil, fmt.Errorf("file path or connection string is required for SQLite format")
|
return nil, fmt.Errorf("file path or connection string is required for SQLite format")
|
||||||
}
|
}
|
||||||
reader = sqlite.NewReader(&readers.ReaderOptions{FilePath: dbPath})
|
reader = sqlite.NewReader(newReaderOptions(dbPath, ""))
|
||||||
|
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("unsupported source format: %s", dbType)
|
return nil, fmt.Errorf("unsupported source format: %s", dbType)
|
||||||
@@ -381,14 +383,10 @@ func readDatabaseForConvert(dbType, filePath, connString string) (*models.Databa
|
|||||||
return db, nil
|
return db, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func writeDatabase(db *models.Database, dbType, outputPath, packageName, schemaFilter string, flattenSchema bool) error {
|
func writeDatabase(db *models.Database, dbType, outputPath, packageName, schemaFilter string, flattenSchema bool, nullableTypes string) error {
|
||||||
var writer writers.Writer
|
var writer writers.Writer
|
||||||
|
|
||||||
writerOpts := &writers.WriterOptions{
|
writerOpts := newWriterOptions(outputPath, packageName, flattenSchema, nullableTypes)
|
||||||
OutputPath: outputPath,
|
|
||||||
PackageName: packageName,
|
|
||||||
FlattenSchema: flattenSchema,
|
|
||||||
}
|
|
||||||
|
|
||||||
switch strings.ToLower(dbType) {
|
switch strings.ToLower(dbType) {
|
||||||
case "dbml":
|
case "dbml":
|
||||||
|
|||||||
@@ -240,62 +240,62 @@ func readDatabaseForEdit(dbType, filePath, connString, label string) (*models.Da
|
|||||||
if filePath == "" {
|
if filePath == "" {
|
||||||
return nil, fmt.Errorf("%s: file path is required for DBML format", label)
|
return nil, fmt.Errorf("%s: file path is required for DBML format", label)
|
||||||
}
|
}
|
||||||
reader = dbml.NewReader(&readers.ReaderOptions{FilePath: filePath})
|
reader = dbml.NewReader(newReaderOptions(filePath, ""))
|
||||||
case "dctx":
|
case "dctx":
|
||||||
if filePath == "" {
|
if filePath == "" {
|
||||||
return nil, fmt.Errorf("%s: file path is required for DCTX format", label)
|
return nil, fmt.Errorf("%s: file path is required for DCTX format", label)
|
||||||
}
|
}
|
||||||
reader = dctx.NewReader(&readers.ReaderOptions{FilePath: filePath})
|
reader = dctx.NewReader(newReaderOptions(filePath, ""))
|
||||||
case "drawdb":
|
case "drawdb":
|
||||||
if filePath == "" {
|
if filePath == "" {
|
||||||
return nil, fmt.Errorf("%s: file path is required for DrawDB format", label)
|
return nil, fmt.Errorf("%s: file path is required for DrawDB format", label)
|
||||||
}
|
}
|
||||||
reader = drawdb.NewReader(&readers.ReaderOptions{FilePath: filePath})
|
reader = drawdb.NewReader(newReaderOptions(filePath, ""))
|
||||||
case "graphql":
|
case "graphql":
|
||||||
if filePath == "" {
|
if filePath == "" {
|
||||||
return nil, fmt.Errorf("%s: file path is required for GraphQL format", label)
|
return nil, fmt.Errorf("%s: file path is required for GraphQL format", label)
|
||||||
}
|
}
|
||||||
reader = graphql.NewReader(&readers.ReaderOptions{FilePath: filePath})
|
reader = graphql.NewReader(newReaderOptions(filePath, ""))
|
||||||
case "json":
|
case "json":
|
||||||
if filePath == "" {
|
if filePath == "" {
|
||||||
return nil, fmt.Errorf("%s: file path is required for JSON format", label)
|
return nil, fmt.Errorf("%s: file path is required for JSON format", label)
|
||||||
}
|
}
|
||||||
reader = json.NewReader(&readers.ReaderOptions{FilePath: filePath})
|
reader = json.NewReader(newReaderOptions(filePath, ""))
|
||||||
case "yaml":
|
case "yaml":
|
||||||
if filePath == "" {
|
if filePath == "" {
|
||||||
return nil, fmt.Errorf("%s: file path is required for YAML format", label)
|
return nil, fmt.Errorf("%s: file path is required for YAML format", label)
|
||||||
}
|
}
|
||||||
reader = yaml.NewReader(&readers.ReaderOptions{FilePath: filePath})
|
reader = yaml.NewReader(newReaderOptions(filePath, ""))
|
||||||
case "gorm":
|
case "gorm":
|
||||||
if filePath == "" {
|
if filePath == "" {
|
||||||
return nil, fmt.Errorf("%s: file path is required for GORM format", label)
|
return nil, fmt.Errorf("%s: file path is required for GORM format", label)
|
||||||
}
|
}
|
||||||
reader = gorm.NewReader(&readers.ReaderOptions{FilePath: filePath})
|
reader = gorm.NewReader(newReaderOptions(filePath, ""))
|
||||||
case "bun":
|
case "bun":
|
||||||
if filePath == "" {
|
if filePath == "" {
|
||||||
return nil, fmt.Errorf("%s: file path is required for Bun format", label)
|
return nil, fmt.Errorf("%s: file path is required for Bun format", label)
|
||||||
}
|
}
|
||||||
reader = bun.NewReader(&readers.ReaderOptions{FilePath: filePath})
|
reader = bun.NewReader(newReaderOptions(filePath, ""))
|
||||||
case "drizzle":
|
case "drizzle":
|
||||||
if filePath == "" {
|
if filePath == "" {
|
||||||
return nil, fmt.Errorf("%s: file path is required for Drizzle format", label)
|
return nil, fmt.Errorf("%s: file path is required for Drizzle format", label)
|
||||||
}
|
}
|
||||||
reader = drizzle.NewReader(&readers.ReaderOptions{FilePath: filePath})
|
reader = drizzle.NewReader(newReaderOptions(filePath, ""))
|
||||||
case "prisma":
|
case "prisma":
|
||||||
if filePath == "" {
|
if filePath == "" {
|
||||||
return nil, fmt.Errorf("%s: file path is required for Prisma format", label)
|
return nil, fmt.Errorf("%s: file path is required for Prisma format", label)
|
||||||
}
|
}
|
||||||
reader = prisma.NewReader(&readers.ReaderOptions{FilePath: filePath})
|
reader = prisma.NewReader(newReaderOptions(filePath, ""))
|
||||||
case "typeorm":
|
case "typeorm":
|
||||||
if filePath == "" {
|
if filePath == "" {
|
||||||
return nil, fmt.Errorf("%s: file path is required for TypeORM format", label)
|
return nil, fmt.Errorf("%s: file path is required for TypeORM format", label)
|
||||||
}
|
}
|
||||||
reader = typeorm.NewReader(&readers.ReaderOptions{FilePath: filePath})
|
reader = typeorm.NewReader(newReaderOptions(filePath, ""))
|
||||||
case "pgsql":
|
case "pgsql":
|
||||||
if connString == "" {
|
if connString == "" {
|
||||||
return nil, fmt.Errorf("%s: connection string is required for PostgreSQL format", label)
|
return nil, fmt.Errorf("%s: connection string is required for PostgreSQL format", label)
|
||||||
}
|
}
|
||||||
reader = pgsql.NewReader(&readers.ReaderOptions{ConnectionString: connString})
|
reader = pgsql.NewReader(newReaderOptions("", connString))
|
||||||
case "sqlite", "sqlite3":
|
case "sqlite", "sqlite3":
|
||||||
// SQLite can use either file path or connection string
|
// SQLite can use either file path or connection string
|
||||||
dbPath := filePath
|
dbPath := filePath
|
||||||
@@ -305,7 +305,7 @@ func readDatabaseForEdit(dbType, filePath, connString, label string) (*models.Da
|
|||||||
if dbPath == "" {
|
if dbPath == "" {
|
||||||
return nil, fmt.Errorf("%s: file path or connection string is required for SQLite format", label)
|
return nil, fmt.Errorf("%s: file path or connection string is required for SQLite format", label)
|
||||||
}
|
}
|
||||||
reader = sqlite.NewReader(&readers.ReaderOptions{FilePath: dbPath})
|
reader = sqlite.NewReader(newReaderOptions(dbPath, ""))
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("%s: unsupported format: %s", label, dbType)
|
return nil, fmt.Errorf("%s: unsupported format: %s", label, dbType)
|
||||||
}
|
}
|
||||||
@@ -323,31 +323,31 @@ func writeDatabaseForEdit(dbType, filePath, connString string, db *models.Databa
|
|||||||
|
|
||||||
switch strings.ToLower(dbType) {
|
switch strings.ToLower(dbType) {
|
||||||
case "dbml":
|
case "dbml":
|
||||||
writer = wdbml.NewWriter(&writers.WriterOptions{OutputPath: filePath})
|
writer = wdbml.NewWriter(newWriterOptions(filePath, "", false, ""))
|
||||||
case "dctx":
|
case "dctx":
|
||||||
writer = wdctx.NewWriter(&writers.WriterOptions{OutputPath: filePath})
|
writer = wdctx.NewWriter(newWriterOptions(filePath, "", false, ""))
|
||||||
case "drawdb":
|
case "drawdb":
|
||||||
writer = wdrawdb.NewWriter(&writers.WriterOptions{OutputPath: filePath})
|
writer = wdrawdb.NewWriter(newWriterOptions(filePath, "", false, ""))
|
||||||
case "graphql":
|
case "graphql":
|
||||||
writer = wgraphql.NewWriter(&writers.WriterOptions{OutputPath: filePath})
|
writer = wgraphql.NewWriter(newWriterOptions(filePath, "", false, ""))
|
||||||
case "json":
|
case "json":
|
||||||
writer = wjson.NewWriter(&writers.WriterOptions{OutputPath: filePath})
|
writer = wjson.NewWriter(newWriterOptions(filePath, "", false, ""))
|
||||||
case "yaml":
|
case "yaml":
|
||||||
writer = wyaml.NewWriter(&writers.WriterOptions{OutputPath: filePath})
|
writer = wyaml.NewWriter(newWriterOptions(filePath, "", false, ""))
|
||||||
case "gorm":
|
case "gorm":
|
||||||
writer = wgorm.NewWriter(&writers.WriterOptions{OutputPath: filePath})
|
writer = wgorm.NewWriter(newWriterOptions(filePath, "", false, ""))
|
||||||
case "bun":
|
case "bun":
|
||||||
writer = wbun.NewWriter(&writers.WriterOptions{OutputPath: filePath})
|
writer = wbun.NewWriter(newWriterOptions(filePath, "", false, ""))
|
||||||
case "drizzle":
|
case "drizzle":
|
||||||
writer = wdrizzle.NewWriter(&writers.WriterOptions{OutputPath: filePath})
|
writer = wdrizzle.NewWriter(newWriterOptions(filePath, "", false, ""))
|
||||||
case "prisma":
|
case "prisma":
|
||||||
writer = wprisma.NewWriter(&writers.WriterOptions{OutputPath: filePath})
|
writer = wprisma.NewWriter(newWriterOptions(filePath, "", false, ""))
|
||||||
case "typeorm":
|
case "typeorm":
|
||||||
writer = wtypeorm.NewWriter(&writers.WriterOptions{OutputPath: filePath})
|
writer = wtypeorm.NewWriter(newWriterOptions(filePath, "", false, ""))
|
||||||
case "sqlite", "sqlite3":
|
case "sqlite", "sqlite3":
|
||||||
writer = wsqlite.NewWriter(&writers.WriterOptions{OutputPath: filePath})
|
writer = wsqlite.NewWriter(newWriterOptions(filePath, "", false, ""))
|
||||||
case "pgsql":
|
case "pgsql":
|
||||||
writer = wpgsql.NewWriter(&writers.WriterOptions{OutputPath: filePath})
|
writer = wpgsql.NewWriter(newWriterOptions(filePath, "", false, ""))
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("%s: unsupported format: %s", label, dbType)
|
return fmt.Errorf("%s: unsupported format: %s", label, dbType)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -221,73 +221,73 @@ func readDatabaseForInspect(dbType, filePath, connString string) (*models.Databa
|
|||||||
if filePath == "" {
|
if filePath == "" {
|
||||||
return nil, fmt.Errorf("file path is required for DBML format")
|
return nil, fmt.Errorf("file path is required for DBML format")
|
||||||
}
|
}
|
||||||
reader = dbml.NewReader(&readers.ReaderOptions{FilePath: filePath})
|
reader = dbml.NewReader(newReaderOptions(filePath, ""))
|
||||||
|
|
||||||
case "dctx":
|
case "dctx":
|
||||||
if filePath == "" {
|
if filePath == "" {
|
||||||
return nil, fmt.Errorf("file path is required for DCTX format")
|
return nil, fmt.Errorf("file path is required for DCTX format")
|
||||||
}
|
}
|
||||||
reader = dctx.NewReader(&readers.ReaderOptions{FilePath: filePath})
|
reader = dctx.NewReader(newReaderOptions(filePath, ""))
|
||||||
|
|
||||||
case "drawdb":
|
case "drawdb":
|
||||||
if filePath == "" {
|
if filePath == "" {
|
||||||
return nil, fmt.Errorf("file path is required for DrawDB format")
|
return nil, fmt.Errorf("file path is required for DrawDB format")
|
||||||
}
|
}
|
||||||
reader = drawdb.NewReader(&readers.ReaderOptions{FilePath: filePath})
|
reader = drawdb.NewReader(newReaderOptions(filePath, ""))
|
||||||
|
|
||||||
case "graphql":
|
case "graphql":
|
||||||
if filePath == "" {
|
if filePath == "" {
|
||||||
return nil, fmt.Errorf("file path is required for GraphQL format")
|
return nil, fmt.Errorf("file path is required for GraphQL format")
|
||||||
}
|
}
|
||||||
reader = graphql.NewReader(&readers.ReaderOptions{FilePath: filePath})
|
reader = graphql.NewReader(newReaderOptions(filePath, ""))
|
||||||
|
|
||||||
case "json":
|
case "json":
|
||||||
if filePath == "" {
|
if filePath == "" {
|
||||||
return nil, fmt.Errorf("file path is required for JSON format")
|
return nil, fmt.Errorf("file path is required for JSON format")
|
||||||
}
|
}
|
||||||
reader = json.NewReader(&readers.ReaderOptions{FilePath: filePath})
|
reader = json.NewReader(newReaderOptions(filePath, ""))
|
||||||
|
|
||||||
case "yaml", "yml":
|
case "yaml", "yml":
|
||||||
if filePath == "" {
|
if filePath == "" {
|
||||||
return nil, fmt.Errorf("file path is required for YAML format")
|
return nil, fmt.Errorf("file path is required for YAML format")
|
||||||
}
|
}
|
||||||
reader = yaml.NewReader(&readers.ReaderOptions{FilePath: filePath})
|
reader = yaml.NewReader(newReaderOptions(filePath, ""))
|
||||||
|
|
||||||
case "gorm":
|
case "gorm":
|
||||||
if filePath == "" {
|
if filePath == "" {
|
||||||
return nil, fmt.Errorf("file path is required for GORM format")
|
return nil, fmt.Errorf("file path is required for GORM format")
|
||||||
}
|
}
|
||||||
reader = gorm.NewReader(&readers.ReaderOptions{FilePath: filePath})
|
reader = gorm.NewReader(newReaderOptions(filePath, ""))
|
||||||
|
|
||||||
case "bun":
|
case "bun":
|
||||||
if filePath == "" {
|
if filePath == "" {
|
||||||
return nil, fmt.Errorf("file path is required for Bun format")
|
return nil, fmt.Errorf("file path is required for Bun format")
|
||||||
}
|
}
|
||||||
reader = bun.NewReader(&readers.ReaderOptions{FilePath: filePath})
|
reader = bun.NewReader(newReaderOptions(filePath, ""))
|
||||||
|
|
||||||
case "drizzle":
|
case "drizzle":
|
||||||
if filePath == "" {
|
if filePath == "" {
|
||||||
return nil, fmt.Errorf("file path is required for Drizzle format")
|
return nil, fmt.Errorf("file path is required for Drizzle format")
|
||||||
}
|
}
|
||||||
reader = drizzle.NewReader(&readers.ReaderOptions{FilePath: filePath})
|
reader = drizzle.NewReader(newReaderOptions(filePath, ""))
|
||||||
|
|
||||||
case "prisma":
|
case "prisma":
|
||||||
if filePath == "" {
|
if filePath == "" {
|
||||||
return nil, fmt.Errorf("file path is required for Prisma format")
|
return nil, fmt.Errorf("file path is required for Prisma format")
|
||||||
}
|
}
|
||||||
reader = prisma.NewReader(&readers.ReaderOptions{FilePath: filePath})
|
reader = prisma.NewReader(newReaderOptions(filePath, ""))
|
||||||
|
|
||||||
case "typeorm":
|
case "typeorm":
|
||||||
if filePath == "" {
|
if filePath == "" {
|
||||||
return nil, fmt.Errorf("file path is required for TypeORM format")
|
return nil, fmt.Errorf("file path is required for TypeORM format")
|
||||||
}
|
}
|
||||||
reader = typeorm.NewReader(&readers.ReaderOptions{FilePath: filePath})
|
reader = typeorm.NewReader(newReaderOptions(filePath, ""))
|
||||||
|
|
||||||
case "pgsql", "postgres", "postgresql":
|
case "pgsql", "postgres", "postgresql":
|
||||||
if connString == "" {
|
if connString == "" {
|
||||||
return nil, fmt.Errorf("connection string is required for PostgreSQL format")
|
return nil, fmt.Errorf("connection string is required for PostgreSQL format")
|
||||||
}
|
}
|
||||||
reader = pgsql.NewReader(&readers.ReaderOptions{ConnectionString: connString})
|
reader = pgsql.NewReader(newReaderOptions("", connString))
|
||||||
|
|
||||||
case "sqlite", "sqlite3":
|
case "sqlite", "sqlite3":
|
||||||
// SQLite can use either file path or connection string
|
// SQLite can use either file path or connection string
|
||||||
@@ -298,7 +298,7 @@ func readDatabaseForInspect(dbType, filePath, connString string) (*models.Databa
|
|||||||
if dbPath == "" {
|
if dbPath == "" {
|
||||||
return nil, fmt.Errorf("file path or connection string is required for SQLite format")
|
return nil, fmt.Errorf("file path or connection string is required for SQLite format")
|
||||||
}
|
}
|
||||||
reader = sqlite.NewReader(&readers.ReaderOptions{FilePath: dbPath})
|
reader = sqlite.NewReader(newReaderOptions(dbPath, ""))
|
||||||
|
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("unsupported database type: %s", dbType)
|
return nil, fmt.Errorf("unsupported database type: %s", dbType)
|
||||||
|
|||||||
@@ -258,6 +258,11 @@ func runMerge(cmd *cobra.Command, args []string) error {
|
|||||||
fmt.Fprintf(os.Stderr, " ✓ Merge complete\n\n")
|
fmt.Fprintf(os.Stderr, " ✓ Merge complete\n\n")
|
||||||
fmt.Fprintf(os.Stderr, "%s\n", merge.GetMergeSummary(result))
|
fmt.Fprintf(os.Stderr, "%s\n", merge.GetMergeSummary(result))
|
||||||
|
|
||||||
|
if strings.EqualFold(mergeOutputType, "pgsql") && len(result.TypeConflicts) > 0 {
|
||||||
|
return fmt.Errorf("merge detected conflicting existing column types and cannot safely continue with pgsql output\n%s",
|
||||||
|
merge.GetColumnTypeConflictSummary(result, 10))
|
||||||
|
}
|
||||||
|
|
||||||
// Step 4: Write output
|
// Step 4: Write output
|
||||||
fmt.Fprintf(os.Stderr, "\n[4/4] Writing output...\n")
|
fmt.Fprintf(os.Stderr, "\n[4/4] Writing output...\n")
|
||||||
fmt.Fprintf(os.Stderr, " Format: %s\n", mergeOutputType)
|
fmt.Fprintf(os.Stderr, " Format: %s\n", mergeOutputType)
|
||||||
@@ -284,62 +289,62 @@ func readDatabaseForMerge(dbType, filePath, connString, label string) (*models.D
|
|||||||
if filePath == "" {
|
if filePath == "" {
|
||||||
return nil, fmt.Errorf("%s: file path is required for DBML format", label)
|
return nil, fmt.Errorf("%s: file path is required for DBML format", label)
|
||||||
}
|
}
|
||||||
reader = dbml.NewReader(&readers.ReaderOptions{FilePath: filePath})
|
reader = dbml.NewReader(newReaderOptions(filePath, ""))
|
||||||
case "dctx":
|
case "dctx":
|
||||||
if filePath == "" {
|
if filePath == "" {
|
||||||
return nil, fmt.Errorf("%s: file path is required for DCTX format", label)
|
return nil, fmt.Errorf("%s: file path is required for DCTX format", label)
|
||||||
}
|
}
|
||||||
reader = dctx.NewReader(&readers.ReaderOptions{FilePath: filePath})
|
reader = dctx.NewReader(newReaderOptions(filePath, ""))
|
||||||
case "drawdb":
|
case "drawdb":
|
||||||
if filePath == "" {
|
if filePath == "" {
|
||||||
return nil, fmt.Errorf("%s: file path is required for DrawDB format", label)
|
return nil, fmt.Errorf("%s: file path is required for DrawDB format", label)
|
||||||
}
|
}
|
||||||
reader = drawdb.NewReader(&readers.ReaderOptions{FilePath: filePath})
|
reader = drawdb.NewReader(newReaderOptions(filePath, ""))
|
||||||
case "graphql":
|
case "graphql":
|
||||||
if filePath == "" {
|
if filePath == "" {
|
||||||
return nil, fmt.Errorf("%s: file path is required for GraphQL format", label)
|
return nil, fmt.Errorf("%s: file path is required for GraphQL format", label)
|
||||||
}
|
}
|
||||||
reader = graphql.NewReader(&readers.ReaderOptions{FilePath: filePath})
|
reader = graphql.NewReader(newReaderOptions(filePath, ""))
|
||||||
case "json":
|
case "json":
|
||||||
if filePath == "" {
|
if filePath == "" {
|
||||||
return nil, fmt.Errorf("%s: file path is required for JSON format", label)
|
return nil, fmt.Errorf("%s: file path is required for JSON format", label)
|
||||||
}
|
}
|
||||||
reader = json.NewReader(&readers.ReaderOptions{FilePath: filePath})
|
reader = json.NewReader(newReaderOptions(filePath, ""))
|
||||||
case "yaml":
|
case "yaml":
|
||||||
if filePath == "" {
|
if filePath == "" {
|
||||||
return nil, fmt.Errorf("%s: file path is required for YAML format", label)
|
return nil, fmt.Errorf("%s: file path is required for YAML format", label)
|
||||||
}
|
}
|
||||||
reader = yaml.NewReader(&readers.ReaderOptions{FilePath: filePath})
|
reader = yaml.NewReader(newReaderOptions(filePath, ""))
|
||||||
case "gorm":
|
case "gorm":
|
||||||
if filePath == "" {
|
if filePath == "" {
|
||||||
return nil, fmt.Errorf("%s: file path is required for GORM format", label)
|
return nil, fmt.Errorf("%s: file path is required for GORM format", label)
|
||||||
}
|
}
|
||||||
reader = gorm.NewReader(&readers.ReaderOptions{FilePath: filePath})
|
reader = gorm.NewReader(newReaderOptions(filePath, ""))
|
||||||
case "bun":
|
case "bun":
|
||||||
if filePath == "" {
|
if filePath == "" {
|
||||||
return nil, fmt.Errorf("%s: file path is required for Bun format", label)
|
return nil, fmt.Errorf("%s: file path is required for Bun format", label)
|
||||||
}
|
}
|
||||||
reader = bun.NewReader(&readers.ReaderOptions{FilePath: filePath})
|
reader = bun.NewReader(newReaderOptions(filePath, ""))
|
||||||
case "drizzle":
|
case "drizzle":
|
||||||
if filePath == "" {
|
if filePath == "" {
|
||||||
return nil, fmt.Errorf("%s: file path is required for Drizzle format", label)
|
return nil, fmt.Errorf("%s: file path is required for Drizzle format", label)
|
||||||
}
|
}
|
||||||
reader = drizzle.NewReader(&readers.ReaderOptions{FilePath: filePath})
|
reader = drizzle.NewReader(newReaderOptions(filePath, ""))
|
||||||
case "prisma":
|
case "prisma":
|
||||||
if filePath == "" {
|
if filePath == "" {
|
||||||
return nil, fmt.Errorf("%s: file path is required for Prisma format", label)
|
return nil, fmt.Errorf("%s: file path is required for Prisma format", label)
|
||||||
}
|
}
|
||||||
reader = prisma.NewReader(&readers.ReaderOptions{FilePath: filePath})
|
reader = prisma.NewReader(newReaderOptions(filePath, ""))
|
||||||
case "typeorm":
|
case "typeorm":
|
||||||
if filePath == "" {
|
if filePath == "" {
|
||||||
return nil, fmt.Errorf("%s: file path is required for TypeORM format", label)
|
return nil, fmt.Errorf("%s: file path is required for TypeORM format", label)
|
||||||
}
|
}
|
||||||
reader = typeorm.NewReader(&readers.ReaderOptions{FilePath: filePath})
|
reader = typeorm.NewReader(newReaderOptions(filePath, ""))
|
||||||
case "pgsql":
|
case "pgsql":
|
||||||
if connString == "" {
|
if connString == "" {
|
||||||
return nil, fmt.Errorf("%s: connection string is required for PostgreSQL format", label)
|
return nil, fmt.Errorf("%s: connection string is required for PostgreSQL format", label)
|
||||||
}
|
}
|
||||||
reader = pgsql.NewReader(&readers.ReaderOptions{ConnectionString: connString})
|
reader = pgsql.NewReader(newReaderOptions("", connString))
|
||||||
case "sqlite", "sqlite3":
|
case "sqlite", "sqlite3":
|
||||||
// SQLite can use either file path or connection string
|
// SQLite can use either file path or connection string
|
||||||
dbPath := filePath
|
dbPath := filePath
|
||||||
@@ -349,7 +354,7 @@ func readDatabaseForMerge(dbType, filePath, connString, label string) (*models.D
|
|||||||
if dbPath == "" {
|
if dbPath == "" {
|
||||||
return nil, fmt.Errorf("%s: file path or connection string is required for SQLite format", label)
|
return nil, fmt.Errorf("%s: file path or connection string is required for SQLite format", label)
|
||||||
}
|
}
|
||||||
reader = sqlite.NewReader(&readers.ReaderOptions{FilePath: dbPath})
|
reader = sqlite.NewReader(newReaderOptions(dbPath, ""))
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("%s: unsupported format '%s'", label, dbType)
|
return nil, fmt.Errorf("%s: unsupported format '%s'", label, dbType)
|
||||||
}
|
}
|
||||||
@@ -370,61 +375,61 @@ func writeDatabaseForMerge(dbType, filePath, connString string, db *models.Datab
|
|||||||
if filePath == "" {
|
if filePath == "" {
|
||||||
return fmt.Errorf("%s: file path is required for DBML format", label)
|
return fmt.Errorf("%s: file path is required for DBML format", label)
|
||||||
}
|
}
|
||||||
writer = wdbml.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema})
|
writer = wdbml.NewWriter(newWriterOptions(filePath, "", flattenSchema, ""))
|
||||||
case "dctx":
|
case "dctx":
|
||||||
if filePath == "" {
|
if filePath == "" {
|
||||||
return fmt.Errorf("%s: file path is required for DCTX format", label)
|
return fmt.Errorf("%s: file path is required for DCTX format", label)
|
||||||
}
|
}
|
||||||
writer = wdctx.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema})
|
writer = wdctx.NewWriter(newWriterOptions(filePath, "", flattenSchema, ""))
|
||||||
case "drawdb":
|
case "drawdb":
|
||||||
if filePath == "" {
|
if filePath == "" {
|
||||||
return fmt.Errorf("%s: file path is required for DrawDB format", label)
|
return fmt.Errorf("%s: file path is required for DrawDB format", label)
|
||||||
}
|
}
|
||||||
writer = wdrawdb.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema})
|
writer = wdrawdb.NewWriter(newWriterOptions(filePath, "", flattenSchema, ""))
|
||||||
case "graphql":
|
case "graphql":
|
||||||
if filePath == "" {
|
if filePath == "" {
|
||||||
return fmt.Errorf("%s: file path is required for GraphQL format", label)
|
return fmt.Errorf("%s: file path is required for GraphQL format", label)
|
||||||
}
|
}
|
||||||
writer = wgraphql.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema})
|
writer = wgraphql.NewWriter(newWriterOptions(filePath, "", flattenSchema, ""))
|
||||||
case "json":
|
case "json":
|
||||||
if filePath == "" {
|
if filePath == "" {
|
||||||
return fmt.Errorf("%s: file path is required for JSON format", label)
|
return fmt.Errorf("%s: file path is required for JSON format", label)
|
||||||
}
|
}
|
||||||
writer = wjson.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema})
|
writer = wjson.NewWriter(newWriterOptions(filePath, "", flattenSchema, ""))
|
||||||
case "yaml":
|
case "yaml":
|
||||||
if filePath == "" {
|
if filePath == "" {
|
||||||
return fmt.Errorf("%s: file path is required for YAML format", label)
|
return fmt.Errorf("%s: file path is required for YAML format", label)
|
||||||
}
|
}
|
||||||
writer = wyaml.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema})
|
writer = wyaml.NewWriter(newWriterOptions(filePath, "", flattenSchema, ""))
|
||||||
case "gorm":
|
case "gorm":
|
||||||
if filePath == "" {
|
if filePath == "" {
|
||||||
return fmt.Errorf("%s: file path is required for GORM format", label)
|
return fmt.Errorf("%s: file path is required for GORM format", label)
|
||||||
}
|
}
|
||||||
writer = wgorm.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema})
|
writer = wgorm.NewWriter(newWriterOptions(filePath, "", flattenSchema, ""))
|
||||||
case "bun":
|
case "bun":
|
||||||
if filePath == "" {
|
if filePath == "" {
|
||||||
return fmt.Errorf("%s: file path is required for Bun format", label)
|
return fmt.Errorf("%s: file path is required for Bun format", label)
|
||||||
}
|
}
|
||||||
writer = wbun.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema})
|
writer = wbun.NewWriter(newWriterOptions(filePath, "", flattenSchema, ""))
|
||||||
case "drizzle":
|
case "drizzle":
|
||||||
if filePath == "" {
|
if filePath == "" {
|
||||||
return fmt.Errorf("%s: file path is required for Drizzle format", label)
|
return fmt.Errorf("%s: file path is required for Drizzle format", label)
|
||||||
}
|
}
|
||||||
writer = wdrizzle.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema})
|
writer = wdrizzle.NewWriter(newWriterOptions(filePath, "", flattenSchema, ""))
|
||||||
case "prisma":
|
case "prisma":
|
||||||
if filePath == "" {
|
if filePath == "" {
|
||||||
return fmt.Errorf("%s: file path is required for Prisma format", label)
|
return fmt.Errorf("%s: file path is required for Prisma format", label)
|
||||||
}
|
}
|
||||||
writer = wprisma.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema})
|
writer = wprisma.NewWriter(newWriterOptions(filePath, "", flattenSchema, ""))
|
||||||
case "typeorm":
|
case "typeorm":
|
||||||
if filePath == "" {
|
if filePath == "" {
|
||||||
return fmt.Errorf("%s: file path is required for TypeORM format", label)
|
return fmt.Errorf("%s: file path is required for TypeORM format", label)
|
||||||
}
|
}
|
||||||
writer = wtypeorm.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema})
|
writer = wtypeorm.NewWriter(newWriterOptions(filePath, "", flattenSchema, ""))
|
||||||
case "sqlite", "sqlite3":
|
case "sqlite", "sqlite3":
|
||||||
writer = wsqlite.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema})
|
writer = wsqlite.NewWriter(newWriterOptions(filePath, "", flattenSchema, ""))
|
||||||
case "pgsql":
|
case "pgsql":
|
||||||
writerOpts := &writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema}
|
writerOpts := newWriterOptions(filePath, "", flattenSchema, "")
|
||||||
if connString != "" {
|
if connString != "" {
|
||||||
writerOpts.Metadata = map[string]interface{}{
|
writerOpts.Metadata = map[string]interface{}{
|
||||||
"connection_string": connString,
|
"connection_string": connString,
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package main
|
|||||||
import (
|
import (
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -160,3 +161,38 @@ func TestRunMerge_FromListMissingSourceType(t *testing.T) {
|
|||||||
t.Error("expected error when neither --source-path nor --from-list is provided")
|
t.Error("expected error when neither --source-path nor --from-list is provided")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRunMerge_PgsqlOutputRejectsColumnTypeConflict(t *testing.T) {
|
||||||
|
saved := saveMergeState()
|
||||||
|
defer restoreMergeState(saved)
|
||||||
|
|
||||||
|
dir := t.TempDir()
|
||||||
|
targetFile := filepath.Join(dir, "target.json")
|
||||||
|
sourceFile := filepath.Join(dir, "source.json")
|
||||||
|
writeTestJSONWithSingleColumnType(t, targetFile, "users", "integer")
|
||||||
|
writeTestJSONWithSingleColumnType(t, sourceFile, "users", "uuid")
|
||||||
|
|
||||||
|
mergeTargetType = "json"
|
||||||
|
mergeTargetPath = targetFile
|
||||||
|
mergeTargetConn = ""
|
||||||
|
mergeSourceType = "json"
|
||||||
|
mergeSourcePath = sourceFile
|
||||||
|
mergeSourceConn = ""
|
||||||
|
mergeFromList = nil
|
||||||
|
mergeOutputType = "pgsql"
|
||||||
|
mergeOutputPath = ""
|
||||||
|
mergeOutputConn = "postgres://relspec:secret@localhost/testdb"
|
||||||
|
mergeSkipTables = ""
|
||||||
|
mergeReportPath = ""
|
||||||
|
|
||||||
|
err := runMerge(nil, nil)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected pgsql output merge to fail on column type conflict")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "column type conflicts detected") {
|
||||||
|
t.Fatalf("expected conflict summary in error, got: %v", err)
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "public.users.id") {
|
||||||
|
t.Fatalf("expected conflicting column path in error, got: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
24
cmd/relspec/prisma_options.go
Normal file
24
cmd/relspec/prisma_options.go
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"git.warky.dev/wdevs/relspecgo/pkg/readers"
|
||||||
|
"git.warky.dev/wdevs/relspecgo/pkg/writers"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newReaderOptions(filePath, connString string) *readers.ReaderOptions {
|
||||||
|
return &readers.ReaderOptions{
|
||||||
|
FilePath: filePath,
|
||||||
|
ConnectionString: connString,
|
||||||
|
Prisma7: prisma7,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newWriterOptions(outputPath, packageName string, flattenSchema bool, nullableTypes string) *writers.WriterOptions {
|
||||||
|
return &writers.WriterOptions{
|
||||||
|
OutputPath: outputPath,
|
||||||
|
PackageName: packageName,
|
||||||
|
FlattenSchema: flattenSchema,
|
||||||
|
NullableTypes: nullableTypes,
|
||||||
|
Prisma7: prisma7,
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -12,6 +12,7 @@ var (
|
|||||||
// Version information, set via ldflags during build
|
// Version information, set via ldflags during build
|
||||||
version = "dev"
|
version = "dev"
|
||||||
buildDate = "unknown"
|
buildDate = "unknown"
|
||||||
|
prisma7 bool
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
@@ -68,4 +69,5 @@ func init() {
|
|||||||
rootCmd.AddCommand(mergeCmd)
|
rootCmd.AddCommand(mergeCmd)
|
||||||
rootCmd.AddCommand(splitCmd)
|
rootCmd.AddCommand(splitCmd)
|
||||||
rootCmd.AddCommand(versionCmd)
|
rootCmd.AddCommand(versionCmd)
|
||||||
|
rootCmd.PersistentFlags().BoolVar(&prisma7, "prisma7", false, "Use Prisma 7 generator conventions when reading/writing Prisma schemas")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ var (
|
|||||||
splitDatabaseName string
|
splitDatabaseName string
|
||||||
splitExcludeSchema string
|
splitExcludeSchema string
|
||||||
splitExcludeTables string
|
splitExcludeTables string
|
||||||
|
splitNullableTypes string
|
||||||
)
|
)
|
||||||
|
|
||||||
var splitCmd = &cobra.Command{
|
var splitCmd = &cobra.Command{
|
||||||
@@ -110,6 +111,7 @@ func init() {
|
|||||||
splitCmd.Flags().StringVar(&splitTables, "tables", "", "Comma-separated list of table names to include (case-insensitive)")
|
splitCmd.Flags().StringVar(&splitTables, "tables", "", "Comma-separated list of table names to include (case-insensitive)")
|
||||||
splitCmd.Flags().StringVar(&splitExcludeSchema, "exclude-schema", "", "Comma-separated list of schema names to exclude")
|
splitCmd.Flags().StringVar(&splitExcludeSchema, "exclude-schema", "", "Comma-separated list of schema names to exclude")
|
||||||
splitCmd.Flags().StringVar(&splitExcludeTables, "exclude-tables", "", "Comma-separated list of table names to exclude (case-insensitive)")
|
splitCmd.Flags().StringVar(&splitExcludeTables, "exclude-tables", "", "Comma-separated list of table names to exclude (case-insensitive)")
|
||||||
|
splitCmd.Flags().StringVar(&splitNullableTypes, "types", "", "Nullable type package for code-gen writers (bun/gorm): 'resolvespec' (default) or 'stdlib' (database/sql)")
|
||||||
|
|
||||||
err := splitCmd.MarkFlagRequired("from")
|
err := splitCmd.MarkFlagRequired("from")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -185,6 +187,7 @@ func runSplit(cmd *cobra.Command, args []string) error {
|
|||||||
splitPackageName,
|
splitPackageName,
|
||||||
"", // no schema filter for split
|
"", // no schema filter for split
|
||||||
false, // no flatten-schema for split
|
false, // no flatten-schema for split
|
||||||
|
splitNullableTypes,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to write output: %w", err)
|
return fmt.Errorf("failed to write output: %w", err)
|
||||||
|
|||||||
@@ -71,6 +71,40 @@ func writeTestJSON(t *testing.T, path string, tableNames []string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func writeTestJSONWithSingleColumnType(t *testing.T, path, tableName, columnType string) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
db := minimalDatabase{
|
||||||
|
Name: "test_db",
|
||||||
|
Schemas: []minimalSchema{{
|
||||||
|
Name: "public",
|
||||||
|
Tables: []minimalTable{{
|
||||||
|
Name: tableName,
|
||||||
|
Schema: "public",
|
||||||
|
Columns: map[string]minimalColumn{
|
||||||
|
"id": {
|
||||||
|
Name: "id",
|
||||||
|
Table: tableName,
|
||||||
|
Schema: "public",
|
||||||
|
Type: columnType,
|
||||||
|
NotNull: true,
|
||||||
|
IsPrimaryKey: true,
|
||||||
|
AutoIncrement: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := json.Marshal(db)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to marshal test JSON: %v", err)
|
||||||
|
}
|
||||||
|
if err := os.WriteFile(path, data, 0644); err != nil {
|
||||||
|
t.Fatalf("failed to write test file %s: %v", path, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// convertState captures and restores all convert global vars.
|
// convertState captures and restores all convert global vars.
|
||||||
type convertState struct {
|
type convertState struct {
|
||||||
sourceType string
|
sourceType string
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
# Maintainer: Hein (Warky Devs) <hein@warky.dev>
|
# Maintainer: Hein (Warky Devs) <hein@warky.dev>
|
||||||
pkgname=relspec
|
pkgname=relspec
|
||||||
pkgver=1.0.44
|
pkgver=1.0.56
|
||||||
pkgrel=1
|
pkgrel=1
|
||||||
pkgdesc="RelSpec is a comprehensive database relations management tool that reads, transforms, and writes database table specifications across multiple formats and ORMs."
|
pkgdesc="RelSpec is a comprehensive database relations management tool that reads, transforms, and writes database table specifications across multiple formats and ORMs."
|
||||||
arch=('x86_64' 'aarch64')
|
arch=('x86_64' 'aarch64')
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
Name: relspec
|
Name: relspec
|
||||||
Version: 1.0.44
|
Version: 1.0.56
|
||||||
Release: 1%{?dist}
|
Release: 1%{?dist}
|
||||||
Summary: RelSpec is a comprehensive database relations management tool that reads, transforms, and writes database table specifications across multiple formats and ORMs.
|
Summary: RelSpec is a comprehensive database relations management tool that reads, transforms, and writes database table specifications across multiple formats and ORMs.
|
||||||
|
|
||||||
|
|||||||
@@ -22,6 +22,16 @@ type MergeResult struct {
|
|||||||
EnumsAdded int
|
EnumsAdded int
|
||||||
ViewsAdded int
|
ViewsAdded int
|
||||||
SequencesAdded int
|
SequencesAdded int
|
||||||
|
TypeConflicts []ColumnTypeConflict
|
||||||
|
}
|
||||||
|
|
||||||
|
// ColumnTypeConflict describes a column that exists in both schemas but with incompatible types.
|
||||||
|
type ColumnTypeConflict struct {
|
||||||
|
Schema string
|
||||||
|
Table string
|
||||||
|
Column string
|
||||||
|
TargetType string
|
||||||
|
SourceType string
|
||||||
}
|
}
|
||||||
|
|
||||||
// MergeOptions contains options for merge operations
|
// MergeOptions contains options for merge operations
|
||||||
@@ -146,11 +156,19 @@ func (r *MergeResult) mergeColumns(table *models.Table, srcTable *models.Table)
|
|||||||
|
|
||||||
// Merge columns
|
// Merge columns
|
||||||
for colName, srcCol := range srcTable.Columns {
|
for colName, srcCol := range srcTable.Columns {
|
||||||
if _, exists := existingColumns[colName]; !exists {
|
if tgtCol, exists := existingColumns[colName]; !exists {
|
||||||
// Column doesn't exist, add it
|
// Column doesn't exist, add it
|
||||||
newCol := cloneColumn(srcCol)
|
newCol := cloneColumn(srcCol)
|
||||||
table.Columns[colName] = newCol
|
table.Columns[colName] = newCol
|
||||||
r.ColumnsAdded++
|
r.ColumnsAdded++
|
||||||
|
} else if columnTypeConflict(tgtCol, srcCol) {
|
||||||
|
r.TypeConflicts = append(r.TypeConflicts, ColumnTypeConflict{
|
||||||
|
Schema: firstNonEmpty(table.Schema, srcTable.Schema, srcCol.Schema),
|
||||||
|
Table: firstNonEmpty(table.Name, srcTable.Name, srcCol.Table),
|
||||||
|
Column: firstNonEmpty(tgtCol.Name, srcCol.Name, colName),
|
||||||
|
TargetType: describeColumnType(tgtCol),
|
||||||
|
SourceType: describeColumnType(srcCol),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -426,6 +444,52 @@ func cloneColumn(col *models.Column) *models.Column {
|
|||||||
return newCol
|
return newCol
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func columnTypeConflict(target, source *models.Column) bool {
|
||||||
|
if target == nil || source == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return normalizeType(target.Type) != normalizeType(source.Type) ||
|
||||||
|
target.Length != source.Length ||
|
||||||
|
target.Precision != source.Precision ||
|
||||||
|
target.Scale != source.Scale
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeType(value string) string {
|
||||||
|
return strings.ToLower(strings.TrimSpace(value))
|
||||||
|
}
|
||||||
|
|
||||||
|
func describeColumnType(col *models.Column) string {
|
||||||
|
if col == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
typeName := strings.TrimSpace(col.Type)
|
||||||
|
if typeName == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case col.Precision > 0 && col.Scale > 0:
|
||||||
|
return fmt.Sprintf("%s(%d,%d)", typeName, col.Precision, col.Scale)
|
||||||
|
case col.Precision > 0:
|
||||||
|
return fmt.Sprintf("%s(%d)", typeName, col.Precision)
|
||||||
|
case col.Length > 0:
|
||||||
|
return fmt.Sprintf("%s(%d)", typeName, col.Length)
|
||||||
|
default:
|
||||||
|
return typeName
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func firstNonEmpty(values ...string) string {
|
||||||
|
for _, value := range values {
|
||||||
|
if strings.TrimSpace(value) != "" {
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
func cloneConstraint(constraint *models.Constraint) *models.Constraint {
|
func cloneConstraint(constraint *models.Constraint) *models.Constraint {
|
||||||
if constraint == nil {
|
if constraint == nil {
|
||||||
return nil
|
return nil
|
||||||
@@ -609,6 +673,7 @@ func GetMergeSummary(result *MergeResult) string {
|
|||||||
fmt.Sprintf("Enums added: %d", result.EnumsAdded),
|
fmt.Sprintf("Enums added: %d", result.EnumsAdded),
|
||||||
fmt.Sprintf("Relations added: %d", result.RelationsAdded),
|
fmt.Sprintf("Relations added: %d", result.RelationsAdded),
|
||||||
fmt.Sprintf("Domains added: %d", result.DomainsAdded),
|
fmt.Sprintf("Domains added: %d", result.DomainsAdded),
|
||||||
|
fmt.Sprintf("Type conflicts: %d", len(result.TypeConflicts)),
|
||||||
}
|
}
|
||||||
|
|
||||||
totalAdded := result.SchemasAdded + result.TablesAdded + result.ColumnsAdded +
|
totalAdded := result.SchemasAdded + result.TablesAdded + result.ColumnsAdded +
|
||||||
@@ -625,3 +690,35 @@ func GetMergeSummary(result *MergeResult) string {
|
|||||||
|
|
||||||
return summary
|
return summary
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetColumnTypeConflictSummary returns a short, human-readable conflict summary.
|
||||||
|
func GetColumnTypeConflictSummary(result *MergeResult, limit int) string {
|
||||||
|
if result == nil || len(result.TypeConflicts) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if limit <= 0 {
|
||||||
|
limit = len(result.TypeConflicts)
|
||||||
|
}
|
||||||
|
|
||||||
|
lines := make([]string, 0, min(limit, len(result.TypeConflicts))+1)
|
||||||
|
lines = append(lines, "column type conflicts detected:")
|
||||||
|
for i, conflict := range result.TypeConflicts {
|
||||||
|
if i >= limit {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
lines = append(lines, fmt.Sprintf(" - %s.%s.%s: target=%s source=%s",
|
||||||
|
conflict.Schema, conflict.Table, conflict.Column, conflict.TargetType, conflict.SourceType))
|
||||||
|
}
|
||||||
|
if len(result.TypeConflicts) > limit {
|
||||||
|
lines = append(lines, fmt.Sprintf(" ... and %d more", len(result.TypeConflicts)-limit))
|
||||||
|
}
|
||||||
|
|
||||||
|
return strings.Join(lines, "\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
func min(a, b int) int {
|
||||||
|
if a < b {
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package merge
|
package merge
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||||
@@ -140,6 +141,61 @@ func TestMergeColumns_NewColumn(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMergeColumns_TypeConflictIsDetected(t *testing.T) {
|
||||||
|
target := &models.Database{
|
||||||
|
Schemas: []*models.Schema{
|
||||||
|
{
|
||||||
|
Name: "public",
|
||||||
|
Tables: []*models.Table{
|
||||||
|
{
|
||||||
|
Name: "users",
|
||||||
|
Schema: "public",
|
||||||
|
Columns: map[string]*models.Column{
|
||||||
|
"email": {Name: "email", Type: "varchar", Length: 255},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
source := &models.Database{
|
||||||
|
Schemas: []*models.Schema{
|
||||||
|
{
|
||||||
|
Name: "public",
|
||||||
|
Tables: []*models.Table{
|
||||||
|
{
|
||||||
|
Name: "users",
|
||||||
|
Schema: "public",
|
||||||
|
Columns: map[string]*models.Column{
|
||||||
|
"email": {Name: "email", Type: "text"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := MergeDatabases(target, source, nil)
|
||||||
|
|
||||||
|
if len(result.TypeConflicts) != 1 {
|
||||||
|
t.Fatalf("Expected 1 type conflict, got %d", len(result.TypeConflicts))
|
||||||
|
}
|
||||||
|
conflict := result.TypeConflicts[0]
|
||||||
|
if conflict.Schema != "public" || conflict.Table != "users" || conflict.Column != "email" {
|
||||||
|
t.Fatalf("Unexpected conflict location: %+v", conflict)
|
||||||
|
}
|
||||||
|
if conflict.TargetType != "varchar(255)" {
|
||||||
|
t.Fatalf("Expected target type varchar(255), got %q", conflict.TargetType)
|
||||||
|
}
|
||||||
|
if conflict.SourceType != "text" {
|
||||||
|
t.Fatalf("Expected source type text, got %q", conflict.SourceType)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := target.Schemas[0].Tables[0].Columns["email"].Type; got != "varchar" {
|
||||||
|
t.Fatalf("Expected target column type to remain unchanged, got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestMergeConstraints_NewConstraint(t *testing.T) {
|
func TestMergeConstraints_NewConstraint(t *testing.T) {
|
||||||
target := &models.Database{
|
target := &models.Database{
|
||||||
Schemas: []*models.Schema{
|
Schemas: []*models.Schema{
|
||||||
@@ -509,6 +565,9 @@ func TestGetMergeSummary(t *testing.T) {
|
|||||||
ConstraintsAdded: 3,
|
ConstraintsAdded: 3,
|
||||||
IndexesAdded: 2,
|
IndexesAdded: 2,
|
||||||
ViewsAdded: 1,
|
ViewsAdded: 1,
|
||||||
|
TypeConflicts: []ColumnTypeConflict{
|
||||||
|
{Schema: "public", Table: "users", Column: "email", TargetType: "varchar(255)", SourceType: "text"},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
summary := GetMergeSummary(result)
|
summary := GetMergeSummary(result)
|
||||||
@@ -518,6 +577,9 @@ func TestGetMergeSummary(t *testing.T) {
|
|||||||
if len(summary) < 50 {
|
if len(summary) < 50 {
|
||||||
t.Errorf("Summary seems too short: %s", summary)
|
t.Errorf("Summary seems too short: %s", summary)
|
||||||
}
|
}
|
||||||
|
if !strings.Contains(summary, "Type conflicts: 1") {
|
||||||
|
t.Errorf("Expected type conflict count in summary, got: %s", summary)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGetMergeSummary_Nil(t *testing.T) {
|
func TestGetMergeSummary_Nil(t *testing.T) {
|
||||||
|
|||||||
@@ -145,6 +145,24 @@ var postgresTypeAliases = map[string]string{
|
|||||||
"bool": "boolean",
|
"bool": "boolean",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var postgresEquivalentBaseTypes = map[string]string{
|
||||||
|
"character varying": "varchar",
|
||||||
|
"character": "char",
|
||||||
|
"timestamp without time zone": "timestamp",
|
||||||
|
"timestamp with time zone": "timestamptz",
|
||||||
|
"time without time zone": "time",
|
||||||
|
"time with time zone": "timetz",
|
||||||
|
}
|
||||||
|
|
||||||
|
var postgresEquivalentBaseTypeVariants = map[string][]string{
|
||||||
|
"varchar": {"varchar", "character varying"},
|
||||||
|
"char": {"char", "character"},
|
||||||
|
"timestamp": {"timestamp", "timestamp without time zone"},
|
||||||
|
"timestamptz": {"timestamptz", "timestamp with time zone"},
|
||||||
|
"time": {"time", "time without time zone"},
|
||||||
|
"timetz": {"timetz", "time with time zone"},
|
||||||
|
}
|
||||||
|
|
||||||
// GetPostgresBaseTypes returns a sorted-ish stable list of registered base type names.
|
// GetPostgresBaseTypes returns a sorted-ish stable list of registered base type names.
|
||||||
func GetPostgresBaseTypes() []string {
|
func GetPostgresBaseTypes() []string {
|
||||||
result := make([]string, 0, len(postgresBaseTypes))
|
result := make([]string, 0, len(postgresBaseTypes))
|
||||||
@@ -212,6 +230,86 @@ func CanonicalizeBaseType(baseType string) string {
|
|||||||
return base
|
return base
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// EquivalentBaseType resolves broader SQL-equivalent spellings to a common comparable form.
|
||||||
|
func EquivalentBaseType(baseType string) string {
|
||||||
|
base := CanonicalizeBaseType(baseType)
|
||||||
|
if equivalent, ok := postgresEquivalentBaseTypes[base]; ok {
|
||||||
|
return equivalent
|
||||||
|
}
|
||||||
|
return base
|
||||||
|
}
|
||||||
|
|
||||||
|
// NormalizeEquivalentSQLType returns a normalized SQL type string suitable for equality checks.
|
||||||
|
// Equivalent spellings such as "character varying(255)" and "varchar(255)" normalize identically.
|
||||||
|
func NormalizeEquivalentSQLType(sqlType string) string {
|
||||||
|
t := normalizeTypeToken(sqlType)
|
||||||
|
if t == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
arrayDepth := 0
|
||||||
|
for strings.HasSuffix(t, "[]") {
|
||||||
|
arrayDepth++
|
||||||
|
t = strings.TrimSpace(strings.TrimSuffix(t, "[]"))
|
||||||
|
}
|
||||||
|
|
||||||
|
modifier := ""
|
||||||
|
if idx := strings.Index(t, "("); idx >= 0 {
|
||||||
|
modifier = strings.TrimSpace(t[idx:])
|
||||||
|
t = strings.TrimSpace(t[:idx])
|
||||||
|
}
|
||||||
|
|
||||||
|
base := EquivalentBaseType(t)
|
||||||
|
normalized := base + modifier
|
||||||
|
for i := 0; i < arrayDepth; i++ {
|
||||||
|
normalized += "[]"
|
||||||
|
}
|
||||||
|
return normalized
|
||||||
|
}
|
||||||
|
|
||||||
|
// EquivalentSQLTypeVariants returns equivalent PostgreSQL spellings for a SQL type.
|
||||||
|
// Examples:
|
||||||
|
// - varchar(255) -> ["varchar(255)", "character varying(255)"]
|
||||||
|
// - timestamptz -> ["timestamptz", "timestamp with time zone"]
|
||||||
|
func EquivalentSQLTypeVariants(sqlType string) []string {
|
||||||
|
t := normalizeTypeToken(sqlType)
|
||||||
|
if t == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
arrayDepth := 0
|
||||||
|
for strings.HasSuffix(t, "[]") {
|
||||||
|
arrayDepth++
|
||||||
|
t = strings.TrimSpace(strings.TrimSuffix(t, "[]"))
|
||||||
|
}
|
||||||
|
|
||||||
|
modifier := ""
|
||||||
|
if idx := strings.Index(t, "("); idx >= 0 {
|
||||||
|
modifier = strings.TrimSpace(t[idx:])
|
||||||
|
t = strings.TrimSpace(t[:idx])
|
||||||
|
}
|
||||||
|
|
||||||
|
base := EquivalentBaseType(t)
|
||||||
|
bases := postgresEquivalentBaseTypeVariants[base]
|
||||||
|
if len(bases) == 0 {
|
||||||
|
bases = []string{base}
|
||||||
|
}
|
||||||
|
|
||||||
|
seen := make(map[string]bool, len(bases))
|
||||||
|
result := make([]string, 0, len(bases))
|
||||||
|
for _, variantBase := range bases {
|
||||||
|
variant := variantBase + modifier
|
||||||
|
for i := 0; i < arrayDepth; i++ {
|
||||||
|
variant += "[]"
|
||||||
|
}
|
||||||
|
if !seen[variant] {
|
||||||
|
seen[variant] = true
|
||||||
|
result = append(result, variant)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
// IsKnownPostgresType reports whether a type (including array forms) exists in the registry.
|
// IsKnownPostgresType reports whether a type (including array forms) exists in the registry.
|
||||||
func IsKnownPostgresType(sqlType string) bool {
|
func IsKnownPostgresType(sqlType string) bool {
|
||||||
base := CanonicalizeBaseType(ExtractBaseTypeLower(sqlType))
|
base := CanonicalizeBaseType(ExtractBaseTypeLower(sqlType))
|
||||||
|
|||||||
@@ -97,3 +97,51 @@ func TestPostgresTypeRegistry_TypeParsingAndCapabilities(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNormalizeEquivalentSQLType(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
input string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{input: "character varying(255)", want: "varchar(255)"},
|
||||||
|
{input: "varchar(255)", want: "varchar(255)"},
|
||||||
|
{input: "timestamp with time zone", want: "timestamptz"},
|
||||||
|
{input: "timestamptz", want: "timestamptz"},
|
||||||
|
{input: "time without time zone", want: "time"},
|
||||||
|
{input: "character varying(255)[]", want: "varchar(255)[]"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.input, func(t *testing.T) {
|
||||||
|
got := NormalizeEquivalentSQLType(tt.input)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Fatalf("NormalizeEquivalentSQLType(%q) = %q, want %q", tt.input, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEquivalentSQLTypeVariants(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
input string
|
||||||
|
want []string
|
||||||
|
}{
|
||||||
|
{input: "character varying(255)", want: []string{"varchar(255)", "character varying(255)"}},
|
||||||
|
{input: "timestamptz", want: []string{"timestamptz", "timestamp with time zone"}},
|
||||||
|
{input: "text[]", want: []string{"text[]"}},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.input, func(t *testing.T) {
|
||||||
|
got := EquivalentSQLTypeVariants(tt.input)
|
||||||
|
if len(got) != len(tt.want) {
|
||||||
|
t.Fatalf("EquivalentSQLTypeVariants(%q) len = %d, want %d (%v)", tt.input, len(got), len(tt.want), got)
|
||||||
|
}
|
||||||
|
for i := range tt.want {
|
||||||
|
if got[i] != tt.want[i] {
|
||||||
|
t.Fatalf("EquivalentSQLTypeVariants(%q)[%d] = %q, want %q", tt.input, i, got[i], tt.want[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -252,7 +252,7 @@ func (r *Reader) queryColumns(schemaName string) (map[string]map[string]*models.
|
|||||||
column.AutoIncrement = true
|
column.AutoIncrement = true
|
||||||
column.Default = defaultVal
|
column.Default = defaultVal
|
||||||
} else {
|
} else {
|
||||||
column.Default = defaultVal
|
column.Default = normalizePostgresDefault(defaultVal)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -613,3 +613,30 @@ func (r *Reader) parseIndexDefinition(indexName, tableName, schema, indexDef str
|
|||||||
|
|
||||||
return index, nil
|
return index, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// normalizePostgresDefault converts a raw PostgreSQL column_default expression into the
|
||||||
|
// unquoted string value that the model convention expects. PostgreSQL stores string
|
||||||
|
// literal defaults as 'value' or 'value'::type (e.g. '{}'::text[]), while every other
|
||||||
|
// reader stores the bare value so the writer can re-quote it correctly.
|
||||||
|
func normalizePostgresDefault(defaultVal string) string {
|
||||||
|
if !strings.HasPrefix(defaultVal, "'") {
|
||||||
|
return defaultVal
|
||||||
|
}
|
||||||
|
// Decode the SQL string literal: skip the leading quote, unescape '' → ', stop at
|
||||||
|
// the first unescaped closing quote (any trailing ::cast is ignored).
|
||||||
|
rest := defaultVal[1:]
|
||||||
|
var buf strings.Builder
|
||||||
|
for i := 0; i < len(rest); i++ {
|
||||||
|
if rest[i] == '\'' {
|
||||||
|
if i+1 < len(rest) && rest[i+1] == '\'' {
|
||||||
|
buf.WriteByte('\'')
|
||||||
|
i++
|
||||||
|
} else {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
buf.WriteByte(rest[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return buf.String()
|
||||||
|
}
|
||||||
|
|||||||
@@ -70,6 +70,7 @@ func (r *Reader) ReadTable() (*models.Table, error) {
|
|||||||
// parsePrisma parses Prisma schema content and returns a Database model
|
// parsePrisma parses Prisma schema content and returns a Database model
|
||||||
func (r *Reader) parsePrisma(content string) (*models.Database, error) {
|
func (r *Reader) parsePrisma(content string) (*models.Database, error) {
|
||||||
db := models.InitDatabase("database")
|
db := models.InitDatabase("database")
|
||||||
|
db.SourceFormat = "prisma"
|
||||||
|
|
||||||
if r.options.Metadata != nil {
|
if r.options.Metadata != nil {
|
||||||
if name, ok := r.options.Metadata["name"].(string); ok {
|
if name, ok := r.options.Metadata["name"].(string); ok {
|
||||||
@@ -139,7 +140,7 @@ func (r *Reader) parsePrisma(content string) (*models.Database, error) {
|
|||||||
case "datasource":
|
case "datasource":
|
||||||
r.parseDatasource(blockContent, db)
|
r.parseDatasource(blockContent, db)
|
||||||
case "generator":
|
case "generator":
|
||||||
// We don't need to do anything with generator blocks
|
r.parseGenerator(blockContent, db)
|
||||||
case "model":
|
case "model":
|
||||||
if currentTable != nil {
|
if currentTable != nil {
|
||||||
r.parseModelFields(blockContent, currentTable)
|
r.parseModelFields(blockContent, currentTable)
|
||||||
@@ -173,10 +174,34 @@ func (r *Reader) parsePrisma(content string) (*models.Database, error) {
|
|||||||
// Second pass: resolve relationships
|
// Second pass: resolve relationships
|
||||||
r.resolveRelationships(schema)
|
r.resolveRelationships(schema)
|
||||||
|
|
||||||
|
if db.SourceFormat == "prisma" && r.options != nil && r.options.Prisma7 {
|
||||||
|
db.SourceFormat = "prisma7"
|
||||||
|
}
|
||||||
|
|
||||||
db.Schemas = append(db.Schemas, schema)
|
db.Schemas = append(db.Schemas, schema)
|
||||||
return db, nil
|
return db, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *Reader) parseGenerator(lines []string, db *models.Database) {
|
||||||
|
providerRegex := regexp.MustCompile(`provider\s*=\s*"([^"]+)"`)
|
||||||
|
|
||||||
|
for _, line := range lines {
|
||||||
|
if matches := providerRegex.FindStringSubmatch(line); matches != nil {
|
||||||
|
switch matches[1] {
|
||||||
|
case "prisma-client":
|
||||||
|
db.SourceFormat = "prisma7"
|
||||||
|
default:
|
||||||
|
db.SourceFormat = "prisma"
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.options != nil && r.options.Prisma7 {
|
||||||
|
db.SourceFormat = "prisma7"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// parseDatasource extracts database type from datasource block
|
// parseDatasource extracts database type from datasource block
|
||||||
func (r *Reader) parseDatasource(lines []string, db *models.Database) {
|
func (r *Reader) parseDatasource(lines []string, db *models.Database) {
|
||||||
providerRegex := regexp.MustCompile(`provider\s*=\s*"?(\w+)"?`)
|
providerRegex := regexp.MustCompile(`provider\s*=\s*"?(\w+)"?`)
|
||||||
|
|||||||
77
pkg/readers/prisma/reader_test.go
Normal file
77
pkg/readers/prisma/reader_test.go
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
package prisma
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.warky.dev/wdevs/relspecgo/pkg/readers"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestReadDatabase_Prisma7GeneratorSetsSourceFormat(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
schemaPath := filepath.Join(tmpDir, "schema.prisma")
|
||||||
|
|
||||||
|
content := `datasource db {
|
||||||
|
provider = "postgresql"
|
||||||
|
url = env("DATABASE_URL")
|
||||||
|
}
|
||||||
|
|
||||||
|
generator client {
|
||||||
|
provider = "prisma-client"
|
||||||
|
output = "./generated"
|
||||||
|
}
|
||||||
|
|
||||||
|
model User {
|
||||||
|
id Int @id @default(autoincrement())
|
||||||
|
}`
|
||||||
|
|
||||||
|
if err := os.WriteFile(schemaPath, []byte(content), 0644); err != nil {
|
||||||
|
t.Fatalf("failed to write schema: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
reader := NewReader(&readers.ReaderOptions{FilePath: schemaPath})
|
||||||
|
db, err := reader.ReadDatabase()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ReadDatabase() failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if db.SourceFormat != "prisma7" {
|
||||||
|
t.Fatalf("expected SourceFormat prisma7, got %q", db.SourceFormat)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadDatabase_Prisma7FlagSetsSourceFormatWithoutGenerator(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
schemaPath := filepath.Join(tmpDir, "schema.prisma")
|
||||||
|
|
||||||
|
content := `datasource db {
|
||||||
|
provider = "postgresql"
|
||||||
|
url = env("DATABASE_URL")
|
||||||
|
}
|
||||||
|
|
||||||
|
model User {
|
||||||
|
id Int @id @default(autoincrement())
|
||||||
|
}`
|
||||||
|
|
||||||
|
if err := os.WriteFile(schemaPath, []byte(content), 0644); err != nil {
|
||||||
|
t.Fatalf("failed to write schema: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
reader := NewReader(&readers.ReaderOptions{
|
||||||
|
FilePath: schemaPath,
|
||||||
|
Prisma7: true,
|
||||||
|
})
|
||||||
|
db, err := reader.ReadDatabase()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ReadDatabase() failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if db.SourceFormat != "prisma7" {
|
||||||
|
t.Fatalf("expected SourceFormat prisma7 from flag, got %q", db.SourceFormat)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -25,6 +25,9 @@ type ReaderOptions struct {
|
|||||||
// ConnectionString is the database connection string (for DB readers)
|
// ConnectionString is the database connection string (for DB readers)
|
||||||
ConnectionString string
|
ConnectionString string
|
||||||
|
|
||||||
|
// Prisma7 enables Prisma 7-specific handling for Prisma schemas.
|
||||||
|
Prisma7 bool
|
||||||
|
|
||||||
// Additional options can be added here as needed
|
// Additional options can be added here as needed
|
||||||
Metadata map[string]interface{}
|
Metadata map[string]interface{}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -46,54 +46,67 @@ func main() {
|
|||||||
### CLI Examples
|
### CLI Examples
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Generate Bun models from PostgreSQL database
|
# Generate Bun models from a DBML schema (default: resolvespec types)
|
||||||
relspec --input pgsql \
|
relspec convert --from dbml --from-path schema.dbml \
|
||||||
--conn "postgres://localhost/mydb" \
|
--to bun --to-path models.go --package models
|
||||||
--output bun \
|
|
||||||
--out-file models.go \
|
|
||||||
--package models
|
|
||||||
|
|
||||||
# Convert GORM models to Bun
|
# Use standard library database/sql nullable types instead of resolvespec
|
||||||
relspec --input gorm --in-file gorm_models.go --output bun --out-file bun_models.go
|
relspec convert --from dbml --from-path schema.dbml \
|
||||||
|
--to bun --to-path models.go --package models \
|
||||||
|
--types stdlib
|
||||||
|
|
||||||
# Multi-file output
|
# Explicitly select resolvespec types (same as omitting --types)
|
||||||
relspec --input json --in-file schema.json --output bun --out-file models/
|
relspec convert --from pgsql --from-conn "postgres://localhost/mydb" \
|
||||||
|
--to bun --to-path models.go --package models \
|
||||||
|
--types resolvespec
|
||||||
|
|
||||||
|
# Multi-file output (one file per table)
|
||||||
|
relspec convert --from json --from-path schema.json \
|
||||||
|
--to bun --to-path models/ --package models
|
||||||
```
|
```
|
||||||
|
|
||||||
## Generated Code Example
|
## Generated Code Examples
|
||||||
|
|
||||||
|
### Default — resolvespec types (`--types resolvespec`)
|
||||||
|
|
||||||
```go
|
```go
|
||||||
package models
|
package models
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"time"
|
resolvespec_common "github.com/bitechdev/ResolveSpec/pkg/spectypes"
|
||||||
"database/sql"
|
|
||||||
"github.com/uptrace/bun"
|
"github.com/uptrace/bun"
|
||||||
)
|
)
|
||||||
|
|
||||||
type User struct {
|
type User struct {
|
||||||
bun.BaseModel `bun:"table:users,alias:u"`
|
bun.BaseModel `bun:"table:users,alias:u"`
|
||||||
|
|
||||||
ID int64 `bun:"id,pk,autoincrement" json:"id"`
|
ID int64 `bun:"id,type:uuid,pk," json:"id"`
|
||||||
Username string `bun:"username,notnull,unique" json:"username"`
|
Username string `bun:"username,type:text,notnull," json:"username"`
|
||||||
Email string `bun:"email,notnull" json:"email"`
|
Email resolvespec_common.SqlString `bun:"email,type:text,nullzero," json:"email"`
|
||||||
Bio sql.NullString `bun:"bio" json:"bio,omitempty"`
|
Tags resolvespec_common.SqlStringArray `bun:"tags,type:text[],default:'{}',notnull," json:"tags"`
|
||||||
CreatedAt time.Time `bun:"created_at,notnull,default:now()" json:"created_at"`
|
CreatedAt resolvespec_common.SqlTimeStamp `bun:"created_at,type:timestamptz,default:now(),notnull," json:"created_at"`
|
||||||
|
|
||||||
// Relationships
|
|
||||||
Posts []*Post `bun:"rel:has-many,join:id=user_id" json:"posts,omitempty"`
|
|
||||||
}
|
}
|
||||||
|
```
|
||||||
|
|
||||||
type Post struct {
|
### Standard library — `--types stdlib`
|
||||||
bun.BaseModel `bun:"table:posts,alias:p"`
|
|
||||||
|
|
||||||
ID int64 `bun:"id,pk" json:"id"`
|
```go
|
||||||
UserID int64 `bun:"user_id,notnull" json:"user_id"`
|
package models
|
||||||
Title string `bun:"title,notnull" json:"title"`
|
|
||||||
Content sql.NullString `bun:"content" json:"content,omitempty"`
|
|
||||||
|
|
||||||
// Belongs to
|
import (
|
||||||
User *User `bun:"rel:belongs-to,join:user_id=id" json:"user,omitempty"`
|
"database/sql"
|
||||||
|
"time"
|
||||||
|
"github.com/uptrace/bun"
|
||||||
|
)
|
||||||
|
|
||||||
|
type User struct {
|
||||||
|
bun.BaseModel `bun:"table:users,alias:u"`
|
||||||
|
|
||||||
|
ID string `bun:"id,type:uuid,pk," json:"id"`
|
||||||
|
Username string `bun:"username,type:text,notnull," json:"username"`
|
||||||
|
Email sql.NullString `bun:"email,type:text,nullzero," json:"email"`
|
||||||
|
Tags []string `bun:"tags,type:text[],default:'{}',notnull," json:"tags"`
|
||||||
|
CreatedAt time.Time `bun:"created_at,type:timestamptz,default:now(),notnull," json:"created_at"`
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -111,19 +124,68 @@ type Post struct {
|
|||||||
|
|
||||||
## Type Mapping
|
## Type Mapping
|
||||||
|
|
||||||
| SQL Type | Go Type | Nullable Type |
|
The nullable type package is selected with `--types` (or `WriterOptions.NullableTypes`).
|
||||||
|----------|---------|---------------|
|
|
||||||
| bigint | int64 | sql.NullInt64 |
|
| SQL Type | NOT NULL (both) | Nullable — resolvespec | Nullable — stdlib |
|
||||||
| integer | int | sql.NullInt32 |
|
|---|---|---|---|
|
||||||
| varchar, text | string | sql.NullString |
|
| `bigint` | `int64` | `SqlInt64` | `sql.NullInt64` |
|
||||||
| boolean | bool | sql.NullBool |
|
| `integer` | `int32` | `SqlInt32` | `sql.NullInt32` |
|
||||||
| timestamp | time.Time | sql.NullTime |
|
| `smallint` | `int16` | `SqlInt16` | `sql.NullInt16` |
|
||||||
| numeric | float64 | sql.NullFloat64 |
|
| `text`, `varchar` | `string` | `SqlString` | `sql.NullString` |
|
||||||
|
| `boolean` | `bool` | `SqlBool` | `sql.NullBool` |
|
||||||
|
| `timestamp`, `timestamptz` | `time.Time`* | `SqlTimeStamp` | `sql.NullTime` |
|
||||||
|
| `numeric`, `decimal` | `float64` | `SqlFloat64` | `sql.NullFloat64` |
|
||||||
|
| `uuid` | `string` | `SqlUUID` | `sql.NullString` |
|
||||||
|
| `jsonb` | `string` | `SqlJSONB` | `sql.NullString` |
|
||||||
|
| `text[]` | `SqlStringArray` | `SqlStringArray` | `[]string` |
|
||||||
|
| `integer[]` | `SqlInt32Array` | `SqlInt32Array` | `[]int32` |
|
||||||
|
| `uuid[]` | `SqlUUIDArray` | `SqlUUIDArray` | `[]string` |
|
||||||
|
| `vector` | `SqlVector` | `SqlVector` | `[]float32` |
|
||||||
|
|
||||||
|
\* In resolvespec mode, NOT NULL timestamps use `SqlTimeStamp` (not `time.Time`) unless the base type is a simple integer or boolean. In stdlib mode, NOT NULL timestamps use `time.Time`.
|
||||||
|
|
||||||
|
## Writer Options
|
||||||
|
|
||||||
|
### NullableTypes
|
||||||
|
|
||||||
|
Controls which Go package is used for nullable column types. Set via the `--types` CLI flag or `WriterOptions.NullableTypes`:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Use resolvespec types (default — omit NullableTypes or set to "resolvespec")
|
||||||
|
options := &writers.WriterOptions{
|
||||||
|
OutputPath: "models.go",
|
||||||
|
PackageName: "models",
|
||||||
|
NullableTypes: writers.NullableTypeResolveSpec,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use standard library database/sql types
|
||||||
|
options := &writers.WriterOptions{
|
||||||
|
OutputPath: "models.go",
|
||||||
|
PackageName: "models",
|
||||||
|
NullableTypes: writers.NullableTypeStdlib,
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Metadata Options
|
||||||
|
|
||||||
|
```go
|
||||||
|
options := &writers.WriterOptions{
|
||||||
|
OutputPath: "models.go",
|
||||||
|
PackageName: "models",
|
||||||
|
Metadata: map[string]any{
|
||||||
|
"multi_file": true, // Enable multi-file mode
|
||||||
|
"populate_refs": true, // Populate RefDatabase/RefSchema
|
||||||
|
"generate_get_id_str": true, // Generate GetIDStr() methods
|
||||||
|
},
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
## Notes
|
## Notes
|
||||||
|
|
||||||
- Model names are derived from table names (singularized, PascalCase)
|
- Model names are derived from table names (singularized, PascalCase)
|
||||||
- Table aliases are auto-generated from table names
|
- Table aliases are auto-generated from table names
|
||||||
|
- Nullable columns use `resolvespec_common.SqlString`, `resolvespec_common.SqlTimeStamp`, etc. by default; pass `--types stdlib` to use `sql.NullString`, `sql.NullTime`, etc. instead
|
||||||
|
- Array columns use `resolvespec_common.SqlStringArray`, `resolvespec_common.SqlInt32Array`, etc. by default; `--types stdlib` produces plain Go slices (`[]string`, `[]int32`, …)
|
||||||
- Multi-file mode: one file per table named `sql_{schema}_{table}.go`
|
- Multi-file mode: one file per table named `sql_{schema}_{table}.go`
|
||||||
- Generated code is auto-formatted
|
- Generated code is auto-formatted
|
||||||
- JSON tags are automatically added
|
- JSON tags are automatically added
|
||||||
|
|||||||
@@ -26,7 +26,10 @@ type ModelData struct {
|
|||||||
Fields []*FieldData
|
Fields []*FieldData
|
||||||
Config *MethodConfig
|
Config *MethodConfig
|
||||||
PrimaryKeyField string // Name of the primary key field
|
PrimaryKeyField string // Name of the primary key field
|
||||||
|
PrimaryKeyType string // Go type of the primary key field
|
||||||
PrimaryKeyIsSQL bool // Whether PK uses SQL type (needs .Int64() call)
|
PrimaryKeyIsSQL bool // Whether PK uses SQL type (needs .Int64() call)
|
||||||
|
PrimaryKeyIsStr bool // Whether helper methods should use string IDs
|
||||||
|
PrimaryKeyIDType string // Helper method GetID/SetID/UpdateID type
|
||||||
IDColumnName string // Name of the ID column in database
|
IDColumnName string // Name of the ID column in database
|
||||||
Prefix string // 3-letter prefix
|
Prefix string // 3-letter prefix
|
||||||
}
|
}
|
||||||
@@ -140,7 +143,13 @@ func NewModelData(table *models.Table, schema string, typeMapper *TypeMapper, fl
|
|||||||
model.IDColumnName = safeName
|
model.IDColumnName = safeName
|
||||||
// Check if PK type is a SQL type (contains resolvespec_common or sql_types)
|
// Check if PK type is a SQL type (contains resolvespec_common or sql_types)
|
||||||
goType := typeMapper.SQLTypeToGoType(col.Type, col.NotNull)
|
goType := typeMapper.SQLTypeToGoType(col.Type, col.NotNull)
|
||||||
|
model.PrimaryKeyType = goType
|
||||||
model.PrimaryKeyIsSQL = strings.Contains(goType, "resolvespec_common") || strings.Contains(goType, "sql_types")
|
model.PrimaryKeyIsSQL = strings.Contains(goType, "resolvespec_common") || strings.Contains(goType, "sql_types")
|
||||||
|
model.PrimaryKeyIsStr = isStringLikePrimaryKeyType(goType)
|
||||||
|
model.PrimaryKeyIDType = "int64"
|
||||||
|
if model.PrimaryKeyIsStr {
|
||||||
|
model.PrimaryKeyIDType = "string"
|
||||||
|
}
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -192,6 +201,15 @@ func formatComment(description, comment string) string {
|
|||||||
return comment
|
return comment
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isStringLikePrimaryKeyType(goType string) bool {
|
||||||
|
switch goType {
|
||||||
|
case "string", "sql.NullString", "resolvespec_common.SqlString", "resolvespec_common.SqlUUID":
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// resolveFieldNameCollision checks if a field name conflicts with generated method names
|
// resolveFieldNameCollision checks if a field name conflicts with generated method names
|
||||||
// and adds an underscore suffix if there's a collision
|
// and adds an underscore suffix if there's a collision
|
||||||
func resolveFieldNameCollision(fieldName string) string {
|
func resolveFieldNameCollision(fieldName string) string {
|
||||||
|
|||||||
@@ -44,33 +44,55 @@ func (m {{.Name}}) SchemaName() string {
|
|||||||
{{end}}
|
{{end}}
|
||||||
{{if and .Config.GenerateGetID .PrimaryKeyField}}
|
{{if and .Config.GenerateGetID .PrimaryKeyField}}
|
||||||
// GetID returns the primary key value
|
// GetID returns the primary key value
|
||||||
func (m {{.Name}}) GetID() int64 {
|
func (m {{.Name}}) GetID() {{.PrimaryKeyIDType}} {
|
||||||
{{if .PrimaryKeyIsSQL -}}
|
{{if .PrimaryKeyIsSQL -}}
|
||||||
|
{{if .PrimaryKeyIsStr -}}
|
||||||
|
return m.{{.PrimaryKeyField}}.String()
|
||||||
|
{{- else -}}
|
||||||
return m.{{.PrimaryKeyField}}.Int64()
|
return m.{{.PrimaryKeyField}}.Int64()
|
||||||
|
{{- end}}
|
||||||
|
{{- else -}}
|
||||||
|
{{if .PrimaryKeyIsStr -}}
|
||||||
|
return m.{{.PrimaryKeyField}}
|
||||||
{{- else -}}
|
{{- else -}}
|
||||||
return int64(m.{{.PrimaryKeyField}})
|
return int64(m.{{.PrimaryKeyField}})
|
||||||
{{- end}}
|
{{- end}}
|
||||||
|
{{- end}}
|
||||||
}
|
}
|
||||||
{{end}}
|
{{end}}
|
||||||
{{if and .Config.GenerateGetIDStr .PrimaryKeyField}}
|
{{if and .Config.GenerateGetIDStr .PrimaryKeyField}}
|
||||||
// GetIDStr returns the primary key as a string
|
// GetIDStr returns the primary key as a string
|
||||||
func (m {{.Name}}) GetIDStr() string {
|
func (m {{.Name}}) GetIDStr() string {
|
||||||
|
{{if .PrimaryKeyIsSQL -}}
|
||||||
|
return m.{{.PrimaryKeyField}}.String()
|
||||||
|
{{- else if .PrimaryKeyIsStr -}}
|
||||||
|
return m.{{.PrimaryKeyField}}
|
||||||
|
{{- else -}}
|
||||||
return fmt.Sprintf("%d", m.{{.PrimaryKeyField}})
|
return fmt.Sprintf("%d", m.{{.PrimaryKeyField}})
|
||||||
|
{{- end}}
|
||||||
}
|
}
|
||||||
{{end}}
|
{{end}}
|
||||||
{{if and .Config.GenerateSetID .PrimaryKeyField}}
|
{{if and .Config.GenerateSetID .PrimaryKeyField}}
|
||||||
// SetID sets the primary key value
|
// SetID sets the primary key value
|
||||||
func (m {{.Name}}) SetID(newid int64) {
|
func (m {{.Name}}) SetID(newid {{.PrimaryKeyIDType}}) {
|
||||||
m.UpdateID(newid)
|
m.UpdateID(newid)
|
||||||
}
|
}
|
||||||
{{end}}
|
{{end}}
|
||||||
{{if and .Config.GenerateUpdateID .PrimaryKeyField}}
|
{{if and .Config.GenerateUpdateID .PrimaryKeyField}}
|
||||||
// UpdateID updates the primary key value
|
// UpdateID updates the primary key value
|
||||||
func (m *{{.Name}}) UpdateID(newid int64) {
|
func (m *{{.Name}}) UpdateID(newid {{.PrimaryKeyIDType}}) {
|
||||||
{{if .PrimaryKeyIsSQL -}}
|
{{if .PrimaryKeyIsSQL -}}
|
||||||
m.{{.PrimaryKeyField}}.FromString(fmt.Sprintf("%d", newid))
|
{{if .PrimaryKeyIsStr -}}
|
||||||
|
m.{{.PrimaryKeyField}}.FromString(newid)
|
||||||
{{- else -}}
|
{{- else -}}
|
||||||
m.{{.PrimaryKeyField}} = int32(newid)
|
m.{{.PrimaryKeyField}}.FromString(fmt.Sprintf("%d", newid))
|
||||||
|
{{- end}}
|
||||||
|
{{- else -}}
|
||||||
|
{{if .PrimaryKeyIsStr -}}
|
||||||
|
m.{{.PrimaryKeyField}} = newid
|
||||||
|
{{- else -}}
|
||||||
|
m.{{.PrimaryKeyField}} = {{.PrimaryKeyType}}(newid)
|
||||||
|
{{- end}}
|
||||||
{{- end}}
|
{{- end}}
|
||||||
}
|
}
|
||||||
{{end}}
|
{{end}}
|
||||||
|
|||||||
@@ -11,30 +11,43 @@ import (
|
|||||||
|
|
||||||
// TypeMapper handles type conversions between SQL and Go types for Bun
|
// TypeMapper handles type conversions between SQL and Go types for Bun
|
||||||
type TypeMapper struct {
|
type TypeMapper struct {
|
||||||
// Package alias for sql_types import
|
|
||||||
sqlTypesAlias string
|
sqlTypesAlias string
|
||||||
|
typeStyle string // writers.NullableTypeResolveSpec | writers.NullableTypeStdlib
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewTypeMapper creates a new TypeMapper with default settings
|
// NewTypeMapper creates a new TypeMapper.
|
||||||
func NewTypeMapper() *TypeMapper {
|
// typeStyle should be writers.NullableTypeResolveSpec or writers.NullableTypeStdlib;
|
||||||
|
// an empty string defaults to resolvespec.
|
||||||
|
func NewTypeMapper(typeStyle string) *TypeMapper {
|
||||||
|
if typeStyle == "" {
|
||||||
|
typeStyle = writers.NullableTypeResolveSpec
|
||||||
|
}
|
||||||
return &TypeMapper{
|
return &TypeMapper{
|
||||||
sqlTypesAlias: "resolvespec_common",
|
sqlTypesAlias: "resolvespec_common",
|
||||||
|
typeStyle: typeStyle,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// SQLTypeToGoType converts a SQL type to its Go equivalent
|
// SQLTypeToGoType converts a SQL type to its Go equivalent.
|
||||||
// Uses ResolveSpec common package types (all are nullable by default in Bun)
|
|
||||||
func (tm *TypeMapper) SQLTypeToGoType(sqlType string, notNull bool) string {
|
func (tm *TypeMapper) SQLTypeToGoType(sqlType string, notNull bool) string {
|
||||||
// Normalize SQL type (lowercase, remove length/precision)
|
// Array types are handled separately for both styles.
|
||||||
|
if pgsql.IsArrayType(sqlType) {
|
||||||
|
return tm.arrayGoType(tm.extractBaseType(sqlType))
|
||||||
|
}
|
||||||
|
|
||||||
baseType := tm.extractBaseType(sqlType)
|
baseType := tm.extractBaseType(sqlType)
|
||||||
|
|
||||||
// For Bun, we typically use resolvespec_common types for most fields
|
if tm.typeStyle == writers.NullableTypeStdlib {
|
||||||
// unless they're explicitly NOT NULL and we want to avoid null handling
|
if notNull {
|
||||||
|
return tm.rawGoType(baseType)
|
||||||
|
}
|
||||||
|
return tm.stdlibNullableGoType(baseType)
|
||||||
|
}
|
||||||
|
|
||||||
|
// resolvespec (default): use base Go types only for simple NOT NULL fields.
|
||||||
if notNull && tm.isSimpleType(baseType) {
|
if notNull && tm.isSimpleType(baseType) {
|
||||||
return tm.baseGoType(baseType)
|
return tm.baseGoType(baseType)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use resolvespec_common types for nullable fields
|
|
||||||
return tm.bunGoType(baseType)
|
return tm.bunGoType(baseType)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -154,6 +167,9 @@ func (tm *TypeMapper) bunGoType(sqlType string) string {
|
|||||||
|
|
||||||
// Other
|
// Other
|
||||||
"money": tm.sqlTypesAlias + ".SqlFloat64",
|
"money": tm.sqlTypesAlias + ".SqlFloat64",
|
||||||
|
|
||||||
|
// pgvector
|
||||||
|
"vector": tm.sqlTypesAlias + ".SqlVector",
|
||||||
}
|
}
|
||||||
|
|
||||||
if goType, ok := typeMap[sqlType]; ok {
|
if goType, ok := typeMap[sqlType]; ok {
|
||||||
@@ -164,6 +180,123 @@ func (tm *TypeMapper) bunGoType(sqlType string) string {
|
|||||||
return tm.sqlTypesAlias + ".SqlString"
|
return tm.sqlTypesAlias + ".SqlString"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// arrayGoType returns the Go type for a PostgreSQL array column.
|
||||||
|
// The baseElemType is the canonical base type (e.g. "text", "integer").
|
||||||
|
func (tm *TypeMapper) arrayGoType(baseElemType string) string {
|
||||||
|
if tm.typeStyle == writers.NullableTypeStdlib {
|
||||||
|
return tm.stdlibArrayGoType(baseElemType)
|
||||||
|
}
|
||||||
|
typeMap := map[string]string{
|
||||||
|
"text": tm.sqlTypesAlias + ".SqlStringArray", "varchar": tm.sqlTypesAlias + ".SqlStringArray",
|
||||||
|
"char": tm.sqlTypesAlias + ".SqlStringArray", "character": tm.sqlTypesAlias + ".SqlStringArray",
|
||||||
|
"citext": tm.sqlTypesAlias + ".SqlStringArray", "bpchar": tm.sqlTypesAlias + ".SqlStringArray",
|
||||||
|
"inet": tm.sqlTypesAlias + ".SqlStringArray", "cidr": tm.sqlTypesAlias + ".SqlStringArray",
|
||||||
|
"macaddr": tm.sqlTypesAlias + ".SqlStringArray",
|
||||||
|
"json": tm.sqlTypesAlias + ".SqlStringArray", "jsonb": tm.sqlTypesAlias + ".SqlStringArray",
|
||||||
|
"integer": tm.sqlTypesAlias + ".SqlInt32Array", "int": tm.sqlTypesAlias + ".SqlInt32Array",
|
||||||
|
"int4": tm.sqlTypesAlias + ".SqlInt32Array", "serial": tm.sqlTypesAlias + ".SqlInt32Array",
|
||||||
|
"smallint": tm.sqlTypesAlias + ".SqlInt16Array", "int2": tm.sqlTypesAlias + ".SqlInt16Array",
|
||||||
|
"smallserial": tm.sqlTypesAlias + ".SqlInt16Array",
|
||||||
|
"bigint": tm.sqlTypesAlias + ".SqlInt64Array", "int8": tm.sqlTypesAlias + ".SqlInt64Array",
|
||||||
|
"bigserial": tm.sqlTypesAlias + ".SqlInt64Array",
|
||||||
|
"real": tm.sqlTypesAlias + ".SqlFloat32Array", "float4": tm.sqlTypesAlias + ".SqlFloat32Array",
|
||||||
|
"double precision": tm.sqlTypesAlias + ".SqlFloat64Array", "float8": tm.sqlTypesAlias + ".SqlFloat64Array",
|
||||||
|
"numeric": tm.sqlTypesAlias + ".SqlFloat64Array", "decimal": tm.sqlTypesAlias + ".SqlFloat64Array",
|
||||||
|
"money": tm.sqlTypesAlias + ".SqlFloat64Array",
|
||||||
|
"boolean": tm.sqlTypesAlias + ".SqlBoolArray", "bool": tm.sqlTypesAlias + ".SqlBoolArray",
|
||||||
|
"uuid": tm.sqlTypesAlias + ".SqlUUIDArray",
|
||||||
|
}
|
||||||
|
if goType, ok := typeMap[baseElemType]; ok {
|
||||||
|
return goType
|
||||||
|
}
|
||||||
|
return tm.sqlTypesAlias + ".SqlStringArray"
|
||||||
|
}
|
||||||
|
|
||||||
|
// rawGoType returns the plain Go type for a NOT NULL column in stdlib mode.
|
||||||
|
func (tm *TypeMapper) rawGoType(sqlType string) string {
|
||||||
|
typeMap := map[string]string{
|
||||||
|
"integer": "int32", "int": "int32", "int4": "int32", "serial": "int32",
|
||||||
|
"smallint": "int16", "int2": "int16", "smallserial": "int16",
|
||||||
|
"bigint": "int64", "int8": "int64", "bigserial": "int64",
|
||||||
|
"boolean": "bool", "bool": "bool",
|
||||||
|
"real": "float32", "float4": "float32",
|
||||||
|
"double precision": "float64", "float8": "float64",
|
||||||
|
"numeric": "float64", "decimal": "float64", "money": "float64",
|
||||||
|
"text": "string", "varchar": "string", "char": "string",
|
||||||
|
"character": "string", "citext": "string", "bpchar": "string",
|
||||||
|
"inet": "string", "cidr": "string", "macaddr": "string",
|
||||||
|
"uuid": "string", "json": "string", "jsonb": "string",
|
||||||
|
"timestamp": "time.Time",
|
||||||
|
"timestamp without time zone": "time.Time",
|
||||||
|
"timestamp with time zone": "time.Time",
|
||||||
|
"timestamptz": "time.Time",
|
||||||
|
"date": "time.Time",
|
||||||
|
"time": "time.Time",
|
||||||
|
"time without time zone": "time.Time",
|
||||||
|
"time with time zone": "time.Time",
|
||||||
|
"timetz": "time.Time",
|
||||||
|
"bytea": "[]byte",
|
||||||
|
"vector": "[]float32",
|
||||||
|
}
|
||||||
|
if goType, ok := typeMap[sqlType]; ok {
|
||||||
|
return goType
|
||||||
|
}
|
||||||
|
return "string"
|
||||||
|
}
|
||||||
|
|
||||||
|
// stdlibNullableGoType returns the database/sql nullable type for a column.
|
||||||
|
func (tm *TypeMapper) stdlibNullableGoType(sqlType string) string {
|
||||||
|
typeMap := map[string]string{
|
||||||
|
"integer": "sql.NullInt32", "int": "sql.NullInt32", "int4": "sql.NullInt32", "serial": "sql.NullInt32",
|
||||||
|
"smallint": "sql.NullInt16", "int2": "sql.NullInt16", "smallserial": "sql.NullInt16",
|
||||||
|
"bigint": "sql.NullInt64", "int8": "sql.NullInt64", "bigserial": "sql.NullInt64",
|
||||||
|
"boolean": "sql.NullBool", "bool": "sql.NullBool",
|
||||||
|
"real": "sql.NullFloat64", "float4": "sql.NullFloat64",
|
||||||
|
"double precision": "sql.NullFloat64", "float8": "sql.NullFloat64",
|
||||||
|
"numeric": "sql.NullFloat64", "decimal": "sql.NullFloat64", "money": "sql.NullFloat64",
|
||||||
|
"text": "sql.NullString", "varchar": "sql.NullString", "char": "sql.NullString",
|
||||||
|
"character": "sql.NullString", "citext": "sql.NullString", "bpchar": "sql.NullString",
|
||||||
|
"inet": "sql.NullString", "cidr": "sql.NullString", "macaddr": "sql.NullString",
|
||||||
|
"uuid": "sql.NullString", "json": "sql.NullString", "jsonb": "sql.NullString",
|
||||||
|
"timestamp": "sql.NullTime",
|
||||||
|
"timestamp without time zone": "sql.NullTime",
|
||||||
|
"timestamp with time zone": "sql.NullTime",
|
||||||
|
"timestamptz": "sql.NullTime",
|
||||||
|
"date": "sql.NullTime",
|
||||||
|
"time": "sql.NullTime",
|
||||||
|
"time without time zone": "sql.NullTime",
|
||||||
|
"time with time zone": "sql.NullTime",
|
||||||
|
"timetz": "sql.NullTime",
|
||||||
|
"bytea": "[]byte",
|
||||||
|
"vector": "[]float32",
|
||||||
|
}
|
||||||
|
if goType, ok := typeMap[sqlType]; ok {
|
||||||
|
return goType
|
||||||
|
}
|
||||||
|
return "sql.NullString"
|
||||||
|
}
|
||||||
|
|
||||||
|
// stdlibArrayGoType returns a plain Go slice type for array columns in stdlib mode.
|
||||||
|
func (tm *TypeMapper) stdlibArrayGoType(baseElemType string) string {
|
||||||
|
typeMap := map[string]string{
|
||||||
|
"text": "[]string", "varchar": "[]string", "char": "[]string",
|
||||||
|
"character": "[]string", "citext": "[]string", "bpchar": "[]string",
|
||||||
|
"inet": "[]string", "cidr": "[]string", "macaddr": "[]string",
|
||||||
|
"uuid": "[]string", "json": "[]string", "jsonb": "[]string",
|
||||||
|
"integer": "[]int32", "int": "[]int32", "int4": "[]int32", "serial": "[]int32",
|
||||||
|
"smallint": "[]int16", "int2": "[]int16", "smallserial": "[]int16",
|
||||||
|
"bigint": "[]int64", "int8": "[]int64", "bigserial": "[]int64",
|
||||||
|
"real": "[]float32", "float4": "[]float32",
|
||||||
|
"double precision": "[]float64", "float8": "[]float64",
|
||||||
|
"numeric": "[]float64", "decimal": "[]float64", "money": "[]float64",
|
||||||
|
"boolean": "[]bool", "bool": "[]bool",
|
||||||
|
}
|
||||||
|
if goType, ok := typeMap[baseElemType]; ok {
|
||||||
|
return goType
|
||||||
|
}
|
||||||
|
return "[]string"
|
||||||
|
}
|
||||||
|
|
||||||
// BuildBunTag generates a complete Bun tag string for a column
|
// BuildBunTag generates a complete Bun tag string for a column
|
||||||
// Bun format: bun:"column_name,type:type_name,pk,default:value"
|
// Bun format: bun:"column_name,type:type_name,pk,default:value"
|
||||||
func (tm *TypeMapper) BuildBunTag(column *models.Column, table *models.Table) string {
|
func (tm *TypeMapper) BuildBunTag(column *models.Column, table *models.Table) string {
|
||||||
@@ -178,10 +311,11 @@ func (tm *TypeMapper) BuildBunTag(column *models.Column, table *models.Table) st
|
|||||||
if column.Type != "" {
|
if column.Type != "" {
|
||||||
// Sanitize type to remove backticks
|
// Sanitize type to remove backticks
|
||||||
typeStr := writers.SanitizeStructTagValue(column.Type)
|
typeStr := writers.SanitizeStructTagValue(column.Type)
|
||||||
|
isArray := pgsql.IsArrayType(typeStr)
|
||||||
hasExplicitTypeModifier := pgsql.HasExplicitTypeModifier(typeStr)
|
hasExplicitTypeModifier := pgsql.HasExplicitTypeModifier(typeStr)
|
||||||
if !hasExplicitTypeModifier && column.Length > 0 {
|
if !hasExplicitTypeModifier && !isArray && column.Length > 0 {
|
||||||
typeStr = fmt.Sprintf("%s(%d)", typeStr, column.Length)
|
typeStr = fmt.Sprintf("%s(%d)", typeStr, column.Length)
|
||||||
} else if !hasExplicitTypeModifier && column.Precision > 0 {
|
} else if !hasExplicitTypeModifier && !isArray && column.Precision > 0 {
|
||||||
if column.Scale > 0 {
|
if column.Scale > 0 {
|
||||||
typeStr = fmt.Sprintf("%s(%d,%d)", typeStr, column.Precision, column.Scale)
|
typeStr = fmt.Sprintf("%s(%d,%d)", typeStr, column.Precision, column.Scale)
|
||||||
} else {
|
} else {
|
||||||
@@ -189,6 +323,9 @@ func (tm *TypeMapper) BuildBunTag(column *models.Column, table *models.Table) st
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
parts = append(parts, fmt.Sprintf("type:%s", typeStr))
|
parts = append(parts, fmt.Sprintf("type:%s", typeStr))
|
||||||
|
if isArray && tm.typeStyle == writers.NullableTypeStdlib {
|
||||||
|
parts = append(parts, "array")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Primary key
|
// Primary key
|
||||||
@@ -286,11 +423,20 @@ func (tm *TypeMapper) NeedsFmtImport(generateGetIDStr bool) bool {
|
|||||||
return generateGetIDStr
|
return generateGetIDStr
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetSQLTypesImport returns the import path for sql_types (ResolveSpec common)
|
// GetSQLTypesImport returns the import path for the ResolveSpec spectypes package.
|
||||||
func (tm *TypeMapper) GetSQLTypesImport() string {
|
func (tm *TypeMapper) GetSQLTypesImport() string {
|
||||||
return "github.com/bitechdev/ResolveSpec/pkg/spectypes"
|
return "github.com/bitechdev/ResolveSpec/pkg/spectypes"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetNullableTypeImportLine returns the full Go import line for the nullable type
|
||||||
|
// package (ready to pass to AddImport). Returns empty string when no import is needed.
|
||||||
|
func (tm *TypeMapper) GetNullableTypeImportLine() string {
|
||||||
|
if tm.typeStyle == writers.NullableTypeStdlib {
|
||||||
|
return "\"database/sql\""
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s \"%s\"", tm.sqlTypesAlias, tm.GetSQLTypesImport())
|
||||||
|
}
|
||||||
|
|
||||||
// GetBunImport returns the import path for Bun
|
// GetBunImport returns the import path for Bun
|
||||||
func (tm *TypeMapper) GetBunImport() string {
|
func (tm *TypeMapper) GetBunImport() string {
|
||||||
return "github.com/uptrace/bun"
|
return "github.com/uptrace/bun"
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ type Writer struct {
|
|||||||
func NewWriter(options *writers.WriterOptions) *Writer {
|
func NewWriter(options *writers.WriterOptions) *Writer {
|
||||||
w := &Writer{
|
w := &Writer{
|
||||||
options: options,
|
options: options,
|
||||||
typeMapper: NewTypeMapper(),
|
typeMapper: NewTypeMapper(options.NullableTypes),
|
||||||
config: LoadMethodConfigFromMetadata(options.Metadata),
|
config: LoadMethodConfigFromMetadata(options.Metadata),
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -80,8 +80,8 @@ func (w *Writer) writeSingleFile(db *models.Database) error {
|
|||||||
// Add bun import (always needed)
|
// Add bun import (always needed)
|
||||||
templateData.AddImport(fmt.Sprintf("\"%s\"", w.typeMapper.GetBunImport()))
|
templateData.AddImport(fmt.Sprintf("\"%s\"", w.typeMapper.GetBunImport()))
|
||||||
|
|
||||||
// Add resolvespec_common import (always needed for nullable types)
|
// Add nullable types import (resolvespec or stdlib depending on options)
|
||||||
templateData.AddImport(fmt.Sprintf("resolvespec_common \"%s\"", w.typeMapper.GetSQLTypesImport()))
|
templateData.AddImport(w.typeMapper.GetNullableTypeImportLine())
|
||||||
|
|
||||||
// Collect all models
|
// Collect all models
|
||||||
for _, schema := range db.Schemas {
|
for _, schema := range db.Schemas {
|
||||||
@@ -102,8 +102,8 @@ func (w *Writer) writeSingleFile(db *models.Database) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add fmt import if GetIDStr is enabled
|
// Add fmt import when generated helper methods need string formatting.
|
||||||
if w.config.GenerateGetIDStr {
|
if w.needsFmtImport(templateData.Models) {
|
||||||
templateData.AddImport("\"fmt\"")
|
templateData.AddImport("\"fmt\"")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -177,8 +177,8 @@ func (w *Writer) writeMultiFile(db *models.Database) error {
|
|||||||
// Add bun import
|
// Add bun import
|
||||||
templateData.AddImport(fmt.Sprintf("\"%s\"", w.typeMapper.GetBunImport()))
|
templateData.AddImport(fmt.Sprintf("\"%s\"", w.typeMapper.GetBunImport()))
|
||||||
|
|
||||||
// Add resolvespec_common import
|
// Add nullable types import (resolvespec or stdlib depending on options)
|
||||||
templateData.AddImport(fmt.Sprintf("resolvespec_common \"%s\"", w.typeMapper.GetSQLTypesImport()))
|
templateData.AddImport(w.typeMapper.GetNullableTypeImportLine())
|
||||||
|
|
||||||
// Create model data
|
// Create model data
|
||||||
modelData := NewModelData(table, schema.Name, w.typeMapper, w.options.FlattenSchema)
|
modelData := NewModelData(table, schema.Name, w.typeMapper, w.options.FlattenSchema)
|
||||||
@@ -195,8 +195,8 @@ func (w *Writer) writeMultiFile(db *models.Database) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add fmt import if GetIDStr is enabled
|
// Add fmt import when generated helper methods need string formatting.
|
||||||
if w.config.GenerateGetIDStr {
|
if w.needsFmtImport(templateData.Models) {
|
||||||
templateData.AddImport("\"fmt\"")
|
templateData.AddImport("\"fmt\"")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -301,6 +301,26 @@ func (w *Writer) addRelationshipFields(modelData *ModelData, table *models.Table
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (w *Writer) needsFmtImport(models []*ModelData) bool {
|
||||||
|
if w.config.GenerateGetIDStr {
|
||||||
|
for _, model := range models {
|
||||||
|
if model.PrimaryKeyField != "" && !model.PrimaryKeyIsSQL && !model.PrimaryKeyIsStr {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if w.config.GenerateUpdateID {
|
||||||
|
for _, model := range models {
|
||||||
|
if model.PrimaryKeyField != "" && model.PrimaryKeyIsSQL && !model.PrimaryKeyIsStr {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
// findTable finds a table by schema and name in the database
|
// findTable finds a table by schema and name in the database
|
||||||
func (w *Writer) findTable(schemaName, tableName string, db *models.Database) *models.Table {
|
func (w *Writer) findTable(schemaName, tableName string, db *models.Database) *models.Table {
|
||||||
for _, schema := range db.Schemas {
|
for _, schema := range db.Schemas {
|
||||||
|
|||||||
@@ -556,7 +556,7 @@ func TestWriter_FieldNameCollision(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestTypeMapper_SQLTypeToGoType_Bun(t *testing.T) {
|
func TestTypeMapper_SQLTypeToGoType_Bun(t *testing.T) {
|
||||||
mapper := NewTypeMapper()
|
mapper := NewTypeMapper("")
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
sqlType string
|
sqlType string
|
||||||
@@ -574,6 +574,10 @@ func TestTypeMapper_SQLTypeToGoType_Bun(t *testing.T) {
|
|||||||
{"boolean", false, "resolvespec_common.SqlBool"},
|
{"boolean", false, "resolvespec_common.SqlBool"},
|
||||||
{"uuid", false, "resolvespec_common.SqlUUID"},
|
{"uuid", false, "resolvespec_common.SqlUUID"},
|
||||||
{"jsonb", false, "resolvespec_common.SqlJSONB"},
|
{"jsonb", false, "resolvespec_common.SqlJSONB"},
|
||||||
|
{"text[]", true, "resolvespec_common.SqlStringArray"},
|
||||||
|
{"text[]", false, "resolvespec_common.SqlStringArray"},
|
||||||
|
{"integer[]", true, "resolvespec_common.SqlInt32Array"},
|
||||||
|
{"bigint[]", false, "resolvespec_common.SqlInt64Array"},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
@@ -586,8 +590,118 @@ func TestTypeMapper_SQLTypeToGoType_Bun(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestWriter_UpdateIDTypeSafety_Bun(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
pkType string
|
||||||
|
expectedPK string
|
||||||
|
expectedLine string
|
||||||
|
forbidInt32 bool
|
||||||
|
}{
|
||||||
|
{"int32_pk", "int", "int32", "m.ID = int32(newid)", false},
|
||||||
|
{"sql_int16_pk", "smallint", "resolvespec_common.SqlInt16", "m.ID.FromString(fmt.Sprintf(\"%d\", newid))", true},
|
||||||
|
{"int64_pk", "bigint", "int64", "m.ID = int64(newid)", true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
table := models.InitTable("test_table", "public")
|
||||||
|
table.Columns["id"] = &models.Column{
|
||||||
|
Name: "id",
|
||||||
|
Type: tt.pkType,
|
||||||
|
NotNull: true,
|
||||||
|
IsPrimaryKey: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
opts := &writers.WriterOptions{
|
||||||
|
PackageName: "models",
|
||||||
|
OutputPath: filepath.Join(tmpDir, "test.go"),
|
||||||
|
}
|
||||||
|
|
||||||
|
writer := NewWriter(opts)
|
||||||
|
err := writer.WriteTable(table)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("WriteTable failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
content, err := os.ReadFile(opts.OutputPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to read generated file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
generated := string(content)
|
||||||
|
|
||||||
|
if !strings.Contains(generated, tt.expectedLine) {
|
||||||
|
t.Errorf("Expected UpdateID to include %s\nGenerated:\n%s", tt.expectedLine, generated)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(generated, "ID "+tt.expectedPK) {
|
||||||
|
t.Errorf("Expected generated primary key field type %s\nGenerated:\n%s", tt.expectedPK, generated)
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.forbidInt32 && strings.Contains(generated, "int32(newid)") {
|
||||||
|
t.Errorf("UpdateID should not cast to int32 for %s type\nGenerated:\n%s", tt.pkType, generated)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(generated, "UpdateID(newid int64)") {
|
||||||
|
t.Errorf("UpdateID should accept int64 parameter\nGenerated:\n%s", generated)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWriter_StringPrimaryKeyHelpers_Bun(t *testing.T) {
|
||||||
|
table := models.InitTable("accounts", "public")
|
||||||
|
table.Columns["id"] = &models.Column{
|
||||||
|
Name: "id",
|
||||||
|
Type: "uuid",
|
||||||
|
NotNull: true,
|
||||||
|
IsPrimaryKey: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
opts := &writers.WriterOptions{
|
||||||
|
PackageName: "models",
|
||||||
|
OutputPath: filepath.Join(tmpDir, "test.go"),
|
||||||
|
}
|
||||||
|
|
||||||
|
writer := NewWriter(opts)
|
||||||
|
err := writer.WriteTable(table)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("WriteTable failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
content, err := os.ReadFile(opts.OutputPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to read generated file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
generated := string(content)
|
||||||
|
|
||||||
|
expectations := []string{
|
||||||
|
"resolvespec_common.SqlUUID",
|
||||||
|
"func (m ModelPublicAccounts) GetID() string",
|
||||||
|
"return m.ID.String()",
|
||||||
|
"func (m ModelPublicAccounts) GetIDStr() string",
|
||||||
|
"func (m ModelPublicAccounts) SetID(newid string)",
|
||||||
|
"func (m *ModelPublicAccounts) UpdateID(newid string)",
|
||||||
|
"m.ID.FromString(newid)",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, expected := range expectations {
|
||||||
|
if !strings.Contains(generated, expected) {
|
||||||
|
t.Errorf("Generated code missing expected content: %q\nGenerated:\n%s", expected, generated)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.Contains(generated, "GetID() int64") || strings.Contains(generated, "UpdateID(newid int64)") {
|
||||||
|
t.Errorf("String primary keys should not use int64 helper signatures\nGenerated:\n%s", generated)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestTypeMapper_BuildBunTag(t *testing.T) {
|
func TestTypeMapper_BuildBunTag(t *testing.T) {
|
||||||
mapper := NewTypeMapper()
|
mapper := NewTypeMapper("")
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -685,6 +799,24 @@ func TestTypeMapper_BuildBunTag(t *testing.T) {
|
|||||||
},
|
},
|
||||||
want: []string{"id,", "type:bigserial,", "pk,", "autoincrement,"},
|
want: []string{"id,", "type:bigserial,", "pk,", "autoincrement,"},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "text array type",
|
||||||
|
column: &models.Column{
|
||||||
|
Name: "tags",
|
||||||
|
Type: "text[]",
|
||||||
|
NotNull: false,
|
||||||
|
},
|
||||||
|
want: []string{"tags,", "type:text[],"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "integer array type",
|
||||||
|
column: &models.Column{
|
||||||
|
Name: "scores",
|
||||||
|
Type: "integer[]",
|
||||||
|
NotNull: true,
|
||||||
|
},
|
||||||
|
want: []string{"scores,", "type:integer[],"},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
@@ -695,12 +827,36 @@ func TestTypeMapper_BuildBunTag(t *testing.T) {
|
|||||||
t.Errorf("BuildBunTag() = %q, missing %q", result, part)
|
t.Errorf("BuildBunTag() = %q, missing %q", result, part)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// resolvespec mode must NOT add "array" — SqlXxxArray uses sql.Scanner
|
||||||
|
if strings.Contains(result, ",array,") || strings.HasSuffix(result, ",array,") {
|
||||||
|
t.Errorf("BuildBunTag() = %q, must not contain 'array' in resolvespec mode", result)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTypeMapper_BuildBunTag_StdlibArrayHasArrayTag(t *testing.T) {
|
||||||
|
mapper := NewTypeMapper(writers.NullableTypeStdlib)
|
||||||
|
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
column *models.Column
|
||||||
|
}{
|
||||||
|
{name: "text array", column: &models.Column{Name: "tags", Type: "text[]"}},
|
||||||
|
{name: "integer array", column: &models.Column{Name: "scores", Type: "integer[]", NotNull: true}},
|
||||||
|
}
|
||||||
|
for _, tt := range cases {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := mapper.BuildBunTag(tt.column, nil)
|
||||||
|
if !strings.Contains(result, "array") {
|
||||||
|
t.Errorf("BuildBunTag() = %q, expected 'array' in stdlib mode", result)
|
||||||
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTypeMapper_BuildBunTag_PreservesExplicitTypeModifiers(t *testing.T) {
|
func TestTypeMapper_BuildBunTag_PreservesExplicitTypeModifiers(t *testing.T) {
|
||||||
mapper := NewTypeMapper()
|
mapper := NewTypeMapper("")
|
||||||
|
|
||||||
col := &models.Column{
|
col := &models.Column{
|
||||||
Name: "embedding",
|
Name: "embedding",
|
||||||
|
|||||||
@@ -48,22 +48,23 @@ func main() {
|
|||||||
### CLI Examples
|
### CLI Examples
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Generate GORM models from PostgreSQL database (single file)
|
# Generate GORM models from a DBML schema (default: resolvespec types)
|
||||||
relspec --input pgsql \
|
relspec convert --from dbml --from-path schema.dbml \
|
||||||
--conn "postgres://localhost/mydb" \
|
--to gorm --to-path models.go --package models
|
||||||
--output gorm \
|
|
||||||
--out-file models.go \
|
|
||||||
--package models
|
|
||||||
|
|
||||||
# Generate GORM models with multi-file output (one file per table)
|
# Use standard library database/sql nullable types instead of resolvespec
|
||||||
relspec --input json \
|
relspec convert --from dbml --from-path schema.dbml \
|
||||||
--in-file schema.json \
|
--to gorm --to-path models.go --package models \
|
||||||
--output gorm \
|
--types stdlib
|
||||||
--out-file models/ \
|
|
||||||
--package models
|
|
||||||
|
|
||||||
# Convert DBML to GORM models
|
# Explicitly select resolvespec types (same as omitting --types)
|
||||||
relspec --input dbml --in-file schema.dbml --output gorm --out-file models.go
|
relspec convert --from pgsql --from-conn "postgres://localhost/mydb" \
|
||||||
|
--to gorm --to-path models.go --package models \
|
||||||
|
--types resolvespec
|
||||||
|
|
||||||
|
# Multi-file output (one file per table)
|
||||||
|
relspec convert --from json --from-path schema.json \
|
||||||
|
--to gorm --to-path models/ --package models
|
||||||
```
|
```
|
||||||
|
|
||||||
## Output Modes
|
## Output Modes
|
||||||
@@ -86,56 +87,84 @@ relspec --input pgsql --conn "..." --output gorm --out-file models/
|
|||||||
|
|
||||||
Files are named: `sql_{schema}_{table}.go`
|
Files are named: `sql_{schema}_{table}.go`
|
||||||
|
|
||||||
## Generated Code Example
|
## Generated Code Examples
|
||||||
|
|
||||||
|
### Default — resolvespec types (`--types resolvespec`)
|
||||||
|
|
||||||
```go
|
```go
|
||||||
package models
|
package models
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"time"
|
sql_types "github.com/bitechdev/ResolveSpec/pkg/spectypes"
|
||||||
sql_types "git.warky.dev/wdevs/sql_types"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type ModelUser struct {
|
type ModelUser struct {
|
||||||
ID int64 `gorm:"column:id;type:bigint;primaryKey;autoIncrement" json:"id"`
|
ID string `gorm:"column:id;type:uuid;primaryKey" json:"id"`
|
||||||
Username string `gorm:"column:username;type:varchar(50);not null;uniqueIndex" json:"username"`
|
Username string `gorm:"column:username;type:text;not null" json:"username"`
|
||||||
Email string `gorm:"column:email;type:varchar(100);not null" json:"email"`
|
Email sql_types.SqlString `gorm:"column:email;type:text" json:"email,omitempty"`
|
||||||
CreatedAt time.Time `gorm:"column:created_at;type:timestamp;not null;default:now()" json:"created_at"`
|
Tags sql_types.SqlStringArray `gorm:"column:tags;type:text[];not null;default:'{}'" json:"tags"`
|
||||||
|
CreatedAt sql_types.SqlTimeStamp `gorm:"column:created_at;type:timestamptz;not null;default:now()" json:"created_at"`
|
||||||
// Relationships
|
|
||||||
Pos []*ModelPost `gorm:"foreignKey:UserID;references:ID;constraint:OnDelete:CASCADE" json:"pos,omitempty"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ModelUser) TableName() string {
|
func (ModelUser) TableName() string {
|
||||||
return "public.users"
|
return "public.users"
|
||||||
}
|
}
|
||||||
|
```
|
||||||
|
|
||||||
type ModelPost struct {
|
### Standard library — `--types stdlib`
|
||||||
ID int64 `gorm:"column:id;type:bigint;primaryKey" json:"id"`
|
|
||||||
UserID int64 `gorm:"column:user_id;type:bigint;not null" json:"user_id"`
|
|
||||||
Title string `gorm:"column:title;type:varchar(200);not null" json:"title"`
|
|
||||||
Content sql_types.SqlString `gorm:"column:content;type:text" json:"content,omitempty"`
|
|
||||||
|
|
||||||
// Belongs to
|
```go
|
||||||
Use *ModelUser `gorm:"foreignKey:UserID;references:ID" json:"use,omitempty"`
|
package models
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ModelUser struct {
|
||||||
|
ID string `gorm:"column:id;type:uuid;primaryKey" json:"id"`
|
||||||
|
Username string `gorm:"column:username;type:text;not null" json:"username"`
|
||||||
|
Email sql.NullString `gorm:"column:email;type:text" json:"email,omitempty"`
|
||||||
|
Tags []string `gorm:"column:tags;type:text[];not null;default:'{}'" json:"tags"`
|
||||||
|
CreatedAt time.Time `gorm:"column:created_at;type:timestamptz;not null;default:now()" json:"created_at"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ModelPost) TableName() string {
|
func (ModelUser) TableName() string {
|
||||||
return "public.posts"
|
return "public.users"
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
## Writer Options
|
## Writer Options
|
||||||
|
|
||||||
|
### NullableTypes
|
||||||
|
|
||||||
|
Controls which Go package is used for nullable column types. Set via the `--types` CLI flag or `WriterOptions.NullableTypes`:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Use resolvespec types (default — omit NullableTypes or set to "resolvespec")
|
||||||
|
options := &writers.WriterOptions{
|
||||||
|
OutputPath: "models.go",
|
||||||
|
PackageName: "models",
|
||||||
|
NullableTypes: writers.NullableTypeResolveSpec,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use standard library database/sql types
|
||||||
|
options := &writers.WriterOptions{
|
||||||
|
OutputPath: "models.go",
|
||||||
|
PackageName: "models",
|
||||||
|
NullableTypes: writers.NullableTypeStdlib,
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
### Metadata Options
|
### Metadata Options
|
||||||
|
|
||||||
Configure the writer behavior using metadata in `WriterOptions`:
|
Configure additional writer behavior using metadata in `WriterOptions`:
|
||||||
|
|
||||||
```go
|
```go
|
||||||
options := &writers.WriterOptions{
|
options := &writers.WriterOptions{
|
||||||
OutputPath: "models.go",
|
OutputPath: "models.go",
|
||||||
PackageName: "models",
|
PackageName: "models",
|
||||||
Metadata: map[string]interface{}{
|
Metadata: map[string]any{
|
||||||
"multi_file": true, // Enable multi-file mode
|
"multi_file": true, // Enable multi-file mode
|
||||||
"populate_refs": true, // Populate RefDatabase/RefSchema
|
"populate_refs": true, // Populate RefDatabase/RefSchema
|
||||||
"generate_get_id_str": true, // Generate GetIDStr() methods
|
"generate_get_id_str": true, // Generate GetIDStr() methods
|
||||||
@@ -145,18 +174,23 @@ options := &writers.WriterOptions{
|
|||||||
|
|
||||||
## Type Mapping
|
## Type Mapping
|
||||||
|
|
||||||
| SQL Type | Go Type | Notes |
|
The nullable type package is selected with `--types` (or `WriterOptions.NullableTypes`).
|
||||||
|----------|---------|-------|
|
|
||||||
| bigint, int8 | int64 | - |
|
| SQL Type | NOT NULL — both | Nullable — resolvespec | Nullable — stdlib |
|
||||||
| integer, int, int4 | int | - |
|
|---|---|---|---|
|
||||||
| smallint, int2 | int16 | - |
|
| `bigint` | `int64` | `SqlInt64` | `sql.NullInt64` |
|
||||||
| varchar, text | string | Not nullable |
|
| `integer` | `int32` | `SqlInt32` | `sql.NullInt32` |
|
||||||
| varchar, text (nullable) | sql_types.SqlString | Nullable |
|
| `smallint` | `int16` | `SqlInt16` | `sql.NullInt16` |
|
||||||
| boolean, bool | bool | - |
|
| `text`, `varchar` | `string` | `SqlString` | `sql.NullString` |
|
||||||
| timestamp, timestamptz | time.Time | - |
|
| `boolean` | `bool` | `SqlBool` | `sql.NullBool` |
|
||||||
| numeric, decimal | float64 | - |
|
| `timestamp`, `timestamptz` | `time.Time` | `SqlTimeStamp` | `sql.NullTime` |
|
||||||
| uuid | string | - |
|
| `numeric`, `decimal` | `float64` | `SqlFloat64` | `sql.NullFloat64` |
|
||||||
| json, jsonb | string | - |
|
| `uuid` | `string` | `SqlUUID` | `sql.NullString` |
|
||||||
|
| `jsonb` | `string` | `SqlString` | `sql.NullString` |
|
||||||
|
| `text[]` | `SqlStringArray` | `SqlStringArray` | `[]string` |
|
||||||
|
| `integer[]` | `SqlInt32Array` | `SqlInt32Array` | `[]int32` |
|
||||||
|
| `uuid[]` | `SqlUUIDArray` | `SqlUUIDArray` | `[]string` |
|
||||||
|
| `vector` | `SqlVector` | `SqlVector` | `[]float32` |
|
||||||
|
|
||||||
## Relationship Generation
|
## Relationship Generation
|
||||||
|
|
||||||
@@ -170,7 +204,8 @@ The writer automatically generates relationship fields:
|
|||||||
## Notes
|
## Notes
|
||||||
|
|
||||||
- Model names are prefixed with "Model" (e.g., `ModelUser`)
|
- Model names are prefixed with "Model" (e.g., `ModelUser`)
|
||||||
- Nullable columns use `sql_types.SqlString`, `sql_types.SqlInt64`, etc.
|
- Nullable columns use `sql_types.SqlString`, `sql_types.SqlInt64`, etc. by default; pass `--types stdlib` to use `sql.NullString`, `sql.NullInt64`, etc. instead
|
||||||
|
- Array columns use `sql_types.SqlStringArray`, `sql_types.SqlInt32Array`, etc. by default; `--types stdlib` produces plain Go slices (`[]string`, `[]int32`, …)
|
||||||
- Generated code is auto-formatted with `go fmt`
|
- Generated code is auto-formatted with `go fmt`
|
||||||
- JSON tags are automatically added
|
- JSON tags are automatically added
|
||||||
- Supports schema-qualified table names in `TableName()` method
|
- Supports schema-qualified table names in `TableName()` method
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package gorm
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"sort"
|
"sort"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||||
"git.warky.dev/wdevs/relspecgo/pkg/writers"
|
"git.warky.dev/wdevs/relspecgo/pkg/writers"
|
||||||
@@ -26,6 +27,9 @@ type ModelData struct {
|
|||||||
Config *MethodConfig
|
Config *MethodConfig
|
||||||
PrimaryKeyField string // Name of the primary key field
|
PrimaryKeyField string // Name of the primary key field
|
||||||
PrimaryKeyType string // Go type of the primary key field
|
PrimaryKeyType string // Go type of the primary key field
|
||||||
|
PrimaryKeyIsSQL bool // Whether PK uses a SQL wrapper type
|
||||||
|
PrimaryKeyIsStr bool // Whether helper methods should use string IDs
|
||||||
|
PrimaryKeyIDType string // Helper method GetID/SetID/UpdateID type
|
||||||
IDColumnName string // Name of the ID column in database
|
IDColumnName string // Name of the ID column in database
|
||||||
Prefix string // 3-letter prefix
|
Prefix string // 3-letter prefix
|
||||||
}
|
}
|
||||||
@@ -136,7 +140,14 @@ func NewModelData(table *models.Table, schema string, typeMapper *TypeMapper, fl
|
|||||||
// Sanitize column name to remove backticks
|
// Sanitize column name to remove backticks
|
||||||
safeName := writers.SanitizeStructTagValue(col.Name)
|
safeName := writers.SanitizeStructTagValue(col.Name)
|
||||||
model.PrimaryKeyField = SnakeCaseToPascalCase(safeName)
|
model.PrimaryKeyField = SnakeCaseToPascalCase(safeName)
|
||||||
model.PrimaryKeyType = typeMapper.SQLTypeToGoType(col.Type, col.NotNull)
|
goType := typeMapper.SQLTypeToGoType(col.Type, col.NotNull)
|
||||||
|
model.PrimaryKeyType = goType
|
||||||
|
model.PrimaryKeyIsSQL = strings.Contains(goType, "sql_types.") || strings.Contains(goType, "sql.")
|
||||||
|
model.PrimaryKeyIsStr = isStringLikePrimaryKeyType(goType)
|
||||||
|
model.PrimaryKeyIDType = "int64"
|
||||||
|
if model.PrimaryKeyIsStr {
|
||||||
|
model.PrimaryKeyIDType = "string"
|
||||||
|
}
|
||||||
model.IDColumnName = safeName
|
model.IDColumnName = safeName
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@@ -189,6 +200,15 @@ func formatComment(description, comment string) string {
|
|||||||
return comment
|
return comment
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isStringLikePrimaryKeyType(goType string) bool {
|
||||||
|
switch goType {
|
||||||
|
case "string", "sql.NullString", "sql_types.SqlString", "sql_types.SqlUUID":
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// resolveFieldNameCollision checks if a field name conflicts with generated method names
|
// resolveFieldNameCollision checks if a field name conflicts with generated method names
|
||||||
// and adds an underscore suffix if there's a collision
|
// and adds an underscore suffix if there's a collision
|
||||||
func resolveFieldNameCollision(fieldName string) string {
|
func resolveFieldNameCollision(fieldName string) string {
|
||||||
|
|||||||
@@ -43,26 +43,56 @@ func (m {{.Name}}) SchemaName() string {
|
|||||||
{{end}}
|
{{end}}
|
||||||
{{if and .Config.GenerateGetID .PrimaryKeyField}}
|
{{if and .Config.GenerateGetID .PrimaryKeyField}}
|
||||||
// GetID returns the primary key value
|
// GetID returns the primary key value
|
||||||
func (m {{.Name}}) GetID() int64 {
|
func (m {{.Name}}) GetID() {{.PrimaryKeyIDType}} {
|
||||||
|
{{if .PrimaryKeyIsSQL -}}
|
||||||
|
{{if .PrimaryKeyIsStr -}}
|
||||||
|
return m.{{.PrimaryKeyField}}.String()
|
||||||
|
{{- else -}}
|
||||||
|
return m.{{.PrimaryKeyField}}.Int64()
|
||||||
|
{{- end}}
|
||||||
|
{{- else -}}
|
||||||
|
{{if .PrimaryKeyIsStr -}}
|
||||||
|
return m.{{.PrimaryKeyField}}
|
||||||
|
{{- else -}}
|
||||||
return int64(m.{{.PrimaryKeyField}})
|
return int64(m.{{.PrimaryKeyField}})
|
||||||
|
{{- end}}
|
||||||
|
{{- end}}
|
||||||
}
|
}
|
||||||
{{end}}
|
{{end}}
|
||||||
{{if and .Config.GenerateGetIDStr .PrimaryKeyField}}
|
{{if and .Config.GenerateGetIDStr .PrimaryKeyField}}
|
||||||
// GetIDStr returns the primary key as a string
|
// GetIDStr returns the primary key as a string
|
||||||
func (m {{.Name}}) GetIDStr() string {
|
func (m {{.Name}}) GetIDStr() string {
|
||||||
|
{{if .PrimaryKeyIsSQL -}}
|
||||||
|
return m.{{.PrimaryKeyField}}.String()
|
||||||
|
{{- else if .PrimaryKeyIsStr -}}
|
||||||
|
return m.{{.PrimaryKeyField}}
|
||||||
|
{{- else -}}
|
||||||
return fmt.Sprintf("%d", m.{{.PrimaryKeyField}})
|
return fmt.Sprintf("%d", m.{{.PrimaryKeyField}})
|
||||||
|
{{- end}}
|
||||||
}
|
}
|
||||||
{{end}}
|
{{end}}
|
||||||
{{if and .Config.GenerateSetID .PrimaryKeyField}}
|
{{if and .Config.GenerateSetID .PrimaryKeyField}}
|
||||||
// SetID sets the primary key value
|
// SetID sets the primary key value
|
||||||
func (m {{.Name}}) SetID(newid int64) {
|
func (m {{.Name}}) SetID(newid {{.PrimaryKeyIDType}}) {
|
||||||
m.UpdateID(newid)
|
m.UpdateID(newid)
|
||||||
}
|
}
|
||||||
{{end}}
|
{{end}}
|
||||||
{{if and .Config.GenerateUpdateID .PrimaryKeyField}}
|
{{if and .Config.GenerateUpdateID .PrimaryKeyField}}
|
||||||
// UpdateID updates the primary key value
|
// UpdateID updates the primary key value
|
||||||
func (m *{{.Name}}) UpdateID(newid int64) {
|
func (m *{{.Name}}) UpdateID(newid {{.PrimaryKeyIDType}}) {
|
||||||
|
{{if .PrimaryKeyIsSQL -}}
|
||||||
|
{{if .PrimaryKeyIsStr -}}
|
||||||
|
m.{{.PrimaryKeyField}}.FromString(newid)
|
||||||
|
{{- else -}}
|
||||||
|
m.{{.PrimaryKeyField}}.FromString(fmt.Sprintf("%d", newid))
|
||||||
|
{{- end}}
|
||||||
|
{{- else -}}
|
||||||
|
{{if .PrimaryKeyIsStr -}}
|
||||||
|
m.{{.PrimaryKeyField}} = newid
|
||||||
|
{{- else -}}
|
||||||
m.{{.PrimaryKeyField}} = {{.PrimaryKeyType}}(newid)
|
m.{{.PrimaryKeyField}} = {{.PrimaryKeyType}}(newid)
|
||||||
|
{{- end}}
|
||||||
|
{{- end}}
|
||||||
}
|
}
|
||||||
{{end}}
|
{{end}}
|
||||||
{{if and .Config.GenerateGetIDName .IDColumnName}}
|
{{if and .Config.GenerateGetIDName .IDColumnName}}
|
||||||
|
|||||||
@@ -11,29 +11,43 @@ import (
|
|||||||
|
|
||||||
// TypeMapper handles type conversions between SQL and Go types
|
// TypeMapper handles type conversions between SQL and Go types
|
||||||
type TypeMapper struct {
|
type TypeMapper struct {
|
||||||
// Package alias for sql_types import
|
|
||||||
sqlTypesAlias string
|
sqlTypesAlias string
|
||||||
|
typeStyle string // writers.NullableTypeResolveSpec | writers.NullableTypeStdlib
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewTypeMapper creates a new TypeMapper with default settings
|
// NewTypeMapper creates a new TypeMapper.
|
||||||
func NewTypeMapper() *TypeMapper {
|
// typeStyle should be writers.NullableTypeResolveSpec or writers.NullableTypeStdlib;
|
||||||
|
// an empty string defaults to resolvespec.
|
||||||
|
func NewTypeMapper(typeStyle string) *TypeMapper {
|
||||||
|
if typeStyle == "" {
|
||||||
|
typeStyle = writers.NullableTypeResolveSpec
|
||||||
|
}
|
||||||
return &TypeMapper{
|
return &TypeMapper{
|
||||||
sqlTypesAlias: "sql_types",
|
sqlTypesAlias: "sql_types",
|
||||||
|
typeStyle: typeStyle,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// SQLTypeToGoType converts a SQL type to its Go equivalent
|
// SQLTypeToGoType converts a SQL type to its Go equivalent.
|
||||||
// Handles nullable types using ResolveSpec sql_types package
|
|
||||||
func (tm *TypeMapper) SQLTypeToGoType(sqlType string, notNull bool) string {
|
func (tm *TypeMapper) SQLTypeToGoType(sqlType string, notNull bool) string {
|
||||||
// Normalize SQL type (lowercase, remove length/precision)
|
// Array types are handled separately for both styles.
|
||||||
|
if pgsql.IsArrayType(sqlType) {
|
||||||
|
return tm.arrayGoType(tm.extractBaseType(sqlType))
|
||||||
|
}
|
||||||
|
|
||||||
baseType := tm.extractBaseType(sqlType)
|
baseType := tm.extractBaseType(sqlType)
|
||||||
|
|
||||||
// If not null, use base Go types
|
if tm.typeStyle == writers.NullableTypeStdlib {
|
||||||
|
if notNull {
|
||||||
|
return tm.rawGoType(baseType)
|
||||||
|
}
|
||||||
|
return tm.stdlibNullableGoType(baseType)
|
||||||
|
}
|
||||||
|
|
||||||
|
// resolvespec (default)
|
||||||
if notNull {
|
if notNull {
|
||||||
return tm.baseGoType(baseType)
|
return tm.baseGoType(baseType)
|
||||||
}
|
}
|
||||||
|
|
||||||
// For nullable fields, use sql_types
|
|
||||||
return tm.nullableGoType(baseType)
|
return tm.nullableGoType(baseType)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -106,6 +120,9 @@ func (tm *TypeMapper) baseGoType(sqlType string) string {
|
|||||||
|
|
||||||
// Other
|
// Other
|
||||||
"money": "float64",
|
"money": "float64",
|
||||||
|
|
||||||
|
// pgvector — always uses SqlVector even when NOT NULL
|
||||||
|
"vector": tm.sqlTypesAlias + ".SqlVector",
|
||||||
}
|
}
|
||||||
|
|
||||||
if goType, ok := typeMap[sqlType]; ok {
|
if goType, ok := typeMap[sqlType]; ok {
|
||||||
@@ -179,6 +196,9 @@ func (tm *TypeMapper) nullableGoType(sqlType string) string {
|
|||||||
|
|
||||||
// Other
|
// Other
|
||||||
"money": tm.sqlTypesAlias + ".SqlFloat64",
|
"money": tm.sqlTypesAlias + ".SqlFloat64",
|
||||||
|
|
||||||
|
// pgvector
|
||||||
|
"vector": tm.sqlTypesAlias + ".SqlVector",
|
||||||
}
|
}
|
||||||
|
|
||||||
if goType, ok := typeMap[sqlType]; ok {
|
if goType, ok := typeMap[sqlType]; ok {
|
||||||
@@ -189,6 +209,123 @@ func (tm *TypeMapper) nullableGoType(sqlType string) string {
|
|||||||
return tm.sqlTypesAlias + ".SqlString"
|
return tm.sqlTypesAlias + ".SqlString"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// arrayGoType returns the Go type for a PostgreSQL array column.
|
||||||
|
// The baseElemType is the canonical base type (e.g. "text", "integer").
|
||||||
|
func (tm *TypeMapper) arrayGoType(baseElemType string) string {
|
||||||
|
if tm.typeStyle == writers.NullableTypeStdlib {
|
||||||
|
return tm.stdlibArrayGoType(baseElemType)
|
||||||
|
}
|
||||||
|
typeMap := map[string]string{
|
||||||
|
"text": tm.sqlTypesAlias + ".SqlStringArray", "varchar": tm.sqlTypesAlias + ".SqlStringArray",
|
||||||
|
"char": tm.sqlTypesAlias + ".SqlStringArray", "character": tm.sqlTypesAlias + ".SqlStringArray",
|
||||||
|
"citext": tm.sqlTypesAlias + ".SqlStringArray", "bpchar": tm.sqlTypesAlias + ".SqlStringArray",
|
||||||
|
"inet": tm.sqlTypesAlias + ".SqlStringArray", "cidr": tm.sqlTypesAlias + ".SqlStringArray",
|
||||||
|
"macaddr": tm.sqlTypesAlias + ".SqlStringArray",
|
||||||
|
"json": tm.sqlTypesAlias + ".SqlStringArray", "jsonb": tm.sqlTypesAlias + ".SqlStringArray",
|
||||||
|
"integer": tm.sqlTypesAlias + ".SqlInt32Array", "int": tm.sqlTypesAlias + ".SqlInt32Array",
|
||||||
|
"int4": tm.sqlTypesAlias + ".SqlInt32Array", "serial": tm.sqlTypesAlias + ".SqlInt32Array",
|
||||||
|
"smallint": tm.sqlTypesAlias + ".SqlInt16Array", "int2": tm.sqlTypesAlias + ".SqlInt16Array",
|
||||||
|
"smallserial": tm.sqlTypesAlias + ".SqlInt16Array",
|
||||||
|
"bigint": tm.sqlTypesAlias + ".SqlInt64Array", "int8": tm.sqlTypesAlias + ".SqlInt64Array",
|
||||||
|
"bigserial": tm.sqlTypesAlias + ".SqlInt64Array",
|
||||||
|
"real": tm.sqlTypesAlias + ".SqlFloat32Array", "float4": tm.sqlTypesAlias + ".SqlFloat32Array",
|
||||||
|
"double precision": tm.sqlTypesAlias + ".SqlFloat64Array", "float8": tm.sqlTypesAlias + ".SqlFloat64Array",
|
||||||
|
"numeric": tm.sqlTypesAlias + ".SqlFloat64Array", "decimal": tm.sqlTypesAlias + ".SqlFloat64Array",
|
||||||
|
"money": tm.sqlTypesAlias + ".SqlFloat64Array",
|
||||||
|
"boolean": tm.sqlTypesAlias + ".SqlBoolArray", "bool": tm.sqlTypesAlias + ".SqlBoolArray",
|
||||||
|
"uuid": tm.sqlTypesAlias + ".SqlUUIDArray",
|
||||||
|
}
|
||||||
|
if goType, ok := typeMap[baseElemType]; ok {
|
||||||
|
return goType
|
||||||
|
}
|
||||||
|
return tm.sqlTypesAlias + ".SqlStringArray"
|
||||||
|
}
|
||||||
|
|
||||||
|
// rawGoType returns the plain Go type for a NOT NULL column in stdlib mode.
|
||||||
|
func (tm *TypeMapper) rawGoType(sqlType string) string {
|
||||||
|
typeMap := map[string]string{
|
||||||
|
"integer": "int32", "int": "int32", "int4": "int32", "serial": "int32",
|
||||||
|
"smallint": "int16", "int2": "int16", "smallserial": "int16",
|
||||||
|
"bigint": "int64", "int8": "int64", "bigserial": "int64",
|
||||||
|
"boolean": "bool", "bool": "bool",
|
||||||
|
"real": "float32", "float4": "float32",
|
||||||
|
"double precision": "float64", "float8": "float64",
|
||||||
|
"numeric": "float64", "decimal": "float64", "money": "float64",
|
||||||
|
"text": "string", "varchar": "string", "char": "string",
|
||||||
|
"character": "string", "citext": "string", "bpchar": "string",
|
||||||
|
"inet": "string", "cidr": "string", "macaddr": "string",
|
||||||
|
"uuid": "string", "json": "string", "jsonb": "string",
|
||||||
|
"timestamp": "time.Time",
|
||||||
|
"timestamp without time zone": "time.Time",
|
||||||
|
"timestamp with time zone": "time.Time",
|
||||||
|
"timestamptz": "time.Time",
|
||||||
|
"date": "time.Time",
|
||||||
|
"time": "time.Time",
|
||||||
|
"time without time zone": "time.Time",
|
||||||
|
"time with time zone": "time.Time",
|
||||||
|
"timetz": "time.Time",
|
||||||
|
"bytea": "[]byte",
|
||||||
|
"vector": "[]float32",
|
||||||
|
}
|
||||||
|
if goType, ok := typeMap[sqlType]; ok {
|
||||||
|
return goType
|
||||||
|
}
|
||||||
|
return "string"
|
||||||
|
}
|
||||||
|
|
||||||
|
// stdlibNullableGoType returns the database/sql nullable type for a column.
|
||||||
|
func (tm *TypeMapper) stdlibNullableGoType(sqlType string) string {
|
||||||
|
typeMap := map[string]string{
|
||||||
|
"integer": "sql.NullInt32", "int": "sql.NullInt32", "int4": "sql.NullInt32", "serial": "sql.NullInt32",
|
||||||
|
"smallint": "sql.NullInt16", "int2": "sql.NullInt16", "smallserial": "sql.NullInt16",
|
||||||
|
"bigint": "sql.NullInt64", "int8": "sql.NullInt64", "bigserial": "sql.NullInt64",
|
||||||
|
"boolean": "sql.NullBool", "bool": "sql.NullBool",
|
||||||
|
"real": "sql.NullFloat64", "float4": "sql.NullFloat64",
|
||||||
|
"double precision": "sql.NullFloat64", "float8": "sql.NullFloat64",
|
||||||
|
"numeric": "sql.NullFloat64", "decimal": "sql.NullFloat64", "money": "sql.NullFloat64",
|
||||||
|
"text": "sql.NullString", "varchar": "sql.NullString", "char": "sql.NullString",
|
||||||
|
"character": "sql.NullString", "citext": "sql.NullString", "bpchar": "sql.NullString",
|
||||||
|
"inet": "sql.NullString", "cidr": "sql.NullString", "macaddr": "sql.NullString",
|
||||||
|
"uuid": "sql.NullString", "json": "sql.NullString", "jsonb": "sql.NullString",
|
||||||
|
"timestamp": "sql.NullTime",
|
||||||
|
"timestamp without time zone": "sql.NullTime",
|
||||||
|
"timestamp with time zone": "sql.NullTime",
|
||||||
|
"timestamptz": "sql.NullTime",
|
||||||
|
"date": "sql.NullTime",
|
||||||
|
"time": "sql.NullTime",
|
||||||
|
"time without time zone": "sql.NullTime",
|
||||||
|
"time with time zone": "sql.NullTime",
|
||||||
|
"timetz": "sql.NullTime",
|
||||||
|
"bytea": "[]byte",
|
||||||
|
"vector": "[]float32",
|
||||||
|
}
|
||||||
|
if goType, ok := typeMap[sqlType]; ok {
|
||||||
|
return goType
|
||||||
|
}
|
||||||
|
return "sql.NullString"
|
||||||
|
}
|
||||||
|
|
||||||
|
// stdlibArrayGoType returns a plain Go slice type for array columns in stdlib mode.
|
||||||
|
func (tm *TypeMapper) stdlibArrayGoType(baseElemType string) string {
|
||||||
|
typeMap := map[string]string{
|
||||||
|
"text": "[]string", "varchar": "[]string", "char": "[]string",
|
||||||
|
"character": "[]string", "citext": "[]string", "bpchar": "[]string",
|
||||||
|
"inet": "[]string", "cidr": "[]string", "macaddr": "[]string",
|
||||||
|
"uuid": "[]string", "json": "[]string", "jsonb": "[]string",
|
||||||
|
"integer": "[]int32", "int": "[]int32", "int4": "[]int32", "serial": "[]int32",
|
||||||
|
"smallint": "[]int16", "int2": "[]int16", "smallserial": "[]int16",
|
||||||
|
"bigint": "[]int64", "int8": "[]int64", "bigserial": "[]int64",
|
||||||
|
"real": "[]float32", "float4": "[]float32",
|
||||||
|
"double precision": "[]float64", "float8": "[]float64",
|
||||||
|
"numeric": "[]float64", "decimal": "[]float64", "money": "[]float64",
|
||||||
|
"boolean": "[]bool", "bool": "[]bool",
|
||||||
|
}
|
||||||
|
if goType, ok := typeMap[baseElemType]; ok {
|
||||||
|
return goType
|
||||||
|
}
|
||||||
|
return "[]string"
|
||||||
|
}
|
||||||
|
|
||||||
// BuildGormTag generates a complete GORM tag string for a column
|
// BuildGormTag generates a complete GORM tag string for a column
|
||||||
func (tm *TypeMapper) BuildGormTag(column *models.Column, table *models.Table) string {
|
func (tm *TypeMapper) BuildGormTag(column *models.Column, table *models.Table) string {
|
||||||
var parts []string
|
var parts []string
|
||||||
@@ -330,7 +467,16 @@ func (tm *TypeMapper) NeedsFmtImport(generateGetIDStr bool) bool {
|
|||||||
return generateGetIDStr
|
return generateGetIDStr
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetSQLTypesImport returns the import path for sql_types
|
// GetSQLTypesImport returns the import path for the ResolveSpec spectypes package.
|
||||||
func (tm *TypeMapper) GetSQLTypesImport() string {
|
func (tm *TypeMapper) GetSQLTypesImport() string {
|
||||||
return "github.com/bitechdev/ResolveSpec/pkg/spectypes"
|
return "github.com/bitechdev/ResolveSpec/pkg/spectypes"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetNullableTypeImportLine returns the full Go import line for the nullable type
|
||||||
|
// package (ready to pass to AddImport). Returns empty string when no import is needed.
|
||||||
|
func (tm *TypeMapper) GetNullableTypeImportLine() string {
|
||||||
|
if tm.typeStyle == writers.NullableTypeStdlib {
|
||||||
|
return "\"database/sql\""
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s \"%s\"", tm.sqlTypesAlias, tm.GetSQLTypesImport())
|
||||||
|
}
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ type Writer struct {
|
|||||||
func NewWriter(options *writers.WriterOptions) *Writer {
|
func NewWriter(options *writers.WriterOptions) *Writer {
|
||||||
w := &Writer{
|
w := &Writer{
|
||||||
options: options,
|
options: options,
|
||||||
typeMapper: NewTypeMapper(),
|
typeMapper: NewTypeMapper(options.NullableTypes),
|
||||||
config: LoadMethodConfigFromMetadata(options.Metadata),
|
config: LoadMethodConfigFromMetadata(options.Metadata),
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -77,8 +77,8 @@ func (w *Writer) writeSingleFile(db *models.Database) error {
|
|||||||
packageName := w.getPackageName()
|
packageName := w.getPackageName()
|
||||||
templateData := NewTemplateData(packageName, w.config)
|
templateData := NewTemplateData(packageName, w.config)
|
||||||
|
|
||||||
// Add sql_types import (always needed for nullable types)
|
// Add nullable types import (resolvespec or stdlib depending on options)
|
||||||
templateData.AddImport(fmt.Sprintf("sql_types \"%s\"", w.typeMapper.GetSQLTypesImport()))
|
templateData.AddImport(w.typeMapper.GetNullableTypeImportLine())
|
||||||
|
|
||||||
// Collect all models
|
// Collect all models
|
||||||
for _, schema := range db.Schemas {
|
for _, schema := range db.Schemas {
|
||||||
@@ -99,8 +99,8 @@ func (w *Writer) writeSingleFile(db *models.Database) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add fmt import if GetIDStr is enabled
|
// Add fmt import when generated helper methods need string formatting.
|
||||||
if w.config.GenerateGetIDStr {
|
if w.needsFmtImport(templateData.Models) {
|
||||||
templateData.AddImport("\"fmt\"")
|
templateData.AddImport("\"fmt\"")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -171,8 +171,8 @@ func (w *Writer) writeMultiFile(db *models.Database) error {
|
|||||||
// Create template data for this single table
|
// Create template data for this single table
|
||||||
templateData := NewTemplateData(packageName, w.config)
|
templateData := NewTemplateData(packageName, w.config)
|
||||||
|
|
||||||
// Add sql_types import
|
// Add nullable types import (resolvespec or stdlib depending on options)
|
||||||
templateData.AddImport(fmt.Sprintf("sql_types \"%s\"", w.typeMapper.GetSQLTypesImport()))
|
templateData.AddImport(w.typeMapper.GetNullableTypeImportLine())
|
||||||
|
|
||||||
// Create model data
|
// Create model data
|
||||||
modelData := NewModelData(table, schema.Name, w.typeMapper, w.options.FlattenSchema)
|
modelData := NewModelData(table, schema.Name, w.typeMapper, w.options.FlattenSchema)
|
||||||
@@ -189,8 +189,8 @@ func (w *Writer) writeMultiFile(db *models.Database) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add fmt import if GetIDStr is enabled
|
// Add fmt import when generated helper methods need string formatting.
|
||||||
if w.config.GenerateGetIDStr {
|
if w.needsFmtImport(templateData.Models) {
|
||||||
templateData.AddImport("\"fmt\"")
|
templateData.AddImport("\"fmt\"")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -295,6 +295,26 @@ func (w *Writer) addRelationshipFields(modelData *ModelData, table *models.Table
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (w *Writer) needsFmtImport(models []*ModelData) bool {
|
||||||
|
if w.config.GenerateGetIDStr {
|
||||||
|
for _, model := range models {
|
||||||
|
if model.PrimaryKeyField != "" && !model.PrimaryKeyIsSQL && !model.PrimaryKeyIsStr {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if w.config.GenerateUpdateID {
|
||||||
|
for _, model := range models {
|
||||||
|
if model.PrimaryKeyField != "" && model.PrimaryKeyIsSQL && !model.PrimaryKeyIsStr {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
// findTable finds a table by schema and name in the database
|
// findTable finds a table by schema and name in the database
|
||||||
func (w *Writer) findTable(schemaName, tableName string, db *models.Database) *models.Table {
|
func (w *Writer) findTable(schemaName, tableName string, db *models.Database) *models.Table {
|
||||||
for _, schema := range db.Schemas {
|
for _, schema := range db.Schemas {
|
||||||
|
|||||||
@@ -598,6 +598,55 @@ func TestWriter_UpdateIDTypeSafety(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestWriter_StringPrimaryKeyHelpers_Gorm(t *testing.T) {
|
||||||
|
table := models.InitTable("accounts", "public")
|
||||||
|
table.Columns["id"] = &models.Column{
|
||||||
|
Name: "id",
|
||||||
|
Type: "uuid",
|
||||||
|
NotNull: true,
|
||||||
|
IsPrimaryKey: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
opts := &writers.WriterOptions{
|
||||||
|
PackageName: "models",
|
||||||
|
OutputPath: filepath.Join(tmpDir, "test.go"),
|
||||||
|
}
|
||||||
|
|
||||||
|
writer := NewWriter(opts)
|
||||||
|
err := writer.WriteTable(table)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("WriteTable failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
content, err := os.ReadFile(opts.OutputPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to read generated file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
generated := string(content)
|
||||||
|
|
||||||
|
expectations := []string{
|
||||||
|
"ID string",
|
||||||
|
"func (m ModelPublicAccounts) GetID() string",
|
||||||
|
"return m.ID",
|
||||||
|
"func (m ModelPublicAccounts) GetIDStr() string",
|
||||||
|
"func (m ModelPublicAccounts) SetID(newid string)",
|
||||||
|
"func (m *ModelPublicAccounts) UpdateID(newid string)",
|
||||||
|
"m.ID = newid",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, expected := range expectations {
|
||||||
|
if !strings.Contains(generated, expected) {
|
||||||
|
t.Errorf("Generated code missing expected content: %q\nGenerated:\n%s", expected, generated)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.Contains(generated, "GetID() int64") || strings.Contains(generated, "UpdateID(newid int64)") {
|
||||||
|
t.Errorf("String primary keys should not use int64 helper signatures\nGenerated:\n%s", generated)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestNameConverter_SnakeCaseToPascalCase(t *testing.T) {
|
func TestNameConverter_SnakeCaseToPascalCase(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
input string
|
input string
|
||||||
@@ -643,7 +692,7 @@ func TestNameConverter_Pluralize(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestTypeMapper_SQLTypeToGoType(t *testing.T) {
|
func TestTypeMapper_SQLTypeToGoType(t *testing.T) {
|
||||||
mapper := NewTypeMapper()
|
mapper := NewTypeMapper("")
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
sqlType string
|
sqlType string
|
||||||
@@ -658,6 +707,10 @@ func TestTypeMapper_SQLTypeToGoType(t *testing.T) {
|
|||||||
{"timestamp", false, "sql_types.SqlTimeStamp"},
|
{"timestamp", false, "sql_types.SqlTimeStamp"},
|
||||||
{"boolean", true, "bool"},
|
{"boolean", true, "bool"},
|
||||||
{"boolean", false, "sql_types.SqlBool"},
|
{"boolean", false, "sql_types.SqlBool"},
|
||||||
|
{"text[]", true, "sql_types.SqlStringArray"},
|
||||||
|
{"text[]", false, "sql_types.SqlStringArray"},
|
||||||
|
{"integer[]", true, "sql_types.SqlInt32Array"},
|
||||||
|
{"bigint[]", false, "sql_types.SqlInt64Array"},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
@@ -670,8 +723,23 @@ func TestTypeMapper_SQLTypeToGoType(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestTypeMapper_BuildGormTag_ArrayType(t *testing.T) {
|
||||||
|
mapper := NewTypeMapper("")
|
||||||
|
|
||||||
|
col := &models.Column{
|
||||||
|
Name: "tags",
|
||||||
|
Type: "text[]",
|
||||||
|
NotNull: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
tag := mapper.BuildGormTag(col, nil)
|
||||||
|
if !strings.Contains(tag, "type:text[]") {
|
||||||
|
t.Fatalf("expected array type to be preserved, got %q", tag)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestTypeMapper_BuildGormTag_PreservesExplicitTypeModifiers(t *testing.T) {
|
func TestTypeMapper_BuildGormTag_PreservesExplicitTypeModifiers(t *testing.T) {
|
||||||
mapper := NewTypeMapper()
|
mapper := NewTypeMapper("")
|
||||||
|
|
||||||
col := &models.Column{
|
col := &models.Column{
|
||||||
Name: "embedding",
|
Name: "embedding",
|
||||||
|
|||||||
@@ -31,6 +31,10 @@ type MigrationWriter struct {
|
|||||||
|
|
||||||
// NewMigrationWriter creates a new templated migration writer
|
// NewMigrationWriter creates a new templated migration writer
|
||||||
func NewMigrationWriter(options *writers.WriterOptions) (*MigrationWriter, error) {
|
func NewMigrationWriter(options *writers.WriterOptions) (*MigrationWriter, error) {
|
||||||
|
if options == nil {
|
||||||
|
options = &writers.WriterOptions{}
|
||||||
|
}
|
||||||
|
|
||||||
executor, err := NewTemplateExecutor(options.FlattenSchema)
|
executor, err := NewTemplateExecutor(options.FlattenSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to create template executor: %w", err)
|
return nil, fmt.Errorf("failed to create template executor: %w", err)
|
||||||
@@ -44,6 +48,16 @@ func NewMigrationWriter(options *writers.WriterOptions) (*MigrationWriter, error
|
|||||||
|
|
||||||
// WriteMigration generates migration scripts using templates
|
// WriteMigration generates migration scripts using templates
|
||||||
func (w *MigrationWriter) WriteMigration(model *models.Database, current *models.Database) error {
|
func (w *MigrationWriter) WriteMigration(model *models.Database, current *models.Database) error {
|
||||||
|
if model == nil {
|
||||||
|
return fmt.Errorf("model database is required")
|
||||||
|
}
|
||||||
|
if w.options == nil {
|
||||||
|
w.options = &writers.WriterOptions{}
|
||||||
|
}
|
||||||
|
if current == nil {
|
||||||
|
current = models.InitDatabase(model.Name)
|
||||||
|
}
|
||||||
|
|
||||||
var writer io.Writer
|
var writer io.Writer
|
||||||
var file *os.File
|
var file *os.File
|
||||||
var err error
|
var err error
|
||||||
@@ -86,9 +100,16 @@ func (w *MigrationWriter) WriteMigration(model *models.Database, current *models
|
|||||||
|
|
||||||
// Process each schema in the model
|
// Process each schema in the model
|
||||||
for _, modelSchema := range model.Schemas {
|
for _, modelSchema := range model.Schemas {
|
||||||
|
if modelSchema == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
// Find corresponding schema in current database
|
// Find corresponding schema in current database
|
||||||
var currentSchema *models.Schema
|
var currentSchema *models.Schema
|
||||||
for _, cs := range current.Schemas {
|
for _, cs := range current.Schemas {
|
||||||
|
if cs == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
if strings.EqualFold(cs.Name, modelSchema.Name) {
|
if strings.EqualFold(cs.Name, modelSchema.Name) {
|
||||||
currentSchema = cs
|
currentSchema = cs
|
||||||
break
|
break
|
||||||
@@ -139,6 +160,17 @@ func (w *MigrationWriter) WriteMigration(model *models.Database, current *models
|
|||||||
func (w *MigrationWriter) generateSchemaScripts(model *models.Schema, current *models.Schema) ([]MigrationScript, error) {
|
func (w *MigrationWriter) generateSchemaScripts(model *models.Schema, current *models.Schema) ([]MigrationScript, error) {
|
||||||
scripts := make([]MigrationScript, 0)
|
scripts := make([]MigrationScript, 0)
|
||||||
|
|
||||||
|
if schemaRequiresPGTrgm(model) {
|
||||||
|
scripts = append(scripts, MigrationScript{
|
||||||
|
ObjectName: "extension.pg_trgm",
|
||||||
|
ObjectType: "create extension",
|
||||||
|
Schema: model.Name,
|
||||||
|
Priority: 80,
|
||||||
|
Sequence: len(scripts),
|
||||||
|
Body: "CREATE EXTENSION IF NOT EXISTS pg_trgm;",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// Phase 1: Drop constraints and indexes that changed (Priority 11-50)
|
// Phase 1: Drop constraints and indexes that changed (Priority 11-50)
|
||||||
if current != nil {
|
if current != nil {
|
||||||
dropScripts, err := w.generateDropScripts(model, current)
|
dropScripts, err := w.generateDropScripts(model, current)
|
||||||
@@ -329,14 +361,18 @@ func (w *MigrationWriter) generateAlterTableScripts(schema *models.Schema, model
|
|||||||
// Column doesn't exist, add it
|
// Column doesn't exist, add it
|
||||||
defaultVal := ""
|
defaultVal := ""
|
||||||
if modelCol.Default != nil {
|
if modelCol.Default != nil {
|
||||||
|
if value, ok := modelCol.Default.(string); ok {
|
||||||
|
defaultVal = writers.QuoteDefaultValue(value, modelCol.Type)
|
||||||
|
} else {
|
||||||
defaultVal = fmt.Sprintf("%v", modelCol.Default)
|
defaultVal = fmt.Sprintf("%v", modelCol.Default)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
sql, err := w.executor.ExecuteAddColumn(AddColumnData{
|
sql, err := w.executor.ExecuteAddColumn(AddColumnData{
|
||||||
SchemaName: schema.Name,
|
SchemaName: schema.Name,
|
||||||
TableName: modelTable.Name,
|
TableName: modelTable.Name,
|
||||||
ColumnName: modelCol.Name,
|
ColumnName: modelCol.Name,
|
||||||
ColumnType: pgsql.ConvertSQLType(modelCol.Type),
|
ColumnType: effectiveColumnSQLType(modelCol),
|
||||||
Default: defaultVal,
|
Default: defaultVal,
|
||||||
NotNull: modelCol.NotNull,
|
NotNull: modelCol.NotNull,
|
||||||
})
|
})
|
||||||
@@ -355,12 +391,13 @@ func (w *MigrationWriter) generateAlterTableScripts(schema *models.Schema, model
|
|||||||
scripts = append(scripts, script)
|
scripts = append(scripts, script)
|
||||||
} else if !columnsEqual(modelCol, currentCol) {
|
} else if !columnsEqual(modelCol, currentCol) {
|
||||||
// Column exists but properties changed
|
// Column exists but properties changed
|
||||||
if modelCol.Type != currentCol.Type {
|
if !columnTypesEqual(modelCol, currentCol) {
|
||||||
sql, err := w.executor.ExecuteAlterColumnType(AlterColumnTypeData{
|
sql, err := w.executor.ExecuteAlterColumnType(AlterColumnTypeData{
|
||||||
SchemaName: schema.Name,
|
SchemaName: schema.Name,
|
||||||
TableName: modelTable.Name,
|
TableName: modelTable.Name,
|
||||||
ColumnName: modelCol.Name,
|
ColumnName: modelCol.Name,
|
||||||
NewType: pgsql.ConvertSQLType(modelCol.Type),
|
NewType: effectiveAlterColumnSQLType(modelCol),
|
||||||
|
UsingExpr: buildAlterColumnUsingExpression(modelCol.Name, effectiveAlterColumnSQLType(modelCol)),
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -382,8 +419,12 @@ func (w *MigrationWriter) generateAlterTableScripts(schema *models.Schema, model
|
|||||||
setDefault := modelCol.Default != nil
|
setDefault := modelCol.Default != nil
|
||||||
defaultVal := ""
|
defaultVal := ""
|
||||||
if setDefault {
|
if setDefault {
|
||||||
|
if value, ok := modelCol.Default.(string); ok {
|
||||||
|
defaultVal = writers.QuoteDefaultValue(value, modelCol.Type)
|
||||||
|
} else {
|
||||||
defaultVal = fmt.Sprintf("%v", modelCol.Default)
|
defaultVal = fmt.Sprintf("%v", modelCol.Default)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
sql, err := w.executor.ExecuteAlterColumnDefault(AlterColumnDefaultData{
|
sql, err := w.executor.ExecuteAlterColumnDefault(AlterColumnDefaultData{
|
||||||
SchemaName: schema.Name,
|
SchemaName: schema.Name,
|
||||||
@@ -537,12 +578,17 @@ func (w *MigrationWriter) generateIndexScripts(model *models.Schema, current *mo
|
|||||||
indexType = modelIndex.Type
|
indexType = modelIndex.Type
|
||||||
}
|
}
|
||||||
|
|
||||||
|
columnExprs := buildIndexColumnExpressions(modelTable, modelIndex, indexType)
|
||||||
|
if len(columnExprs) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
sql, err := w.executor.ExecuteCreateIndex(CreateIndexData{
|
sql, err := w.executor.ExecuteCreateIndex(CreateIndexData{
|
||||||
SchemaName: model.Name,
|
SchemaName: model.Name,
|
||||||
TableName: modelTable.Name,
|
TableName: modelTable.Name,
|
||||||
IndexName: indexName,
|
IndexName: indexName,
|
||||||
IndexType: indexType,
|
IndexType: indexType,
|
||||||
Columns: strings.Join(modelIndex.Columns, ", "),
|
Columns: strings.Join(columnExprs, ", "),
|
||||||
Unique: modelIndex.Unique,
|
Unique: modelIndex.Unique,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -565,6 +611,26 @@ func (w *MigrationWriter) generateIndexScripts(model *models.Schema, current *mo
|
|||||||
return scripts, nil
|
return scripts, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func buildIndexColumnExpressions(table *models.Table, index *models.Index, indexType string) []string {
|
||||||
|
columnExprs := make([]string, 0, len(index.Columns))
|
||||||
|
for _, colName := range index.Columns {
|
||||||
|
colExpr := colName
|
||||||
|
if table != nil {
|
||||||
|
if col, ok := resolveIndexColumn(table, colName); ok && col != nil {
|
||||||
|
colExpr = col.SQLName()
|
||||||
|
if strings.EqualFold(indexType, "gin") {
|
||||||
|
opClass := ginOperatorClassForColumn(col, index.Comment)
|
||||||
|
if opClass != "" {
|
||||||
|
colExpr = fmt.Sprintf("%s %s", col.SQLName(), opClass)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
columnExprs = append(columnExprs, colExpr)
|
||||||
|
}
|
||||||
|
return columnExprs
|
||||||
|
}
|
||||||
|
|
||||||
// generateForeignKeyScripts generates ADD CONSTRAINT FOREIGN KEY scripts using templates
|
// generateForeignKeyScripts generates ADD CONSTRAINT FOREIGN KEY scripts using templates
|
||||||
func (w *MigrationWriter) generateForeignKeyScripts(model *models.Schema, current *models.Schema) ([]MigrationScript, error) {
|
func (w *MigrationWriter) generateForeignKeyScripts(model *models.Schema, current *models.Schema) ([]MigrationScript, error) {
|
||||||
scripts := make([]MigrationScript, 0)
|
scripts := make([]MigrationScript, 0)
|
||||||
@@ -820,11 +886,21 @@ func columnsEqual(col1, col2 *models.Column) bool {
|
|||||||
if col1 == nil || col2 == nil {
|
if col1 == nil || col2 == nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
return strings.EqualFold(col1.Type, col2.Type) &&
|
return columnTypesEqual(col1, col2) &&
|
||||||
col1.NotNull == col2.NotNull &&
|
col1.NotNull == col2.NotNull &&
|
||||||
fmt.Sprintf("%v", col1.Default) == fmt.Sprintf("%v", col2.Default)
|
fmt.Sprintf("%v", col1.Default) == fmt.Sprintf("%v", col2.Default)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func columnTypesEqual(col1, col2 *models.Column) bool {
|
||||||
|
if col1 == nil || col2 == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return strings.EqualFold(
|
||||||
|
pgsql.NormalizeEquivalentSQLType(effectiveColumnSQLType(col1)),
|
||||||
|
pgsql.NormalizeEquivalentSQLType(effectiveColumnSQLType(col2)),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
// constraintsEqual checks if two constraints are equal
|
// constraintsEqual checks if two constraints are equal
|
||||||
func constraintsEqual(c1, c2 *models.Constraint) bool {
|
func constraintsEqual(c1, c2 *models.Constraint) bool {
|
||||||
if c1 == nil || c2 == nil {
|
if c1 == nil || c2 == nil {
|
||||||
|
|||||||
@@ -57,6 +57,410 @@ func TestWriteMigration_NewTable(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestWriteMigration_ArrayDefault(t *testing.T) {
|
||||||
|
current := models.InitDatabase("testdb")
|
||||||
|
currentSchema := models.InitSchema("public")
|
||||||
|
current.Schemas = append(current.Schemas, currentSchema)
|
||||||
|
|
||||||
|
model := models.InitDatabase("testdb")
|
||||||
|
modelSchema := models.InitSchema("public")
|
||||||
|
|
||||||
|
table := models.InitTable("plans", "public")
|
||||||
|
|
||||||
|
tagsCol := models.InitColumn("tags", "plans", "public")
|
||||||
|
tagsCol.Type = "text[]"
|
||||||
|
tagsCol.NotNull = true
|
||||||
|
tagsCol.Default = "''{}''"
|
||||||
|
table.Columns["tags"] = tagsCol
|
||||||
|
|
||||||
|
modelSchema.Tables = append(modelSchema.Tables, table)
|
||||||
|
model.Schemas = append(model.Schemas, modelSchema)
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
writer, err := NewMigrationWriter(&writers.WriterOptions{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create writer: %v", err)
|
||||||
|
}
|
||||||
|
writer.writer = &buf
|
||||||
|
|
||||||
|
err = writer.WriteMigration(model, current)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("WriteMigration failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
output := buf.String()
|
||||||
|
if !strings.Contains(output, "tags text[] DEFAULT '{}' NOT NULL") {
|
||||||
|
t.Fatalf("expected normalized array default in migration, got:\n%s", output)
|
||||||
|
}
|
||||||
|
if strings.Contains(output, "'''{}'''") {
|
||||||
|
t.Fatalf("migration still contains triple-quoted array default:\n%s", output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWriteMigration_AltersColumnTypeWhenActualTypeDiffers(t *testing.T) {
|
||||||
|
current := models.InitDatabase("testdb")
|
||||||
|
currentSchema := models.InitSchema("public")
|
||||||
|
currentTable := models.InitTable("learnings", "public")
|
||||||
|
currentDetails := models.InitColumn("details", "learnings", "public")
|
||||||
|
currentDetails.Type = "jsonb"
|
||||||
|
currentTable.Columns["details"] = currentDetails
|
||||||
|
currentSchema.Tables = append(currentSchema.Tables, currentTable)
|
||||||
|
current.Schemas = append(current.Schemas, currentSchema)
|
||||||
|
|
||||||
|
model := models.InitDatabase("testdb")
|
||||||
|
modelSchema := models.InitSchema("public")
|
||||||
|
modelTable := models.InitTable("learnings", "public")
|
||||||
|
modelDetails := models.InitColumn("details", "learnings", "public")
|
||||||
|
modelDetails.Type = "text"
|
||||||
|
modelTable.Columns["details"] = modelDetails
|
||||||
|
modelSchema.Tables = append(modelSchema.Tables, modelTable)
|
||||||
|
model.Schemas = append(model.Schemas, modelSchema)
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
writer, err := NewMigrationWriter(&writers.WriterOptions{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create writer: %v", err)
|
||||||
|
}
|
||||||
|
writer.writer = &buf
|
||||||
|
|
||||||
|
if err := writer.WriteMigration(model, current); err != nil {
|
||||||
|
t.Fatalf("WriteMigration failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
output := buf.String()
|
||||||
|
if !strings.Contains(output, "ALTER TABLE public.learnings") || !strings.Contains(output, "ALTER COLUMN details TYPE text") {
|
||||||
|
t.Fatalf("expected migration to alter mismatched column type, got:\n%s", output)
|
||||||
|
}
|
||||||
|
if !strings.Contains(output, `ALTER COLUMN details TYPE text USING details::text;`) {
|
||||||
|
t.Fatalf("expected migration type alter to include USING cast, got:\n%s", output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWriteMigration_UsesStorageTypeForSerialAlterStatements(t *testing.T) {
|
||||||
|
current := models.InitDatabase("testdb")
|
||||||
|
currentSchema := models.InitSchema("public")
|
||||||
|
currentTable := models.InitTable("learnings", "public")
|
||||||
|
currentID := models.InitColumn("id", "learnings", "public")
|
||||||
|
currentID.Type = "uuid"
|
||||||
|
currentTable.Columns["id"] = currentID
|
||||||
|
currentSchema.Tables = append(currentSchema.Tables, currentTable)
|
||||||
|
current.Schemas = append(current.Schemas, currentSchema)
|
||||||
|
|
||||||
|
model := models.InitDatabase("testdb")
|
||||||
|
modelSchema := models.InitSchema("public")
|
||||||
|
modelTable := models.InitTable("learnings", "public")
|
||||||
|
modelID := models.InitColumn("id", "learnings", "public")
|
||||||
|
modelID.Type = "bigserial"
|
||||||
|
modelTable.Columns["id"] = modelID
|
||||||
|
modelSchema.Tables = append(modelSchema.Tables, modelTable)
|
||||||
|
model.Schemas = append(model.Schemas, modelSchema)
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
writer, err := NewMigrationWriter(&writers.WriterOptions{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create writer: %v", err)
|
||||||
|
}
|
||||||
|
writer.writer = &buf
|
||||||
|
|
||||||
|
if err := writer.WriteMigration(model, current); err != nil {
|
||||||
|
t.Fatalf("WriteMigration failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
output := buf.String()
|
||||||
|
if !strings.Contains(output, "ALTER COLUMN id TYPE bigint") {
|
||||||
|
t.Fatalf("expected serial alter to use bigint storage type, got:\n%s", output)
|
||||||
|
}
|
||||||
|
if strings.Contains(output, "ALTER COLUMN id TYPE bigserial;") {
|
||||||
|
t.Fatalf("did not expect invalid bigserial alter statement, got:\n%s", output)
|
||||||
|
}
|
||||||
|
if !strings.Contains(output, `ALTER COLUMN id TYPE bigint USING id::bigint;`) {
|
||||||
|
t.Fatalf("expected serial alter to include USING cast, got:\n%s", output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWriteMigration_ArrayAlterIncludesUsingCast(t *testing.T) {
|
||||||
|
current := models.InitDatabase("testdb")
|
||||||
|
currentSchema := models.InitSchema("public")
|
||||||
|
currentTable := models.InitTable("learnings", "public")
|
||||||
|
currentTags := models.InitColumn("tags", "learnings", "public")
|
||||||
|
currentTags.Type = "text"
|
||||||
|
currentTable.Columns["tags"] = currentTags
|
||||||
|
currentSchema.Tables = append(currentSchema.Tables, currentTable)
|
||||||
|
current.Schemas = append(current.Schemas, currentSchema)
|
||||||
|
|
||||||
|
model := models.InitDatabase("testdb")
|
||||||
|
modelSchema := models.InitSchema("public")
|
||||||
|
modelTable := models.InitTable("learnings", "public")
|
||||||
|
modelTags := models.InitColumn("tags", "learnings", "public")
|
||||||
|
modelTags.Type = "text[]"
|
||||||
|
modelTable.Columns["tags"] = modelTags
|
||||||
|
modelSchema.Tables = append(modelSchema.Tables, modelTable)
|
||||||
|
model.Schemas = append(model.Schemas, modelSchema)
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
writer, err := NewMigrationWriter(&writers.WriterOptions{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create writer: %v", err)
|
||||||
|
}
|
||||||
|
writer.writer = &buf
|
||||||
|
|
||||||
|
if err := writer.WriteMigration(model, current); err != nil {
|
||||||
|
t.Fatalf("WriteMigration failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
output := buf.String()
|
||||||
|
if !strings.Contains(output, `ALTER COLUMN tags TYPE text[] USING tags::text[];`) {
|
||||||
|
t.Fatalf("expected array alter to include USING cast, got:\n%s", output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWriteMigration_DoesNotAlterEquivalentNormalizedColumnType(t *testing.T) {
|
||||||
|
current := models.InitDatabase("testdb")
|
||||||
|
currentSchema := models.InitSchema("public")
|
||||||
|
currentTable := models.InitTable("users", "public")
|
||||||
|
currentEmail := models.InitColumn("email", "users", "public")
|
||||||
|
currentEmail.Type = "character varying"
|
||||||
|
currentEmail.Length = 255
|
||||||
|
currentTable.Columns["email"] = currentEmail
|
||||||
|
currentSchema.Tables = append(currentSchema.Tables, currentTable)
|
||||||
|
current.Schemas = append(current.Schemas, currentSchema)
|
||||||
|
|
||||||
|
model := models.InitDatabase("testdb")
|
||||||
|
modelSchema := models.InitSchema("public")
|
||||||
|
modelTable := models.InitTable("users", "public")
|
||||||
|
modelEmail := models.InitColumn("email", "users", "public")
|
||||||
|
modelEmail.Type = "varchar(255)"
|
||||||
|
modelTable.Columns["email"] = modelEmail
|
||||||
|
modelSchema.Tables = append(modelSchema.Tables, modelTable)
|
||||||
|
model.Schemas = append(model.Schemas, modelSchema)
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
writer, err := NewMigrationWriter(&writers.WriterOptions{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create writer: %v", err)
|
||||||
|
}
|
||||||
|
writer.writer = &buf
|
||||||
|
|
||||||
|
if err := writer.WriteMigration(model, current); err != nil {
|
||||||
|
t.Fatalf("WriteMigration failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
output := buf.String()
|
||||||
|
if strings.Contains(output, "ALTER COLUMN email TYPE") {
|
||||||
|
t.Fatalf("did not expect alter type for equivalent normalized types, got:\n%s", output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWriteMigration_GinIndexOnTextUsesTrigramOperatorClass(t *testing.T) {
|
||||||
|
current := models.InitDatabase("testdb")
|
||||||
|
currentSchema := models.InitSchema("public")
|
||||||
|
current.Schemas = append(current.Schemas, currentSchema)
|
||||||
|
|
||||||
|
model := models.InitDatabase("testdb")
|
||||||
|
modelSchema := models.InitSchema("public")
|
||||||
|
|
||||||
|
table := models.InitTable("articles", "public")
|
||||||
|
titleCol := models.InitColumn("title", "articles", "public")
|
||||||
|
titleCol.Type = "text"
|
||||||
|
table.Columns["title"] = titleCol
|
||||||
|
|
||||||
|
index := &models.Index{
|
||||||
|
Name: "idx_articles_title_gin",
|
||||||
|
Type: "gin",
|
||||||
|
Columns: []string{"title"},
|
||||||
|
}
|
||||||
|
table.Indexes[index.Name] = index
|
||||||
|
|
||||||
|
modelSchema.Tables = append(modelSchema.Tables, table)
|
||||||
|
model.Schemas = append(model.Schemas, modelSchema)
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
writer, err := NewMigrationWriter(&writers.WriterOptions{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create writer: %v", err)
|
||||||
|
}
|
||||||
|
writer.writer = &buf
|
||||||
|
|
||||||
|
if err := writer.WriteMigration(model, current); err != nil {
|
||||||
|
t.Fatalf("WriteMigration failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
output := buf.String()
|
||||||
|
if !strings.Contains(output, "CREATE EXTENSION IF NOT EXISTS pg_trgm;") {
|
||||||
|
t.Fatalf("expected trigram extension for text GIN migration index, got:\n%s", output)
|
||||||
|
}
|
||||||
|
if !strings.Contains(output, "USING gin (title gin_trgm_ops)") {
|
||||||
|
t.Fatalf("expected GIN text index to include gin_trgm_ops, got:\n%s", output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWriteMigration_GinIndexOnQuotedTextColumnUsesTrigramOperatorClass(t *testing.T) {
|
||||||
|
current := models.InitDatabase("testdb")
|
||||||
|
currentSchema := models.InitSchema("public")
|
||||||
|
current.Schemas = append(current.Schemas, currentSchema)
|
||||||
|
|
||||||
|
model := models.InitDatabase("testdb")
|
||||||
|
modelSchema := models.InitSchema("public")
|
||||||
|
|
||||||
|
table := models.InitTable("agent_personas", "public")
|
||||||
|
nameCol := models.InitColumn("name", "agent_personas", "public")
|
||||||
|
nameCol.Type = "text"
|
||||||
|
table.Columns["name"] = nameCol
|
||||||
|
|
||||||
|
index := &models.Index{
|
||||||
|
Name: "idx_agent_personas_name_gin",
|
||||||
|
Type: "gin",
|
||||||
|
Columns: []string{`"name"`},
|
||||||
|
}
|
||||||
|
table.Indexes[index.Name] = index
|
||||||
|
|
||||||
|
modelSchema.Tables = append(modelSchema.Tables, table)
|
||||||
|
model.Schemas = append(model.Schemas, modelSchema)
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
writer, err := NewMigrationWriter(&writers.WriterOptions{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create writer: %v", err)
|
||||||
|
}
|
||||||
|
writer.writer = &buf
|
||||||
|
|
||||||
|
if err := writer.WriteMigration(model, current); err != nil {
|
||||||
|
t.Fatalf("WriteMigration failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
output := buf.String()
|
||||||
|
if !strings.Contains(output, "USING gin (name gin_trgm_ops)") {
|
||||||
|
t.Fatalf("expected quoted text column GIN index to include gin_trgm_ops, got:\n%s", output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWriteMigration_GinIndexOnTextArrayDoesNotUseTrigramOperatorClass(t *testing.T) {
|
||||||
|
current := models.InitDatabase("testdb")
|
||||||
|
currentSchema := models.InitSchema("public")
|
||||||
|
current.Schemas = append(current.Schemas, currentSchema)
|
||||||
|
|
||||||
|
model := models.InitDatabase("testdb")
|
||||||
|
modelSchema := models.InitSchema("public")
|
||||||
|
|
||||||
|
table := models.InitTable("plans", "public")
|
||||||
|
tagsCol := models.InitColumn("tags", "plans", "public")
|
||||||
|
tagsCol.Type = "text[]"
|
||||||
|
table.Columns["tags"] = tagsCol
|
||||||
|
|
||||||
|
index := &models.Index{
|
||||||
|
Name: "idx_plans_tags",
|
||||||
|
Type: "gin",
|
||||||
|
Columns: []string{"tags"},
|
||||||
|
}
|
||||||
|
table.Indexes[index.Name] = index
|
||||||
|
|
||||||
|
modelSchema.Tables = append(modelSchema.Tables, table)
|
||||||
|
model.Schemas = append(model.Schemas, modelSchema)
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
writer, err := NewMigrationWriter(&writers.WriterOptions{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create writer: %v", err)
|
||||||
|
}
|
||||||
|
writer.writer = &buf
|
||||||
|
|
||||||
|
if err := writer.WriteMigration(model, current); err != nil {
|
||||||
|
t.Fatalf("WriteMigration failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
output := buf.String()
|
||||||
|
if !strings.Contains(output, "USING gin (tags array_ops)") {
|
||||||
|
t.Fatalf("expected GIN array index with array_ops, got:\n%s", output)
|
||||||
|
}
|
||||||
|
if strings.Contains(output, "gin_trgm_ops") {
|
||||||
|
t.Fatalf("did not expect gin_trgm_ops for text[] migration index, got:\n%s", output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWriteMigration_GinIndexOnJSONBUsesJSONBOperatorClass(t *testing.T) {
|
||||||
|
current := models.InitDatabase("testdb")
|
||||||
|
currentSchema := models.InitSchema("public")
|
||||||
|
current.Schemas = append(current.Schemas, currentSchema)
|
||||||
|
|
||||||
|
model := models.InitDatabase("testdb")
|
||||||
|
modelSchema := models.InitSchema("public")
|
||||||
|
|
||||||
|
table := models.InitTable("learnings", "public")
|
||||||
|
detailsCol := models.InitColumn("details", "learnings", "public")
|
||||||
|
detailsCol.Type = "jsonb"
|
||||||
|
table.Columns["details"] = detailsCol
|
||||||
|
|
||||||
|
index := &models.Index{
|
||||||
|
Name: "idx_learnings_details",
|
||||||
|
Type: "gin",
|
||||||
|
Columns: []string{"details"},
|
||||||
|
}
|
||||||
|
table.Indexes[index.Name] = index
|
||||||
|
|
||||||
|
modelSchema.Tables = append(modelSchema.Tables, table)
|
||||||
|
model.Schemas = append(model.Schemas, modelSchema)
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
writer, err := NewMigrationWriter(&writers.WriterOptions{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create writer: %v", err)
|
||||||
|
}
|
||||||
|
writer.writer = &buf
|
||||||
|
|
||||||
|
if err := writer.WriteMigration(model, current); err != nil {
|
||||||
|
t.Fatalf("WriteMigration failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
output := buf.String()
|
||||||
|
if !strings.Contains(output, "USING gin (details jsonb_ops)") {
|
||||||
|
t.Fatalf("expected GIN jsonb index to include jsonb_ops, got:\n%s", output)
|
||||||
|
}
|
||||||
|
if strings.Contains(output, "gin_trgm_ops") {
|
||||||
|
t.Fatalf("did not expect gin_trgm_ops for jsonb migration index, got:\n%s", output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWriteMigration_GinIndexOnJSONBIgnoresIncompatibleTrigramOperatorClass(t *testing.T) {
|
||||||
|
current := models.InitDatabase("testdb")
|
||||||
|
currentSchema := models.InitSchema("public")
|
||||||
|
current.Schemas = append(current.Schemas, currentSchema)
|
||||||
|
|
||||||
|
model := models.InitDatabase("testdb")
|
||||||
|
modelSchema := models.InitSchema("public")
|
||||||
|
|
||||||
|
table := models.InitTable("learnings", "public")
|
||||||
|
detailsCol := models.InitColumn("details", "learnings", "public")
|
||||||
|
detailsCol.Type = "jsonb"
|
||||||
|
table.Columns["details"] = detailsCol
|
||||||
|
|
||||||
|
index := &models.Index{
|
||||||
|
Name: "idx_learnings_details",
|
||||||
|
Type: "gin",
|
||||||
|
Columns: []string{"details"},
|
||||||
|
Comment: "gin_trgm_ops",
|
||||||
|
}
|
||||||
|
table.Indexes[index.Name] = index
|
||||||
|
|
||||||
|
modelSchema.Tables = append(modelSchema.Tables, table)
|
||||||
|
model.Schemas = append(model.Schemas, modelSchema)
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
writer, err := NewMigrationWriter(&writers.WriterOptions{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create writer: %v", err)
|
||||||
|
}
|
||||||
|
writer.writer = &buf
|
||||||
|
|
||||||
|
if err := writer.WriteMigration(model, current); err != nil {
|
||||||
|
t.Fatalf("WriteMigration failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
output := buf.String()
|
||||||
|
if !strings.Contains(output, "USING gin (details jsonb_ops)") {
|
||||||
|
t.Fatalf("expected incompatible trigram hint on jsonb to fall back to jsonb_ops, got:\n%s", output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestWriteMigration_WithAudit(t *testing.T) {
|
func TestWriteMigration_WithAudit(t *testing.T) {
|
||||||
// Current database (empty)
|
// Current database (empty)
|
||||||
current := models.InitDatabase("testdb")
|
current := models.InitDatabase("testdb")
|
||||||
@@ -282,3 +686,46 @@ func TestWriteMigration_NumericConstraintNames(t *testing.T) {
|
|||||||
t.Error("Migration missing FOREIGN KEY")
|
t.Error("Migration missing FOREIGN KEY")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNewMigrationWriter_NilOptions(t *testing.T) {
|
||||||
|
writer, err := NewMigrationWriter(nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewMigrationWriter(nil) returned error: %v", err)
|
||||||
|
}
|
||||||
|
if writer == nil {
|
||||||
|
t.Fatal("expected writer instance")
|
||||||
|
}
|
||||||
|
if writer.options == nil {
|
||||||
|
t.Fatal("expected default writer options to be initialized")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWriteMigration_NilCurrentTreatsDatabaseAsEmpty(t *testing.T) {
|
||||||
|
model := models.InitDatabase("testdb")
|
||||||
|
modelSchema := models.InitSchema("public")
|
||||||
|
|
||||||
|
table := models.InitTable("users", "public")
|
||||||
|
idCol := models.InitColumn("id", "users", "public")
|
||||||
|
idCol.Type = "integer"
|
||||||
|
idCol.NotNull = true
|
||||||
|
table.Columns["id"] = idCol
|
||||||
|
|
||||||
|
modelSchema.Tables = append(modelSchema.Tables, table)
|
||||||
|
model.Schemas = append(model.Schemas, modelSchema)
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
writer, err := NewMigrationWriter(nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create writer: %v", err)
|
||||||
|
}
|
||||||
|
writer.writer = &buf
|
||||||
|
|
||||||
|
if err := writer.WriteMigration(model, nil); err != nil {
|
||||||
|
t.Fatalf("WriteMigration with nil current failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
output := buf.String()
|
||||||
|
if !strings.Contains(output, "CREATE TABLE") {
|
||||||
|
t.Fatalf("expected CREATE TABLE in migration output, got:\n%s", output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"text/template"
|
"text/template"
|
||||||
|
|
||||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||||
|
"git.warky.dev/wdevs/relspecgo/pkg/writers"
|
||||||
)
|
)
|
||||||
|
|
||||||
//go:embed templates/*.tmpl
|
//go:embed templates/*.tmpl
|
||||||
@@ -94,6 +95,16 @@ type AlterColumnTypeData struct {
|
|||||||
TableName string
|
TableName string
|
||||||
ColumnName string
|
ColumnName string
|
||||||
NewType string
|
NewType string
|
||||||
|
UsingExpr string
|
||||||
|
}
|
||||||
|
|
||||||
|
type AlterColumnTypeWithCheckData struct {
|
||||||
|
SchemaName string
|
||||||
|
TableName string
|
||||||
|
ColumnName string
|
||||||
|
NewType string
|
||||||
|
EquivalentTypes string
|
||||||
|
UsingExpr string
|
||||||
}
|
}
|
||||||
|
|
||||||
// AlterColumnDefaultData contains data for alter column default template
|
// AlterColumnDefaultData contains data for alter column default template
|
||||||
@@ -266,6 +277,7 @@ type CreatePrimaryKeyWithAutoGenCheckData struct {
|
|||||||
ConstraintName string
|
ConstraintName string
|
||||||
AutoGenNames string // Comma-separated list of names like "'name1', 'name2'"
|
AutoGenNames string // Comma-separated list of names like "'name1', 'name2'"
|
||||||
Columns string
|
Columns string
|
||||||
|
ColumnNames string // Comma-separated list of quoted column names like "'id', 'tenant_id'"
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute methods for each template
|
// Execute methods for each template
|
||||||
@@ -300,6 +312,15 @@ func (te *TemplateExecutor) ExecuteAlterColumnType(data AlterColumnTypeData) (st
|
|||||||
return buf.String(), nil
|
return buf.String(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (te *TemplateExecutor) ExecuteAlterColumnTypeWithCheck(data AlterColumnTypeWithCheckData) (string, error) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
err := te.templates.ExecuteTemplate(&buf, "alter_column_type_with_check.tmpl", data)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to execute alter_column_type_with_check template: %w", err)
|
||||||
|
}
|
||||||
|
return buf.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
// ExecuteAlterColumnDefault executes the alter column default template
|
// ExecuteAlterColumnDefault executes the alter column default template
|
||||||
func (te *TemplateExecutor) ExecuteAlterColumnDefault(data AlterColumnDefaultData) (string, error) {
|
func (te *TemplateExecutor) ExecuteAlterColumnDefault(data AlterColumnDefaultData) (string, error) {
|
||||||
var buf bytes.Buffer
|
var buf bytes.Buffer
|
||||||
@@ -495,8 +516,12 @@ func BuildCreateTableData(schemaName string, table *models.Table) CreateTableDat
|
|||||||
NotNull: col.NotNull,
|
NotNull: col.NotNull,
|
||||||
}
|
}
|
||||||
if col.Default != nil {
|
if col.Default != nil {
|
||||||
|
if value, ok := col.Default.(string); ok {
|
||||||
|
colData.Default = writers.QuoteDefaultValue(value, col.Type)
|
||||||
|
} else {
|
||||||
colData.Default = fmt.Sprintf("%v", col.Default)
|
colData.Default = fmt.Sprintf("%v", col.Default)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
columns = append(columns, colData)
|
columns = append(columns, colData)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,2 +1,2 @@
|
|||||||
ALTER TABLE {{qual_table .SchemaName .TableName}}
|
ALTER TABLE {{qual_table .SchemaName .TableName}}
|
||||||
ALTER COLUMN {{quote_ident .ColumnName}} TYPE {{.NewType}};
|
ALTER COLUMN {{quote_ident .ColumnName}} TYPE {{.NewType}}{{if .UsingExpr}} USING {{.UsingExpr}}{{end}};
|
||||||
|
|||||||
@@ -0,0 +1,22 @@
|
|||||||
|
DO $$
|
||||||
|
DECLARE
|
||||||
|
current_type text;
|
||||||
|
BEGIN
|
||||||
|
SELECT pg_catalog.format_type(a.atttypid, a.atttypmod)
|
||||||
|
INTO current_type
|
||||||
|
FROM pg_attribute a
|
||||||
|
JOIN pg_class t ON t.oid = a.attrelid
|
||||||
|
JOIN pg_namespace n ON n.oid = t.relnamespace
|
||||||
|
WHERE n.nspname = '{{.SchemaName}}'
|
||||||
|
AND t.relname = '{{.TableName}}'
|
||||||
|
AND a.attname = '{{.ColumnName}}'
|
||||||
|
AND a.attnum > 0
|
||||||
|
AND NOT a.attisdropped;
|
||||||
|
|
||||||
|
IF current_type IS NOT NULL
|
||||||
|
AND current_type <> ALL(ARRAY[{{.EquivalentTypes}}]) THEN
|
||||||
|
ALTER TABLE {{qual_table .SchemaName .TableName}}
|
||||||
|
ALTER COLUMN {{quote_ident .ColumnName}} TYPE {{.NewType}}{{if .UsingExpr}} USING {{.UsingExpr}}{{end}};
|
||||||
|
END IF;
|
||||||
|
END;
|
||||||
|
$$;
|
||||||
@@ -1,26 +1,42 @@
|
|||||||
DO $$
|
DO $$
|
||||||
DECLARE
|
DECLARE
|
||||||
auto_pk_name text;
|
current_pk_name text;
|
||||||
|
current_pk_matches boolean := false;
|
||||||
BEGIN
|
BEGIN
|
||||||
-- Drop auto-generated primary key if it exists
|
SELECT tc.constraint_name,
|
||||||
SELECT constraint_name INTO auto_pk_name
|
COALESCE(
|
||||||
FROM information_schema.table_constraints
|
ARRAY(
|
||||||
WHERE table_schema = '{{.SchemaName}}'
|
SELECT a.attname::text
|
||||||
AND table_name = '{{.TableName}}'
|
FROM pg_constraint c
|
||||||
AND constraint_type = 'PRIMARY KEY'
|
JOIN pg_class t ON t.oid = c.conrelid
|
||||||
AND constraint_name IN ({{.AutoGenNames}});
|
JOIN pg_namespace n ON n.oid = t.relnamespace
|
||||||
|
JOIN unnest(c.conkey) WITH ORDINALITY AS cols(attnum, ord)
|
||||||
|
ON TRUE
|
||||||
|
JOIN pg_attribute a
|
||||||
|
ON a.attrelid = t.oid
|
||||||
|
AND a.attnum = cols.attnum
|
||||||
|
WHERE c.contype = 'p'
|
||||||
|
AND n.nspname = '{{.SchemaName}}'
|
||||||
|
AND t.relname = '{{.TableName}}'
|
||||||
|
ORDER BY cols.ord
|
||||||
|
),
|
||||||
|
ARRAY[]::text[]
|
||||||
|
) = ARRAY[{{.ColumnNames}}]
|
||||||
|
INTO current_pk_name, current_pk_matches
|
||||||
|
FROM information_schema.table_constraints tc
|
||||||
|
WHERE tc.table_schema = '{{.SchemaName}}'
|
||||||
|
AND tc.table_name = '{{.TableName}}'
|
||||||
|
AND tc.constraint_type = 'PRIMARY KEY';
|
||||||
|
|
||||||
IF auto_pk_name IS NOT NULL THEN
|
IF current_pk_name IS NOT NULL
|
||||||
EXECUTE 'ALTER TABLE {{qual_table .SchemaName .TableName}} DROP CONSTRAINT ' || quote_ident(auto_pk_name);
|
AND NOT current_pk_matches
|
||||||
|
AND current_pk_name IN ({{.AutoGenNames}}) THEN
|
||||||
|
EXECUTE 'ALTER TABLE {{qual_table .SchemaName .TableName}} DROP CONSTRAINT ' || quote_ident(current_pk_name);
|
||||||
END IF;
|
END IF;
|
||||||
|
|
||||||
-- Add named primary key if it doesn't exist
|
-- Add the desired primary key only when no matching primary key already exists.
|
||||||
IF NOT EXISTS (
|
IF current_pk_name IS NULL
|
||||||
SELECT 1 FROM information_schema.table_constraints
|
OR (NOT current_pk_matches AND current_pk_name IN ({{.AutoGenNames}})) THEN
|
||||||
WHERE table_schema = '{{.SchemaName}}'
|
|
||||||
AND table_name = '{{.TableName}}'
|
|
||||||
AND constraint_name = '{{.ConstraintName}}'
|
|
||||||
) THEN
|
|
||||||
ALTER TABLE {{qual_table .SchemaName .TableName}} ADD CONSTRAINT {{quote_ident .ConstraintName}} PRIMARY KEY ({{.Columns}});
|
ALTER TABLE {{qual_table .SchemaName .TableName}} ADD CONSTRAINT {{quote_ident .ConstraintName}} PRIMARY KEY ({{.Columns}});
|
||||||
END IF;
|
END IF;
|
||||||
END;
|
END;
|
||||||
|
|||||||
@@ -143,6 +143,10 @@ func (w *Writer) GenerateSchemaStatements(schema *models.Schema) ([]string, erro
|
|||||||
statements = append(statements, fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %s", schema.SQLName()))
|
statements = append(statements, fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %s", schema.SQLName()))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if schemaRequiresPGTrgm(schema) {
|
||||||
|
statements = append(statements, `CREATE EXTENSION IF NOT EXISTS pg_trgm`)
|
||||||
|
}
|
||||||
|
|
||||||
// Phase 2: Create sequences
|
// Phase 2: Create sequences
|
||||||
for _, table := range schema.Tables {
|
for _, table := range schema.Tables {
|
||||||
pk := table.GetPrimaryKey()
|
pk := table.GetPrimaryKey()
|
||||||
@@ -181,6 +185,12 @@ func (w *Writer) GenerateSchemaStatements(schema *models.Schema) ([]string, erro
|
|||||||
}
|
}
|
||||||
statements = append(statements, addColStmts...)
|
statements = append(statements, addColStmts...)
|
||||||
|
|
||||||
|
alterTypeStmts, err := w.GenerateAlterColumnTypeStatements(schema)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to generate alter column type statements: %w", err)
|
||||||
|
}
|
||||||
|
statements = append(statements, alterTypeStmts...)
|
||||||
|
|
||||||
// Phase 4: Primary keys
|
// Phase 4: Primary keys
|
||||||
for _, table := range schema.Tables {
|
for _, table := range schema.Tables {
|
||||||
// First check for explicit PrimaryKeyConstraint
|
// First check for explicit PrimaryKeyConstraint
|
||||||
@@ -228,6 +238,7 @@ func (w *Writer) GenerateSchemaStatements(schema *models.Schema) ([]string, erro
|
|||||||
ConstraintName: pkName,
|
ConstraintName: pkName,
|
||||||
AutoGenNames: formatStringList(autoGenPKNames),
|
AutoGenNames: formatStringList(autoGenPKNames),
|
||||||
Columns: strings.Join(pkColumns, ", "),
|
Columns: strings.Join(pkColumns, ", "),
|
||||||
|
ColumnNames: formatStringList(pkColumns),
|
||||||
}
|
}
|
||||||
|
|
||||||
stmt, err := w.executor.ExecuteCreatePrimaryKeyWithAutoGenCheck(data)
|
stmt, err := w.executor.ExecuteCreatePrimaryKeyWithAutoGenCheck(data)
|
||||||
@@ -260,16 +271,13 @@ func (w *Writer) GenerateSchemaStatements(schema *models.Schema) ([]string, erro
|
|||||||
columnExprs := make([]string, 0, len(index.Columns))
|
columnExprs := make([]string, 0, len(index.Columns))
|
||||||
for _, colName := range index.Columns {
|
for _, colName := range index.Columns {
|
||||||
colExpr := colName
|
colExpr := colName
|
||||||
if col, ok := table.Columns[colName]; ok {
|
if col, ok := resolveIndexColumn(table, colName); ok {
|
||||||
// For GIN indexes on text columns, add operator class
|
if strings.EqualFold(indexType, "gin") {
|
||||||
if strings.EqualFold(indexType, "gin") && isTextType(col.Type) {
|
if opClass := ginOperatorClassForColumn(col, index.Comment); opClass != "" {
|
||||||
opClass := extractOperatorClass(index.Comment)
|
|
||||||
if opClass == "" {
|
|
||||||
opClass = "gin_trgm_ops"
|
|
||||||
}
|
|
||||||
colExpr = fmt.Sprintf("%s %s", colName, opClass)
|
colExpr = fmt.Sprintf("%s %s", colName, opClass)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
columnExprs = append(columnExprs, colExpr)
|
columnExprs = append(columnExprs, colExpr)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -436,6 +444,33 @@ func (w *Writer) GenerateAddColumnStatements(schema *models.Schema) ([]string, e
|
|||||||
return statements, nil
|
return statements, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (w *Writer) GenerateAlterColumnTypeStatements(schema *models.Schema) ([]string, error) {
|
||||||
|
statements := []string{}
|
||||||
|
|
||||||
|
statements = append(statements, fmt.Sprintf("-- Alter column types for schema: %s", schema.Name))
|
||||||
|
|
||||||
|
for _, table := range schema.Tables {
|
||||||
|
columns := getSortedColumns(table.Columns)
|
||||||
|
for _, col := range columns {
|
||||||
|
targetType := effectiveAlterColumnSQLType(col)
|
||||||
|
stmt, err := w.executor.ExecuteAlterColumnTypeWithCheck(AlterColumnTypeWithCheckData{
|
||||||
|
SchemaName: schema.Name,
|
||||||
|
TableName: table.Name,
|
||||||
|
ColumnName: col.Name,
|
||||||
|
NewType: targetType,
|
||||||
|
EquivalentTypes: equivalentTypeListSQL(targetType),
|
||||||
|
UsingExpr: buildAlterColumnUsingExpression(col.Name, targetType),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to generate alter column type for %s.%s.%s: %w", schema.Name, table.Name, col.Name, err)
|
||||||
|
}
|
||||||
|
statements = append(statements, stmt)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return statements, nil
|
||||||
|
}
|
||||||
|
|
||||||
// GenerateAddColumnsForDatabase generates ALTER TABLE ADD COLUMN statements for the entire database
|
// GenerateAddColumnsForDatabase generates ALTER TABLE ADD COLUMN statements for the entire database
|
||||||
func (w *Writer) GenerateAddColumnsForDatabase(db *models.Database) ([]string, error) {
|
func (w *Writer) GenerateAddColumnsForDatabase(db *models.Database) ([]string, error) {
|
||||||
statements := []string{}
|
statements := []string{}
|
||||||
@@ -488,31 +523,7 @@ func (w *Writer) generateCreateTableStatement(schema *models.Schema, table *mode
|
|||||||
func (w *Writer) generateColumnDefinition(col *models.Column) string {
|
func (w *Writer) generateColumnDefinition(col *models.Column) string {
|
||||||
parts := []string{col.SQLName()}
|
parts := []string{col.SQLName()}
|
||||||
|
|
||||||
// Type with length/precision - convert to valid PostgreSQL type
|
parts = append(parts, effectiveColumnSQLType(col))
|
||||||
baseType := pgsql.ConvertSQLType(col.Type)
|
|
||||||
typeStr := baseType
|
|
||||||
hasExplicitTypeModifier := pgsql.HasExplicitTypeModifier(baseType)
|
|
||||||
|
|
||||||
// Only add size specifiers for types that support them
|
|
||||||
if !hasExplicitTypeModifier && col.Length > 0 && col.Precision == 0 {
|
|
||||||
if pgsql.SupportsLength(baseType) {
|
|
||||||
typeStr = fmt.Sprintf("%s(%d)", baseType, col.Length)
|
|
||||||
} else if isTextTypeWithoutLength(baseType) {
|
|
||||||
// Convert text with length to varchar
|
|
||||||
typeStr = fmt.Sprintf("varchar(%d)", col.Length)
|
|
||||||
}
|
|
||||||
// For types that don't support length (integer, bigint, etc.), ignore the length
|
|
||||||
} else if !hasExplicitTypeModifier && col.Precision > 0 {
|
|
||||||
if pgsql.SupportsPrecision(baseType) {
|
|
||||||
if col.Scale > 0 {
|
|
||||||
typeStr = fmt.Sprintf("%s(%d,%d)", baseType, col.Precision, col.Scale)
|
|
||||||
} else {
|
|
||||||
typeStr = fmt.Sprintf("%s(%d)", baseType, col.Precision)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// For types that don't support precision, ignore it
|
|
||||||
}
|
|
||||||
parts = append(parts, typeStr)
|
|
||||||
|
|
||||||
// NOT NULL
|
// NOT NULL
|
||||||
if col.NotNull {
|
if col.NotNull {
|
||||||
@@ -523,15 +534,7 @@ func (w *Writer) generateColumnDefinition(col *models.Column) string {
|
|||||||
if col.Default != nil {
|
if col.Default != nil {
|
||||||
switch v := col.Default.(type) {
|
switch v := col.Default.(type) {
|
||||||
case string:
|
case string:
|
||||||
// Strip backticks - DBML uses them for SQL expressions but PostgreSQL doesn't
|
parts = append(parts, fmt.Sprintf("DEFAULT %s", writers.QuoteDefaultValue(stripBackticks(v), col.Type)))
|
||||||
cleanDefault := stripBackticks(v)
|
|
||||||
if strings.HasPrefix(cleanDefault, "nextval") || strings.HasPrefix(cleanDefault, "CURRENT_") || strings.Contains(cleanDefault, "()") {
|
|
||||||
parts = append(parts, fmt.Sprintf("DEFAULT %s", cleanDefault))
|
|
||||||
} else if cleanDefault == "true" || cleanDefault == "false" {
|
|
||||||
parts = append(parts, fmt.Sprintf("DEFAULT %s", cleanDefault))
|
|
||||||
} else {
|
|
||||||
parts = append(parts, fmt.Sprintf("DEFAULT '%s'", escapeQuote(cleanDefault)))
|
|
||||||
}
|
|
||||||
case bool:
|
case bool:
|
||||||
parts = append(parts, fmt.Sprintf("DEFAULT %v", v))
|
parts = append(parts, fmt.Sprintf("DEFAULT %v", v))
|
||||||
default:
|
default:
|
||||||
@@ -542,6 +545,64 @@ func (w *Writer) generateColumnDefinition(col *models.Column) string {
|
|||||||
return strings.Join(parts, " ")
|
return strings.Join(parts, " ")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func effectiveColumnSQLType(col *models.Column) string {
|
||||||
|
if col == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
baseType := pgsql.ConvertSQLType(col.Type)
|
||||||
|
typeStr := baseType
|
||||||
|
hasExplicitTypeModifier := pgsql.HasExplicitTypeModifier(baseType)
|
||||||
|
|
||||||
|
if !hasExplicitTypeModifier && col.Length > 0 && col.Precision == 0 {
|
||||||
|
if pgsql.SupportsLength(baseType) {
|
||||||
|
typeStr = fmt.Sprintf("%s(%d)", baseType, col.Length)
|
||||||
|
} else if isTextTypeWithoutLength(baseType) {
|
||||||
|
typeStr = fmt.Sprintf("varchar(%d)", col.Length)
|
||||||
|
}
|
||||||
|
} else if !hasExplicitTypeModifier && col.Precision > 0 {
|
||||||
|
if pgsql.SupportsPrecision(baseType) {
|
||||||
|
if col.Scale > 0 {
|
||||||
|
typeStr = fmt.Sprintf("%s(%d,%d)", baseType, col.Precision, col.Scale)
|
||||||
|
} else {
|
||||||
|
typeStr = fmt.Sprintf("%s(%d)", baseType, col.Precision)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return typeStr
|
||||||
|
}
|
||||||
|
|
||||||
|
func effectiveAlterColumnSQLType(col *models.Column) string {
|
||||||
|
typeStr := effectiveColumnSQLType(col)
|
||||||
|
switch strings.ToLower(strings.TrimSpace(typeStr)) {
|
||||||
|
case "smallserial":
|
||||||
|
return "smallint"
|
||||||
|
case "serial":
|
||||||
|
return "integer"
|
||||||
|
case "bigserial":
|
||||||
|
return "bigint"
|
||||||
|
default:
|
||||||
|
return typeStr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildAlterColumnUsingExpression(columnName, targetType string) string {
|
||||||
|
if strings.TrimSpace(columnName) == "" || strings.TrimSpace(targetType) == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s::%s", quoteIdent(columnName), targetType)
|
||||||
|
}
|
||||||
|
|
||||||
|
func equivalentTypeListSQL(sqlType string) string {
|
||||||
|
variants := pgsql.EquivalentSQLTypeVariants(sqlType)
|
||||||
|
quoted := make([]string, 0, len(variants))
|
||||||
|
for _, variant := range variants {
|
||||||
|
quoted = append(quoted, fmt.Sprintf("'%s'", escapeQuote(variant)))
|
||||||
|
}
|
||||||
|
return strings.Join(quoted, ", ")
|
||||||
|
}
|
||||||
|
|
||||||
// WriteSchema writes a single schema and all its tables
|
// WriteSchema writes a single schema and all its tables
|
||||||
func (w *Writer) WriteSchema(schema *models.Schema) error {
|
func (w *Writer) WriteSchema(schema *models.Schema) error {
|
||||||
if w.writer == nil {
|
if w.writer == nil {
|
||||||
@@ -553,6 +614,10 @@ func (w *Writer) WriteSchema(schema *models.Schema) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := w.writeRequiredExtensions(schema); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
// Phase 2: Create sequences (priority 80)
|
// Phase 2: Create sequences (priority 80)
|
||||||
if err := w.writeSequences(schema); err != nil {
|
if err := w.writeSequences(schema); err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -568,6 +633,10 @@ func (w *Writer) WriteSchema(schema *models.Schema) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := w.writeAlterColumnTypes(schema); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
// Phase 4: Create primary keys (priority 160)
|
// Phase 4: Create primary keys (priority 160)
|
||||||
if err := w.writePrimaryKeys(schema); err != nil {
|
if err := w.writePrimaryKeys(schema); err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -668,6 +737,16 @@ func (w *Writer) writeCreateSchema(schema *models.Schema) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (w *Writer) writeRequiredExtensions(schema *models.Schema) error {
|
||||||
|
if !schemaRequiresPGTrgm(schema) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintln(w.writer, "CREATE EXTENSION IF NOT EXISTS pg_trgm;")
|
||||||
|
fmt.Fprintln(w.writer)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// writeSequences generates CREATE SEQUENCE statements for identity columns
|
// writeSequences generates CREATE SEQUENCE statements for identity columns
|
||||||
func (w *Writer) writeSequences(schema *models.Schema) error {
|
func (w *Writer) writeSequences(schema *models.Schema) error {
|
||||||
fmt.Fprintf(w.writer, "-- Sequences for schema: %s\n", schema.Name)
|
fmt.Fprintf(w.writer, "-- Sequences for schema: %s\n", schema.Name)
|
||||||
@@ -761,6 +840,21 @@ func (w *Writer) writeAddColumns(schema *models.Schema) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (w *Writer) writeAlterColumnTypes(schema *models.Schema) error {
|
||||||
|
fmt.Fprintf(w.writer, "-- Alter column types for schema: %s\n", schema.Name)
|
||||||
|
|
||||||
|
statements, err := w.GenerateAlterColumnTypeStatements(schema)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
for _, stmt := range statements[1:] {
|
||||||
|
fmt.Fprint(w.writer, stmt)
|
||||||
|
fmt.Fprint(w.writer, "\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// writePrimaryKeys generates ALTER TABLE statements for primary keys
|
// writePrimaryKeys generates ALTER TABLE statements for primary keys
|
||||||
func (w *Writer) writePrimaryKeys(schema *models.Schema) error {
|
func (w *Writer) writePrimaryKeys(schema *models.Schema) error {
|
||||||
fmt.Fprintf(w.writer, "-- Primary keys for schema: %s\n", schema.Name)
|
fmt.Fprintf(w.writer, "-- Primary keys for schema: %s\n", schema.Name)
|
||||||
@@ -814,6 +908,7 @@ func (w *Writer) writePrimaryKeys(schema *models.Schema) error {
|
|||||||
ConstraintName: pkName,
|
ConstraintName: pkName,
|
||||||
AutoGenNames: formatStringList(autoGenPKNames),
|
AutoGenNames: formatStringList(autoGenPKNames),
|
||||||
Columns: strings.Join(columnNames, ", "),
|
Columns: strings.Join(columnNames, ", "),
|
||||||
|
ColumnNames: formatStringList(columnNames),
|
||||||
}
|
}
|
||||||
|
|
||||||
sql, err := w.executor.ExecuteCreatePrimaryKeyWithAutoGenCheck(data)
|
sql, err := w.executor.ExecuteCreatePrimaryKeyWithAutoGenCheck(data)
|
||||||
@@ -861,16 +956,14 @@ func (w *Writer) writeIndexes(schema *models.Schema) error {
|
|||||||
// Build column list with operator class support for GIN indexes
|
// Build column list with operator class support for GIN indexes
|
||||||
columnExprs := make([]string, 0, len(index.Columns))
|
columnExprs := make([]string, 0, len(index.Columns))
|
||||||
for _, colName := range index.Columns {
|
for _, colName := range index.Columns {
|
||||||
if col, ok := table.Columns[colName]; ok {
|
if col, ok := resolveIndexColumn(table, colName); ok {
|
||||||
colExpr := col.SQLName()
|
colExpr := col.SQLName()
|
||||||
// For GIN indexes on text columns, add operator class
|
if strings.EqualFold(index.Type, "gin") {
|
||||||
if strings.EqualFold(index.Type, "gin") && isTextType(col.Type) {
|
opClass := ginOperatorClassForColumn(col, index.Comment)
|
||||||
opClass := extractOperatorClass(index.Comment)
|
if opClass != "" {
|
||||||
if opClass == "" {
|
|
||||||
opClass = "gin_trgm_ops"
|
|
||||||
}
|
|
||||||
colExpr = fmt.Sprintf("%s %s", col.SQLName(), opClass)
|
colExpr = fmt.Sprintf("%s %s", col.SQLName(), opClass)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
columnExprs = append(columnExprs, colExpr)
|
columnExprs = append(columnExprs, colExpr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1256,20 +1349,126 @@ func isIntegerType(colType string) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// isTextType checks if a column type is a text type (for GIN index operator class)
|
// isTextType checks if a column type is a text type (for GIN index operator class)
|
||||||
func isTextType(colType string) bool {
|
// func isTextType(colType string) bool {
|
||||||
textTypes := []string{"text", "varchar", "character varying", "char", "character", "string"}
|
// textTypes := []string{"text", "varchar", "character varying", "char", "character", "string"}
|
||||||
lowerType := strings.ToLower(colType)
|
// lowerType := strings.ToLower(colType)
|
||||||
for _, t := range textTypes {
|
// if strings.HasSuffix(lowerType, "[]") {
|
||||||
if strings.HasPrefix(lowerType, t) {
|
// return false
|
||||||
|
// }
|
||||||
|
// for _, t := range textTypes {
|
||||||
|
// if strings.HasPrefix(lowerType, t) {
|
||||||
|
// return true
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// return false
|
||||||
|
// }
|
||||||
|
|
||||||
|
// isTextTypeWithoutLength checks if type is text (which should convert to varchar when length is specified)
|
||||||
|
func isTextTypeWithoutLength(colType string) bool {
|
||||||
|
return strings.EqualFold(colType, "text")
|
||||||
|
}
|
||||||
|
|
||||||
|
func ginOperatorClassForColumn(col *models.Column, comment string) string {
|
||||||
|
if col == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
sqlType := effectiveColumnSQLType(col)
|
||||||
|
baseType := pgsql.CanonicalizeBaseType(pgsql.ExtractBaseTypeLower(sqlType))
|
||||||
|
isArray := pgsql.IsArrayType(sqlType)
|
||||||
|
requested := extractOperatorClass(comment)
|
||||||
|
|
||||||
|
if requested != "" && ginOperatorClassCompatible(baseType, isArray, requested) {
|
||||||
|
return requested
|
||||||
|
}
|
||||||
|
|
||||||
|
if isArray {
|
||||||
|
return "array_ops"
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case isTextGinBaseType(baseType):
|
||||||
|
return "gin_trgm_ops"
|
||||||
|
case baseType == "jsonb":
|
||||||
|
return "jsonb_ops"
|
||||||
|
default:
|
||||||
|
return requested
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ginOperatorClassCompatible(baseType string, isArray bool, opClass string) bool {
|
||||||
|
switch opClass {
|
||||||
|
case "gin_trgm_ops", "gin_bigm_ops":
|
||||||
|
return !isArray && isTextGinBaseType(baseType)
|
||||||
|
case "jsonb_ops", "jsonb_path_ops":
|
||||||
|
return !isArray && baseType == "jsonb"
|
||||||
|
case "array_ops":
|
||||||
|
return isArray
|
||||||
|
default:
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func isTextGinBaseType(baseType string) bool {
|
||||||
|
switch baseType {
|
||||||
|
case "text", "varchar", "character varying", "char", "character", "string", "citext", "bpchar":
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func schemaRequiresPGTrgm(schema *models.Schema) bool {
|
||||||
|
if schema == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for _, table := range schema.Tables {
|
||||||
|
if table == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, index := range table.Indexes {
|
||||||
|
if index == nil || !strings.EqualFold(index.Type, "gin") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, colName := range index.Columns {
|
||||||
|
col, ok := resolveIndexColumn(table, colName)
|
||||||
|
if !ok || col == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if ginOperatorClassForColumn(col, index.Comment) == "gin_trgm_ops" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// isTextTypeWithoutLength checks if type is text (which should convert to varchar when length is specified)
|
func resolveIndexColumn(table *models.Table, colName string) (*models.Column, bool) {
|
||||||
func isTextTypeWithoutLength(colType string) bool {
|
if table == nil {
|
||||||
return strings.EqualFold(colType, "text")
|
return nil, false
|
||||||
|
}
|
||||||
|
if col, ok := table.Columns[colName]; ok && col != nil {
|
||||||
|
return col, true
|
||||||
|
}
|
||||||
|
|
||||||
|
normalized := strings.ToLower(strings.Trim(colName, `"`))
|
||||||
|
for key, col := range table.Columns {
|
||||||
|
if col == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if strings.ToLower(strings.Trim(key, `"`)) == normalized {
|
||||||
|
return col, true
|
||||||
|
}
|
||||||
|
if strings.ToLower(strings.Trim(col.Name, `"`)) == normalized {
|
||||||
|
return col, true
|
||||||
|
}
|
||||||
|
if strings.ToLower(strings.Trim(col.SQLName(), `"`)) == normalized {
|
||||||
|
return col, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
// formatStringList formats a list of strings as a SQL-safe comma-separated quoted list
|
// formatStringList formats a list of strings as a SQL-safe comma-separated quoted list
|
||||||
|
|||||||
@@ -87,6 +87,117 @@ func TestWriteDatabase(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestWriteDatabase_GinIndexOnTextArrayDoesNotUseTrigramOperatorClass(t *testing.T) {
|
||||||
|
db := models.InitDatabase("testdb")
|
||||||
|
schema := models.InitSchema("public")
|
||||||
|
|
||||||
|
table := models.InitTable("plans", "public")
|
||||||
|
|
||||||
|
tagsCol := models.InitColumn("tags", "plans", "public")
|
||||||
|
tagsCol.Type = "text[]"
|
||||||
|
table.Columns["tags"] = tagsCol
|
||||||
|
|
||||||
|
index := &models.Index{
|
||||||
|
Name: "idx_plans_tags",
|
||||||
|
Type: "gin",
|
||||||
|
Columns: []string{"tags"},
|
||||||
|
}
|
||||||
|
table.Indexes[index.Name] = index
|
||||||
|
|
||||||
|
schema.Tables = append(schema.Tables, table)
|
||||||
|
db.Schemas = append(db.Schemas, schema)
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
writer := NewWriter(&writers.WriterOptions{})
|
||||||
|
writer.writer = &buf
|
||||||
|
|
||||||
|
if err := writer.WriteDatabase(db); err != nil {
|
||||||
|
t.Fatalf("WriteDatabase failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
output := buf.String()
|
||||||
|
if !strings.Contains(output, `USING gin (tags array_ops)`) {
|
||||||
|
t.Fatalf("expected GIN index on array column with array_ops, got:\n%s", output)
|
||||||
|
}
|
||||||
|
if strings.Contains(output, "gin_trgm_ops") {
|
||||||
|
t.Fatalf("did not expect gin_trgm_ops for text[] column, got:\n%s", output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWriteDatabase_GinIndexOnQuotedTextColumnUsesTrigramOperatorClass(t *testing.T) {
|
||||||
|
db := models.InitDatabase("testdb")
|
||||||
|
schema := models.InitSchema("public")
|
||||||
|
|
||||||
|
table := models.InitTable("agent_personas", "public")
|
||||||
|
|
||||||
|
nameCol := models.InitColumn("name", "agent_personas", "public")
|
||||||
|
nameCol.Type = "text"
|
||||||
|
table.Columns["name"] = nameCol
|
||||||
|
|
||||||
|
index := &models.Index{
|
||||||
|
Name: "idx_agent_personas_name_gin",
|
||||||
|
Type: "gin",
|
||||||
|
Columns: []string{`"name"`},
|
||||||
|
}
|
||||||
|
table.Indexes[index.Name] = index
|
||||||
|
|
||||||
|
schema.Tables = append(schema.Tables, table)
|
||||||
|
db.Schemas = append(db.Schemas, schema)
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
writer := NewWriter(&writers.WriterOptions{})
|
||||||
|
writer.writer = &buf
|
||||||
|
|
||||||
|
if err := writer.WriteDatabase(db); err != nil {
|
||||||
|
t.Fatalf("WriteDatabase failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
output := buf.String()
|
||||||
|
if !strings.Contains(output, `CREATE EXTENSION IF NOT EXISTS pg_trgm`) {
|
||||||
|
t.Fatalf("expected trigram extension for text GIN index, got:\n%s", output)
|
||||||
|
}
|
||||||
|
if !strings.Contains(output, `USING gin (name gin_trgm_ops)`) {
|
||||||
|
t.Fatalf("expected quoted text GIN index to include gin_trgm_ops, got:\n%s", output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWriteDatabase_GinIndexOnJSONBUsesJSONBOperatorClass(t *testing.T) {
|
||||||
|
db := models.InitDatabase("testdb")
|
||||||
|
schema := models.InitSchema("public")
|
||||||
|
|
||||||
|
table := models.InitTable("learnings", "public")
|
||||||
|
|
||||||
|
detailsCol := models.InitColumn("details", "learnings", "public")
|
||||||
|
detailsCol.Type = "jsonb"
|
||||||
|
table.Columns["details"] = detailsCol
|
||||||
|
|
||||||
|
index := &models.Index{
|
||||||
|
Name: "idx_learnings_details",
|
||||||
|
Type: "gin",
|
||||||
|
Columns: []string{"details"},
|
||||||
|
}
|
||||||
|
table.Indexes[index.Name] = index
|
||||||
|
|
||||||
|
schema.Tables = append(schema.Tables, table)
|
||||||
|
db.Schemas = append(db.Schemas, schema)
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
writer := NewWriter(&writers.WriterOptions{})
|
||||||
|
writer.writer = &buf
|
||||||
|
|
||||||
|
if err := writer.WriteDatabase(db); err != nil {
|
||||||
|
t.Fatalf("WriteDatabase failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
output := buf.String()
|
||||||
|
if !strings.Contains(output, `USING gin (details jsonb_ops)`) {
|
||||||
|
t.Fatalf("expected GIN jsonb index to include jsonb_ops, got:\n%s", output)
|
||||||
|
}
|
||||||
|
if strings.Contains(output, "gin_trgm_ops") {
|
||||||
|
t.Fatalf("did not expect gin_trgm_ops for jsonb column, got:\n%s", output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestWriteForeignKeys(t *testing.T) {
|
func TestWriteForeignKeys(t *testing.T) {
|
||||||
// Create a test database with two related tables
|
// Create a test database with two related tables
|
||||||
db := models.InitDatabase("testdb")
|
db := models.InitDatabase("testdb")
|
||||||
@@ -636,9 +747,14 @@ func TestPrimaryKeyExistenceCheck(t *testing.T) {
|
|||||||
t.Errorf("Output missing logic to drop auto-generated primary key\nFull output:\n%s", output)
|
t.Errorf("Output missing logic to drop auto-generated primary key\nFull output:\n%s", output)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify it checks for our specific named constraint before adding it
|
// Verify it compares the current primary key columns before dropping/recreating
|
||||||
if !strings.Contains(output, "constraint_name = 'pk_public_products'") {
|
if !strings.Contains(output, "current_pk_matches") || !strings.Contains(output, "ARRAY['id']") {
|
||||||
t.Errorf("Output missing check for our named primary key constraint\nFull output:\n%s", output)
|
t.Errorf("Output missing safe primary key comparison logic\nFull output:\n%s", output)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify it only adds the desired key when no PK exists or an auto-generated mismatch was dropped
|
||||||
|
if !strings.Contains(output, "current_pk_name IS NULL") || !strings.Contains(output, "current_pk_name IN ('products_pkey', 'public_products_pkey')") {
|
||||||
|
t.Errorf("Output missing guarded primary key creation logic\nFull output:\n%s", output)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -729,6 +845,43 @@ func TestColumnSizeSpecifiers(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestWriteDatabase_PrimaryKeyTemplateDoesNotDropMatchingAutoPrimaryKey(t *testing.T) {
|
||||||
|
db := models.InitDatabase("testdb")
|
||||||
|
schema := models.InitSchema("public")
|
||||||
|
table := models.InitTable("learnings", "public")
|
||||||
|
|
||||||
|
idCol := models.InitColumn("id", "learnings", "public")
|
||||||
|
idCol.Type = "bigint"
|
||||||
|
idCol.IsPrimaryKey = true
|
||||||
|
table.Columns["id"] = idCol
|
||||||
|
|
||||||
|
parentCol := models.InitColumn("duplicate_of_learning_id", "learnings", "public")
|
||||||
|
parentCol.Type = "bigint"
|
||||||
|
table.Columns["duplicate_of_learning_id"] = parentCol
|
||||||
|
|
||||||
|
schema.Tables = append(schema.Tables, table)
|
||||||
|
db.Schemas = append(db.Schemas, schema)
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
writer := NewWriter(&writers.WriterOptions{})
|
||||||
|
writer.writer = &buf
|
||||||
|
|
||||||
|
if err := writer.WriteDatabase(db); err != nil {
|
||||||
|
t.Fatalf("WriteDatabase failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
output := buf.String()
|
||||||
|
if !strings.Contains(output, "current_pk_matches") {
|
||||||
|
t.Fatalf("expected generated SQL to compare current PK columns, got:\n%s", output)
|
||||||
|
}
|
||||||
|
if !strings.Contains(output, "ARRAY['id']") {
|
||||||
|
t.Fatalf("expected generated SQL to compare against desired PK columns, got:\n%s", output)
|
||||||
|
}
|
||||||
|
if !strings.Contains(output, "NOT current_pk_matches") {
|
||||||
|
t.Fatalf("expected generated SQL to avoid dropping matching PKs, got:\n%s", output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestGenerateColumnDefinition_PreservesExplicitTypeModifiers(t *testing.T) {
|
func TestGenerateColumnDefinition_PreservesExplicitTypeModifiers(t *testing.T) {
|
||||||
writer := NewWriter(&writers.WriterOptions{})
|
writer := NewWriter(&writers.WriterOptions{})
|
||||||
|
|
||||||
@@ -905,3 +1058,82 @@ func TestWriteAddColumnStatements(t *testing.T) {
|
|||||||
t.Errorf("Output missing DO block\nFull output:\n%s", output)
|
t.Errorf("Output missing DO block\nFull output:\n%s", output)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestWriteSchema_EmitsGuardedAlterColumnTypeStatements(t *testing.T) {
|
||||||
|
db := models.InitDatabase("testdb")
|
||||||
|
schema := models.InitSchema("public")
|
||||||
|
|
||||||
|
table := models.InitTable("agent_skills", "public")
|
||||||
|
|
||||||
|
nameCol := models.InitColumn("name", "agent_skills", "public")
|
||||||
|
nameCol.Type = "character varying"
|
||||||
|
nameCol.Length = 255
|
||||||
|
table.Columns["name"] = nameCol
|
||||||
|
|
||||||
|
tagsCol := models.InitColumn("tags", "agent_skills", "public")
|
||||||
|
tagsCol.Type = "text[]"
|
||||||
|
table.Columns["tags"] = tagsCol
|
||||||
|
|
||||||
|
schema.Tables = append(schema.Tables, table)
|
||||||
|
db.Schemas = append(db.Schemas, schema)
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
writer := NewWriter(&writers.WriterOptions{})
|
||||||
|
writer.writer = &buf
|
||||||
|
|
||||||
|
if err := writer.WriteDatabase(db); err != nil {
|
||||||
|
t.Fatalf("WriteDatabase failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
output := buf.String()
|
||||||
|
if !strings.Contains(output, "-- Alter column types for schema: public") {
|
||||||
|
t.Fatalf("expected alter column type section, got:\n%s", output)
|
||||||
|
}
|
||||||
|
if !strings.Contains(output, "pg_catalog.format_type") {
|
||||||
|
t.Fatalf("expected guarded live-type check, got:\n%s", output)
|
||||||
|
}
|
||||||
|
if !strings.Contains(output, "ALTER COLUMN name TYPE character varying(255)") {
|
||||||
|
t.Fatalf("expected guarded alter for character varying(255), got:\n%s", output)
|
||||||
|
}
|
||||||
|
if !strings.Contains(output, "ARRAY['varchar(255)', 'character varying(255)']") {
|
||||||
|
t.Fatalf("expected equivalent type spellings for varchar guard, got:\n%s", output)
|
||||||
|
}
|
||||||
|
if !strings.Contains(output, "ALTER COLUMN tags TYPE text[]") {
|
||||||
|
t.Fatalf("expected guarded alter for array type, got:\n%s", output)
|
||||||
|
}
|
||||||
|
if !strings.Contains(output, `ALTER COLUMN tags TYPE text[] USING tags::text[];`) {
|
||||||
|
t.Fatalf("expected guarded alter for array type to include USING cast, got:\n%s", output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWriteSchema_UsesStorageTypeForSerialAlterStatements(t *testing.T) {
|
||||||
|
db := models.InitDatabase("testdb")
|
||||||
|
schema := models.InitSchema("public")
|
||||||
|
|
||||||
|
table := models.InitTable("learnings", "public")
|
||||||
|
idCol := models.InitColumn("id", "learnings", "public")
|
||||||
|
idCol.Type = "bigserial"
|
||||||
|
table.Columns["id"] = idCol
|
||||||
|
|
||||||
|
schema.Tables = append(schema.Tables, table)
|
||||||
|
db.Schemas = append(db.Schemas, schema)
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
writer := NewWriter(&writers.WriterOptions{})
|
||||||
|
writer.writer = &buf
|
||||||
|
|
||||||
|
if err := writer.WriteDatabase(db); err != nil {
|
||||||
|
t.Fatalf("WriteDatabase failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
output := buf.String()
|
||||||
|
if !strings.Contains(output, "ALTER COLUMN id TYPE bigint") {
|
||||||
|
t.Fatalf("expected serial alter to use bigint storage type, got:\n%s", output)
|
||||||
|
}
|
||||||
|
if strings.Contains(output, "ALTER COLUMN id TYPE bigserial;") {
|
||||||
|
t.Fatalf("did not expect invalid bigserial alter statement, got:\n%s", output)
|
||||||
|
}
|
||||||
|
if !strings.Contains(output, `ALTER COLUMN id TYPE bigint USING id::bigint;`) {
|
||||||
|
t.Fatalf("expected serial alter to include USING cast, got:\n%s", output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -61,7 +61,7 @@ func (w *Writer) databaseToPrisma(db *models.Database) string {
|
|||||||
sb.WriteString("\n")
|
sb.WriteString("\n")
|
||||||
|
|
||||||
// Write generator block
|
// Write generator block
|
||||||
sb.WriteString(w.generateGenerator())
|
sb.WriteString(w.generateGenerator(db))
|
||||||
sb.WriteString("\n")
|
sb.WriteString("\n")
|
||||||
|
|
||||||
// Process all schemas (typically just one in Prisma)
|
// Process all schemas (typically just one in Prisma)
|
||||||
@@ -114,13 +114,28 @@ func (w *Writer) generateDatasource(db *models.Database) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// generateGenerator generates the generator block
|
// generateGenerator generates the generator block
|
||||||
func (w *Writer) generateGenerator() string {
|
func (w *Writer) generateGenerator(db *models.Database) string {
|
||||||
|
if w.usePrisma7Generator(db) {
|
||||||
|
return `generator client {
|
||||||
|
provider = "prisma-client"
|
||||||
|
output = "./generated"
|
||||||
|
}
|
||||||
|
`
|
||||||
|
}
|
||||||
|
|
||||||
return `generator client {
|
return `generator client {
|
||||||
provider = "prisma-client-js"
|
provider = "prisma-client-js"
|
||||||
}
|
}
|
||||||
`
|
`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (w *Writer) usePrisma7Generator(db *models.Database) bool {
|
||||||
|
if w.options != nil && w.options.Prisma7 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return db != nil && db.SourceFormat == "prisma7"
|
||||||
|
}
|
||||||
|
|
||||||
// enumToPrisma converts an Enum to Prisma enum block
|
// enumToPrisma converts an Enum to Prisma enum block
|
||||||
func (w *Writer) enumToPrisma(enum *models.Enum) string {
|
func (w *Writer) enumToPrisma(enum *models.Enum) string {
|
||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
|
|||||||
52
pkg/writers/prisma/writer_test.go
Normal file
52
pkg/writers/prisma/writer_test.go
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
package prisma
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||||
|
"git.warky.dev/wdevs/relspecgo/pkg/writers"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGenerateGenerator_DefaultsToPrismaClientJS(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
writer := NewWriter(&writers.WriterOptions{})
|
||||||
|
db := models.InitDatabase("testdb")
|
||||||
|
|
||||||
|
got := writer.generateGenerator(db)
|
||||||
|
if !strings.Contains(got, `provider = "prisma-client-js"`) {
|
||||||
|
t.Fatalf("expected prisma-client-js generator, got:\n%s", got)
|
||||||
|
}
|
||||||
|
if strings.Contains(got, `output = "./generated"`) {
|
||||||
|
t.Fatalf("did not expect prisma7 output path in default generator:\n%s", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerateGenerator_Prisma7FlagUsesPrismaClient(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
writer := NewWriter(&writers.WriterOptions{Prisma7: true})
|
||||||
|
db := models.InitDatabase("testdb")
|
||||||
|
|
||||||
|
got := writer.generateGenerator(db)
|
||||||
|
if !strings.Contains(got, `provider = "prisma-client"`) {
|
||||||
|
t.Fatalf("expected prisma-client generator, got:\n%s", got)
|
||||||
|
}
|
||||||
|
if !strings.Contains(got, `output = "./generated"`) {
|
||||||
|
t.Fatalf("expected prisma7 output path, got:\n%s", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerateGenerator_Prisma7SourceFormatUsesPrismaClient(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
writer := NewWriter(&writers.WriterOptions{})
|
||||||
|
db := models.InitDatabase("testdb")
|
||||||
|
db.SourceFormat = "prisma7"
|
||||||
|
|
||||||
|
got := writer.generateGenerator(db)
|
||||||
|
if !strings.Contains(got, `provider = "prisma-client"`) {
|
||||||
|
t.Fatalf("expected prisma-client generator from source format, got:\n%s", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -20,6 +20,18 @@ type Writer interface {
|
|||||||
WriteTable(table *models.Table) error
|
WriteTable(table *models.Table) error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NullableType constants control which Go package is used for nullable column types
|
||||||
|
// in code-generation writers (Bun, GORM).
|
||||||
|
const (
|
||||||
|
// NullableTypeResolveSpec uses github.com/bitechdev/ResolveSpec/pkg/spectypes
|
||||||
|
// (SqlString, SqlInt32, SqlVector, SqlStringArray, …). This is the default.
|
||||||
|
NullableTypeResolveSpec = "resolvespec"
|
||||||
|
|
||||||
|
// NullableTypeStdlib uses the standard library database/sql nullable types
|
||||||
|
// (sql.NullString, sql.NullInt32, …) and plain Go slices for arrays.
|
||||||
|
NullableTypeStdlib = "stdlib"
|
||||||
|
)
|
||||||
|
|
||||||
// WriterOptions contains common options for writers
|
// WriterOptions contains common options for writers
|
||||||
type WriterOptions struct {
|
type WriterOptions struct {
|
||||||
// OutputPath is the path where the output should be written
|
// OutputPath is the path where the output should be written
|
||||||
@@ -33,6 +45,15 @@ type WriterOptions struct {
|
|||||||
// Useful for databases like SQLite that do not support schemas.
|
// Useful for databases like SQLite that do not support schemas.
|
||||||
FlattenSchema bool
|
FlattenSchema bool
|
||||||
|
|
||||||
|
// NullableTypes selects the Go type package used for nullable columns in
|
||||||
|
// code-generation writers (bun, gorm). Accepted values:
|
||||||
|
// "resolvespec" (default) — github.com/bitechdev/ResolveSpec/pkg/spectypes
|
||||||
|
// "stdlib" — database/sql (sql.NullString, sql.NullInt32, …)
|
||||||
|
NullableTypes string
|
||||||
|
|
||||||
|
// Prisma7 enables Prisma 7-specific output for Prisma writers.
|
||||||
|
Prisma7 bool
|
||||||
|
|
||||||
// Additional options can be added here as needed
|
// Additional options can be added here as needed
|
||||||
Metadata map[string]interface{}
|
Metadata map[string]interface{}
|
||||||
}
|
}
|
||||||
@@ -92,8 +113,12 @@ func SanitizeFilename(name string) string {
|
|||||||
// Examples (bigint): "0" → "0"
|
// Examples (bigint): "0" → "0"
|
||||||
// Examples (timestamp): "now()" → "now()" (function call – never quoted)
|
// Examples (timestamp): "now()" → "now()" (function call – never quoted)
|
||||||
func QuoteDefaultValue(value, sqlType string) string {
|
func QuoteDefaultValue(value, sqlType string) string {
|
||||||
|
value = strings.TrimSpace(value)
|
||||||
|
|
||||||
// Function calls are never quoted regardless of column type.
|
// Function calls are never quoted regardless of column type.
|
||||||
if strings.Contains(value, "(") || strings.Contains(value, ")") {
|
if strings.Contains(value, "(") || strings.Contains(value, ")") ||
|
||||||
|
strings.Contains(value, "::") ||
|
||||||
|
strings.HasPrefix(strings.ToUpper(value), "ARRAY[") {
|
||||||
return value
|
return value
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -103,6 +128,16 @@ func QuoteDefaultValue(value, sqlType string) string {
|
|||||||
baseType = baseType[:idx]
|
baseType = baseType[:idx]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if isArraySQLType(baseType) {
|
||||||
|
if arrayLiteral, ok := normalizeArrayDefaultLiteral(value); ok {
|
||||||
|
return quoteSQLLiteral(arrayLiteral)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if isQuotedSQLLiteral(value) {
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
|
||||||
// Types whose default values must NOT be quoted.
|
// Types whose default values must NOT be quoted.
|
||||||
unquotedTypes := map[string]bool{
|
unquotedTypes := map[string]bool{
|
||||||
// Integer types
|
// Integer types
|
||||||
@@ -136,7 +171,32 @@ func QuoteDefaultValue(value, sqlType string) string {
|
|||||||
|
|
||||||
// Everything else (text, varchar, char, uuid, date, time, timestamp, json, …)
|
// Everything else (text, varchar, char, uuid, date, time, timestamp, json, …)
|
||||||
// is treated as a quoted literal.
|
// is treated as a quoted literal.
|
||||||
return "'" + value + "'"
|
return quoteSQLLiteral(value)
|
||||||
|
}
|
||||||
|
|
||||||
|
func isArraySQLType(sqlType string) bool {
|
||||||
|
return strings.HasSuffix(sqlType, "[]")
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeArrayDefaultLiteral(value string) (string, bool) {
|
||||||
|
switch {
|
||||||
|
case strings.HasPrefix(value, "''{") && strings.HasSuffix(value, "}''"):
|
||||||
|
return value[2 : len(value)-2], true
|
||||||
|
case strings.HasPrefix(value, "'{") && strings.HasSuffix(value, "}'"):
|
||||||
|
return value[1 : len(value)-1], true
|
||||||
|
case strings.HasPrefix(value, "{") && strings.HasSuffix(value, "}"):
|
||||||
|
return value, true
|
||||||
|
default:
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func isQuotedSQLLiteral(value string) bool {
|
||||||
|
return len(value) >= 2 && strings.HasPrefix(value, "'") && strings.HasSuffix(value, "'")
|
||||||
|
}
|
||||||
|
|
||||||
|
func quoteSQLLiteral(value string) string {
|
||||||
|
return "'" + strings.ReplaceAll(value, "'", "''") + "'"
|
||||||
}
|
}
|
||||||
|
|
||||||
// SanitizeStructTagValue sanitizes a value to be safely used inside Go struct tags.
|
// SanitizeStructTagValue sanitizes a value to be safely used inside Go struct tags.
|
||||||
@@ -147,7 +207,8 @@ func QuoteDefaultValue(value, sqlType string) string {
|
|||||||
// - Returns a clean identifier safe for use in struct tags and field names
|
// - Returns a clean identifier safe for use in struct tags and field names
|
||||||
func SanitizeStructTagValue(value string) string {
|
func SanitizeStructTagValue(value string) string {
|
||||||
// Remove DBML/DCTX style comments in brackets (e.g., [note: 'description'])
|
// Remove DBML/DCTX style comments in brackets (e.g., [note: 'description'])
|
||||||
commentRegex := regexp.MustCompile(`\s*\[.*?\]\s*`)
|
// Require at least one character inside brackets to avoid stripping PostgreSQL array suffix "[]"
|
||||||
|
commentRegex := regexp.MustCompile(`\s*\[[^\]]+\]\s*`)
|
||||||
value = commentRegex.ReplaceAllString(value, "")
|
value = commentRegex.ReplaceAllString(value, "")
|
||||||
|
|
||||||
// Trim whitespace
|
// Trim whitespace
|
||||||
|
|||||||
54
pkg/writers/writer_test.go
Normal file
54
pkg/writers/writer_test.go
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
package writers
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestQuoteDefaultValue(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
value string
|
||||||
|
sqlType string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "text default is quoted",
|
||||||
|
value: "active",
|
||||||
|
sqlType: "text",
|
||||||
|
want: "'active'",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "array default from bare literal is quoted once",
|
||||||
|
value: "{}",
|
||||||
|
sqlType: "text[]",
|
||||||
|
want: "'{}'",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "array default from quoted literal is preserved",
|
||||||
|
value: "'{}'",
|
||||||
|
sqlType: "text[]",
|
||||||
|
want: "'{}'",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "array default from double quoted literal is normalized",
|
||||||
|
value: "''{}''",
|
||||||
|
sqlType: "text[]",
|
||||||
|
want: "'{}'",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "function default is left alone",
|
||||||
|
value: "now()",
|
||||||
|
sqlType: "timestamptz",
|
||||||
|
want: "now()",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := QuoteDefaultValue(tt.value, tt.sqlType)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Fatalf("QuoteDefaultValue(%q, %q) = %q, want %q", tt.value, tt.sqlType, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user