Compare commits

..

20 Commits

Author SHA1 Message Date
warkanum bb7ceb37fe chore(release): update package version to 1.0.58
Release / test (push) Successful in -32m53s
Release / release (push) Successful in -20m53s
Release / pkg-deb (push) Successful in -31m34s
Release / pkg-rpm (push) Successful in -31m3s
Release / pkg-aur (push) Successful in -11m7s
2026-05-19 19:26:45 +02:00
warkanum 6a759ef3d1 fix(mssql): correct order of MSSQL type mappings 2026-05-19 19:26:30 +02:00
warkanum cb735f0754 feat(sqlite): add SQLite type mapping and conversion functions
* Implement SQLiteToCanonicalTypes for type mapping
* Add ConvertSQLiteToCanonical and ConvertCanonicalToSQLite functions
* Update mapDataType to utilize new conversion logic
2026-05-19 19:26:09 +02:00
warkanum 80fb49bc5e refactor(datatypes): remove redundant normalization function 2026-05-19 19:12:54 +02:00
warkanum 9190df81dd feat(merge): enhance type conflict detection for columns
* Introduced extractTypeParts function to handle embedded dimensions in type strings.
* Updated columnTypeConflict to utilize new type extraction logic.
* Improved PostgreSQL type normalization and handling in various components.
2026-05-19 19:12:27 +02:00
Hein 9235ef5e08 chore(release): update package version to 1.0.57
Release / test (push) Successful in -32m7s
Release / release (push) Successful in -27m58s
Release / pkg-deb (push) Successful in -30m52s
Release / pkg-aur (push) Successful in -29m5s
Release / pkg-rpm (push) Successful in -29m58s
2026-05-07 14:45:02 +02:00
Hein b91d6b33b5 feat(writer): add continue-on-error option for SQL writers
* Introduce ContinueOnError option to WriterOptions
* Update writer functions to support continue-on-error behavior
* Modify migration and database writing to handle continue-on-error
2026-05-07 14:44:36 +02:00
warkanum 30ef1db010 chore(release): update package version to 1.0.56
Release / test (push) Successful in -32m19s
Release / release (push) Successful in -31m39s
Release / pkg-deb (push) Successful in -31m57s
Release / pkg-aur (push) Successful in -31m46s
Release / pkg-rpm (push) Successful in -4m7s
2026-05-05 14:51:10 +02:00
warkanum 2d97a47ee1 feat: Enhance PostgreSQL type handling and migration scripts
- Introduced equivalent base types and variants for PostgreSQL types to normalize type comparisons.
- Added functions for normalizing SQL types and retrieving equivalent type variants.
- Updated migration writer to handle type alterations with checks for existing types.
- Implemented logic to create necessary extensions (e.g., pg_trgm) based on schema requirements.
- Enhanced tests to cover new functionality for type normalization and migration handling.
- Improved handling of GIN indexes to use appropriate operator classes based on column types.
2026-05-05 14:50:34 +02:00
warkanum 72200ea72e chore(release): update package version to 1.0.55
Release / test (push) Successful in -32m1s
Release / release (push) Successful in -31m13s
Release / pkg-aur (push) Successful in -32m13s
Release / pkg-deb (push) Successful in -31m12s
Release / pkg-rpm (push) Successful in -29m45s
2026-05-05 11:36:29 +02:00
warkanum 608893a3d6 feat(index): implement GIN index support for quoted text columns and enhance index column resolution 2026-05-05 11:32:15 +02:00
warkanum 53ff745d5d chore(release): update package version to 1.0.54
Release / test (push) Successful in -31m47s
Release / release (push) Successful in -31m9s
Release / pkg-aur (push) Successful in -31m57s
Release / pkg-deb (push) Successful in -31m1s
Release / pkg-rpm (push) Successful in -29m27s
2026-05-05 11:12:49 +02:00
warkanum 17bc8ed395 feat(migration): enhance primary key handling and add GIN index support in migration writer 2026-05-05 11:12:23 +02:00
warkanum a447b68b22 chore(release): update package version to 1.0.53
Release / test (push) Successful in -31m55s
Release / release (push) Successful in -31m19s
Release / pkg-aur (push) Successful in -32m3s
Release / pkg-deb (push) Successful in -31m21s
Release / pkg-rpm (push) Successful in -28m4s
2026-05-05 10:48:27 +02:00
warkanum 4303dcf59b Support typed primary key helpers in gorm and bun writers 2026-05-05 10:32:33 +02:00
warkanum e828d48798 chore(release): update package version to 1.0.52
Release / test (push) Successful in -32m39s
Release / release (push) Successful in -32m1s
Release / pkg-deb (push) Successful in -32m9s
Release / pkg-aur (push) Successful in -31m37s
Release / pkg-rpm (push) Successful in -27m28s
2026-05-03 17:19:22 +02:00
warkanum 6e470a9239 fix(type_mapper): adjust array tag handling in BuildBunTag 2026-05-03 17:18:58 +02:00
warkanum 096815fe49 chore(release): update package version to 1.0.51
Release / test (push) Successful in -32m30s
Release / release (push) Successful in -31m54s
Release / pkg-aur (push) Successful in -32m31s
Release / pkg-deb (push) Successful in -32m7s
Release / pkg-rpm (push) Successful in -30m36s
2026-05-03 16:11:13 +02:00
warkanum b8f60203cb fix(type_mapper): handle PostgreSQL array types in tags
* Update BuildBunTag to append "array" for array types
* Add tests for handling array types in TypeMapper
* Adjust regex in SanitizeStructTagValue to preserve array suffix
2026-05-03 16:11:01 +02:00
Hein 15763f60cc Fix GIN opclass handling for array columns 2026-04-30 20:35:06 +02:00
41 changed files with 2508 additions and 423 deletions
+5 -3
View File
@@ -53,6 +53,7 @@ var (
convertSchemaFilter string
convertFlattenSchema bool
convertNullableTypes string
convertContinueOnError bool
)
var convertCmd = &cobra.Command{
@@ -177,6 +178,7 @@ func init() {
convertCmd.Flags().StringVar(&convertSchemaFilter, "schema", "", "Filter to a specific schema by name (required for formats like dctx that only support single schemas)")
convertCmd.Flags().BoolVar(&convertFlattenSchema, "flatten-schema", false, "Flatten schema.table names to schema_table (useful for databases like SQLite that do not support schemas)")
convertCmd.Flags().StringVar(&convertNullableTypes, "types", "", "Nullable type package for code-gen writers (bun/gorm): 'resolvespec' (default) or 'stdlib' (database/sql)")
convertCmd.Flags().BoolVar(&convertContinueOnError, "continue-on-error", false, "Prepend \\set ON_ERROR_STOP off to generated SQL so psql continues past errors (pgsql output only)")
err := convertCmd.MarkFlagRequired("from")
if err != nil {
@@ -243,7 +245,7 @@ func runConvert(cmd *cobra.Command, args []string) error {
fmt.Fprintf(os.Stderr, " Schema: %s\n", convertSchemaFilter)
}
if err := writeDatabase(db, convertTargetType, convertTargetPath, convertPackageName, convertSchemaFilter, convertFlattenSchema, convertNullableTypes); err != nil {
if err := writeDatabase(db, convertTargetType, convertTargetPath, convertPackageName, convertSchemaFilter, convertFlattenSchema, convertNullableTypes, convertContinueOnError); err != nil {
return fmt.Errorf("failed to write target: %w", err)
}
@@ -383,10 +385,10 @@ func readDatabaseForConvert(dbType, filePath, connString string) (*models.Databa
return db, nil
}
func writeDatabase(db *models.Database, dbType, outputPath, packageName, schemaFilter string, flattenSchema bool, nullableTypes string) error {
func writeDatabase(db *models.Database, dbType, outputPath, packageName, schemaFilter string, flattenSchema bool, nullableTypes string, continueOnError bool) error {
var writer writers.Writer
writerOpts := newWriterOptions(outputPath, packageName, flattenSchema, nullableTypes)
writerOpts := newWriterOptions(outputPath, packageName, flattenSchema, nullableTypes, continueOnError)
switch strings.ToLower(dbType) {
case "dbml":
+13 -13
View File
@@ -323,31 +323,31 @@ func writeDatabaseForEdit(dbType, filePath, connString string, db *models.Databa
switch strings.ToLower(dbType) {
case "dbml":
writer = wdbml.NewWriter(newWriterOptions(filePath, "", false, ""))
writer = wdbml.NewWriter(newWriterOptions(filePath, "", false, "", false))
case "dctx":
writer = wdctx.NewWriter(newWriterOptions(filePath, "", false, ""))
writer = wdctx.NewWriter(newWriterOptions(filePath, "", false, "", false))
case "drawdb":
writer = wdrawdb.NewWriter(newWriterOptions(filePath, "", false, ""))
writer = wdrawdb.NewWriter(newWriterOptions(filePath, "", false, "", false))
case "graphql":
writer = wgraphql.NewWriter(newWriterOptions(filePath, "", false, ""))
writer = wgraphql.NewWriter(newWriterOptions(filePath, "", false, "", false))
case "json":
writer = wjson.NewWriter(newWriterOptions(filePath, "", false, ""))
writer = wjson.NewWriter(newWriterOptions(filePath, "", false, "", false))
case "yaml":
writer = wyaml.NewWriter(newWriterOptions(filePath, "", false, ""))
writer = wyaml.NewWriter(newWriterOptions(filePath, "", false, "", false))
case "gorm":
writer = wgorm.NewWriter(newWriterOptions(filePath, "", false, ""))
writer = wgorm.NewWriter(newWriterOptions(filePath, "", false, "", false))
case "bun":
writer = wbun.NewWriter(newWriterOptions(filePath, "", false, ""))
writer = wbun.NewWriter(newWriterOptions(filePath, "", false, "", false))
case "drizzle":
writer = wdrizzle.NewWriter(newWriterOptions(filePath, "", false, ""))
writer = wdrizzle.NewWriter(newWriterOptions(filePath, "", false, "", false))
case "prisma":
writer = wprisma.NewWriter(newWriterOptions(filePath, "", false, ""))
writer = wprisma.NewWriter(newWriterOptions(filePath, "", false, "", false))
case "typeorm":
writer = wtypeorm.NewWriter(newWriterOptions(filePath, "", false, ""))
writer = wtypeorm.NewWriter(newWriterOptions(filePath, "", false, "", false))
case "sqlite", "sqlite3":
writer = wsqlite.NewWriter(newWriterOptions(filePath, "", false, ""))
writer = wsqlite.NewWriter(newWriterOptions(filePath, "", false, "", false))
case "pgsql":
writer = wpgsql.NewWriter(newWriterOptions(filePath, "", false, ""))
writer = wpgsql.NewWriter(newWriterOptions(filePath, "", false, "", false))
default:
return fmt.Errorf("%s: unsupported format: %s", label, dbType)
}
+18 -13
View File
@@ -258,6 +258,11 @@ func runMerge(cmd *cobra.Command, args []string) error {
fmt.Fprintf(os.Stderr, " ✓ Merge complete\n\n")
fmt.Fprintf(os.Stderr, "%s\n", merge.GetMergeSummary(result))
if strings.EqualFold(mergeOutputType, "pgsql") && len(result.TypeConflicts) > 0 {
return fmt.Errorf("merge detected conflicting existing column types and cannot safely continue with pgsql output\n%s",
merge.GetColumnTypeConflictSummary(result, 10))
}
// Step 4: Write output
fmt.Fprintf(os.Stderr, "\n[4/4] Writing output...\n")
fmt.Fprintf(os.Stderr, " Format: %s\n", mergeOutputType)
@@ -370,61 +375,61 @@ func writeDatabaseForMerge(dbType, filePath, connString string, db *models.Datab
if filePath == "" {
return fmt.Errorf("%s: file path is required for DBML format", label)
}
writer = wdbml.NewWriter(newWriterOptions(filePath, "", flattenSchema, ""))
writer = wdbml.NewWriter(newWriterOptions(filePath, "", flattenSchema, "", false))
case "dctx":
if filePath == "" {
return fmt.Errorf("%s: file path is required for DCTX format", label)
}
writer = wdctx.NewWriter(newWriterOptions(filePath, "", flattenSchema, ""))
writer = wdctx.NewWriter(newWriterOptions(filePath, "", flattenSchema, "", false))
case "drawdb":
if filePath == "" {
return fmt.Errorf("%s: file path is required for DrawDB format", label)
}
writer = wdrawdb.NewWriter(newWriterOptions(filePath, "", flattenSchema, ""))
writer = wdrawdb.NewWriter(newWriterOptions(filePath, "", flattenSchema, "", false))
case "graphql":
if filePath == "" {
return fmt.Errorf("%s: file path is required for GraphQL format", label)
}
writer = wgraphql.NewWriter(newWriterOptions(filePath, "", flattenSchema, ""))
writer = wgraphql.NewWriter(newWriterOptions(filePath, "", flattenSchema, "", false))
case "json":
if filePath == "" {
return fmt.Errorf("%s: file path is required for JSON format", label)
}
writer = wjson.NewWriter(newWriterOptions(filePath, "", flattenSchema, ""))
writer = wjson.NewWriter(newWriterOptions(filePath, "", flattenSchema, "", false))
case "yaml":
if filePath == "" {
return fmt.Errorf("%s: file path is required for YAML format", label)
}
writer = wyaml.NewWriter(newWriterOptions(filePath, "", flattenSchema, ""))
writer = wyaml.NewWriter(newWriterOptions(filePath, "", flattenSchema, "", false))
case "gorm":
if filePath == "" {
return fmt.Errorf("%s: file path is required for GORM format", label)
}
writer = wgorm.NewWriter(newWriterOptions(filePath, "", flattenSchema, ""))
writer = wgorm.NewWriter(newWriterOptions(filePath, "", flattenSchema, "", false))
case "bun":
if filePath == "" {
return fmt.Errorf("%s: file path is required for Bun format", label)
}
writer = wbun.NewWriter(newWriterOptions(filePath, "", flattenSchema, ""))
writer = wbun.NewWriter(newWriterOptions(filePath, "", flattenSchema, "", false))
case "drizzle":
if filePath == "" {
return fmt.Errorf("%s: file path is required for Drizzle format", label)
}
writer = wdrizzle.NewWriter(newWriterOptions(filePath, "", flattenSchema, ""))
writer = wdrizzle.NewWriter(newWriterOptions(filePath, "", flattenSchema, "", false))
case "prisma":
if filePath == "" {
return fmt.Errorf("%s: file path is required for Prisma format", label)
}
writer = wprisma.NewWriter(newWriterOptions(filePath, "", flattenSchema, ""))
writer = wprisma.NewWriter(newWriterOptions(filePath, "", flattenSchema, "", false))
case "typeorm":
if filePath == "" {
return fmt.Errorf("%s: file path is required for TypeORM format", label)
}
writer = wtypeorm.NewWriter(newWriterOptions(filePath, "", flattenSchema, ""))
writer = wtypeorm.NewWriter(newWriterOptions(filePath, "", flattenSchema, "", false))
case "sqlite", "sqlite3":
writer = wsqlite.NewWriter(newWriterOptions(filePath, "", flattenSchema, ""))
writer = wsqlite.NewWriter(newWriterOptions(filePath, "", flattenSchema, "", false))
case "pgsql":
writerOpts := newWriterOptions(filePath, "", flattenSchema, "")
writerOpts := newWriterOptions(filePath, "", flattenSchema, "", false)
if connString != "" {
writerOpts.Metadata = map[string]interface{}{
"connection_string": connString,
+36
View File
@@ -3,6 +3,7 @@ package main
import (
"os"
"path/filepath"
"strings"
"testing"
)
@@ -160,3 +161,38 @@ func TestRunMerge_FromListMissingSourceType(t *testing.T) {
t.Error("expected error when neither --source-path nor --from-list is provided")
}
}
func TestRunMerge_PgsqlOutputRejectsColumnTypeConflict(t *testing.T) {
saved := saveMergeState()
defer restoreMergeState(saved)
dir := t.TempDir()
targetFile := filepath.Join(dir, "target.json")
sourceFile := filepath.Join(dir, "source.json")
writeTestJSONWithSingleColumnType(t, targetFile, "users", "integer")
writeTestJSONWithSingleColumnType(t, sourceFile, "users", "uuid")
mergeTargetType = "json"
mergeTargetPath = targetFile
mergeTargetConn = ""
mergeSourceType = "json"
mergeSourcePath = sourceFile
mergeSourceConn = ""
mergeFromList = nil
mergeOutputType = "pgsql"
mergeOutputPath = ""
mergeOutputConn = "postgres://relspec:secret@localhost/testdb"
mergeSkipTables = ""
mergeReportPath = ""
err := runMerge(nil, nil)
if err == nil {
t.Fatal("expected pgsql output merge to fail on column type conflict")
}
if !strings.Contains(err.Error(), "column type conflicts detected") {
t.Fatalf("expected conflict summary in error, got: %v", err)
}
if !strings.Contains(err.Error(), "public.users.id") {
t.Fatalf("expected conflicting column path in error, got: %v", err)
}
}
+2 -1
View File
@@ -13,12 +13,13 @@ func newReaderOptions(filePath, connString string) *readers.ReaderOptions {
}
}
func newWriterOptions(outputPath, packageName string, flattenSchema bool, nullableTypes string) *writers.WriterOptions {
func newWriterOptions(outputPath, packageName string, flattenSchema bool, nullableTypes string, continueOnError bool) *writers.WriterOptions {
return &writers.WriterOptions{
OutputPath: outputPath,
PackageName: packageName,
FlattenSchema: flattenSchema,
NullableTypes: nullableTypes,
Prisma7: prisma7,
ContinueOnError: continueOnError,
}
}
+1
View File
@@ -188,6 +188,7 @@ func runSplit(cmd *cobra.Command, args []string) error {
"", // no schema filter for split
false, // no flatten-schema for split
splitNullableTypes,
false, // no continue-on-error for split
)
if err != nil {
return fmt.Errorf("failed to write output: %w", err)
+34
View File
@@ -71,6 +71,40 @@ func writeTestJSON(t *testing.T, path string, tableNames []string) {
}
}
func writeTestJSONWithSingleColumnType(t *testing.T, path, tableName, columnType string) {
t.Helper()
db := minimalDatabase{
Name: "test_db",
Schemas: []minimalSchema{{
Name: "public",
Tables: []minimalTable{{
Name: tableName,
Schema: "public",
Columns: map[string]minimalColumn{
"id": {
Name: "id",
Table: tableName,
Schema: "public",
Type: columnType,
NotNull: true,
IsPrimaryKey: true,
AutoIncrement: true,
},
},
}},
}},
}
data, err := json.Marshal(db)
if err != nil {
t.Fatalf("failed to marshal test JSON: %v", err)
}
if err := os.WriteFile(path, data, 0644); err != nil {
t.Fatalf("failed to write test file %s: %v", path, err)
}
}
// convertState captures and restores all convert global vars.
type convertState struct {
sourceType string
+1 -1
View File
@@ -1,6 +1,6 @@
# Maintainer: Hein (Warky Devs) <hein@warky.dev>
pkgname=relspec
pkgver=1.0.50
pkgver=1.0.58
pkgrel=1
pkgdesc="RelSpec is a comprehensive database relations management tool that reads, transforms, and writes database table specifications across multiple formats and ORMs."
arch=('x86_64' 'aarch64')
+1 -1
View File
@@ -1,5 +1,5 @@
Name: relspec
Version: 1.0.50
Version: 1.0.58
Release: 1%{?dist}
Summary: RelSpec is a comprehensive database relations management tool that reads, transforms, and writes database table specifications across multiple formats and ORMs.
+156
View File
@@ -0,0 +1,156 @@
package mariadb
import "strings"
// MariaDBToCanonicalTypes maps MariaDB/MySQL type names to canonical types.
var MariaDBToCanonicalTypes = map[string]string{
// Integer types
"tinyint": "int8",
"smallint": "int16",
"mediumint": "int",
"int": "int",
"integer": "int",
"int2": "int16",
"int4": "int",
"int8": "int64",
"bigint": "int64",
// Boolean (TINYINT(1) alias)
"boolean": "bool",
"bool": "bool",
"bit": "bool",
// Float types
"float": "float32",
"double": "float64",
"real": "float64",
"double precision": "float64",
// Decimal types
"decimal": "decimal",
"numeric": "decimal",
"dec": "decimal",
"fixed": "decimal",
// String types
"char": "string",
"character": "string",
"varchar": "string",
"nchar": "string",
"nvarchar": "string",
"tinytext": "text",
"text": "text",
"mediumtext": "text",
"longtext": "text",
// Binary/blob types
"binary": "bytea",
"varbinary": "bytea",
"tinyblob": "bytea",
"blob": "bytea",
"mediumblob": "bytea",
"longblob": "bytea",
// Date/time types
"date": "date",
"time": "time",
"datetime": "timestamp",
"timestamp": "timestamp",
"year": "int",
// Other types
"json": "json",
"enum": "string",
"set": "string",
"uuid": "uuid",
}
// CanonicalToMariaDBTypes maps canonical types to MariaDB/MySQL types.
var CanonicalToMariaDBTypes = map[string]string{
"bool": "TINYINT(1)",
"int8": "TINYINT",
"int16": "SMALLINT",
"int": "INT",
"int32": "INT",
"int64": "BIGINT",
"uint": "INT UNSIGNED",
"uint8": "TINYINT UNSIGNED",
"uint16": "SMALLINT UNSIGNED",
"uint32": "INT UNSIGNED",
"uint64": "BIGINT UNSIGNED",
"float32": "FLOAT",
"float64": "DOUBLE",
"decimal": "DECIMAL",
"string": "VARCHAR(255)",
"text": "TEXT",
"date": "DATE",
"time": "TIME",
"timestamp": "DATETIME",
"timestamptz": "DATETIME",
"uuid": "CHAR(36)",
"json": "JSON",
"jsonb": "JSON",
"bytea": "BLOB",
}
// MariaDBTypeSynonyms maps MariaDB/MySQL type aliases to their canonical MariaDB name.
var MariaDBTypeSynonyms = map[string]string{
"integer": "int",
"int2": "smallint",
"int4": "int",
"int8": "bigint",
"double precision": "double",
"character": "char",
"dec": "decimal",
"fixed": "decimal",
"numeric": "decimal",
"boolean": "tinyint",
"bool": "tinyint",
}
// NormalizeMariaDBType maps a MariaDB/MySQL base type (no dimension parameters)
// to its canonical MariaDB form. Unknown types are returned as-is (lowercased).
func NormalizeMariaDBType(baseType string) string {
lower := strings.ToLower(strings.TrimSpace(baseType))
if canonical, ok := MariaDBTypeSynonyms[lower]; ok {
return canonical
}
return lower
}
// ConvertMariaDBToCanonical converts a MariaDB/MySQL type name to the canonical type.
// Strips dimension parameters and normalizes aliases. Defaults to "string".
func ConvertMariaDBToCanonical(mariadbType string) string {
base := strings.ToLower(strings.TrimSpace(mariadbType))
if idx := strings.Index(base, "("); idx >= 0 {
base = strings.TrimSpace(base[:idx])
}
if canonical, ok := MariaDBToCanonicalTypes[base]; ok {
return canonical
}
// Prefix match for composite types (e.g., "unsigned bigint")
for key, canonical := range MariaDBToCanonicalTypes {
if strings.HasPrefix(base, key) {
return canonical
}
}
return "string"
}
// ConvertCanonicalToMariaDB converts a canonical type to a MariaDB/MySQL type.
// Defaults to VARCHAR(255) for unrecognised types.
func ConvertCanonicalToMariaDB(canonicalType string) string {
lower := strings.ToLower(strings.TrimSpace(canonicalType))
if idx := strings.Index(lower, "("); idx >= 0 {
lower = strings.TrimSpace(lower[:idx])
}
if mariadbType, ok := CanonicalToMariaDBTypes[lower]; ok {
return mariadbType
}
// Prefix fallback
for canonical, mariadb := range CanonicalToMariaDBTypes {
if strings.HasPrefix(lower, canonical) {
return mariadb
}
}
return "VARCHAR(255)"
}
+126 -1
View File
@@ -5,9 +5,11 @@ package merge
import (
"fmt"
"strconv"
"strings"
"git.warky.dev/wdevs/relspecgo/pkg/models"
"git.warky.dev/wdevs/relspecgo/pkg/pgsql"
)
// MergeResult represents the result of a merge operation
@@ -22,6 +24,16 @@ type MergeResult struct {
EnumsAdded int
ViewsAdded int
SequencesAdded int
TypeConflicts []ColumnTypeConflict
}
// ColumnTypeConflict describes a column that exists in both schemas but with incompatible types.
type ColumnTypeConflict struct {
Schema string
Table string
Column string
TargetType string
SourceType string
}
// MergeOptions contains options for merge operations
@@ -146,11 +158,19 @@ func (r *MergeResult) mergeColumns(table *models.Table, srcTable *models.Table)
// Merge columns
for colName, srcCol := range srcTable.Columns {
if _, exists := existingColumns[colName]; !exists {
if tgtCol, exists := existingColumns[colName]; !exists {
// Column doesn't exist, add it
newCol := cloneColumn(srcCol)
table.Columns[colName] = newCol
r.ColumnsAdded++
} else if columnTypeConflict(tgtCol, srcCol) {
r.TypeConflicts = append(r.TypeConflicts, ColumnTypeConflict{
Schema: firstNonEmpty(table.Schema, srcTable.Schema, srcCol.Schema),
Table: firstNonEmpty(table.Name, srcTable.Name, srcCol.Table),
Column: firstNonEmpty(tgtCol.Name, srcCol.Name, colName),
TargetType: describeColumnType(tgtCol),
SourceType: describeColumnType(srcCol),
})
}
}
}
@@ -426,6 +446,78 @@ func cloneColumn(col *models.Column) *models.Column {
return newCol
}
func columnTypeConflict(target, source *models.Column) bool {
if target == nil || source == nil {
return false
}
tType, tLen, tPrec, tScale := extractTypeParts(target)
sType, sLen, sPrec, sScale := extractTypeParts(source)
return tType != sType || tLen != sLen || tPrec != sPrec || tScale != sScale
}
// extractTypeParts returns the canonical base type and dimensions for a column,
// handling the case where dimensions are embedded in the type string (e.g. "char(2)")
// rather than stored in the separate Length/Precision/Scale fields.
func extractTypeParts(col *models.Column) (baseType string, length, precision, scale int) {
typeName := strings.ToLower(strings.TrimSpace(col.Type))
length, precision, scale = col.Length, col.Precision, col.Scale
if idx := strings.Index(typeName, "("); idx >= 0 {
inner := strings.TrimRight(strings.TrimSpace(typeName[idx+1:]), ")")
typeName = strings.TrimSpace(typeName[:idx])
parts := strings.Split(inner, ",")
if len(parts) == 2 {
if p, err := strconv.Atoi(strings.TrimSpace(parts[0])); err == nil && p > 0 && precision == 0 {
precision = p
}
if s, err := strconv.Atoi(strings.TrimSpace(parts[1])); err == nil && s > 0 && scale == 0 {
scale = s
}
} else if len(parts) == 1 {
if l, err := strconv.Atoi(strings.TrimSpace(parts[0])); err == nil && l > 0 && length == 0 && precision == 0 {
length = l
}
}
}
typeName = pgsql.NormalizePGType(typeName)
return typeName, length, precision, scale
}
func describeColumnType(col *models.Column) string {
if col == nil {
return ""
}
typeName := strings.TrimSpace(col.Type)
if typeName == "" {
return ""
}
switch {
case col.Precision > 0 && col.Scale > 0:
return fmt.Sprintf("%s(%d,%d)", typeName, col.Precision, col.Scale)
case col.Precision > 0:
return fmt.Sprintf("%s(%d)", typeName, col.Precision)
case col.Length > 0:
return fmt.Sprintf("%s(%d)", typeName, col.Length)
default:
return typeName
}
}
func firstNonEmpty(values ...string) string {
for _, value := range values {
if strings.TrimSpace(value) != "" {
return value
}
}
return ""
}
func cloneConstraint(constraint *models.Constraint) *models.Constraint {
if constraint == nil {
return nil
@@ -609,6 +701,7 @@ func GetMergeSummary(result *MergeResult) string {
fmt.Sprintf("Enums added: %d", result.EnumsAdded),
fmt.Sprintf("Relations added: %d", result.RelationsAdded),
fmt.Sprintf("Domains added: %d", result.DomainsAdded),
fmt.Sprintf("Type conflicts: %d", len(result.TypeConflicts)),
}
totalAdded := result.SchemasAdded + result.TablesAdded + result.ColumnsAdded +
@@ -625,3 +718,35 @@ func GetMergeSummary(result *MergeResult) string {
return summary
}
// GetColumnTypeConflictSummary returns a short, human-readable conflict summary.
func GetColumnTypeConflictSummary(result *MergeResult, limit int) string {
if result == nil || len(result.TypeConflicts) == 0 {
return ""
}
if limit <= 0 {
limit = len(result.TypeConflicts)
}
lines := make([]string, 0, min(limit, len(result.TypeConflicts))+1)
lines = append(lines, "column type conflicts detected:")
for i, conflict := range result.TypeConflicts {
if i >= limit {
break
}
lines = append(lines, fmt.Sprintf(" - %s.%s.%s: target=%s source=%s",
conflict.Schema, conflict.Table, conflict.Column, conflict.TargetType, conflict.SourceType))
}
if len(result.TypeConflicts) > limit {
lines = append(lines, fmt.Sprintf(" ... and %d more", len(result.TypeConflicts)-limit))
}
return strings.Join(lines, "\n")
}
func min(a, b int) int {
if a < b {
return a
}
return b
}
+62
View File
@@ -1,6 +1,7 @@
package merge
import (
"strings"
"testing"
"git.warky.dev/wdevs/relspecgo/pkg/models"
@@ -140,6 +141,61 @@ func TestMergeColumns_NewColumn(t *testing.T) {
}
}
func TestMergeColumns_TypeConflictIsDetected(t *testing.T) {
target := &models.Database{
Schemas: []*models.Schema{
{
Name: "public",
Tables: []*models.Table{
{
Name: "users",
Schema: "public",
Columns: map[string]*models.Column{
"email": {Name: "email", Type: "varchar", Length: 255},
},
},
},
},
},
}
source := &models.Database{
Schemas: []*models.Schema{
{
Name: "public",
Tables: []*models.Table{
{
Name: "users",
Schema: "public",
Columns: map[string]*models.Column{
"email": {Name: "email", Type: "text"},
},
},
},
},
},
}
result := MergeDatabases(target, source, nil)
if len(result.TypeConflicts) != 1 {
t.Fatalf("Expected 1 type conflict, got %d", len(result.TypeConflicts))
}
conflict := result.TypeConflicts[0]
if conflict.Schema != "public" || conflict.Table != "users" || conflict.Column != "email" {
t.Fatalf("Unexpected conflict location: %+v", conflict)
}
if conflict.TargetType != "varchar(255)" {
t.Fatalf("Expected target type varchar(255), got %q", conflict.TargetType)
}
if conflict.SourceType != "text" {
t.Fatalf("Expected source type text, got %q", conflict.SourceType)
}
if got := target.Schemas[0].Tables[0].Columns["email"].Type; got != "varchar" {
t.Fatalf("Expected target column type to remain unchanged, got %q", got)
}
}
func TestMergeConstraints_NewConstraint(t *testing.T) {
target := &models.Database{
Schemas: []*models.Schema{
@@ -509,6 +565,9 @@ func TestGetMergeSummary(t *testing.T) {
ConstraintsAdded: 3,
IndexesAdded: 2,
ViewsAdded: 1,
TypeConflicts: []ColumnTypeConflict{
{Schema: "public", Table: "users", Column: "email", TargetType: "varchar(255)", SourceType: "text"},
},
}
summary := GetMergeSummary(result)
@@ -518,6 +577,9 @@ func TestGetMergeSummary(t *testing.T) {
if len(summary) < 50 {
t.Errorf("Summary seems too short: %s", summary)
}
if !strings.Contains(summary, "Type conflicts: 1") {
t.Errorf("Expected type conflict count in summary, got: %s", summary)
}
}
func TestGetMergeSummary_Nil(t *testing.T) {
+76 -32
View File
@@ -2,32 +2,73 @@ package mssql
import "strings"
// CanonicalToMSSQLTypes maps canonical types to MSSQL types
// CanonicalToMSSQLTypes maps canonical types to MSSQL types.
// Accepts both Go canonical names ("int", "string") and SQL canonical names
// ("integer", "varchar") so the writer handles input from any reader.
var CanonicalToMSSQLTypes = map[string]string{
// Boolean — Go and SQL canonical
"bool": "BIT",
"boolean": "BIT",
// Integer — Go canonical
"int8": "TINYINT",
"int16": "SMALLINT",
"int": "INT",
"int32": "INT",
"int64": "BIGINT",
"uint": "BIGINT",
"uint8": "SMALLINT",
"uint16": "INT",
"uint8": "TINYINT",
"uint16": "SMALLINT",
"uint32": "BIGINT",
"uint64": "BIGINT",
// Integer — SQL canonical (serial types map to base integer; IDENTITY is set via AutoIncrement)
"integer": "INT",
"smallint": "SMALLINT",
"bigint": "BIGINT",
"tinyint": "TINYINT",
"serial": "INT",
"smallserial": "SMALLINT",
"bigserial": "BIGINT",
// Float — Go canonical
"float32": "REAL",
"float64": "FLOAT",
// Float — SQL canonical
"real": "REAL",
"double precision": "FLOAT",
"double": "FLOAT",
// Decimal/numeric
"decimal": "NUMERIC",
"numeric": "NUMERIC",
"money": "MONEY",
// String — Go canonical
"string": "NVARCHAR(255)",
"text": "NVARCHAR(MAX)",
// String — SQL canonical
"varchar": "NVARCHAR(255)",
"char": "NCHAR",
"nvarchar": "NVARCHAR(255)",
"nchar": "NCHAR",
"citext": "NVARCHAR(MAX)",
// Date/time
"date": "DATE",
"time": "TIME",
"timetz": "DATETIMEOFFSET",
"timestamp": "DATETIME2",
"timestamptz": "DATETIMEOFFSET",
"datetime": "DATETIME2",
"interval": "NVARCHAR(50)",
// UUID
"uuid": "UNIQUEIDENTIFIER",
// JSON — MSSQL has no native JSON type; stored as NVARCHAR(MAX)
"json": "NVARCHAR(MAX)",
"jsonb": "NVARCHAR(MAX)",
// Binary
"bytea": "VARBINARY(MAX)",
"blob": "VARBINARY(MAX)",
// Network/geo types — no MSSQL native equivalent
"xml": "XML",
"inet": "NVARCHAR(45)",
"cidr": "NVARCHAR(43)",
"macaddr": "NVARCHAR(17)",
}
// MSSQLToCanonicalTypes maps MSSQL types to canonical types
@@ -68,47 +109,50 @@ var MSSQLToCanonicalTypes = map[string]string{
"geometry": "string",
}
// ConvertCanonicalToMSSQL converts a canonical type to MSSQL type
// MSSQLTypeSynonyms maps MSSQL type aliases to their canonical MSSQL name.
var MSSQLTypeSynonyms = map[string]string{
"integer": "int",
"dec": "decimal",
"float(n)": "float",
}
// NormalizeMSSQLType maps an MSSQL base type (no dimension parameters) to its
// canonical MSSQL form. Unknown types are returned as-is (lowercased).
func NormalizeMSSQLType(baseType string) string {
lower := strings.ToLower(strings.TrimSpace(baseType))
if canonical, ok := MSSQLTypeSynonyms[lower]; ok {
return canonical
}
return lower
}
// ConvertCanonicalToMSSQL converts a canonical type (Go or SQL) to an MSSQL type.
// Strips dimension parameters before lookup. Defaults to NVARCHAR(255) for unknown types.
func ConvertCanonicalToMSSQL(canonicalType string) string {
// Check direct mapping
if mssqlType, exists := CanonicalToMSSQLTypes[strings.ToLower(canonicalType)]; exists {
base := strings.ToLower(strings.TrimSpace(canonicalType))
if idx := strings.Index(base, "("); idx >= 0 {
base = strings.TrimSpace(base[:idx])
}
base = strings.TrimSuffix(base, "[]")
if mssqlType, exists := CanonicalToMSSQLTypes[base]; exists {
return mssqlType
}
// Try to find by prefix
lowerType := strings.ToLower(canonicalType)
for canonical, mssql := range CanonicalToMSSQLTypes {
if strings.HasPrefix(lowerType, canonical) {
return mssql
}
}
// Default to NVARCHAR
return "NVARCHAR(255)"
}
// ConvertMSSQLToCanonical converts an MSSQL type to canonical type
// ConvertMSSQLToCanonical converts an MSSQL type to the canonical type.
// Strips dimension parameters before lookup. Defaults to "string" for unknown types.
func ConvertMSSQLToCanonical(mssqlType string) string {
// Extract base type (remove parentheses and parameters)
baseType := mssqlType
if idx := strings.Index(baseType, "("); idx != -1 {
baseType = baseType[:idx]
base := strings.ToLower(strings.TrimSpace(mssqlType))
if idx := strings.Index(base, "("); idx >= 0 {
base = strings.TrimSpace(base[:idx])
}
baseType = strings.TrimSpace(baseType)
// Check direct mapping
if canonicalType, exists := MSSQLToCanonicalTypes[strings.ToLower(baseType)]; exists {
if canonicalType, exists := MSSQLToCanonicalTypes[base]; exists {
return canonicalType
}
// Try to find by prefix
lowerType := strings.ToLower(baseType)
for mssql, canonical := range MSSQLToCanonicalTypes {
if strings.HasPrefix(lowerType, mssql) {
return canonical
}
}
// Default to string
return "string"
}
+58
View File
@@ -45,6 +45,7 @@ var GoToStdTypes = map[string]string{
"sqldate": "date",
"sqltime": "time",
"sqltimestamp": "timestamp",
"time.Time": "timestamp",
}
var GoToPGSQLTypes = map[string]string{
@@ -90,6 +91,7 @@ var GoToPGSQLTypes = map[string]string{
"sqldate": "date",
"sqltime": "time",
"sqltimestamp": "timestamp",
"time.Time": "timestamp",
"citext": "citext",
}
@@ -135,6 +137,62 @@ func ConvertSQLType(anytype string) string {
return anytype
}
// PGTypeCanonical maps PostgreSQL type aliases and synonyms to their canonical base name.
// Input should be a base type (no dimension parameters, lowercase).
var PGTypeCanonical = map[string]string{
// integer aliases
"int": "integer",
"int4": "integer",
"int2": "smallint",
"int8": "bigint",
// float aliases
"float4": "real",
"float8": "double precision",
// bool alias
"bool": "boolean",
// char aliases
"character": "char",
"character varying": "varchar",
"bpchar": "char",
// timestamp aliases
"timestamp without time zone": "timestamp",
"timestamp with time zone": "timestamptz",
// time aliases
"time without time zone": "time",
"time with time zone": "timetz",
// decimal alias
"decimal": "numeric",
}
// knownPGBaseTypes is the set of canonical PostgreSQL base types (no aliases).
var knownPGBaseTypes = map[string]struct{}{
"integer": {}, "bigint": {}, "smallint": {},
"serial": {}, "bigserial": {}, "smallserial": {},
"numeric": {}, "real": {}, "double precision": {}, "money": {},
"varchar": {}, "char": {}, "text": {}, "citext": {},
"boolean": {},
"date": {}, "time": {}, "timetz": {}, "timestamp": {}, "timestamptz": {}, "interval": {},
"uuid": {}, "json": {}, "jsonb": {}, "bytea": {},
"inet": {}, "cidr": {}, "macaddr": {}, "xml": {},
}
// NormalizePGType maps a PostgreSQL base type (no dimension parameters) to its
// canonical form. Unknown types are returned as-is (lowercased).
func NormalizePGType(baseType string) string {
lower := strings.ToLower(strings.TrimSpace(baseType))
if canonical, ok := PGTypeCanonical[lower]; ok {
return canonical
}
return lower
}
// IsKnownPGBaseType reports whether the given name (after NormalizePGType) is a
// recognized built-in PostgreSQL type. Custom types (e.g. vector, postgis) return false.
func IsKnownPGBaseType(baseType string) bool {
_, ok := knownPGBaseTypes[strings.ToLower(strings.TrimSpace(baseType))]
return ok
}
func IsGoType(pTypeName string) bool {
for k := range GoToStdTypes {
if strings.EqualFold(pTypeName, k) {
+98
View File
@@ -145,6 +145,24 @@ var postgresTypeAliases = map[string]string{
"bool": "boolean",
}
var postgresEquivalentBaseTypes = map[string]string{
"character varying": "varchar",
"character": "char",
"timestamp without time zone": "timestamp",
"timestamp with time zone": "timestamptz",
"time without time zone": "time",
"time with time zone": "timetz",
}
var postgresEquivalentBaseTypeVariants = map[string][]string{
"varchar": {"varchar", "character varying"},
"char": {"char", "character"},
"timestamp": {"timestamp", "timestamp without time zone"},
"timestamptz": {"timestamptz", "timestamp with time zone"},
"time": {"time", "time without time zone"},
"timetz": {"timetz", "time with time zone"},
}
// GetPostgresBaseTypes returns a sorted-ish stable list of registered base type names.
func GetPostgresBaseTypes() []string {
result := make([]string, 0, len(postgresBaseTypes))
@@ -212,6 +230,86 @@ func CanonicalizeBaseType(baseType string) string {
return base
}
// EquivalentBaseType resolves broader SQL-equivalent spellings to a common comparable form.
func EquivalentBaseType(baseType string) string {
base := CanonicalizeBaseType(baseType)
if equivalent, ok := postgresEquivalentBaseTypes[base]; ok {
return equivalent
}
return base
}
// NormalizeEquivalentSQLType returns a normalized SQL type string suitable for equality checks.
// Equivalent spellings such as "character varying(255)" and "varchar(255)" normalize identically.
func NormalizeEquivalentSQLType(sqlType string) string {
t := normalizeTypeToken(sqlType)
if t == "" {
return ""
}
arrayDepth := 0
for strings.HasSuffix(t, "[]") {
arrayDepth++
t = strings.TrimSpace(strings.TrimSuffix(t, "[]"))
}
modifier := ""
if idx := strings.Index(t, "("); idx >= 0 {
modifier = strings.TrimSpace(t[idx:])
t = strings.TrimSpace(t[:idx])
}
base := EquivalentBaseType(t)
normalized := base + modifier
for i := 0; i < arrayDepth; i++ {
normalized += "[]"
}
return normalized
}
// EquivalentSQLTypeVariants returns equivalent PostgreSQL spellings for a SQL type.
// Examples:
// - varchar(255) -> ["varchar(255)", "character varying(255)"]
// - timestamptz -> ["timestamptz", "timestamp with time zone"]
func EquivalentSQLTypeVariants(sqlType string) []string {
t := normalizeTypeToken(sqlType)
if t == "" {
return nil
}
arrayDepth := 0
for strings.HasSuffix(t, "[]") {
arrayDepth++
t = strings.TrimSpace(strings.TrimSuffix(t, "[]"))
}
modifier := ""
if idx := strings.Index(t, "("); idx >= 0 {
modifier = strings.TrimSpace(t[idx:])
t = strings.TrimSpace(t[:idx])
}
base := EquivalentBaseType(t)
bases := postgresEquivalentBaseTypeVariants[base]
if len(bases) == 0 {
bases = []string{base}
}
seen := make(map[string]bool, len(bases))
result := make([]string, 0, len(bases))
for _, variantBase := range bases {
variant := variantBase + modifier
for i := 0; i < arrayDepth; i++ {
variant += "[]"
}
if !seen[variant] {
seen[variant] = true
result = append(result, variant)
}
}
return result
}
// IsKnownPostgresType reports whether a type (including array forms) exists in the registry.
func IsKnownPostgresType(sqlType string) bool {
base := CanonicalizeBaseType(ExtractBaseTypeLower(sqlType))
+48
View File
@@ -97,3 +97,51 @@ func TestPostgresTypeRegistry_TypeParsingAndCapabilities(t *testing.T) {
})
}
}
func TestNormalizeEquivalentSQLType(t *testing.T) {
tests := []struct {
input string
want string
}{
{input: "character varying(255)", want: "varchar(255)"},
{input: "varchar(255)", want: "varchar(255)"},
{input: "timestamp with time zone", want: "timestamptz"},
{input: "timestamptz", want: "timestamptz"},
{input: "time without time zone", want: "time"},
{input: "character varying(255)[]", want: "varchar(255)[]"},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
got := NormalizeEquivalentSQLType(tt.input)
if got != tt.want {
t.Fatalf("NormalizeEquivalentSQLType(%q) = %q, want %q", tt.input, got, tt.want)
}
})
}
}
func TestEquivalentSQLTypeVariants(t *testing.T) {
tests := []struct {
input string
want []string
}{
{input: "character varying(255)", want: []string{"varchar(255)", "character varying(255)"}},
{input: "timestamptz", want: []string{"timestamptz", "timestamp with time zone"}},
{input: "text[]", want: []string{"text[]"}},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
got := EquivalentSQLTypeVariants(tt.input)
if len(got) != len(tt.want) {
t.Fatalf("EquivalentSQLTypeVariants(%q) len = %d, want %d (%v)", tt.input, len(got), len(tt.want), got)
}
for i := range tt.want {
if got[i] != tt.want[i] {
t.Fatalf("EquivalentSQLTypeVariants(%q)[%d] = %q, want %q", tt.input, i, got[i], tt.want[i])
}
}
})
}
}
+8
View File
@@ -270,8 +270,16 @@ func (r *Reader) queryColumns(schemaName string) (map[string]map[string]*models.
}
if numPrecision != nil {
// For integer and serial types, numeric_precision is a bit-width (32, 64, 16)
// not a user-visible column parameter. Only store precision for types where
// it represents actual decimal/scale precision (numeric, decimal, float).
switch column.Type {
case "integer", "bigint", "smallint", "serial", "bigserial", "smallserial":
// skip — bit-width, not a column parameter
default:
column.Precision = *numPrecision
}
}
if numScale != nil {
column.Scale = *numScale
+28 -62
View File
@@ -259,12 +259,14 @@ func (r *Reader) close() {
}
}
// mapDataType maps PostgreSQL data types while preserving exact type text when available.
// mapDataType maps a PostgreSQL data type to its canonical RelSpec name.
// For known built-in types, dimensions are stripped from the type string (they are
// stored separately in column.Length/Precision/Scale). For custom types (e.g.
// vector(1536), postgis geometries), the full formatted type is preserved.
func (r *Reader) mapDataType(pgType, udtName, formattedType string, hasNextval bool) string {
normalizedPGType := strings.ToLower(strings.TrimSpace(pgType))
// If the column has a nextval default, it's likely a serial type
// Map to the appropriate serial type instead of the base integer type
// Detect serial types from nextval defaults before anything else.
if hasNextval {
switch normalizedPGType {
case "integer", "int", "int4":
@@ -276,73 +278,38 @@ func (r *Reader) mapDataType(pgType, udtName, formattedType string, hasNextval b
}
}
// Prefer the database-provided formatted type; this preserves arrays/custom
// types/modifiers like text[], vector(1536), numeric(10,2), etc.
if strings.TrimSpace(formattedType) != "" {
return formattedType
}
// information_schema reports arrays generically as "ARRAY" with udt_name like "_text".
if strings.EqualFold(pgType, "ARRAY") && strings.HasPrefix(udtName, "_") && len(udtName) > 1 {
return udtName[1:] + "[]"
}
// Map common PostgreSQL types
typeMap := map[string]string{
"integer": "integer",
"bigint": "bigint",
"smallint": "smallint",
"int": "integer",
"int2": "smallint",
"int4": "integer",
"int8": "bigint",
"serial": "serial",
"bigserial": "bigserial",
"smallserial": "smallserial",
"numeric": "numeric",
"decimal": "decimal",
"real": "real",
"double precision": "double precision",
"float4": "real",
"float8": "double precision",
"money": "money",
"character varying": "varchar",
"varchar": "varchar",
"character": "char",
"char": "char",
"text": "text",
"boolean": "boolean",
"bool": "boolean",
"date": "date",
"time": "time",
"time without time zone": "time",
"time with time zone": "timetz",
"timestamp": "timestamp",
"timestamp without time zone": "timestamp",
"timestamp with time zone": "timestamptz",
"timestamptz": "timestamptz",
"interval": "interval",
"uuid": "uuid",
"json": "json",
"jsonb": "jsonb",
"bytea": "bytea",
"inet": "inet",
"cidr": "cidr",
"macaddr": "macaddr",
"xml": "xml",
// Use the database-formatted type when available. For known built-in types, strip
// embedded dimensions (they are stored in column.Length/Precision/Scale separately).
// For unknown/custom types, keep the full formatted string (e.g. vector(1536)).
if strings.TrimSpace(formattedType) != "" {
lower := strings.ToLower(strings.TrimSpace(formattedType))
isArray := strings.HasSuffix(lower, "[]")
base := strings.TrimSuffix(lower, "[]")
if idx := strings.Index(base, "("); idx >= 0 {
base = strings.TrimSpace(base[:idx])
}
canonical := pgsql.NormalizePGType(base)
if pgsql.IsKnownPGBaseType(canonical) {
if isArray {
return canonical + "[]"
}
return canonical
}
return formattedType
}
// Try mapped type first
if mapped, exists := typeMap[normalizedPGType]; exists {
return mapped
// Fall back to normalizing the information_schema type name directly.
canonical := pgsql.NormalizePGType(normalizedPGType)
if pgsql.IsKnownPGBaseType(canonical) {
return canonical
}
// Use pgsql utilities if available
if pgsql.ValidSQLType(pgType) {
return pgsql.GetSQLType(pgType)
}
// Return UDT name for custom types (including array fallback when needed)
// Return UDT name for custom types.
if udtName != "" {
if strings.HasPrefix(udtName, "_") && len(udtName) > 1 {
return udtName[1:] + "[]"
@@ -350,7 +317,6 @@ func (r *Reader) mapDataType(pgType, udtName, formattedType string, hasNextval b
return udtName
}
// Default to the original type
return pgType
}
+1 -1
View File
@@ -198,7 +198,7 @@ func TestMapDataType(t *testing.T) {
{"unknown_type", "custom", "", "custom"}, // Should return UDT name
{"ARRAY", "_text", "", "text[]"},
{"USER-DEFINED", "vector", "vector(1536)", "vector(1536)"},
{"character varying", "varchar", "character varying(255)", "character varying(255)"},
{"character varying", "varchar", "character varying(255)", "varchar"},
}
for _, tt := range tests {
+3 -52
View File
@@ -10,6 +10,7 @@ import (
"git.warky.dev/wdevs/relspecgo/pkg/models"
"git.warky.dev/wdevs/relspecgo/pkg/readers"
sqlitepkg "git.warky.dev/wdevs/relspecgo/pkg/sqlite"
)
// Reader implements the readers.Reader interface for SQLite databases
@@ -183,59 +184,9 @@ func (r *Reader) close() {
}
}
// mapDataType maps SQLite data types to canonical types
// mapDataType maps SQLite data types to canonical types.
func (r *Reader) mapDataType(sqliteType string) string {
// SQLite has a flexible type system, but we map common types
typeMap := map[string]string{
"INTEGER": "int",
"INT": "int",
"TINYINT": "int8",
"SMALLINT": "int16",
"MEDIUMINT": "int",
"BIGINT": "int64",
"UNSIGNED BIG INT": "uint64",
"INT2": "int16",
"INT8": "int64",
"REAL": "float64",
"DOUBLE": "float64",
"DOUBLE PRECISION": "float64",
"FLOAT": "float32",
"NUMERIC": "decimal",
"DECIMAL": "decimal",
"BOOLEAN": "bool",
"BOOL": "bool",
"DATE": "date",
"DATETIME": "timestamp",
"TIMESTAMP": "timestamp",
"TEXT": "string",
"VARCHAR": "string",
"CHAR": "string",
"CHARACTER": "string",
"VARYING CHARACTER": "string",
"NCHAR": "string",
"NVARCHAR": "string",
"CLOB": "text",
"BLOB": "bytea",
}
// Try exact match first
if mapped, exists := typeMap[sqliteType]; exists {
return mapped
}
// Try case-insensitive match for common types
sqliteTypeUpper := sqliteType
if len(sqliteType) > 0 {
// Extract base type (e.g., "VARCHAR(255)" -> "VARCHAR")
for baseType := range typeMap {
if len(sqliteTypeUpper) >= len(baseType) && sqliteTypeUpper[:len(baseType)] == baseType {
return typeMap[baseType]
}
}
}
// Default to string for unknown types
return "string"
return sqlitepkg.ConvertSQLiteToCanonical(sqliteType)
}
// deriveRelationship creates a relationship from a foreign key constraint
+152
View File
@@ -0,0 +1,152 @@
package sqlite
import "strings"
// SQLiteToCanonicalTypes maps SQLite type names to canonical types.
// SQLite has type affinity rules; this maps common type names including
// MySQL/PostgreSQL types that users write in SQLite schemas.
var SQLiteToCanonicalTypes = map[string]string{
// Integer affinity
"integer": "int",
"int": "int",
"tinyint": "int8",
"smallint": "int16",
"mediumint": "int",
"bigint": "int64",
"unsigned big int": "uint64",
"int2": "int16",
"int8": "int64",
// Real affinity
"real": "float64",
"double": "float64",
"double precision": "float64",
"float": "float32",
// Numeric affinity
"numeric": "decimal",
"decimal": "decimal",
// Boolean (stored as integer in SQLite)
"boolean": "bool",
"bool": "bool",
// Date/time (stored as text in SQLite)
"date": "date",
"datetime": "timestamp",
"timestamp": "timestamp",
// Text affinity
"text": "string",
"varchar": "string",
"char": "string",
"character": "string",
"varying character": "string",
"nchar": "string",
"nvarchar": "string",
"clob": "text",
// Blob affinity
"blob": "bytea",
}
// CanonicalToSQLiteAffinity maps type names to SQLite type affinity names.
// Accepts both Go canonical names ("int", "string") and SQL canonical names
// ("integer", "varchar") so the writer handles input from any reader.
// The five SQLite type affinities are TEXT, INTEGER, REAL, NUMERIC, BLOB.
var CanonicalToSQLiteAffinity = map[string]string{
// INTEGER affinity — Go canonical
"int": "INTEGER",
"int8": "INTEGER",
"int16": "INTEGER",
"int32": "INTEGER",
"int64": "INTEGER",
"uint": "INTEGER",
"uint8": "INTEGER",
"uint16": "INTEGER",
"uint32": "INTEGER",
"uint64": "INTEGER",
"bool": "INTEGER",
// INTEGER affinity — SQL canonical
"integer": "INTEGER",
"smallint": "INTEGER",
"bigint": "INTEGER",
"serial": "INTEGER",
"smallserial": "INTEGER",
"bigserial": "INTEGER",
"boolean": "INTEGER",
"tinyint": "INTEGER",
"mediumint": "INTEGER",
// REAL affinity — Go canonical
"float32": "REAL",
"float64": "REAL",
// REAL affinity — SQL canonical
"real": "REAL",
"float": "REAL",
"double": "REAL",
"double precision": "REAL",
// NUMERIC affinity
"decimal": "NUMERIC",
"numeric": "NUMERIC",
"money": "NUMERIC",
"smallmoney": "NUMERIC",
// BLOB affinity
"bytea": "BLOB",
"blob": "BLOB",
// TEXT affinity — Go canonical
"string": "TEXT",
"text": "TEXT",
// TEXT affinity — SQL canonical
"varchar": "TEXT",
"char": "TEXT",
"nvarchar": "TEXT",
"nchar": "TEXT",
"citext": "TEXT",
"date": "TEXT",
"time": "TEXT",
"timetz": "TEXT",
"timestamp": "TEXT",
"timestamptz": "TEXT",
"datetime": "TEXT",
"uuid": "TEXT",
"json": "TEXT",
"jsonb": "TEXT",
"xml": "TEXT",
"inet": "TEXT",
"cidr": "TEXT",
"macaddr": "TEXT",
}
// ConvertSQLiteToCanonical converts a SQLite type name to the canonical type.
// Strips dimension parameters (e.g. VARCHAR(255) → string) and handles
// SQLite's flexible affinity rules. Defaults to "string" for unknown types.
func ConvertSQLiteToCanonical(sqliteType string) string {
base := strings.ToUpper(strings.TrimSpace(sqliteType))
if idx := strings.Index(base, "("); idx >= 0 {
base = strings.TrimSpace(base[:idx])
}
lower := strings.ToLower(base)
if canonical, ok := SQLiteToCanonicalTypes[lower]; ok {
return canonical
}
// Prefix match for types like "VARYING CHARACTER(255)"
for key, canonical := range SQLiteToCanonicalTypes {
if strings.HasPrefix(lower, key) {
return canonical
}
}
return "string"
}
// ConvertCanonicalToSQLite converts a canonical type (or any SQL type) to its
// SQLite type affinity. Defaults to TEXT for unrecognised types.
func ConvertCanonicalToSQLite(canonicalType string) string {
normalized := strings.ToLower(strings.TrimSpace(canonicalType))
if idx := strings.Index(normalized, "("); idx >= 0 {
normalized = strings.TrimSpace(normalized[:idx])
}
normalized = strings.TrimSuffix(normalized, "[]")
if affinity, ok := CanonicalToSQLiteAffinity[normalized]; ok {
return affinity
}
return "TEXT"
}
+18
View File
@@ -26,7 +26,10 @@ type ModelData struct {
Fields []*FieldData
Config *MethodConfig
PrimaryKeyField string // Name of the primary key field
PrimaryKeyType string // Go type of the primary key field
PrimaryKeyIsSQL bool // Whether PK uses SQL type (needs .Int64() call)
PrimaryKeyIsStr bool // Whether helper methods should use string IDs
PrimaryKeyIDType string // Helper method GetID/SetID/UpdateID type
IDColumnName string // Name of the ID column in database
Prefix string // 3-letter prefix
}
@@ -140,7 +143,13 @@ func NewModelData(table *models.Table, schema string, typeMapper *TypeMapper, fl
model.IDColumnName = safeName
// Check if PK type is a SQL type (contains resolvespec_common or sql_types)
goType := typeMapper.SQLTypeToGoType(col.Type, col.NotNull)
model.PrimaryKeyType = goType
model.PrimaryKeyIsSQL = strings.Contains(goType, "resolvespec_common") || strings.Contains(goType, "sql_types")
model.PrimaryKeyIsStr = isStringLikePrimaryKeyType(goType)
model.PrimaryKeyIDType = "int64"
if model.PrimaryKeyIsStr {
model.PrimaryKeyIDType = "string"
}
break
}
}
@@ -192,6 +201,15 @@ func formatComment(description, comment string) string {
return comment
}
func isStringLikePrimaryKeyType(goType string) bool {
switch goType {
case "string", "sql.NullString", "resolvespec_common.SqlString", "resolvespec_common.SqlUUID":
return true
default:
return false
}
}
// resolveFieldNameCollision checks if a field name conflicts with generated method names
// and adds an underscore suffix if there's a collision
func resolveFieldNameCollision(fieldName string) string {
+27 -5
View File
@@ -44,33 +44,55 @@ func (m {{.Name}}) SchemaName() string {
{{end}}
{{if and .Config.GenerateGetID .PrimaryKeyField}}
// GetID returns the primary key value
func (m {{.Name}}) GetID() int64 {
func (m {{.Name}}) GetID() {{.PrimaryKeyIDType}} {
{{if .PrimaryKeyIsSQL -}}
{{if .PrimaryKeyIsStr -}}
return m.{{.PrimaryKeyField}}.String()
{{- else -}}
return m.{{.PrimaryKeyField}}.Int64()
{{- end}}
{{- else -}}
{{if .PrimaryKeyIsStr -}}
return m.{{.PrimaryKeyField}}
{{- else -}}
return int64(m.{{.PrimaryKeyField}})
{{- end}}
{{- end}}
}
{{end}}
{{if and .Config.GenerateGetIDStr .PrimaryKeyField}}
// GetIDStr returns the primary key as a string
func (m {{.Name}}) GetIDStr() string {
{{if .PrimaryKeyIsSQL -}}
return m.{{.PrimaryKeyField}}.String()
{{- else if .PrimaryKeyIsStr -}}
return m.{{.PrimaryKeyField}}
{{- else -}}
return fmt.Sprintf("%d", m.{{.PrimaryKeyField}})
{{- end}}
}
{{end}}
{{if and .Config.GenerateSetID .PrimaryKeyField}}
// SetID sets the primary key value
func (m {{.Name}}) SetID(newid int64) {
func (m {{.Name}}) SetID(newid {{.PrimaryKeyIDType}}) {
m.UpdateID(newid)
}
{{end}}
{{if and .Config.GenerateUpdateID .PrimaryKeyField}}
// UpdateID updates the primary key value
func (m *{{.Name}}) UpdateID(newid int64) {
func (m *{{.Name}}) UpdateID(newid {{.PrimaryKeyIDType}}) {
{{if .PrimaryKeyIsSQL -}}
m.{{.PrimaryKeyField}}.FromString(fmt.Sprintf("%d", newid))
{{if .PrimaryKeyIsStr -}}
m.{{.PrimaryKeyField}}.FromString(newid)
{{- else -}}
m.{{.PrimaryKeyField}} = int32(newid)
m.{{.PrimaryKeyField}}.FromString(fmt.Sprintf("%d", newid))
{{- end}}
{{- else -}}
{{if .PrimaryKeyIsStr -}}
m.{{.PrimaryKeyField}} = newid
{{- else -}}
m.{{.PrimaryKeyField}} = {{.PrimaryKeyType}}(newid)
{{- end}}
{{- end}}
}
{{end}}
+6 -2
View File
@@ -311,10 +311,11 @@ func (tm *TypeMapper) BuildBunTag(column *models.Column, table *models.Table) st
if column.Type != "" {
// Sanitize type to remove backticks
typeStr := writers.SanitizeStructTagValue(column.Type)
isArray := pgsql.IsArrayType(typeStr)
hasExplicitTypeModifier := pgsql.HasExplicitTypeModifier(typeStr)
if !hasExplicitTypeModifier && column.Length > 0 {
if !hasExplicitTypeModifier && !isArray && column.Length > 0 {
typeStr = fmt.Sprintf("%s(%d)", typeStr, column.Length)
} else if !hasExplicitTypeModifier && column.Precision > 0 {
} else if !hasExplicitTypeModifier && !isArray && column.Precision > 0 {
if column.Scale > 0 {
typeStr = fmt.Sprintf("%s(%d,%d)", typeStr, column.Precision, column.Scale)
} else {
@@ -322,6 +323,9 @@ func (tm *TypeMapper) BuildBunTag(column *models.Column, table *models.Table) st
}
}
parts = append(parts, fmt.Sprintf("type:%s", typeStr))
if isArray && tm.typeStyle == writers.NullableTypeStdlib {
parts = append(parts, "array")
}
}
// Primary key
+24 -4
View File
@@ -102,8 +102,8 @@ func (w *Writer) writeSingleFile(db *models.Database) error {
}
}
// Add fmt import if GetIDStr is enabled
if w.config.GenerateGetIDStr {
// Add fmt import when generated helper methods need string formatting.
if w.needsFmtImport(templateData.Models) {
templateData.AddImport("\"fmt\"")
}
@@ -195,8 +195,8 @@ func (w *Writer) writeMultiFile(db *models.Database) error {
}
}
// Add fmt import if GetIDStr is enabled
if w.config.GenerateGetIDStr {
// Add fmt import when generated helper methods need string formatting.
if w.needsFmtImport(templateData.Models) {
templateData.AddImport("\"fmt\"")
}
@@ -301,6 +301,26 @@ func (w *Writer) addRelationshipFields(modelData *ModelData, table *models.Table
}
}
func (w *Writer) needsFmtImport(models []*ModelData) bool {
if w.config.GenerateGetIDStr {
for _, model := range models {
if model.PrimaryKeyField != "" && !model.PrimaryKeyIsSQL && !model.PrimaryKeyIsStr {
return true
}
}
}
if w.config.GenerateUpdateID {
for _, model := range models {
if model.PrimaryKeyField != "" && model.PrimaryKeyIsSQL && !model.PrimaryKeyIsStr {
return true
}
}
}
return false
}
// findTable finds a table by schema and name in the database
func (w *Writer) findTable(schemaName, tableName string, db *models.Database) *models.Table {
for _, schema := range db.Schemas {
+156
View File
@@ -574,6 +574,10 @@ func TestTypeMapper_SQLTypeToGoType_Bun(t *testing.T) {
{"boolean", false, "resolvespec_common.SqlBool"},
{"uuid", false, "resolvespec_common.SqlUUID"},
{"jsonb", false, "resolvespec_common.SqlJSONB"},
{"text[]", true, "resolvespec_common.SqlStringArray"},
{"text[]", false, "resolvespec_common.SqlStringArray"},
{"integer[]", true, "resolvespec_common.SqlInt32Array"},
{"bigint[]", false, "resolvespec_common.SqlInt64Array"},
}
for _, tt := range tests {
@@ -586,6 +590,116 @@ func TestTypeMapper_SQLTypeToGoType_Bun(t *testing.T) {
}
}
func TestWriter_UpdateIDTypeSafety_Bun(t *testing.T) {
tests := []struct {
name string
pkType string
expectedPK string
expectedLine string
forbidInt32 bool
}{
{"int32_pk", "int", "int32", "m.ID = int32(newid)", false},
{"sql_int16_pk", "smallint", "resolvespec_common.SqlInt16", "m.ID.FromString(fmt.Sprintf(\"%d\", newid))", true},
{"int64_pk", "bigint", "int64", "m.ID = int64(newid)", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
table := models.InitTable("test_table", "public")
table.Columns["id"] = &models.Column{
Name: "id",
Type: tt.pkType,
NotNull: true,
IsPrimaryKey: true,
}
tmpDir := t.TempDir()
opts := &writers.WriterOptions{
PackageName: "models",
OutputPath: filepath.Join(tmpDir, "test.go"),
}
writer := NewWriter(opts)
err := writer.WriteTable(table)
if err != nil {
t.Fatalf("WriteTable failed: %v", err)
}
content, err := os.ReadFile(opts.OutputPath)
if err != nil {
t.Fatalf("Failed to read generated file: %v", err)
}
generated := string(content)
if !strings.Contains(generated, tt.expectedLine) {
t.Errorf("Expected UpdateID to include %s\nGenerated:\n%s", tt.expectedLine, generated)
}
if !strings.Contains(generated, "ID "+tt.expectedPK) {
t.Errorf("Expected generated primary key field type %s\nGenerated:\n%s", tt.expectedPK, generated)
}
if tt.forbidInt32 && strings.Contains(generated, "int32(newid)") {
t.Errorf("UpdateID should not cast to int32 for %s type\nGenerated:\n%s", tt.pkType, generated)
}
if !strings.Contains(generated, "UpdateID(newid int64)") {
t.Errorf("UpdateID should accept int64 parameter\nGenerated:\n%s", generated)
}
})
}
}
func TestWriter_StringPrimaryKeyHelpers_Bun(t *testing.T) {
table := models.InitTable("accounts", "public")
table.Columns["id"] = &models.Column{
Name: "id",
Type: "uuid",
NotNull: true,
IsPrimaryKey: true,
}
tmpDir := t.TempDir()
opts := &writers.WriterOptions{
PackageName: "models",
OutputPath: filepath.Join(tmpDir, "test.go"),
}
writer := NewWriter(opts)
err := writer.WriteTable(table)
if err != nil {
t.Fatalf("WriteTable failed: %v", err)
}
content, err := os.ReadFile(opts.OutputPath)
if err != nil {
t.Fatalf("Failed to read generated file: %v", err)
}
generated := string(content)
expectations := []string{
"resolvespec_common.SqlUUID",
"func (m ModelPublicAccounts) GetID() string",
"return m.ID.String()",
"func (m ModelPublicAccounts) GetIDStr() string",
"func (m ModelPublicAccounts) SetID(newid string)",
"func (m *ModelPublicAccounts) UpdateID(newid string)",
"m.ID.FromString(newid)",
}
for _, expected := range expectations {
if !strings.Contains(generated, expected) {
t.Errorf("Generated code missing expected content: %q\nGenerated:\n%s", expected, generated)
}
}
if strings.Contains(generated, "GetID() int64") || strings.Contains(generated, "UpdateID(newid int64)") {
t.Errorf("String primary keys should not use int64 helper signatures\nGenerated:\n%s", generated)
}
}
func TestTypeMapper_BuildBunTag(t *testing.T) {
mapper := NewTypeMapper("")
@@ -685,6 +799,24 @@ func TestTypeMapper_BuildBunTag(t *testing.T) {
},
want: []string{"id,", "type:bigserial,", "pk,", "autoincrement,"},
},
{
name: "text array type",
column: &models.Column{
Name: "tags",
Type: "text[]",
NotNull: false,
},
want: []string{"tags,", "type:text[],"},
},
{
name: "integer array type",
column: &models.Column{
Name: "scores",
Type: "integer[]",
NotNull: true,
},
want: []string{"scores,", "type:integer[],"},
},
}
for _, tt := range tests {
@@ -695,6 +827,30 @@ func TestTypeMapper_BuildBunTag(t *testing.T) {
t.Errorf("BuildBunTag() = %q, missing %q", result, part)
}
}
// resolvespec mode must NOT add "array" — SqlXxxArray uses sql.Scanner
if strings.Contains(result, ",array,") || strings.HasSuffix(result, ",array,") {
t.Errorf("BuildBunTag() = %q, must not contain 'array' in resolvespec mode", result)
}
})
}
}
func TestTypeMapper_BuildBunTag_StdlibArrayHasArrayTag(t *testing.T) {
mapper := NewTypeMapper(writers.NullableTypeStdlib)
cases := []struct {
name string
column *models.Column
}{
{name: "text array", column: &models.Column{Name: "tags", Type: "text[]"}},
{name: "integer array", column: &models.Column{Name: "scores", Type: "integer[]", NotNull: true}},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
result := mapper.BuildBunTag(tt.column, nil)
if !strings.Contains(result, "array") {
t.Errorf("BuildBunTag() = %q, expected 'array' in stdlib mode", result)
}
})
}
}
+21 -1
View File
@@ -2,6 +2,7 @@ package gorm
import (
"sort"
"strings"
"git.warky.dev/wdevs/relspecgo/pkg/models"
"git.warky.dev/wdevs/relspecgo/pkg/writers"
@@ -26,6 +27,9 @@ type ModelData struct {
Config *MethodConfig
PrimaryKeyField string // Name of the primary key field
PrimaryKeyType string // Go type of the primary key field
PrimaryKeyIsSQL bool // Whether PK uses a SQL wrapper type
PrimaryKeyIsStr bool // Whether helper methods should use string IDs
PrimaryKeyIDType string // Helper method GetID/SetID/UpdateID type
IDColumnName string // Name of the ID column in database
Prefix string // 3-letter prefix
}
@@ -136,7 +140,14 @@ func NewModelData(table *models.Table, schema string, typeMapper *TypeMapper, fl
// Sanitize column name to remove backticks
safeName := writers.SanitizeStructTagValue(col.Name)
model.PrimaryKeyField = SnakeCaseToPascalCase(safeName)
model.PrimaryKeyType = typeMapper.SQLTypeToGoType(col.Type, col.NotNull)
goType := typeMapper.SQLTypeToGoType(col.Type, col.NotNull)
model.PrimaryKeyType = goType
model.PrimaryKeyIsSQL = strings.Contains(goType, "sql_types.") || strings.Contains(goType, "sql.")
model.PrimaryKeyIsStr = isStringLikePrimaryKeyType(goType)
model.PrimaryKeyIDType = "int64"
if model.PrimaryKeyIsStr {
model.PrimaryKeyIDType = "string"
}
model.IDColumnName = safeName
break
}
@@ -189,6 +200,15 @@ func formatComment(description, comment string) string {
return comment
}
func isStringLikePrimaryKeyType(goType string) bool {
switch goType {
case "string", "sql.NullString", "sql_types.SqlString", "sql_types.SqlUUID":
return true
default:
return false
}
}
// resolveFieldNameCollision checks if a field name conflicts with generated method names
// and adds an underscore suffix if there's a collision
func resolveFieldNameCollision(fieldName string) string {
+33 -3
View File
@@ -43,26 +43,56 @@ func (m {{.Name}}) SchemaName() string {
{{end}}
{{if and .Config.GenerateGetID .PrimaryKeyField}}
// GetID returns the primary key value
func (m {{.Name}}) GetID() int64 {
func (m {{.Name}}) GetID() {{.PrimaryKeyIDType}} {
{{if .PrimaryKeyIsSQL -}}
{{if .PrimaryKeyIsStr -}}
return m.{{.PrimaryKeyField}}.String()
{{- else -}}
return m.{{.PrimaryKeyField}}.Int64()
{{- end}}
{{- else -}}
{{if .PrimaryKeyIsStr -}}
return m.{{.PrimaryKeyField}}
{{- else -}}
return int64(m.{{.PrimaryKeyField}})
{{- end}}
{{- end}}
}
{{end}}
{{if and .Config.GenerateGetIDStr .PrimaryKeyField}}
// GetIDStr returns the primary key as a string
func (m {{.Name}}) GetIDStr() string {
{{if .PrimaryKeyIsSQL -}}
return m.{{.PrimaryKeyField}}.String()
{{- else if .PrimaryKeyIsStr -}}
return m.{{.PrimaryKeyField}}
{{- else -}}
return fmt.Sprintf("%d", m.{{.PrimaryKeyField}})
{{- end}}
}
{{end}}
{{if and .Config.GenerateSetID .PrimaryKeyField}}
// SetID sets the primary key value
func (m {{.Name}}) SetID(newid int64) {
func (m {{.Name}}) SetID(newid {{.PrimaryKeyIDType}}) {
m.UpdateID(newid)
}
{{end}}
{{if and .Config.GenerateUpdateID .PrimaryKeyField}}
// UpdateID updates the primary key value
func (m *{{.Name}}) UpdateID(newid int64) {
func (m *{{.Name}}) UpdateID(newid {{.PrimaryKeyIDType}}) {
{{if .PrimaryKeyIsSQL -}}
{{if .PrimaryKeyIsStr -}}
m.{{.PrimaryKeyField}}.FromString(newid)
{{- else -}}
m.{{.PrimaryKeyField}}.FromString(fmt.Sprintf("%d", newid))
{{- end}}
{{- else -}}
{{if .PrimaryKeyIsStr -}}
m.{{.PrimaryKeyField}} = newid
{{- else -}}
m.{{.PrimaryKeyField}} = {{.PrimaryKeyType}}(newid)
{{- end}}
{{- end}}
}
{{end}}
{{if and .Config.GenerateGetIDName .IDColumnName}}
+24 -4
View File
@@ -99,8 +99,8 @@ func (w *Writer) writeSingleFile(db *models.Database) error {
}
}
// Add fmt import if GetIDStr is enabled
if w.config.GenerateGetIDStr {
// Add fmt import when generated helper methods need string formatting.
if w.needsFmtImport(templateData.Models) {
templateData.AddImport("\"fmt\"")
}
@@ -189,8 +189,8 @@ func (w *Writer) writeMultiFile(db *models.Database) error {
}
}
// Add fmt import if GetIDStr is enabled
if w.config.GenerateGetIDStr {
// Add fmt import when generated helper methods need string formatting.
if w.needsFmtImport(templateData.Models) {
templateData.AddImport("\"fmt\"")
}
@@ -295,6 +295,26 @@ func (w *Writer) addRelationshipFields(modelData *ModelData, table *models.Table
}
}
func (w *Writer) needsFmtImport(models []*ModelData) bool {
if w.config.GenerateGetIDStr {
for _, model := range models {
if model.PrimaryKeyField != "" && !model.PrimaryKeyIsSQL && !model.PrimaryKeyIsStr {
return true
}
}
}
if w.config.GenerateUpdateID {
for _, model := range models {
if model.PrimaryKeyField != "" && model.PrimaryKeyIsSQL && !model.PrimaryKeyIsStr {
return true
}
}
}
return false
}
// findTable finds a table by schema and name in the database
func (w *Writer) findTable(schemaName, tableName string, db *models.Database) *models.Table {
for _, schema := range db.Schemas {
+68
View File
@@ -598,6 +598,55 @@ func TestWriter_UpdateIDTypeSafety(t *testing.T) {
}
}
func TestWriter_StringPrimaryKeyHelpers_Gorm(t *testing.T) {
table := models.InitTable("accounts", "public")
table.Columns["id"] = &models.Column{
Name: "id",
Type: "uuid",
NotNull: true,
IsPrimaryKey: true,
}
tmpDir := t.TempDir()
opts := &writers.WriterOptions{
PackageName: "models",
OutputPath: filepath.Join(tmpDir, "test.go"),
}
writer := NewWriter(opts)
err := writer.WriteTable(table)
if err != nil {
t.Fatalf("WriteTable failed: %v", err)
}
content, err := os.ReadFile(opts.OutputPath)
if err != nil {
t.Fatalf("Failed to read generated file: %v", err)
}
generated := string(content)
expectations := []string{
"ID string",
"func (m ModelPublicAccounts) GetID() string",
"return m.ID",
"func (m ModelPublicAccounts) GetIDStr() string",
"func (m ModelPublicAccounts) SetID(newid string)",
"func (m *ModelPublicAccounts) UpdateID(newid string)",
"m.ID = newid",
}
for _, expected := range expectations {
if !strings.Contains(generated, expected) {
t.Errorf("Generated code missing expected content: %q\nGenerated:\n%s", expected, generated)
}
}
if strings.Contains(generated, "GetID() int64") || strings.Contains(generated, "UpdateID(newid int64)") {
t.Errorf("String primary keys should not use int64 helper signatures\nGenerated:\n%s", generated)
}
}
func TestNameConverter_SnakeCaseToPascalCase(t *testing.T) {
tests := []struct {
input string
@@ -658,6 +707,10 @@ func TestTypeMapper_SQLTypeToGoType(t *testing.T) {
{"timestamp", false, "sql_types.SqlTimeStamp"},
{"boolean", true, "bool"},
{"boolean", false, "sql_types.SqlBool"},
{"text[]", true, "sql_types.SqlStringArray"},
{"text[]", false, "sql_types.SqlStringArray"},
{"integer[]", true, "sql_types.SqlInt32Array"},
{"bigint[]", false, "sql_types.SqlInt64Array"},
}
for _, tt := range tests {
@@ -670,6 +723,21 @@ func TestTypeMapper_SQLTypeToGoType(t *testing.T) {
}
}
func TestTypeMapper_BuildGormTag_ArrayType(t *testing.T) {
mapper := NewTypeMapper("")
col := &models.Column{
Name: "tags",
Type: "text[]",
NotNull: false,
}
tag := mapper.BuildGormTag(col, nil)
if !strings.Contains(tag, "type:text[]") {
t.Fatalf("expected array type to be preserved, got %q", tag)
}
}
func TestTypeMapper_BuildGormTag_PreservesExplicitTypeModifiers(t *testing.T) {
mapper := NewTypeMapper("")
+142 -22
View File
@@ -31,6 +31,10 @@ type MigrationWriter struct {
// NewMigrationWriter creates a new templated migration writer
func NewMigrationWriter(options *writers.WriterOptions) (*MigrationWriter, error) {
if options == nil {
options = &writers.WriterOptions{}
}
executor, err := NewTemplateExecutor(options.FlattenSchema)
if err != nil {
return nil, fmt.Errorf("failed to create template executor: %w", err)
@@ -44,6 +48,16 @@ func NewMigrationWriter(options *writers.WriterOptions) (*MigrationWriter, error
// WriteMigration generates migration scripts using templates
func (w *MigrationWriter) WriteMigration(model *models.Database, current *models.Database) error {
if model == nil {
return fmt.Errorf("model database is required")
}
if w.options == nil {
w.options = &writers.WriterOptions{}
}
if current == nil {
current = models.InitDatabase(model.Name)
}
var writer io.Writer
var file *os.File
var err error
@@ -86,9 +100,16 @@ func (w *MigrationWriter) WriteMigration(model *models.Database, current *models
// Process each schema in the model
for _, modelSchema := range model.Schemas {
if modelSchema == nil {
continue
}
// Find corresponding schema in current database
var currentSchema *models.Schema
for _, cs := range current.Schemas {
if cs == nil {
continue
}
if strings.EqualFold(cs.Name, modelSchema.Name) {
currentSchema = cs
break
@@ -123,7 +144,11 @@ func (w *MigrationWriter) WriteMigration(model *models.Database, current *models
// Write header
fmt.Fprintf(w.writer, "-- PostgreSQL Migration Script\n")
fmt.Fprintf(w.writer, "-- Generated by RelSpec\n")
fmt.Fprintf(w.writer, "-- Source: %s -> %s\n\n", current.Name, model.Name)
fmt.Fprintf(w.writer, "-- Source: %s -> %s\n", current.Name, model.Name)
if w.options.ContinueOnError {
fmt.Fprintf(w.writer, "\\set ON_ERROR_STOP off\n")
}
fmt.Fprintf(w.writer, "\n")
// Write scripts
for _, script := range scripts {
@@ -139,13 +164,26 @@ func (w *MigrationWriter) WriteMigration(model *models.Database, current *models
func (w *MigrationWriter) generateSchemaScripts(model *models.Schema, current *models.Schema) ([]MigrationScript, error) {
scripts := make([]MigrationScript, 0)
// Phase 1: Drop constraints and indexes that changed (Priority 11-50)
if schemaRequiresPGTrgm(model) {
scripts = append(scripts, MigrationScript{
ObjectName: "extension.pg_trgm",
ObjectType: "create extension",
Schema: model.Name,
Priority: 80,
Sequence: len(scripts),
Body: "CREATE EXTENSION IF NOT EXISTS pg_trgm;",
})
}
// Phase 1: Drop constraints and indexes that changed (Priority 5-50)
var droppedFKs map[string]bool
if current != nil {
dropScripts, err := w.generateDropScripts(model, current)
dropScripts, dropped, err := w.generateDropScripts(model, current)
if err != nil {
return nil, fmt.Errorf("failed to generate drop scripts: %w", err)
}
scripts = append(scripts, dropScripts...)
droppedFKs = dropped
}
// Phase 3: Create/Alter tables and columns (Priority 100-145)
@@ -163,7 +201,7 @@ func (w *MigrationWriter) generateSchemaScripts(model *models.Schema, current *m
scripts = append(scripts, indexScripts...)
// Phase 5: Create foreign keys (Priority 195)
fkScripts, err := w.generateForeignKeyScripts(model, current)
fkScripts, err := w.generateForeignKeyScripts(model, current, droppedFKs)
if err != nil {
return nil, fmt.Errorf("failed to generate foreign key scripts: %w", err)
}
@@ -179,9 +217,12 @@ func (w *MigrationWriter) generateSchemaScripts(model *models.Schema, current *m
return scripts, nil
}
// generateDropScripts generates DROP scripts using templates
func (w *MigrationWriter) generateDropScripts(model *models.Schema, current *models.Schema) ([]MigrationScript, error) {
// generateDropScripts generates DROP scripts using templates.
// Returns the scripts and a set of FK constraint keys (schema.table.name) that were
// explicitly dropped because their referenced PK was being dropped, so they can be force-recreated.
func (w *MigrationWriter) generateDropScripts(model *models.Schema, current *models.Schema) ([]MigrationScript, map[string]bool, error) {
scripts := make([]MigrationScript, 0)
droppedFKs := make(map[string]bool)
// Build map of model tables for quick lookup
modelTables := make(map[string]*models.Table)
@@ -208,6 +249,44 @@ func (w *MigrationWriter) generateDropScripts(model *models.Schema, current *mod
shouldDrop = true
}
if shouldDrop && currentConstraint.Type == models.PrimaryKeyConstraint {
// Drop FK constraints that depend on this PK before dropping the PK itself.
for _, otherTable := range current.Tables {
for fkName, fkConstraint := range otherTable.Constraints {
if fkConstraint.Type != models.ForeignKeyConstraint {
continue
}
refTable := fkConstraint.ReferencedTable
refSchema := fkConstraint.ReferencedSchema
if refSchema == "" {
refSchema = current.Name
}
if strings.EqualFold(refTable, currentTable.Name) && strings.EqualFold(refSchema, current.Name) {
fkKey := fmt.Sprintf("%s.%s.%s", current.Name, otherTable.Name, fkName)
if !droppedFKs[fkKey] {
droppedFKs[fkKey] = true
sql, err := w.executor.ExecuteDropConstraint(DropConstraintData{
SchemaName: current.Name,
TableName: otherTable.Name,
ConstraintName: fkName,
})
if err != nil {
return nil, nil, err
}
scripts = append(scripts, MigrationScript{
ObjectName: fkKey,
ObjectType: "drop constraint",
Schema: current.Name,
Priority: 5,
Sequence: len(scripts),
Body: sql,
})
}
}
}
}
}
if shouldDrop {
sql, err := w.executor.ExecuteDropConstraint(DropConstraintData{
SchemaName: current.Name,
@@ -215,7 +294,7 @@ func (w *MigrationWriter) generateDropScripts(model *models.Schema, current *mod
ConstraintName: constraintName,
})
if err != nil {
return nil, err
return nil, nil, err
}
script := MigrationScript{
@@ -247,7 +326,7 @@ func (w *MigrationWriter) generateDropScripts(model *models.Schema, current *mod
IndexName: indexName,
})
if err != nil {
return nil, err
return nil, nil, err
}
script := MigrationScript{
@@ -263,7 +342,7 @@ func (w *MigrationWriter) generateDropScripts(model *models.Schema, current *mod
}
}
return scripts, nil
return scripts, droppedFKs, nil
}
// generateTableScripts generates CREATE/ALTER TABLE scripts using templates
@@ -340,7 +419,7 @@ func (w *MigrationWriter) generateAlterTableScripts(schema *models.Schema, model
SchemaName: schema.Name,
TableName: modelTable.Name,
ColumnName: modelCol.Name,
ColumnType: pgsql.ConvertSQLType(modelCol.Type),
ColumnType: effectiveColumnSQLType(modelCol),
Default: defaultVal,
NotNull: modelCol.NotNull,
})
@@ -359,12 +438,13 @@ func (w *MigrationWriter) generateAlterTableScripts(schema *models.Schema, model
scripts = append(scripts, script)
} else if !columnsEqual(modelCol, currentCol) {
// Column exists but properties changed
if modelCol.Type != currentCol.Type {
if !columnTypesEqual(modelCol, currentCol) {
sql, err := w.executor.ExecuteAlterColumnType(AlterColumnTypeData{
SchemaName: schema.Name,
TableName: modelTable.Name,
ColumnName: modelCol.Name,
NewType: pgsql.ConvertSQLType(modelCol.Type),
NewType: effectiveAlterColumnSQLType(modelCol),
UsingExpr: buildAlterColumnUsingExpression(modelCol.Name, effectiveAlterColumnSQLType(modelCol)),
})
if err != nil {
return nil, err
@@ -545,12 +625,17 @@ func (w *MigrationWriter) generateIndexScripts(model *models.Schema, current *mo
indexType = modelIndex.Type
}
columnExprs := buildIndexColumnExpressions(modelTable, modelIndex, indexType)
if len(columnExprs) == 0 {
continue
}
sql, err := w.executor.ExecuteCreateIndex(CreateIndexData{
SchemaName: model.Name,
TableName: modelTable.Name,
IndexName: indexName,
IndexType: indexType,
Columns: strings.Join(modelIndex.Columns, ", "),
Columns: strings.Join(columnExprs, ", "),
Unique: modelIndex.Unique,
})
if err != nil {
@@ -573,8 +658,30 @@ func (w *MigrationWriter) generateIndexScripts(model *models.Schema, current *mo
return scripts, nil
}
// generateForeignKeyScripts generates ADD CONSTRAINT FOREIGN KEY scripts using templates
func (w *MigrationWriter) generateForeignKeyScripts(model *models.Schema, current *models.Schema) ([]MigrationScript, error) {
func buildIndexColumnExpressions(table *models.Table, index *models.Index, indexType string) []string {
columnExprs := make([]string, 0, len(index.Columns))
for _, colName := range index.Columns {
colExpr := colName
if table != nil {
if col, ok := resolveIndexColumn(table, colName); ok && col != nil {
colExpr = col.SQLName()
if strings.EqualFold(indexType, "gin") {
opClass := ginOperatorClassForColumn(col, index.Comment)
if opClass != "" {
colExpr = fmt.Sprintf("%s %s", col.SQLName(), opClass)
}
}
}
}
columnExprs = append(columnExprs, colExpr)
}
return columnExprs
}
// generateForeignKeyScripts generates ADD CONSTRAINT FOREIGN KEY scripts using templates.
// forceRecreate is a set of FK constraint keys (schema.table.name) that must be recreated
// even if unchanged, because their referenced PK was dropped and recreated.
func (w *MigrationWriter) generateForeignKeyScripts(model *models.Schema, current *models.Schema, forceRecreate map[string]bool) ([]MigrationScript, error) {
scripts := make([]MigrationScript, 0)
// Build map of current tables
@@ -595,13 +702,16 @@ func (w *MigrationWriter) generateForeignKeyScripts(model *models.Schema, curren
continue
}
shouldCreate := true
fkKey := fmt.Sprintf("%s.%s.%s", model.Name, modelTable.Name, constraintName)
shouldCreate := forceRecreate[fkKey]
if currentTable != nil {
if currentConstraint, exists := currentTable.Constraints[constraintName]; exists {
if constraintsEqual(constraint, currentConstraint) {
shouldCreate = false
}
if !shouldCreate {
if currentTable == nil {
shouldCreate = true
} else if currentConstraint, exists := currentTable.Constraints[constraintName]; !exists {
shouldCreate = true
} else if !constraintsEqual(constraint, currentConstraint) {
shouldCreate = true
}
}
@@ -828,11 +938,21 @@ func columnsEqual(col1, col2 *models.Column) bool {
if col1 == nil || col2 == nil {
return false
}
return strings.EqualFold(col1.Type, col2.Type) &&
return columnTypesEqual(col1, col2) &&
col1.NotNull == col2.NotNull &&
fmt.Sprintf("%v", col1.Default) == fmt.Sprintf("%v", col2.Default)
}
func columnTypesEqual(col1, col2 *models.Column) bool {
if col1 == nil || col2 == nil {
return false
}
return strings.EqualFold(
pgsql.NormalizeEquivalentSQLType(effectiveColumnSQLType(col1)),
pgsql.NormalizeEquivalentSQLType(effectiveColumnSQLType(col2)),
)
}
// constraintsEqual checks if two constraints are equal
func constraintsEqual(c1, c2 *models.Constraint) bool {
if c1 == nil || c2 == nil {
+407
View File
@@ -97,6 +97,370 @@ func TestWriteMigration_ArrayDefault(t *testing.T) {
}
}
func TestWriteMigration_AltersColumnTypeWhenActualTypeDiffers(t *testing.T) {
current := models.InitDatabase("testdb")
currentSchema := models.InitSchema("public")
currentTable := models.InitTable("learnings", "public")
currentDetails := models.InitColumn("details", "learnings", "public")
currentDetails.Type = "jsonb"
currentTable.Columns["details"] = currentDetails
currentSchema.Tables = append(currentSchema.Tables, currentTable)
current.Schemas = append(current.Schemas, currentSchema)
model := models.InitDatabase("testdb")
modelSchema := models.InitSchema("public")
modelTable := models.InitTable("learnings", "public")
modelDetails := models.InitColumn("details", "learnings", "public")
modelDetails.Type = "text"
modelTable.Columns["details"] = modelDetails
modelSchema.Tables = append(modelSchema.Tables, modelTable)
model.Schemas = append(model.Schemas, modelSchema)
var buf bytes.Buffer
writer, err := NewMigrationWriter(&writers.WriterOptions{})
if err != nil {
t.Fatalf("Failed to create writer: %v", err)
}
writer.writer = &buf
if err := writer.WriteMigration(model, current); err != nil {
t.Fatalf("WriteMigration failed: %v", err)
}
output := buf.String()
if !strings.Contains(output, "ALTER TABLE public.learnings") || !strings.Contains(output, "ALTER COLUMN details TYPE text") {
t.Fatalf("expected migration to alter mismatched column type, got:\n%s", output)
}
if !strings.Contains(output, `ALTER COLUMN details TYPE text USING details::text;`) {
t.Fatalf("expected migration type alter to include USING cast, got:\n%s", output)
}
}
func TestWriteMigration_UsesStorageTypeForSerialAlterStatements(t *testing.T) {
current := models.InitDatabase("testdb")
currentSchema := models.InitSchema("public")
currentTable := models.InitTable("learnings", "public")
currentID := models.InitColumn("id", "learnings", "public")
currentID.Type = "uuid"
currentTable.Columns["id"] = currentID
currentSchema.Tables = append(currentSchema.Tables, currentTable)
current.Schemas = append(current.Schemas, currentSchema)
model := models.InitDatabase("testdb")
modelSchema := models.InitSchema("public")
modelTable := models.InitTable("learnings", "public")
modelID := models.InitColumn("id", "learnings", "public")
modelID.Type = "bigserial"
modelTable.Columns["id"] = modelID
modelSchema.Tables = append(modelSchema.Tables, modelTable)
model.Schemas = append(model.Schemas, modelSchema)
var buf bytes.Buffer
writer, err := NewMigrationWriter(&writers.WriterOptions{})
if err != nil {
t.Fatalf("Failed to create writer: %v", err)
}
writer.writer = &buf
if err := writer.WriteMigration(model, current); err != nil {
t.Fatalf("WriteMigration failed: %v", err)
}
output := buf.String()
if !strings.Contains(output, "ALTER COLUMN id TYPE bigint") {
t.Fatalf("expected serial alter to use bigint storage type, got:\n%s", output)
}
if strings.Contains(output, "ALTER COLUMN id TYPE bigserial;") {
t.Fatalf("did not expect invalid bigserial alter statement, got:\n%s", output)
}
if !strings.Contains(output, `ALTER COLUMN id TYPE bigint USING id::bigint;`) {
t.Fatalf("expected serial alter to include USING cast, got:\n%s", output)
}
}
func TestWriteMigration_ArrayAlterIncludesUsingCast(t *testing.T) {
current := models.InitDatabase("testdb")
currentSchema := models.InitSchema("public")
currentTable := models.InitTable("learnings", "public")
currentTags := models.InitColumn("tags", "learnings", "public")
currentTags.Type = "text"
currentTable.Columns["tags"] = currentTags
currentSchema.Tables = append(currentSchema.Tables, currentTable)
current.Schemas = append(current.Schemas, currentSchema)
model := models.InitDatabase("testdb")
modelSchema := models.InitSchema("public")
modelTable := models.InitTable("learnings", "public")
modelTags := models.InitColumn("tags", "learnings", "public")
modelTags.Type = "text[]"
modelTable.Columns["tags"] = modelTags
modelSchema.Tables = append(modelSchema.Tables, modelTable)
model.Schemas = append(model.Schemas, modelSchema)
var buf bytes.Buffer
writer, err := NewMigrationWriter(&writers.WriterOptions{})
if err != nil {
t.Fatalf("Failed to create writer: %v", err)
}
writer.writer = &buf
if err := writer.WriteMigration(model, current); err != nil {
t.Fatalf("WriteMigration failed: %v", err)
}
output := buf.String()
if !strings.Contains(output, `ALTER COLUMN tags TYPE text[] USING tags::text[];`) {
t.Fatalf("expected array alter to include USING cast, got:\n%s", output)
}
}
func TestWriteMigration_DoesNotAlterEquivalentNormalizedColumnType(t *testing.T) {
current := models.InitDatabase("testdb")
currentSchema := models.InitSchema("public")
currentTable := models.InitTable("users", "public")
currentEmail := models.InitColumn("email", "users", "public")
currentEmail.Type = "character varying"
currentEmail.Length = 255
currentTable.Columns["email"] = currentEmail
currentSchema.Tables = append(currentSchema.Tables, currentTable)
current.Schemas = append(current.Schemas, currentSchema)
model := models.InitDatabase("testdb")
modelSchema := models.InitSchema("public")
modelTable := models.InitTable("users", "public")
modelEmail := models.InitColumn("email", "users", "public")
modelEmail.Type = "varchar(255)"
modelTable.Columns["email"] = modelEmail
modelSchema.Tables = append(modelSchema.Tables, modelTable)
model.Schemas = append(model.Schemas, modelSchema)
var buf bytes.Buffer
writer, err := NewMigrationWriter(&writers.WriterOptions{})
if err != nil {
t.Fatalf("Failed to create writer: %v", err)
}
writer.writer = &buf
if err := writer.WriteMigration(model, current); err != nil {
t.Fatalf("WriteMigration failed: %v", err)
}
output := buf.String()
if strings.Contains(output, "ALTER COLUMN email TYPE") {
t.Fatalf("did not expect alter type for equivalent normalized types, got:\n%s", output)
}
}
func TestWriteMigration_GinIndexOnTextUsesTrigramOperatorClass(t *testing.T) {
current := models.InitDatabase("testdb")
currentSchema := models.InitSchema("public")
current.Schemas = append(current.Schemas, currentSchema)
model := models.InitDatabase("testdb")
modelSchema := models.InitSchema("public")
table := models.InitTable("articles", "public")
titleCol := models.InitColumn("title", "articles", "public")
titleCol.Type = "text"
table.Columns["title"] = titleCol
index := &models.Index{
Name: "idx_articles_title_gin",
Type: "gin",
Columns: []string{"title"},
}
table.Indexes[index.Name] = index
modelSchema.Tables = append(modelSchema.Tables, table)
model.Schemas = append(model.Schemas, modelSchema)
var buf bytes.Buffer
writer, err := NewMigrationWriter(&writers.WriterOptions{})
if err != nil {
t.Fatalf("Failed to create writer: %v", err)
}
writer.writer = &buf
if err := writer.WriteMigration(model, current); err != nil {
t.Fatalf("WriteMigration failed: %v", err)
}
output := buf.String()
if !strings.Contains(output, "CREATE EXTENSION IF NOT EXISTS pg_trgm;") {
t.Fatalf("expected trigram extension for text GIN migration index, got:\n%s", output)
}
if !strings.Contains(output, "USING gin (title gin_trgm_ops)") {
t.Fatalf("expected GIN text index to include gin_trgm_ops, got:\n%s", output)
}
}
func TestWriteMigration_GinIndexOnQuotedTextColumnUsesTrigramOperatorClass(t *testing.T) {
current := models.InitDatabase("testdb")
currentSchema := models.InitSchema("public")
current.Schemas = append(current.Schemas, currentSchema)
model := models.InitDatabase("testdb")
modelSchema := models.InitSchema("public")
table := models.InitTable("agent_personas", "public")
nameCol := models.InitColumn("name", "agent_personas", "public")
nameCol.Type = "text"
table.Columns["name"] = nameCol
index := &models.Index{
Name: "idx_agent_personas_name_gin",
Type: "gin",
Columns: []string{`"name"`},
}
table.Indexes[index.Name] = index
modelSchema.Tables = append(modelSchema.Tables, table)
model.Schemas = append(model.Schemas, modelSchema)
var buf bytes.Buffer
writer, err := NewMigrationWriter(&writers.WriterOptions{})
if err != nil {
t.Fatalf("Failed to create writer: %v", err)
}
writer.writer = &buf
if err := writer.WriteMigration(model, current); err != nil {
t.Fatalf("WriteMigration failed: %v", err)
}
output := buf.String()
if !strings.Contains(output, "USING gin (name gin_trgm_ops)") {
t.Fatalf("expected quoted text column GIN index to include gin_trgm_ops, got:\n%s", output)
}
}
func TestWriteMigration_GinIndexOnTextArrayDoesNotUseTrigramOperatorClass(t *testing.T) {
current := models.InitDatabase("testdb")
currentSchema := models.InitSchema("public")
current.Schemas = append(current.Schemas, currentSchema)
model := models.InitDatabase("testdb")
modelSchema := models.InitSchema("public")
table := models.InitTable("plans", "public")
tagsCol := models.InitColumn("tags", "plans", "public")
tagsCol.Type = "text[]"
table.Columns["tags"] = tagsCol
index := &models.Index{
Name: "idx_plans_tags",
Type: "gin",
Columns: []string{"tags"},
}
table.Indexes[index.Name] = index
modelSchema.Tables = append(modelSchema.Tables, table)
model.Schemas = append(model.Schemas, modelSchema)
var buf bytes.Buffer
writer, err := NewMigrationWriter(&writers.WriterOptions{})
if err != nil {
t.Fatalf("Failed to create writer: %v", err)
}
writer.writer = &buf
if err := writer.WriteMigration(model, current); err != nil {
t.Fatalf("WriteMigration failed: %v", err)
}
output := buf.String()
if !strings.Contains(output, "USING gin (tags array_ops)") {
t.Fatalf("expected GIN array index with array_ops, got:\n%s", output)
}
if strings.Contains(output, "gin_trgm_ops") {
t.Fatalf("did not expect gin_trgm_ops for text[] migration index, got:\n%s", output)
}
}
func TestWriteMigration_GinIndexOnJSONBUsesJSONBOperatorClass(t *testing.T) {
current := models.InitDatabase("testdb")
currentSchema := models.InitSchema("public")
current.Schemas = append(current.Schemas, currentSchema)
model := models.InitDatabase("testdb")
modelSchema := models.InitSchema("public")
table := models.InitTable("learnings", "public")
detailsCol := models.InitColumn("details", "learnings", "public")
detailsCol.Type = "jsonb"
table.Columns["details"] = detailsCol
index := &models.Index{
Name: "idx_learnings_details",
Type: "gin",
Columns: []string{"details"},
}
table.Indexes[index.Name] = index
modelSchema.Tables = append(modelSchema.Tables, table)
model.Schemas = append(model.Schemas, modelSchema)
var buf bytes.Buffer
writer, err := NewMigrationWriter(&writers.WriterOptions{})
if err != nil {
t.Fatalf("Failed to create writer: %v", err)
}
writer.writer = &buf
if err := writer.WriteMigration(model, current); err != nil {
t.Fatalf("WriteMigration failed: %v", err)
}
output := buf.String()
if !strings.Contains(output, "USING gin (details jsonb_ops)") {
t.Fatalf("expected GIN jsonb index to include jsonb_ops, got:\n%s", output)
}
if strings.Contains(output, "gin_trgm_ops") {
t.Fatalf("did not expect gin_trgm_ops for jsonb migration index, got:\n%s", output)
}
}
func TestWriteMigration_GinIndexOnJSONBIgnoresIncompatibleTrigramOperatorClass(t *testing.T) {
current := models.InitDatabase("testdb")
currentSchema := models.InitSchema("public")
current.Schemas = append(current.Schemas, currentSchema)
model := models.InitDatabase("testdb")
modelSchema := models.InitSchema("public")
table := models.InitTable("learnings", "public")
detailsCol := models.InitColumn("details", "learnings", "public")
detailsCol.Type = "jsonb"
table.Columns["details"] = detailsCol
index := &models.Index{
Name: "idx_learnings_details",
Type: "gin",
Columns: []string{"details"},
Comment: "gin_trgm_ops",
}
table.Indexes[index.Name] = index
modelSchema.Tables = append(modelSchema.Tables, table)
model.Schemas = append(model.Schemas, modelSchema)
var buf bytes.Buffer
writer, err := NewMigrationWriter(&writers.WriterOptions{})
if err != nil {
t.Fatalf("Failed to create writer: %v", err)
}
writer.writer = &buf
if err := writer.WriteMigration(model, current); err != nil {
t.Fatalf("WriteMigration failed: %v", err)
}
output := buf.String()
if !strings.Contains(output, "USING gin (details jsonb_ops)") {
t.Fatalf("expected incompatible trigram hint on jsonb to fall back to jsonb_ops, got:\n%s", output)
}
}
func TestWriteMigration_WithAudit(t *testing.T) {
// Current database (empty)
current := models.InitDatabase("testdb")
@@ -322,3 +686,46 @@ func TestWriteMigration_NumericConstraintNames(t *testing.T) {
t.Error("Migration missing FOREIGN KEY")
}
}
func TestNewMigrationWriter_NilOptions(t *testing.T) {
writer, err := NewMigrationWriter(nil)
if err != nil {
t.Fatalf("NewMigrationWriter(nil) returned error: %v", err)
}
if writer == nil {
t.Fatal("expected writer instance")
}
if writer.options == nil {
t.Fatal("expected default writer options to be initialized")
}
}
func TestWriteMigration_NilCurrentTreatsDatabaseAsEmpty(t *testing.T) {
model := models.InitDatabase("testdb")
modelSchema := models.InitSchema("public")
table := models.InitTable("users", "public")
idCol := models.InitColumn("id", "users", "public")
idCol.Type = "integer"
idCol.NotNull = true
table.Columns["id"] = idCol
modelSchema.Tables = append(modelSchema.Tables, table)
model.Schemas = append(model.Schemas, modelSchema)
var buf bytes.Buffer
writer, err := NewMigrationWriter(nil)
if err != nil {
t.Fatalf("Failed to create writer: %v", err)
}
writer.writer = &buf
if err := writer.WriteMigration(model, nil); err != nil {
t.Fatalf("WriteMigration with nil current failed: %v", err)
}
output := buf.String()
if !strings.Contains(output, "CREATE TABLE") {
t.Fatalf("expected CREATE TABLE in migration output, got:\n%s", output)
}
}
+5 -15
View File
@@ -5,6 +5,8 @@ import (
"regexp"
"strings"
"unicode"
"git.warky.dev/wdevs/relspecgo/pkg/pgsql"
)
// TemplateFunctions returns a map of custom template functions
@@ -162,24 +164,12 @@ func quoteIdent(s string) string {
// Type conversion functions
// goTypeToSQL converts Go type to PostgreSQL type
// goTypeToSQL converts Go type to PostgreSQL type using the shared pgsql type map.
func goTypeToSQL(goType string) string {
typeMap := map[string]string{
"string": "text",
"int": "integer",
"int32": "integer",
"int64": "bigint",
"float32": "real",
"float64": "double precision",
"bool": "boolean",
"time.Time": "timestamp",
"[]byte": "bytea",
}
if sqlType, ok := typeMap[goType]; ok {
if sqlType, ok := pgsql.GoToPGSQLTypes[goType]; ok {
return sqlType
}
return "text" // Default
return "text"
}
// sqlTypeToGo converts PostgreSQL type to Go type
+20
View File
@@ -95,6 +95,16 @@ type AlterColumnTypeData struct {
TableName string
ColumnName string
NewType string
UsingExpr string
}
type AlterColumnTypeWithCheckData struct {
SchemaName string
TableName string
ColumnName string
NewType string
EquivalentTypes string
UsingExpr string
}
// AlterColumnDefaultData contains data for alter column default template
@@ -267,6 +277,7 @@ type CreatePrimaryKeyWithAutoGenCheckData struct {
ConstraintName string
AutoGenNames string // Comma-separated list of names like "'name1', 'name2'"
Columns string
ColumnNames string // Comma-separated list of quoted column names like "'id', 'tenant_id'"
}
// Execute methods for each template
@@ -301,6 +312,15 @@ func (te *TemplateExecutor) ExecuteAlterColumnType(data AlterColumnTypeData) (st
return buf.String(), nil
}
func (te *TemplateExecutor) ExecuteAlterColumnTypeWithCheck(data AlterColumnTypeWithCheckData) (string, error) {
var buf bytes.Buffer
err := te.templates.ExecuteTemplate(&buf, "alter_column_type_with_check.tmpl", data)
if err != nil {
return "", fmt.Errorf("failed to execute alter_column_type_with_check template: %w", err)
}
return buf.String(), nil
}
// ExecuteAlterColumnDefault executes the alter column default template
func (te *TemplateExecutor) ExecuteAlterColumnDefault(data AlterColumnDefaultData) (string, error) {
var buf bytes.Buffer
@@ -1,2 +1,2 @@
ALTER TABLE {{qual_table .SchemaName .TableName}}
ALTER COLUMN {{quote_ident .ColumnName}} TYPE {{.NewType}};
ALTER COLUMN {{quote_ident .ColumnName}} TYPE {{.NewType}}{{if .UsingExpr}} USING {{.UsingExpr}}{{end}};
@@ -0,0 +1,22 @@
DO $$
DECLARE
current_type text;
BEGIN
SELECT pg_catalog.format_type(a.atttypid, a.atttypmod)
INTO current_type
FROM pg_attribute a
JOIN pg_class t ON t.oid = a.attrelid
JOIN pg_namespace n ON n.oid = t.relnamespace
WHERE n.nspname = '{{.SchemaName}}'
AND t.relname = '{{.TableName}}'
AND a.attname = '{{.ColumnName}}'
AND a.attnum > 0
AND NOT a.attisdropped;
IF current_type IS NOT NULL
AND current_type <> ALL(ARRAY[{{.EquivalentTypes}}]) THEN
ALTER TABLE {{qual_table .SchemaName .TableName}}
ALTER COLUMN {{quote_ident .ColumnName}} TYPE {{.NewType}}{{if .UsingExpr}} USING {{.UsingExpr}}{{end}};
END IF;
END;
$$;
@@ -1,26 +1,42 @@
DO $$
DECLARE
auto_pk_name text;
current_pk_name text;
current_pk_matches boolean := false;
BEGIN
-- Drop auto-generated primary key if it exists
SELECT constraint_name INTO auto_pk_name
FROM information_schema.table_constraints
WHERE table_schema = '{{.SchemaName}}'
AND table_name = '{{.TableName}}'
AND constraint_type = 'PRIMARY KEY'
AND constraint_name IN ({{.AutoGenNames}});
SELECT tc.constraint_name,
COALESCE(
ARRAY(
SELECT a.attname::text
FROM pg_constraint c
JOIN pg_class t ON t.oid = c.conrelid
JOIN pg_namespace n ON n.oid = t.relnamespace
JOIN unnest(c.conkey) WITH ORDINALITY AS cols(attnum, ord)
ON TRUE
JOIN pg_attribute a
ON a.attrelid = t.oid
AND a.attnum = cols.attnum
WHERE c.contype = 'p'
AND n.nspname = '{{.SchemaName}}'
AND t.relname = '{{.TableName}}'
ORDER BY cols.ord
),
ARRAY[]::text[]
) = ARRAY[{{.ColumnNames}}]
INTO current_pk_name, current_pk_matches
FROM information_schema.table_constraints tc
WHERE tc.table_schema = '{{.SchemaName}}'
AND tc.table_name = '{{.TableName}}'
AND tc.constraint_type = 'PRIMARY KEY';
IF auto_pk_name IS NOT NULL THEN
EXECUTE 'ALTER TABLE {{qual_table .SchemaName .TableName}} DROP CONSTRAINT ' || quote_ident(auto_pk_name);
IF current_pk_name IS NOT NULL
AND NOT current_pk_matches
AND current_pk_name IN ({{.AutoGenNames}}) THEN
EXECUTE 'ALTER TABLE {{qual_table .SchemaName .TableName}} DROP CONSTRAINT ' || quote_ident(current_pk_name) || ' CASCADE';
END IF;
-- Add named primary key if it doesn't exist
IF NOT EXISTS (
SELECT 1 FROM information_schema.table_constraints
WHERE table_schema = '{{.SchemaName}}'
AND table_name = '{{.TableName}}'
AND constraint_name = '{{.ConstraintName}}'
) THEN
-- Add the desired primary key only when no matching primary key already exists.
IF current_pk_name IS NULL
OR (NOT current_pk_matches AND current_pk_name IN ({{.AutoGenNames}})) THEN
ALTER TABLE {{qual_table .SchemaName .TableName}} ADD CONSTRAINT {{quote_ident .ConstraintName}} PRIMARY KEY ({{.Columns}});
END IF;
END;
+259 -48
View File
@@ -99,7 +99,11 @@ func (w *Writer) WriteDatabase(db *models.Database) error {
// Write header comment
fmt.Fprintf(w.writer, "-- PostgreSQL Database Schema\n")
fmt.Fprintf(w.writer, "-- Database: %s\n", db.Name)
fmt.Fprintf(w.writer, "-- Generated by RelSpec\n\n")
fmt.Fprintf(w.writer, "-- Generated by RelSpec\n")
if w.options.ContinueOnError {
fmt.Fprintf(w.writer, "\\set ON_ERROR_STOP off\n")
}
fmt.Fprintf(w.writer, "\n")
// Process each schema in the database
for _, schema := range db.Schemas {
@@ -143,6 +147,10 @@ func (w *Writer) GenerateSchemaStatements(schema *models.Schema) ([]string, erro
statements = append(statements, fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %s", schema.SQLName()))
}
if schemaRequiresPGTrgm(schema) {
statements = append(statements, `CREATE EXTENSION IF NOT EXISTS pg_trgm`)
}
// Phase 2: Create sequences
for _, table := range schema.Tables {
pk := table.GetPrimaryKey()
@@ -181,6 +189,12 @@ func (w *Writer) GenerateSchemaStatements(schema *models.Schema) ([]string, erro
}
statements = append(statements, addColStmts...)
alterTypeStmts, err := w.GenerateAlterColumnTypeStatements(schema)
if err != nil {
return nil, fmt.Errorf("failed to generate alter column type statements: %w", err)
}
statements = append(statements, alterTypeStmts...)
// Phase 4: Primary keys
for _, table := range schema.Tables {
// First check for explicit PrimaryKeyConstraint
@@ -228,6 +242,7 @@ func (w *Writer) GenerateSchemaStatements(schema *models.Schema) ([]string, erro
ConstraintName: pkName,
AutoGenNames: formatStringList(autoGenPKNames),
Columns: strings.Join(pkColumns, ", "),
ColumnNames: formatStringList(pkColumns),
}
stmt, err := w.executor.ExecuteCreatePrimaryKeyWithAutoGenCheck(data)
@@ -260,16 +275,13 @@ func (w *Writer) GenerateSchemaStatements(schema *models.Schema) ([]string, erro
columnExprs := make([]string, 0, len(index.Columns))
for _, colName := range index.Columns {
colExpr := colName
if col, ok := table.Columns[colName]; ok {
// For GIN indexes on text columns, add operator class
if strings.EqualFold(indexType, "gin") && isTextType(col.Type) {
opClass := extractOperatorClass(index.Comment)
if opClass == "" {
opClass = "gin_trgm_ops"
}
if col, ok := resolveIndexColumn(table, colName); ok {
if strings.EqualFold(indexType, "gin") {
if opClass := ginOperatorClassForColumn(col, index.Comment); opClass != "" {
colExpr = fmt.Sprintf("%s %s", colName, opClass)
}
}
}
columnExprs = append(columnExprs, colExpr)
}
@@ -436,6 +448,33 @@ func (w *Writer) GenerateAddColumnStatements(schema *models.Schema) ([]string, e
return statements, nil
}
func (w *Writer) GenerateAlterColumnTypeStatements(schema *models.Schema) ([]string, error) {
statements := []string{}
statements = append(statements, fmt.Sprintf("-- Alter column types for schema: %s", schema.Name))
for _, table := range schema.Tables {
columns := getSortedColumns(table.Columns)
for _, col := range columns {
targetType := effectiveAlterColumnSQLType(col)
stmt, err := w.executor.ExecuteAlterColumnTypeWithCheck(AlterColumnTypeWithCheckData{
SchemaName: schema.Name,
TableName: table.Name,
ColumnName: col.Name,
NewType: targetType,
EquivalentTypes: equivalentTypeListSQL(targetType),
UsingExpr: buildAlterColumnUsingExpression(col.Name, targetType),
})
if err != nil {
return nil, fmt.Errorf("failed to generate alter column type for %s.%s.%s: %w", schema.Name, table.Name, col.Name, err)
}
statements = append(statements, stmt)
}
}
return statements, nil
}
// GenerateAddColumnsForDatabase generates ALTER TABLE ADD COLUMN statements for the entire database
func (w *Writer) GenerateAddColumnsForDatabase(db *models.Database) ([]string, error) {
statements := []string{}
@@ -488,31 +527,7 @@ func (w *Writer) generateCreateTableStatement(schema *models.Schema, table *mode
func (w *Writer) generateColumnDefinition(col *models.Column) string {
parts := []string{col.SQLName()}
// Type with length/precision - convert to valid PostgreSQL type
baseType := pgsql.ConvertSQLType(col.Type)
typeStr := baseType
hasExplicitTypeModifier := pgsql.HasExplicitTypeModifier(baseType)
// Only add size specifiers for types that support them
if !hasExplicitTypeModifier && col.Length > 0 && col.Precision == 0 {
if pgsql.SupportsLength(baseType) {
typeStr = fmt.Sprintf("%s(%d)", baseType, col.Length)
} else if isTextTypeWithoutLength(baseType) {
// Convert text with length to varchar
typeStr = fmt.Sprintf("varchar(%d)", col.Length)
}
// For types that don't support length (integer, bigint, etc.), ignore the length
} else if !hasExplicitTypeModifier && col.Precision > 0 {
if pgsql.SupportsPrecision(baseType) {
if col.Scale > 0 {
typeStr = fmt.Sprintf("%s(%d,%d)", baseType, col.Precision, col.Scale)
} else {
typeStr = fmt.Sprintf("%s(%d)", baseType, col.Precision)
}
}
// For types that don't support precision, ignore it
}
parts = append(parts, typeStr)
parts = append(parts, effectiveColumnSQLType(col))
// NOT NULL
if col.NotNull {
@@ -534,6 +549,64 @@ func (w *Writer) generateColumnDefinition(col *models.Column) string {
return strings.Join(parts, " ")
}
func effectiveColumnSQLType(col *models.Column) string {
if col == nil {
return ""
}
baseType := pgsql.ConvertSQLType(col.Type)
typeStr := baseType
hasExplicitTypeModifier := pgsql.HasExplicitTypeModifier(baseType)
if !hasExplicitTypeModifier && col.Length > 0 && col.Precision == 0 {
if pgsql.SupportsLength(baseType) {
typeStr = fmt.Sprintf("%s(%d)", baseType, col.Length)
} else if isTextTypeWithoutLength(baseType) {
typeStr = fmt.Sprintf("varchar(%d)", col.Length)
}
} else if !hasExplicitTypeModifier && col.Precision > 0 {
if pgsql.SupportsPrecision(baseType) {
if col.Scale > 0 {
typeStr = fmt.Sprintf("%s(%d,%d)", baseType, col.Precision, col.Scale)
} else {
typeStr = fmt.Sprintf("%s(%d)", baseType, col.Precision)
}
}
}
return typeStr
}
func effectiveAlterColumnSQLType(col *models.Column) string {
typeStr := effectiveColumnSQLType(col)
switch strings.ToLower(strings.TrimSpace(typeStr)) {
case "smallserial":
return "smallint"
case "serial":
return "integer"
case "bigserial":
return "bigint"
default:
return typeStr
}
}
func buildAlterColumnUsingExpression(columnName, targetType string) string {
if strings.TrimSpace(columnName) == "" || strings.TrimSpace(targetType) == "" {
return ""
}
return fmt.Sprintf("%s::%s", quoteIdent(columnName), targetType)
}
func equivalentTypeListSQL(sqlType string) string {
variants := pgsql.EquivalentSQLTypeVariants(sqlType)
quoted := make([]string, 0, len(variants))
for _, variant := range variants {
quoted = append(quoted, fmt.Sprintf("'%s'", escapeQuote(variant)))
}
return strings.Join(quoted, ", ")
}
// WriteSchema writes a single schema and all its tables
func (w *Writer) WriteSchema(schema *models.Schema) error {
if w.writer == nil {
@@ -545,6 +618,10 @@ func (w *Writer) WriteSchema(schema *models.Schema) error {
return err
}
if err := w.writeRequiredExtensions(schema); err != nil {
return err
}
// Phase 2: Create sequences (priority 80)
if err := w.writeSequences(schema); err != nil {
return err
@@ -560,6 +637,10 @@ func (w *Writer) WriteSchema(schema *models.Schema) error {
return err
}
if err := w.writeAlterColumnTypes(schema); err != nil {
return err
}
// Phase 4: Create primary keys (priority 160)
if err := w.writePrimaryKeys(schema); err != nil {
return err
@@ -660,6 +741,16 @@ func (w *Writer) writeCreateSchema(schema *models.Schema) error {
return nil
}
func (w *Writer) writeRequiredExtensions(schema *models.Schema) error {
if !schemaRequiresPGTrgm(schema) {
return nil
}
fmt.Fprintln(w.writer, "CREATE EXTENSION IF NOT EXISTS pg_trgm;")
fmt.Fprintln(w.writer)
return nil
}
// writeSequences generates CREATE SEQUENCE statements for identity columns
func (w *Writer) writeSequences(schema *models.Schema) error {
fmt.Fprintf(w.writer, "-- Sequences for schema: %s\n", schema.Name)
@@ -753,6 +844,21 @@ func (w *Writer) writeAddColumns(schema *models.Schema) error {
return nil
}
func (w *Writer) writeAlterColumnTypes(schema *models.Schema) error {
fmt.Fprintf(w.writer, "-- Alter column types for schema: %s\n", schema.Name)
statements, err := w.GenerateAlterColumnTypeStatements(schema)
if err != nil {
return err
}
for _, stmt := range statements[1:] {
fmt.Fprint(w.writer, stmt)
fmt.Fprint(w.writer, "\n")
}
return nil
}
// writePrimaryKeys generates ALTER TABLE statements for primary keys
func (w *Writer) writePrimaryKeys(schema *models.Schema) error {
fmt.Fprintf(w.writer, "-- Primary keys for schema: %s\n", schema.Name)
@@ -806,6 +912,7 @@ func (w *Writer) writePrimaryKeys(schema *models.Schema) error {
ConstraintName: pkName,
AutoGenNames: formatStringList(autoGenPKNames),
Columns: strings.Join(columnNames, ", "),
ColumnNames: formatStringList(columnNames),
}
sql, err := w.executor.ExecuteCreatePrimaryKeyWithAutoGenCheck(data)
@@ -853,16 +960,14 @@ func (w *Writer) writeIndexes(schema *models.Schema) error {
// Build column list with operator class support for GIN indexes
columnExprs := make([]string, 0, len(index.Columns))
for _, colName := range index.Columns {
if col, ok := table.Columns[colName]; ok {
if col, ok := resolveIndexColumn(table, colName); ok {
colExpr := col.SQLName()
// For GIN indexes on text columns, add operator class
if strings.EqualFold(index.Type, "gin") && isTextType(col.Type) {
opClass := extractOperatorClass(index.Comment)
if opClass == "" {
opClass = "gin_trgm_ops"
}
if strings.EqualFold(index.Type, "gin") {
opClass := ginOperatorClassForColumn(col, index.Comment)
if opClass != "" {
colExpr = fmt.Sprintf("%s %s", col.SQLName(), opClass)
}
}
columnExprs = append(columnExprs, colExpr)
}
}
@@ -1248,20 +1353,126 @@ func isIntegerType(colType string) bool {
}
// isTextType checks if a column type is a text type (for GIN index operator class)
func isTextType(colType string) bool {
textTypes := []string{"text", "varchar", "character varying", "char", "character", "string"}
lowerType := strings.ToLower(colType)
for _, t := range textTypes {
if strings.HasPrefix(lowerType, t) {
// func isTextType(colType string) bool {
// textTypes := []string{"text", "varchar", "character varying", "char", "character", "string"}
// lowerType := strings.ToLower(colType)
// if strings.HasSuffix(lowerType, "[]") {
// return false
// }
// for _, t := range textTypes {
// if strings.HasPrefix(lowerType, t) {
// return true
// }
// }
// return false
// }
// isTextTypeWithoutLength checks if type is text (which should convert to varchar when length is specified)
func isTextTypeWithoutLength(colType string) bool {
return strings.EqualFold(colType, "text")
}
func ginOperatorClassForColumn(col *models.Column, comment string) string {
if col == nil {
return ""
}
sqlType := effectiveColumnSQLType(col)
baseType := pgsql.CanonicalizeBaseType(pgsql.ExtractBaseTypeLower(sqlType))
isArray := pgsql.IsArrayType(sqlType)
requested := extractOperatorClass(comment)
if requested != "" && ginOperatorClassCompatible(baseType, isArray, requested) {
return requested
}
if isArray {
return "array_ops"
}
switch {
case isTextGinBaseType(baseType):
return "gin_trgm_ops"
case baseType == "jsonb":
return "jsonb_ops"
default:
return requested
}
}
func ginOperatorClassCompatible(baseType string, isArray bool, opClass string) bool {
switch opClass {
case "gin_trgm_ops", "gin_bigm_ops":
return !isArray && isTextGinBaseType(baseType)
case "jsonb_ops", "jsonb_path_ops":
return !isArray && baseType == "jsonb"
case "array_ops":
return isArray
default:
return true
}
}
func isTextGinBaseType(baseType string) bool {
switch baseType {
case "text", "varchar", "character varying", "char", "character", "string", "citext", "bpchar":
return true
default:
return false
}
}
func schemaRequiresPGTrgm(schema *models.Schema) bool {
if schema == nil {
return false
}
for _, table := range schema.Tables {
if table == nil {
continue
}
for _, index := range table.Indexes {
if index == nil || !strings.EqualFold(index.Type, "gin") {
continue
}
for _, colName := range index.Columns {
col, ok := resolveIndexColumn(table, colName)
if !ok || col == nil {
continue
}
if ginOperatorClassForColumn(col, index.Comment) == "gin_trgm_ops" {
return true
}
}
}
}
return false
}
// isTextTypeWithoutLength checks if type is text (which should convert to varchar when length is specified)
func isTextTypeWithoutLength(colType string) bool {
return strings.EqualFold(colType, "text")
func resolveIndexColumn(table *models.Table, colName string) (*models.Column, bool) {
if table == nil {
return nil, false
}
if col, ok := table.Columns[colName]; ok && col != nil {
return col, true
}
normalized := strings.ToLower(strings.Trim(colName, `"`))
for key, col := range table.Columns {
if col == nil {
continue
}
if strings.ToLower(strings.Trim(key, `"`)) == normalized {
return col, true
}
if strings.ToLower(strings.Trim(col.Name, `"`)) == normalized {
return col, true
}
if strings.ToLower(strings.Trim(col.SQLName(), `"`)) == normalized {
return col, true
}
}
return nil, false
}
// formatStringList formats a list of strings as a SQL-safe comma-separated quoted list
+235 -3
View File
@@ -87,6 +87,117 @@ func TestWriteDatabase(t *testing.T) {
}
}
func TestWriteDatabase_GinIndexOnTextArrayDoesNotUseTrigramOperatorClass(t *testing.T) {
db := models.InitDatabase("testdb")
schema := models.InitSchema("public")
table := models.InitTable("plans", "public")
tagsCol := models.InitColumn("tags", "plans", "public")
tagsCol.Type = "text[]"
table.Columns["tags"] = tagsCol
index := &models.Index{
Name: "idx_plans_tags",
Type: "gin",
Columns: []string{"tags"},
}
table.Indexes[index.Name] = index
schema.Tables = append(schema.Tables, table)
db.Schemas = append(db.Schemas, schema)
var buf bytes.Buffer
writer := NewWriter(&writers.WriterOptions{})
writer.writer = &buf
if err := writer.WriteDatabase(db); err != nil {
t.Fatalf("WriteDatabase failed: %v", err)
}
output := buf.String()
if !strings.Contains(output, `USING gin (tags array_ops)`) {
t.Fatalf("expected GIN index on array column with array_ops, got:\n%s", output)
}
if strings.Contains(output, "gin_trgm_ops") {
t.Fatalf("did not expect gin_trgm_ops for text[] column, got:\n%s", output)
}
}
func TestWriteDatabase_GinIndexOnQuotedTextColumnUsesTrigramOperatorClass(t *testing.T) {
db := models.InitDatabase("testdb")
schema := models.InitSchema("public")
table := models.InitTable("agent_personas", "public")
nameCol := models.InitColumn("name", "agent_personas", "public")
nameCol.Type = "text"
table.Columns["name"] = nameCol
index := &models.Index{
Name: "idx_agent_personas_name_gin",
Type: "gin",
Columns: []string{`"name"`},
}
table.Indexes[index.Name] = index
schema.Tables = append(schema.Tables, table)
db.Schemas = append(db.Schemas, schema)
var buf bytes.Buffer
writer := NewWriter(&writers.WriterOptions{})
writer.writer = &buf
if err := writer.WriteDatabase(db); err != nil {
t.Fatalf("WriteDatabase failed: %v", err)
}
output := buf.String()
if !strings.Contains(output, `CREATE EXTENSION IF NOT EXISTS pg_trgm`) {
t.Fatalf("expected trigram extension for text GIN index, got:\n%s", output)
}
if !strings.Contains(output, `USING gin (name gin_trgm_ops)`) {
t.Fatalf("expected quoted text GIN index to include gin_trgm_ops, got:\n%s", output)
}
}
func TestWriteDatabase_GinIndexOnJSONBUsesJSONBOperatorClass(t *testing.T) {
db := models.InitDatabase("testdb")
schema := models.InitSchema("public")
table := models.InitTable("learnings", "public")
detailsCol := models.InitColumn("details", "learnings", "public")
detailsCol.Type = "jsonb"
table.Columns["details"] = detailsCol
index := &models.Index{
Name: "idx_learnings_details",
Type: "gin",
Columns: []string{"details"},
}
table.Indexes[index.Name] = index
schema.Tables = append(schema.Tables, table)
db.Schemas = append(db.Schemas, schema)
var buf bytes.Buffer
writer := NewWriter(&writers.WriterOptions{})
writer.writer = &buf
if err := writer.WriteDatabase(db); err != nil {
t.Fatalf("WriteDatabase failed: %v", err)
}
output := buf.String()
if !strings.Contains(output, `USING gin (details jsonb_ops)`) {
t.Fatalf("expected GIN jsonb index to include jsonb_ops, got:\n%s", output)
}
if strings.Contains(output, "gin_trgm_ops") {
t.Fatalf("did not expect gin_trgm_ops for jsonb column, got:\n%s", output)
}
}
func TestWriteForeignKeys(t *testing.T) {
// Create a test database with two related tables
db := models.InitDatabase("testdb")
@@ -636,9 +747,14 @@ func TestPrimaryKeyExistenceCheck(t *testing.T) {
t.Errorf("Output missing logic to drop auto-generated primary key\nFull output:\n%s", output)
}
// Verify it checks for our specific named constraint before adding it
if !strings.Contains(output, "constraint_name = 'pk_public_products'") {
t.Errorf("Output missing check for our named primary key constraint\nFull output:\n%s", output)
// Verify it compares the current primary key columns before dropping/recreating
if !strings.Contains(output, "current_pk_matches") || !strings.Contains(output, "ARRAY['id']") {
t.Errorf("Output missing safe primary key comparison logic\nFull output:\n%s", output)
}
// Verify it only adds the desired key when no PK exists or an auto-generated mismatch was dropped
if !strings.Contains(output, "current_pk_name IS NULL") || !strings.Contains(output, "current_pk_name IN ('products_pkey', 'public_products_pkey')") {
t.Errorf("Output missing guarded primary key creation logic\nFull output:\n%s", output)
}
}
@@ -729,6 +845,43 @@ func TestColumnSizeSpecifiers(t *testing.T) {
}
}
func TestWriteDatabase_PrimaryKeyTemplateDoesNotDropMatchingAutoPrimaryKey(t *testing.T) {
db := models.InitDatabase("testdb")
schema := models.InitSchema("public")
table := models.InitTable("learnings", "public")
idCol := models.InitColumn("id", "learnings", "public")
idCol.Type = "bigint"
idCol.IsPrimaryKey = true
table.Columns["id"] = idCol
parentCol := models.InitColumn("duplicate_of_learning_id", "learnings", "public")
parentCol.Type = "bigint"
table.Columns["duplicate_of_learning_id"] = parentCol
schema.Tables = append(schema.Tables, table)
db.Schemas = append(db.Schemas, schema)
var buf bytes.Buffer
writer := NewWriter(&writers.WriterOptions{})
writer.writer = &buf
if err := writer.WriteDatabase(db); err != nil {
t.Fatalf("WriteDatabase failed: %v", err)
}
output := buf.String()
if !strings.Contains(output, "current_pk_matches") {
t.Fatalf("expected generated SQL to compare current PK columns, got:\n%s", output)
}
if !strings.Contains(output, "ARRAY['id']") {
t.Fatalf("expected generated SQL to compare against desired PK columns, got:\n%s", output)
}
if !strings.Contains(output, "NOT current_pk_matches") {
t.Fatalf("expected generated SQL to avoid dropping matching PKs, got:\n%s", output)
}
}
func TestGenerateColumnDefinition_PreservesExplicitTypeModifiers(t *testing.T) {
writer := NewWriter(&writers.WriterOptions{})
@@ -905,3 +1058,82 @@ func TestWriteAddColumnStatements(t *testing.T) {
t.Errorf("Output missing DO block\nFull output:\n%s", output)
}
}
func TestWriteSchema_EmitsGuardedAlterColumnTypeStatements(t *testing.T) {
db := models.InitDatabase("testdb")
schema := models.InitSchema("public")
table := models.InitTable("agent_skills", "public")
nameCol := models.InitColumn("name", "agent_skills", "public")
nameCol.Type = "character varying"
nameCol.Length = 255
table.Columns["name"] = nameCol
tagsCol := models.InitColumn("tags", "agent_skills", "public")
tagsCol.Type = "text[]"
table.Columns["tags"] = tagsCol
schema.Tables = append(schema.Tables, table)
db.Schemas = append(db.Schemas, schema)
var buf bytes.Buffer
writer := NewWriter(&writers.WriterOptions{})
writer.writer = &buf
if err := writer.WriteDatabase(db); err != nil {
t.Fatalf("WriteDatabase failed: %v", err)
}
output := buf.String()
if !strings.Contains(output, "-- Alter column types for schema: public") {
t.Fatalf("expected alter column type section, got:\n%s", output)
}
if !strings.Contains(output, "pg_catalog.format_type") {
t.Fatalf("expected guarded live-type check, got:\n%s", output)
}
if !strings.Contains(output, "ALTER COLUMN name TYPE character varying(255)") {
t.Fatalf("expected guarded alter for character varying(255), got:\n%s", output)
}
if !strings.Contains(output, "ARRAY['varchar(255)', 'character varying(255)']") {
t.Fatalf("expected equivalent type spellings for varchar guard, got:\n%s", output)
}
if !strings.Contains(output, "ALTER COLUMN tags TYPE text[]") {
t.Fatalf("expected guarded alter for array type, got:\n%s", output)
}
if !strings.Contains(output, `ALTER COLUMN tags TYPE text[] USING tags::text[];`) {
t.Fatalf("expected guarded alter for array type to include USING cast, got:\n%s", output)
}
}
func TestWriteSchema_UsesStorageTypeForSerialAlterStatements(t *testing.T) {
db := models.InitDatabase("testdb")
schema := models.InitSchema("public")
table := models.InitTable("learnings", "public")
idCol := models.InitColumn("id", "learnings", "public")
idCol.Type = "bigserial"
table.Columns["id"] = idCol
schema.Tables = append(schema.Tables, table)
db.Schemas = append(db.Schemas, schema)
var buf bytes.Buffer
writer := NewWriter(&writers.WriterOptions{})
writer.writer = &buf
if err := writer.WriteDatabase(db); err != nil {
t.Fatalf("WriteDatabase failed: %v", err)
}
output := buf.String()
if !strings.Contains(output, "ALTER COLUMN id TYPE bigint") {
t.Fatalf("expected serial alter to use bigint storage type, got:\n%s", output)
}
if strings.Contains(output, "ALTER COLUMN id TYPE bigserial;") {
t.Fatalf("did not expect invalid bigserial alter statement, got:\n%s", output)
}
if !strings.Contains(output, `ALTER COLUMN id TYPE bigint USING id::bigint;`) {
t.Fatalf("expected serial alter to include USING cast, got:\n%s", output)
}
}
+17 -60
View File
@@ -2,9 +2,11 @@ package sqlite
import (
"strings"
sqlitepkg "git.warky.dev/wdevs/relspecgo/pkg/sqlite"
)
// SQLite type affinities
// SQLite type affinity constants
const (
TypeText = "TEXT"
TypeInteger = "INTEGER"
@@ -13,72 +15,27 @@ const (
TypeBlob = "BLOB"
)
// MapPostgreSQLType maps PostgreSQL data types to SQLite type affinities
// MapTypeToSQLite maps any SQL or Go canonical type to a SQLite type affinity.
// Handles input from any reader (PostgreSQL, MSSQL, SQLite, Go canonical).
func MapTypeToSQLite(colType string) string {
return sqlitepkg.ConvertCanonicalToSQLite(colType)
}
// MapPostgreSQLType is an alias for MapTypeToSQLite kept for compatibility.
//
// Deprecated: use MapTypeToSQLite.
func MapPostgreSQLType(pgType string) string {
// Normalize the type
normalized := strings.ToLower(strings.TrimSpace(pgType))
// Remove array notation if present
normalized = strings.TrimSuffix(normalized, "[]")
// Remove precision/scale if present
if idx := strings.Index(normalized, "("); idx != -1 {
normalized = normalized[:idx]
return MapTypeToSQLite(pgType)
}
// Map to SQLite type affinity
switch normalized {
// TEXT affinity
case "varchar", "character varying", "text", "char", "character",
"citext", "uuid", "timestamp", "timestamptz", "timestamp with time zone",
"timestamp without time zone", "date", "time", "timetz", "time with time zone",
"time without time zone", "json", "jsonb", "xml", "inet", "cidr", "macaddr":
return TypeText
// INTEGER affinity
case "int", "int2", "int4", "int8", "integer", "smallint", "bigint",
"serial", "smallserial", "bigserial", "boolean", "bool":
return TypeInteger
// REAL affinity
case "real", "float", "float4", "float8", "double precision":
return TypeReal
// NUMERIC affinity
case "numeric", "decimal", "money":
return TypeNumeric
// BLOB affinity
case "bytea", "blob":
return TypeBlob
default:
// Default to TEXT for unknown types
return TypeText
}
}
// IsIntegerType checks if a column type should be treated as integer
// IsIntegerType reports whether a column type maps to SQLite INTEGER affinity.
func IsIntegerType(colType string) bool {
normalized := strings.ToLower(strings.TrimSpace(colType))
normalized = strings.TrimSuffix(normalized, "[]")
if idx := strings.Index(normalized, "("); idx != -1 {
normalized = normalized[:idx]
return MapTypeToSQLite(colType) == TypeInteger
}
switch normalized {
case "int", "int2", "int4", "int8", "integer", "smallint", "bigint",
"serial", "smallserial", "bigserial":
return true
default:
return false
}
}
// MapBooleanValue converts PostgreSQL boolean literals to SQLite (0/1)
// MapBooleanValue converts common boolean literals to SQLite integers (1/0).
func MapBooleanValue(value string) string {
normalized := strings.ToLower(strings.TrimSpace(value))
switch normalized {
switch strings.ToLower(strings.TrimSpace(value)) {
case "true", "t", "yes", "y", "1":
return "1"
case "false", "f", "no", "n", "0":
+6 -1
View File
@@ -54,6 +54,10 @@ type WriterOptions struct {
// Prisma7 enables Prisma 7-specific output for Prisma writers.
Prisma7 bool
// ContinueOnError instructs SQL writers to prepend `\set ON_ERROR_STOP off`
// to their output so that psql continues past errors instead of stopping.
ContinueOnError bool
// Additional options can be added here as needed
Metadata map[string]interface{}
}
@@ -207,7 +211,8 @@ func quoteSQLLiteral(value string) string {
// - Returns a clean identifier safe for use in struct tags and field names
func SanitizeStructTagValue(value string) string {
// Remove DBML/DCTX style comments in brackets (e.g., [note: 'description'])
commentRegex := regexp.MustCompile(`\s*\[.*?\]\s*`)
// Require at least one character inside brackets to avoid stripping PostgreSQL array suffix "[]"
commentRegex := regexp.MustCompile(`\s*\[[^\]]+\]\s*`)
value = commentRegex.ReplaceAllString(value, "")
// Trim whitespace