Compare commits
9 Commits
v1.0.33
...
v1.0.38-2-
| Author | SHA1 | Date | |
|---|---|---|---|
| dc9172cc7c | |||
| ee88c07989 | |||
| ff1180524a | |||
|
|
480038d51d | ||
| 77436757c8 | |||
| 5e6f03e412 | |||
| 1dcbc79387 | |||
| 59c4a5ebf8 | |||
| 091e1913ee |
@@ -8,6 +8,7 @@ import (
|
|||||||
|
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
|
"git.warky.dev/wdevs/relspecgo/pkg/merge"
|
||||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||||
"git.warky.dev/wdevs/relspecgo/pkg/readers"
|
"git.warky.dev/wdevs/relspecgo/pkg/readers"
|
||||||
"git.warky.dev/wdevs/relspecgo/pkg/readers/bun"
|
"git.warky.dev/wdevs/relspecgo/pkg/readers/bun"
|
||||||
@@ -45,6 +46,7 @@ var (
|
|||||||
convertSourceType string
|
convertSourceType string
|
||||||
convertSourcePath string
|
convertSourcePath string
|
||||||
convertSourceConn string
|
convertSourceConn string
|
||||||
|
convertFromList []string
|
||||||
convertTargetType string
|
convertTargetType string
|
||||||
convertTargetPath string
|
convertTargetPath string
|
||||||
convertPackageName string
|
convertPackageName string
|
||||||
@@ -166,6 +168,7 @@ func init() {
|
|||||||
convertCmd.Flags().StringVar(&convertSourceType, "from", "", "Source format (dbml, dctx, drawdb, graphql, json, yaml, gorm, bun, drizzle, prisma, typeorm, pgsql, sqlite)")
|
convertCmd.Flags().StringVar(&convertSourceType, "from", "", "Source format (dbml, dctx, drawdb, graphql, json, yaml, gorm, bun, drizzle, prisma, typeorm, pgsql, sqlite)")
|
||||||
convertCmd.Flags().StringVar(&convertSourcePath, "from-path", "", "Source file path (for file-based formats)")
|
convertCmd.Flags().StringVar(&convertSourcePath, "from-path", "", "Source file path (for file-based formats)")
|
||||||
convertCmd.Flags().StringVar(&convertSourceConn, "from-conn", "", "Source connection string (for pgsql) or file path (for sqlite)")
|
convertCmd.Flags().StringVar(&convertSourceConn, "from-conn", "", "Source connection string (for pgsql) or file path (for sqlite)")
|
||||||
|
convertCmd.Flags().StringSliceVar(&convertFromList, "from-list", nil, "Comma-separated list of source file paths to read and merge (mutually exclusive with --from-path)")
|
||||||
|
|
||||||
convertCmd.Flags().StringVar(&convertTargetType, "to", "", "Target format (dbml, dctx, drawdb, graphql, json, yaml, gorm, bun, drizzle, prisma, typeorm, pgsql)")
|
convertCmd.Flags().StringVar(&convertTargetType, "to", "", "Target format (dbml, dctx, drawdb, graphql, json, yaml, gorm, bun, drizzle, prisma, typeorm, pgsql)")
|
||||||
convertCmd.Flags().StringVar(&convertTargetPath, "to-path", "", "Target output path (file or directory)")
|
convertCmd.Flags().StringVar(&convertTargetPath, "to-path", "", "Target output path (file or directory)")
|
||||||
@@ -191,17 +194,29 @@ func runConvert(cmd *cobra.Command, args []string) error {
|
|||||||
fmt.Fprintf(os.Stderr, "\n=== RelSpec Schema Converter ===\n")
|
fmt.Fprintf(os.Stderr, "\n=== RelSpec Schema Converter ===\n")
|
||||||
fmt.Fprintf(os.Stderr, "Started at: %s\n\n", getCurrentTimestamp())
|
fmt.Fprintf(os.Stderr, "Started at: %s\n\n", getCurrentTimestamp())
|
||||||
|
|
||||||
|
// Validate mutually exclusive flags
|
||||||
|
if convertSourcePath != "" && len(convertFromList) > 0 {
|
||||||
|
return fmt.Errorf("--from-path and --from-list are mutually exclusive")
|
||||||
|
}
|
||||||
|
|
||||||
// Read source database
|
// Read source database
|
||||||
fmt.Fprintf(os.Stderr, "[1/2] Reading source schema...\n")
|
fmt.Fprintf(os.Stderr, "[1/2] Reading source schema...\n")
|
||||||
fmt.Fprintf(os.Stderr, " Format: %s\n", convertSourceType)
|
fmt.Fprintf(os.Stderr, " Format: %s\n", convertSourceType)
|
||||||
if convertSourcePath != "" {
|
|
||||||
fmt.Fprintf(os.Stderr, " Path: %s\n", convertSourcePath)
|
|
||||||
}
|
|
||||||
if convertSourceConn != "" {
|
|
||||||
fmt.Fprintf(os.Stderr, " Conn: %s\n", maskPassword(convertSourceConn))
|
|
||||||
}
|
|
||||||
|
|
||||||
db, err := readDatabaseForConvert(convertSourceType, convertSourcePath, convertSourceConn)
|
var db *models.Database
|
||||||
|
var err error
|
||||||
|
|
||||||
|
if len(convertFromList) > 0 {
|
||||||
|
db, err = readDatabaseListForConvert(convertSourceType, convertFromList)
|
||||||
|
} else {
|
||||||
|
if convertSourcePath != "" {
|
||||||
|
fmt.Fprintf(os.Stderr, " Path: %s\n", convertSourcePath)
|
||||||
|
}
|
||||||
|
if convertSourceConn != "" {
|
||||||
|
fmt.Fprintf(os.Stderr, " Conn: %s\n", maskPassword(convertSourceConn))
|
||||||
|
}
|
||||||
|
db, err = readDatabaseForConvert(convertSourceType, convertSourcePath, convertSourceConn)
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to read source: %w", err)
|
return fmt.Errorf("failed to read source: %w", err)
|
||||||
}
|
}
|
||||||
@@ -237,6 +252,30 @@ func runConvert(cmd *cobra.Command, args []string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func readDatabaseListForConvert(dbType string, files []string) (*models.Database, error) {
|
||||||
|
if len(files) == 0 {
|
||||||
|
return nil, fmt.Errorf("file list is empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintf(os.Stderr, " Files: %d file(s)\n", len(files))
|
||||||
|
|
||||||
|
var base *models.Database
|
||||||
|
for i, filePath := range files {
|
||||||
|
fmt.Fprintf(os.Stderr, " [%d/%d] %s\n", i+1, len(files), filePath)
|
||||||
|
db, err := readDatabaseForConvert(dbType, filePath, "")
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read %s: %w", filePath, err)
|
||||||
|
}
|
||||||
|
if base == nil {
|
||||||
|
base = db
|
||||||
|
} else {
|
||||||
|
merge.MergeDatabases(base, db, &merge.MergeOptions{})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return base, nil
|
||||||
|
}
|
||||||
|
|
||||||
func readDatabaseForConvert(dbType, filePath, connString string) (*models.Database, error) {
|
func readDatabaseForConvert(dbType, filePath, connString string) (*models.Database, error) {
|
||||||
var reader readers.Reader
|
var reader readers.Reader
|
||||||
|
|
||||||
|
|||||||
183
cmd/relspec/convert_from_list_test.go
Normal file
183
cmd/relspec/convert_from_list_test.go
Normal file
@@ -0,0 +1,183 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestReadDatabaseListForConvert_SingleFile(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
file := filepath.Join(dir, "schema.json")
|
||||||
|
writeTestJSON(t, file, []string{"users"})
|
||||||
|
|
||||||
|
db, err := readDatabaseListForConvert("json", []string{file})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if len(db.Schemas) == 0 {
|
||||||
|
t.Fatal("expected at least one schema")
|
||||||
|
}
|
||||||
|
if len(db.Schemas[0].Tables) != 1 {
|
||||||
|
t.Errorf("expected 1 table, got %d", len(db.Schemas[0].Tables))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadDatabaseListForConvert_MultipleFiles(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
file1 := filepath.Join(dir, "schema1.json")
|
||||||
|
file2 := filepath.Join(dir, "schema2.json")
|
||||||
|
writeTestJSON(t, file1, []string{"users"})
|
||||||
|
writeTestJSON(t, file2, []string{"comments"})
|
||||||
|
|
||||||
|
db, err := readDatabaseListForConvert("json", []string{file1, file2})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
total := 0
|
||||||
|
for _, s := range db.Schemas {
|
||||||
|
total += len(s.Tables)
|
||||||
|
}
|
||||||
|
if total != 2 {
|
||||||
|
t.Errorf("expected 2 tables (users + comments), got %d", total)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadDatabaseListForConvert_PathWithSpaces(t *testing.T) {
|
||||||
|
spacedDir := filepath.Join(t.TempDir(), "my schema files")
|
||||||
|
if err := os.MkdirAll(spacedDir, 0755); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
file := filepath.Join(spacedDir, "my users schema.json")
|
||||||
|
writeTestJSON(t, file, []string{"users"})
|
||||||
|
|
||||||
|
db, err := readDatabaseListForConvert("json", []string{file})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error with spaced path: %v", err)
|
||||||
|
}
|
||||||
|
if db == nil {
|
||||||
|
t.Fatal("expected non-nil database")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadDatabaseListForConvert_MultipleFilesPathWithSpaces(t *testing.T) {
|
||||||
|
spacedDir := filepath.Join(t.TempDir(), "my schema files")
|
||||||
|
if err := os.MkdirAll(spacedDir, 0755); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
file1 := filepath.Join(spacedDir, "users schema.json")
|
||||||
|
file2 := filepath.Join(spacedDir, "posts schema.json")
|
||||||
|
writeTestJSON(t, file1, []string{"users"})
|
||||||
|
writeTestJSON(t, file2, []string{"posts"})
|
||||||
|
|
||||||
|
db, err := readDatabaseListForConvert("json", []string{file1, file2})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error with spaced paths: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
total := 0
|
||||||
|
for _, s := range db.Schemas {
|
||||||
|
total += len(s.Tables)
|
||||||
|
}
|
||||||
|
if total != 2 {
|
||||||
|
t.Errorf("expected 2 tables, got %d", total)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadDatabaseListForConvert_EmptyList(t *testing.T) {
|
||||||
|
_, err := readDatabaseListForConvert("json", []string{})
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error for empty file list")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadDatabaseListForConvert_InvalidFile(t *testing.T) {
|
||||||
|
_, err := readDatabaseListForConvert("json", []string{"/nonexistent/path/file.json"})
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error for nonexistent file")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunConvert_FromListMutuallyExclusiveWithFromPath(t *testing.T) {
|
||||||
|
saved := saveConvertState()
|
||||||
|
defer restoreConvertState(saved)
|
||||||
|
|
||||||
|
dir := t.TempDir()
|
||||||
|
file := filepath.Join(dir, "schema.json")
|
||||||
|
writeTestJSON(t, file, []string{"users"})
|
||||||
|
|
||||||
|
convertSourceType = "json"
|
||||||
|
convertSourcePath = file
|
||||||
|
convertFromList = []string{file}
|
||||||
|
convertTargetType = "json"
|
||||||
|
convertTargetPath = filepath.Join(dir, "out.json")
|
||||||
|
|
||||||
|
err := runConvert(nil, nil)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error when --from-path and --from-list are both set")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunConvert_FromListEndToEnd(t *testing.T) {
|
||||||
|
saved := saveConvertState()
|
||||||
|
defer restoreConvertState(saved)
|
||||||
|
|
||||||
|
dir := t.TempDir()
|
||||||
|
file1 := filepath.Join(dir, "users.json")
|
||||||
|
file2 := filepath.Join(dir, "posts.json")
|
||||||
|
outFile := filepath.Join(dir, "merged.json")
|
||||||
|
writeTestJSON(t, file1, []string{"users"})
|
||||||
|
writeTestJSON(t, file2, []string{"posts"})
|
||||||
|
|
||||||
|
convertSourceType = "json"
|
||||||
|
convertSourcePath = ""
|
||||||
|
convertSourceConn = ""
|
||||||
|
convertFromList = []string{file1, file2}
|
||||||
|
convertTargetType = "json"
|
||||||
|
convertTargetPath = outFile
|
||||||
|
convertPackageName = ""
|
||||||
|
convertSchemaFilter = ""
|
||||||
|
convertFlattenSchema = false
|
||||||
|
|
||||||
|
if err := runConvert(nil, nil); err != nil {
|
||||||
|
t.Fatalf("runConvert() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := os.Stat(outFile); os.IsNotExist(err) {
|
||||||
|
t.Error("expected output file to be created")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunConvert_FromListEndToEndPathWithSpaces(t *testing.T) {
|
||||||
|
saved := saveConvertState()
|
||||||
|
defer restoreConvertState(saved)
|
||||||
|
|
||||||
|
spacedDir := filepath.Join(t.TempDir(), "my schema dir")
|
||||||
|
if err := os.MkdirAll(spacedDir, 0755); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
file1 := filepath.Join(spacedDir, "users schema.json")
|
||||||
|
file2 := filepath.Join(spacedDir, "posts schema.json")
|
||||||
|
outFile := filepath.Join(spacedDir, "merged output.json")
|
||||||
|
writeTestJSON(t, file1, []string{"users"})
|
||||||
|
writeTestJSON(t, file2, []string{"posts"})
|
||||||
|
|
||||||
|
convertSourceType = "json"
|
||||||
|
convertSourcePath = ""
|
||||||
|
convertSourceConn = ""
|
||||||
|
convertFromList = []string{file1, file2}
|
||||||
|
convertTargetType = "json"
|
||||||
|
convertTargetPath = outFile
|
||||||
|
convertPackageName = ""
|
||||||
|
convertSchemaFilter = ""
|
||||||
|
convertFlattenSchema = false
|
||||||
|
|
||||||
|
if err := runConvert(nil, nil); err != nil {
|
||||||
|
t.Fatalf("runConvert() with spaced paths error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := os.Stat(outFile); os.IsNotExist(err) {
|
||||||
|
t.Error("expected output file to be created")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -47,6 +47,7 @@ var (
|
|||||||
mergeSourceType string
|
mergeSourceType string
|
||||||
mergeSourcePath string
|
mergeSourcePath string
|
||||||
mergeSourceConn string
|
mergeSourceConn string
|
||||||
|
mergeFromList []string
|
||||||
mergeOutputType string
|
mergeOutputType string
|
||||||
mergeOutputPath string
|
mergeOutputPath string
|
||||||
mergeOutputConn string
|
mergeOutputConn string
|
||||||
@@ -109,8 +110,9 @@ func init() {
|
|||||||
|
|
||||||
// Source database flags
|
// Source database flags
|
||||||
mergeCmd.Flags().StringVar(&mergeSourceType, "source", "", "Source format (required): dbml, dctx, drawdb, graphql, json, yaml, gorm, bun, drizzle, prisma, typeorm, pgsql")
|
mergeCmd.Flags().StringVar(&mergeSourceType, "source", "", "Source format (required): dbml, dctx, drawdb, graphql, json, yaml, gorm, bun, drizzle, prisma, typeorm, pgsql")
|
||||||
mergeCmd.Flags().StringVar(&mergeSourcePath, "source-path", "", "Source file path (required for file-based formats)")
|
mergeCmd.Flags().StringVar(&mergeSourcePath, "source-path", "", "Source file path (required for file-based formats, mutually exclusive with --from-list)")
|
||||||
mergeCmd.Flags().StringVar(&mergeSourceConn, "source-conn", "", "Source connection string (required for pgsql)")
|
mergeCmd.Flags().StringVar(&mergeSourceConn, "source-conn", "", "Source connection string (required for pgsql)")
|
||||||
|
mergeCmd.Flags().StringSliceVar(&mergeFromList, "from-list", nil, "Comma-separated list of source file paths to merge (mutually exclusive with --source-path)")
|
||||||
|
|
||||||
// Output flags
|
// Output flags
|
||||||
mergeCmd.Flags().StringVar(&mergeOutputType, "output", "", "Output format (required): dbml, dctx, drawdb, graphql, json, yaml, gorm, bun, drizzle, prisma, typeorm, pgsql")
|
mergeCmd.Flags().StringVar(&mergeOutputType, "output", "", "Output format (required): dbml, dctx, drawdb, graphql, json, yaml, gorm, bun, drizzle, prisma, typeorm, pgsql")
|
||||||
@@ -144,6 +146,11 @@ func runMerge(cmd *cobra.Command, args []string) error {
|
|||||||
return fmt.Errorf("--output format is required")
|
return fmt.Errorf("--output format is required")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Validate mutually exclusive source flags
|
||||||
|
if mergeSourcePath != "" && len(mergeFromList) > 0 {
|
||||||
|
return fmt.Errorf("--source-path and --from-list are mutually exclusive")
|
||||||
|
}
|
||||||
|
|
||||||
// Validate and expand file paths
|
// Validate and expand file paths
|
||||||
if mergeTargetType != "pgsql" {
|
if mergeTargetType != "pgsql" {
|
||||||
if mergeTargetPath == "" {
|
if mergeTargetPath == "" {
|
||||||
@@ -157,8 +164,8 @@ func runMerge(cmd *cobra.Command, args []string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if mergeSourceType != "pgsql" {
|
if mergeSourceType != "pgsql" {
|
||||||
if mergeSourcePath == "" {
|
if mergeSourcePath == "" && len(mergeFromList) == 0 {
|
||||||
return fmt.Errorf("--source-path is required for %s format", mergeSourceType)
|
return fmt.Errorf("--source-path or --from-list is required for %s format", mergeSourceType)
|
||||||
}
|
}
|
||||||
mergeSourcePath = expandPath(mergeSourcePath)
|
mergeSourcePath = expandPath(mergeSourcePath)
|
||||||
} else if mergeSourceConn == "" {
|
} else if mergeSourceConn == "" {
|
||||||
@@ -189,19 +196,36 @@ func runMerge(cmd *cobra.Command, args []string) error {
|
|||||||
fmt.Fprintf(os.Stderr, " ✓ Successfully read target database '%s'\n", targetDB.Name)
|
fmt.Fprintf(os.Stderr, " ✓ Successfully read target database '%s'\n", targetDB.Name)
|
||||||
printDatabaseStats(targetDB)
|
printDatabaseStats(targetDB)
|
||||||
|
|
||||||
// Step 2: Read source database
|
// Step 2: Read source database(s)
|
||||||
fmt.Fprintf(os.Stderr, "\n[2/3] Reading source database...\n")
|
fmt.Fprintf(os.Stderr, "\n[2/3] Reading source database...\n")
|
||||||
fmt.Fprintf(os.Stderr, " Format: %s\n", mergeSourceType)
|
fmt.Fprintf(os.Stderr, " Format: %s\n", mergeSourceType)
|
||||||
if mergeSourcePath != "" {
|
|
||||||
fmt.Fprintf(os.Stderr, " Path: %s\n", mergeSourcePath)
|
|
||||||
}
|
|
||||||
if mergeSourceConn != "" {
|
|
||||||
fmt.Fprintf(os.Stderr, " Conn: %s\n", maskPassword(mergeSourceConn))
|
|
||||||
}
|
|
||||||
|
|
||||||
sourceDB, err := readDatabaseForMerge(mergeSourceType, mergeSourcePath, mergeSourceConn, "Source")
|
var sourceDB *models.Database
|
||||||
if err != nil {
|
if len(mergeFromList) > 0 {
|
||||||
return fmt.Errorf("failed to read source database: %w", err)
|
fmt.Fprintf(os.Stderr, " Files: %d file(s)\n", len(mergeFromList))
|
||||||
|
for i, filePath := range mergeFromList {
|
||||||
|
fmt.Fprintf(os.Stderr, " [%d/%d] %s\n", i+1, len(mergeFromList), filePath)
|
||||||
|
db, readErr := readDatabaseForMerge(mergeSourceType, expandPath(filePath), "", "Source")
|
||||||
|
if readErr != nil {
|
||||||
|
return fmt.Errorf("failed to read source file %s: %w", filePath, readErr)
|
||||||
|
}
|
||||||
|
if sourceDB == nil {
|
||||||
|
sourceDB = db
|
||||||
|
} else {
|
||||||
|
merge.MergeDatabases(sourceDB, db, &merge.MergeOptions{})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if mergeSourcePath != "" {
|
||||||
|
fmt.Fprintf(os.Stderr, " Path: %s\n", mergeSourcePath)
|
||||||
|
}
|
||||||
|
if mergeSourceConn != "" {
|
||||||
|
fmt.Fprintf(os.Stderr, " Conn: %s\n", maskPassword(mergeSourceConn))
|
||||||
|
}
|
||||||
|
sourceDB, err = readDatabaseForMerge(mergeSourceType, mergeSourcePath, mergeSourceConn, "Source")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to read source database: %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
fmt.Fprintf(os.Stderr, " ✓ Successfully read source database '%s'\n", sourceDB.Name)
|
fmt.Fprintf(os.Stderr, " ✓ Successfully read source database '%s'\n", sourceDB.Name)
|
||||||
printDatabaseStats(sourceDB)
|
printDatabaseStats(sourceDB)
|
||||||
|
|||||||
162
cmd/relspec/merge_from_list_test.go
Normal file
162
cmd/relspec/merge_from_list_test.go
Normal file
@@ -0,0 +1,162 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRunMerge_FromListMutuallyExclusiveWithSourcePath(t *testing.T) {
|
||||||
|
saved := saveMergeState()
|
||||||
|
defer restoreMergeState(saved)
|
||||||
|
|
||||||
|
dir := t.TempDir()
|
||||||
|
file := filepath.Join(dir, "schema.json")
|
||||||
|
writeTestJSON(t, file, []string{"users"})
|
||||||
|
|
||||||
|
mergeTargetType = "json"
|
||||||
|
mergeTargetPath = file
|
||||||
|
mergeTargetConn = ""
|
||||||
|
mergeSourceType = "json"
|
||||||
|
mergeSourcePath = file
|
||||||
|
mergeSourceConn = ""
|
||||||
|
mergeFromList = []string{file}
|
||||||
|
mergeOutputType = "json"
|
||||||
|
mergeOutputPath = filepath.Join(dir, "out.json")
|
||||||
|
mergeOutputConn = ""
|
||||||
|
mergeSkipTables = ""
|
||||||
|
mergeReportPath = ""
|
||||||
|
|
||||||
|
err := runMerge(nil, nil)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error when --source-path and --from-list are both set")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunMerge_FromListSingleFile(t *testing.T) {
|
||||||
|
saved := saveMergeState()
|
||||||
|
defer restoreMergeState(saved)
|
||||||
|
|
||||||
|
dir := t.TempDir()
|
||||||
|
targetFile := filepath.Join(dir, "target.json")
|
||||||
|
sourceFile := filepath.Join(dir, "source.json")
|
||||||
|
outFile := filepath.Join(dir, "output.json")
|
||||||
|
writeTestJSON(t, targetFile, []string{"users"})
|
||||||
|
writeTestJSON(t, sourceFile, []string{"posts"})
|
||||||
|
|
||||||
|
mergeTargetType = "json"
|
||||||
|
mergeTargetPath = targetFile
|
||||||
|
mergeTargetConn = ""
|
||||||
|
mergeSourceType = "json"
|
||||||
|
mergeSourcePath = ""
|
||||||
|
mergeSourceConn = ""
|
||||||
|
mergeFromList = []string{sourceFile}
|
||||||
|
mergeOutputType = "json"
|
||||||
|
mergeOutputPath = outFile
|
||||||
|
mergeOutputConn = ""
|
||||||
|
mergeSkipTables = ""
|
||||||
|
mergeReportPath = ""
|
||||||
|
|
||||||
|
if err := runMerge(nil, nil); err != nil {
|
||||||
|
t.Fatalf("runMerge() error = %v", err)
|
||||||
|
}
|
||||||
|
if _, err := os.Stat(outFile); os.IsNotExist(err) {
|
||||||
|
t.Error("expected output file to be created")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunMerge_FromListMultipleFiles(t *testing.T) {
|
||||||
|
saved := saveMergeState()
|
||||||
|
defer restoreMergeState(saved)
|
||||||
|
|
||||||
|
dir := t.TempDir()
|
||||||
|
targetFile := filepath.Join(dir, "target.json")
|
||||||
|
source1 := filepath.Join(dir, "source1.json")
|
||||||
|
source2 := filepath.Join(dir, "source2.json")
|
||||||
|
outFile := filepath.Join(dir, "output.json")
|
||||||
|
writeTestJSON(t, targetFile, []string{"users"})
|
||||||
|
writeTestJSON(t, source1, []string{"posts"})
|
||||||
|
writeTestJSON(t, source2, []string{"comments"})
|
||||||
|
|
||||||
|
mergeTargetType = "json"
|
||||||
|
mergeTargetPath = targetFile
|
||||||
|
mergeTargetConn = ""
|
||||||
|
mergeSourceType = "json"
|
||||||
|
mergeSourcePath = ""
|
||||||
|
mergeSourceConn = ""
|
||||||
|
mergeFromList = []string{source1, source2}
|
||||||
|
mergeOutputType = "json"
|
||||||
|
mergeOutputPath = outFile
|
||||||
|
mergeOutputConn = ""
|
||||||
|
mergeSkipTables = ""
|
||||||
|
mergeReportPath = ""
|
||||||
|
|
||||||
|
if err := runMerge(nil, nil); err != nil {
|
||||||
|
t.Fatalf("runMerge() error = %v", err)
|
||||||
|
}
|
||||||
|
if _, err := os.Stat(outFile); os.IsNotExist(err) {
|
||||||
|
t.Error("expected output file to be created")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunMerge_FromListPathWithSpaces(t *testing.T) {
|
||||||
|
saved := saveMergeState()
|
||||||
|
defer restoreMergeState(saved)
|
||||||
|
|
||||||
|
spacedDir := filepath.Join(t.TempDir(), "my schema files")
|
||||||
|
if err := os.MkdirAll(spacedDir, 0755); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
targetFile := filepath.Join(spacedDir, "target schema.json")
|
||||||
|
sourceFile := filepath.Join(spacedDir, "source schema.json")
|
||||||
|
outFile := filepath.Join(spacedDir, "merged output.json")
|
||||||
|
writeTestJSON(t, targetFile, []string{"users"})
|
||||||
|
writeTestJSON(t, sourceFile, []string{"comments"})
|
||||||
|
|
||||||
|
mergeTargetType = "json"
|
||||||
|
mergeTargetPath = targetFile
|
||||||
|
mergeTargetConn = ""
|
||||||
|
mergeSourceType = "json"
|
||||||
|
mergeSourcePath = ""
|
||||||
|
mergeSourceConn = ""
|
||||||
|
mergeFromList = []string{sourceFile}
|
||||||
|
mergeOutputType = "json"
|
||||||
|
mergeOutputPath = outFile
|
||||||
|
mergeOutputConn = ""
|
||||||
|
mergeSkipTables = ""
|
||||||
|
mergeReportPath = ""
|
||||||
|
|
||||||
|
if err := runMerge(nil, nil); err != nil {
|
||||||
|
t.Fatalf("runMerge() with spaced paths error = %v", err)
|
||||||
|
}
|
||||||
|
if _, err := os.Stat(outFile); os.IsNotExist(err) {
|
||||||
|
t.Error("expected output file to be created")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunMerge_FromListMissingSourceType(t *testing.T) {
|
||||||
|
saved := saveMergeState()
|
||||||
|
defer restoreMergeState(saved)
|
||||||
|
|
||||||
|
dir := t.TempDir()
|
||||||
|
file := filepath.Join(dir, "schema.json")
|
||||||
|
writeTestJSON(t, file, []string{"users"})
|
||||||
|
|
||||||
|
mergeTargetType = "json"
|
||||||
|
mergeTargetPath = file
|
||||||
|
mergeTargetConn = ""
|
||||||
|
mergeSourceType = "json"
|
||||||
|
mergeSourcePath = ""
|
||||||
|
mergeSourceConn = ""
|
||||||
|
mergeFromList = []string{} // empty list, no source-path either
|
||||||
|
mergeOutputType = "json"
|
||||||
|
mergeOutputPath = filepath.Join(dir, "out.json")
|
||||||
|
mergeOutputConn = ""
|
||||||
|
mergeSkipTables = ""
|
||||||
|
mergeReportPath = ""
|
||||||
|
|
||||||
|
err := runMerge(nil, nil)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error when neither --source-path nor --from-list is provided")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -2,6 +2,8 @@ package main
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"runtime/debug"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
)
|
)
|
||||||
@@ -12,6 +14,36 @@ var (
|
|||||||
buildDate = "unknown"
|
buildDate = "unknown"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
// If version wasn't set via ldflags, try to get it from build info
|
||||||
|
if version == "dev" {
|
||||||
|
if info, ok := debug.ReadBuildInfo(); ok {
|
||||||
|
// Try to get version from VCS
|
||||||
|
var vcsRevision, vcsTime string
|
||||||
|
for _, setting := range info.Settings {
|
||||||
|
switch setting.Key {
|
||||||
|
case "vcs.revision":
|
||||||
|
if len(setting.Value) >= 7 {
|
||||||
|
vcsRevision = setting.Value[:7]
|
||||||
|
}
|
||||||
|
case "vcs.time":
|
||||||
|
vcsTime = setting.Value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if vcsRevision != "" {
|
||||||
|
version = vcsRevision
|
||||||
|
}
|
||||||
|
|
||||||
|
if vcsTime != "" {
|
||||||
|
if t, err := time.Parse(time.RFC3339, vcsTime); err == nil {
|
||||||
|
buildDate = t.UTC().Format("2006-01-02 15:04:05 UTC")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
var rootCmd = &cobra.Command{
|
var rootCmd = &cobra.Command{
|
||||||
Use: "relspec",
|
Use: "relspec",
|
||||||
Short: "RelSpec - Database schema conversion and analysis tool",
|
Short: "RelSpec - Database schema conversion and analysis tool",
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ var (
|
|||||||
templSourceType string
|
templSourceType string
|
||||||
templSourcePath string
|
templSourcePath string
|
||||||
templSourceConn string
|
templSourceConn string
|
||||||
|
templFromList []string
|
||||||
templTemplatePath string
|
templTemplatePath string
|
||||||
templOutputPath string
|
templOutputPath string
|
||||||
templSchemaFilter string
|
templSchemaFilter string
|
||||||
@@ -78,8 +79,9 @@ Examples:
|
|||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
templCmd.Flags().StringVar(&templSourceType, "from", "", "Source format (dbml, pgsql, json, etc.)")
|
templCmd.Flags().StringVar(&templSourceType, "from", "", "Source format (dbml, pgsql, json, etc.)")
|
||||||
templCmd.Flags().StringVar(&templSourcePath, "from-path", "", "Source file path (for file-based sources)")
|
templCmd.Flags().StringVar(&templSourcePath, "from-path", "", "Source file path (for file-based sources, mutually exclusive with --from-list)")
|
||||||
templCmd.Flags().StringVar(&templSourceConn, "from-conn", "", "Source connection string (for database sources)")
|
templCmd.Flags().StringVar(&templSourceConn, "from-conn", "", "Source connection string (for database sources)")
|
||||||
|
templCmd.Flags().StringSliceVar(&templFromList, "from-list", nil, "Comma-separated list of source file paths to read and merge (mutually exclusive with --from-path)")
|
||||||
templCmd.Flags().StringVar(&templTemplatePath, "template", "", "Template file path (required)")
|
templCmd.Flags().StringVar(&templTemplatePath, "template", "", "Template file path (required)")
|
||||||
templCmd.Flags().StringVar(&templOutputPath, "output", "", "Output path (file or directory, empty for stdout)")
|
templCmd.Flags().StringVar(&templOutputPath, "output", "", "Output path (file or directory, empty for stdout)")
|
||||||
templCmd.Flags().StringVar(&templSchemaFilter, "schema", "", "Filter to specific schema")
|
templCmd.Flags().StringVar(&templSchemaFilter, "schema", "", "Filter to specific schema")
|
||||||
@@ -95,9 +97,20 @@ func runTempl(cmd *cobra.Command, args []string) error {
|
|||||||
fmt.Fprintf(os.Stderr, "=== RelSpec Template Execution ===\n")
|
fmt.Fprintf(os.Stderr, "=== RelSpec Template Execution ===\n")
|
||||||
fmt.Fprintf(os.Stderr, "Started at: %s\n\n", getCurrentTimestamp())
|
fmt.Fprintf(os.Stderr, "Started at: %s\n\n", getCurrentTimestamp())
|
||||||
|
|
||||||
|
// Validate mutually exclusive flags
|
||||||
|
if templSourcePath != "" && len(templFromList) > 0 {
|
||||||
|
return fmt.Errorf("--from-path and --from-list are mutually exclusive")
|
||||||
|
}
|
||||||
|
|
||||||
// Read database using the same function as convert
|
// Read database using the same function as convert
|
||||||
fmt.Fprintf(os.Stderr, "Reading from %s...\n", templSourceType)
|
fmt.Fprintf(os.Stderr, "Reading from %s...\n", templSourceType)
|
||||||
db, err := readDatabaseForConvert(templSourceType, templSourcePath, templSourceConn)
|
var db *models.Database
|
||||||
|
var err error
|
||||||
|
if len(templFromList) > 0 {
|
||||||
|
db, err = readDatabaseListForConvert(templSourceType, templFromList)
|
||||||
|
} else {
|
||||||
|
db, err = readDatabaseForConvert(templSourceType, templSourcePath, templSourceConn)
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to read source: %w", err)
|
return fmt.Errorf("failed to read source: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
134
cmd/relspec/templ_from_list_test.go
Normal file
134
cmd/relspec/templ_from_list_test.go
Normal file
@@ -0,0 +1,134 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// writeTestTemplate writes a minimal Go text template file.
|
||||||
|
func writeTestTemplate(t *testing.T, path string) {
|
||||||
|
t.Helper()
|
||||||
|
content := []byte(`{{.Name}}`)
|
||||||
|
if err := os.WriteFile(path, content, 0644); err != nil {
|
||||||
|
t.Fatalf("failed to write template file %s: %v", path, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunTempl_FromListMutuallyExclusiveWithFromPath(t *testing.T) {
|
||||||
|
saved := saveTemplState()
|
||||||
|
defer restoreTemplState(saved)
|
||||||
|
|
||||||
|
dir := t.TempDir()
|
||||||
|
file := filepath.Join(dir, "schema.json")
|
||||||
|
tmpl := filepath.Join(dir, "tmpl.tmpl")
|
||||||
|
writeTestJSON(t, file, []string{"users"})
|
||||||
|
writeTestTemplate(t, tmpl)
|
||||||
|
|
||||||
|
templSourceType = "json"
|
||||||
|
templSourcePath = file
|
||||||
|
templFromList = []string{file}
|
||||||
|
templTemplatePath = tmpl
|
||||||
|
templOutputPath = ""
|
||||||
|
templMode = "database"
|
||||||
|
templFilenamePattern = "{{.Name}}.txt"
|
||||||
|
|
||||||
|
err := runTempl(nil, nil)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error when --from-path and --from-list are both set")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunTempl_FromListSingleFile(t *testing.T) {
|
||||||
|
saved := saveTemplState()
|
||||||
|
defer restoreTemplState(saved)
|
||||||
|
|
||||||
|
dir := t.TempDir()
|
||||||
|
file := filepath.Join(dir, "schema.json")
|
||||||
|
tmpl := filepath.Join(dir, "tmpl.tmpl")
|
||||||
|
outFile := filepath.Join(dir, "output.txt")
|
||||||
|
writeTestJSON(t, file, []string{"users"})
|
||||||
|
writeTestTemplate(t, tmpl)
|
||||||
|
|
||||||
|
templSourceType = "json"
|
||||||
|
templSourcePath = ""
|
||||||
|
templSourceConn = ""
|
||||||
|
templFromList = []string{file}
|
||||||
|
templTemplatePath = tmpl
|
||||||
|
templOutputPath = outFile
|
||||||
|
templSchemaFilter = ""
|
||||||
|
templMode = "database"
|
||||||
|
templFilenamePattern = "{{.Name}}.txt"
|
||||||
|
|
||||||
|
if err := runTempl(nil, nil); err != nil {
|
||||||
|
t.Fatalf("runTempl() error = %v", err)
|
||||||
|
}
|
||||||
|
if _, err := os.Stat(outFile); os.IsNotExist(err) {
|
||||||
|
t.Error("expected output file to be created")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunTempl_FromListMultipleFiles(t *testing.T) {
|
||||||
|
saved := saveTemplState()
|
||||||
|
defer restoreTemplState(saved)
|
||||||
|
|
||||||
|
dir := t.TempDir()
|
||||||
|
file1 := filepath.Join(dir, "users.json")
|
||||||
|
file2 := filepath.Join(dir, "posts.json")
|
||||||
|
tmpl := filepath.Join(dir, "tmpl.tmpl")
|
||||||
|
outFile := filepath.Join(dir, "output.txt")
|
||||||
|
writeTestJSON(t, file1, []string{"users"})
|
||||||
|
writeTestJSON(t, file2, []string{"posts"})
|
||||||
|
writeTestTemplate(t, tmpl)
|
||||||
|
|
||||||
|
templSourceType = "json"
|
||||||
|
templSourcePath = ""
|
||||||
|
templSourceConn = ""
|
||||||
|
templFromList = []string{file1, file2}
|
||||||
|
templTemplatePath = tmpl
|
||||||
|
templOutputPath = outFile
|
||||||
|
templSchemaFilter = ""
|
||||||
|
templMode = "database"
|
||||||
|
templFilenamePattern = "{{.Name}}.txt"
|
||||||
|
|
||||||
|
if err := runTempl(nil, nil); err != nil {
|
||||||
|
t.Fatalf("runTempl() error = %v", err)
|
||||||
|
}
|
||||||
|
if _, err := os.Stat(outFile); os.IsNotExist(err) {
|
||||||
|
t.Error("expected output file to be created")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunTempl_FromListPathWithSpaces(t *testing.T) {
|
||||||
|
saved := saveTemplState()
|
||||||
|
defer restoreTemplState(saved)
|
||||||
|
|
||||||
|
spacedDir := filepath.Join(t.TempDir(), "my schema files")
|
||||||
|
if err := os.MkdirAll(spacedDir, 0755); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
file1 := filepath.Join(spacedDir, "users schema.json")
|
||||||
|
file2 := filepath.Join(spacedDir, "posts schema.json")
|
||||||
|
tmpl := filepath.Join(spacedDir, "my template.tmpl")
|
||||||
|
outFile := filepath.Join(spacedDir, "output file.txt")
|
||||||
|
writeTestJSON(t, file1, []string{"users"})
|
||||||
|
writeTestJSON(t, file2, []string{"posts"})
|
||||||
|
writeTestTemplate(t, tmpl)
|
||||||
|
|
||||||
|
templSourceType = "json"
|
||||||
|
templSourcePath = ""
|
||||||
|
templSourceConn = ""
|
||||||
|
templFromList = []string{file1, file2}
|
||||||
|
templTemplatePath = tmpl
|
||||||
|
templOutputPath = outFile
|
||||||
|
templSchemaFilter = ""
|
||||||
|
templMode = "database"
|
||||||
|
templFilenamePattern = "{{.Name}}.txt"
|
||||||
|
|
||||||
|
if err := runTempl(nil, nil); err != nil {
|
||||||
|
t.Fatalf("runTempl() with spaced paths error = %v", err)
|
||||||
|
}
|
||||||
|
if _, err := os.Stat(outFile); os.IsNotExist(err) {
|
||||||
|
t.Error("expected output file to be created")
|
||||||
|
}
|
||||||
|
}
|
||||||
219
cmd/relspec/testhelpers_test.go
Normal file
219
cmd/relspec/testhelpers_test.go
Normal file
@@ -0,0 +1,219 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// minimalColumn is used to build test JSON fixtures.
|
||||||
|
type minimalColumn struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Table string `json:"table"`
|
||||||
|
Schema string `json:"schema"`
|
||||||
|
Type string `json:"type"`
|
||||||
|
NotNull bool `json:"not_null"`
|
||||||
|
IsPrimaryKey bool `json:"is_primary_key"`
|
||||||
|
AutoIncrement bool `json:"auto_increment"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type minimalTable struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Schema string `json:"schema"`
|
||||||
|
Columns map[string]minimalColumn `json:"columns"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type minimalSchema struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Tables []minimalTable `json:"tables"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type minimalDatabase struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Schemas []minimalSchema `json:"schemas"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// writeTestJSON writes a minimal JSON database file with one schema ("public")
|
||||||
|
// containing tables with the given names. Each table has a single "id" PK column.
|
||||||
|
func writeTestJSON(t *testing.T, path string, tableNames []string) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
tables := make([]minimalTable, len(tableNames))
|
||||||
|
for i, name := range tableNames {
|
||||||
|
tables[i] = minimalTable{
|
||||||
|
Name: name,
|
||||||
|
Schema: "public",
|
||||||
|
Columns: map[string]minimalColumn{
|
||||||
|
"id": {
|
||||||
|
Name: "id",
|
||||||
|
Table: name,
|
||||||
|
Schema: "public",
|
||||||
|
Type: "bigint",
|
||||||
|
NotNull: true,
|
||||||
|
IsPrimaryKey: true,
|
||||||
|
AutoIncrement: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
db := minimalDatabase{
|
||||||
|
Name: "test_db",
|
||||||
|
Schemas: []minimalSchema{{Name: "public", Tables: tables}},
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := json.Marshal(db)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to marshal test JSON: %v", err)
|
||||||
|
}
|
||||||
|
if err := os.WriteFile(path, data, 0644); err != nil {
|
||||||
|
t.Fatalf("failed to write test file %s: %v", path, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// convertState captures and restores all convert global vars.
|
||||||
|
type convertState struct {
|
||||||
|
sourceType string
|
||||||
|
sourcePath string
|
||||||
|
sourceConn string
|
||||||
|
fromList []string
|
||||||
|
targetType string
|
||||||
|
targetPath string
|
||||||
|
packageName string
|
||||||
|
schemaFilter string
|
||||||
|
flattenSchema bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func saveConvertState() convertState {
|
||||||
|
return convertState{
|
||||||
|
sourceType: convertSourceType,
|
||||||
|
sourcePath: convertSourcePath,
|
||||||
|
sourceConn: convertSourceConn,
|
||||||
|
fromList: convertFromList,
|
||||||
|
targetType: convertTargetType,
|
||||||
|
targetPath: convertTargetPath,
|
||||||
|
packageName: convertPackageName,
|
||||||
|
schemaFilter: convertSchemaFilter,
|
||||||
|
flattenSchema: convertFlattenSchema,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func restoreConvertState(s convertState) {
|
||||||
|
convertSourceType = s.sourceType
|
||||||
|
convertSourcePath = s.sourcePath
|
||||||
|
convertSourceConn = s.sourceConn
|
||||||
|
convertFromList = s.fromList
|
||||||
|
convertTargetType = s.targetType
|
||||||
|
convertTargetPath = s.targetPath
|
||||||
|
convertPackageName = s.packageName
|
||||||
|
convertSchemaFilter = s.schemaFilter
|
||||||
|
convertFlattenSchema = s.flattenSchema
|
||||||
|
}
|
||||||
|
|
||||||
|
// templState captures and restores all templ global vars.
|
||||||
|
type templState struct {
|
||||||
|
sourceType string
|
||||||
|
sourcePath string
|
||||||
|
sourceConn string
|
||||||
|
fromList []string
|
||||||
|
templatePath string
|
||||||
|
outputPath string
|
||||||
|
schemaFilter string
|
||||||
|
mode string
|
||||||
|
filenamePattern string
|
||||||
|
}
|
||||||
|
|
||||||
|
func saveTemplState() templState {
|
||||||
|
return templState{
|
||||||
|
sourceType: templSourceType,
|
||||||
|
sourcePath: templSourcePath,
|
||||||
|
sourceConn: templSourceConn,
|
||||||
|
fromList: templFromList,
|
||||||
|
templatePath: templTemplatePath,
|
||||||
|
outputPath: templOutputPath,
|
||||||
|
schemaFilter: templSchemaFilter,
|
||||||
|
mode: templMode,
|
||||||
|
filenamePattern: templFilenamePattern,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func restoreTemplState(s templState) {
|
||||||
|
templSourceType = s.sourceType
|
||||||
|
templSourcePath = s.sourcePath
|
||||||
|
templSourceConn = s.sourceConn
|
||||||
|
templFromList = s.fromList
|
||||||
|
templTemplatePath = s.templatePath
|
||||||
|
templOutputPath = s.outputPath
|
||||||
|
templSchemaFilter = s.schemaFilter
|
||||||
|
templMode = s.mode
|
||||||
|
templFilenamePattern = s.filenamePattern
|
||||||
|
}
|
||||||
|
|
||||||
|
// mergeState captures and restores all merge global vars.
|
||||||
|
type mergeState struct {
|
||||||
|
targetType string
|
||||||
|
targetPath string
|
||||||
|
targetConn string
|
||||||
|
sourceType string
|
||||||
|
sourcePath string
|
||||||
|
sourceConn string
|
||||||
|
fromList []string
|
||||||
|
outputType string
|
||||||
|
outputPath string
|
||||||
|
outputConn string
|
||||||
|
skipDomains bool
|
||||||
|
skipRelations bool
|
||||||
|
skipEnums bool
|
||||||
|
skipViews bool
|
||||||
|
skipSequences bool
|
||||||
|
skipTables string
|
||||||
|
verbose bool
|
||||||
|
reportPath string
|
||||||
|
flattenSchema bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func saveMergeState() mergeState {
|
||||||
|
return mergeState{
|
||||||
|
targetType: mergeTargetType,
|
||||||
|
targetPath: mergeTargetPath,
|
||||||
|
targetConn: mergeTargetConn,
|
||||||
|
sourceType: mergeSourceType,
|
||||||
|
sourcePath: mergeSourcePath,
|
||||||
|
sourceConn: mergeSourceConn,
|
||||||
|
fromList: mergeFromList,
|
||||||
|
outputType: mergeOutputType,
|
||||||
|
outputPath: mergeOutputPath,
|
||||||
|
outputConn: mergeOutputConn,
|
||||||
|
skipDomains: mergeSkipDomains,
|
||||||
|
skipRelations: mergeSkipRelations,
|
||||||
|
skipEnums: mergeSkipEnums,
|
||||||
|
skipViews: mergeSkipViews,
|
||||||
|
skipSequences: mergeSkipSequences,
|
||||||
|
skipTables: mergeSkipTables,
|
||||||
|
verbose: mergeVerbose,
|
||||||
|
reportPath: mergeReportPath,
|
||||||
|
flattenSchema: mergeFlattenSchema,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func restoreMergeState(s mergeState) {
|
||||||
|
mergeTargetType = s.targetType
|
||||||
|
mergeTargetPath = s.targetPath
|
||||||
|
mergeTargetConn = s.targetConn
|
||||||
|
mergeSourceType = s.sourceType
|
||||||
|
mergeSourcePath = s.sourcePath
|
||||||
|
mergeSourceConn = s.sourceConn
|
||||||
|
mergeFromList = s.fromList
|
||||||
|
mergeOutputType = s.outputType
|
||||||
|
mergeOutputPath = s.outputPath
|
||||||
|
mergeOutputConn = s.outputConn
|
||||||
|
mergeSkipDomains = s.skipDomains
|
||||||
|
mergeSkipRelations = s.skipRelations
|
||||||
|
mergeSkipEnums = s.skipEnums
|
||||||
|
mergeSkipViews = s.skipViews
|
||||||
|
mergeSkipSequences = s.skipSequences
|
||||||
|
mergeSkipTables = s.skipTables
|
||||||
|
mergeVerbose = s.verbose
|
||||||
|
mergeReportPath = s.reportPath
|
||||||
|
mergeFlattenSchema = s.flattenSchema
|
||||||
|
}
|
||||||
@@ -60,19 +60,19 @@ func (f *MarkdownFormatter) Format(report *InspectorReport) (string, error) {
|
|||||||
// Summary
|
// Summary
|
||||||
sb.WriteString(f.formatHeader("Summary"))
|
sb.WriteString(f.formatHeader("Summary"))
|
||||||
sb.WriteString("\n")
|
sb.WriteString("\n")
|
||||||
sb.WriteString(fmt.Sprintf("- Rules Checked: %d\n", report.Summary.RulesChecked))
|
fmt.Fprintf(&sb, "- Rules Checked: %d\n", report.Summary.RulesChecked)
|
||||||
|
|
||||||
// Color-code error and warning counts
|
// Color-code error and warning counts
|
||||||
if report.Summary.ErrorCount > 0 {
|
if report.Summary.ErrorCount > 0 {
|
||||||
sb.WriteString(f.colorize(fmt.Sprintf("- Errors: %d\n", report.Summary.ErrorCount), colorRed))
|
sb.WriteString(f.colorize(fmt.Sprintf("- Errors: %d\n", report.Summary.ErrorCount), colorRed))
|
||||||
} else {
|
} else {
|
||||||
sb.WriteString(fmt.Sprintf("- Errors: %d\n", report.Summary.ErrorCount))
|
fmt.Fprintf(&sb, "- Errors: %d\n", report.Summary.ErrorCount)
|
||||||
}
|
}
|
||||||
|
|
||||||
if report.Summary.WarningCount > 0 {
|
if report.Summary.WarningCount > 0 {
|
||||||
sb.WriteString(f.colorize(fmt.Sprintf("- Warnings: %d\n", report.Summary.WarningCount), colorYellow))
|
sb.WriteString(f.colorize(fmt.Sprintf("- Warnings: %d\n", report.Summary.WarningCount), colorYellow))
|
||||||
} else {
|
} else {
|
||||||
sb.WriteString(fmt.Sprintf("- Warnings: %d\n", report.Summary.WarningCount))
|
fmt.Fprintf(&sb, "- Warnings: %d\n", report.Summary.WarningCount)
|
||||||
}
|
}
|
||||||
|
|
||||||
if report.Summary.PassedCount > 0 {
|
if report.Summary.PassedCount > 0 {
|
||||||
|
|||||||
@@ -231,14 +231,13 @@ func (r *Reader) queryColumns(schemaName string) (map[string]map[string]*models.
|
|||||||
}
|
}
|
||||||
|
|
||||||
column := models.InitColumn(columnName, tableName, schema)
|
column := models.InitColumn(columnName, tableName, schema)
|
||||||
column.Type = r.mapDataType(dataType, udtName)
|
|
||||||
column.NotNull = (isNullable == "NO")
|
|
||||||
column.Sequence = uint(ordinalPosition)
|
|
||||||
|
|
||||||
|
// Check if this is a serial type (has nextval default)
|
||||||
|
hasNextval := false
|
||||||
if columnDefault != nil {
|
if columnDefault != nil {
|
||||||
// Parse default value - remove nextval for sequences
|
|
||||||
defaultVal := *columnDefault
|
defaultVal := *columnDefault
|
||||||
if strings.HasPrefix(defaultVal, "nextval") {
|
if strings.HasPrefix(defaultVal, "nextval") {
|
||||||
|
hasNextval = true
|
||||||
column.AutoIncrement = true
|
column.AutoIncrement = true
|
||||||
column.Default = defaultVal
|
column.Default = defaultVal
|
||||||
} else {
|
} else {
|
||||||
@@ -246,6 +245,11 @@ func (r *Reader) queryColumns(schemaName string) (map[string]map[string]*models.
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Map data type, preserving serial types when detected
|
||||||
|
column.Type = r.mapDataType(dataType, udtName, hasNextval)
|
||||||
|
column.NotNull = (isNullable == "NO")
|
||||||
|
column.Sequence = uint(ordinalPosition)
|
||||||
|
|
||||||
if description != nil {
|
if description != nil {
|
||||||
column.Description = *description
|
column.Description = *description
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package pgsql
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/jackc/pgx/v5"
|
"github.com/jackc/pgx/v5"
|
||||||
|
|
||||||
@@ -259,33 +260,46 @@ func (r *Reader) close() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// mapDataType maps PostgreSQL data types to canonical types
|
// mapDataType maps PostgreSQL data types to canonical types
|
||||||
func (r *Reader) mapDataType(pgType, udtName string) string {
|
func (r *Reader) mapDataType(pgType, udtName string, hasNextval bool) string {
|
||||||
|
// If the column has a nextval default, it's likely a serial type
|
||||||
|
// Map to the appropriate serial type instead of the base integer type
|
||||||
|
if hasNextval {
|
||||||
|
switch strings.ToLower(pgType) {
|
||||||
|
case "integer", "int", "int4":
|
||||||
|
return "serial"
|
||||||
|
case "bigint", "int8":
|
||||||
|
return "bigserial"
|
||||||
|
case "smallint", "int2":
|
||||||
|
return "smallserial"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Map common PostgreSQL types
|
// Map common PostgreSQL types
|
||||||
typeMap := map[string]string{
|
typeMap := map[string]string{
|
||||||
"integer": "int",
|
"integer": "integer",
|
||||||
"bigint": "int64",
|
"bigint": "bigint",
|
||||||
"smallint": "int16",
|
"smallint": "smallint",
|
||||||
"int": "int",
|
"int": "integer",
|
||||||
"int2": "int16",
|
"int2": "smallint",
|
||||||
"int4": "int",
|
"int4": "integer",
|
||||||
"int8": "int64",
|
"int8": "bigint",
|
||||||
"serial": "int",
|
"serial": "serial",
|
||||||
"bigserial": "int64",
|
"bigserial": "bigserial",
|
||||||
"smallserial": "int16",
|
"smallserial": "smallserial",
|
||||||
"numeric": "decimal",
|
"numeric": "numeric",
|
||||||
"decimal": "decimal",
|
"decimal": "decimal",
|
||||||
"real": "float32",
|
"real": "real",
|
||||||
"double precision": "float64",
|
"double precision": "double precision",
|
||||||
"float4": "float32",
|
"float4": "real",
|
||||||
"float8": "float64",
|
"float8": "double precision",
|
||||||
"money": "decimal",
|
"money": "money",
|
||||||
"character varying": "string",
|
"character varying": "varchar",
|
||||||
"varchar": "string",
|
"varchar": "varchar",
|
||||||
"character": "string",
|
"character": "char",
|
||||||
"char": "string",
|
"char": "char",
|
||||||
"text": "string",
|
"text": "text",
|
||||||
"boolean": "bool",
|
"boolean": "boolean",
|
||||||
"bool": "bool",
|
"bool": "boolean",
|
||||||
"date": "date",
|
"date": "date",
|
||||||
"time": "time",
|
"time": "time",
|
||||||
"time without time zone": "time",
|
"time without time zone": "time",
|
||||||
|
|||||||
@@ -177,20 +177,20 @@ func TestMapDataType(t *testing.T) {
|
|||||||
udtName string
|
udtName string
|
||||||
expected string
|
expected string
|
||||||
}{
|
}{
|
||||||
{"integer", "int4", "int"},
|
{"integer", "int4", "integer"},
|
||||||
{"bigint", "int8", "int64"},
|
{"bigint", "int8", "bigint"},
|
||||||
{"smallint", "int2", "int16"},
|
{"smallint", "int2", "smallint"},
|
||||||
{"character varying", "varchar", "string"},
|
{"character varying", "varchar", "varchar"},
|
||||||
{"text", "text", "string"},
|
{"text", "text", "text"},
|
||||||
{"boolean", "bool", "bool"},
|
{"boolean", "bool", "boolean"},
|
||||||
{"timestamp without time zone", "timestamp", "timestamp"},
|
{"timestamp without time zone", "timestamp", "timestamp"},
|
||||||
{"timestamp with time zone", "timestamptz", "timestamptz"},
|
{"timestamp with time zone", "timestamptz", "timestamptz"},
|
||||||
{"json", "json", "json"},
|
{"json", "json", "json"},
|
||||||
{"jsonb", "jsonb", "jsonb"},
|
{"jsonb", "jsonb", "jsonb"},
|
||||||
{"uuid", "uuid", "uuid"},
|
{"uuid", "uuid", "uuid"},
|
||||||
{"numeric", "numeric", "decimal"},
|
{"numeric", "numeric", "numeric"},
|
||||||
{"real", "float4", "float32"},
|
{"real", "float4", "real"},
|
||||||
{"double precision", "float8", "float64"},
|
{"double precision", "float8", "double precision"},
|
||||||
{"date", "date", "date"},
|
{"date", "date", "date"},
|
||||||
{"time without time zone", "time", "time"},
|
{"time without time zone", "time", "time"},
|
||||||
{"bytea", "bytea", "bytea"},
|
{"bytea", "bytea", "bytea"},
|
||||||
@@ -199,12 +199,31 @@ func TestMapDataType(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.pgType, func(t *testing.T) {
|
t.Run(tt.pgType, func(t *testing.T) {
|
||||||
result := reader.mapDataType(tt.pgType, tt.udtName)
|
result := reader.mapDataType(tt.pgType, tt.udtName, false)
|
||||||
if result != tt.expected {
|
if result != tt.expected {
|
||||||
t.Errorf("mapDataType(%s, %s) = %s, expected %s", tt.pgType, tt.udtName, result, tt.expected)
|
t.Errorf("mapDataType(%s, %s) = %s, expected %s", tt.pgType, tt.udtName, result, tt.expected)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test serial type detection with hasNextval=true
|
||||||
|
serialTests := []struct {
|
||||||
|
pgType string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{"integer", "serial"},
|
||||||
|
{"bigint", "bigserial"},
|
||||||
|
{"smallint", "smallserial"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range serialTests {
|
||||||
|
t.Run(tt.pgType+"_with_nextval", func(t *testing.T) {
|
||||||
|
result := reader.mapDataType(tt.pgType, "", true)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("mapDataType(%s, '', true) = %s, expected %s", tt.pgType, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParseIndexDefinition(t *testing.T) {
|
func TestParseIndexDefinition(t *testing.T) {
|
||||||
|
|||||||
@@ -62,6 +62,17 @@ func (tm *TypeMapper) isSimpleType(sqlType string) bool {
|
|||||||
return simpleTypes[sqlType]
|
return simpleTypes[sqlType]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// isSerialType checks if a SQL type is a serial type (auto-incrementing)
|
||||||
|
func (tm *TypeMapper) isSerialType(sqlType string) bool {
|
||||||
|
baseType := tm.extractBaseType(sqlType)
|
||||||
|
serialTypes := map[string]bool{
|
||||||
|
"serial": true,
|
||||||
|
"bigserial": true,
|
||||||
|
"smallserial": true,
|
||||||
|
}
|
||||||
|
return serialTypes[baseType]
|
||||||
|
}
|
||||||
|
|
||||||
// baseGoType returns the base Go type for a SQL type (not null, simple types only)
|
// baseGoType returns the base Go type for a SQL type (not null, simple types only)
|
||||||
func (tm *TypeMapper) baseGoType(sqlType string) string {
|
func (tm *TypeMapper) baseGoType(sqlType string) string {
|
||||||
typeMap := map[string]string{
|
typeMap := map[string]string{
|
||||||
@@ -122,10 +133,10 @@ func (tm *TypeMapper) bunGoType(sqlType string) string {
|
|||||||
"decimal": tm.sqlTypesAlias + ".SqlFloat64",
|
"decimal": tm.sqlTypesAlias + ".SqlFloat64",
|
||||||
|
|
||||||
// Date/Time types
|
// Date/Time types
|
||||||
"timestamp": tm.sqlTypesAlias + ".SqlTime",
|
"timestamp": tm.sqlTypesAlias + ".SqlTimeStamp",
|
||||||
"timestamp without time zone": tm.sqlTypesAlias + ".SqlTime",
|
"timestamp without time zone": tm.sqlTypesAlias + ".SqlTimeStamp",
|
||||||
"timestamp with time zone": tm.sqlTypesAlias + ".SqlTime",
|
"timestamp with time zone": tm.sqlTypesAlias + ".SqlTimeStamp",
|
||||||
"timestamptz": tm.sqlTypesAlias + ".SqlTime",
|
"timestamptz": tm.sqlTypesAlias + ".SqlTimeStamp",
|
||||||
"date": tm.sqlTypesAlias + ".SqlDate",
|
"date": tm.sqlTypesAlias + ".SqlDate",
|
||||||
"time": tm.sqlTypesAlias + ".SqlTime",
|
"time": tm.sqlTypesAlias + ".SqlTime",
|
||||||
"time without time zone": tm.sqlTypesAlias + ".SqlTime",
|
"time without time zone": tm.sqlTypesAlias + ".SqlTime",
|
||||||
@@ -190,10 +201,15 @@ func (tm *TypeMapper) BuildBunTag(column *models.Column, table *models.Table) st
|
|||||||
parts = append(parts, "pk")
|
parts = append(parts, "pk")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Auto increment (for serial types or explicit auto_increment)
|
||||||
|
if column.AutoIncrement || tm.isSerialType(column.Type) {
|
||||||
|
parts = append(parts, "autoincrement")
|
||||||
|
}
|
||||||
|
|
||||||
// Default value
|
// Default value
|
||||||
if column.Default != nil {
|
if column.Default != nil {
|
||||||
// Sanitize default value to remove backticks
|
// Sanitize default value to remove backticks, then quote based on column type
|
||||||
safeDefault := writers.SanitizeStructTagValue(fmt.Sprintf("%v", column.Default))
|
safeDefault := writers.QuoteDefaultValue(writers.SanitizeStructTagValue(fmt.Sprintf("%v", column.Default)), column.Type)
|
||||||
parts = append(parts, fmt.Sprintf("default:%s", safeDefault))
|
parts = append(parts, fmt.Sprintf("default:%s", safeDefault))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -251,7 +267,15 @@ func (tm *TypeMapper) BuildRelationshipTag(constraint *models.Constraint, relTyp
|
|||||||
if len(constraint.Columns) > 0 && len(constraint.ReferencedColumns) > 0 {
|
if len(constraint.Columns) > 0 && len(constraint.ReferencedColumns) > 0 {
|
||||||
localCol := constraint.Columns[0]
|
localCol := constraint.Columns[0]
|
||||||
foreignCol := constraint.ReferencedColumns[0]
|
foreignCol := constraint.ReferencedColumns[0]
|
||||||
parts = append(parts, fmt.Sprintf("join:%s=%s", localCol, foreignCol))
|
|
||||||
|
// For has-many relationships, swap the columns
|
||||||
|
// has-one: join:fk_in_this_table=pk_in_other_table
|
||||||
|
// has-many: join:pk_in_this_table=fk_in_other_table
|
||||||
|
if relType == "has-many" {
|
||||||
|
parts = append(parts, fmt.Sprintf("join:%s=%s", foreignCol, localCol))
|
||||||
|
} else {
|
||||||
|
parts = append(parts, fmt.Sprintf("join:%s=%s", localCol, foreignCol))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return strings.Join(parts, ",")
|
return strings.Join(parts, ",")
|
||||||
|
|||||||
@@ -90,8 +90,8 @@ func TestWriter_WriteTable(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Verify Bun-specific elements
|
// Verify Bun-specific elements
|
||||||
if !strings.Contains(generated, "bun:\"id,type:bigint,pk,") {
|
if !strings.Contains(generated, "bun:\"id,type:bigint,pk,autoincrement,") {
|
||||||
t.Errorf("Missing Bun-style primary key tag")
|
t.Errorf("Missing Bun-style primary key tag with autoincrement")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -308,14 +308,20 @@ func TestWriter_MultipleReferencesToSameTable(t *testing.T) {
|
|||||||
filepointerStr := string(filepointerContent)
|
filepointerStr := string(filepointerContent)
|
||||||
|
|
||||||
// Should have two different has-many relationships with unique names
|
// Should have two different has-many relationships with unique names
|
||||||
hasManyExpectations := []string{
|
hasManyExpectations := []struct {
|
||||||
"RelRIDFilepointerRequestOrgAPIEvents", // Has many via rid_filepointer_request
|
fieldName string
|
||||||
"RelRIDFilepointerResponseOrgAPIEvents", // Has many via rid_filepointer_response
|
tag string
|
||||||
|
}{
|
||||||
|
{"RelRIDFilepointerRequestOrgAPIEvents", "join:id_filepointer=rid_filepointer_request"}, // Has many via rid_filepointer_request
|
||||||
|
{"RelRIDFilepointerResponseOrgAPIEvents", "join:id_filepointer=rid_filepointer_response"}, // Has many via rid_filepointer_response
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, exp := range hasManyExpectations {
|
for _, exp := range hasManyExpectations {
|
||||||
if !strings.Contains(filepointerStr, exp) {
|
if !strings.Contains(filepointerStr, exp.fieldName) {
|
||||||
t.Errorf("Missing has-many relationship field: %s\nGenerated:\n%s", exp, filepointerStr)
|
t.Errorf("Missing has-many relationship field: %s\nGenerated:\n%s", exp.fieldName, filepointerStr)
|
||||||
|
}
|
||||||
|
if !strings.Contains(filepointerStr, exp.tag) {
|
||||||
|
t.Errorf("Missing has-many relationship join tag: %s\nGenerated:\n%s", exp.tag, filepointerStr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -455,10 +461,10 @@ func TestWriter_MultipleHasManyRelationships(t *testing.T) {
|
|||||||
|
|
||||||
// Verify all has-many relationships have unique names
|
// Verify all has-many relationships have unique names
|
||||||
hasManyExpectations := []string{
|
hasManyExpectations := []string{
|
||||||
"RelRIDAPIProviderOrgLogins", // Has many via Login
|
"RelRIDAPIProviderOrgLogins", // Has many via Login
|
||||||
"RelRIDAPIProviderOrgFilepointers", // Has many via Filepointer
|
"RelRIDAPIProviderOrgFilepointers", // Has many via Filepointer
|
||||||
"RelRIDAPIProviderOrgAPIEvents", // Has many via APIEvent
|
"RelRIDAPIProviderOrgAPIEvents", // Has many via APIEvent
|
||||||
"RelRIDOwner", // Has one via rid_owner
|
"RelRIDOwner", // Has one via rid_owner
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, exp := range hasManyExpectations {
|
for _, exp := range hasManyExpectations {
|
||||||
@@ -561,8 +567,8 @@ func TestTypeMapper_SQLTypeToGoType_Bun(t *testing.T) {
|
|||||||
{"bigint", false, "resolvespec_common.SqlInt64"},
|
{"bigint", false, "resolvespec_common.SqlInt64"},
|
||||||
{"varchar", true, "resolvespec_common.SqlString"}, // Bun uses sql types even for NOT NULL strings
|
{"varchar", true, "resolvespec_common.SqlString"}, // Bun uses sql types even for NOT NULL strings
|
||||||
{"varchar", false, "resolvespec_common.SqlString"},
|
{"varchar", false, "resolvespec_common.SqlString"},
|
||||||
{"timestamp", true, "resolvespec_common.SqlTime"},
|
{"timestamp", true, "resolvespec_common.SqlTimeStamp"},
|
||||||
{"timestamp", false, "resolvespec_common.SqlTime"},
|
{"timestamp", false, "resolvespec_common.SqlTimeStamp"},
|
||||||
{"date", false, "resolvespec_common.SqlDate"},
|
{"date", false, "resolvespec_common.SqlDate"},
|
||||||
{"boolean", true, "bool"},
|
{"boolean", true, "bool"},
|
||||||
{"boolean", false, "resolvespec_common.SqlBool"},
|
{"boolean", false, "resolvespec_common.SqlBool"},
|
||||||
@@ -609,14 +615,75 @@ func TestTypeMapper_BuildBunTag(t *testing.T) {
|
|||||||
want: []string{"email,", "type:varchar(255),", "nullzero,"},
|
want: []string{"email,", "type:varchar(255),", "nullzero,"},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "with default",
|
name: "with default string",
|
||||||
column: &models.Column{
|
column: &models.Column{
|
||||||
Name: "status",
|
Name: "status",
|
||||||
Type: "text",
|
Type: "text",
|
||||||
NotNull: true,
|
NotNull: true,
|
||||||
Default: "active",
|
Default: "active",
|
||||||
},
|
},
|
||||||
want: []string{"status,", "type:text,", "default:active,"},
|
want: []string{"status,", "type:text,", "default:'active',"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "with default integer",
|
||||||
|
column: &models.Column{
|
||||||
|
Name: "retries",
|
||||||
|
Type: "integer",
|
||||||
|
NotNull: true,
|
||||||
|
Default: "0",
|
||||||
|
},
|
||||||
|
want: []string{"retries,", "type:integer,", "default:0,"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "with default boolean",
|
||||||
|
column: &models.Column{
|
||||||
|
Name: "active",
|
||||||
|
Type: "boolean",
|
||||||
|
NotNull: true,
|
||||||
|
Default: "true",
|
||||||
|
},
|
||||||
|
want: []string{"active,", "type:boolean,", "default:true,"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "with default function call",
|
||||||
|
column: &models.Column{
|
||||||
|
Name: "created_at",
|
||||||
|
Type: "timestamp",
|
||||||
|
NotNull: true,
|
||||||
|
Default: "now()",
|
||||||
|
},
|
||||||
|
want: []string{"created_at,", "type:timestamp,", "default:now(),"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "auto increment with AutoIncrement flag",
|
||||||
|
column: &models.Column{
|
||||||
|
Name: "id",
|
||||||
|
Type: "bigint",
|
||||||
|
NotNull: true,
|
||||||
|
IsPrimaryKey: true,
|
||||||
|
AutoIncrement: true,
|
||||||
|
},
|
||||||
|
want: []string{"id,", "type:bigint,", "pk,", "autoincrement,"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "serial type (auto-increment)",
|
||||||
|
column: &models.Column{
|
||||||
|
Name: "id",
|
||||||
|
Type: "serial",
|
||||||
|
NotNull: true,
|
||||||
|
IsPrimaryKey: true,
|
||||||
|
},
|
||||||
|
want: []string{"id,", "type:serial,", "pk,", "autoincrement,"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "bigserial type (auto-increment)",
|
||||||
|
column: &models.Column{
|
||||||
|
Name: "id",
|
||||||
|
Type: "bigserial",
|
||||||
|
NotNull: true,
|
||||||
|
IsPrimaryKey: true,
|
||||||
|
},
|
||||||
|
want: []string{"id,", "type:bigserial,", "pk,", "autoincrement,"},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -62,10 +62,10 @@ func (w *Writer) databaseToDBML(d *models.Database) string {
|
|||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
|
|
||||||
if d.Description != "" {
|
if d.Description != "" {
|
||||||
sb.WriteString(fmt.Sprintf("// %s\n", d.Description))
|
fmt.Fprintf(&sb, "// %s\n", d.Description)
|
||||||
}
|
}
|
||||||
if d.Comment != "" {
|
if d.Comment != "" {
|
||||||
sb.WriteString(fmt.Sprintf("// %s\n", d.Comment))
|
fmt.Fprintf(&sb, "// %s\n", d.Comment)
|
||||||
}
|
}
|
||||||
if d.Description != "" || d.Comment != "" {
|
if d.Description != "" || d.Comment != "" {
|
||||||
sb.WriteString("\n")
|
sb.WriteString("\n")
|
||||||
@@ -94,7 +94,7 @@ func (w *Writer) schemaToDBML(schema *models.Schema) string {
|
|||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
|
|
||||||
if schema.Description != "" {
|
if schema.Description != "" {
|
||||||
sb.WriteString(fmt.Sprintf("// Schema: %s - %s\n", schema.Name, schema.Description))
|
fmt.Fprintf(&sb, "// Schema: %s - %s\n", schema.Name, schema.Description)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, table := range schema.Tables {
|
for _, table := range schema.Tables {
|
||||||
@@ -110,10 +110,10 @@ func (w *Writer) tableToDBML(t *models.Table) string {
|
|||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
|
|
||||||
tableName := fmt.Sprintf("%s.%s", t.Schema, t.Name)
|
tableName := fmt.Sprintf("%s.%s", t.Schema, t.Name)
|
||||||
sb.WriteString(fmt.Sprintf("Table %s {\n", tableName))
|
fmt.Fprintf(&sb, "Table %s {\n", tableName)
|
||||||
|
|
||||||
for _, column := range t.Columns {
|
for _, column := range t.Columns {
|
||||||
sb.WriteString(fmt.Sprintf(" %s %s", column.Name, column.Type))
|
fmt.Fprintf(&sb, " %s %s", column.Name, column.Type)
|
||||||
|
|
||||||
var attrs []string
|
var attrs []string
|
||||||
if column.IsPrimaryKey {
|
if column.IsPrimaryKey {
|
||||||
@@ -138,11 +138,11 @@ func (w *Writer) tableToDBML(t *models.Table) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(attrs) > 0 {
|
if len(attrs) > 0 {
|
||||||
sb.WriteString(fmt.Sprintf(" [%s]", strings.Join(attrs, ", ")))
|
fmt.Fprintf(&sb, " [%s]", strings.Join(attrs, ", "))
|
||||||
}
|
}
|
||||||
|
|
||||||
if column.Comment != "" {
|
if column.Comment != "" {
|
||||||
sb.WriteString(fmt.Sprintf(" // %s", column.Comment))
|
fmt.Fprintf(&sb, " // %s", column.Comment)
|
||||||
}
|
}
|
||||||
sb.WriteString("\n")
|
sb.WriteString("\n")
|
||||||
}
|
}
|
||||||
@@ -161,9 +161,9 @@ func (w *Writer) tableToDBML(t *models.Table) string {
|
|||||||
indexAttrs = append(indexAttrs, fmt.Sprintf("type: %s", index.Type))
|
indexAttrs = append(indexAttrs, fmt.Sprintf("type: %s", index.Type))
|
||||||
}
|
}
|
||||||
|
|
||||||
sb.WriteString(fmt.Sprintf(" (%s)", strings.Join(index.Columns, ", ")))
|
fmt.Fprintf(&sb, " (%s)", strings.Join(index.Columns, ", "))
|
||||||
if len(indexAttrs) > 0 {
|
if len(indexAttrs) > 0 {
|
||||||
sb.WriteString(fmt.Sprintf(" [%s]", strings.Join(indexAttrs, ", ")))
|
fmt.Fprintf(&sb, " [%s]", strings.Join(indexAttrs, ", "))
|
||||||
}
|
}
|
||||||
sb.WriteString("\n")
|
sb.WriteString("\n")
|
||||||
}
|
}
|
||||||
@@ -172,7 +172,7 @@ func (w *Writer) tableToDBML(t *models.Table) string {
|
|||||||
|
|
||||||
note := strings.TrimSpace(t.Description + " " + t.Comment)
|
note := strings.TrimSpace(t.Description + " " + t.Comment)
|
||||||
if note != "" {
|
if note != "" {
|
||||||
sb.WriteString(fmt.Sprintf("\n Note: '%s'\n", note))
|
fmt.Fprintf(&sb, "\n Note: '%s'\n", note)
|
||||||
}
|
}
|
||||||
|
|
||||||
sb.WriteString("}\n")
|
sb.WriteString("}\n")
|
||||||
|
|||||||
@@ -158,10 +158,10 @@ func (tm *TypeMapper) nullableGoType(sqlType string) string {
|
|||||||
"decimal": tm.sqlTypesAlias + ".SqlFloat64",
|
"decimal": tm.sqlTypesAlias + ".SqlFloat64",
|
||||||
|
|
||||||
// Date/Time types
|
// Date/Time types
|
||||||
"timestamp": tm.sqlTypesAlias + ".SqlTime",
|
"timestamp": tm.sqlTypesAlias + ".SqlTimeStamp",
|
||||||
"timestamp without time zone": tm.sqlTypesAlias + ".SqlTime",
|
"timestamp without time zone": tm.sqlTypesAlias + ".SqlTimeStamp",
|
||||||
"timestamp with time zone": tm.sqlTypesAlias + ".SqlTime",
|
"timestamp with time zone": tm.sqlTypesAlias + ".SqlTimeStamp",
|
||||||
"timestamptz": tm.sqlTypesAlias + ".SqlTime",
|
"timestamptz": tm.sqlTypesAlias + ".SqlTimeStamp",
|
||||||
"date": tm.sqlTypesAlias + ".SqlDate",
|
"date": tm.sqlTypesAlias + ".SqlDate",
|
||||||
"time": tm.sqlTypesAlias + ".SqlTime",
|
"time": tm.sqlTypesAlias + ".SqlTime",
|
||||||
"time without time zone": tm.sqlTypesAlias + ".SqlTime",
|
"time without time zone": tm.sqlTypesAlias + ".SqlTime",
|
||||||
@@ -238,8 +238,8 @@ func (tm *TypeMapper) BuildGormTag(column *models.Column, table *models.Table) s
|
|||||||
|
|
||||||
// Default value
|
// Default value
|
||||||
if column.Default != nil {
|
if column.Default != nil {
|
||||||
// Sanitize default value to remove backticks
|
// Sanitize default value to remove backticks, then quote based on column type
|
||||||
safeDefault := writers.SanitizeStructTagValue(fmt.Sprintf("%v", column.Default))
|
safeDefault := writers.QuoteDefaultValue(writers.SanitizeStructTagValue(fmt.Sprintf("%v", column.Default)), column.Type)
|
||||||
parts = append(parts, fmt.Sprintf("default:%s", safeDefault))
|
parts = append(parts, fmt.Sprintf("default:%s", safeDefault))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -655,7 +655,7 @@ func TestTypeMapper_SQLTypeToGoType(t *testing.T) {
|
|||||||
{"varchar", true, "string"},
|
{"varchar", true, "string"},
|
||||||
{"varchar", false, "sql_types.SqlString"},
|
{"varchar", false, "sql_types.SqlString"},
|
||||||
{"timestamp", true, "time.Time"},
|
{"timestamp", true, "time.Time"},
|
||||||
{"timestamp", false, "sql_types.SqlTime"},
|
{"timestamp", false, "sql_types.SqlTimeStamp"},
|
||||||
{"boolean", true, "bool"},
|
{"boolean", true, "bool"},
|
||||||
{"boolean", false, "sql_types.SqlBool"},
|
{"boolean", false, "sql_types.SqlBool"},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -52,7 +52,7 @@ func (w *Writer) databaseToGraphQL(db *models.Database) string {
|
|||||||
if w.shouldIncludeComments() {
|
if w.shouldIncludeComments() {
|
||||||
sb.WriteString("# Generated GraphQL Schema\n")
|
sb.WriteString("# Generated GraphQL Schema\n")
|
||||||
if db.Name != "" {
|
if db.Name != "" {
|
||||||
sb.WriteString(fmt.Sprintf("# Database: %s\n", db.Name))
|
fmt.Fprintf(&sb, "# Database: %s\n", db.Name)
|
||||||
}
|
}
|
||||||
sb.WriteString("\n")
|
sb.WriteString("\n")
|
||||||
}
|
}
|
||||||
@@ -62,7 +62,7 @@ func (w *Writer) databaseToGraphQL(db *models.Database) string {
|
|||||||
scalars := w.collectCustomScalars(db)
|
scalars := w.collectCustomScalars(db)
|
||||||
if len(scalars) > 0 {
|
if len(scalars) > 0 {
|
||||||
for _, scalar := range scalars {
|
for _, scalar := range scalars {
|
||||||
sb.WriteString(fmt.Sprintf("scalar %s\n", scalar))
|
fmt.Fprintf(&sb, "scalar %s\n", scalar)
|
||||||
}
|
}
|
||||||
sb.WriteString("\n")
|
sb.WriteString("\n")
|
||||||
}
|
}
|
||||||
@@ -176,9 +176,9 @@ func (w *Writer) isJoinTable(table *models.Table) bool {
|
|||||||
func (w *Writer) enumToGraphQL(enum *models.Enum) string {
|
func (w *Writer) enumToGraphQL(enum *models.Enum) string {
|
||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
|
|
||||||
sb.WriteString(fmt.Sprintf("enum %s {\n", enum.Name))
|
fmt.Fprintf(&sb, "enum %s {\n", enum.Name)
|
||||||
for _, value := range enum.Values {
|
for _, value := range enum.Values {
|
||||||
sb.WriteString(fmt.Sprintf(" %s\n", value))
|
fmt.Fprintf(&sb, " %s\n", value)
|
||||||
}
|
}
|
||||||
sb.WriteString("}\n")
|
sb.WriteString("}\n")
|
||||||
|
|
||||||
@@ -197,10 +197,10 @@ func (w *Writer) tableToGraphQL(table *models.Table, db *models.Database, schema
|
|||||||
if desc == "" {
|
if desc == "" {
|
||||||
desc = table.Comment
|
desc = table.Comment
|
||||||
}
|
}
|
||||||
sb.WriteString(fmt.Sprintf("# %s\n", desc))
|
fmt.Fprintf(&sb, "# %s\n", desc)
|
||||||
}
|
}
|
||||||
|
|
||||||
sb.WriteString(fmt.Sprintf("type %s {\n", typeName))
|
fmt.Fprintf(&sb, "type %s {\n", typeName)
|
||||||
|
|
||||||
// Collect and categorize fields
|
// Collect and categorize fields
|
||||||
var idFields, scalarFields, relationFields []string
|
var idFields, scalarFields, relationFields []string
|
||||||
|
|||||||
@@ -125,9 +125,9 @@ func (w *Writer) generateGenerator() string {
|
|||||||
func (w *Writer) enumToPrisma(enum *models.Enum) string {
|
func (w *Writer) enumToPrisma(enum *models.Enum) string {
|
||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
|
|
||||||
sb.WriteString(fmt.Sprintf("enum %s {\n", enum.Name))
|
fmt.Fprintf(&sb, "enum %s {\n", enum.Name)
|
||||||
for _, value := range enum.Values {
|
for _, value := range enum.Values {
|
||||||
sb.WriteString(fmt.Sprintf(" %s\n", value))
|
fmt.Fprintf(&sb, " %s\n", value)
|
||||||
}
|
}
|
||||||
sb.WriteString("}\n")
|
sb.WriteString("}\n")
|
||||||
|
|
||||||
@@ -179,7 +179,7 @@ func (w *Writer) identifyJoinTables(schema *models.Schema) map[string]bool {
|
|||||||
func (w *Writer) tableToPrisma(table *models.Table, schema *models.Schema, joinTables map[string]bool) string {
|
func (w *Writer) tableToPrisma(table *models.Table, schema *models.Schema, joinTables map[string]bool) string {
|
||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
|
|
||||||
sb.WriteString(fmt.Sprintf("model %s {\n", table.Name))
|
fmt.Fprintf(&sb, "model %s {\n", table.Name)
|
||||||
|
|
||||||
// Collect columns to write
|
// Collect columns to write
|
||||||
columns := make([]*models.Column, 0, len(table.Columns))
|
columns := make([]*models.Column, 0, len(table.Columns))
|
||||||
@@ -219,11 +219,11 @@ func (w *Writer) columnToField(col *models.Column, table *models.Table, schema *
|
|||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
|
|
||||||
// Field name
|
// Field name
|
||||||
sb.WriteString(fmt.Sprintf(" %s", col.Name))
|
fmt.Fprintf(&sb, " %s", col.Name)
|
||||||
|
|
||||||
// Field type
|
// Field type
|
||||||
prismaType := w.sqlTypeToPrisma(col.Type, schema)
|
prismaType := w.sqlTypeToPrisma(col.Type, schema)
|
||||||
sb.WriteString(fmt.Sprintf(" %s", prismaType))
|
fmt.Fprintf(&sb, " %s", prismaType)
|
||||||
|
|
||||||
// Optional modifier
|
// Optional modifier
|
||||||
if !col.NotNull && !col.IsPrimaryKey {
|
if !col.NotNull && !col.IsPrimaryKey {
|
||||||
@@ -413,7 +413,7 @@ func (w *Writer) generateRelationFields(table *models.Table, schema *models.Sche
|
|||||||
relationName = relationName[:len(relationName)-1]
|
relationName = relationName[:len(relationName)-1]
|
||||||
}
|
}
|
||||||
|
|
||||||
sb.WriteString(fmt.Sprintf(" %s %s", strings.ToLower(relationName), relationType))
|
fmt.Fprintf(&sb, " %s %s", strings.ToLower(relationName), relationType)
|
||||||
|
|
||||||
if isOptional {
|
if isOptional {
|
||||||
sb.WriteString("?")
|
sb.WriteString("?")
|
||||||
@@ -479,8 +479,8 @@ func (w *Writer) generateInverseRelations(table *models.Table, schema *models.Sc
|
|||||||
if fk.ReferencedTable != table.Name {
|
if fk.ReferencedTable != table.Name {
|
||||||
// This is the other side
|
// This is the other side
|
||||||
otherSide := fk.ReferencedTable
|
otherSide := fk.ReferencedTable
|
||||||
sb.WriteString(fmt.Sprintf(" %ss %s[]\n",
|
fmt.Fprintf(&sb, " %ss %s[]\n",
|
||||||
strings.ToLower(otherSide), otherSide))
|
strings.ToLower(otherSide), otherSide)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -497,8 +497,8 @@ func (w *Writer) generateInverseRelations(table *models.Table, schema *models.Sc
|
|||||||
pluralName += "s"
|
pluralName += "s"
|
||||||
}
|
}
|
||||||
|
|
||||||
sb.WriteString(fmt.Sprintf(" %s %s[]\n",
|
fmt.Fprintf(&sb, " %s %s[]\n",
|
||||||
strings.ToLower(pluralName), otherTable.Name))
|
strings.ToLower(pluralName), otherTable.Name)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -530,20 +530,20 @@ func (w *Writer) generateBlockAttributes(table *models.Table) string {
|
|||||||
|
|
||||||
if len(pkCols) > 1 {
|
if len(pkCols) > 1 {
|
||||||
sort.Strings(pkCols)
|
sort.Strings(pkCols)
|
||||||
sb.WriteString(fmt.Sprintf(" @@id([%s])\n", strings.Join(pkCols, ", ")))
|
fmt.Fprintf(&sb, " @@id([%s])\n", strings.Join(pkCols, ", "))
|
||||||
}
|
}
|
||||||
|
|
||||||
// @@unique for multi-column unique constraints
|
// @@unique for multi-column unique constraints
|
||||||
for _, constraint := range table.Constraints {
|
for _, constraint := range table.Constraints {
|
||||||
if constraint.Type == models.UniqueConstraint && len(constraint.Columns) > 1 {
|
if constraint.Type == models.UniqueConstraint && len(constraint.Columns) > 1 {
|
||||||
sb.WriteString(fmt.Sprintf(" @@unique([%s])\n", strings.Join(constraint.Columns, ", ")))
|
fmt.Fprintf(&sb, " @@unique([%s])\n", strings.Join(constraint.Columns, ", "))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// @@index for indexes
|
// @@index for indexes
|
||||||
for _, index := range table.Indexes {
|
for _, index := range table.Indexes {
|
||||||
if !index.Unique { // Unique indexes are handled by @@unique
|
if !index.Unique { // Unique indexes are handled by @@unique
|
||||||
sb.WriteString(fmt.Sprintf(" @@index([%s])\n", strings.Join(index.Columns, ", ")))
|
fmt.Fprintf(&sb, " @@index([%s])\n", strings.Join(index.Columns, ", "))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -207,7 +207,7 @@ func (w *Writer) tableToEntity(table *models.Table, schema *models.Schema, joinT
|
|||||||
|
|
||||||
// Generate @Entity decorator with options
|
// Generate @Entity decorator with options
|
||||||
entityOptions := w.buildEntityOptions(table)
|
entityOptions := w.buildEntityOptions(table)
|
||||||
sb.WriteString(fmt.Sprintf("@Entity({\n%s\n})\n", entityOptions))
|
fmt.Fprintf(&sb, "@Entity({\n%s\n})\n", entityOptions)
|
||||||
|
|
||||||
// Get class name (from metadata if different from table name)
|
// Get class name (from metadata if different from table name)
|
||||||
className := table.Name
|
className := table.Name
|
||||||
@@ -219,7 +219,7 @@ func (w *Writer) tableToEntity(table *models.Table, schema *models.Schema, joinT
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
sb.WriteString(fmt.Sprintf("export class %s {\n", className))
|
fmt.Fprintf(&sb, "export class %s {\n", className)
|
||||||
|
|
||||||
// Collect and sort columns
|
// Collect and sort columns
|
||||||
columns := make([]*models.Column, 0, len(table.Columns))
|
columns := make([]*models.Column, 0, len(table.Columns))
|
||||||
@@ -272,7 +272,7 @@ func (w *Writer) viewToEntity(view *models.View) string {
|
|||||||
sb.WriteString("})\n")
|
sb.WriteString("})\n")
|
||||||
|
|
||||||
// Generate class
|
// Generate class
|
||||||
sb.WriteString(fmt.Sprintf("export class %s {\n", view.Name))
|
fmt.Fprintf(&sb, "export class %s {\n", view.Name)
|
||||||
|
|
||||||
// Generate field definitions (without decorators for view fields)
|
// Generate field definitions (without decorators for view fields)
|
||||||
columns := make([]*models.Column, 0, len(view.Columns))
|
columns := make([]*models.Column, 0, len(view.Columns))
|
||||||
@@ -285,7 +285,7 @@ func (w *Writer) viewToEntity(view *models.View) string {
|
|||||||
|
|
||||||
for _, col := range columns {
|
for _, col := range columns {
|
||||||
tsType := w.sqlTypeToTypeScript(col.Type)
|
tsType := w.sqlTypeToTypeScript(col.Type)
|
||||||
sb.WriteString(fmt.Sprintf(" %s: %s;\n", col.Name, tsType))
|
fmt.Fprintf(&sb, " %s: %s;\n", col.Name, tsType)
|
||||||
}
|
}
|
||||||
|
|
||||||
sb.WriteString("}\n")
|
sb.WriteString("}\n")
|
||||||
@@ -314,7 +314,7 @@ func (w *Writer) columnToField(col *models.Column, table *models.Table) string {
|
|||||||
// Regular @Column decorator
|
// Regular @Column decorator
|
||||||
options := w.buildColumnOptions(col, table)
|
options := w.buildColumnOptions(col, table)
|
||||||
if options != "" {
|
if options != "" {
|
||||||
sb.WriteString(fmt.Sprintf(" @Column({ %s })\n", options))
|
fmt.Fprintf(&sb, " @Column({ %s })\n", options)
|
||||||
} else {
|
} else {
|
||||||
sb.WriteString(" @Column()\n")
|
sb.WriteString(" @Column()\n")
|
||||||
}
|
}
|
||||||
@@ -327,7 +327,7 @@ func (w *Writer) columnToField(col *models.Column, table *models.Table) string {
|
|||||||
nullable = " | null"
|
nullable = " | null"
|
||||||
}
|
}
|
||||||
|
|
||||||
sb.WriteString(fmt.Sprintf(" %s: %s%s;", col.Name, tsType, nullable))
|
fmt.Fprintf(&sb, " %s: %s%s;", col.Name, tsType, nullable)
|
||||||
|
|
||||||
return sb.String()
|
return sb.String()
|
||||||
}
|
}
|
||||||
@@ -464,17 +464,17 @@ func (w *Writer) generateRelationFields(table *models.Table, schema *models.Sche
|
|||||||
inverseField := w.findInverseFieldName(table.Name, relatedTable, schema)
|
inverseField := w.findInverseFieldName(table.Name, relatedTable, schema)
|
||||||
|
|
||||||
if inverseField != "" {
|
if inverseField != "" {
|
||||||
sb.WriteString(fmt.Sprintf(" @ManyToOne(() => %s, %s => %s.%s)\n",
|
fmt.Fprintf(&sb, " @ManyToOne(() => %s, %s => %s.%s)\n",
|
||||||
relatedTable, strings.ToLower(relatedTable), strings.ToLower(relatedTable), inverseField))
|
relatedTable, strings.ToLower(relatedTable), strings.ToLower(relatedTable), inverseField)
|
||||||
} else {
|
} else {
|
||||||
if isNullable {
|
if isNullable {
|
||||||
sb.WriteString(fmt.Sprintf(" @ManyToOne(() => %s, { nullable: true })\n", relatedTable))
|
fmt.Fprintf(&sb, " @ManyToOne(() => %s, { nullable: true })\n", relatedTable)
|
||||||
} else {
|
} else {
|
||||||
sb.WriteString(fmt.Sprintf(" @ManyToOne(() => %s)\n", relatedTable))
|
fmt.Fprintf(&sb, " @ManyToOne(() => %s)\n", relatedTable)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
sb.WriteString(fmt.Sprintf(" %s: %s%s;\n", fieldName, relatedTable, nullable))
|
fmt.Fprintf(&sb, " %s: %s%s;\n", fieldName, relatedTable, nullable)
|
||||||
sb.WriteString("\n")
|
sb.WriteString("\n")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -81,6 +81,64 @@ func SanitizeFilename(name string) string {
|
|||||||
return name
|
return name
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// QuoteDefaultValue wraps a sanitized default value in single quotes when the SQL
|
||||||
|
// column type requires it (strings, dates, times, UUIDs, enums). Numeric types
|
||||||
|
// (integers, floats, serials) and boolean types are left unquoted. Function-call
|
||||||
|
// expressions such as now() or gen_random_uuid() are always left unquoted regardless
|
||||||
|
// of type, because they contain parentheses.
|
||||||
|
//
|
||||||
|
// Examples (varchar): "disconnected" → "'disconnected'"
|
||||||
|
// Examples (boolean): "true" → "true"
|
||||||
|
// Examples (bigint): "0" → "0"
|
||||||
|
// Examples (timestamp): "now()" → "now()" (function call – never quoted)
|
||||||
|
func QuoteDefaultValue(value, sqlType string) string {
|
||||||
|
// Function calls are never quoted regardless of column type.
|
||||||
|
if strings.Contains(value, "(") || strings.Contains(value, ")") {
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
|
||||||
|
// Normalise the SQL type: lowercase, strip length/precision suffix.
|
||||||
|
baseType := strings.ToLower(strings.TrimSpace(sqlType))
|
||||||
|
if idx := strings.Index(baseType, "("); idx > 0 {
|
||||||
|
baseType = baseType[:idx]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Types whose default values must NOT be quoted.
|
||||||
|
unquotedTypes := map[string]bool{
|
||||||
|
// Integer types
|
||||||
|
"integer": true,
|
||||||
|
"int": true,
|
||||||
|
"int2": true,
|
||||||
|
"int4": true,
|
||||||
|
"int8": true,
|
||||||
|
"smallint": true,
|
||||||
|
"bigint": true,
|
||||||
|
"serial": true,
|
||||||
|
"smallserial": true,
|
||||||
|
"bigserial": true,
|
||||||
|
// Float / numeric types
|
||||||
|
"real": true,
|
||||||
|
"float": true,
|
||||||
|
"float4": true,
|
||||||
|
"float8": true,
|
||||||
|
"double precision": true,
|
||||||
|
"numeric": true,
|
||||||
|
"decimal": true,
|
||||||
|
"money": true,
|
||||||
|
// Boolean
|
||||||
|
"boolean": true,
|
||||||
|
"bool": true,
|
||||||
|
}
|
||||||
|
|
||||||
|
if unquotedTypes[baseType] {
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
|
||||||
|
// Everything else (text, varchar, char, uuid, date, time, timestamp, json, …)
|
||||||
|
// is treated as a quoted literal.
|
||||||
|
return "'" + value + "'"
|
||||||
|
}
|
||||||
|
|
||||||
// SanitizeStructTagValue sanitizes a value to be safely used inside Go struct tags.
|
// SanitizeStructTagValue sanitizes a value to be safely used inside Go struct tags.
|
||||||
// Go struct tags are delimited by backticks, so any backtick in the value would break the syntax.
|
// Go struct tags are delimited by backticks, so any backtick in the value would break the syntax.
|
||||||
// This function:
|
// This function:
|
||||||
|
|||||||
Reference in New Issue
Block a user