Compare commits
11 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| bb7ceb37fe | |||
| 6a759ef3d1 | |||
| cb735f0754 | |||
| 80fb49bc5e | |||
| 9190df81dd | |||
| 9235ef5e08 | |||
| b91d6b33b5 | |||
| 30ef1db010 | |||
| 2d97a47ee1 | |||
| 72200ea72e | |||
| 608893a3d6 |
@@ -53,6 +53,7 @@ var (
|
|||||||
convertSchemaFilter string
|
convertSchemaFilter string
|
||||||
convertFlattenSchema bool
|
convertFlattenSchema bool
|
||||||
convertNullableTypes string
|
convertNullableTypes string
|
||||||
|
convertContinueOnError bool
|
||||||
)
|
)
|
||||||
|
|
||||||
var convertCmd = &cobra.Command{
|
var convertCmd = &cobra.Command{
|
||||||
@@ -177,6 +178,7 @@ func init() {
|
|||||||
convertCmd.Flags().StringVar(&convertSchemaFilter, "schema", "", "Filter to a specific schema by name (required for formats like dctx that only support single schemas)")
|
convertCmd.Flags().StringVar(&convertSchemaFilter, "schema", "", "Filter to a specific schema by name (required for formats like dctx that only support single schemas)")
|
||||||
convertCmd.Flags().BoolVar(&convertFlattenSchema, "flatten-schema", false, "Flatten schema.table names to schema_table (useful for databases like SQLite that do not support schemas)")
|
convertCmd.Flags().BoolVar(&convertFlattenSchema, "flatten-schema", false, "Flatten schema.table names to schema_table (useful for databases like SQLite that do not support schemas)")
|
||||||
convertCmd.Flags().StringVar(&convertNullableTypes, "types", "", "Nullable type package for code-gen writers (bun/gorm): 'resolvespec' (default) or 'stdlib' (database/sql)")
|
convertCmd.Flags().StringVar(&convertNullableTypes, "types", "", "Nullable type package for code-gen writers (bun/gorm): 'resolvespec' (default) or 'stdlib' (database/sql)")
|
||||||
|
convertCmd.Flags().BoolVar(&convertContinueOnError, "continue-on-error", false, "Prepend \\set ON_ERROR_STOP off to generated SQL so psql continues past errors (pgsql output only)")
|
||||||
|
|
||||||
err := convertCmd.MarkFlagRequired("from")
|
err := convertCmd.MarkFlagRequired("from")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -243,7 +245,7 @@ func runConvert(cmd *cobra.Command, args []string) error {
|
|||||||
fmt.Fprintf(os.Stderr, " Schema: %s\n", convertSchemaFilter)
|
fmt.Fprintf(os.Stderr, " Schema: %s\n", convertSchemaFilter)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := writeDatabase(db, convertTargetType, convertTargetPath, convertPackageName, convertSchemaFilter, convertFlattenSchema, convertNullableTypes); err != nil {
|
if err := writeDatabase(db, convertTargetType, convertTargetPath, convertPackageName, convertSchemaFilter, convertFlattenSchema, convertNullableTypes, convertContinueOnError); err != nil {
|
||||||
return fmt.Errorf("failed to write target: %w", err)
|
return fmt.Errorf("failed to write target: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -383,10 +385,10 @@ func readDatabaseForConvert(dbType, filePath, connString string) (*models.Databa
|
|||||||
return db, nil
|
return db, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func writeDatabase(db *models.Database, dbType, outputPath, packageName, schemaFilter string, flattenSchema bool, nullableTypes string) error {
|
func writeDatabase(db *models.Database, dbType, outputPath, packageName, schemaFilter string, flattenSchema bool, nullableTypes string, continueOnError bool) error {
|
||||||
var writer writers.Writer
|
var writer writers.Writer
|
||||||
|
|
||||||
writerOpts := newWriterOptions(outputPath, packageName, flattenSchema, nullableTypes)
|
writerOpts := newWriterOptions(outputPath, packageName, flattenSchema, nullableTypes, continueOnError)
|
||||||
|
|
||||||
switch strings.ToLower(dbType) {
|
switch strings.ToLower(dbType) {
|
||||||
case "dbml":
|
case "dbml":
|
||||||
|
|||||||
+13
-13
@@ -323,31 +323,31 @@ func writeDatabaseForEdit(dbType, filePath, connString string, db *models.Databa
|
|||||||
|
|
||||||
switch strings.ToLower(dbType) {
|
switch strings.ToLower(dbType) {
|
||||||
case "dbml":
|
case "dbml":
|
||||||
writer = wdbml.NewWriter(newWriterOptions(filePath, "", false, ""))
|
writer = wdbml.NewWriter(newWriterOptions(filePath, "", false, "", false))
|
||||||
case "dctx":
|
case "dctx":
|
||||||
writer = wdctx.NewWriter(newWriterOptions(filePath, "", false, ""))
|
writer = wdctx.NewWriter(newWriterOptions(filePath, "", false, "", false))
|
||||||
case "drawdb":
|
case "drawdb":
|
||||||
writer = wdrawdb.NewWriter(newWriterOptions(filePath, "", false, ""))
|
writer = wdrawdb.NewWriter(newWriterOptions(filePath, "", false, "", false))
|
||||||
case "graphql":
|
case "graphql":
|
||||||
writer = wgraphql.NewWriter(newWriterOptions(filePath, "", false, ""))
|
writer = wgraphql.NewWriter(newWriterOptions(filePath, "", false, "", false))
|
||||||
case "json":
|
case "json":
|
||||||
writer = wjson.NewWriter(newWriterOptions(filePath, "", false, ""))
|
writer = wjson.NewWriter(newWriterOptions(filePath, "", false, "", false))
|
||||||
case "yaml":
|
case "yaml":
|
||||||
writer = wyaml.NewWriter(newWriterOptions(filePath, "", false, ""))
|
writer = wyaml.NewWriter(newWriterOptions(filePath, "", false, "", false))
|
||||||
case "gorm":
|
case "gorm":
|
||||||
writer = wgorm.NewWriter(newWriterOptions(filePath, "", false, ""))
|
writer = wgorm.NewWriter(newWriterOptions(filePath, "", false, "", false))
|
||||||
case "bun":
|
case "bun":
|
||||||
writer = wbun.NewWriter(newWriterOptions(filePath, "", false, ""))
|
writer = wbun.NewWriter(newWriterOptions(filePath, "", false, "", false))
|
||||||
case "drizzle":
|
case "drizzle":
|
||||||
writer = wdrizzle.NewWriter(newWriterOptions(filePath, "", false, ""))
|
writer = wdrizzle.NewWriter(newWriterOptions(filePath, "", false, "", false))
|
||||||
case "prisma":
|
case "prisma":
|
||||||
writer = wprisma.NewWriter(newWriterOptions(filePath, "", false, ""))
|
writer = wprisma.NewWriter(newWriterOptions(filePath, "", false, "", false))
|
||||||
case "typeorm":
|
case "typeorm":
|
||||||
writer = wtypeorm.NewWriter(newWriterOptions(filePath, "", false, ""))
|
writer = wtypeorm.NewWriter(newWriterOptions(filePath, "", false, "", false))
|
||||||
case "sqlite", "sqlite3":
|
case "sqlite", "sqlite3":
|
||||||
writer = wsqlite.NewWriter(newWriterOptions(filePath, "", false, ""))
|
writer = wsqlite.NewWriter(newWriterOptions(filePath, "", false, "", false))
|
||||||
case "pgsql":
|
case "pgsql":
|
||||||
writer = wpgsql.NewWriter(newWriterOptions(filePath, "", false, ""))
|
writer = wpgsql.NewWriter(newWriterOptions(filePath, "", false, "", false))
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("%s: unsupported format: %s", label, dbType)
|
return fmt.Errorf("%s: unsupported format: %s", label, dbType)
|
||||||
}
|
}
|
||||||
|
|||||||
+18
-13
@@ -258,6 +258,11 @@ func runMerge(cmd *cobra.Command, args []string) error {
|
|||||||
fmt.Fprintf(os.Stderr, " ✓ Merge complete\n\n")
|
fmt.Fprintf(os.Stderr, " ✓ Merge complete\n\n")
|
||||||
fmt.Fprintf(os.Stderr, "%s\n", merge.GetMergeSummary(result))
|
fmt.Fprintf(os.Stderr, "%s\n", merge.GetMergeSummary(result))
|
||||||
|
|
||||||
|
if strings.EqualFold(mergeOutputType, "pgsql") && len(result.TypeConflicts) > 0 {
|
||||||
|
return fmt.Errorf("merge detected conflicting existing column types and cannot safely continue with pgsql output\n%s",
|
||||||
|
merge.GetColumnTypeConflictSummary(result, 10))
|
||||||
|
}
|
||||||
|
|
||||||
// Step 4: Write output
|
// Step 4: Write output
|
||||||
fmt.Fprintf(os.Stderr, "\n[4/4] Writing output...\n")
|
fmt.Fprintf(os.Stderr, "\n[4/4] Writing output...\n")
|
||||||
fmt.Fprintf(os.Stderr, " Format: %s\n", mergeOutputType)
|
fmt.Fprintf(os.Stderr, " Format: %s\n", mergeOutputType)
|
||||||
@@ -370,61 +375,61 @@ func writeDatabaseForMerge(dbType, filePath, connString string, db *models.Datab
|
|||||||
if filePath == "" {
|
if filePath == "" {
|
||||||
return fmt.Errorf("%s: file path is required for DBML format", label)
|
return fmt.Errorf("%s: file path is required for DBML format", label)
|
||||||
}
|
}
|
||||||
writer = wdbml.NewWriter(newWriterOptions(filePath, "", flattenSchema, ""))
|
writer = wdbml.NewWriter(newWriterOptions(filePath, "", flattenSchema, "", false))
|
||||||
case "dctx":
|
case "dctx":
|
||||||
if filePath == "" {
|
if filePath == "" {
|
||||||
return fmt.Errorf("%s: file path is required for DCTX format", label)
|
return fmt.Errorf("%s: file path is required for DCTX format", label)
|
||||||
}
|
}
|
||||||
writer = wdctx.NewWriter(newWriterOptions(filePath, "", flattenSchema, ""))
|
writer = wdctx.NewWriter(newWriterOptions(filePath, "", flattenSchema, "", false))
|
||||||
case "drawdb":
|
case "drawdb":
|
||||||
if filePath == "" {
|
if filePath == "" {
|
||||||
return fmt.Errorf("%s: file path is required for DrawDB format", label)
|
return fmt.Errorf("%s: file path is required for DrawDB format", label)
|
||||||
}
|
}
|
||||||
writer = wdrawdb.NewWriter(newWriterOptions(filePath, "", flattenSchema, ""))
|
writer = wdrawdb.NewWriter(newWriterOptions(filePath, "", flattenSchema, "", false))
|
||||||
case "graphql":
|
case "graphql":
|
||||||
if filePath == "" {
|
if filePath == "" {
|
||||||
return fmt.Errorf("%s: file path is required for GraphQL format", label)
|
return fmt.Errorf("%s: file path is required for GraphQL format", label)
|
||||||
}
|
}
|
||||||
writer = wgraphql.NewWriter(newWriterOptions(filePath, "", flattenSchema, ""))
|
writer = wgraphql.NewWriter(newWriterOptions(filePath, "", flattenSchema, "", false))
|
||||||
case "json":
|
case "json":
|
||||||
if filePath == "" {
|
if filePath == "" {
|
||||||
return fmt.Errorf("%s: file path is required for JSON format", label)
|
return fmt.Errorf("%s: file path is required for JSON format", label)
|
||||||
}
|
}
|
||||||
writer = wjson.NewWriter(newWriterOptions(filePath, "", flattenSchema, ""))
|
writer = wjson.NewWriter(newWriterOptions(filePath, "", flattenSchema, "", false))
|
||||||
case "yaml":
|
case "yaml":
|
||||||
if filePath == "" {
|
if filePath == "" {
|
||||||
return fmt.Errorf("%s: file path is required for YAML format", label)
|
return fmt.Errorf("%s: file path is required for YAML format", label)
|
||||||
}
|
}
|
||||||
writer = wyaml.NewWriter(newWriterOptions(filePath, "", flattenSchema, ""))
|
writer = wyaml.NewWriter(newWriterOptions(filePath, "", flattenSchema, "", false))
|
||||||
case "gorm":
|
case "gorm":
|
||||||
if filePath == "" {
|
if filePath == "" {
|
||||||
return fmt.Errorf("%s: file path is required for GORM format", label)
|
return fmt.Errorf("%s: file path is required for GORM format", label)
|
||||||
}
|
}
|
||||||
writer = wgorm.NewWriter(newWriterOptions(filePath, "", flattenSchema, ""))
|
writer = wgorm.NewWriter(newWriterOptions(filePath, "", flattenSchema, "", false))
|
||||||
case "bun":
|
case "bun":
|
||||||
if filePath == "" {
|
if filePath == "" {
|
||||||
return fmt.Errorf("%s: file path is required for Bun format", label)
|
return fmt.Errorf("%s: file path is required for Bun format", label)
|
||||||
}
|
}
|
||||||
writer = wbun.NewWriter(newWriterOptions(filePath, "", flattenSchema, ""))
|
writer = wbun.NewWriter(newWriterOptions(filePath, "", flattenSchema, "", false))
|
||||||
case "drizzle":
|
case "drizzle":
|
||||||
if filePath == "" {
|
if filePath == "" {
|
||||||
return fmt.Errorf("%s: file path is required for Drizzle format", label)
|
return fmt.Errorf("%s: file path is required for Drizzle format", label)
|
||||||
}
|
}
|
||||||
writer = wdrizzle.NewWriter(newWriterOptions(filePath, "", flattenSchema, ""))
|
writer = wdrizzle.NewWriter(newWriterOptions(filePath, "", flattenSchema, "", false))
|
||||||
case "prisma":
|
case "prisma":
|
||||||
if filePath == "" {
|
if filePath == "" {
|
||||||
return fmt.Errorf("%s: file path is required for Prisma format", label)
|
return fmt.Errorf("%s: file path is required for Prisma format", label)
|
||||||
}
|
}
|
||||||
writer = wprisma.NewWriter(newWriterOptions(filePath, "", flattenSchema, ""))
|
writer = wprisma.NewWriter(newWriterOptions(filePath, "", flattenSchema, "", false))
|
||||||
case "typeorm":
|
case "typeorm":
|
||||||
if filePath == "" {
|
if filePath == "" {
|
||||||
return fmt.Errorf("%s: file path is required for TypeORM format", label)
|
return fmt.Errorf("%s: file path is required for TypeORM format", label)
|
||||||
}
|
}
|
||||||
writer = wtypeorm.NewWriter(newWriterOptions(filePath, "", flattenSchema, ""))
|
writer = wtypeorm.NewWriter(newWriterOptions(filePath, "", flattenSchema, "", false))
|
||||||
case "sqlite", "sqlite3":
|
case "sqlite", "sqlite3":
|
||||||
writer = wsqlite.NewWriter(newWriterOptions(filePath, "", flattenSchema, ""))
|
writer = wsqlite.NewWriter(newWriterOptions(filePath, "", flattenSchema, "", false))
|
||||||
case "pgsql":
|
case "pgsql":
|
||||||
writerOpts := newWriterOptions(filePath, "", flattenSchema, "")
|
writerOpts := newWriterOptions(filePath, "", flattenSchema, "", false)
|
||||||
if connString != "" {
|
if connString != "" {
|
||||||
writerOpts.Metadata = map[string]interface{}{
|
writerOpts.Metadata = map[string]interface{}{
|
||||||
"connection_string": connString,
|
"connection_string": connString,
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package main
|
|||||||
import (
|
import (
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -160,3 +161,38 @@ func TestRunMerge_FromListMissingSourceType(t *testing.T) {
|
|||||||
t.Error("expected error when neither --source-path nor --from-list is provided")
|
t.Error("expected error when neither --source-path nor --from-list is provided")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRunMerge_PgsqlOutputRejectsColumnTypeConflict(t *testing.T) {
|
||||||
|
saved := saveMergeState()
|
||||||
|
defer restoreMergeState(saved)
|
||||||
|
|
||||||
|
dir := t.TempDir()
|
||||||
|
targetFile := filepath.Join(dir, "target.json")
|
||||||
|
sourceFile := filepath.Join(dir, "source.json")
|
||||||
|
writeTestJSONWithSingleColumnType(t, targetFile, "users", "integer")
|
||||||
|
writeTestJSONWithSingleColumnType(t, sourceFile, "users", "uuid")
|
||||||
|
|
||||||
|
mergeTargetType = "json"
|
||||||
|
mergeTargetPath = targetFile
|
||||||
|
mergeTargetConn = ""
|
||||||
|
mergeSourceType = "json"
|
||||||
|
mergeSourcePath = sourceFile
|
||||||
|
mergeSourceConn = ""
|
||||||
|
mergeFromList = nil
|
||||||
|
mergeOutputType = "pgsql"
|
||||||
|
mergeOutputPath = ""
|
||||||
|
mergeOutputConn = "postgres://relspec:secret@localhost/testdb"
|
||||||
|
mergeSkipTables = ""
|
||||||
|
mergeReportPath = ""
|
||||||
|
|
||||||
|
err := runMerge(nil, nil)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected pgsql output merge to fail on column type conflict")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "column type conflicts detected") {
|
||||||
|
t.Fatalf("expected conflict summary in error, got: %v", err)
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "public.users.id") {
|
||||||
|
t.Fatalf("expected conflicting column path in error, got: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -13,12 +13,13 @@ func newReaderOptions(filePath, connString string) *readers.ReaderOptions {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func newWriterOptions(outputPath, packageName string, flattenSchema bool, nullableTypes string) *writers.WriterOptions {
|
func newWriterOptions(outputPath, packageName string, flattenSchema bool, nullableTypes string, continueOnError bool) *writers.WriterOptions {
|
||||||
return &writers.WriterOptions{
|
return &writers.WriterOptions{
|
||||||
OutputPath: outputPath,
|
OutputPath: outputPath,
|
||||||
PackageName: packageName,
|
PackageName: packageName,
|
||||||
FlattenSchema: flattenSchema,
|
FlattenSchema: flattenSchema,
|
||||||
NullableTypes: nullableTypes,
|
NullableTypes: nullableTypes,
|
||||||
Prisma7: prisma7,
|
Prisma7: prisma7,
|
||||||
|
ContinueOnError: continueOnError,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -188,6 +188,7 @@ func runSplit(cmd *cobra.Command, args []string) error {
|
|||||||
"", // no schema filter for split
|
"", // no schema filter for split
|
||||||
false, // no flatten-schema for split
|
false, // no flatten-schema for split
|
||||||
splitNullableTypes,
|
splitNullableTypes,
|
||||||
|
false, // no continue-on-error for split
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to write output: %w", err)
|
return fmt.Errorf("failed to write output: %w", err)
|
||||||
|
|||||||
@@ -71,6 +71,40 @@ func writeTestJSON(t *testing.T, path string, tableNames []string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func writeTestJSONWithSingleColumnType(t *testing.T, path, tableName, columnType string) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
db := minimalDatabase{
|
||||||
|
Name: "test_db",
|
||||||
|
Schemas: []minimalSchema{{
|
||||||
|
Name: "public",
|
||||||
|
Tables: []minimalTable{{
|
||||||
|
Name: tableName,
|
||||||
|
Schema: "public",
|
||||||
|
Columns: map[string]minimalColumn{
|
||||||
|
"id": {
|
||||||
|
Name: "id",
|
||||||
|
Table: tableName,
|
||||||
|
Schema: "public",
|
||||||
|
Type: columnType,
|
||||||
|
NotNull: true,
|
||||||
|
IsPrimaryKey: true,
|
||||||
|
AutoIncrement: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := json.Marshal(db)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to marshal test JSON: %v", err)
|
||||||
|
}
|
||||||
|
if err := os.WriteFile(path, data, 0644); err != nil {
|
||||||
|
t.Fatalf("failed to write test file %s: %v", path, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// convertState captures and restores all convert global vars.
|
// convertState captures and restores all convert global vars.
|
||||||
type convertState struct {
|
type convertState struct {
|
||||||
sourceType string
|
sourceType string
|
||||||
|
|||||||
+1
-1
@@ -1,6 +1,6 @@
|
|||||||
# Maintainer: Hein (Warky Devs) <hein@warky.dev>
|
# Maintainer: Hein (Warky Devs) <hein@warky.dev>
|
||||||
pkgname=relspec
|
pkgname=relspec
|
||||||
pkgver=1.0.54
|
pkgver=1.0.58
|
||||||
pkgrel=1
|
pkgrel=1
|
||||||
pkgdesc="RelSpec is a comprehensive database relations management tool that reads, transforms, and writes database table specifications across multiple formats and ORMs."
|
pkgdesc="RelSpec is a comprehensive database relations management tool that reads, transforms, and writes database table specifications across multiple formats and ORMs."
|
||||||
arch=('x86_64' 'aarch64')
|
arch=('x86_64' 'aarch64')
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
Name: relspec
|
Name: relspec
|
||||||
Version: 1.0.54
|
Version: 1.0.58
|
||||||
Release: 1%{?dist}
|
Release: 1%{?dist}
|
||||||
Summary: RelSpec is a comprehensive database relations management tool that reads, transforms, and writes database table specifications across multiple formats and ORMs.
|
Summary: RelSpec is a comprehensive database relations management tool that reads, transforms, and writes database table specifications across multiple formats and ORMs.
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,156 @@
|
|||||||
|
package mariadb
|
||||||
|
|
||||||
|
import "strings"
|
||||||
|
|
||||||
|
// MariaDBToCanonicalTypes maps MariaDB/MySQL type names to canonical types.
|
||||||
|
var MariaDBToCanonicalTypes = map[string]string{
|
||||||
|
// Integer types
|
||||||
|
"tinyint": "int8",
|
||||||
|
"smallint": "int16",
|
||||||
|
"mediumint": "int",
|
||||||
|
"int": "int",
|
||||||
|
"integer": "int",
|
||||||
|
"int2": "int16",
|
||||||
|
"int4": "int",
|
||||||
|
"int8": "int64",
|
||||||
|
"bigint": "int64",
|
||||||
|
// Boolean (TINYINT(1) alias)
|
||||||
|
"boolean": "bool",
|
||||||
|
"bool": "bool",
|
||||||
|
"bit": "bool",
|
||||||
|
// Float types
|
||||||
|
"float": "float32",
|
||||||
|
"double": "float64",
|
||||||
|
"real": "float64",
|
||||||
|
"double precision": "float64",
|
||||||
|
// Decimal types
|
||||||
|
"decimal": "decimal",
|
||||||
|
"numeric": "decimal",
|
||||||
|
"dec": "decimal",
|
||||||
|
"fixed": "decimal",
|
||||||
|
// String types
|
||||||
|
"char": "string",
|
||||||
|
"character": "string",
|
||||||
|
"varchar": "string",
|
||||||
|
"nchar": "string",
|
||||||
|
"nvarchar": "string",
|
||||||
|
"tinytext": "text",
|
||||||
|
"text": "text",
|
||||||
|
"mediumtext": "text",
|
||||||
|
"longtext": "text",
|
||||||
|
// Binary/blob types
|
||||||
|
"binary": "bytea",
|
||||||
|
"varbinary": "bytea",
|
||||||
|
"tinyblob": "bytea",
|
||||||
|
"blob": "bytea",
|
||||||
|
"mediumblob": "bytea",
|
||||||
|
"longblob": "bytea",
|
||||||
|
// Date/time types
|
||||||
|
"date": "date",
|
||||||
|
"time": "time",
|
||||||
|
"datetime": "timestamp",
|
||||||
|
"timestamp": "timestamp",
|
||||||
|
"year": "int",
|
||||||
|
// Other types
|
||||||
|
"json": "json",
|
||||||
|
"enum": "string",
|
||||||
|
"set": "string",
|
||||||
|
"uuid": "uuid",
|
||||||
|
}
|
||||||
|
|
||||||
|
// CanonicalToMariaDBTypes maps canonical types to MariaDB/MySQL types.
|
||||||
|
var CanonicalToMariaDBTypes = map[string]string{
|
||||||
|
"bool": "TINYINT(1)",
|
||||||
|
"int8": "TINYINT",
|
||||||
|
"int16": "SMALLINT",
|
||||||
|
"int": "INT",
|
||||||
|
"int32": "INT",
|
||||||
|
"int64": "BIGINT",
|
||||||
|
"uint": "INT UNSIGNED",
|
||||||
|
"uint8": "TINYINT UNSIGNED",
|
||||||
|
"uint16": "SMALLINT UNSIGNED",
|
||||||
|
"uint32": "INT UNSIGNED",
|
||||||
|
"uint64": "BIGINT UNSIGNED",
|
||||||
|
"float32": "FLOAT",
|
||||||
|
"float64": "DOUBLE",
|
||||||
|
"decimal": "DECIMAL",
|
||||||
|
"string": "VARCHAR(255)",
|
||||||
|
"text": "TEXT",
|
||||||
|
"date": "DATE",
|
||||||
|
"time": "TIME",
|
||||||
|
"timestamp": "DATETIME",
|
||||||
|
"timestamptz": "DATETIME",
|
||||||
|
"uuid": "CHAR(36)",
|
||||||
|
"json": "JSON",
|
||||||
|
"jsonb": "JSON",
|
||||||
|
"bytea": "BLOB",
|
||||||
|
}
|
||||||
|
|
||||||
|
// MariaDBTypeSynonyms maps MariaDB/MySQL type aliases to their canonical MariaDB name.
|
||||||
|
var MariaDBTypeSynonyms = map[string]string{
|
||||||
|
"integer": "int",
|
||||||
|
"int2": "smallint",
|
||||||
|
"int4": "int",
|
||||||
|
"int8": "bigint",
|
||||||
|
"double precision": "double",
|
||||||
|
"character": "char",
|
||||||
|
"dec": "decimal",
|
||||||
|
"fixed": "decimal",
|
||||||
|
"numeric": "decimal",
|
||||||
|
"boolean": "tinyint",
|
||||||
|
"bool": "tinyint",
|
||||||
|
}
|
||||||
|
|
||||||
|
// NormalizeMariaDBType maps a MariaDB/MySQL base type (no dimension parameters)
|
||||||
|
// to its canonical MariaDB form. Unknown types are returned as-is (lowercased).
|
||||||
|
func NormalizeMariaDBType(baseType string) string {
|
||||||
|
lower := strings.ToLower(strings.TrimSpace(baseType))
|
||||||
|
if canonical, ok := MariaDBTypeSynonyms[lower]; ok {
|
||||||
|
return canonical
|
||||||
|
}
|
||||||
|
return lower
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConvertMariaDBToCanonical converts a MariaDB/MySQL type name to the canonical type.
|
||||||
|
// Strips dimension parameters and normalizes aliases. Defaults to "string".
|
||||||
|
func ConvertMariaDBToCanonical(mariadbType string) string {
|
||||||
|
base := strings.ToLower(strings.TrimSpace(mariadbType))
|
||||||
|
if idx := strings.Index(base, "("); idx >= 0 {
|
||||||
|
base = strings.TrimSpace(base[:idx])
|
||||||
|
}
|
||||||
|
|
||||||
|
if canonical, ok := MariaDBToCanonicalTypes[base]; ok {
|
||||||
|
return canonical
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prefix match for composite types (e.g., "unsigned bigint")
|
||||||
|
for key, canonical := range MariaDBToCanonicalTypes {
|
||||||
|
if strings.HasPrefix(base, key) {
|
||||||
|
return canonical
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return "string"
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConvertCanonicalToMariaDB converts a canonical type to a MariaDB/MySQL type.
|
||||||
|
// Defaults to VARCHAR(255) for unrecognised types.
|
||||||
|
func ConvertCanonicalToMariaDB(canonicalType string) string {
|
||||||
|
lower := strings.ToLower(strings.TrimSpace(canonicalType))
|
||||||
|
if idx := strings.Index(lower, "("); idx >= 0 {
|
||||||
|
lower = strings.TrimSpace(lower[:idx])
|
||||||
|
}
|
||||||
|
|
||||||
|
if mariadbType, ok := CanonicalToMariaDBTypes[lower]; ok {
|
||||||
|
return mariadbType
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prefix fallback
|
||||||
|
for canonical, mariadb := range CanonicalToMariaDBTypes {
|
||||||
|
if strings.HasPrefix(lower, canonical) {
|
||||||
|
return mariadb
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return "VARCHAR(255)"
|
||||||
|
}
|
||||||
+126
-1
@@ -5,9 +5,11 @@ package merge
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||||
|
"git.warky.dev/wdevs/relspecgo/pkg/pgsql"
|
||||||
)
|
)
|
||||||
|
|
||||||
// MergeResult represents the result of a merge operation
|
// MergeResult represents the result of a merge operation
|
||||||
@@ -22,6 +24,16 @@ type MergeResult struct {
|
|||||||
EnumsAdded int
|
EnumsAdded int
|
||||||
ViewsAdded int
|
ViewsAdded int
|
||||||
SequencesAdded int
|
SequencesAdded int
|
||||||
|
TypeConflicts []ColumnTypeConflict
|
||||||
|
}
|
||||||
|
|
||||||
|
// ColumnTypeConflict describes a column that exists in both schemas but with incompatible types.
|
||||||
|
type ColumnTypeConflict struct {
|
||||||
|
Schema string
|
||||||
|
Table string
|
||||||
|
Column string
|
||||||
|
TargetType string
|
||||||
|
SourceType string
|
||||||
}
|
}
|
||||||
|
|
||||||
// MergeOptions contains options for merge operations
|
// MergeOptions contains options for merge operations
|
||||||
@@ -146,11 +158,19 @@ func (r *MergeResult) mergeColumns(table *models.Table, srcTable *models.Table)
|
|||||||
|
|
||||||
// Merge columns
|
// Merge columns
|
||||||
for colName, srcCol := range srcTable.Columns {
|
for colName, srcCol := range srcTable.Columns {
|
||||||
if _, exists := existingColumns[colName]; !exists {
|
if tgtCol, exists := existingColumns[colName]; !exists {
|
||||||
// Column doesn't exist, add it
|
// Column doesn't exist, add it
|
||||||
newCol := cloneColumn(srcCol)
|
newCol := cloneColumn(srcCol)
|
||||||
table.Columns[colName] = newCol
|
table.Columns[colName] = newCol
|
||||||
r.ColumnsAdded++
|
r.ColumnsAdded++
|
||||||
|
} else if columnTypeConflict(tgtCol, srcCol) {
|
||||||
|
r.TypeConflicts = append(r.TypeConflicts, ColumnTypeConflict{
|
||||||
|
Schema: firstNonEmpty(table.Schema, srcTable.Schema, srcCol.Schema),
|
||||||
|
Table: firstNonEmpty(table.Name, srcTable.Name, srcCol.Table),
|
||||||
|
Column: firstNonEmpty(tgtCol.Name, srcCol.Name, colName),
|
||||||
|
TargetType: describeColumnType(tgtCol),
|
||||||
|
SourceType: describeColumnType(srcCol),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -426,6 +446,78 @@ func cloneColumn(col *models.Column) *models.Column {
|
|||||||
return newCol
|
return newCol
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func columnTypeConflict(target, source *models.Column) bool {
|
||||||
|
if target == nil || source == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
tType, tLen, tPrec, tScale := extractTypeParts(target)
|
||||||
|
sType, sLen, sPrec, sScale := extractTypeParts(source)
|
||||||
|
|
||||||
|
return tType != sType || tLen != sLen || tPrec != sPrec || tScale != sScale
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractTypeParts returns the canonical base type and dimensions for a column,
|
||||||
|
// handling the case where dimensions are embedded in the type string (e.g. "char(2)")
|
||||||
|
// rather than stored in the separate Length/Precision/Scale fields.
|
||||||
|
func extractTypeParts(col *models.Column) (baseType string, length, precision, scale int) {
|
||||||
|
typeName := strings.ToLower(strings.TrimSpace(col.Type))
|
||||||
|
length, precision, scale = col.Length, col.Precision, col.Scale
|
||||||
|
|
||||||
|
if idx := strings.Index(typeName, "("); idx >= 0 {
|
||||||
|
inner := strings.TrimRight(strings.TrimSpace(typeName[idx+1:]), ")")
|
||||||
|
typeName = strings.TrimSpace(typeName[:idx])
|
||||||
|
parts := strings.Split(inner, ",")
|
||||||
|
if len(parts) == 2 {
|
||||||
|
if p, err := strconv.Atoi(strings.TrimSpace(parts[0])); err == nil && p > 0 && precision == 0 {
|
||||||
|
precision = p
|
||||||
|
}
|
||||||
|
if s, err := strconv.Atoi(strings.TrimSpace(parts[1])); err == nil && s > 0 && scale == 0 {
|
||||||
|
scale = s
|
||||||
|
}
|
||||||
|
} else if len(parts) == 1 {
|
||||||
|
if l, err := strconv.Atoi(strings.TrimSpace(parts[0])); err == nil && l > 0 && length == 0 && precision == 0 {
|
||||||
|
length = l
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
typeName = pgsql.NormalizePGType(typeName)
|
||||||
|
|
||||||
|
return typeName, length, precision, scale
|
||||||
|
}
|
||||||
|
|
||||||
|
func describeColumnType(col *models.Column) string {
|
||||||
|
if col == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
typeName := strings.TrimSpace(col.Type)
|
||||||
|
if typeName == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case col.Precision > 0 && col.Scale > 0:
|
||||||
|
return fmt.Sprintf("%s(%d,%d)", typeName, col.Precision, col.Scale)
|
||||||
|
case col.Precision > 0:
|
||||||
|
return fmt.Sprintf("%s(%d)", typeName, col.Precision)
|
||||||
|
case col.Length > 0:
|
||||||
|
return fmt.Sprintf("%s(%d)", typeName, col.Length)
|
||||||
|
default:
|
||||||
|
return typeName
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func firstNonEmpty(values ...string) string {
|
||||||
|
for _, value := range values {
|
||||||
|
if strings.TrimSpace(value) != "" {
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
func cloneConstraint(constraint *models.Constraint) *models.Constraint {
|
func cloneConstraint(constraint *models.Constraint) *models.Constraint {
|
||||||
if constraint == nil {
|
if constraint == nil {
|
||||||
return nil
|
return nil
|
||||||
@@ -609,6 +701,7 @@ func GetMergeSummary(result *MergeResult) string {
|
|||||||
fmt.Sprintf("Enums added: %d", result.EnumsAdded),
|
fmt.Sprintf("Enums added: %d", result.EnumsAdded),
|
||||||
fmt.Sprintf("Relations added: %d", result.RelationsAdded),
|
fmt.Sprintf("Relations added: %d", result.RelationsAdded),
|
||||||
fmt.Sprintf("Domains added: %d", result.DomainsAdded),
|
fmt.Sprintf("Domains added: %d", result.DomainsAdded),
|
||||||
|
fmt.Sprintf("Type conflicts: %d", len(result.TypeConflicts)),
|
||||||
}
|
}
|
||||||
|
|
||||||
totalAdded := result.SchemasAdded + result.TablesAdded + result.ColumnsAdded +
|
totalAdded := result.SchemasAdded + result.TablesAdded + result.ColumnsAdded +
|
||||||
@@ -625,3 +718,35 @@ func GetMergeSummary(result *MergeResult) string {
|
|||||||
|
|
||||||
return summary
|
return summary
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetColumnTypeConflictSummary returns a short, human-readable conflict summary.
|
||||||
|
func GetColumnTypeConflictSummary(result *MergeResult, limit int) string {
|
||||||
|
if result == nil || len(result.TypeConflicts) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if limit <= 0 {
|
||||||
|
limit = len(result.TypeConflicts)
|
||||||
|
}
|
||||||
|
|
||||||
|
lines := make([]string, 0, min(limit, len(result.TypeConflicts))+1)
|
||||||
|
lines = append(lines, "column type conflicts detected:")
|
||||||
|
for i, conflict := range result.TypeConflicts {
|
||||||
|
if i >= limit {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
lines = append(lines, fmt.Sprintf(" - %s.%s.%s: target=%s source=%s",
|
||||||
|
conflict.Schema, conflict.Table, conflict.Column, conflict.TargetType, conflict.SourceType))
|
||||||
|
}
|
||||||
|
if len(result.TypeConflicts) > limit {
|
||||||
|
lines = append(lines, fmt.Sprintf(" ... and %d more", len(result.TypeConflicts)-limit))
|
||||||
|
}
|
||||||
|
|
||||||
|
return strings.Join(lines, "\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
func min(a, b int) int {
|
||||||
|
if a < b {
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package merge
|
package merge
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||||
@@ -140,6 +141,61 @@ func TestMergeColumns_NewColumn(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMergeColumns_TypeConflictIsDetected(t *testing.T) {
|
||||||
|
target := &models.Database{
|
||||||
|
Schemas: []*models.Schema{
|
||||||
|
{
|
||||||
|
Name: "public",
|
||||||
|
Tables: []*models.Table{
|
||||||
|
{
|
||||||
|
Name: "users",
|
||||||
|
Schema: "public",
|
||||||
|
Columns: map[string]*models.Column{
|
||||||
|
"email": {Name: "email", Type: "varchar", Length: 255},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
source := &models.Database{
|
||||||
|
Schemas: []*models.Schema{
|
||||||
|
{
|
||||||
|
Name: "public",
|
||||||
|
Tables: []*models.Table{
|
||||||
|
{
|
||||||
|
Name: "users",
|
||||||
|
Schema: "public",
|
||||||
|
Columns: map[string]*models.Column{
|
||||||
|
"email": {Name: "email", Type: "text"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := MergeDatabases(target, source, nil)
|
||||||
|
|
||||||
|
if len(result.TypeConflicts) != 1 {
|
||||||
|
t.Fatalf("Expected 1 type conflict, got %d", len(result.TypeConflicts))
|
||||||
|
}
|
||||||
|
conflict := result.TypeConflicts[0]
|
||||||
|
if conflict.Schema != "public" || conflict.Table != "users" || conflict.Column != "email" {
|
||||||
|
t.Fatalf("Unexpected conflict location: %+v", conflict)
|
||||||
|
}
|
||||||
|
if conflict.TargetType != "varchar(255)" {
|
||||||
|
t.Fatalf("Expected target type varchar(255), got %q", conflict.TargetType)
|
||||||
|
}
|
||||||
|
if conflict.SourceType != "text" {
|
||||||
|
t.Fatalf("Expected source type text, got %q", conflict.SourceType)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := target.Schemas[0].Tables[0].Columns["email"].Type; got != "varchar" {
|
||||||
|
t.Fatalf("Expected target column type to remain unchanged, got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestMergeConstraints_NewConstraint(t *testing.T) {
|
func TestMergeConstraints_NewConstraint(t *testing.T) {
|
||||||
target := &models.Database{
|
target := &models.Database{
|
||||||
Schemas: []*models.Schema{
|
Schemas: []*models.Schema{
|
||||||
@@ -509,6 +565,9 @@ func TestGetMergeSummary(t *testing.T) {
|
|||||||
ConstraintsAdded: 3,
|
ConstraintsAdded: 3,
|
||||||
IndexesAdded: 2,
|
IndexesAdded: 2,
|
||||||
ViewsAdded: 1,
|
ViewsAdded: 1,
|
||||||
|
TypeConflicts: []ColumnTypeConflict{
|
||||||
|
{Schema: "public", Table: "users", Column: "email", TargetType: "varchar(255)", SourceType: "text"},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
summary := GetMergeSummary(result)
|
summary := GetMergeSummary(result)
|
||||||
@@ -518,6 +577,9 @@ func TestGetMergeSummary(t *testing.T) {
|
|||||||
if len(summary) < 50 {
|
if len(summary) < 50 {
|
||||||
t.Errorf("Summary seems too short: %s", summary)
|
t.Errorf("Summary seems too short: %s", summary)
|
||||||
}
|
}
|
||||||
|
if !strings.Contains(summary, "Type conflicts: 1") {
|
||||||
|
t.Errorf("Expected type conflict count in summary, got: %s", summary)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGetMergeSummary_Nil(t *testing.T) {
|
func TestGetMergeSummary_Nil(t *testing.T) {
|
||||||
|
|||||||
+76
-32
@@ -2,32 +2,73 @@ package mssql
|
|||||||
|
|
||||||
import "strings"
|
import "strings"
|
||||||
|
|
||||||
// CanonicalToMSSQLTypes maps canonical types to MSSQL types
|
// CanonicalToMSSQLTypes maps canonical types to MSSQL types.
|
||||||
|
// Accepts both Go canonical names ("int", "string") and SQL canonical names
|
||||||
|
// ("integer", "varchar") so the writer handles input from any reader.
|
||||||
var CanonicalToMSSQLTypes = map[string]string{
|
var CanonicalToMSSQLTypes = map[string]string{
|
||||||
|
// Boolean — Go and SQL canonical
|
||||||
"bool": "BIT",
|
"bool": "BIT",
|
||||||
|
"boolean": "BIT",
|
||||||
|
// Integer — Go canonical
|
||||||
"int8": "TINYINT",
|
"int8": "TINYINT",
|
||||||
"int16": "SMALLINT",
|
"int16": "SMALLINT",
|
||||||
"int": "INT",
|
"int": "INT",
|
||||||
"int32": "INT",
|
"int32": "INT",
|
||||||
"int64": "BIGINT",
|
"int64": "BIGINT",
|
||||||
"uint": "BIGINT",
|
"uint": "BIGINT",
|
||||||
"uint8": "SMALLINT",
|
"uint8": "TINYINT",
|
||||||
"uint16": "INT",
|
"uint16": "SMALLINT",
|
||||||
"uint32": "BIGINT",
|
"uint32": "BIGINT",
|
||||||
"uint64": "BIGINT",
|
"uint64": "BIGINT",
|
||||||
|
// Integer — SQL canonical (serial types map to base integer; IDENTITY is set via AutoIncrement)
|
||||||
|
"integer": "INT",
|
||||||
|
"smallint": "SMALLINT",
|
||||||
|
"bigint": "BIGINT",
|
||||||
|
"tinyint": "TINYINT",
|
||||||
|
"serial": "INT",
|
||||||
|
"smallserial": "SMALLINT",
|
||||||
|
"bigserial": "BIGINT",
|
||||||
|
// Float — Go canonical
|
||||||
"float32": "REAL",
|
"float32": "REAL",
|
||||||
"float64": "FLOAT",
|
"float64": "FLOAT",
|
||||||
|
// Float — SQL canonical
|
||||||
|
"real": "REAL",
|
||||||
|
"double precision": "FLOAT",
|
||||||
|
"double": "FLOAT",
|
||||||
|
// Decimal/numeric
|
||||||
"decimal": "NUMERIC",
|
"decimal": "NUMERIC",
|
||||||
|
"numeric": "NUMERIC",
|
||||||
|
"money": "MONEY",
|
||||||
|
// String — Go canonical
|
||||||
"string": "NVARCHAR(255)",
|
"string": "NVARCHAR(255)",
|
||||||
"text": "NVARCHAR(MAX)",
|
"text": "NVARCHAR(MAX)",
|
||||||
|
// String — SQL canonical
|
||||||
|
"varchar": "NVARCHAR(255)",
|
||||||
|
"char": "NCHAR",
|
||||||
|
"nvarchar": "NVARCHAR(255)",
|
||||||
|
"nchar": "NCHAR",
|
||||||
|
"citext": "NVARCHAR(MAX)",
|
||||||
|
// Date/time
|
||||||
"date": "DATE",
|
"date": "DATE",
|
||||||
"time": "TIME",
|
"time": "TIME",
|
||||||
|
"timetz": "DATETIMEOFFSET",
|
||||||
"timestamp": "DATETIME2",
|
"timestamp": "DATETIME2",
|
||||||
"timestamptz": "DATETIMEOFFSET",
|
"timestamptz": "DATETIMEOFFSET",
|
||||||
|
"datetime": "DATETIME2",
|
||||||
|
"interval": "NVARCHAR(50)",
|
||||||
|
// UUID
|
||||||
"uuid": "UNIQUEIDENTIFIER",
|
"uuid": "UNIQUEIDENTIFIER",
|
||||||
|
// JSON — MSSQL has no native JSON type; stored as NVARCHAR(MAX)
|
||||||
"json": "NVARCHAR(MAX)",
|
"json": "NVARCHAR(MAX)",
|
||||||
"jsonb": "NVARCHAR(MAX)",
|
"jsonb": "NVARCHAR(MAX)",
|
||||||
|
// Binary
|
||||||
"bytea": "VARBINARY(MAX)",
|
"bytea": "VARBINARY(MAX)",
|
||||||
|
"blob": "VARBINARY(MAX)",
|
||||||
|
// Network/geo types — no MSSQL native equivalent
|
||||||
|
"xml": "XML",
|
||||||
|
"inet": "NVARCHAR(45)",
|
||||||
|
"cidr": "NVARCHAR(43)",
|
||||||
|
"macaddr": "NVARCHAR(17)",
|
||||||
}
|
}
|
||||||
|
|
||||||
// MSSQLToCanonicalTypes maps MSSQL types to canonical types
|
// MSSQLToCanonicalTypes maps MSSQL types to canonical types
|
||||||
@@ -68,47 +109,50 @@ var MSSQLToCanonicalTypes = map[string]string{
|
|||||||
"geometry": "string",
|
"geometry": "string",
|
||||||
}
|
}
|
||||||
|
|
||||||
// ConvertCanonicalToMSSQL converts a canonical type to MSSQL type
|
// MSSQLTypeSynonyms maps MSSQL type aliases to their canonical MSSQL name.
|
||||||
|
var MSSQLTypeSynonyms = map[string]string{
|
||||||
|
"integer": "int",
|
||||||
|
"dec": "decimal",
|
||||||
|
"float(n)": "float",
|
||||||
|
}
|
||||||
|
|
||||||
|
// NormalizeMSSQLType maps an MSSQL base type (no dimension parameters) to its
|
||||||
|
// canonical MSSQL form. Unknown types are returned as-is (lowercased).
|
||||||
|
func NormalizeMSSQLType(baseType string) string {
|
||||||
|
lower := strings.ToLower(strings.TrimSpace(baseType))
|
||||||
|
if canonical, ok := MSSQLTypeSynonyms[lower]; ok {
|
||||||
|
return canonical
|
||||||
|
}
|
||||||
|
return lower
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConvertCanonicalToMSSQL converts a canonical type (Go or SQL) to an MSSQL type.
|
||||||
|
// Strips dimension parameters before lookup. Defaults to NVARCHAR(255) for unknown types.
|
||||||
func ConvertCanonicalToMSSQL(canonicalType string) string {
|
func ConvertCanonicalToMSSQL(canonicalType string) string {
|
||||||
// Check direct mapping
|
base := strings.ToLower(strings.TrimSpace(canonicalType))
|
||||||
if mssqlType, exists := CanonicalToMSSQLTypes[strings.ToLower(canonicalType)]; exists {
|
if idx := strings.Index(base, "("); idx >= 0 {
|
||||||
|
base = strings.TrimSpace(base[:idx])
|
||||||
|
}
|
||||||
|
base = strings.TrimSuffix(base, "[]")
|
||||||
|
|
||||||
|
if mssqlType, exists := CanonicalToMSSQLTypes[base]; exists {
|
||||||
return mssqlType
|
return mssqlType
|
||||||
}
|
}
|
||||||
|
|
||||||
// Try to find by prefix
|
|
||||||
lowerType := strings.ToLower(canonicalType)
|
|
||||||
for canonical, mssql := range CanonicalToMSSQLTypes {
|
|
||||||
if strings.HasPrefix(lowerType, canonical) {
|
|
||||||
return mssql
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Default to NVARCHAR
|
|
||||||
return "NVARCHAR(255)"
|
return "NVARCHAR(255)"
|
||||||
}
|
}
|
||||||
|
|
||||||
// ConvertMSSQLToCanonical converts an MSSQL type to canonical type
|
// ConvertMSSQLToCanonical converts an MSSQL type to the canonical type.
|
||||||
|
// Strips dimension parameters before lookup. Defaults to "string" for unknown types.
|
||||||
func ConvertMSSQLToCanonical(mssqlType string) string {
|
func ConvertMSSQLToCanonical(mssqlType string) string {
|
||||||
// Extract base type (remove parentheses and parameters)
|
base := strings.ToLower(strings.TrimSpace(mssqlType))
|
||||||
baseType := mssqlType
|
if idx := strings.Index(base, "("); idx >= 0 {
|
||||||
if idx := strings.Index(baseType, "("); idx != -1 {
|
base = strings.TrimSpace(base[:idx])
|
||||||
baseType = baseType[:idx]
|
|
||||||
}
|
}
|
||||||
baseType = strings.TrimSpace(baseType)
|
|
||||||
|
|
||||||
// Check direct mapping
|
if canonicalType, exists := MSSQLToCanonicalTypes[base]; exists {
|
||||||
if canonicalType, exists := MSSQLToCanonicalTypes[strings.ToLower(baseType)]; exists {
|
|
||||||
return canonicalType
|
return canonicalType
|
||||||
}
|
}
|
||||||
|
|
||||||
// Try to find by prefix
|
|
||||||
lowerType := strings.ToLower(baseType)
|
|
||||||
for mssql, canonical := range MSSQLToCanonicalTypes {
|
|
||||||
if strings.HasPrefix(lowerType, mssql) {
|
|
||||||
return canonical
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Default to string
|
|
||||||
return "string"
|
return "string"
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -45,6 +45,7 @@ var GoToStdTypes = map[string]string{
|
|||||||
"sqldate": "date",
|
"sqldate": "date",
|
||||||
"sqltime": "time",
|
"sqltime": "time",
|
||||||
"sqltimestamp": "timestamp",
|
"sqltimestamp": "timestamp",
|
||||||
|
"time.Time": "timestamp",
|
||||||
}
|
}
|
||||||
|
|
||||||
var GoToPGSQLTypes = map[string]string{
|
var GoToPGSQLTypes = map[string]string{
|
||||||
@@ -90,6 +91,7 @@ var GoToPGSQLTypes = map[string]string{
|
|||||||
"sqldate": "date",
|
"sqldate": "date",
|
||||||
"sqltime": "time",
|
"sqltime": "time",
|
||||||
"sqltimestamp": "timestamp",
|
"sqltimestamp": "timestamp",
|
||||||
|
"time.Time": "timestamp",
|
||||||
"citext": "citext",
|
"citext": "citext",
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -135,6 +137,62 @@ func ConvertSQLType(anytype string) string {
|
|||||||
return anytype
|
return anytype
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PGTypeCanonical maps PostgreSQL type aliases and synonyms to their canonical base name.
|
||||||
|
// Input should be a base type (no dimension parameters, lowercase).
|
||||||
|
var PGTypeCanonical = map[string]string{
|
||||||
|
// integer aliases
|
||||||
|
"int": "integer",
|
||||||
|
"int4": "integer",
|
||||||
|
"int2": "smallint",
|
||||||
|
"int8": "bigint",
|
||||||
|
// float aliases
|
||||||
|
"float4": "real",
|
||||||
|
"float8": "double precision",
|
||||||
|
// bool alias
|
||||||
|
"bool": "boolean",
|
||||||
|
// char aliases
|
||||||
|
"character": "char",
|
||||||
|
"character varying": "varchar",
|
||||||
|
"bpchar": "char",
|
||||||
|
// timestamp aliases
|
||||||
|
"timestamp without time zone": "timestamp",
|
||||||
|
"timestamp with time zone": "timestamptz",
|
||||||
|
// time aliases
|
||||||
|
"time without time zone": "time",
|
||||||
|
"time with time zone": "timetz",
|
||||||
|
// decimal alias
|
||||||
|
"decimal": "numeric",
|
||||||
|
}
|
||||||
|
|
||||||
|
// knownPGBaseTypes is the set of canonical PostgreSQL base types (no aliases).
|
||||||
|
var knownPGBaseTypes = map[string]struct{}{
|
||||||
|
"integer": {}, "bigint": {}, "smallint": {},
|
||||||
|
"serial": {}, "bigserial": {}, "smallserial": {},
|
||||||
|
"numeric": {}, "real": {}, "double precision": {}, "money": {},
|
||||||
|
"varchar": {}, "char": {}, "text": {}, "citext": {},
|
||||||
|
"boolean": {},
|
||||||
|
"date": {}, "time": {}, "timetz": {}, "timestamp": {}, "timestamptz": {}, "interval": {},
|
||||||
|
"uuid": {}, "json": {}, "jsonb": {}, "bytea": {},
|
||||||
|
"inet": {}, "cidr": {}, "macaddr": {}, "xml": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
// NormalizePGType maps a PostgreSQL base type (no dimension parameters) to its
|
||||||
|
// canonical form. Unknown types are returned as-is (lowercased).
|
||||||
|
func NormalizePGType(baseType string) string {
|
||||||
|
lower := strings.ToLower(strings.TrimSpace(baseType))
|
||||||
|
if canonical, ok := PGTypeCanonical[lower]; ok {
|
||||||
|
return canonical
|
||||||
|
}
|
||||||
|
return lower
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsKnownPGBaseType reports whether the given name (after NormalizePGType) is a
|
||||||
|
// recognized built-in PostgreSQL type. Custom types (e.g. vector, postgis) return false.
|
||||||
|
func IsKnownPGBaseType(baseType string) bool {
|
||||||
|
_, ok := knownPGBaseTypes[strings.ToLower(strings.TrimSpace(baseType))]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
func IsGoType(pTypeName string) bool {
|
func IsGoType(pTypeName string) bool {
|
||||||
for k := range GoToStdTypes {
|
for k := range GoToStdTypes {
|
||||||
if strings.EqualFold(pTypeName, k) {
|
if strings.EqualFold(pTypeName, k) {
|
||||||
|
|||||||
@@ -145,6 +145,24 @@ var postgresTypeAliases = map[string]string{
|
|||||||
"bool": "boolean",
|
"bool": "boolean",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var postgresEquivalentBaseTypes = map[string]string{
|
||||||
|
"character varying": "varchar",
|
||||||
|
"character": "char",
|
||||||
|
"timestamp without time zone": "timestamp",
|
||||||
|
"timestamp with time zone": "timestamptz",
|
||||||
|
"time without time zone": "time",
|
||||||
|
"time with time zone": "timetz",
|
||||||
|
}
|
||||||
|
|
||||||
|
var postgresEquivalentBaseTypeVariants = map[string][]string{
|
||||||
|
"varchar": {"varchar", "character varying"},
|
||||||
|
"char": {"char", "character"},
|
||||||
|
"timestamp": {"timestamp", "timestamp without time zone"},
|
||||||
|
"timestamptz": {"timestamptz", "timestamp with time zone"},
|
||||||
|
"time": {"time", "time without time zone"},
|
||||||
|
"timetz": {"timetz", "time with time zone"},
|
||||||
|
}
|
||||||
|
|
||||||
// GetPostgresBaseTypes returns a sorted-ish stable list of registered base type names.
|
// GetPostgresBaseTypes returns a sorted-ish stable list of registered base type names.
|
||||||
func GetPostgresBaseTypes() []string {
|
func GetPostgresBaseTypes() []string {
|
||||||
result := make([]string, 0, len(postgresBaseTypes))
|
result := make([]string, 0, len(postgresBaseTypes))
|
||||||
@@ -212,6 +230,86 @@ func CanonicalizeBaseType(baseType string) string {
|
|||||||
return base
|
return base
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// EquivalentBaseType resolves broader SQL-equivalent spellings to a common comparable form.
|
||||||
|
func EquivalentBaseType(baseType string) string {
|
||||||
|
base := CanonicalizeBaseType(baseType)
|
||||||
|
if equivalent, ok := postgresEquivalentBaseTypes[base]; ok {
|
||||||
|
return equivalent
|
||||||
|
}
|
||||||
|
return base
|
||||||
|
}
|
||||||
|
|
||||||
|
// NormalizeEquivalentSQLType returns a normalized SQL type string suitable for equality checks.
|
||||||
|
// Equivalent spellings such as "character varying(255)" and "varchar(255)" normalize identically.
|
||||||
|
func NormalizeEquivalentSQLType(sqlType string) string {
|
||||||
|
t := normalizeTypeToken(sqlType)
|
||||||
|
if t == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
arrayDepth := 0
|
||||||
|
for strings.HasSuffix(t, "[]") {
|
||||||
|
arrayDepth++
|
||||||
|
t = strings.TrimSpace(strings.TrimSuffix(t, "[]"))
|
||||||
|
}
|
||||||
|
|
||||||
|
modifier := ""
|
||||||
|
if idx := strings.Index(t, "("); idx >= 0 {
|
||||||
|
modifier = strings.TrimSpace(t[idx:])
|
||||||
|
t = strings.TrimSpace(t[:idx])
|
||||||
|
}
|
||||||
|
|
||||||
|
base := EquivalentBaseType(t)
|
||||||
|
normalized := base + modifier
|
||||||
|
for i := 0; i < arrayDepth; i++ {
|
||||||
|
normalized += "[]"
|
||||||
|
}
|
||||||
|
return normalized
|
||||||
|
}
|
||||||
|
|
||||||
|
// EquivalentSQLTypeVariants returns equivalent PostgreSQL spellings for a SQL type.
|
||||||
|
// Examples:
|
||||||
|
// - varchar(255) -> ["varchar(255)", "character varying(255)"]
|
||||||
|
// - timestamptz -> ["timestamptz", "timestamp with time zone"]
|
||||||
|
func EquivalentSQLTypeVariants(sqlType string) []string {
|
||||||
|
t := normalizeTypeToken(sqlType)
|
||||||
|
if t == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
arrayDepth := 0
|
||||||
|
for strings.HasSuffix(t, "[]") {
|
||||||
|
arrayDepth++
|
||||||
|
t = strings.TrimSpace(strings.TrimSuffix(t, "[]"))
|
||||||
|
}
|
||||||
|
|
||||||
|
modifier := ""
|
||||||
|
if idx := strings.Index(t, "("); idx >= 0 {
|
||||||
|
modifier = strings.TrimSpace(t[idx:])
|
||||||
|
t = strings.TrimSpace(t[:idx])
|
||||||
|
}
|
||||||
|
|
||||||
|
base := EquivalentBaseType(t)
|
||||||
|
bases := postgresEquivalentBaseTypeVariants[base]
|
||||||
|
if len(bases) == 0 {
|
||||||
|
bases = []string{base}
|
||||||
|
}
|
||||||
|
|
||||||
|
seen := make(map[string]bool, len(bases))
|
||||||
|
result := make([]string, 0, len(bases))
|
||||||
|
for _, variantBase := range bases {
|
||||||
|
variant := variantBase + modifier
|
||||||
|
for i := 0; i < arrayDepth; i++ {
|
||||||
|
variant += "[]"
|
||||||
|
}
|
||||||
|
if !seen[variant] {
|
||||||
|
seen[variant] = true
|
||||||
|
result = append(result, variant)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
// IsKnownPostgresType reports whether a type (including array forms) exists in the registry.
|
// IsKnownPostgresType reports whether a type (including array forms) exists in the registry.
|
||||||
func IsKnownPostgresType(sqlType string) bool {
|
func IsKnownPostgresType(sqlType string) bool {
|
||||||
base := CanonicalizeBaseType(ExtractBaseTypeLower(sqlType))
|
base := CanonicalizeBaseType(ExtractBaseTypeLower(sqlType))
|
||||||
|
|||||||
@@ -97,3 +97,51 @@ func TestPostgresTypeRegistry_TypeParsingAndCapabilities(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNormalizeEquivalentSQLType(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
input string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{input: "character varying(255)", want: "varchar(255)"},
|
||||||
|
{input: "varchar(255)", want: "varchar(255)"},
|
||||||
|
{input: "timestamp with time zone", want: "timestamptz"},
|
||||||
|
{input: "timestamptz", want: "timestamptz"},
|
||||||
|
{input: "time without time zone", want: "time"},
|
||||||
|
{input: "character varying(255)[]", want: "varchar(255)[]"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.input, func(t *testing.T) {
|
||||||
|
got := NormalizeEquivalentSQLType(tt.input)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Fatalf("NormalizeEquivalentSQLType(%q) = %q, want %q", tt.input, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEquivalentSQLTypeVariants(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
input string
|
||||||
|
want []string
|
||||||
|
}{
|
||||||
|
{input: "character varying(255)", want: []string{"varchar(255)", "character varying(255)"}},
|
||||||
|
{input: "timestamptz", want: []string{"timestamptz", "timestamp with time zone"}},
|
||||||
|
{input: "text[]", want: []string{"text[]"}},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.input, func(t *testing.T) {
|
||||||
|
got := EquivalentSQLTypeVariants(tt.input)
|
||||||
|
if len(got) != len(tt.want) {
|
||||||
|
t.Fatalf("EquivalentSQLTypeVariants(%q) len = %d, want %d (%v)", tt.input, len(got), len(tt.want), got)
|
||||||
|
}
|
||||||
|
for i := range tt.want {
|
||||||
|
if got[i] != tt.want[i] {
|
||||||
|
t.Fatalf("EquivalentSQLTypeVariants(%q)[%d] = %q, want %q", tt.input, i, got[i], tt.want[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -270,8 +270,16 @@ func (r *Reader) queryColumns(schemaName string) (map[string]map[string]*models.
|
|||||||
}
|
}
|
||||||
|
|
||||||
if numPrecision != nil {
|
if numPrecision != nil {
|
||||||
|
// For integer and serial types, numeric_precision is a bit-width (32, 64, 16)
|
||||||
|
// not a user-visible column parameter. Only store precision for types where
|
||||||
|
// it represents actual decimal/scale precision (numeric, decimal, float).
|
||||||
|
switch column.Type {
|
||||||
|
case "integer", "bigint", "smallint", "serial", "bigserial", "smallserial":
|
||||||
|
// skip — bit-width, not a column parameter
|
||||||
|
default:
|
||||||
column.Precision = *numPrecision
|
column.Precision = *numPrecision
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if numScale != nil {
|
if numScale != nil {
|
||||||
column.Scale = *numScale
|
column.Scale = *numScale
|
||||||
|
|||||||
+28
-62
@@ -259,12 +259,14 @@ func (r *Reader) close() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// mapDataType maps PostgreSQL data types while preserving exact type text when available.
|
// mapDataType maps a PostgreSQL data type to its canonical RelSpec name.
|
||||||
|
// For known built-in types, dimensions are stripped from the type string (they are
|
||||||
|
// stored separately in column.Length/Precision/Scale). For custom types (e.g.
|
||||||
|
// vector(1536), postgis geometries), the full formatted type is preserved.
|
||||||
func (r *Reader) mapDataType(pgType, udtName, formattedType string, hasNextval bool) string {
|
func (r *Reader) mapDataType(pgType, udtName, formattedType string, hasNextval bool) string {
|
||||||
normalizedPGType := strings.ToLower(strings.TrimSpace(pgType))
|
normalizedPGType := strings.ToLower(strings.TrimSpace(pgType))
|
||||||
|
|
||||||
// If the column has a nextval default, it's likely a serial type
|
// Detect serial types from nextval defaults before anything else.
|
||||||
// Map to the appropriate serial type instead of the base integer type
|
|
||||||
if hasNextval {
|
if hasNextval {
|
||||||
switch normalizedPGType {
|
switch normalizedPGType {
|
||||||
case "integer", "int", "int4":
|
case "integer", "int", "int4":
|
||||||
@@ -276,73 +278,38 @@ func (r *Reader) mapDataType(pgType, udtName, formattedType string, hasNextval b
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Prefer the database-provided formatted type; this preserves arrays/custom
|
|
||||||
// types/modifiers like text[], vector(1536), numeric(10,2), etc.
|
|
||||||
if strings.TrimSpace(formattedType) != "" {
|
|
||||||
return formattedType
|
|
||||||
}
|
|
||||||
|
|
||||||
// information_schema reports arrays generically as "ARRAY" with udt_name like "_text".
|
// information_schema reports arrays generically as "ARRAY" with udt_name like "_text".
|
||||||
if strings.EqualFold(pgType, "ARRAY") && strings.HasPrefix(udtName, "_") && len(udtName) > 1 {
|
if strings.EqualFold(pgType, "ARRAY") && strings.HasPrefix(udtName, "_") && len(udtName) > 1 {
|
||||||
return udtName[1:] + "[]"
|
return udtName[1:] + "[]"
|
||||||
}
|
}
|
||||||
|
|
||||||
// Map common PostgreSQL types
|
// Use the database-formatted type when available. For known built-in types, strip
|
||||||
typeMap := map[string]string{
|
// embedded dimensions (they are stored in column.Length/Precision/Scale separately).
|
||||||
"integer": "integer",
|
// For unknown/custom types, keep the full formatted string (e.g. vector(1536)).
|
||||||
"bigint": "bigint",
|
if strings.TrimSpace(formattedType) != "" {
|
||||||
"smallint": "smallint",
|
lower := strings.ToLower(strings.TrimSpace(formattedType))
|
||||||
"int": "integer",
|
isArray := strings.HasSuffix(lower, "[]")
|
||||||
"int2": "smallint",
|
base := strings.TrimSuffix(lower, "[]")
|
||||||
"int4": "integer",
|
if idx := strings.Index(base, "("); idx >= 0 {
|
||||||
"int8": "bigint",
|
base = strings.TrimSpace(base[:idx])
|
||||||
"serial": "serial",
|
}
|
||||||
"bigserial": "bigserial",
|
canonical := pgsql.NormalizePGType(base)
|
||||||
"smallserial": "smallserial",
|
if pgsql.IsKnownPGBaseType(canonical) {
|
||||||
"numeric": "numeric",
|
if isArray {
|
||||||
"decimal": "decimal",
|
return canonical + "[]"
|
||||||
"real": "real",
|
}
|
||||||
"double precision": "double precision",
|
return canonical
|
||||||
"float4": "real",
|
}
|
||||||
"float8": "double precision",
|
return formattedType
|
||||||
"money": "money",
|
|
||||||
"character varying": "varchar",
|
|
||||||
"varchar": "varchar",
|
|
||||||
"character": "char",
|
|
||||||
"char": "char",
|
|
||||||
"text": "text",
|
|
||||||
"boolean": "boolean",
|
|
||||||
"bool": "boolean",
|
|
||||||
"date": "date",
|
|
||||||
"time": "time",
|
|
||||||
"time without time zone": "time",
|
|
||||||
"time with time zone": "timetz",
|
|
||||||
"timestamp": "timestamp",
|
|
||||||
"timestamp without time zone": "timestamp",
|
|
||||||
"timestamp with time zone": "timestamptz",
|
|
||||||
"timestamptz": "timestamptz",
|
|
||||||
"interval": "interval",
|
|
||||||
"uuid": "uuid",
|
|
||||||
"json": "json",
|
|
||||||
"jsonb": "jsonb",
|
|
||||||
"bytea": "bytea",
|
|
||||||
"inet": "inet",
|
|
||||||
"cidr": "cidr",
|
|
||||||
"macaddr": "macaddr",
|
|
||||||
"xml": "xml",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Try mapped type first
|
// Fall back to normalizing the information_schema type name directly.
|
||||||
if mapped, exists := typeMap[normalizedPGType]; exists {
|
canonical := pgsql.NormalizePGType(normalizedPGType)
|
||||||
return mapped
|
if pgsql.IsKnownPGBaseType(canonical) {
|
||||||
|
return canonical
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use pgsql utilities if available
|
// Return UDT name for custom types.
|
||||||
if pgsql.ValidSQLType(pgType) {
|
|
||||||
return pgsql.GetSQLType(pgType)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Return UDT name for custom types (including array fallback when needed)
|
|
||||||
if udtName != "" {
|
if udtName != "" {
|
||||||
if strings.HasPrefix(udtName, "_") && len(udtName) > 1 {
|
if strings.HasPrefix(udtName, "_") && len(udtName) > 1 {
|
||||||
return udtName[1:] + "[]"
|
return udtName[1:] + "[]"
|
||||||
@@ -350,7 +317,6 @@ func (r *Reader) mapDataType(pgType, udtName, formattedType string, hasNextval b
|
|||||||
return udtName
|
return udtName
|
||||||
}
|
}
|
||||||
|
|
||||||
// Default to the original type
|
|
||||||
return pgType
|
return pgType
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -198,7 +198,7 @@ func TestMapDataType(t *testing.T) {
|
|||||||
{"unknown_type", "custom", "", "custom"}, // Should return UDT name
|
{"unknown_type", "custom", "", "custom"}, // Should return UDT name
|
||||||
{"ARRAY", "_text", "", "text[]"},
|
{"ARRAY", "_text", "", "text[]"},
|
||||||
{"USER-DEFINED", "vector", "vector(1536)", "vector(1536)"},
|
{"USER-DEFINED", "vector", "vector(1536)", "vector(1536)"},
|
||||||
{"character varying", "varchar", "character varying(255)", "character varying(255)"},
|
{"character varying", "varchar", "character varying(255)", "varchar"},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
|
|
||||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||||
"git.warky.dev/wdevs/relspecgo/pkg/readers"
|
"git.warky.dev/wdevs/relspecgo/pkg/readers"
|
||||||
|
sqlitepkg "git.warky.dev/wdevs/relspecgo/pkg/sqlite"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Reader implements the readers.Reader interface for SQLite databases
|
// Reader implements the readers.Reader interface for SQLite databases
|
||||||
@@ -183,59 +184,9 @@ func (r *Reader) close() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// mapDataType maps SQLite data types to canonical types
|
// mapDataType maps SQLite data types to canonical types.
|
||||||
func (r *Reader) mapDataType(sqliteType string) string {
|
func (r *Reader) mapDataType(sqliteType string) string {
|
||||||
// SQLite has a flexible type system, but we map common types
|
return sqlitepkg.ConvertSQLiteToCanonical(sqliteType)
|
||||||
typeMap := map[string]string{
|
|
||||||
"INTEGER": "int",
|
|
||||||
"INT": "int",
|
|
||||||
"TINYINT": "int8",
|
|
||||||
"SMALLINT": "int16",
|
|
||||||
"MEDIUMINT": "int",
|
|
||||||
"BIGINT": "int64",
|
|
||||||
"UNSIGNED BIG INT": "uint64",
|
|
||||||
"INT2": "int16",
|
|
||||||
"INT8": "int64",
|
|
||||||
"REAL": "float64",
|
|
||||||
"DOUBLE": "float64",
|
|
||||||
"DOUBLE PRECISION": "float64",
|
|
||||||
"FLOAT": "float32",
|
|
||||||
"NUMERIC": "decimal",
|
|
||||||
"DECIMAL": "decimal",
|
|
||||||
"BOOLEAN": "bool",
|
|
||||||
"BOOL": "bool",
|
|
||||||
"DATE": "date",
|
|
||||||
"DATETIME": "timestamp",
|
|
||||||
"TIMESTAMP": "timestamp",
|
|
||||||
"TEXT": "string",
|
|
||||||
"VARCHAR": "string",
|
|
||||||
"CHAR": "string",
|
|
||||||
"CHARACTER": "string",
|
|
||||||
"VARYING CHARACTER": "string",
|
|
||||||
"NCHAR": "string",
|
|
||||||
"NVARCHAR": "string",
|
|
||||||
"CLOB": "text",
|
|
||||||
"BLOB": "bytea",
|
|
||||||
}
|
|
||||||
|
|
||||||
// Try exact match first
|
|
||||||
if mapped, exists := typeMap[sqliteType]; exists {
|
|
||||||
return mapped
|
|
||||||
}
|
|
||||||
|
|
||||||
// Try case-insensitive match for common types
|
|
||||||
sqliteTypeUpper := sqliteType
|
|
||||||
if len(sqliteType) > 0 {
|
|
||||||
// Extract base type (e.g., "VARCHAR(255)" -> "VARCHAR")
|
|
||||||
for baseType := range typeMap {
|
|
||||||
if len(sqliteTypeUpper) >= len(baseType) && sqliteTypeUpper[:len(baseType)] == baseType {
|
|
||||||
return typeMap[baseType]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Default to string for unknown types
|
|
||||||
return "string"
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// deriveRelationship creates a relationship from a foreign key constraint
|
// deriveRelationship creates a relationship from a foreign key constraint
|
||||||
|
|||||||
@@ -0,0 +1,152 @@
|
|||||||
|
package sqlite
|
||||||
|
|
||||||
|
import "strings"
|
||||||
|
|
||||||
|
// SQLiteToCanonicalTypes maps SQLite type names to canonical types.
|
||||||
|
// SQLite has type affinity rules; this maps common type names including
|
||||||
|
// MySQL/PostgreSQL types that users write in SQLite schemas.
|
||||||
|
var SQLiteToCanonicalTypes = map[string]string{
|
||||||
|
// Integer affinity
|
||||||
|
"integer": "int",
|
||||||
|
"int": "int",
|
||||||
|
"tinyint": "int8",
|
||||||
|
"smallint": "int16",
|
||||||
|
"mediumint": "int",
|
||||||
|
"bigint": "int64",
|
||||||
|
"unsigned big int": "uint64",
|
||||||
|
"int2": "int16",
|
||||||
|
"int8": "int64",
|
||||||
|
// Real affinity
|
||||||
|
"real": "float64",
|
||||||
|
"double": "float64",
|
||||||
|
"double precision": "float64",
|
||||||
|
"float": "float32",
|
||||||
|
// Numeric affinity
|
||||||
|
"numeric": "decimal",
|
||||||
|
"decimal": "decimal",
|
||||||
|
// Boolean (stored as integer in SQLite)
|
||||||
|
"boolean": "bool",
|
||||||
|
"bool": "bool",
|
||||||
|
// Date/time (stored as text in SQLite)
|
||||||
|
"date": "date",
|
||||||
|
"datetime": "timestamp",
|
||||||
|
"timestamp": "timestamp",
|
||||||
|
// Text affinity
|
||||||
|
"text": "string",
|
||||||
|
"varchar": "string",
|
||||||
|
"char": "string",
|
||||||
|
"character": "string",
|
||||||
|
"varying character": "string",
|
||||||
|
"nchar": "string",
|
||||||
|
"nvarchar": "string",
|
||||||
|
"clob": "text",
|
||||||
|
// Blob affinity
|
||||||
|
"blob": "bytea",
|
||||||
|
}
|
||||||
|
|
||||||
|
// CanonicalToSQLiteAffinity maps type names to SQLite type affinity names.
|
||||||
|
// Accepts both Go canonical names ("int", "string") and SQL canonical names
|
||||||
|
// ("integer", "varchar") so the writer handles input from any reader.
|
||||||
|
// The five SQLite type affinities are TEXT, INTEGER, REAL, NUMERIC, BLOB.
|
||||||
|
var CanonicalToSQLiteAffinity = map[string]string{
|
||||||
|
// INTEGER affinity — Go canonical
|
||||||
|
"int": "INTEGER",
|
||||||
|
"int8": "INTEGER",
|
||||||
|
"int16": "INTEGER",
|
||||||
|
"int32": "INTEGER",
|
||||||
|
"int64": "INTEGER",
|
||||||
|
"uint": "INTEGER",
|
||||||
|
"uint8": "INTEGER",
|
||||||
|
"uint16": "INTEGER",
|
||||||
|
"uint32": "INTEGER",
|
||||||
|
"uint64": "INTEGER",
|
||||||
|
"bool": "INTEGER",
|
||||||
|
// INTEGER affinity — SQL canonical
|
||||||
|
"integer": "INTEGER",
|
||||||
|
"smallint": "INTEGER",
|
||||||
|
"bigint": "INTEGER",
|
||||||
|
"serial": "INTEGER",
|
||||||
|
"smallserial": "INTEGER",
|
||||||
|
"bigserial": "INTEGER",
|
||||||
|
"boolean": "INTEGER",
|
||||||
|
"tinyint": "INTEGER",
|
||||||
|
"mediumint": "INTEGER",
|
||||||
|
// REAL affinity — Go canonical
|
||||||
|
"float32": "REAL",
|
||||||
|
"float64": "REAL",
|
||||||
|
// REAL affinity — SQL canonical
|
||||||
|
"real": "REAL",
|
||||||
|
"float": "REAL",
|
||||||
|
"double": "REAL",
|
||||||
|
"double precision": "REAL",
|
||||||
|
// NUMERIC affinity
|
||||||
|
"decimal": "NUMERIC",
|
||||||
|
"numeric": "NUMERIC",
|
||||||
|
"money": "NUMERIC",
|
||||||
|
"smallmoney": "NUMERIC",
|
||||||
|
// BLOB affinity
|
||||||
|
"bytea": "BLOB",
|
||||||
|
"blob": "BLOB",
|
||||||
|
// TEXT affinity — Go canonical
|
||||||
|
"string": "TEXT",
|
||||||
|
"text": "TEXT",
|
||||||
|
// TEXT affinity — SQL canonical
|
||||||
|
"varchar": "TEXT",
|
||||||
|
"char": "TEXT",
|
||||||
|
"nvarchar": "TEXT",
|
||||||
|
"nchar": "TEXT",
|
||||||
|
"citext": "TEXT",
|
||||||
|
"date": "TEXT",
|
||||||
|
"time": "TEXT",
|
||||||
|
"timetz": "TEXT",
|
||||||
|
"timestamp": "TEXT",
|
||||||
|
"timestamptz": "TEXT",
|
||||||
|
"datetime": "TEXT",
|
||||||
|
"uuid": "TEXT",
|
||||||
|
"json": "TEXT",
|
||||||
|
"jsonb": "TEXT",
|
||||||
|
"xml": "TEXT",
|
||||||
|
"inet": "TEXT",
|
||||||
|
"cidr": "TEXT",
|
||||||
|
"macaddr": "TEXT",
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConvertSQLiteToCanonical converts a SQLite type name to the canonical type.
|
||||||
|
// Strips dimension parameters (e.g. VARCHAR(255) → string) and handles
|
||||||
|
// SQLite's flexible affinity rules. Defaults to "string" for unknown types.
|
||||||
|
func ConvertSQLiteToCanonical(sqliteType string) string {
|
||||||
|
base := strings.ToUpper(strings.TrimSpace(sqliteType))
|
||||||
|
if idx := strings.Index(base, "("); idx >= 0 {
|
||||||
|
base = strings.TrimSpace(base[:idx])
|
||||||
|
}
|
||||||
|
lower := strings.ToLower(base)
|
||||||
|
|
||||||
|
if canonical, ok := SQLiteToCanonicalTypes[lower]; ok {
|
||||||
|
return canonical
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prefix match for types like "VARYING CHARACTER(255)"
|
||||||
|
for key, canonical := range SQLiteToCanonicalTypes {
|
||||||
|
if strings.HasPrefix(lower, key) {
|
||||||
|
return canonical
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return "string"
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConvertCanonicalToSQLite converts a canonical type (or any SQL type) to its
|
||||||
|
// SQLite type affinity. Defaults to TEXT for unrecognised types.
|
||||||
|
func ConvertCanonicalToSQLite(canonicalType string) string {
|
||||||
|
normalized := strings.ToLower(strings.TrimSpace(canonicalType))
|
||||||
|
if idx := strings.Index(normalized, "("); idx >= 0 {
|
||||||
|
normalized = strings.TrimSpace(normalized[:idx])
|
||||||
|
}
|
||||||
|
normalized = strings.TrimSuffix(normalized, "[]")
|
||||||
|
|
||||||
|
if affinity, ok := CanonicalToSQLiteAffinity[normalized]; ok {
|
||||||
|
return affinity
|
||||||
|
}
|
||||||
|
|
||||||
|
return "TEXT"
|
||||||
|
}
|
||||||
@@ -144,7 +144,11 @@ func (w *MigrationWriter) WriteMigration(model *models.Database, current *models
|
|||||||
// Write header
|
// Write header
|
||||||
fmt.Fprintf(w.writer, "-- PostgreSQL Migration Script\n")
|
fmt.Fprintf(w.writer, "-- PostgreSQL Migration Script\n")
|
||||||
fmt.Fprintf(w.writer, "-- Generated by RelSpec\n")
|
fmt.Fprintf(w.writer, "-- Generated by RelSpec\n")
|
||||||
fmt.Fprintf(w.writer, "-- Source: %s -> %s\n\n", current.Name, model.Name)
|
fmt.Fprintf(w.writer, "-- Source: %s -> %s\n", current.Name, model.Name)
|
||||||
|
if w.options.ContinueOnError {
|
||||||
|
fmt.Fprintf(w.writer, "\\set ON_ERROR_STOP off\n")
|
||||||
|
}
|
||||||
|
fmt.Fprintf(w.writer, "\n")
|
||||||
|
|
||||||
// Write scripts
|
// Write scripts
|
||||||
for _, script := range scripts {
|
for _, script := range scripts {
|
||||||
@@ -160,13 +164,26 @@ func (w *MigrationWriter) WriteMigration(model *models.Database, current *models
|
|||||||
func (w *MigrationWriter) generateSchemaScripts(model *models.Schema, current *models.Schema) ([]MigrationScript, error) {
|
func (w *MigrationWriter) generateSchemaScripts(model *models.Schema, current *models.Schema) ([]MigrationScript, error) {
|
||||||
scripts := make([]MigrationScript, 0)
|
scripts := make([]MigrationScript, 0)
|
||||||
|
|
||||||
// Phase 1: Drop constraints and indexes that changed (Priority 11-50)
|
if schemaRequiresPGTrgm(model) {
|
||||||
|
scripts = append(scripts, MigrationScript{
|
||||||
|
ObjectName: "extension.pg_trgm",
|
||||||
|
ObjectType: "create extension",
|
||||||
|
Schema: model.Name,
|
||||||
|
Priority: 80,
|
||||||
|
Sequence: len(scripts),
|
||||||
|
Body: "CREATE EXTENSION IF NOT EXISTS pg_trgm;",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Phase 1: Drop constraints and indexes that changed (Priority 5-50)
|
||||||
|
var droppedFKs map[string]bool
|
||||||
if current != nil {
|
if current != nil {
|
||||||
dropScripts, err := w.generateDropScripts(model, current)
|
dropScripts, dropped, err := w.generateDropScripts(model, current)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to generate drop scripts: %w", err)
|
return nil, fmt.Errorf("failed to generate drop scripts: %w", err)
|
||||||
}
|
}
|
||||||
scripts = append(scripts, dropScripts...)
|
scripts = append(scripts, dropScripts...)
|
||||||
|
droppedFKs = dropped
|
||||||
}
|
}
|
||||||
|
|
||||||
// Phase 3: Create/Alter tables and columns (Priority 100-145)
|
// Phase 3: Create/Alter tables and columns (Priority 100-145)
|
||||||
@@ -184,7 +201,7 @@ func (w *MigrationWriter) generateSchemaScripts(model *models.Schema, current *m
|
|||||||
scripts = append(scripts, indexScripts...)
|
scripts = append(scripts, indexScripts...)
|
||||||
|
|
||||||
// Phase 5: Create foreign keys (Priority 195)
|
// Phase 5: Create foreign keys (Priority 195)
|
||||||
fkScripts, err := w.generateForeignKeyScripts(model, current)
|
fkScripts, err := w.generateForeignKeyScripts(model, current, droppedFKs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to generate foreign key scripts: %w", err)
|
return nil, fmt.Errorf("failed to generate foreign key scripts: %w", err)
|
||||||
}
|
}
|
||||||
@@ -200,9 +217,12 @@ func (w *MigrationWriter) generateSchemaScripts(model *models.Schema, current *m
|
|||||||
return scripts, nil
|
return scripts, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// generateDropScripts generates DROP scripts using templates
|
// generateDropScripts generates DROP scripts using templates.
|
||||||
func (w *MigrationWriter) generateDropScripts(model *models.Schema, current *models.Schema) ([]MigrationScript, error) {
|
// Returns the scripts and a set of FK constraint keys (schema.table.name) that were
|
||||||
|
// explicitly dropped because their referenced PK was being dropped, so they can be force-recreated.
|
||||||
|
func (w *MigrationWriter) generateDropScripts(model *models.Schema, current *models.Schema) ([]MigrationScript, map[string]bool, error) {
|
||||||
scripts := make([]MigrationScript, 0)
|
scripts := make([]MigrationScript, 0)
|
||||||
|
droppedFKs := make(map[string]bool)
|
||||||
|
|
||||||
// Build map of model tables for quick lookup
|
// Build map of model tables for quick lookup
|
||||||
modelTables := make(map[string]*models.Table)
|
modelTables := make(map[string]*models.Table)
|
||||||
@@ -229,6 +249,44 @@ func (w *MigrationWriter) generateDropScripts(model *models.Schema, current *mod
|
|||||||
shouldDrop = true
|
shouldDrop = true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if shouldDrop && currentConstraint.Type == models.PrimaryKeyConstraint {
|
||||||
|
// Drop FK constraints that depend on this PK before dropping the PK itself.
|
||||||
|
for _, otherTable := range current.Tables {
|
||||||
|
for fkName, fkConstraint := range otherTable.Constraints {
|
||||||
|
if fkConstraint.Type != models.ForeignKeyConstraint {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
refTable := fkConstraint.ReferencedTable
|
||||||
|
refSchema := fkConstraint.ReferencedSchema
|
||||||
|
if refSchema == "" {
|
||||||
|
refSchema = current.Name
|
||||||
|
}
|
||||||
|
if strings.EqualFold(refTable, currentTable.Name) && strings.EqualFold(refSchema, current.Name) {
|
||||||
|
fkKey := fmt.Sprintf("%s.%s.%s", current.Name, otherTable.Name, fkName)
|
||||||
|
if !droppedFKs[fkKey] {
|
||||||
|
droppedFKs[fkKey] = true
|
||||||
|
sql, err := w.executor.ExecuteDropConstraint(DropConstraintData{
|
||||||
|
SchemaName: current.Name,
|
||||||
|
TableName: otherTable.Name,
|
||||||
|
ConstraintName: fkName,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
scripts = append(scripts, MigrationScript{
|
||||||
|
ObjectName: fkKey,
|
||||||
|
ObjectType: "drop constraint",
|
||||||
|
Schema: current.Name,
|
||||||
|
Priority: 5,
|
||||||
|
Sequence: len(scripts),
|
||||||
|
Body: sql,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if shouldDrop {
|
if shouldDrop {
|
||||||
sql, err := w.executor.ExecuteDropConstraint(DropConstraintData{
|
sql, err := w.executor.ExecuteDropConstraint(DropConstraintData{
|
||||||
SchemaName: current.Name,
|
SchemaName: current.Name,
|
||||||
@@ -236,7 +294,7 @@ func (w *MigrationWriter) generateDropScripts(model *models.Schema, current *mod
|
|||||||
ConstraintName: constraintName,
|
ConstraintName: constraintName,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
script := MigrationScript{
|
script := MigrationScript{
|
||||||
@@ -268,7 +326,7 @@ func (w *MigrationWriter) generateDropScripts(model *models.Schema, current *mod
|
|||||||
IndexName: indexName,
|
IndexName: indexName,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
script := MigrationScript{
|
script := MigrationScript{
|
||||||
@@ -284,7 +342,7 @@ func (w *MigrationWriter) generateDropScripts(model *models.Schema, current *mod
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return scripts, nil
|
return scripts, droppedFKs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// generateTableScripts generates CREATE/ALTER TABLE scripts using templates
|
// generateTableScripts generates CREATE/ALTER TABLE scripts using templates
|
||||||
@@ -361,7 +419,7 @@ func (w *MigrationWriter) generateAlterTableScripts(schema *models.Schema, model
|
|||||||
SchemaName: schema.Name,
|
SchemaName: schema.Name,
|
||||||
TableName: modelTable.Name,
|
TableName: modelTable.Name,
|
||||||
ColumnName: modelCol.Name,
|
ColumnName: modelCol.Name,
|
||||||
ColumnType: pgsql.ConvertSQLType(modelCol.Type),
|
ColumnType: effectiveColumnSQLType(modelCol),
|
||||||
Default: defaultVal,
|
Default: defaultVal,
|
||||||
NotNull: modelCol.NotNull,
|
NotNull: modelCol.NotNull,
|
||||||
})
|
})
|
||||||
@@ -380,12 +438,13 @@ func (w *MigrationWriter) generateAlterTableScripts(schema *models.Schema, model
|
|||||||
scripts = append(scripts, script)
|
scripts = append(scripts, script)
|
||||||
} else if !columnsEqual(modelCol, currentCol) {
|
} else if !columnsEqual(modelCol, currentCol) {
|
||||||
// Column exists but properties changed
|
// Column exists but properties changed
|
||||||
if modelCol.Type != currentCol.Type {
|
if !columnTypesEqual(modelCol, currentCol) {
|
||||||
sql, err := w.executor.ExecuteAlterColumnType(AlterColumnTypeData{
|
sql, err := w.executor.ExecuteAlterColumnType(AlterColumnTypeData{
|
||||||
SchemaName: schema.Name,
|
SchemaName: schema.Name,
|
||||||
TableName: modelTable.Name,
|
TableName: modelTable.Name,
|
||||||
ColumnName: modelCol.Name,
|
ColumnName: modelCol.Name,
|
||||||
NewType: pgsql.ConvertSQLType(modelCol.Type),
|
NewType: effectiveAlterColumnSQLType(modelCol),
|
||||||
|
UsingExpr: buildAlterColumnUsingExpression(modelCol.Name, effectiveAlterColumnSQLType(modelCol)),
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -604,24 +663,25 @@ func buildIndexColumnExpressions(table *models.Table, index *models.Index, index
|
|||||||
for _, colName := range index.Columns {
|
for _, colName := range index.Columns {
|
||||||
colExpr := colName
|
colExpr := colName
|
||||||
if table != nil {
|
if table != nil {
|
||||||
if col, ok := table.Columns[colName]; ok && col != nil {
|
if col, ok := resolveIndexColumn(table, colName); ok && col != nil {
|
||||||
colExpr = col.SQLName()
|
colExpr = col.SQLName()
|
||||||
if strings.EqualFold(indexType, "gin") && isTextType(col.Type) {
|
if strings.EqualFold(indexType, "gin") {
|
||||||
opClass := extractOperatorClass(index.Comment)
|
opClass := ginOperatorClassForColumn(col, index.Comment)
|
||||||
if opClass == "" {
|
if opClass != "" {
|
||||||
opClass = "gin_trgm_ops"
|
|
||||||
}
|
|
||||||
colExpr = fmt.Sprintf("%s %s", col.SQLName(), opClass)
|
colExpr = fmt.Sprintf("%s %s", col.SQLName(), opClass)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
columnExprs = append(columnExprs, colExpr)
|
columnExprs = append(columnExprs, colExpr)
|
||||||
}
|
}
|
||||||
return columnExprs
|
return columnExprs
|
||||||
}
|
}
|
||||||
|
|
||||||
// generateForeignKeyScripts generates ADD CONSTRAINT FOREIGN KEY scripts using templates
|
// generateForeignKeyScripts generates ADD CONSTRAINT FOREIGN KEY scripts using templates.
|
||||||
func (w *MigrationWriter) generateForeignKeyScripts(model *models.Schema, current *models.Schema) ([]MigrationScript, error) {
|
// forceRecreate is a set of FK constraint keys (schema.table.name) that must be recreated
|
||||||
|
// even if unchanged, because their referenced PK was dropped and recreated.
|
||||||
|
func (w *MigrationWriter) generateForeignKeyScripts(model *models.Schema, current *models.Schema, forceRecreate map[string]bool) ([]MigrationScript, error) {
|
||||||
scripts := make([]MigrationScript, 0)
|
scripts := make([]MigrationScript, 0)
|
||||||
|
|
||||||
// Build map of current tables
|
// Build map of current tables
|
||||||
@@ -642,13 +702,16 @@ func (w *MigrationWriter) generateForeignKeyScripts(model *models.Schema, curren
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
shouldCreate := true
|
fkKey := fmt.Sprintf("%s.%s.%s", model.Name, modelTable.Name, constraintName)
|
||||||
|
shouldCreate := forceRecreate[fkKey]
|
||||||
|
|
||||||
if currentTable != nil {
|
if !shouldCreate {
|
||||||
if currentConstraint, exists := currentTable.Constraints[constraintName]; exists {
|
if currentTable == nil {
|
||||||
if constraintsEqual(constraint, currentConstraint) {
|
shouldCreate = true
|
||||||
shouldCreate = false
|
} else if currentConstraint, exists := currentTable.Constraints[constraintName]; !exists {
|
||||||
}
|
shouldCreate = true
|
||||||
|
} else if !constraintsEqual(constraint, currentConstraint) {
|
||||||
|
shouldCreate = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -875,11 +938,21 @@ func columnsEqual(col1, col2 *models.Column) bool {
|
|||||||
if col1 == nil || col2 == nil {
|
if col1 == nil || col2 == nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
return strings.EqualFold(col1.Type, col2.Type) &&
|
return columnTypesEqual(col1, col2) &&
|
||||||
col1.NotNull == col2.NotNull &&
|
col1.NotNull == col2.NotNull &&
|
||||||
fmt.Sprintf("%v", col1.Default) == fmt.Sprintf("%v", col2.Default)
|
fmt.Sprintf("%v", col1.Default) == fmt.Sprintf("%v", col2.Default)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func columnTypesEqual(col1, col2 *models.Column) bool {
|
||||||
|
if col1 == nil || col2 == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return strings.EqualFold(
|
||||||
|
pgsql.NormalizeEquivalentSQLType(effectiveColumnSQLType(col1)),
|
||||||
|
pgsql.NormalizeEquivalentSQLType(effectiveColumnSQLType(col2)),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
// constraintsEqual checks if two constraints are equal
|
// constraintsEqual checks if two constraints are equal
|
||||||
func constraintsEqual(c1, c2 *models.Constraint) bool {
|
func constraintsEqual(c1, c2 *models.Constraint) bool {
|
||||||
if c1 == nil || c2 == nil {
|
if c1 == nil || c2 == nil {
|
||||||
|
|||||||
@@ -97,6 +97,160 @@ func TestWriteMigration_ArrayDefault(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestWriteMigration_AltersColumnTypeWhenActualTypeDiffers(t *testing.T) {
|
||||||
|
current := models.InitDatabase("testdb")
|
||||||
|
currentSchema := models.InitSchema("public")
|
||||||
|
currentTable := models.InitTable("learnings", "public")
|
||||||
|
currentDetails := models.InitColumn("details", "learnings", "public")
|
||||||
|
currentDetails.Type = "jsonb"
|
||||||
|
currentTable.Columns["details"] = currentDetails
|
||||||
|
currentSchema.Tables = append(currentSchema.Tables, currentTable)
|
||||||
|
current.Schemas = append(current.Schemas, currentSchema)
|
||||||
|
|
||||||
|
model := models.InitDatabase("testdb")
|
||||||
|
modelSchema := models.InitSchema("public")
|
||||||
|
modelTable := models.InitTable("learnings", "public")
|
||||||
|
modelDetails := models.InitColumn("details", "learnings", "public")
|
||||||
|
modelDetails.Type = "text"
|
||||||
|
modelTable.Columns["details"] = modelDetails
|
||||||
|
modelSchema.Tables = append(modelSchema.Tables, modelTable)
|
||||||
|
model.Schemas = append(model.Schemas, modelSchema)
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
writer, err := NewMigrationWriter(&writers.WriterOptions{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create writer: %v", err)
|
||||||
|
}
|
||||||
|
writer.writer = &buf
|
||||||
|
|
||||||
|
if err := writer.WriteMigration(model, current); err != nil {
|
||||||
|
t.Fatalf("WriteMigration failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
output := buf.String()
|
||||||
|
if !strings.Contains(output, "ALTER TABLE public.learnings") || !strings.Contains(output, "ALTER COLUMN details TYPE text") {
|
||||||
|
t.Fatalf("expected migration to alter mismatched column type, got:\n%s", output)
|
||||||
|
}
|
||||||
|
if !strings.Contains(output, `ALTER COLUMN details TYPE text USING details::text;`) {
|
||||||
|
t.Fatalf("expected migration type alter to include USING cast, got:\n%s", output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWriteMigration_UsesStorageTypeForSerialAlterStatements(t *testing.T) {
|
||||||
|
current := models.InitDatabase("testdb")
|
||||||
|
currentSchema := models.InitSchema("public")
|
||||||
|
currentTable := models.InitTable("learnings", "public")
|
||||||
|
currentID := models.InitColumn("id", "learnings", "public")
|
||||||
|
currentID.Type = "uuid"
|
||||||
|
currentTable.Columns["id"] = currentID
|
||||||
|
currentSchema.Tables = append(currentSchema.Tables, currentTable)
|
||||||
|
current.Schemas = append(current.Schemas, currentSchema)
|
||||||
|
|
||||||
|
model := models.InitDatabase("testdb")
|
||||||
|
modelSchema := models.InitSchema("public")
|
||||||
|
modelTable := models.InitTable("learnings", "public")
|
||||||
|
modelID := models.InitColumn("id", "learnings", "public")
|
||||||
|
modelID.Type = "bigserial"
|
||||||
|
modelTable.Columns["id"] = modelID
|
||||||
|
modelSchema.Tables = append(modelSchema.Tables, modelTable)
|
||||||
|
model.Schemas = append(model.Schemas, modelSchema)
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
writer, err := NewMigrationWriter(&writers.WriterOptions{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create writer: %v", err)
|
||||||
|
}
|
||||||
|
writer.writer = &buf
|
||||||
|
|
||||||
|
if err := writer.WriteMigration(model, current); err != nil {
|
||||||
|
t.Fatalf("WriteMigration failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
output := buf.String()
|
||||||
|
if !strings.Contains(output, "ALTER COLUMN id TYPE bigint") {
|
||||||
|
t.Fatalf("expected serial alter to use bigint storage type, got:\n%s", output)
|
||||||
|
}
|
||||||
|
if strings.Contains(output, "ALTER COLUMN id TYPE bigserial;") {
|
||||||
|
t.Fatalf("did not expect invalid bigserial alter statement, got:\n%s", output)
|
||||||
|
}
|
||||||
|
if !strings.Contains(output, `ALTER COLUMN id TYPE bigint USING id::bigint;`) {
|
||||||
|
t.Fatalf("expected serial alter to include USING cast, got:\n%s", output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWriteMigration_ArrayAlterIncludesUsingCast(t *testing.T) {
|
||||||
|
current := models.InitDatabase("testdb")
|
||||||
|
currentSchema := models.InitSchema("public")
|
||||||
|
currentTable := models.InitTable("learnings", "public")
|
||||||
|
currentTags := models.InitColumn("tags", "learnings", "public")
|
||||||
|
currentTags.Type = "text"
|
||||||
|
currentTable.Columns["tags"] = currentTags
|
||||||
|
currentSchema.Tables = append(currentSchema.Tables, currentTable)
|
||||||
|
current.Schemas = append(current.Schemas, currentSchema)
|
||||||
|
|
||||||
|
model := models.InitDatabase("testdb")
|
||||||
|
modelSchema := models.InitSchema("public")
|
||||||
|
modelTable := models.InitTable("learnings", "public")
|
||||||
|
modelTags := models.InitColumn("tags", "learnings", "public")
|
||||||
|
modelTags.Type = "text[]"
|
||||||
|
modelTable.Columns["tags"] = modelTags
|
||||||
|
modelSchema.Tables = append(modelSchema.Tables, modelTable)
|
||||||
|
model.Schemas = append(model.Schemas, modelSchema)
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
writer, err := NewMigrationWriter(&writers.WriterOptions{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create writer: %v", err)
|
||||||
|
}
|
||||||
|
writer.writer = &buf
|
||||||
|
|
||||||
|
if err := writer.WriteMigration(model, current); err != nil {
|
||||||
|
t.Fatalf("WriteMigration failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
output := buf.String()
|
||||||
|
if !strings.Contains(output, `ALTER COLUMN tags TYPE text[] USING tags::text[];`) {
|
||||||
|
t.Fatalf("expected array alter to include USING cast, got:\n%s", output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWriteMigration_DoesNotAlterEquivalentNormalizedColumnType(t *testing.T) {
|
||||||
|
current := models.InitDatabase("testdb")
|
||||||
|
currentSchema := models.InitSchema("public")
|
||||||
|
currentTable := models.InitTable("users", "public")
|
||||||
|
currentEmail := models.InitColumn("email", "users", "public")
|
||||||
|
currentEmail.Type = "character varying"
|
||||||
|
currentEmail.Length = 255
|
||||||
|
currentTable.Columns["email"] = currentEmail
|
||||||
|
currentSchema.Tables = append(currentSchema.Tables, currentTable)
|
||||||
|
current.Schemas = append(current.Schemas, currentSchema)
|
||||||
|
|
||||||
|
model := models.InitDatabase("testdb")
|
||||||
|
modelSchema := models.InitSchema("public")
|
||||||
|
modelTable := models.InitTable("users", "public")
|
||||||
|
modelEmail := models.InitColumn("email", "users", "public")
|
||||||
|
modelEmail.Type = "varchar(255)"
|
||||||
|
modelTable.Columns["email"] = modelEmail
|
||||||
|
modelSchema.Tables = append(modelSchema.Tables, modelTable)
|
||||||
|
model.Schemas = append(model.Schemas, modelSchema)
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
writer, err := NewMigrationWriter(&writers.WriterOptions{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create writer: %v", err)
|
||||||
|
}
|
||||||
|
writer.writer = &buf
|
||||||
|
|
||||||
|
if err := writer.WriteMigration(model, current); err != nil {
|
||||||
|
t.Fatalf("WriteMigration failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
output := buf.String()
|
||||||
|
if strings.Contains(output, "ALTER COLUMN email TYPE") {
|
||||||
|
t.Fatalf("did not expect alter type for equivalent normalized types, got:\n%s", output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestWriteMigration_GinIndexOnTextUsesTrigramOperatorClass(t *testing.T) {
|
func TestWriteMigration_GinIndexOnTextUsesTrigramOperatorClass(t *testing.T) {
|
||||||
current := models.InitDatabase("testdb")
|
current := models.InitDatabase("testdb")
|
||||||
currentSchema := models.InitSchema("public")
|
currentSchema := models.InitSchema("public")
|
||||||
@@ -132,11 +286,54 @@ func TestWriteMigration_GinIndexOnTextUsesTrigramOperatorClass(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
output := buf.String()
|
output := buf.String()
|
||||||
|
if !strings.Contains(output, "CREATE EXTENSION IF NOT EXISTS pg_trgm;") {
|
||||||
|
t.Fatalf("expected trigram extension for text GIN migration index, got:\n%s", output)
|
||||||
|
}
|
||||||
if !strings.Contains(output, "USING gin (title gin_trgm_ops)") {
|
if !strings.Contains(output, "USING gin (title gin_trgm_ops)") {
|
||||||
t.Fatalf("expected GIN text index to include gin_trgm_ops, got:\n%s", output)
|
t.Fatalf("expected GIN text index to include gin_trgm_ops, got:\n%s", output)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestWriteMigration_GinIndexOnQuotedTextColumnUsesTrigramOperatorClass(t *testing.T) {
|
||||||
|
current := models.InitDatabase("testdb")
|
||||||
|
currentSchema := models.InitSchema("public")
|
||||||
|
current.Schemas = append(current.Schemas, currentSchema)
|
||||||
|
|
||||||
|
model := models.InitDatabase("testdb")
|
||||||
|
modelSchema := models.InitSchema("public")
|
||||||
|
|
||||||
|
table := models.InitTable("agent_personas", "public")
|
||||||
|
nameCol := models.InitColumn("name", "agent_personas", "public")
|
||||||
|
nameCol.Type = "text"
|
||||||
|
table.Columns["name"] = nameCol
|
||||||
|
|
||||||
|
index := &models.Index{
|
||||||
|
Name: "idx_agent_personas_name_gin",
|
||||||
|
Type: "gin",
|
||||||
|
Columns: []string{`"name"`},
|
||||||
|
}
|
||||||
|
table.Indexes[index.Name] = index
|
||||||
|
|
||||||
|
modelSchema.Tables = append(modelSchema.Tables, table)
|
||||||
|
model.Schemas = append(model.Schemas, modelSchema)
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
writer, err := NewMigrationWriter(&writers.WriterOptions{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create writer: %v", err)
|
||||||
|
}
|
||||||
|
writer.writer = &buf
|
||||||
|
|
||||||
|
if err := writer.WriteMigration(model, current); err != nil {
|
||||||
|
t.Fatalf("WriteMigration failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
output := buf.String()
|
||||||
|
if !strings.Contains(output, "USING gin (name gin_trgm_ops)") {
|
||||||
|
t.Fatalf("expected quoted text column GIN index to include gin_trgm_ops, got:\n%s", output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestWriteMigration_GinIndexOnTextArrayDoesNotUseTrigramOperatorClass(t *testing.T) {
|
func TestWriteMigration_GinIndexOnTextArrayDoesNotUseTrigramOperatorClass(t *testing.T) {
|
||||||
current := models.InitDatabase("testdb")
|
current := models.InitDatabase("testdb")
|
||||||
currentSchema := models.InitSchema("public")
|
currentSchema := models.InitSchema("public")
|
||||||
@@ -172,14 +369,98 @@ func TestWriteMigration_GinIndexOnTextArrayDoesNotUseTrigramOperatorClass(t *tes
|
|||||||
}
|
}
|
||||||
|
|
||||||
output := buf.String()
|
output := buf.String()
|
||||||
if !strings.Contains(output, "USING gin (tags)") {
|
if !strings.Contains(output, "USING gin (tags array_ops)") {
|
||||||
t.Fatalf("expected GIN array index without explicit trigram opclass, got:\n%s", output)
|
t.Fatalf("expected GIN array index with array_ops, got:\n%s", output)
|
||||||
}
|
}
|
||||||
if strings.Contains(output, "gin_trgm_ops") {
|
if strings.Contains(output, "gin_trgm_ops") {
|
||||||
t.Fatalf("did not expect gin_trgm_ops for text[] migration index, got:\n%s", output)
|
t.Fatalf("did not expect gin_trgm_ops for text[] migration index, got:\n%s", output)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestWriteMigration_GinIndexOnJSONBUsesJSONBOperatorClass(t *testing.T) {
|
||||||
|
current := models.InitDatabase("testdb")
|
||||||
|
currentSchema := models.InitSchema("public")
|
||||||
|
current.Schemas = append(current.Schemas, currentSchema)
|
||||||
|
|
||||||
|
model := models.InitDatabase("testdb")
|
||||||
|
modelSchema := models.InitSchema("public")
|
||||||
|
|
||||||
|
table := models.InitTable("learnings", "public")
|
||||||
|
detailsCol := models.InitColumn("details", "learnings", "public")
|
||||||
|
detailsCol.Type = "jsonb"
|
||||||
|
table.Columns["details"] = detailsCol
|
||||||
|
|
||||||
|
index := &models.Index{
|
||||||
|
Name: "idx_learnings_details",
|
||||||
|
Type: "gin",
|
||||||
|
Columns: []string{"details"},
|
||||||
|
}
|
||||||
|
table.Indexes[index.Name] = index
|
||||||
|
|
||||||
|
modelSchema.Tables = append(modelSchema.Tables, table)
|
||||||
|
model.Schemas = append(model.Schemas, modelSchema)
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
writer, err := NewMigrationWriter(&writers.WriterOptions{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create writer: %v", err)
|
||||||
|
}
|
||||||
|
writer.writer = &buf
|
||||||
|
|
||||||
|
if err := writer.WriteMigration(model, current); err != nil {
|
||||||
|
t.Fatalf("WriteMigration failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
output := buf.String()
|
||||||
|
if !strings.Contains(output, "USING gin (details jsonb_ops)") {
|
||||||
|
t.Fatalf("expected GIN jsonb index to include jsonb_ops, got:\n%s", output)
|
||||||
|
}
|
||||||
|
if strings.Contains(output, "gin_trgm_ops") {
|
||||||
|
t.Fatalf("did not expect gin_trgm_ops for jsonb migration index, got:\n%s", output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWriteMigration_GinIndexOnJSONBIgnoresIncompatibleTrigramOperatorClass(t *testing.T) {
|
||||||
|
current := models.InitDatabase("testdb")
|
||||||
|
currentSchema := models.InitSchema("public")
|
||||||
|
current.Schemas = append(current.Schemas, currentSchema)
|
||||||
|
|
||||||
|
model := models.InitDatabase("testdb")
|
||||||
|
modelSchema := models.InitSchema("public")
|
||||||
|
|
||||||
|
table := models.InitTable("learnings", "public")
|
||||||
|
detailsCol := models.InitColumn("details", "learnings", "public")
|
||||||
|
detailsCol.Type = "jsonb"
|
||||||
|
table.Columns["details"] = detailsCol
|
||||||
|
|
||||||
|
index := &models.Index{
|
||||||
|
Name: "idx_learnings_details",
|
||||||
|
Type: "gin",
|
||||||
|
Columns: []string{"details"},
|
||||||
|
Comment: "gin_trgm_ops",
|
||||||
|
}
|
||||||
|
table.Indexes[index.Name] = index
|
||||||
|
|
||||||
|
modelSchema.Tables = append(modelSchema.Tables, table)
|
||||||
|
model.Schemas = append(model.Schemas, modelSchema)
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
writer, err := NewMigrationWriter(&writers.WriterOptions{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create writer: %v", err)
|
||||||
|
}
|
||||||
|
writer.writer = &buf
|
||||||
|
|
||||||
|
if err := writer.WriteMigration(model, current); err != nil {
|
||||||
|
t.Fatalf("WriteMigration failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
output := buf.String()
|
||||||
|
if !strings.Contains(output, "USING gin (details jsonb_ops)") {
|
||||||
|
t.Fatalf("expected incompatible trigram hint on jsonb to fall back to jsonb_ops, got:\n%s", output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestWriteMigration_WithAudit(t *testing.T) {
|
func TestWriteMigration_WithAudit(t *testing.T) {
|
||||||
// Current database (empty)
|
// Current database (empty)
|
||||||
current := models.InitDatabase("testdb")
|
current := models.InitDatabase("testdb")
|
||||||
|
|||||||
@@ -5,6 +5,8 @@ import (
|
|||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
"unicode"
|
"unicode"
|
||||||
|
|
||||||
|
"git.warky.dev/wdevs/relspecgo/pkg/pgsql"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TemplateFunctions returns a map of custom template functions
|
// TemplateFunctions returns a map of custom template functions
|
||||||
@@ -162,24 +164,12 @@ func quoteIdent(s string) string {
|
|||||||
|
|
||||||
// Type conversion functions
|
// Type conversion functions
|
||||||
|
|
||||||
// goTypeToSQL converts Go type to PostgreSQL type
|
// goTypeToSQL converts Go type to PostgreSQL type using the shared pgsql type map.
|
||||||
func goTypeToSQL(goType string) string {
|
func goTypeToSQL(goType string) string {
|
||||||
typeMap := map[string]string{
|
if sqlType, ok := pgsql.GoToPGSQLTypes[goType]; ok {
|
||||||
"string": "text",
|
|
||||||
"int": "integer",
|
|
||||||
"int32": "integer",
|
|
||||||
"int64": "bigint",
|
|
||||||
"float32": "real",
|
|
||||||
"float64": "double precision",
|
|
||||||
"bool": "boolean",
|
|
||||||
"time.Time": "timestamp",
|
|
||||||
"[]byte": "bytea",
|
|
||||||
}
|
|
||||||
|
|
||||||
if sqlType, ok := typeMap[goType]; ok {
|
|
||||||
return sqlType
|
return sqlType
|
||||||
}
|
}
|
||||||
return "text" // Default
|
return "text"
|
||||||
}
|
}
|
||||||
|
|
||||||
// sqlTypeToGo converts PostgreSQL type to Go type
|
// sqlTypeToGo converts PostgreSQL type to Go type
|
||||||
|
|||||||
@@ -95,6 +95,16 @@ type AlterColumnTypeData struct {
|
|||||||
TableName string
|
TableName string
|
||||||
ColumnName string
|
ColumnName string
|
||||||
NewType string
|
NewType string
|
||||||
|
UsingExpr string
|
||||||
|
}
|
||||||
|
|
||||||
|
type AlterColumnTypeWithCheckData struct {
|
||||||
|
SchemaName string
|
||||||
|
TableName string
|
||||||
|
ColumnName string
|
||||||
|
NewType string
|
||||||
|
EquivalentTypes string
|
||||||
|
UsingExpr string
|
||||||
}
|
}
|
||||||
|
|
||||||
// AlterColumnDefaultData contains data for alter column default template
|
// AlterColumnDefaultData contains data for alter column default template
|
||||||
@@ -302,6 +312,15 @@ func (te *TemplateExecutor) ExecuteAlterColumnType(data AlterColumnTypeData) (st
|
|||||||
return buf.String(), nil
|
return buf.String(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (te *TemplateExecutor) ExecuteAlterColumnTypeWithCheck(data AlterColumnTypeWithCheckData) (string, error) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
err := te.templates.ExecuteTemplate(&buf, "alter_column_type_with_check.tmpl", data)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to execute alter_column_type_with_check template: %w", err)
|
||||||
|
}
|
||||||
|
return buf.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
// ExecuteAlterColumnDefault executes the alter column default template
|
// ExecuteAlterColumnDefault executes the alter column default template
|
||||||
func (te *TemplateExecutor) ExecuteAlterColumnDefault(data AlterColumnDefaultData) (string, error) {
|
func (te *TemplateExecutor) ExecuteAlterColumnDefault(data AlterColumnDefaultData) (string, error) {
|
||||||
var buf bytes.Buffer
|
var buf bytes.Buffer
|
||||||
|
|||||||
@@ -1,2 +1,2 @@
|
|||||||
ALTER TABLE {{qual_table .SchemaName .TableName}}
|
ALTER TABLE {{qual_table .SchemaName .TableName}}
|
||||||
ALTER COLUMN {{quote_ident .ColumnName}} TYPE {{.NewType}};
|
ALTER COLUMN {{quote_ident .ColumnName}} TYPE {{.NewType}}{{if .UsingExpr}} USING {{.UsingExpr}}{{end}};
|
||||||
|
|||||||
@@ -0,0 +1,22 @@
|
|||||||
|
DO $$
|
||||||
|
DECLARE
|
||||||
|
current_type text;
|
||||||
|
BEGIN
|
||||||
|
SELECT pg_catalog.format_type(a.atttypid, a.atttypmod)
|
||||||
|
INTO current_type
|
||||||
|
FROM pg_attribute a
|
||||||
|
JOIN pg_class t ON t.oid = a.attrelid
|
||||||
|
JOIN pg_namespace n ON n.oid = t.relnamespace
|
||||||
|
WHERE n.nspname = '{{.SchemaName}}'
|
||||||
|
AND t.relname = '{{.TableName}}'
|
||||||
|
AND a.attname = '{{.ColumnName}}'
|
||||||
|
AND a.attnum > 0
|
||||||
|
AND NOT a.attisdropped;
|
||||||
|
|
||||||
|
IF current_type IS NOT NULL
|
||||||
|
AND current_type <> ALL(ARRAY[{{.EquivalentTypes}}]) THEN
|
||||||
|
ALTER TABLE {{qual_table .SchemaName .TableName}}
|
||||||
|
ALTER COLUMN {{quote_ident .ColumnName}} TYPE {{.NewType}}{{if .UsingExpr}} USING {{.UsingExpr}}{{end}};
|
||||||
|
END IF;
|
||||||
|
END;
|
||||||
|
$$;
|
||||||
@@ -6,7 +6,7 @@ BEGIN
|
|||||||
SELECT tc.constraint_name,
|
SELECT tc.constraint_name,
|
||||||
COALESCE(
|
COALESCE(
|
||||||
ARRAY(
|
ARRAY(
|
||||||
SELECT a.attname
|
SELECT a.attname::text
|
||||||
FROM pg_constraint c
|
FROM pg_constraint c
|
||||||
JOIN pg_class t ON t.oid = c.conrelid
|
JOIN pg_class t ON t.oid = c.conrelid
|
||||||
JOIN pg_namespace n ON n.oid = t.relnamespace
|
JOIN pg_namespace n ON n.oid = t.relnamespace
|
||||||
@@ -31,7 +31,7 @@ BEGIN
|
|||||||
IF current_pk_name IS NOT NULL
|
IF current_pk_name IS NOT NULL
|
||||||
AND NOT current_pk_matches
|
AND NOT current_pk_matches
|
||||||
AND current_pk_name IN ({{.AutoGenNames}}) THEN
|
AND current_pk_name IN ({{.AutoGenNames}}) THEN
|
||||||
EXECUTE 'ALTER TABLE {{qual_table .SchemaName .TableName}} DROP CONSTRAINT ' || quote_ident(current_pk_name);
|
EXECUTE 'ALTER TABLE {{qual_table .SchemaName .TableName}} DROP CONSTRAINT ' || quote_ident(current_pk_name) || ' CASCADE';
|
||||||
END IF;
|
END IF;
|
||||||
|
|
||||||
-- Add the desired primary key only when no matching primary key already exists.
|
-- Add the desired primary key only when no matching primary key already exists.
|
||||||
|
|||||||
+255
-49
@@ -99,7 +99,11 @@ func (w *Writer) WriteDatabase(db *models.Database) error {
|
|||||||
// Write header comment
|
// Write header comment
|
||||||
fmt.Fprintf(w.writer, "-- PostgreSQL Database Schema\n")
|
fmt.Fprintf(w.writer, "-- PostgreSQL Database Schema\n")
|
||||||
fmt.Fprintf(w.writer, "-- Database: %s\n", db.Name)
|
fmt.Fprintf(w.writer, "-- Database: %s\n", db.Name)
|
||||||
fmt.Fprintf(w.writer, "-- Generated by RelSpec\n\n")
|
fmt.Fprintf(w.writer, "-- Generated by RelSpec\n")
|
||||||
|
if w.options.ContinueOnError {
|
||||||
|
fmt.Fprintf(w.writer, "\\set ON_ERROR_STOP off\n")
|
||||||
|
}
|
||||||
|
fmt.Fprintf(w.writer, "\n")
|
||||||
|
|
||||||
// Process each schema in the database
|
// Process each schema in the database
|
||||||
for _, schema := range db.Schemas {
|
for _, schema := range db.Schemas {
|
||||||
@@ -143,6 +147,10 @@ func (w *Writer) GenerateSchemaStatements(schema *models.Schema) ([]string, erro
|
|||||||
statements = append(statements, fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %s", schema.SQLName()))
|
statements = append(statements, fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %s", schema.SQLName()))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if schemaRequiresPGTrgm(schema) {
|
||||||
|
statements = append(statements, `CREATE EXTENSION IF NOT EXISTS pg_trgm`)
|
||||||
|
}
|
||||||
|
|
||||||
// Phase 2: Create sequences
|
// Phase 2: Create sequences
|
||||||
for _, table := range schema.Tables {
|
for _, table := range schema.Tables {
|
||||||
pk := table.GetPrimaryKey()
|
pk := table.GetPrimaryKey()
|
||||||
@@ -181,6 +189,12 @@ func (w *Writer) GenerateSchemaStatements(schema *models.Schema) ([]string, erro
|
|||||||
}
|
}
|
||||||
statements = append(statements, addColStmts...)
|
statements = append(statements, addColStmts...)
|
||||||
|
|
||||||
|
alterTypeStmts, err := w.GenerateAlterColumnTypeStatements(schema)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to generate alter column type statements: %w", err)
|
||||||
|
}
|
||||||
|
statements = append(statements, alterTypeStmts...)
|
||||||
|
|
||||||
// Phase 4: Primary keys
|
// Phase 4: Primary keys
|
||||||
for _, table := range schema.Tables {
|
for _, table := range schema.Tables {
|
||||||
// First check for explicit PrimaryKeyConstraint
|
// First check for explicit PrimaryKeyConstraint
|
||||||
@@ -261,16 +275,13 @@ func (w *Writer) GenerateSchemaStatements(schema *models.Schema) ([]string, erro
|
|||||||
columnExprs := make([]string, 0, len(index.Columns))
|
columnExprs := make([]string, 0, len(index.Columns))
|
||||||
for _, colName := range index.Columns {
|
for _, colName := range index.Columns {
|
||||||
colExpr := colName
|
colExpr := colName
|
||||||
if col, ok := table.Columns[colName]; ok {
|
if col, ok := resolveIndexColumn(table, colName); ok {
|
||||||
// For GIN indexes on text columns, add operator class
|
if strings.EqualFold(indexType, "gin") {
|
||||||
if strings.EqualFold(indexType, "gin") && isTextType(col.Type) {
|
if opClass := ginOperatorClassForColumn(col, index.Comment); opClass != "" {
|
||||||
opClass := extractOperatorClass(index.Comment)
|
|
||||||
if opClass == "" {
|
|
||||||
opClass = "gin_trgm_ops"
|
|
||||||
}
|
|
||||||
colExpr = fmt.Sprintf("%s %s", colName, opClass)
|
colExpr = fmt.Sprintf("%s %s", colName, opClass)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
columnExprs = append(columnExprs, colExpr)
|
columnExprs = append(columnExprs, colExpr)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -437,6 +448,33 @@ func (w *Writer) GenerateAddColumnStatements(schema *models.Schema) ([]string, e
|
|||||||
return statements, nil
|
return statements, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (w *Writer) GenerateAlterColumnTypeStatements(schema *models.Schema) ([]string, error) {
|
||||||
|
statements := []string{}
|
||||||
|
|
||||||
|
statements = append(statements, fmt.Sprintf("-- Alter column types for schema: %s", schema.Name))
|
||||||
|
|
||||||
|
for _, table := range schema.Tables {
|
||||||
|
columns := getSortedColumns(table.Columns)
|
||||||
|
for _, col := range columns {
|
||||||
|
targetType := effectiveAlterColumnSQLType(col)
|
||||||
|
stmt, err := w.executor.ExecuteAlterColumnTypeWithCheck(AlterColumnTypeWithCheckData{
|
||||||
|
SchemaName: schema.Name,
|
||||||
|
TableName: table.Name,
|
||||||
|
ColumnName: col.Name,
|
||||||
|
NewType: targetType,
|
||||||
|
EquivalentTypes: equivalentTypeListSQL(targetType),
|
||||||
|
UsingExpr: buildAlterColumnUsingExpression(col.Name, targetType),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to generate alter column type for %s.%s.%s: %w", schema.Name, table.Name, col.Name, err)
|
||||||
|
}
|
||||||
|
statements = append(statements, stmt)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return statements, nil
|
||||||
|
}
|
||||||
|
|
||||||
// GenerateAddColumnsForDatabase generates ALTER TABLE ADD COLUMN statements for the entire database
|
// GenerateAddColumnsForDatabase generates ALTER TABLE ADD COLUMN statements for the entire database
|
||||||
func (w *Writer) GenerateAddColumnsForDatabase(db *models.Database) ([]string, error) {
|
func (w *Writer) GenerateAddColumnsForDatabase(db *models.Database) ([]string, error) {
|
||||||
statements := []string{}
|
statements := []string{}
|
||||||
@@ -489,31 +527,7 @@ func (w *Writer) generateCreateTableStatement(schema *models.Schema, table *mode
|
|||||||
func (w *Writer) generateColumnDefinition(col *models.Column) string {
|
func (w *Writer) generateColumnDefinition(col *models.Column) string {
|
||||||
parts := []string{col.SQLName()}
|
parts := []string{col.SQLName()}
|
||||||
|
|
||||||
// Type with length/precision - convert to valid PostgreSQL type
|
parts = append(parts, effectiveColumnSQLType(col))
|
||||||
baseType := pgsql.ConvertSQLType(col.Type)
|
|
||||||
typeStr := baseType
|
|
||||||
hasExplicitTypeModifier := pgsql.HasExplicitTypeModifier(baseType)
|
|
||||||
|
|
||||||
// Only add size specifiers for types that support them
|
|
||||||
if !hasExplicitTypeModifier && col.Length > 0 && col.Precision == 0 {
|
|
||||||
if pgsql.SupportsLength(baseType) {
|
|
||||||
typeStr = fmt.Sprintf("%s(%d)", baseType, col.Length)
|
|
||||||
} else if isTextTypeWithoutLength(baseType) {
|
|
||||||
// Convert text with length to varchar
|
|
||||||
typeStr = fmt.Sprintf("varchar(%d)", col.Length)
|
|
||||||
}
|
|
||||||
// For types that don't support length (integer, bigint, etc.), ignore the length
|
|
||||||
} else if !hasExplicitTypeModifier && col.Precision > 0 {
|
|
||||||
if pgsql.SupportsPrecision(baseType) {
|
|
||||||
if col.Scale > 0 {
|
|
||||||
typeStr = fmt.Sprintf("%s(%d,%d)", baseType, col.Precision, col.Scale)
|
|
||||||
} else {
|
|
||||||
typeStr = fmt.Sprintf("%s(%d)", baseType, col.Precision)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// For types that don't support precision, ignore it
|
|
||||||
}
|
|
||||||
parts = append(parts, typeStr)
|
|
||||||
|
|
||||||
// NOT NULL
|
// NOT NULL
|
||||||
if col.NotNull {
|
if col.NotNull {
|
||||||
@@ -535,6 +549,64 @@ func (w *Writer) generateColumnDefinition(col *models.Column) string {
|
|||||||
return strings.Join(parts, " ")
|
return strings.Join(parts, " ")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func effectiveColumnSQLType(col *models.Column) string {
|
||||||
|
if col == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
baseType := pgsql.ConvertSQLType(col.Type)
|
||||||
|
typeStr := baseType
|
||||||
|
hasExplicitTypeModifier := pgsql.HasExplicitTypeModifier(baseType)
|
||||||
|
|
||||||
|
if !hasExplicitTypeModifier && col.Length > 0 && col.Precision == 0 {
|
||||||
|
if pgsql.SupportsLength(baseType) {
|
||||||
|
typeStr = fmt.Sprintf("%s(%d)", baseType, col.Length)
|
||||||
|
} else if isTextTypeWithoutLength(baseType) {
|
||||||
|
typeStr = fmt.Sprintf("varchar(%d)", col.Length)
|
||||||
|
}
|
||||||
|
} else if !hasExplicitTypeModifier && col.Precision > 0 {
|
||||||
|
if pgsql.SupportsPrecision(baseType) {
|
||||||
|
if col.Scale > 0 {
|
||||||
|
typeStr = fmt.Sprintf("%s(%d,%d)", baseType, col.Precision, col.Scale)
|
||||||
|
} else {
|
||||||
|
typeStr = fmt.Sprintf("%s(%d)", baseType, col.Precision)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return typeStr
|
||||||
|
}
|
||||||
|
|
||||||
|
func effectiveAlterColumnSQLType(col *models.Column) string {
|
||||||
|
typeStr := effectiveColumnSQLType(col)
|
||||||
|
switch strings.ToLower(strings.TrimSpace(typeStr)) {
|
||||||
|
case "smallserial":
|
||||||
|
return "smallint"
|
||||||
|
case "serial":
|
||||||
|
return "integer"
|
||||||
|
case "bigserial":
|
||||||
|
return "bigint"
|
||||||
|
default:
|
||||||
|
return typeStr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildAlterColumnUsingExpression(columnName, targetType string) string {
|
||||||
|
if strings.TrimSpace(columnName) == "" || strings.TrimSpace(targetType) == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s::%s", quoteIdent(columnName), targetType)
|
||||||
|
}
|
||||||
|
|
||||||
|
func equivalentTypeListSQL(sqlType string) string {
|
||||||
|
variants := pgsql.EquivalentSQLTypeVariants(sqlType)
|
||||||
|
quoted := make([]string, 0, len(variants))
|
||||||
|
for _, variant := range variants {
|
||||||
|
quoted = append(quoted, fmt.Sprintf("'%s'", escapeQuote(variant)))
|
||||||
|
}
|
||||||
|
return strings.Join(quoted, ", ")
|
||||||
|
}
|
||||||
|
|
||||||
// WriteSchema writes a single schema and all its tables
|
// WriteSchema writes a single schema and all its tables
|
||||||
func (w *Writer) WriteSchema(schema *models.Schema) error {
|
func (w *Writer) WriteSchema(schema *models.Schema) error {
|
||||||
if w.writer == nil {
|
if w.writer == nil {
|
||||||
@@ -546,6 +618,10 @@ func (w *Writer) WriteSchema(schema *models.Schema) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := w.writeRequiredExtensions(schema); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
// Phase 2: Create sequences (priority 80)
|
// Phase 2: Create sequences (priority 80)
|
||||||
if err := w.writeSequences(schema); err != nil {
|
if err := w.writeSequences(schema); err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -561,6 +637,10 @@ func (w *Writer) WriteSchema(schema *models.Schema) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := w.writeAlterColumnTypes(schema); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
// Phase 4: Create primary keys (priority 160)
|
// Phase 4: Create primary keys (priority 160)
|
||||||
if err := w.writePrimaryKeys(schema); err != nil {
|
if err := w.writePrimaryKeys(schema); err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -661,6 +741,16 @@ func (w *Writer) writeCreateSchema(schema *models.Schema) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (w *Writer) writeRequiredExtensions(schema *models.Schema) error {
|
||||||
|
if !schemaRequiresPGTrgm(schema) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintln(w.writer, "CREATE EXTENSION IF NOT EXISTS pg_trgm;")
|
||||||
|
fmt.Fprintln(w.writer)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// writeSequences generates CREATE SEQUENCE statements for identity columns
|
// writeSequences generates CREATE SEQUENCE statements for identity columns
|
||||||
func (w *Writer) writeSequences(schema *models.Schema) error {
|
func (w *Writer) writeSequences(schema *models.Schema) error {
|
||||||
fmt.Fprintf(w.writer, "-- Sequences for schema: %s\n", schema.Name)
|
fmt.Fprintf(w.writer, "-- Sequences for schema: %s\n", schema.Name)
|
||||||
@@ -754,6 +844,21 @@ func (w *Writer) writeAddColumns(schema *models.Schema) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (w *Writer) writeAlterColumnTypes(schema *models.Schema) error {
|
||||||
|
fmt.Fprintf(w.writer, "-- Alter column types for schema: %s\n", schema.Name)
|
||||||
|
|
||||||
|
statements, err := w.GenerateAlterColumnTypeStatements(schema)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
for _, stmt := range statements[1:] {
|
||||||
|
fmt.Fprint(w.writer, stmt)
|
||||||
|
fmt.Fprint(w.writer, "\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// writePrimaryKeys generates ALTER TABLE statements for primary keys
|
// writePrimaryKeys generates ALTER TABLE statements for primary keys
|
||||||
func (w *Writer) writePrimaryKeys(schema *models.Schema) error {
|
func (w *Writer) writePrimaryKeys(schema *models.Schema) error {
|
||||||
fmt.Fprintf(w.writer, "-- Primary keys for schema: %s\n", schema.Name)
|
fmt.Fprintf(w.writer, "-- Primary keys for schema: %s\n", schema.Name)
|
||||||
@@ -855,16 +960,14 @@ func (w *Writer) writeIndexes(schema *models.Schema) error {
|
|||||||
// Build column list with operator class support for GIN indexes
|
// Build column list with operator class support for GIN indexes
|
||||||
columnExprs := make([]string, 0, len(index.Columns))
|
columnExprs := make([]string, 0, len(index.Columns))
|
||||||
for _, colName := range index.Columns {
|
for _, colName := range index.Columns {
|
||||||
if col, ok := table.Columns[colName]; ok {
|
if col, ok := resolveIndexColumn(table, colName); ok {
|
||||||
colExpr := col.SQLName()
|
colExpr := col.SQLName()
|
||||||
// For GIN indexes on text columns, add operator class
|
if strings.EqualFold(index.Type, "gin") {
|
||||||
if strings.EqualFold(index.Type, "gin") && isTextType(col.Type) {
|
opClass := ginOperatorClassForColumn(col, index.Comment)
|
||||||
opClass := extractOperatorClass(index.Comment)
|
if opClass != "" {
|
||||||
if opClass == "" {
|
|
||||||
opClass = "gin_trgm_ops"
|
|
||||||
}
|
|
||||||
colExpr = fmt.Sprintf("%s %s", col.SQLName(), opClass)
|
colExpr = fmt.Sprintf("%s %s", col.SQLName(), opClass)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
columnExprs = append(columnExprs, colExpr)
|
columnExprs = append(columnExprs, colExpr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1250,23 +1353,126 @@ func isIntegerType(colType string) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// isTextType checks if a column type is a text type (for GIN index operator class)
|
// isTextType checks if a column type is a text type (for GIN index operator class)
|
||||||
func isTextType(colType string) bool {
|
// func isTextType(colType string) bool {
|
||||||
textTypes := []string{"text", "varchar", "character varying", "char", "character", "string"}
|
// textTypes := []string{"text", "varchar", "character varying", "char", "character", "string"}
|
||||||
lowerType := strings.ToLower(colType)
|
// lowerType := strings.ToLower(colType)
|
||||||
if strings.HasSuffix(lowerType, "[]") {
|
// if strings.HasSuffix(lowerType, "[]") {
|
||||||
|
// return false
|
||||||
|
// }
|
||||||
|
// for _, t := range textTypes {
|
||||||
|
// if strings.HasPrefix(lowerType, t) {
|
||||||
|
// return true
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// return false
|
||||||
|
// }
|
||||||
|
|
||||||
|
// isTextTypeWithoutLength checks if type is text (which should convert to varchar when length is specified)
|
||||||
|
func isTextTypeWithoutLength(colType string) bool {
|
||||||
|
return strings.EqualFold(colType, "text")
|
||||||
|
}
|
||||||
|
|
||||||
|
func ginOperatorClassForColumn(col *models.Column, comment string) string {
|
||||||
|
if col == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
sqlType := effectiveColumnSQLType(col)
|
||||||
|
baseType := pgsql.CanonicalizeBaseType(pgsql.ExtractBaseTypeLower(sqlType))
|
||||||
|
isArray := pgsql.IsArrayType(sqlType)
|
||||||
|
requested := extractOperatorClass(comment)
|
||||||
|
|
||||||
|
if requested != "" && ginOperatorClassCompatible(baseType, isArray, requested) {
|
||||||
|
return requested
|
||||||
|
}
|
||||||
|
|
||||||
|
if isArray {
|
||||||
|
return "array_ops"
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case isTextGinBaseType(baseType):
|
||||||
|
return "gin_trgm_ops"
|
||||||
|
case baseType == "jsonb":
|
||||||
|
return "jsonb_ops"
|
||||||
|
default:
|
||||||
|
return requested
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ginOperatorClassCompatible(baseType string, isArray bool, opClass string) bool {
|
||||||
|
switch opClass {
|
||||||
|
case "gin_trgm_ops", "gin_bigm_ops":
|
||||||
|
return !isArray && isTextGinBaseType(baseType)
|
||||||
|
case "jsonb_ops", "jsonb_path_ops":
|
||||||
|
return !isArray && baseType == "jsonb"
|
||||||
|
case "array_ops":
|
||||||
|
return isArray
|
||||||
|
default:
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func isTextGinBaseType(baseType string) bool {
|
||||||
|
switch baseType {
|
||||||
|
case "text", "varchar", "character varying", "char", "character", "string", "citext", "bpchar":
|
||||||
|
return true
|
||||||
|
default:
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
for _, t := range textTypes {
|
}
|
||||||
if strings.HasPrefix(lowerType, t) {
|
|
||||||
|
func schemaRequiresPGTrgm(schema *models.Schema) bool {
|
||||||
|
if schema == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for _, table := range schema.Tables {
|
||||||
|
if table == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, index := range table.Indexes {
|
||||||
|
if index == nil || !strings.EqualFold(index.Type, "gin") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, colName := range index.Columns {
|
||||||
|
col, ok := resolveIndexColumn(table, colName)
|
||||||
|
if !ok || col == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if ginOperatorClassForColumn(col, index.Comment) == "gin_trgm_ops" {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// isTextTypeWithoutLength checks if type is text (which should convert to varchar when length is specified)
|
func resolveIndexColumn(table *models.Table, colName string) (*models.Column, bool) {
|
||||||
func isTextTypeWithoutLength(colType string) bool {
|
if table == nil {
|
||||||
return strings.EqualFold(colType, "text")
|
return nil, false
|
||||||
|
}
|
||||||
|
if col, ok := table.Columns[colName]; ok && col != nil {
|
||||||
|
return col, true
|
||||||
|
}
|
||||||
|
|
||||||
|
normalized := strings.ToLower(strings.Trim(colName, `"`))
|
||||||
|
for key, col := range table.Columns {
|
||||||
|
if col == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if strings.ToLower(strings.Trim(key, `"`)) == normalized {
|
||||||
|
return col, true
|
||||||
|
}
|
||||||
|
if strings.ToLower(strings.Trim(col.Name, `"`)) == normalized {
|
||||||
|
return col, true
|
||||||
|
}
|
||||||
|
if strings.ToLower(strings.Trim(col.SQLName(), `"`)) == normalized {
|
||||||
|
return col, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
// formatStringList formats a list of strings as a SQL-safe comma-separated quoted list
|
// formatStringList formats a list of strings as a SQL-safe comma-separated quoted list
|
||||||
|
|||||||
@@ -116,14 +116,88 @@ func TestWriteDatabase_GinIndexOnTextArrayDoesNotUseTrigramOperatorClass(t *test
|
|||||||
}
|
}
|
||||||
|
|
||||||
output := buf.String()
|
output := buf.String()
|
||||||
if !strings.Contains(output, `USING gin (tags)`) {
|
if !strings.Contains(output, `USING gin (tags array_ops)`) {
|
||||||
t.Fatalf("expected GIN index on array column without explicit trigram opclass, got:\n%s", output)
|
t.Fatalf("expected GIN index on array column with array_ops, got:\n%s", output)
|
||||||
}
|
}
|
||||||
if strings.Contains(output, "gin_trgm_ops") {
|
if strings.Contains(output, "gin_trgm_ops") {
|
||||||
t.Fatalf("did not expect gin_trgm_ops for text[] column, got:\n%s", output)
|
t.Fatalf("did not expect gin_trgm_ops for text[] column, got:\n%s", output)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestWriteDatabase_GinIndexOnQuotedTextColumnUsesTrigramOperatorClass(t *testing.T) {
|
||||||
|
db := models.InitDatabase("testdb")
|
||||||
|
schema := models.InitSchema("public")
|
||||||
|
|
||||||
|
table := models.InitTable("agent_personas", "public")
|
||||||
|
|
||||||
|
nameCol := models.InitColumn("name", "agent_personas", "public")
|
||||||
|
nameCol.Type = "text"
|
||||||
|
table.Columns["name"] = nameCol
|
||||||
|
|
||||||
|
index := &models.Index{
|
||||||
|
Name: "idx_agent_personas_name_gin",
|
||||||
|
Type: "gin",
|
||||||
|
Columns: []string{`"name"`},
|
||||||
|
}
|
||||||
|
table.Indexes[index.Name] = index
|
||||||
|
|
||||||
|
schema.Tables = append(schema.Tables, table)
|
||||||
|
db.Schemas = append(db.Schemas, schema)
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
writer := NewWriter(&writers.WriterOptions{})
|
||||||
|
writer.writer = &buf
|
||||||
|
|
||||||
|
if err := writer.WriteDatabase(db); err != nil {
|
||||||
|
t.Fatalf("WriteDatabase failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
output := buf.String()
|
||||||
|
if !strings.Contains(output, `CREATE EXTENSION IF NOT EXISTS pg_trgm`) {
|
||||||
|
t.Fatalf("expected trigram extension for text GIN index, got:\n%s", output)
|
||||||
|
}
|
||||||
|
if !strings.Contains(output, `USING gin (name gin_trgm_ops)`) {
|
||||||
|
t.Fatalf("expected quoted text GIN index to include gin_trgm_ops, got:\n%s", output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWriteDatabase_GinIndexOnJSONBUsesJSONBOperatorClass(t *testing.T) {
|
||||||
|
db := models.InitDatabase("testdb")
|
||||||
|
schema := models.InitSchema("public")
|
||||||
|
|
||||||
|
table := models.InitTable("learnings", "public")
|
||||||
|
|
||||||
|
detailsCol := models.InitColumn("details", "learnings", "public")
|
||||||
|
detailsCol.Type = "jsonb"
|
||||||
|
table.Columns["details"] = detailsCol
|
||||||
|
|
||||||
|
index := &models.Index{
|
||||||
|
Name: "idx_learnings_details",
|
||||||
|
Type: "gin",
|
||||||
|
Columns: []string{"details"},
|
||||||
|
}
|
||||||
|
table.Indexes[index.Name] = index
|
||||||
|
|
||||||
|
schema.Tables = append(schema.Tables, table)
|
||||||
|
db.Schemas = append(db.Schemas, schema)
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
writer := NewWriter(&writers.WriterOptions{})
|
||||||
|
writer.writer = &buf
|
||||||
|
|
||||||
|
if err := writer.WriteDatabase(db); err != nil {
|
||||||
|
t.Fatalf("WriteDatabase failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
output := buf.String()
|
||||||
|
if !strings.Contains(output, `USING gin (details jsonb_ops)`) {
|
||||||
|
t.Fatalf("expected GIN jsonb index to include jsonb_ops, got:\n%s", output)
|
||||||
|
}
|
||||||
|
if strings.Contains(output, "gin_trgm_ops") {
|
||||||
|
t.Fatalf("did not expect gin_trgm_ops for jsonb column, got:\n%s", output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestWriteForeignKeys(t *testing.T) {
|
func TestWriteForeignKeys(t *testing.T) {
|
||||||
// Create a test database with two related tables
|
// Create a test database with two related tables
|
||||||
db := models.InitDatabase("testdb")
|
db := models.InitDatabase("testdb")
|
||||||
@@ -984,3 +1058,82 @@ func TestWriteAddColumnStatements(t *testing.T) {
|
|||||||
t.Errorf("Output missing DO block\nFull output:\n%s", output)
|
t.Errorf("Output missing DO block\nFull output:\n%s", output)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestWriteSchema_EmitsGuardedAlterColumnTypeStatements(t *testing.T) {
|
||||||
|
db := models.InitDatabase("testdb")
|
||||||
|
schema := models.InitSchema("public")
|
||||||
|
|
||||||
|
table := models.InitTable("agent_skills", "public")
|
||||||
|
|
||||||
|
nameCol := models.InitColumn("name", "agent_skills", "public")
|
||||||
|
nameCol.Type = "character varying"
|
||||||
|
nameCol.Length = 255
|
||||||
|
table.Columns["name"] = nameCol
|
||||||
|
|
||||||
|
tagsCol := models.InitColumn("tags", "agent_skills", "public")
|
||||||
|
tagsCol.Type = "text[]"
|
||||||
|
table.Columns["tags"] = tagsCol
|
||||||
|
|
||||||
|
schema.Tables = append(schema.Tables, table)
|
||||||
|
db.Schemas = append(db.Schemas, schema)
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
writer := NewWriter(&writers.WriterOptions{})
|
||||||
|
writer.writer = &buf
|
||||||
|
|
||||||
|
if err := writer.WriteDatabase(db); err != nil {
|
||||||
|
t.Fatalf("WriteDatabase failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
output := buf.String()
|
||||||
|
if !strings.Contains(output, "-- Alter column types for schema: public") {
|
||||||
|
t.Fatalf("expected alter column type section, got:\n%s", output)
|
||||||
|
}
|
||||||
|
if !strings.Contains(output, "pg_catalog.format_type") {
|
||||||
|
t.Fatalf("expected guarded live-type check, got:\n%s", output)
|
||||||
|
}
|
||||||
|
if !strings.Contains(output, "ALTER COLUMN name TYPE character varying(255)") {
|
||||||
|
t.Fatalf("expected guarded alter for character varying(255), got:\n%s", output)
|
||||||
|
}
|
||||||
|
if !strings.Contains(output, "ARRAY['varchar(255)', 'character varying(255)']") {
|
||||||
|
t.Fatalf("expected equivalent type spellings for varchar guard, got:\n%s", output)
|
||||||
|
}
|
||||||
|
if !strings.Contains(output, "ALTER COLUMN tags TYPE text[]") {
|
||||||
|
t.Fatalf("expected guarded alter for array type, got:\n%s", output)
|
||||||
|
}
|
||||||
|
if !strings.Contains(output, `ALTER COLUMN tags TYPE text[] USING tags::text[];`) {
|
||||||
|
t.Fatalf("expected guarded alter for array type to include USING cast, got:\n%s", output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWriteSchema_UsesStorageTypeForSerialAlterStatements(t *testing.T) {
|
||||||
|
db := models.InitDatabase("testdb")
|
||||||
|
schema := models.InitSchema("public")
|
||||||
|
|
||||||
|
table := models.InitTable("learnings", "public")
|
||||||
|
idCol := models.InitColumn("id", "learnings", "public")
|
||||||
|
idCol.Type = "bigserial"
|
||||||
|
table.Columns["id"] = idCol
|
||||||
|
|
||||||
|
schema.Tables = append(schema.Tables, table)
|
||||||
|
db.Schemas = append(db.Schemas, schema)
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
writer := NewWriter(&writers.WriterOptions{})
|
||||||
|
writer.writer = &buf
|
||||||
|
|
||||||
|
if err := writer.WriteDatabase(db); err != nil {
|
||||||
|
t.Fatalf("WriteDatabase failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
output := buf.String()
|
||||||
|
if !strings.Contains(output, "ALTER COLUMN id TYPE bigint") {
|
||||||
|
t.Fatalf("expected serial alter to use bigint storage type, got:\n%s", output)
|
||||||
|
}
|
||||||
|
if strings.Contains(output, "ALTER COLUMN id TYPE bigserial;") {
|
||||||
|
t.Fatalf("did not expect invalid bigserial alter statement, got:\n%s", output)
|
||||||
|
}
|
||||||
|
if !strings.Contains(output, `ALTER COLUMN id TYPE bigint USING id::bigint;`) {
|
||||||
|
t.Fatalf("expected serial alter to include USING cast, got:\n%s", output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,9 +2,11 @@ package sqlite
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
sqlitepkg "git.warky.dev/wdevs/relspecgo/pkg/sqlite"
|
||||||
)
|
)
|
||||||
|
|
||||||
// SQLite type affinities
|
// SQLite type affinity constants
|
||||||
const (
|
const (
|
||||||
TypeText = "TEXT"
|
TypeText = "TEXT"
|
||||||
TypeInteger = "INTEGER"
|
TypeInteger = "INTEGER"
|
||||||
@@ -13,72 +15,27 @@ const (
|
|||||||
TypeBlob = "BLOB"
|
TypeBlob = "BLOB"
|
||||||
)
|
)
|
||||||
|
|
||||||
// MapPostgreSQLType maps PostgreSQL data types to SQLite type affinities
|
// MapTypeToSQLite maps any SQL or Go canonical type to a SQLite type affinity.
|
||||||
|
// Handles input from any reader (PostgreSQL, MSSQL, SQLite, Go canonical).
|
||||||
|
func MapTypeToSQLite(colType string) string {
|
||||||
|
return sqlitepkg.ConvertCanonicalToSQLite(colType)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MapPostgreSQLType is an alias for MapTypeToSQLite kept for compatibility.
|
||||||
|
//
|
||||||
|
// Deprecated: use MapTypeToSQLite.
|
||||||
func MapPostgreSQLType(pgType string) string {
|
func MapPostgreSQLType(pgType string) string {
|
||||||
// Normalize the type
|
return MapTypeToSQLite(pgType)
|
||||||
normalized := strings.ToLower(strings.TrimSpace(pgType))
|
|
||||||
|
|
||||||
// Remove array notation if present
|
|
||||||
normalized = strings.TrimSuffix(normalized, "[]")
|
|
||||||
|
|
||||||
// Remove precision/scale if present
|
|
||||||
if idx := strings.Index(normalized, "("); idx != -1 {
|
|
||||||
normalized = normalized[:idx]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Map to SQLite type affinity
|
// IsIntegerType reports whether a column type maps to SQLite INTEGER affinity.
|
||||||
switch normalized {
|
|
||||||
// TEXT affinity
|
|
||||||
case "varchar", "character varying", "text", "char", "character",
|
|
||||||
"citext", "uuid", "timestamp", "timestamptz", "timestamp with time zone",
|
|
||||||
"timestamp without time zone", "date", "time", "timetz", "time with time zone",
|
|
||||||
"time without time zone", "json", "jsonb", "xml", "inet", "cidr", "macaddr":
|
|
||||||
return TypeText
|
|
||||||
|
|
||||||
// INTEGER affinity
|
|
||||||
case "int", "int2", "int4", "int8", "integer", "smallint", "bigint",
|
|
||||||
"serial", "smallserial", "bigserial", "boolean", "bool":
|
|
||||||
return TypeInteger
|
|
||||||
|
|
||||||
// REAL affinity
|
|
||||||
case "real", "float", "float4", "float8", "double precision":
|
|
||||||
return TypeReal
|
|
||||||
|
|
||||||
// NUMERIC affinity
|
|
||||||
case "numeric", "decimal", "money":
|
|
||||||
return TypeNumeric
|
|
||||||
|
|
||||||
// BLOB affinity
|
|
||||||
case "bytea", "blob":
|
|
||||||
return TypeBlob
|
|
||||||
|
|
||||||
default:
|
|
||||||
// Default to TEXT for unknown types
|
|
||||||
return TypeText
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsIntegerType checks if a column type should be treated as integer
|
|
||||||
func IsIntegerType(colType string) bool {
|
func IsIntegerType(colType string) bool {
|
||||||
normalized := strings.ToLower(strings.TrimSpace(colType))
|
return MapTypeToSQLite(colType) == TypeInteger
|
||||||
normalized = strings.TrimSuffix(normalized, "[]")
|
|
||||||
if idx := strings.Index(normalized, "("); idx != -1 {
|
|
||||||
normalized = normalized[:idx]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
switch normalized {
|
// MapBooleanValue converts common boolean literals to SQLite integers (1/0).
|
||||||
case "int", "int2", "int4", "int8", "integer", "smallint", "bigint",
|
|
||||||
"serial", "smallserial", "bigserial":
|
|
||||||
return true
|
|
||||||
default:
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// MapBooleanValue converts PostgreSQL boolean literals to SQLite (0/1)
|
|
||||||
func MapBooleanValue(value string) string {
|
func MapBooleanValue(value string) string {
|
||||||
normalized := strings.ToLower(strings.TrimSpace(value))
|
switch strings.ToLower(strings.TrimSpace(value)) {
|
||||||
switch normalized {
|
|
||||||
case "true", "t", "yes", "y", "1":
|
case "true", "t", "yes", "y", "1":
|
||||||
return "1"
|
return "1"
|
||||||
case "false", "f", "no", "n", "0":
|
case "false", "f", "no", "n", "0":
|
||||||
|
|||||||
@@ -54,6 +54,10 @@ type WriterOptions struct {
|
|||||||
// Prisma7 enables Prisma 7-specific output for Prisma writers.
|
// Prisma7 enables Prisma 7-specific output for Prisma writers.
|
||||||
Prisma7 bool
|
Prisma7 bool
|
||||||
|
|
||||||
|
// ContinueOnError instructs SQL writers to prepend `\set ON_ERROR_STOP off`
|
||||||
|
// to their output so that psql continues past errors instead of stopping.
|
||||||
|
ContinueOnError bool
|
||||||
|
|
||||||
// Additional options can be added here as needed
|
// Additional options can be added here as needed
|
||||||
Metadata map[string]interface{}
|
Metadata map[string]interface{}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user