Compare commits
20 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5fb09b78c3 | ||
| 5d9770b430 | |||
| f2d500f98d | |||
| 2ec9991324 | |||
| a3e45c206d | |||
| 165623bb1d | |||
| 3c20c3c5d9 | |||
| a54594e49b | |||
| cafe6a461f | |||
| abdb9b4c78 | |||
| e7a15c8e4f | |||
| c36b5ede2b | |||
| 51ab29f8e3 | |||
| f532fc110c | |||
| 92dff99725 | |||
| 283b568adb | |||
| 122743ee43 | |||
| 91b6046b9b | |||
| 6f55505444 | |||
| e0e7b64c69 |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -47,3 +47,4 @@ dist/
|
|||||||
build/
|
build/
|
||||||
bin/
|
bin/
|
||||||
tests/integration/failed_statements_example.txt
|
tests/integration/failed_statements_example.txt
|
||||||
|
test_output.log
|
||||||
|
|||||||
@@ -38,13 +38,14 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
convertSourceType string
|
convertSourceType string
|
||||||
convertSourcePath string
|
convertSourcePath string
|
||||||
convertSourceConn string
|
convertSourceConn string
|
||||||
convertTargetType string
|
convertTargetType string
|
||||||
convertTargetPath string
|
convertTargetPath string
|
||||||
convertPackageName string
|
convertPackageName string
|
||||||
convertSchemaFilter string
|
convertSchemaFilter string
|
||||||
|
convertFlattenSchema bool
|
||||||
)
|
)
|
||||||
|
|
||||||
var convertCmd = &cobra.Command{
|
var convertCmd = &cobra.Command{
|
||||||
@@ -148,6 +149,7 @@ func init() {
|
|||||||
convertCmd.Flags().StringVar(&convertTargetPath, "to-path", "", "Target output path (file or directory)")
|
convertCmd.Flags().StringVar(&convertTargetPath, "to-path", "", "Target output path (file or directory)")
|
||||||
convertCmd.Flags().StringVar(&convertPackageName, "package", "", "Package name (for code generation formats like gorm/bun)")
|
convertCmd.Flags().StringVar(&convertPackageName, "package", "", "Package name (for code generation formats like gorm/bun)")
|
||||||
convertCmd.Flags().StringVar(&convertSchemaFilter, "schema", "", "Filter to a specific schema by name (required for formats like dctx that only support single schemas)")
|
convertCmd.Flags().StringVar(&convertSchemaFilter, "schema", "", "Filter to a specific schema by name (required for formats like dctx that only support single schemas)")
|
||||||
|
convertCmd.Flags().BoolVar(&convertFlattenSchema, "flatten-schema", false, "Flatten schema.table names to schema_table (useful for databases like SQLite that do not support schemas)")
|
||||||
|
|
||||||
err := convertCmd.MarkFlagRequired("from")
|
err := convertCmd.MarkFlagRequired("from")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -202,7 +204,7 @@ func runConvert(cmd *cobra.Command, args []string) error {
|
|||||||
fmt.Fprintf(os.Stderr, " Schema: %s\n", convertSchemaFilter)
|
fmt.Fprintf(os.Stderr, " Schema: %s\n", convertSchemaFilter)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := writeDatabase(db, convertTargetType, convertTargetPath, convertPackageName, convertSchemaFilter); err != nil {
|
if err := writeDatabase(db, convertTargetType, convertTargetPath, convertPackageName, convertSchemaFilter, convertFlattenSchema); err != nil {
|
||||||
return fmt.Errorf("failed to write target: %w", err)
|
return fmt.Errorf("failed to write target: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -301,12 +303,13 @@ func readDatabaseForConvert(dbType, filePath, connString string) (*models.Databa
|
|||||||
return db, nil
|
return db, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func writeDatabase(db *models.Database, dbType, outputPath, packageName, schemaFilter string) error {
|
func writeDatabase(db *models.Database, dbType, outputPath, packageName, schemaFilter string, flattenSchema bool) error {
|
||||||
var writer writers.Writer
|
var writer writers.Writer
|
||||||
|
|
||||||
writerOpts := &writers.WriterOptions{
|
writerOpts := &writers.WriterOptions{
|
||||||
OutputPath: outputPath,
|
OutputPath: outputPath,
|
||||||
PackageName: packageName,
|
PackageName: packageName,
|
||||||
|
FlattenSchema: flattenSchema,
|
||||||
}
|
}
|
||||||
|
|
||||||
switch strings.ToLower(dbType) {
|
switch strings.ToLower(dbType) {
|
||||||
|
|||||||
@@ -55,6 +55,8 @@ var (
|
|||||||
mergeSkipSequences bool
|
mergeSkipSequences bool
|
||||||
mergeSkipTables string // Comma-separated table names to skip
|
mergeSkipTables string // Comma-separated table names to skip
|
||||||
mergeVerbose bool
|
mergeVerbose bool
|
||||||
|
mergeReportPath string // Path to write merge report
|
||||||
|
mergeFlattenSchema bool
|
||||||
)
|
)
|
||||||
|
|
||||||
var mergeCmd = &cobra.Command{
|
var mergeCmd = &cobra.Command{
|
||||||
@@ -78,6 +80,12 @@ Examples:
|
|||||||
--source pgsql --source-conn "postgres://user:pass@localhost/source_db" \
|
--source pgsql --source-conn "postgres://user:pass@localhost/source_db" \
|
||||||
--output json --output-path combined.json
|
--output json --output-path combined.json
|
||||||
|
|
||||||
|
# Merge and execute on PostgreSQL database with report
|
||||||
|
relspec merge --target json --target-path base.json \
|
||||||
|
--source json --source-path additional.json \
|
||||||
|
--output pgsql --output-conn "postgres://user:pass@localhost/target_db" \
|
||||||
|
--merge-report merge-report.json
|
||||||
|
|
||||||
# Merge DBML and YAML, skip relations
|
# Merge DBML and YAML, skip relations
|
||||||
relspec merge --target dbml --target-path schema.dbml \
|
relspec merge --target dbml --target-path schema.dbml \
|
||||||
--source yaml --source-path tables.yaml \
|
--source yaml --source-path tables.yaml \
|
||||||
@@ -115,6 +123,8 @@ func init() {
|
|||||||
mergeCmd.Flags().BoolVar(&mergeSkipSequences, "skip-sequences", false, "Skip sequences during merge")
|
mergeCmd.Flags().BoolVar(&mergeSkipSequences, "skip-sequences", false, "Skip sequences during merge")
|
||||||
mergeCmd.Flags().StringVar(&mergeSkipTables, "skip-tables", "", "Comma-separated list of table names to skip during merge")
|
mergeCmd.Flags().StringVar(&mergeSkipTables, "skip-tables", "", "Comma-separated list of table names to skip during merge")
|
||||||
mergeCmd.Flags().BoolVar(&mergeVerbose, "verbose", false, "Show verbose output")
|
mergeCmd.Flags().BoolVar(&mergeVerbose, "verbose", false, "Show verbose output")
|
||||||
|
mergeCmd.Flags().StringVar(&mergeReportPath, "merge-report", "", "Path to write merge report (JSON format)")
|
||||||
|
mergeCmd.Flags().BoolVar(&mergeFlattenSchema, "flatten-schema", false, "Flatten schema.table names to schema_table (useful for databases like SQLite that do not support schemas)")
|
||||||
}
|
}
|
||||||
|
|
||||||
func runMerge(cmd *cobra.Command, args []string) error {
|
func runMerge(cmd *cobra.Command, args []string) error {
|
||||||
@@ -229,7 +239,7 @@ func runMerge(cmd *cobra.Command, args []string) error {
|
|||||||
fmt.Fprintf(os.Stderr, " Path: %s\n", mergeOutputPath)
|
fmt.Fprintf(os.Stderr, " Path: %s\n", mergeOutputPath)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = writeDatabaseForMerge(mergeOutputType, mergeOutputPath, "", targetDB, "Output")
|
err = writeDatabaseForMerge(mergeOutputType, mergeOutputPath, mergeOutputConn, targetDB, "Output", mergeFlattenSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to write output: %w", err)
|
return fmt.Errorf("failed to write output: %w", err)
|
||||||
}
|
}
|
||||||
@@ -316,7 +326,7 @@ func readDatabaseForMerge(dbType, filePath, connString, label string) (*models.D
|
|||||||
return db, nil
|
return db, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func writeDatabaseForMerge(dbType, filePath, connString string, db *models.Database, label string) error {
|
func writeDatabaseForMerge(dbType, filePath, connString string, db *models.Database, label string, flattenSchema bool) error {
|
||||||
var writer writers.Writer
|
var writer writers.Writer
|
||||||
|
|
||||||
switch strings.ToLower(dbType) {
|
switch strings.ToLower(dbType) {
|
||||||
@@ -324,59 +334,69 @@ func writeDatabaseForMerge(dbType, filePath, connString string, db *models.Datab
|
|||||||
if filePath == "" {
|
if filePath == "" {
|
||||||
return fmt.Errorf("%s: file path is required for DBML format", label)
|
return fmt.Errorf("%s: file path is required for DBML format", label)
|
||||||
}
|
}
|
||||||
writer = wdbml.NewWriter(&writers.WriterOptions{OutputPath: filePath})
|
writer = wdbml.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema})
|
||||||
case "dctx":
|
case "dctx":
|
||||||
if filePath == "" {
|
if filePath == "" {
|
||||||
return fmt.Errorf("%s: file path is required for DCTX format", label)
|
return fmt.Errorf("%s: file path is required for DCTX format", label)
|
||||||
}
|
}
|
||||||
writer = wdctx.NewWriter(&writers.WriterOptions{OutputPath: filePath})
|
writer = wdctx.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema})
|
||||||
case "drawdb":
|
case "drawdb":
|
||||||
if filePath == "" {
|
if filePath == "" {
|
||||||
return fmt.Errorf("%s: file path is required for DrawDB format", label)
|
return fmt.Errorf("%s: file path is required for DrawDB format", label)
|
||||||
}
|
}
|
||||||
writer = wdrawdb.NewWriter(&writers.WriterOptions{OutputPath: filePath})
|
writer = wdrawdb.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema})
|
||||||
case "graphql":
|
case "graphql":
|
||||||
if filePath == "" {
|
if filePath == "" {
|
||||||
return fmt.Errorf("%s: file path is required for GraphQL format", label)
|
return fmt.Errorf("%s: file path is required for GraphQL format", label)
|
||||||
}
|
}
|
||||||
writer = wgraphql.NewWriter(&writers.WriterOptions{OutputPath: filePath})
|
writer = wgraphql.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema})
|
||||||
case "json":
|
case "json":
|
||||||
if filePath == "" {
|
if filePath == "" {
|
||||||
return fmt.Errorf("%s: file path is required for JSON format", label)
|
return fmt.Errorf("%s: file path is required for JSON format", label)
|
||||||
}
|
}
|
||||||
writer = wjson.NewWriter(&writers.WriterOptions{OutputPath: filePath})
|
writer = wjson.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema})
|
||||||
case "yaml":
|
case "yaml":
|
||||||
if filePath == "" {
|
if filePath == "" {
|
||||||
return fmt.Errorf("%s: file path is required for YAML format", label)
|
return fmt.Errorf("%s: file path is required for YAML format", label)
|
||||||
}
|
}
|
||||||
writer = wyaml.NewWriter(&writers.WriterOptions{OutputPath: filePath})
|
writer = wyaml.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema})
|
||||||
case "gorm":
|
case "gorm":
|
||||||
if filePath == "" {
|
if filePath == "" {
|
||||||
return fmt.Errorf("%s: file path is required for GORM format", label)
|
return fmt.Errorf("%s: file path is required for GORM format", label)
|
||||||
}
|
}
|
||||||
writer = wgorm.NewWriter(&writers.WriterOptions{OutputPath: filePath})
|
writer = wgorm.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema})
|
||||||
case "bun":
|
case "bun":
|
||||||
if filePath == "" {
|
if filePath == "" {
|
||||||
return fmt.Errorf("%s: file path is required for Bun format", label)
|
return fmt.Errorf("%s: file path is required for Bun format", label)
|
||||||
}
|
}
|
||||||
writer = wbun.NewWriter(&writers.WriterOptions{OutputPath: filePath})
|
writer = wbun.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema})
|
||||||
case "drizzle":
|
case "drizzle":
|
||||||
if filePath == "" {
|
if filePath == "" {
|
||||||
return fmt.Errorf("%s: file path is required for Drizzle format", label)
|
return fmt.Errorf("%s: file path is required for Drizzle format", label)
|
||||||
}
|
}
|
||||||
writer = wdrizzle.NewWriter(&writers.WriterOptions{OutputPath: filePath})
|
writer = wdrizzle.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema})
|
||||||
case "prisma":
|
case "prisma":
|
||||||
if filePath == "" {
|
if filePath == "" {
|
||||||
return fmt.Errorf("%s: file path is required for Prisma format", label)
|
return fmt.Errorf("%s: file path is required for Prisma format", label)
|
||||||
}
|
}
|
||||||
writer = wprisma.NewWriter(&writers.WriterOptions{OutputPath: filePath})
|
writer = wprisma.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema})
|
||||||
case "typeorm":
|
case "typeorm":
|
||||||
if filePath == "" {
|
if filePath == "" {
|
||||||
return fmt.Errorf("%s: file path is required for TypeORM format", label)
|
return fmt.Errorf("%s: file path is required for TypeORM format", label)
|
||||||
}
|
}
|
||||||
writer = wtypeorm.NewWriter(&writers.WriterOptions{OutputPath: filePath})
|
writer = wtypeorm.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema})
|
||||||
case "pgsql":
|
case "pgsql":
|
||||||
writer = wpgsql.NewWriter(&writers.WriterOptions{OutputPath: filePath})
|
writerOpts := &writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema}
|
||||||
|
if connString != "" {
|
||||||
|
writerOpts.Metadata = map[string]interface{}{
|
||||||
|
"connection_string": connString,
|
||||||
|
}
|
||||||
|
// Add report path if merge report is enabled
|
||||||
|
if mergeReportPath != "" {
|
||||||
|
writerOpts.Metadata["report_path"] = mergeReportPath
|
||||||
|
}
|
||||||
|
}
|
||||||
|
writer = wpgsql.NewWriter(writerOpts)
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("%s: unsupported format '%s'", label, dbType)
|
return fmt.Errorf("%s: unsupported format '%s'", label, dbType)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,10 +14,11 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
scriptsDir string
|
scriptsDir string
|
||||||
scriptsConn string
|
scriptsConn string
|
||||||
scriptsSchemaName string
|
scriptsSchemaName string
|
||||||
scriptsDBName string
|
scriptsDBName string
|
||||||
|
scriptsIgnoreErrors bool
|
||||||
)
|
)
|
||||||
|
|
||||||
var scriptsCmd = &cobra.Command{
|
var scriptsCmd = &cobra.Command{
|
||||||
@@ -39,8 +40,8 @@ Example filenames (hyphen format):
|
|||||||
1-002-create-posts.sql # Priority 1, Sequence 2
|
1-002-create-posts.sql # Priority 1, Sequence 2
|
||||||
10-10-create-newid.pgsql # Priority 10, Sequence 10
|
10-10-create-newid.pgsql # Priority 10, Sequence 10
|
||||||
|
|
||||||
Both formats can be mixed in the same directory.
|
Both formats can be mixed in the same directory and subdirectories.
|
||||||
Scripts are executed in order: Priority (ascending), then Sequence (ascending).`,
|
Scripts are executed in order: Priority (ascending), Sequence (ascending), Name (alphabetical).`,
|
||||||
}
|
}
|
||||||
|
|
||||||
var scriptsListCmd = &cobra.Command{
|
var scriptsListCmd = &cobra.Command{
|
||||||
@@ -48,8 +49,8 @@ var scriptsListCmd = &cobra.Command{
|
|||||||
Short: "List SQL scripts from a directory",
|
Short: "List SQL scripts from a directory",
|
||||||
Long: `List SQL scripts from a directory and show their execution order.
|
Long: `List SQL scripts from a directory and show their execution order.
|
||||||
|
|
||||||
The scripts are read from the specified directory and displayed in the order
|
The scripts are read recursively from the specified directory and displayed in the order
|
||||||
they would be executed (Priority ascending, then Sequence ascending).
|
they would be executed: Priority (ascending), then Sequence (ascending), then Name (alphabetical).
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
relspec scripts list --dir ./migrations`,
|
relspec scripts list --dir ./migrations`,
|
||||||
@@ -61,10 +62,10 @@ var scriptsExecuteCmd = &cobra.Command{
|
|||||||
Short: "Execute SQL scripts against a database",
|
Short: "Execute SQL scripts against a database",
|
||||||
Long: `Execute SQL scripts from a directory against a PostgreSQL database.
|
Long: `Execute SQL scripts from a directory against a PostgreSQL database.
|
||||||
|
|
||||||
Scripts are executed in order: Priority (ascending), then Sequence (ascending).
|
Scripts are executed in order: Priority (ascending), Sequence (ascending), Name (alphabetical).
|
||||||
Execution stops immediately on the first error.
|
By default, execution stops immediately on the first error. Use --ignore-errors to continue execution.
|
||||||
|
|
||||||
The directory is scanned recursively for files matching the patterns:
|
The directory is scanned recursively for all subdirectories and files matching the patterns:
|
||||||
{priority}_{sequence}_{name}.sql or .pgsql (underscore format)
|
{priority}_{sequence}_{name}.sql or .pgsql (underscore format)
|
||||||
{priority}-{sequence}-{name}.sql or .pgsql (hyphen format)
|
{priority}-{sequence}-{name}.sql or .pgsql (hyphen format)
|
||||||
|
|
||||||
@@ -75,7 +76,7 @@ PostgreSQL Connection String Examples:
|
|||||||
postgresql://user:pass@host/dbname?sslmode=require
|
postgresql://user:pass@host/dbname?sslmode=require
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
# Execute migration scripts
|
# Execute migration scripts from a directory (including subdirectories)
|
||||||
relspec scripts execute --dir ./migrations \
|
relspec scripts execute --dir ./migrations \
|
||||||
--conn "postgres://user:pass@localhost:5432/mydb"
|
--conn "postgres://user:pass@localhost:5432/mydb"
|
||||||
|
|
||||||
@@ -86,7 +87,12 @@ Examples:
|
|||||||
|
|
||||||
# Execute with SSL disabled
|
# Execute with SSL disabled
|
||||||
relspec scripts execute --dir ./sql \
|
relspec scripts execute --dir ./sql \
|
||||||
--conn "postgres://user:pass@localhost/db?sslmode=disable"`,
|
--conn "postgres://user:pass@localhost/db?sslmode=disable"
|
||||||
|
|
||||||
|
# Continue executing even if errors occur
|
||||||
|
relspec scripts execute --dir ./migrations \
|
||||||
|
--conn "postgres://localhost/mydb" \
|
||||||
|
--ignore-errors`,
|
||||||
RunE: runScriptsExecute,
|
RunE: runScriptsExecute,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -105,6 +111,7 @@ func init() {
|
|||||||
scriptsExecuteCmd.Flags().StringVar(&scriptsConn, "conn", "", "PostgreSQL connection string (required)")
|
scriptsExecuteCmd.Flags().StringVar(&scriptsConn, "conn", "", "PostgreSQL connection string (required)")
|
||||||
scriptsExecuteCmd.Flags().StringVar(&scriptsSchemaName, "schema", "public", "Schema name (optional, default: public)")
|
scriptsExecuteCmd.Flags().StringVar(&scriptsSchemaName, "schema", "public", "Schema name (optional, default: public)")
|
||||||
scriptsExecuteCmd.Flags().StringVar(&scriptsDBName, "database", "database", "Database name (optional, default: database)")
|
scriptsExecuteCmd.Flags().StringVar(&scriptsDBName, "database", "database", "Database name (optional, default: database)")
|
||||||
|
scriptsExecuteCmd.Flags().BoolVar(&scriptsIgnoreErrors, "ignore-errors", false, "Continue executing scripts even if errors occur")
|
||||||
|
|
||||||
err = scriptsExecuteCmd.MarkFlagRequired("dir")
|
err = scriptsExecuteCmd.MarkFlagRequired("dir")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -149,7 +156,7 @@ func runScriptsList(cmd *cobra.Command, args []string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sort scripts by Priority then Sequence
|
// Sort scripts by Priority, Sequence, then Name
|
||||||
sortedScripts := make([]*struct {
|
sortedScripts := make([]*struct {
|
||||||
name string
|
name string
|
||||||
priority int
|
priority int
|
||||||
@@ -186,7 +193,10 @@ func runScriptsList(cmd *cobra.Command, args []string) error {
|
|||||||
if sortedScripts[i].priority != sortedScripts[j].priority {
|
if sortedScripts[i].priority != sortedScripts[j].priority {
|
||||||
return sortedScripts[i].priority < sortedScripts[j].priority
|
return sortedScripts[i].priority < sortedScripts[j].priority
|
||||||
}
|
}
|
||||||
return sortedScripts[i].sequence < sortedScripts[j].sequence
|
if sortedScripts[i].sequence != sortedScripts[j].sequence {
|
||||||
|
return sortedScripts[i].sequence < sortedScripts[j].sequence
|
||||||
|
}
|
||||||
|
return sortedScripts[i].name < sortedScripts[j].name
|
||||||
})
|
})
|
||||||
|
|
||||||
fmt.Fprintf(os.Stderr, "Found %d script(s) in execution order:\n\n", len(sortedScripts))
|
fmt.Fprintf(os.Stderr, "Found %d script(s) in execution order:\n\n", len(sortedScripts))
|
||||||
@@ -242,22 +252,44 @@ func runScriptsExecute(cmd *cobra.Command, args []string) error {
|
|||||||
fmt.Fprintf(os.Stderr, " ✓ Found %d script(s)\n\n", len(schema.Scripts))
|
fmt.Fprintf(os.Stderr, " ✓ Found %d script(s)\n\n", len(schema.Scripts))
|
||||||
|
|
||||||
// Step 2: Execute scripts
|
// Step 2: Execute scripts
|
||||||
fmt.Fprintf(os.Stderr, "[2/2] Executing scripts in order (Priority → Sequence)...\n\n")
|
fmt.Fprintf(os.Stderr, "[2/2] Executing scripts in order (Priority → Sequence → Name)...\n\n")
|
||||||
|
|
||||||
writer := sqlexec.NewWriter(&writers.WriterOptions{
|
writer := sqlexec.NewWriter(&writers.WriterOptions{
|
||||||
Metadata: map[string]any{
|
Metadata: map[string]any{
|
||||||
"connection_string": scriptsConn,
|
"connection_string": scriptsConn,
|
||||||
|
"ignore_errors": scriptsIgnoreErrors,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
if err := writer.WriteSchema(schema); err != nil {
|
if err := writer.WriteSchema(schema); err != nil {
|
||||||
fmt.Fprintf(os.Stderr, "\n")
|
fmt.Fprintf(os.Stderr, "\n")
|
||||||
return fmt.Errorf("execution failed: %w", err)
|
return fmt.Errorf("script execution failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get execution results from writer metadata
|
||||||
|
totalCount := len(schema.Scripts)
|
||||||
|
successCount := totalCount
|
||||||
|
failedCount := 0
|
||||||
|
|
||||||
|
opts := writer.Options()
|
||||||
|
if total, exists := opts.Metadata["execution_total"].(int); exists {
|
||||||
|
totalCount = total
|
||||||
|
}
|
||||||
|
if success, exists := opts.Metadata["execution_success"].(int); exists {
|
||||||
|
successCount = success
|
||||||
|
}
|
||||||
|
if failed, exists := opts.Metadata["execution_failed"].(int); exists {
|
||||||
|
failedCount = failed
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Fprintf(os.Stderr, "\n=== Execution Complete ===\n")
|
fmt.Fprintf(os.Stderr, "\n=== Execution Complete ===\n")
|
||||||
fmt.Fprintf(os.Stderr, "Completed at: %s\n", getCurrentTimestamp())
|
fmt.Fprintf(os.Stderr, "Completed at: %s\n", getCurrentTimestamp())
|
||||||
fmt.Fprintf(os.Stderr, "Successfully executed %d script(s)\n\n", len(schema.Scripts))
|
fmt.Fprintf(os.Stderr, "Total scripts: %d\n", totalCount)
|
||||||
|
fmt.Fprintf(os.Stderr, "Successful: %d\n", successCount)
|
||||||
|
if failedCount > 0 {
|
||||||
|
fmt.Fprintf(os.Stderr, "Failed: %d\n", failedCount)
|
||||||
|
}
|
||||||
|
fmt.Fprintf(os.Stderr, "\n")
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -183,7 +183,8 @@ func runSplit(cmd *cobra.Command, args []string) error {
|
|||||||
splitTargetType,
|
splitTargetType,
|
||||||
splitTargetPath,
|
splitTargetPath,
|
||||||
splitPackageName,
|
splitPackageName,
|
||||||
"", // no schema filter for split
|
"", // no schema filter for split
|
||||||
|
false, // no flatten-schema for split
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to write output: %w", err)
|
return fmt.Errorf("failed to write output: %w", err)
|
||||||
|
|||||||
714
pkg/commontypes/commontypes_test.go
Normal file
714
pkg/commontypes/commontypes_test.go
Normal file
@@ -0,0 +1,714 @@
|
|||||||
|
package commontypes
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestExtractBaseType(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
sqlType string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"varchar with length", "varchar(100)", "varchar"},
|
||||||
|
{"VARCHAR uppercase with length", "VARCHAR(255)", "varchar"},
|
||||||
|
{"numeric with precision", "numeric(10,2)", "numeric"},
|
||||||
|
{"NUMERIC uppercase", "NUMERIC(18,4)", "numeric"},
|
||||||
|
{"decimal with precision", "decimal(15,3)", "decimal"},
|
||||||
|
{"char with length", "char(50)", "char"},
|
||||||
|
{"simple integer", "integer", "integer"},
|
||||||
|
{"simple text", "text", "text"},
|
||||||
|
{"bigint", "bigint", "bigint"},
|
||||||
|
{"With spaces", " varchar(100) ", "varchar"},
|
||||||
|
{"No parentheses", "boolean", "boolean"},
|
||||||
|
{"Empty string", "", ""},
|
||||||
|
{"Mixed case", "VarChar(100)", "varchar"},
|
||||||
|
{"timestamp with time zone", "timestamp(6) with time zone", "timestamp"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := ExtractBaseType(tt.sqlType)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("ExtractBaseType(%q) = %q, want %q", tt.sqlType, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeType(t *testing.T) {
|
||||||
|
// NormalizeType is an alias for ExtractBaseType, test that they behave the same
|
||||||
|
testCases := []string{
|
||||||
|
"varchar(100)",
|
||||||
|
"numeric(10,2)",
|
||||||
|
"integer",
|
||||||
|
"text",
|
||||||
|
" VARCHAR(255) ",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc, func(t *testing.T) {
|
||||||
|
extracted := ExtractBaseType(tc)
|
||||||
|
normalized := NormalizeType(tc)
|
||||||
|
if extracted != normalized {
|
||||||
|
t.Errorf("ExtractBaseType(%q) = %q, but NormalizeType(%q) = %q",
|
||||||
|
tc, extracted, tc, normalized)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSQLToGo(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
sqlType string
|
||||||
|
nullable bool
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
// Integer types (nullable)
|
||||||
|
{"integer nullable", "integer", true, "int32"},
|
||||||
|
{"bigint nullable", "bigint", true, "int64"},
|
||||||
|
{"smallint nullable", "smallint", true, "int16"},
|
||||||
|
{"serial nullable", "serial", true, "int32"},
|
||||||
|
|
||||||
|
// Integer types (not nullable)
|
||||||
|
{"integer not nullable", "integer", false, "*int32"},
|
||||||
|
{"bigint not nullable", "bigint", false, "*int64"},
|
||||||
|
{"smallint not nullable", "smallint", false, "*int16"},
|
||||||
|
|
||||||
|
// String types (nullable)
|
||||||
|
{"text nullable", "text", true, "string"},
|
||||||
|
{"varchar nullable", "varchar", true, "string"},
|
||||||
|
{"varchar with length nullable", "varchar(100)", true, "string"},
|
||||||
|
|
||||||
|
// String types (not nullable)
|
||||||
|
{"text not nullable", "text", false, "*string"},
|
||||||
|
{"varchar not nullable", "varchar", false, "*string"},
|
||||||
|
|
||||||
|
// Boolean
|
||||||
|
{"boolean nullable", "boolean", true, "bool"},
|
||||||
|
{"boolean not nullable", "boolean", false, "*bool"},
|
||||||
|
|
||||||
|
// Float types
|
||||||
|
{"real nullable", "real", true, "float32"},
|
||||||
|
{"double precision nullable", "double precision", true, "float64"},
|
||||||
|
{"real not nullable", "real", false, "*float32"},
|
||||||
|
{"double precision not nullable", "double precision", false, "*float64"},
|
||||||
|
|
||||||
|
// Date/Time types
|
||||||
|
{"timestamp nullable", "timestamp", true, "time.Time"},
|
||||||
|
{"date nullable", "date", true, "time.Time"},
|
||||||
|
{"timestamp not nullable", "timestamp", false, "*time.Time"},
|
||||||
|
|
||||||
|
// Binary
|
||||||
|
{"bytea nullable", "bytea", true, "[]byte"},
|
||||||
|
{"bytea not nullable", "bytea", false, "[]byte"}, // Slices don't get pointer
|
||||||
|
|
||||||
|
// UUID
|
||||||
|
{"uuid nullable", "uuid", true, "string"},
|
||||||
|
{"uuid not nullable", "uuid", false, "*string"},
|
||||||
|
|
||||||
|
// JSON
|
||||||
|
{"json nullable", "json", true, "string"},
|
||||||
|
{"jsonb nullable", "jsonb", true, "string"},
|
||||||
|
|
||||||
|
// Array
|
||||||
|
{"array nullable", "array", true, "[]string"},
|
||||||
|
{"array not nullable", "array", false, "[]string"}, // Slices don't get pointer
|
||||||
|
|
||||||
|
// Unknown types
|
||||||
|
{"unknown type nullable", "unknowntype", true, "interface{}"},
|
||||||
|
{"unknown type not nullable", "unknowntype", false, "interface{}"}, // Interface doesn't get pointer
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := SQLToGo(tt.sqlType, tt.nullable)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("SQLToGo(%q, %v) = %q, want %q", tt.sqlType, tt.nullable, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSQLToTypeScript(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
sqlType string
|
||||||
|
nullable bool
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
// Integer types
|
||||||
|
{"integer nullable", "integer", true, "number"},
|
||||||
|
{"integer not nullable", "integer", false, "number | null"},
|
||||||
|
{"bigint nullable", "bigint", true, "number"},
|
||||||
|
{"bigint not nullable", "bigint", false, "number | null"},
|
||||||
|
|
||||||
|
// String types
|
||||||
|
{"text nullable", "text", true, "string"},
|
||||||
|
{"text not nullable", "text", false, "string | null"},
|
||||||
|
{"varchar nullable", "varchar", true, "string"},
|
||||||
|
{"varchar(100) nullable", "varchar(100)", true, "string"},
|
||||||
|
|
||||||
|
// Boolean
|
||||||
|
{"boolean nullable", "boolean", true, "boolean"},
|
||||||
|
{"boolean not nullable", "boolean", false, "boolean | null"},
|
||||||
|
|
||||||
|
// Float types
|
||||||
|
{"real nullable", "real", true, "number"},
|
||||||
|
{"double precision nullable", "double precision", true, "number"},
|
||||||
|
|
||||||
|
// Date/Time types
|
||||||
|
{"timestamp nullable", "timestamp", true, "Date"},
|
||||||
|
{"date nullable", "date", true, "Date"},
|
||||||
|
{"timestamp not nullable", "timestamp", false, "Date | null"},
|
||||||
|
|
||||||
|
// Binary
|
||||||
|
{"bytea nullable", "bytea", true, "Buffer"},
|
||||||
|
{"bytea not nullable", "bytea", false, "Buffer | null"},
|
||||||
|
|
||||||
|
// JSON
|
||||||
|
{"json nullable", "json", true, "any"},
|
||||||
|
{"jsonb nullable", "jsonb", true, "any"},
|
||||||
|
|
||||||
|
// UUID
|
||||||
|
{"uuid nullable", "uuid", true, "string"},
|
||||||
|
|
||||||
|
// Unknown types
|
||||||
|
{"unknown type nullable", "unknowntype", true, "any"},
|
||||||
|
{"unknown type not nullable", "unknowntype", false, "any | null"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := SQLToTypeScript(tt.sqlType, tt.nullable)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("SQLToTypeScript(%q, %v) = %q, want %q", tt.sqlType, tt.nullable, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSQLToPython(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
sqlType string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
// Integer types
|
||||||
|
{"integer", "integer", "int"},
|
||||||
|
{"bigint", "bigint", "int"},
|
||||||
|
{"smallint", "smallint", "int"},
|
||||||
|
|
||||||
|
// String types
|
||||||
|
{"text", "text", "str"},
|
||||||
|
{"varchar", "varchar", "str"},
|
||||||
|
{"varchar(100)", "varchar(100)", "str"},
|
||||||
|
|
||||||
|
// Boolean
|
||||||
|
{"boolean", "boolean", "bool"},
|
||||||
|
|
||||||
|
// Float types
|
||||||
|
{"real", "real", "float"},
|
||||||
|
{"double precision", "double precision", "float"},
|
||||||
|
{"numeric", "numeric", "Decimal"},
|
||||||
|
{"decimal", "decimal", "Decimal"},
|
||||||
|
|
||||||
|
// Date/Time types
|
||||||
|
{"timestamp", "timestamp", "datetime"},
|
||||||
|
{"date", "date", "date"},
|
||||||
|
{"time", "time", "time"},
|
||||||
|
|
||||||
|
// Binary
|
||||||
|
{"bytea", "bytea", "bytes"},
|
||||||
|
|
||||||
|
// JSON
|
||||||
|
{"json", "json", "dict"},
|
||||||
|
{"jsonb", "jsonb", "dict"},
|
||||||
|
|
||||||
|
// UUID
|
||||||
|
{"uuid", "uuid", "UUID"},
|
||||||
|
|
||||||
|
// Array
|
||||||
|
{"array", "array", "list"},
|
||||||
|
|
||||||
|
// Unknown types
|
||||||
|
{"unknown type", "unknowntype", "Any"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := SQLToPython(tt.sqlType)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("SQLToPython(%q) = %q, want %q", tt.sqlType, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSQLToCSharp(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
sqlType string
|
||||||
|
nullable bool
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
// Integer types (nullable)
|
||||||
|
{"integer nullable", "integer", true, "int"},
|
||||||
|
{"bigint nullable", "bigint", true, "long"},
|
||||||
|
{"smallint nullable", "smallint", true, "short"},
|
||||||
|
|
||||||
|
// Integer types (not nullable - value types get ?)
|
||||||
|
{"integer not nullable", "integer", false, "int?"},
|
||||||
|
{"bigint not nullable", "bigint", false, "long?"},
|
||||||
|
{"smallint not nullable", "smallint", false, "short?"},
|
||||||
|
|
||||||
|
// String types (reference types, no ? needed)
|
||||||
|
{"text nullable", "text", true, "string"},
|
||||||
|
{"text not nullable", "text", false, "string"},
|
||||||
|
{"varchar nullable", "varchar", true, "string"},
|
||||||
|
{"varchar(100) nullable", "varchar(100)", true, "string"},
|
||||||
|
|
||||||
|
// Boolean
|
||||||
|
{"boolean nullable", "boolean", true, "bool"},
|
||||||
|
{"boolean not nullable", "boolean", false, "bool?"},
|
||||||
|
|
||||||
|
// Float types
|
||||||
|
{"real nullable", "real", true, "float"},
|
||||||
|
{"double precision nullable", "double precision", true, "double"},
|
||||||
|
{"decimal nullable", "decimal", true, "decimal"},
|
||||||
|
{"real not nullable", "real", false, "float?"},
|
||||||
|
{"double precision not nullable", "double precision", false, "double?"},
|
||||||
|
{"decimal not nullable", "decimal", false, "decimal?"},
|
||||||
|
|
||||||
|
// Date/Time types
|
||||||
|
{"timestamp nullable", "timestamp", true, "DateTime"},
|
||||||
|
{"date nullable", "date", true, "DateTime"},
|
||||||
|
{"timestamptz nullable", "timestamptz", true, "DateTimeOffset"},
|
||||||
|
{"timestamp not nullable", "timestamp", false, "DateTime?"},
|
||||||
|
{"timestamptz not nullable", "timestamptz", false, "DateTimeOffset?"},
|
||||||
|
|
||||||
|
// Binary (array type, no ?)
|
||||||
|
{"bytea nullable", "bytea", true, "byte[]"},
|
||||||
|
{"bytea not nullable", "bytea", false, "byte[]"},
|
||||||
|
|
||||||
|
// UUID
|
||||||
|
{"uuid nullable", "uuid", true, "Guid"},
|
||||||
|
{"uuid not nullable", "uuid", false, "Guid?"},
|
||||||
|
|
||||||
|
// JSON
|
||||||
|
{"json nullable", "json", true, "string"},
|
||||||
|
|
||||||
|
// Unknown types (object is reference type)
|
||||||
|
{"unknown type nullable", "unknowntype", true, "object"},
|
||||||
|
{"unknown type not nullable", "unknowntype", false, "object"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := SQLToCSharp(tt.sqlType, tt.nullable)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("SQLToCSharp(%q, %v) = %q, want %q", tt.sqlType, tt.nullable, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNeedsTimeImport(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
goType string
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{"time.Time type", "time.Time", true},
|
||||||
|
{"pointer to time.Time", "*time.Time", true},
|
||||||
|
{"int32 type", "int32", false},
|
||||||
|
{"string type", "string", false},
|
||||||
|
{"bool type", "bool", false},
|
||||||
|
{"[]byte type", "[]byte", false},
|
||||||
|
{"interface{}", "interface{}", false},
|
||||||
|
{"empty string", "", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := NeedsTimeImport(tt.goType)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("NeedsTimeImport(%q) = %v, want %v", tt.goType, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGoTypeMap(t *testing.T) {
|
||||||
|
// Test that the map contains expected entries
|
||||||
|
expectedMappings := map[string]string{
|
||||||
|
"integer": "int32",
|
||||||
|
"bigint": "int64",
|
||||||
|
"text": "string",
|
||||||
|
"boolean": "bool",
|
||||||
|
"double precision": "float64",
|
||||||
|
"bytea": "[]byte",
|
||||||
|
"timestamp": "time.Time",
|
||||||
|
"uuid": "string",
|
||||||
|
"json": "string",
|
||||||
|
}
|
||||||
|
|
||||||
|
for sqlType, expectedGoType := range expectedMappings {
|
||||||
|
if goType, ok := GoTypeMap[sqlType]; !ok {
|
||||||
|
t.Errorf("GoTypeMap missing entry for %q", sqlType)
|
||||||
|
} else if goType != expectedGoType {
|
||||||
|
t.Errorf("GoTypeMap[%q] = %q, want %q", sqlType, goType, expectedGoType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(GoTypeMap) == 0 {
|
||||||
|
t.Error("GoTypeMap is empty")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTypeScriptTypeMap(t *testing.T) {
|
||||||
|
expectedMappings := map[string]string{
|
||||||
|
"integer": "number",
|
||||||
|
"bigint": "number",
|
||||||
|
"text": "string",
|
||||||
|
"boolean": "boolean",
|
||||||
|
"double precision": "number",
|
||||||
|
"bytea": "Buffer",
|
||||||
|
"timestamp": "Date",
|
||||||
|
"uuid": "string",
|
||||||
|
"json": "any",
|
||||||
|
}
|
||||||
|
|
||||||
|
for sqlType, expectedTSType := range expectedMappings {
|
||||||
|
if tsType, ok := TypeScriptTypeMap[sqlType]; !ok {
|
||||||
|
t.Errorf("TypeScriptTypeMap missing entry for %q", sqlType)
|
||||||
|
} else if tsType != expectedTSType {
|
||||||
|
t.Errorf("TypeScriptTypeMap[%q] = %q, want %q", sqlType, tsType, expectedTSType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(TypeScriptTypeMap) == 0 {
|
||||||
|
t.Error("TypeScriptTypeMap is empty")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPythonTypeMap(t *testing.T) {
|
||||||
|
expectedMappings := map[string]string{
|
||||||
|
"integer": "int",
|
||||||
|
"bigint": "int",
|
||||||
|
"text": "str",
|
||||||
|
"boolean": "bool",
|
||||||
|
"real": "float",
|
||||||
|
"numeric": "Decimal",
|
||||||
|
"bytea": "bytes",
|
||||||
|
"date": "date",
|
||||||
|
"uuid": "UUID",
|
||||||
|
"json": "dict",
|
||||||
|
}
|
||||||
|
|
||||||
|
for sqlType, expectedPyType := range expectedMappings {
|
||||||
|
if pyType, ok := PythonTypeMap[sqlType]; !ok {
|
||||||
|
t.Errorf("PythonTypeMap missing entry for %q", sqlType)
|
||||||
|
} else if pyType != expectedPyType {
|
||||||
|
t.Errorf("PythonTypeMap[%q] = %q, want %q", sqlType, pyType, expectedPyType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(PythonTypeMap) == 0 {
|
||||||
|
t.Error("PythonTypeMap is empty")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCSharpTypeMap(t *testing.T) {
|
||||||
|
expectedMappings := map[string]string{
|
||||||
|
"integer": "int",
|
||||||
|
"bigint": "long",
|
||||||
|
"smallint": "short",
|
||||||
|
"text": "string",
|
||||||
|
"boolean": "bool",
|
||||||
|
"double precision": "double",
|
||||||
|
"decimal": "decimal",
|
||||||
|
"bytea": "byte[]",
|
||||||
|
"timestamp": "DateTime",
|
||||||
|
"uuid": "Guid",
|
||||||
|
"json": "string",
|
||||||
|
}
|
||||||
|
|
||||||
|
for sqlType, expectedCSType := range expectedMappings {
|
||||||
|
if csType, ok := CSharpTypeMap[sqlType]; !ok {
|
||||||
|
t.Errorf("CSharpTypeMap missing entry for %q", sqlType)
|
||||||
|
} else if csType != expectedCSType {
|
||||||
|
t.Errorf("CSharpTypeMap[%q] = %q, want %q", sqlType, csType, expectedCSType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(CSharpTypeMap) == 0 {
|
||||||
|
t.Error("CSharpTypeMap is empty")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSQLToJava(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
sqlType string
|
||||||
|
nullable bool
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
// Integer types
|
||||||
|
{"integer nullable", "integer", true, "Integer"},
|
||||||
|
{"integer not nullable", "integer", false, "Integer"},
|
||||||
|
{"bigint nullable", "bigint", true, "Long"},
|
||||||
|
{"smallint nullable", "smallint", true, "Short"},
|
||||||
|
|
||||||
|
// String types
|
||||||
|
{"text nullable", "text", true, "String"},
|
||||||
|
{"varchar nullable", "varchar", true, "String"},
|
||||||
|
{"varchar(100) nullable", "varchar(100)", true, "String"},
|
||||||
|
|
||||||
|
// Boolean
|
||||||
|
{"boolean nullable", "boolean", true, "Boolean"},
|
||||||
|
|
||||||
|
// Float types
|
||||||
|
{"real nullable", "real", true, "Float"},
|
||||||
|
{"double precision nullable", "double precision", true, "Double"},
|
||||||
|
{"numeric nullable", "numeric", true, "BigDecimal"},
|
||||||
|
|
||||||
|
// Date/Time types
|
||||||
|
{"timestamp nullable", "timestamp", true, "Timestamp"},
|
||||||
|
{"date nullable", "date", true, "Date"},
|
||||||
|
{"time nullable", "time", true, "Time"},
|
||||||
|
|
||||||
|
// Binary
|
||||||
|
{"bytea nullable", "bytea", true, "byte[]"},
|
||||||
|
|
||||||
|
// UUID
|
||||||
|
{"uuid nullable", "uuid", true, "UUID"},
|
||||||
|
|
||||||
|
// JSON
|
||||||
|
{"json nullable", "json", true, "String"},
|
||||||
|
|
||||||
|
// Unknown types
|
||||||
|
{"unknown type nullable", "unknowntype", true, "Object"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := SQLToJava(tt.sqlType, tt.nullable)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("SQLToJava(%q, %v) = %q, want %q", tt.sqlType, tt.nullable, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSQLToPhp(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
sqlType string
|
||||||
|
nullable bool
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
// Integer types (nullable)
|
||||||
|
{"integer nullable", "integer", true, "int"},
|
||||||
|
{"bigint nullable", "bigint", true, "int"},
|
||||||
|
{"smallint nullable", "smallint", true, "int"},
|
||||||
|
|
||||||
|
// Integer types (not nullable)
|
||||||
|
{"integer not nullable", "integer", false, "?int"},
|
||||||
|
{"bigint not nullable", "bigint", false, "?int"},
|
||||||
|
|
||||||
|
// String types
|
||||||
|
{"text nullable", "text", true, "string"},
|
||||||
|
{"text not nullable", "text", false, "?string"},
|
||||||
|
{"varchar nullable", "varchar", true, "string"},
|
||||||
|
{"varchar(100) nullable", "varchar(100)", true, "string"},
|
||||||
|
|
||||||
|
// Boolean
|
||||||
|
{"boolean nullable", "boolean", true, "bool"},
|
||||||
|
{"boolean not nullable", "boolean", false, "?bool"},
|
||||||
|
|
||||||
|
// Float types
|
||||||
|
{"real nullable", "real", true, "float"},
|
||||||
|
{"double precision nullable", "double precision", true, "float"},
|
||||||
|
{"real not nullable", "real", false, "?float"},
|
||||||
|
|
||||||
|
// Date/Time types
|
||||||
|
{"timestamp nullable", "timestamp", true, "\\DateTime"},
|
||||||
|
{"date nullable", "date", true, "\\DateTime"},
|
||||||
|
{"timestamp not nullable", "timestamp", false, "?\\DateTime"},
|
||||||
|
|
||||||
|
// Binary
|
||||||
|
{"bytea nullable", "bytea", true, "string"},
|
||||||
|
{"bytea not nullable", "bytea", false, "?string"},
|
||||||
|
|
||||||
|
// JSON
|
||||||
|
{"json nullable", "json", true, "array"},
|
||||||
|
{"json not nullable", "json", false, "?array"},
|
||||||
|
|
||||||
|
// UUID
|
||||||
|
{"uuid nullable", "uuid", true, "string"},
|
||||||
|
|
||||||
|
// Unknown types
|
||||||
|
{"unknown type nullable", "unknowntype", true, "mixed"},
|
||||||
|
{"unknown type not nullable", "unknowntype", false, "mixed"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := SQLToPhp(tt.sqlType, tt.nullable)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("SQLToPhp(%q, %v) = %q, want %q", tt.sqlType, tt.nullable, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSQLToRust(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
sqlType string
|
||||||
|
nullable bool
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
// Integer types (nullable)
|
||||||
|
{"integer nullable", "integer", true, "i32"},
|
||||||
|
{"bigint nullable", "bigint", true, "i64"},
|
||||||
|
{"smallint nullable", "smallint", true, "i16"},
|
||||||
|
|
||||||
|
// Integer types (not nullable)
|
||||||
|
{"integer not nullable", "integer", false, "Option<i32>"},
|
||||||
|
{"bigint not nullable", "bigint", false, "Option<i64>"},
|
||||||
|
{"smallint not nullable", "smallint", false, "Option<i16>"},
|
||||||
|
|
||||||
|
// String types
|
||||||
|
{"text nullable", "text", true, "String"},
|
||||||
|
{"text not nullable", "text", false, "Option<String>"},
|
||||||
|
{"varchar nullable", "varchar", true, "String"},
|
||||||
|
{"varchar(100) nullable", "varchar(100)", true, "String"},
|
||||||
|
|
||||||
|
// Boolean
|
||||||
|
{"boolean nullable", "boolean", true, "bool"},
|
||||||
|
{"boolean not nullable", "boolean", false, "Option<bool>"},
|
||||||
|
|
||||||
|
// Float types
|
||||||
|
{"real nullable", "real", true, "f32"},
|
||||||
|
{"double precision nullable", "double precision", true, "f64"},
|
||||||
|
{"real not nullable", "real", false, "Option<f32>"},
|
||||||
|
{"double precision not nullable", "double precision", false, "Option<f64>"},
|
||||||
|
|
||||||
|
// Date/Time types
|
||||||
|
{"timestamp nullable", "timestamp", true, "NaiveDateTime"},
|
||||||
|
{"timestamptz nullable", "timestamptz", true, "DateTime<Utc>"},
|
||||||
|
{"date nullable", "date", true, "NaiveDate"},
|
||||||
|
{"time nullable", "time", true, "NaiveTime"},
|
||||||
|
{"timestamp not nullable", "timestamp", false, "Option<NaiveDateTime>"},
|
||||||
|
|
||||||
|
// Binary
|
||||||
|
{"bytea nullable", "bytea", true, "Vec<u8>"},
|
||||||
|
{"bytea not nullable", "bytea", false, "Option<Vec<u8>>"},
|
||||||
|
|
||||||
|
// JSON
|
||||||
|
{"json nullable", "json", true, "serde_json::Value"},
|
||||||
|
{"json not nullable", "json", false, "Option<serde_json::Value>"},
|
||||||
|
|
||||||
|
// UUID
|
||||||
|
{"uuid nullable", "uuid", true, "String"},
|
||||||
|
|
||||||
|
// Unknown types
|
||||||
|
{"unknown type nullable", "unknowntype", true, "String"},
|
||||||
|
{"unknown type not nullable", "unknowntype", false, "Option<String>"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := SQLToRust(tt.sqlType, tt.nullable)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("SQLToRust(%q, %v) = %q, want %q", tt.sqlType, tt.nullable, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJavaTypeMap(t *testing.T) {
|
||||||
|
expectedMappings := map[string]string{
|
||||||
|
"integer": "Integer",
|
||||||
|
"bigint": "Long",
|
||||||
|
"smallint": "Short",
|
||||||
|
"text": "String",
|
||||||
|
"boolean": "Boolean",
|
||||||
|
"double precision": "Double",
|
||||||
|
"numeric": "BigDecimal",
|
||||||
|
"bytea": "byte[]",
|
||||||
|
"timestamp": "Timestamp",
|
||||||
|
"uuid": "UUID",
|
||||||
|
"date": "Date",
|
||||||
|
}
|
||||||
|
|
||||||
|
for sqlType, expectedJavaType := range expectedMappings {
|
||||||
|
if javaType, ok := JavaTypeMap[sqlType]; !ok {
|
||||||
|
t.Errorf("JavaTypeMap missing entry for %q", sqlType)
|
||||||
|
} else if javaType != expectedJavaType {
|
||||||
|
t.Errorf("JavaTypeMap[%q] = %q, want %q", sqlType, javaType, expectedJavaType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(JavaTypeMap) == 0 {
|
||||||
|
t.Error("JavaTypeMap is empty")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPHPTypeMap(t *testing.T) {
|
||||||
|
expectedMappings := map[string]string{
|
||||||
|
"integer": "int",
|
||||||
|
"bigint": "int",
|
||||||
|
"text": "string",
|
||||||
|
"boolean": "bool",
|
||||||
|
"double precision": "float",
|
||||||
|
"bytea": "string",
|
||||||
|
"timestamp": "\\DateTime",
|
||||||
|
"uuid": "string",
|
||||||
|
"json": "array",
|
||||||
|
}
|
||||||
|
|
||||||
|
for sqlType, expectedPHPType := range expectedMappings {
|
||||||
|
if phpType, ok := PHPTypeMap[sqlType]; !ok {
|
||||||
|
t.Errorf("PHPTypeMap missing entry for %q", sqlType)
|
||||||
|
} else if phpType != expectedPHPType {
|
||||||
|
t.Errorf("PHPTypeMap[%q] = %q, want %q", sqlType, phpType, expectedPHPType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(PHPTypeMap) == 0 {
|
||||||
|
t.Error("PHPTypeMap is empty")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRustTypeMap(t *testing.T) {
|
||||||
|
expectedMappings := map[string]string{
|
||||||
|
"integer": "i32",
|
||||||
|
"bigint": "i64",
|
||||||
|
"smallint": "i16",
|
||||||
|
"text": "String",
|
||||||
|
"boolean": "bool",
|
||||||
|
"double precision": "f64",
|
||||||
|
"real": "f32",
|
||||||
|
"bytea": "Vec<u8>",
|
||||||
|
"timestamp": "NaiveDateTime",
|
||||||
|
"timestamptz": "DateTime<Utc>",
|
||||||
|
"date": "NaiveDate",
|
||||||
|
"json": "serde_json::Value",
|
||||||
|
}
|
||||||
|
|
||||||
|
for sqlType, expectedRustType := range expectedMappings {
|
||||||
|
if rustType, ok := RustTypeMap[sqlType]; !ok {
|
||||||
|
t.Errorf("RustTypeMap missing entry for %q", sqlType)
|
||||||
|
} else if rustType != expectedRustType {
|
||||||
|
t.Errorf("RustTypeMap[%q] = %q, want %q", sqlType, rustType, expectedRustType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(RustTypeMap) == 0 {
|
||||||
|
t.Error("RustTypeMap is empty")
|
||||||
|
}
|
||||||
|
}
|
||||||
558
pkg/diff/diff_test.go
Normal file
558
pkg/diff/diff_test.go
Normal file
@@ -0,0 +1,558 @@
|
|||||||
|
package diff
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCompareDatabases(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
source *models.Database
|
||||||
|
target *models.Database
|
||||||
|
want func(*DiffResult) bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "identical databases",
|
||||||
|
source: &models.Database{
|
||||||
|
Name: "source",
|
||||||
|
Schemas: []*models.Schema{},
|
||||||
|
},
|
||||||
|
target: &models.Database{
|
||||||
|
Name: "target",
|
||||||
|
Schemas: []*models.Schema{},
|
||||||
|
},
|
||||||
|
want: func(r *DiffResult) bool {
|
||||||
|
return r.Source == "source" && r.Target == "target" &&
|
||||||
|
len(r.Schemas.Missing) == 0 && len(r.Schemas.Extra) == 0
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "different schemas",
|
||||||
|
source: &models.Database{
|
||||||
|
Name: "source",
|
||||||
|
Schemas: []*models.Schema{
|
||||||
|
{Name: "public"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
target: &models.Database{
|
||||||
|
Name: "target",
|
||||||
|
Schemas: []*models.Schema{},
|
||||||
|
},
|
||||||
|
want: func(r *DiffResult) bool {
|
||||||
|
return len(r.Schemas.Missing) == 1 && r.Schemas.Missing[0].Name == "public"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := CompareDatabases(tt.source, tt.target)
|
||||||
|
if !tt.want(got) {
|
||||||
|
t.Errorf("CompareDatabases() result doesn't match expectations")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompareColumns(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
source map[string]*models.Column
|
||||||
|
target map[string]*models.Column
|
||||||
|
want func(*ColumnDiff) bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "identical columns",
|
||||||
|
source: map[string]*models.Column{},
|
||||||
|
target: map[string]*models.Column{},
|
||||||
|
want: func(d *ColumnDiff) bool {
|
||||||
|
return len(d.Missing) == 0 && len(d.Extra) == 0 && len(d.Modified) == 0
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing column",
|
||||||
|
source: map[string]*models.Column{
|
||||||
|
"id": {Name: "id", Type: "integer"},
|
||||||
|
},
|
||||||
|
target: map[string]*models.Column{},
|
||||||
|
want: func(d *ColumnDiff) bool {
|
||||||
|
return len(d.Missing) == 1 && d.Missing[0].Name == "id"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "extra column",
|
||||||
|
source: map[string]*models.Column{},
|
||||||
|
target: map[string]*models.Column{
|
||||||
|
"id": {Name: "id", Type: "integer"},
|
||||||
|
},
|
||||||
|
want: func(d *ColumnDiff) bool {
|
||||||
|
return len(d.Extra) == 1 && d.Extra[0].Name == "id"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "modified column type",
|
||||||
|
source: map[string]*models.Column{
|
||||||
|
"id": {Name: "id", Type: "integer"},
|
||||||
|
},
|
||||||
|
target: map[string]*models.Column{
|
||||||
|
"id": {Name: "id", Type: "bigint"},
|
||||||
|
},
|
||||||
|
want: func(d *ColumnDiff) bool {
|
||||||
|
return len(d.Modified) == 1 && d.Modified[0].Name == "id" &&
|
||||||
|
d.Modified[0].Changes["type"] != nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "modified column nullable",
|
||||||
|
source: map[string]*models.Column{
|
||||||
|
"name": {Name: "name", Type: "text", NotNull: true},
|
||||||
|
},
|
||||||
|
target: map[string]*models.Column{
|
||||||
|
"name": {Name: "name", Type: "text", NotNull: false},
|
||||||
|
},
|
||||||
|
want: func(d *ColumnDiff) bool {
|
||||||
|
return len(d.Modified) == 1 && d.Modified[0].Changes["not_null"] != nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "modified column length",
|
||||||
|
source: map[string]*models.Column{
|
||||||
|
"name": {Name: "name", Type: "varchar", Length: 100},
|
||||||
|
},
|
||||||
|
target: map[string]*models.Column{
|
||||||
|
"name": {Name: "name", Type: "varchar", Length: 255},
|
||||||
|
},
|
||||||
|
want: func(d *ColumnDiff) bool {
|
||||||
|
return len(d.Modified) == 1 && d.Modified[0].Changes["length"] != nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := compareColumns(tt.source, tt.target)
|
||||||
|
if !tt.want(got) {
|
||||||
|
t.Errorf("compareColumns() result doesn't match expectations")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompareColumnDetails(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
source *models.Column
|
||||||
|
target *models.Column
|
||||||
|
want int // number of changes
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "identical columns",
|
||||||
|
source: &models.Column{Name: "id", Type: "integer"},
|
||||||
|
target: &models.Column{Name: "id", Type: "integer"},
|
||||||
|
want: 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "type change",
|
||||||
|
source: &models.Column{Name: "id", Type: "integer"},
|
||||||
|
target: &models.Column{Name: "id", Type: "bigint"},
|
||||||
|
want: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "length change",
|
||||||
|
source: &models.Column{Name: "name", Type: "varchar", Length: 100},
|
||||||
|
target: &models.Column{Name: "name", Type: "varchar", Length: 255},
|
||||||
|
want: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "precision change",
|
||||||
|
source: &models.Column{Name: "price", Type: "numeric", Precision: 10},
|
||||||
|
target: &models.Column{Name: "price", Type: "numeric", Precision: 12},
|
||||||
|
want: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "scale change",
|
||||||
|
source: &models.Column{Name: "price", Type: "numeric", Scale: 2},
|
||||||
|
target: &models.Column{Name: "price", Type: "numeric", Scale: 4},
|
||||||
|
want: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "not null change",
|
||||||
|
source: &models.Column{Name: "name", Type: "text", NotNull: true},
|
||||||
|
target: &models.Column{Name: "name", Type: "text", NotNull: false},
|
||||||
|
want: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "auto increment change",
|
||||||
|
source: &models.Column{Name: "id", Type: "integer", AutoIncrement: true},
|
||||||
|
target: &models.Column{Name: "id", Type: "integer", AutoIncrement: false},
|
||||||
|
want: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "primary key change",
|
||||||
|
source: &models.Column{Name: "id", Type: "integer", IsPrimaryKey: true},
|
||||||
|
target: &models.Column{Name: "id", Type: "integer", IsPrimaryKey: false},
|
||||||
|
want: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple changes",
|
||||||
|
source: &models.Column{Name: "id", Type: "integer", NotNull: true, AutoIncrement: true},
|
||||||
|
target: &models.Column{Name: "id", Type: "bigint", NotNull: false, AutoIncrement: false},
|
||||||
|
want: 3,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := compareColumnDetails(tt.source, tt.target)
|
||||||
|
if len(got) != tt.want {
|
||||||
|
t.Errorf("compareColumnDetails() = %d changes, want %d", len(got), tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompareIndexes(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
source map[string]*models.Index
|
||||||
|
target map[string]*models.Index
|
||||||
|
want func(*IndexDiff) bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "identical indexes",
|
||||||
|
source: map[string]*models.Index{},
|
||||||
|
target: map[string]*models.Index{},
|
||||||
|
want: func(d *IndexDiff) bool {
|
||||||
|
return len(d.Missing) == 0 && len(d.Extra) == 0 && len(d.Modified) == 0
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing index",
|
||||||
|
source: map[string]*models.Index{
|
||||||
|
"idx_name": {Name: "idx_name", Columns: []string{"name"}},
|
||||||
|
},
|
||||||
|
target: map[string]*models.Index{},
|
||||||
|
want: func(d *IndexDiff) bool {
|
||||||
|
return len(d.Missing) == 1 && d.Missing[0].Name == "idx_name"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "extra index",
|
||||||
|
source: map[string]*models.Index{},
|
||||||
|
target: map[string]*models.Index{
|
||||||
|
"idx_name": {Name: "idx_name", Columns: []string{"name"}},
|
||||||
|
},
|
||||||
|
want: func(d *IndexDiff) bool {
|
||||||
|
return len(d.Extra) == 1 && d.Extra[0].Name == "idx_name"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "modified index uniqueness",
|
||||||
|
source: map[string]*models.Index{
|
||||||
|
"idx_name": {Name: "idx_name", Columns: []string{"name"}, Unique: false},
|
||||||
|
},
|
||||||
|
target: map[string]*models.Index{
|
||||||
|
"idx_name": {Name: "idx_name", Columns: []string{"name"}, Unique: true},
|
||||||
|
},
|
||||||
|
want: func(d *IndexDiff) bool {
|
||||||
|
return len(d.Modified) == 1 && d.Modified[0].Name == "idx_name"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := compareIndexes(tt.source, tt.target)
|
||||||
|
if !tt.want(got) {
|
||||||
|
t.Errorf("compareIndexes() result doesn't match expectations")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompareConstraints(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
source map[string]*models.Constraint
|
||||||
|
target map[string]*models.Constraint
|
||||||
|
want func(*ConstraintDiff) bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "identical constraints",
|
||||||
|
source: map[string]*models.Constraint{},
|
||||||
|
target: map[string]*models.Constraint{},
|
||||||
|
want: func(d *ConstraintDiff) bool {
|
||||||
|
return len(d.Missing) == 0 && len(d.Extra) == 0 && len(d.Modified) == 0
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing constraint",
|
||||||
|
source: map[string]*models.Constraint{
|
||||||
|
"pk_id": {Name: "pk_id", Type: "PRIMARY KEY", Columns: []string{"id"}},
|
||||||
|
},
|
||||||
|
target: map[string]*models.Constraint{},
|
||||||
|
want: func(d *ConstraintDiff) bool {
|
||||||
|
return len(d.Missing) == 1 && d.Missing[0].Name == "pk_id"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "extra constraint",
|
||||||
|
source: map[string]*models.Constraint{},
|
||||||
|
target: map[string]*models.Constraint{
|
||||||
|
"pk_id": {Name: "pk_id", Type: "PRIMARY KEY", Columns: []string{"id"}},
|
||||||
|
},
|
||||||
|
want: func(d *ConstraintDiff) bool {
|
||||||
|
return len(d.Extra) == 1 && d.Extra[0].Name == "pk_id"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := compareConstraints(tt.source, tt.target)
|
||||||
|
if !tt.want(got) {
|
||||||
|
t.Errorf("compareConstraints() result doesn't match expectations")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompareRelationships(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
source map[string]*models.Relationship
|
||||||
|
target map[string]*models.Relationship
|
||||||
|
want func(*RelationshipDiff) bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "identical relationships",
|
||||||
|
source: map[string]*models.Relationship{},
|
||||||
|
target: map[string]*models.Relationship{},
|
||||||
|
want: func(d *RelationshipDiff) bool {
|
||||||
|
return len(d.Missing) == 0 && len(d.Extra) == 0 && len(d.Modified) == 0
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing relationship",
|
||||||
|
source: map[string]*models.Relationship{
|
||||||
|
"fk_user": {Name: "fk_user", Type: "FOREIGN KEY"},
|
||||||
|
},
|
||||||
|
target: map[string]*models.Relationship{},
|
||||||
|
want: func(d *RelationshipDiff) bool {
|
||||||
|
return len(d.Missing) == 1 && d.Missing[0].Name == "fk_user"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "extra relationship",
|
||||||
|
source: map[string]*models.Relationship{},
|
||||||
|
target: map[string]*models.Relationship{
|
||||||
|
"fk_user": {Name: "fk_user", Type: "FOREIGN KEY"},
|
||||||
|
},
|
||||||
|
want: func(d *RelationshipDiff) bool {
|
||||||
|
return len(d.Extra) == 1 && d.Extra[0].Name == "fk_user"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := compareRelationships(tt.source, tt.target)
|
||||||
|
if !tt.want(got) {
|
||||||
|
t.Errorf("compareRelationships() result doesn't match expectations")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompareTables(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
source []*models.Table
|
||||||
|
target []*models.Table
|
||||||
|
want func(*TableDiff) bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "identical tables",
|
||||||
|
source: []*models.Table{},
|
||||||
|
target: []*models.Table{},
|
||||||
|
want: func(d *TableDiff) bool {
|
||||||
|
return len(d.Missing) == 0 && len(d.Extra) == 0 && len(d.Modified) == 0
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing table",
|
||||||
|
source: []*models.Table{
|
||||||
|
{Name: "users", Schema: "public"},
|
||||||
|
},
|
||||||
|
target: []*models.Table{},
|
||||||
|
want: func(d *TableDiff) bool {
|
||||||
|
return len(d.Missing) == 1 && d.Missing[0].Name == "users"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "extra table",
|
||||||
|
source: []*models.Table{},
|
||||||
|
target: []*models.Table{
|
||||||
|
{Name: "users", Schema: "public"},
|
||||||
|
},
|
||||||
|
want: func(d *TableDiff) bool {
|
||||||
|
return len(d.Extra) == 1 && d.Extra[0].Name == "users"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "modified table",
|
||||||
|
source: []*models.Table{
|
||||||
|
{
|
||||||
|
Name: "users",
|
||||||
|
Schema: "public",
|
||||||
|
Columns: map[string]*models.Column{
|
||||||
|
"id": {Name: "id", Type: "integer"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
target: []*models.Table{
|
||||||
|
{
|
||||||
|
Name: "users",
|
||||||
|
Schema: "public",
|
||||||
|
Columns: map[string]*models.Column{
|
||||||
|
"id": {Name: "id", Type: "bigint"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: func(d *TableDiff) bool {
|
||||||
|
return len(d.Modified) == 1 && d.Modified[0].Name == "users"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := compareTables(tt.source, tt.target)
|
||||||
|
if !tt.want(got) {
|
||||||
|
t.Errorf("compareTables() result doesn't match expectations")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompareSchemas(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
source []*models.Schema
|
||||||
|
target []*models.Schema
|
||||||
|
want func(*SchemaDiff) bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "identical schemas",
|
||||||
|
source: []*models.Schema{},
|
||||||
|
target: []*models.Schema{},
|
||||||
|
want: func(d *SchemaDiff) bool {
|
||||||
|
return len(d.Missing) == 0 && len(d.Extra) == 0 && len(d.Modified) == 0
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing schema",
|
||||||
|
source: []*models.Schema{
|
||||||
|
{Name: "public"},
|
||||||
|
},
|
||||||
|
target: []*models.Schema{},
|
||||||
|
want: func(d *SchemaDiff) bool {
|
||||||
|
return len(d.Missing) == 1 && d.Missing[0].Name == "public"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "extra schema",
|
||||||
|
source: []*models.Schema{},
|
||||||
|
target: []*models.Schema{
|
||||||
|
{Name: "public"},
|
||||||
|
},
|
||||||
|
want: func(d *SchemaDiff) bool {
|
||||||
|
return len(d.Extra) == 1 && d.Extra[0].Name == "public"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := compareSchemas(tt.source, tt.target)
|
||||||
|
if !tt.want(got) {
|
||||||
|
t.Errorf("compareSchemas() result doesn't match expectations")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsEmpty(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
v interface{}
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{"empty ColumnDiff", &ColumnDiff{Missing: []*models.Column{}, Extra: []*models.Column{}, Modified: []*ColumnChange{}}, true},
|
||||||
|
{"ColumnDiff with missing", &ColumnDiff{Missing: []*models.Column{{Name: "id"}}, Extra: []*models.Column{}, Modified: []*ColumnChange{}}, false},
|
||||||
|
{"ColumnDiff with extra", &ColumnDiff{Missing: []*models.Column{}, Extra: []*models.Column{{Name: "id"}}, Modified: []*ColumnChange{}}, false},
|
||||||
|
{"empty IndexDiff", &IndexDiff{Missing: []*models.Index{}, Extra: []*models.Index{}, Modified: []*IndexChange{}}, true},
|
||||||
|
{"IndexDiff with missing", &IndexDiff{Missing: []*models.Index{{Name: "idx"}}, Extra: []*models.Index{}, Modified: []*IndexChange{}}, false},
|
||||||
|
{"empty TableDiff", &TableDiff{Missing: []*models.Table{}, Extra: []*models.Table{}, Modified: []*TableChange{}}, true},
|
||||||
|
{"TableDiff with extra", &TableDiff{Missing: []*models.Table{}, Extra: []*models.Table{{Name: "users"}}, Modified: []*TableChange{}}, false},
|
||||||
|
{"empty ConstraintDiff", &ConstraintDiff{Missing: []*models.Constraint{}, Extra: []*models.Constraint{}, Modified: []*ConstraintChange{}}, true},
|
||||||
|
{"empty RelationshipDiff", &RelationshipDiff{Missing: []*models.Relationship{}, Extra: []*models.Relationship{}, Modified: []*RelationshipChange{}}, true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := isEmpty(tt.v)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("isEmpty() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestComputeSummary(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
result *DiffResult
|
||||||
|
want func(*Summary) bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty diff",
|
||||||
|
result: &DiffResult{
|
||||||
|
Schemas: &SchemaDiff{
|
||||||
|
Missing: []*models.Schema{},
|
||||||
|
Extra: []*models.Schema{},
|
||||||
|
Modified: []*SchemaChange{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: func(s *Summary) bool {
|
||||||
|
return s.Schemas.Missing == 0 && s.Schemas.Extra == 0 && s.Schemas.Modified == 0
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "schemas with differences",
|
||||||
|
result: &DiffResult{
|
||||||
|
Schemas: &SchemaDiff{
|
||||||
|
Missing: []*models.Schema{{Name: "schema1"}},
|
||||||
|
Extra: []*models.Schema{{Name: "schema2"}, {Name: "schema3"}},
|
||||||
|
Modified: []*SchemaChange{
|
||||||
|
{Name: "public"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: func(s *Summary) bool {
|
||||||
|
return s.Schemas.Missing == 1 && s.Schemas.Extra == 2 && s.Schemas.Modified == 1
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := ComputeSummary(tt.result)
|
||||||
|
if !tt.want(got) {
|
||||||
|
t.Errorf("ComputeSummary() result doesn't match expectations")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
440
pkg/diff/formatters_test.go
Normal file
440
pkg/diff/formatters_test.go
Normal file
@@ -0,0 +1,440 @@
|
|||||||
|
package diff
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestFormatDiff(t *testing.T) {
|
||||||
|
result := &DiffResult{
|
||||||
|
Source: "source_db",
|
||||||
|
Target: "target_db",
|
||||||
|
Schemas: &SchemaDiff{
|
||||||
|
Missing: []*models.Schema{},
|
||||||
|
Extra: []*models.Schema{},
|
||||||
|
Modified: []*SchemaChange{},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
format OutputFormat
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"summary format", FormatSummary, false},
|
||||||
|
{"json format", FormatJSON, false},
|
||||||
|
{"html format", FormatHTML, false},
|
||||||
|
{"invalid format", OutputFormat("invalid"), true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
err := FormatDiff(result, tt.format, &buf)
|
||||||
|
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("FormatDiff() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !tt.wantErr && buf.Len() == 0 {
|
||||||
|
t.Error("FormatDiff() produced empty output")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFormatSummary(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
result *DiffResult
|
||||||
|
wantStr []string // strings that should appear in output
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "no differences",
|
||||||
|
result: &DiffResult{
|
||||||
|
Source: "source",
|
||||||
|
Target: "target",
|
||||||
|
Schemas: &SchemaDiff{
|
||||||
|
Missing: []*models.Schema{},
|
||||||
|
Extra: []*models.Schema{},
|
||||||
|
Modified: []*SchemaChange{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantStr: []string{"source", "target", "No differences found"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "with schema differences",
|
||||||
|
result: &DiffResult{
|
||||||
|
Source: "source",
|
||||||
|
Target: "target",
|
||||||
|
Schemas: &SchemaDiff{
|
||||||
|
Missing: []*models.Schema{{Name: "schema1"}},
|
||||||
|
Extra: []*models.Schema{{Name: "schema2"}},
|
||||||
|
Modified: []*SchemaChange{
|
||||||
|
{Name: "public"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantStr: []string{"Schemas:", "Missing: 1", "Extra: 1", "Modified: 1"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "with table differences",
|
||||||
|
result: &DiffResult{
|
||||||
|
Source: "source",
|
||||||
|
Target: "target",
|
||||||
|
Schemas: &SchemaDiff{
|
||||||
|
Modified: []*SchemaChange{
|
||||||
|
{
|
||||||
|
Name: "public",
|
||||||
|
Tables: &TableDiff{
|
||||||
|
Missing: []*models.Table{{Name: "users"}},
|
||||||
|
Extra: []*models.Table{{Name: "posts"}},
|
||||||
|
Modified: []*TableChange{
|
||||||
|
{Name: "comments", Schema: "public"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantStr: []string{"Tables:", "Missing: 1", "Extra: 1", "Modified: 1"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
err := formatSummary(tt.result, &buf)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("formatSummary() error = %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
output := buf.String()
|
||||||
|
for _, want := range tt.wantStr {
|
||||||
|
if !strings.Contains(output, want) {
|
||||||
|
t.Errorf("formatSummary() output doesn't contain %q\nGot: %s", want, output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFormatJSON(t *testing.T) {
|
||||||
|
result := &DiffResult{
|
||||||
|
Source: "source",
|
||||||
|
Target: "target",
|
||||||
|
Schemas: &SchemaDiff{
|
||||||
|
Missing: []*models.Schema{{Name: "schema1"}},
|
||||||
|
Extra: []*models.Schema{},
|
||||||
|
Modified: []*SchemaChange{},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
err := formatJSON(result, &buf)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("formatJSON() error = %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if output is valid JSON
|
||||||
|
var decoded DiffResult
|
||||||
|
if err := json.Unmarshal(buf.Bytes(), &decoded); err != nil {
|
||||||
|
t.Errorf("formatJSON() produced invalid JSON: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check basic structure
|
||||||
|
if decoded.Source != "source" {
|
||||||
|
t.Errorf("formatJSON() source = %v, want %v", decoded.Source, "source")
|
||||||
|
}
|
||||||
|
if decoded.Target != "target" {
|
||||||
|
t.Errorf("formatJSON() target = %v, want %v", decoded.Target, "target")
|
||||||
|
}
|
||||||
|
if len(decoded.Schemas.Missing) != 1 {
|
||||||
|
t.Errorf("formatJSON() missing schemas = %v, want 1", len(decoded.Schemas.Missing))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFormatHTML(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
result *DiffResult
|
||||||
|
wantStr []string // HTML elements/content that should appear
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "basic HTML structure",
|
||||||
|
result: &DiffResult{
|
||||||
|
Source: "source",
|
||||||
|
Target: "target",
|
||||||
|
Schemas: &SchemaDiff{
|
||||||
|
Missing: []*models.Schema{},
|
||||||
|
Extra: []*models.Schema{},
|
||||||
|
Modified: []*SchemaChange{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantStr: []string{
|
||||||
|
"<!DOCTYPE html>",
|
||||||
|
"<title>Database Diff Report</title>",
|
||||||
|
"source",
|
||||||
|
"target",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "with schema differences",
|
||||||
|
result: &DiffResult{
|
||||||
|
Source: "source",
|
||||||
|
Target: "target",
|
||||||
|
Schemas: &SchemaDiff{
|
||||||
|
Missing: []*models.Schema{{Name: "missing_schema"}},
|
||||||
|
Extra: []*models.Schema{{Name: "extra_schema"}},
|
||||||
|
Modified: []*SchemaChange{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantStr: []string{
|
||||||
|
"<!DOCTYPE html>",
|
||||||
|
"missing_schema",
|
||||||
|
"extra_schema",
|
||||||
|
"MISSING",
|
||||||
|
"EXTRA",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "with table modifications",
|
||||||
|
result: &DiffResult{
|
||||||
|
Source: "source",
|
||||||
|
Target: "target",
|
||||||
|
Schemas: &SchemaDiff{
|
||||||
|
Modified: []*SchemaChange{
|
||||||
|
{
|
||||||
|
Name: "public",
|
||||||
|
Tables: &TableDiff{
|
||||||
|
Modified: []*TableChange{
|
||||||
|
{
|
||||||
|
Name: "users",
|
||||||
|
Schema: "public",
|
||||||
|
Columns: &ColumnDiff{
|
||||||
|
Missing: []*models.Column{{Name: "email", Type: "text"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantStr: []string{
|
||||||
|
"public",
|
||||||
|
"users",
|
||||||
|
"email",
|
||||||
|
"text",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
err := formatHTML(tt.result, &buf)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("formatHTML() error = %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
output := buf.String()
|
||||||
|
for _, want := range tt.wantStr {
|
||||||
|
if !strings.Contains(output, want) {
|
||||||
|
t.Errorf("formatHTML() output doesn't contain %q", want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFormatSummaryWithColumns(t *testing.T) {
|
||||||
|
result := &DiffResult{
|
||||||
|
Source: "source",
|
||||||
|
Target: "target",
|
||||||
|
Schemas: &SchemaDiff{
|
||||||
|
Modified: []*SchemaChange{
|
||||||
|
{
|
||||||
|
Name: "public",
|
||||||
|
Tables: &TableDiff{
|
||||||
|
Modified: []*TableChange{
|
||||||
|
{
|
||||||
|
Name: "users",
|
||||||
|
Schema: "public",
|
||||||
|
Columns: &ColumnDiff{
|
||||||
|
Missing: []*models.Column{{Name: "email"}},
|
||||||
|
Extra: []*models.Column{{Name: "phone"}, {Name: "address"}},
|
||||||
|
Modified: []*ColumnChange{
|
||||||
|
{Name: "name"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
err := formatSummary(result, &buf)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("formatSummary() error = %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
output := buf.String()
|
||||||
|
wantStrings := []string{
|
||||||
|
"Columns:",
|
||||||
|
"Missing: 1",
|
||||||
|
"Extra: 2",
|
||||||
|
"Modified: 1",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, want := range wantStrings {
|
||||||
|
if !strings.Contains(output, want) {
|
||||||
|
t.Errorf("formatSummary() output doesn't contain %q\nGot: %s", want, output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFormatSummaryWithIndexes(t *testing.T) {
|
||||||
|
result := &DiffResult{
|
||||||
|
Source: "source",
|
||||||
|
Target: "target",
|
||||||
|
Schemas: &SchemaDiff{
|
||||||
|
Modified: []*SchemaChange{
|
||||||
|
{
|
||||||
|
Name: "public",
|
||||||
|
Tables: &TableDiff{
|
||||||
|
Modified: []*TableChange{
|
||||||
|
{
|
||||||
|
Name: "users",
|
||||||
|
Schema: "public",
|
||||||
|
Indexes: &IndexDiff{
|
||||||
|
Missing: []*models.Index{{Name: "idx_email"}},
|
||||||
|
Extra: []*models.Index{{Name: "idx_phone"}},
|
||||||
|
Modified: []*IndexChange{{Name: "idx_name"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
err := formatSummary(result, &buf)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("formatSummary() error = %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
output := buf.String()
|
||||||
|
if !strings.Contains(output, "Indexes:") {
|
||||||
|
t.Error("formatSummary() output doesn't contain Indexes section")
|
||||||
|
}
|
||||||
|
if !strings.Contains(output, "Missing: 1") {
|
||||||
|
t.Error("formatSummary() output doesn't contain correct missing count")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFormatSummaryWithConstraints(t *testing.T) {
|
||||||
|
result := &DiffResult{
|
||||||
|
Source: "source",
|
||||||
|
Target: "target",
|
||||||
|
Schemas: &SchemaDiff{
|
||||||
|
Modified: []*SchemaChange{
|
||||||
|
{
|
||||||
|
Name: "public",
|
||||||
|
Tables: &TableDiff{
|
||||||
|
Modified: []*TableChange{
|
||||||
|
{
|
||||||
|
Name: "users",
|
||||||
|
Schema: "public",
|
||||||
|
Constraints: &ConstraintDiff{
|
||||||
|
Missing: []*models.Constraint{{Name: "pk_users", Type: "PRIMARY KEY"}},
|
||||||
|
Extra: []*models.Constraint{{Name: "fk_users_roles", Type: "FOREIGN KEY"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
err := formatSummary(result, &buf)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("formatSummary() error = %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
output := buf.String()
|
||||||
|
if !strings.Contains(output, "Constraints:") {
|
||||||
|
t.Error("formatSummary() output doesn't contain Constraints section")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFormatJSONIndentation(t *testing.T) {
|
||||||
|
result := &DiffResult{
|
||||||
|
Source: "source",
|
||||||
|
Target: "target",
|
||||||
|
Schemas: &SchemaDiff{
|
||||||
|
Missing: []*models.Schema{{Name: "test"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
err := formatJSON(result, &buf)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("formatJSON() error = %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that JSON is indented (has newlines and spaces)
|
||||||
|
output := buf.String()
|
||||||
|
if !strings.Contains(output, "\n") {
|
||||||
|
t.Error("formatJSON() should produce indented JSON with newlines")
|
||||||
|
}
|
||||||
|
if !strings.Contains(output, " ") {
|
||||||
|
t.Error("formatJSON() should produce indented JSON with spaces")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOutputFormatConstants(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
format OutputFormat
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"summary constant", FormatSummary, "summary"},
|
||||||
|
{"json constant", FormatJSON, "json"},
|
||||||
|
{"html constant", FormatHTML, "html"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if string(tt.format) != tt.want {
|
||||||
|
t.Errorf("OutputFormat %v = %v, want %v", tt.name, tt.format, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
238
pkg/inspector/inspector_test.go
Normal file
238
pkg/inspector/inspector_test.go
Normal file
@@ -0,0 +1,238 @@
|
|||||||
|
package inspector
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewInspector(t *testing.T) {
|
||||||
|
db := createTestDatabase()
|
||||||
|
config := GetDefaultConfig()
|
||||||
|
|
||||||
|
inspector := NewInspector(db, config)
|
||||||
|
|
||||||
|
if inspector == nil {
|
||||||
|
t.Fatal("NewInspector() returned nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if inspector.db != db {
|
||||||
|
t.Error("NewInspector() database not set correctly")
|
||||||
|
}
|
||||||
|
|
||||||
|
if inspector.config != config {
|
||||||
|
t.Error("NewInspector() config not set correctly")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInspect(t *testing.T) {
|
||||||
|
db := createTestDatabase()
|
||||||
|
config := GetDefaultConfig()
|
||||||
|
|
||||||
|
inspector := NewInspector(db, config)
|
||||||
|
report, err := inspector.Inspect()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Inspect() returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if report == nil {
|
||||||
|
t.Fatal("Inspect() returned nil report")
|
||||||
|
}
|
||||||
|
|
||||||
|
if report.Database != db.Name {
|
||||||
|
t.Errorf("Inspect() report.Database = %q, want %q", report.Database, db.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
if report.Summary.TotalRules != len(config.Rules) {
|
||||||
|
t.Errorf("Inspect() TotalRules = %d, want %d", report.Summary.TotalRules, len(config.Rules))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(report.Violations) == 0 {
|
||||||
|
t.Error("Inspect() returned no violations, expected some results")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInspectWithDisabledRules(t *testing.T) {
|
||||||
|
db := createTestDatabase()
|
||||||
|
config := GetDefaultConfig()
|
||||||
|
|
||||||
|
// Disable all rules
|
||||||
|
for name := range config.Rules {
|
||||||
|
rule := config.Rules[name]
|
||||||
|
rule.Enabled = "off"
|
||||||
|
config.Rules[name] = rule
|
||||||
|
}
|
||||||
|
|
||||||
|
inspector := NewInspector(db, config)
|
||||||
|
report, err := inspector.Inspect()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Inspect() with disabled rules returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if report.Summary.RulesChecked != 0 {
|
||||||
|
t.Errorf("Inspect() RulesChecked = %d, want 0 (all disabled)", report.Summary.RulesChecked)
|
||||||
|
}
|
||||||
|
|
||||||
|
if report.Summary.RulesSkipped != len(config.Rules) {
|
||||||
|
t.Errorf("Inspect() RulesSkipped = %d, want %d", report.Summary.RulesSkipped, len(config.Rules))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInspectWithEnforcedRules(t *testing.T) {
|
||||||
|
db := createTestDatabase()
|
||||||
|
config := GetDefaultConfig()
|
||||||
|
|
||||||
|
// Enable only one rule and enforce it
|
||||||
|
for name := range config.Rules {
|
||||||
|
rule := config.Rules[name]
|
||||||
|
rule.Enabled = "off"
|
||||||
|
config.Rules[name] = rule
|
||||||
|
}
|
||||||
|
|
||||||
|
primaryKeyRule := config.Rules["primary_key_naming"]
|
||||||
|
primaryKeyRule.Enabled = "enforce"
|
||||||
|
primaryKeyRule.Pattern = "^id$"
|
||||||
|
config.Rules["primary_key_naming"] = primaryKeyRule
|
||||||
|
|
||||||
|
inspector := NewInspector(db, config)
|
||||||
|
report, err := inspector.Inspect()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Inspect() returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if report.Summary.RulesChecked != 1 {
|
||||||
|
t.Errorf("Inspect() RulesChecked = %d, want 1", report.Summary.RulesChecked)
|
||||||
|
}
|
||||||
|
|
||||||
|
// All results should be at error level for enforced rules
|
||||||
|
for _, violation := range report.Violations {
|
||||||
|
if violation.Level != "error" {
|
||||||
|
t.Errorf("Enforced rule violation has Level = %q, want \"error\"", violation.Level)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerateSummary(t *testing.T) {
|
||||||
|
db := createTestDatabase()
|
||||||
|
config := GetDefaultConfig()
|
||||||
|
inspector := NewInspector(db, config)
|
||||||
|
|
||||||
|
results := []ValidationResult{
|
||||||
|
{RuleName: "rule1", Passed: true, Level: "error"},
|
||||||
|
{RuleName: "rule2", Passed: false, Level: "error"},
|
||||||
|
{RuleName: "rule3", Passed: false, Level: "warning"},
|
||||||
|
{RuleName: "rule4", Passed: true, Level: "warning"},
|
||||||
|
}
|
||||||
|
|
||||||
|
summary := inspector.generateSummary(results)
|
||||||
|
|
||||||
|
if summary.PassedCount != 2 {
|
||||||
|
t.Errorf("generateSummary() PassedCount = %d, want 2", summary.PassedCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
if summary.ErrorCount != 1 {
|
||||||
|
t.Errorf("generateSummary() ErrorCount = %d, want 1", summary.ErrorCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
if summary.WarningCount != 1 {
|
||||||
|
t.Errorf("generateSummary() WarningCount = %d, want 1", summary.WarningCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHasErrors(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
report *InspectorReport
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "with errors",
|
||||||
|
report: &InspectorReport{
|
||||||
|
Summary: ReportSummary{
|
||||||
|
ErrorCount: 5,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "without errors",
|
||||||
|
report: &InspectorReport{
|
||||||
|
Summary: ReportSummary{
|
||||||
|
ErrorCount: 0,
|
||||||
|
WarningCount: 3,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := tt.report.HasErrors(); got != tt.want {
|
||||||
|
t.Errorf("HasErrors() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetValidator(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
functionName string
|
||||||
|
wantExists bool
|
||||||
|
}{
|
||||||
|
{"primary_key_naming", "primary_key_naming", true},
|
||||||
|
{"primary_key_datatype", "primary_key_datatype", true},
|
||||||
|
{"foreign_key_column_naming", "foreign_key_column_naming", true},
|
||||||
|
{"table_regexpr", "table_regexpr", true},
|
||||||
|
{"column_regexpr", "column_regexpr", true},
|
||||||
|
{"reserved_words", "reserved_words", true},
|
||||||
|
{"have_primary_key", "have_primary_key", true},
|
||||||
|
{"orphaned_foreign_key", "orphaned_foreign_key", true},
|
||||||
|
{"circular_dependency", "circular_dependency", true},
|
||||||
|
{"unknown_function", "unknown_function", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
_, exists := getValidator(tt.functionName)
|
||||||
|
if exists != tt.wantExists {
|
||||||
|
t.Errorf("getValidator(%q) exists = %v, want %v", tt.functionName, exists, tt.wantExists)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateResult(t *testing.T) {
|
||||||
|
result := createResult(
|
||||||
|
"test_rule",
|
||||||
|
true,
|
||||||
|
"Test message",
|
||||||
|
"schema.table.column",
|
||||||
|
map[string]interface{}{
|
||||||
|
"key1": "value1",
|
||||||
|
"key2": 42,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.RuleName != "test_rule" {
|
||||||
|
t.Errorf("createResult() RuleName = %q, want \"test_rule\"", result.RuleName)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !result.Passed {
|
||||||
|
t.Error("createResult() Passed = false, want true")
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.Message != "Test message" {
|
||||||
|
t.Errorf("createResult() Message = %q, want \"Test message\"", result.Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.Location != "schema.table.column" {
|
||||||
|
t.Errorf("createResult() Location = %q, want \"schema.table.column\"", result.Location)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(result.Context) != 2 {
|
||||||
|
t.Errorf("createResult() Context length = %d, want 2", len(result.Context))
|
||||||
|
}
|
||||||
|
}
|
||||||
366
pkg/inspector/report_test.go
Normal file
366
pkg/inspector/report_test.go
Normal file
@@ -0,0 +1,366 @@
|
|||||||
|
package inspector
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func createTestReport() *InspectorReport {
|
||||||
|
return &InspectorReport{
|
||||||
|
Summary: ReportSummary{
|
||||||
|
TotalRules: 10,
|
||||||
|
RulesChecked: 8,
|
||||||
|
RulesSkipped: 2,
|
||||||
|
ErrorCount: 3,
|
||||||
|
WarningCount: 5,
|
||||||
|
PassedCount: 12,
|
||||||
|
},
|
||||||
|
Violations: []ValidationResult{
|
||||||
|
{
|
||||||
|
RuleName: "primary_key_naming",
|
||||||
|
Level: "error",
|
||||||
|
Message: "Primary key should start with 'id_'",
|
||||||
|
Location: "public.users.user_id",
|
||||||
|
Passed: false,
|
||||||
|
Context: map[string]interface{}{
|
||||||
|
"schema": "public",
|
||||||
|
"table": "users",
|
||||||
|
"column": "user_id",
|
||||||
|
"pattern": "^id_",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
RuleName: "table_name_length",
|
||||||
|
Level: "warning",
|
||||||
|
Message: "Table name too long",
|
||||||
|
Location: "public.very_long_table_name_that_exceeds_limits",
|
||||||
|
Passed: false,
|
||||||
|
Context: map[string]interface{}{
|
||||||
|
"schema": "public",
|
||||||
|
"table": "very_long_table_name_that_exceeds_limits",
|
||||||
|
"length": 44,
|
||||||
|
"max_length": 32,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
GeneratedAt: time.Now(),
|
||||||
|
Database: "testdb",
|
||||||
|
SourceFormat: "postgresql",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewMarkdownFormatter(t *testing.T) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
formatter := NewMarkdownFormatter(&buf)
|
||||||
|
|
||||||
|
if formatter == nil {
|
||||||
|
t.Fatal("NewMarkdownFormatter() returned nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Buffer is not a terminal, so colors should be disabled
|
||||||
|
if formatter.UseColors {
|
||||||
|
t.Error("NewMarkdownFormatter() UseColors should be false for non-terminal")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewJSONFormatter(t *testing.T) {
|
||||||
|
formatter := NewJSONFormatter()
|
||||||
|
|
||||||
|
if formatter == nil {
|
||||||
|
t.Fatal("NewJSONFormatter() returned nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMarkdownFormatter_Format(t *testing.T) {
|
||||||
|
report := createTestReport()
|
||||||
|
var buf bytes.Buffer
|
||||||
|
formatter := NewMarkdownFormatter(&buf)
|
||||||
|
|
||||||
|
output, err := formatter.Format(report)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("MarkdownFormatter.Format() returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that output contains expected sections
|
||||||
|
if !strings.Contains(output, "# RelSpec Inspector Report") {
|
||||||
|
t.Error("Markdown output missing header")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(output, "Database:") {
|
||||||
|
t.Error("Markdown output missing database field")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(output, "testdb") {
|
||||||
|
t.Error("Markdown output missing database name")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(output, "Summary") {
|
||||||
|
t.Error("Markdown output missing summary section")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(output, "Rules Checked: 8") {
|
||||||
|
t.Error("Markdown output missing rules checked count")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(output, "Errors: 3") {
|
||||||
|
t.Error("Markdown output missing error count")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(output, "Warnings: 5") {
|
||||||
|
t.Error("Markdown output missing warning count")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(output, "Violations") {
|
||||||
|
t.Error("Markdown output missing violations section")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(output, "primary_key_naming") {
|
||||||
|
t.Error("Markdown output missing rule name")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(output, "public.users.user_id") {
|
||||||
|
t.Error("Markdown output missing location")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMarkdownFormatter_FormatNoViolations(t *testing.T) {
|
||||||
|
report := &InspectorReport{
|
||||||
|
Summary: ReportSummary{
|
||||||
|
TotalRules: 10,
|
||||||
|
RulesChecked: 10,
|
||||||
|
RulesSkipped: 0,
|
||||||
|
ErrorCount: 0,
|
||||||
|
WarningCount: 0,
|
||||||
|
PassedCount: 50,
|
||||||
|
},
|
||||||
|
Violations: []ValidationResult{},
|
||||||
|
GeneratedAt: time.Now(),
|
||||||
|
Database: "testdb",
|
||||||
|
SourceFormat: "postgresql",
|
||||||
|
}
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
formatter := NewMarkdownFormatter(&buf)
|
||||||
|
|
||||||
|
output, err := formatter.Format(report)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("MarkdownFormatter.Format() returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(output, "No violations found") {
|
||||||
|
t.Error("Markdown output should indicate no violations")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJSONFormatter_Format(t *testing.T) {
|
||||||
|
report := createTestReport()
|
||||||
|
formatter := NewJSONFormatter()
|
||||||
|
|
||||||
|
output, err := formatter.Format(report)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("JSONFormatter.Format() returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify it's valid JSON
|
||||||
|
var decoded InspectorReport
|
||||||
|
if err := json.Unmarshal([]byte(output), &decoded); err != nil {
|
||||||
|
t.Fatalf("JSONFormatter.Format() produced invalid JSON: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check key fields
|
||||||
|
if decoded.Database != "testdb" {
|
||||||
|
t.Errorf("JSON decoded Database = %q, want \"testdb\"", decoded.Database)
|
||||||
|
}
|
||||||
|
|
||||||
|
if decoded.Summary.ErrorCount != 3 {
|
||||||
|
t.Errorf("JSON decoded ErrorCount = %d, want 3", decoded.Summary.ErrorCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(decoded.Violations) != 2 {
|
||||||
|
t.Errorf("JSON decoded Violations length = %d, want 2", len(decoded.Violations))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMarkdownFormatter_FormatHeader(t *testing.T) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
formatter := NewMarkdownFormatter(&buf)
|
||||||
|
|
||||||
|
header := formatter.formatHeader("Test Header")
|
||||||
|
|
||||||
|
if !strings.Contains(header, "# Test Header") {
|
||||||
|
t.Errorf("formatHeader() = %q, want to contain \"# Test Header\"", header)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMarkdownFormatter_FormatBold(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
useColors bool
|
||||||
|
text string
|
||||||
|
wantContains string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "without colors",
|
||||||
|
useColors: false,
|
||||||
|
text: "Bold Text",
|
||||||
|
wantContains: "**Bold Text**",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "with colors",
|
||||||
|
useColors: true,
|
||||||
|
text: "Bold Text",
|
||||||
|
wantContains: "Bold Text",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
formatter := &MarkdownFormatter{UseColors: tt.useColors}
|
||||||
|
result := formatter.formatBold(tt.text)
|
||||||
|
|
||||||
|
if !strings.Contains(result, tt.wantContains) {
|
||||||
|
t.Errorf("formatBold() = %q, want to contain %q", result, tt.wantContains)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMarkdownFormatter_Colorize(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
useColors bool
|
||||||
|
text string
|
||||||
|
color string
|
||||||
|
wantColor bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "without colors",
|
||||||
|
useColors: false,
|
||||||
|
text: "Test",
|
||||||
|
color: colorRed,
|
||||||
|
wantColor: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "with colors",
|
||||||
|
useColors: true,
|
||||||
|
text: "Test",
|
||||||
|
color: colorRed,
|
||||||
|
wantColor: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
formatter := &MarkdownFormatter{UseColors: tt.useColors}
|
||||||
|
result := formatter.colorize(tt.text, tt.color)
|
||||||
|
|
||||||
|
hasColor := strings.Contains(result, tt.color)
|
||||||
|
if hasColor != tt.wantColor {
|
||||||
|
t.Errorf("colorize() has color codes = %v, want %v", hasColor, tt.wantColor)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(result, tt.text) {
|
||||||
|
t.Errorf("colorize() doesn't contain original text %q", tt.text)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMarkdownFormatter_FormatContext(t *testing.T) {
|
||||||
|
formatter := &MarkdownFormatter{UseColors: false}
|
||||||
|
|
||||||
|
context := map[string]interface{}{
|
||||||
|
"schema": "public",
|
||||||
|
"table": "users",
|
||||||
|
"column": "id",
|
||||||
|
"pattern": "^id_",
|
||||||
|
"max_length": 64,
|
||||||
|
}
|
||||||
|
|
||||||
|
result := formatter.formatContext(context)
|
||||||
|
|
||||||
|
// Should not include schema, table, column (they're in location)
|
||||||
|
if strings.Contains(result, "schema") {
|
||||||
|
t.Error("formatContext() should skip schema field")
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.Contains(result, "table=") {
|
||||||
|
t.Error("formatContext() should skip table field")
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.Contains(result, "column=") {
|
||||||
|
t.Error("formatContext() should skip column field")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should include other fields
|
||||||
|
if !strings.Contains(result, "pattern") {
|
||||||
|
t.Error("formatContext() should include pattern field")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(result, "max_length") {
|
||||||
|
t.Error("formatContext() should include max_length field")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMarkdownFormatter_FormatViolation(t *testing.T) {
|
||||||
|
formatter := &MarkdownFormatter{UseColors: false}
|
||||||
|
|
||||||
|
violation := ValidationResult{
|
||||||
|
RuleName: "test_rule",
|
||||||
|
Level: "error",
|
||||||
|
Message: "Test violation message",
|
||||||
|
Location: "public.users.id",
|
||||||
|
Passed: false,
|
||||||
|
Context: map[string]interface{}{
|
||||||
|
"pattern": "^id_",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := formatter.formatViolation(violation, colorRed)
|
||||||
|
|
||||||
|
if !strings.Contains(result, "test_rule") {
|
||||||
|
t.Error("formatViolation() should include rule name")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(result, "Test violation message") {
|
||||||
|
t.Error("formatViolation() should include message")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(result, "public.users.id") {
|
||||||
|
t.Error("formatViolation() should include location")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(result, "Location:") {
|
||||||
|
t.Error("formatViolation() should include Location label")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(result, "Message:") {
|
||||||
|
t.Error("formatViolation() should include Message label")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReportFormatConstants(t *testing.T) {
|
||||||
|
// Test that color constants are defined
|
||||||
|
if colorReset == "" {
|
||||||
|
t.Error("colorReset is not defined")
|
||||||
|
}
|
||||||
|
|
||||||
|
if colorRed == "" {
|
||||||
|
t.Error("colorRed is not defined")
|
||||||
|
}
|
||||||
|
|
||||||
|
if colorYellow == "" {
|
||||||
|
t.Error("colorYellow is not defined")
|
||||||
|
}
|
||||||
|
|
||||||
|
if colorGreen == "" {
|
||||||
|
t.Error("colorGreen is not defined")
|
||||||
|
}
|
||||||
|
|
||||||
|
if colorBold == "" {
|
||||||
|
t.Error("colorBold is not defined")
|
||||||
|
}
|
||||||
|
}
|
||||||
249
pkg/inspector/rules_test.go
Normal file
249
pkg/inspector/rules_test.go
Normal file
@@ -0,0 +1,249 @@
|
|||||||
|
package inspector
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGetDefaultConfig(t *testing.T) {
|
||||||
|
config := GetDefaultConfig()
|
||||||
|
|
||||||
|
if config == nil {
|
||||||
|
t.Fatal("GetDefaultConfig() returned nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.Version != "1.0" {
|
||||||
|
t.Errorf("GetDefaultConfig() Version = %q, want \"1.0\"", config.Version)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(config.Rules) == 0 {
|
||||||
|
t.Error("GetDefaultConfig() returned no rules")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that all expected rules are present
|
||||||
|
expectedRules := []string{
|
||||||
|
"primary_key_naming",
|
||||||
|
"primary_key_datatype",
|
||||||
|
"primary_key_auto_increment",
|
||||||
|
"foreign_key_column_naming",
|
||||||
|
"foreign_key_constraint_naming",
|
||||||
|
"foreign_key_index",
|
||||||
|
"table_naming_case",
|
||||||
|
"column_naming_case",
|
||||||
|
"table_name_length",
|
||||||
|
"column_name_length",
|
||||||
|
"reserved_keywords",
|
||||||
|
"missing_primary_key",
|
||||||
|
"orphaned_foreign_key",
|
||||||
|
"circular_dependency",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, ruleName := range expectedRules {
|
||||||
|
if _, exists := config.Rules[ruleName]; !exists {
|
||||||
|
t.Errorf("GetDefaultConfig() missing rule: %q", ruleName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadConfig_NonExistentFile(t *testing.T) {
|
||||||
|
// Try to load a non-existent file
|
||||||
|
config, err := LoadConfig("/path/to/nonexistent/file.yaml")
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("LoadConfig() with non-existent file returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should return default config
|
||||||
|
if config == nil {
|
||||||
|
t.Fatal("LoadConfig() returned nil config for non-existent file")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(config.Rules) == 0 {
|
||||||
|
t.Error("LoadConfig() returned config with no rules")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadConfig_ValidFile(t *testing.T) {
|
||||||
|
// Create a temporary config file
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
configPath := filepath.Join(tmpDir, "test-config.yaml")
|
||||||
|
|
||||||
|
configContent := `version: "1.0"
|
||||||
|
rules:
|
||||||
|
primary_key_naming:
|
||||||
|
enabled: "enforce"
|
||||||
|
function: "primary_key_naming"
|
||||||
|
pattern: "^pk_"
|
||||||
|
message: "Primary keys must start with pk_"
|
||||||
|
table_name_length:
|
||||||
|
enabled: "warn"
|
||||||
|
function: "table_name_length"
|
||||||
|
max_length: 50
|
||||||
|
message: "Table name too long"
|
||||||
|
`
|
||||||
|
|
||||||
|
err := os.WriteFile(configPath, []byte(configContent), 0644)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create test config file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
config, err := LoadConfig(configPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("LoadConfig() returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.Version != "1.0" {
|
||||||
|
t.Errorf("LoadConfig() Version = %q, want \"1.0\"", config.Version)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(config.Rules) != 2 {
|
||||||
|
t.Errorf("LoadConfig() loaded %d rules, want 2", len(config.Rules))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check primary_key_naming rule
|
||||||
|
pkRule, exists := config.Rules["primary_key_naming"]
|
||||||
|
if !exists {
|
||||||
|
t.Fatal("LoadConfig() missing primary_key_naming rule")
|
||||||
|
}
|
||||||
|
|
||||||
|
if pkRule.Enabled != "enforce" {
|
||||||
|
t.Errorf("primary_key_naming.Enabled = %q, want \"enforce\"", pkRule.Enabled)
|
||||||
|
}
|
||||||
|
|
||||||
|
if pkRule.Pattern != "^pk_" {
|
||||||
|
t.Errorf("primary_key_naming.Pattern = %q, want \"^pk_\"", pkRule.Pattern)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check table_name_length rule
|
||||||
|
lengthRule, exists := config.Rules["table_name_length"]
|
||||||
|
if !exists {
|
||||||
|
t.Fatal("LoadConfig() missing table_name_length rule")
|
||||||
|
}
|
||||||
|
|
||||||
|
if lengthRule.MaxLength != 50 {
|
||||||
|
t.Errorf("table_name_length.MaxLength = %d, want 50", lengthRule.MaxLength)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadConfig_InvalidYAML(t *testing.T) {
|
||||||
|
// Create a temporary invalid config file
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
configPath := filepath.Join(tmpDir, "invalid-config.yaml")
|
||||||
|
|
||||||
|
invalidContent := `invalid: yaml: content: {[}]`
|
||||||
|
|
||||||
|
err := os.WriteFile(configPath, []byte(invalidContent), 0644)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create test config file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = LoadConfig(configPath)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("LoadConfig() with invalid YAML did not return error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRuleIsEnabled(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
rule Rule
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "enforce is enabled",
|
||||||
|
rule: Rule{Enabled: "enforce"},
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "warn is enabled",
|
||||||
|
rule: Rule{Enabled: "warn"},
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "off is not enabled",
|
||||||
|
rule: Rule{Enabled: "off"},
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty is not enabled",
|
||||||
|
rule: Rule{Enabled: ""},
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := tt.rule.IsEnabled(); got != tt.want {
|
||||||
|
t.Errorf("Rule.IsEnabled() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRuleIsEnforced(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
rule Rule
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "enforce is enforced",
|
||||||
|
rule: Rule{Enabled: "enforce"},
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "warn is not enforced",
|
||||||
|
rule: Rule{Enabled: "warn"},
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "off is not enforced",
|
||||||
|
rule: Rule{Enabled: "off"},
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := tt.rule.IsEnforced(); got != tt.want {
|
||||||
|
t.Errorf("Rule.IsEnforced() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDefaultConfigRuleSettings(t *testing.T) {
|
||||||
|
config := GetDefaultConfig()
|
||||||
|
|
||||||
|
// Test specific rule settings
|
||||||
|
pkNamingRule := config.Rules["primary_key_naming"]
|
||||||
|
if pkNamingRule.Function != "primary_key_naming" {
|
||||||
|
t.Errorf("primary_key_naming.Function = %q, want \"primary_key_naming\"", pkNamingRule.Function)
|
||||||
|
}
|
||||||
|
|
||||||
|
if pkNamingRule.Pattern != "^id_" {
|
||||||
|
t.Errorf("primary_key_naming.Pattern = %q, want \"^id_\"", pkNamingRule.Pattern)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test datatype rule
|
||||||
|
pkDatatypeRule := config.Rules["primary_key_datatype"]
|
||||||
|
if len(pkDatatypeRule.AllowedTypes) == 0 {
|
||||||
|
t.Error("primary_key_datatype has no allowed types")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test length rule
|
||||||
|
tableLengthRule := config.Rules["table_name_length"]
|
||||||
|
if tableLengthRule.MaxLength != 64 {
|
||||||
|
t.Errorf("table_name_length.MaxLength = %d, want 64", tableLengthRule.MaxLength)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test reserved keywords rule
|
||||||
|
reservedRule := config.Rules["reserved_keywords"]
|
||||||
|
if !reservedRule.CheckTables {
|
||||||
|
t.Error("reserved_keywords.CheckTables should be true")
|
||||||
|
}
|
||||||
|
if !reservedRule.CheckColumns {
|
||||||
|
t.Error("reserved_keywords.CheckColumns should be true")
|
||||||
|
}
|
||||||
|
}
|
||||||
837
pkg/inspector/validators_test.go
Normal file
837
pkg/inspector/validators_test.go
Normal file
@@ -0,0 +1,837 @@
|
|||||||
|
package inspector
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Helper function to create test database
|
||||||
|
func createTestDatabase() *models.Database {
|
||||||
|
return &models.Database{
|
||||||
|
Name: "testdb",
|
||||||
|
Schemas: []*models.Schema{
|
||||||
|
{
|
||||||
|
Name: "public",
|
||||||
|
Tables: []*models.Table{
|
||||||
|
{
|
||||||
|
Name: "users",
|
||||||
|
Columns: map[string]*models.Column{
|
||||||
|
"id": {
|
||||||
|
Name: "id",
|
||||||
|
Type: "bigserial",
|
||||||
|
IsPrimaryKey: true,
|
||||||
|
AutoIncrement: true,
|
||||||
|
},
|
||||||
|
"username": {
|
||||||
|
Name: "username",
|
||||||
|
Type: "varchar(50)",
|
||||||
|
NotNull: true,
|
||||||
|
IsPrimaryKey: false,
|
||||||
|
},
|
||||||
|
"rid_organization": {
|
||||||
|
Name: "rid_organization",
|
||||||
|
Type: "bigint",
|
||||||
|
NotNull: true,
|
||||||
|
IsPrimaryKey: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Constraints: map[string]*models.Constraint{
|
||||||
|
"fk_users_organization": {
|
||||||
|
Name: "fk_users_organization",
|
||||||
|
Type: models.ForeignKeyConstraint,
|
||||||
|
Columns: []string{"rid_organization"},
|
||||||
|
ReferencedTable: "organizations",
|
||||||
|
ReferencedSchema: "public",
|
||||||
|
ReferencedColumns: []string{"id"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Indexes: map[string]*models.Index{
|
||||||
|
"idx_rid_organization": {
|
||||||
|
Name: "idx_rid_organization",
|
||||||
|
Columns: []string{"rid_organization"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "organizations",
|
||||||
|
Columns: map[string]*models.Column{
|
||||||
|
"id": {
|
||||||
|
Name: "id",
|
||||||
|
Type: "bigserial",
|
||||||
|
IsPrimaryKey: true,
|
||||||
|
AutoIncrement: true,
|
||||||
|
},
|
||||||
|
"name": {
|
||||||
|
Name: "name",
|
||||||
|
Type: "varchar(100)",
|
||||||
|
NotNull: true,
|
||||||
|
IsPrimaryKey: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidatePrimaryKeyNaming(t *testing.T) {
|
||||||
|
db := createTestDatabase()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
rule Rule
|
||||||
|
wantLen int
|
||||||
|
wantPass bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "matching pattern id",
|
||||||
|
rule: Rule{
|
||||||
|
Pattern: "^id$",
|
||||||
|
Message: "Primary key should be 'id'",
|
||||||
|
},
|
||||||
|
wantLen: 2,
|
||||||
|
wantPass: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "non-matching pattern id_",
|
||||||
|
rule: Rule{
|
||||||
|
Pattern: "^id_",
|
||||||
|
Message: "Primary key should start with 'id_'",
|
||||||
|
},
|
||||||
|
wantLen: 2,
|
||||||
|
wantPass: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
results := validatePrimaryKeyNaming(db, tt.rule, "test_rule")
|
||||||
|
if len(results) != tt.wantLen {
|
||||||
|
t.Errorf("validatePrimaryKeyNaming() returned %d results, want %d", len(results), tt.wantLen)
|
||||||
|
}
|
||||||
|
if len(results) > 0 && results[0].Passed != tt.wantPass {
|
||||||
|
t.Errorf("validatePrimaryKeyNaming() passed=%v, want %v", results[0].Passed, tt.wantPass)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidatePrimaryKeyDatatype(t *testing.T) {
|
||||||
|
db := createTestDatabase()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
rule Rule
|
||||||
|
wantLen int
|
||||||
|
wantPass bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "allowed type bigserial",
|
||||||
|
rule: Rule{
|
||||||
|
AllowedTypes: []string{"bigserial", "bigint", "int"},
|
||||||
|
Message: "Primary key should use integer types",
|
||||||
|
},
|
||||||
|
wantLen: 2,
|
||||||
|
wantPass: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "disallowed type",
|
||||||
|
rule: Rule{
|
||||||
|
AllowedTypes: []string{"uuid"},
|
||||||
|
Message: "Primary key should use UUID",
|
||||||
|
},
|
||||||
|
wantLen: 2,
|
||||||
|
wantPass: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
results := validatePrimaryKeyDatatype(db, tt.rule, "test_rule")
|
||||||
|
if len(results) != tt.wantLen {
|
||||||
|
t.Errorf("validatePrimaryKeyDatatype() returned %d results, want %d", len(results), tt.wantLen)
|
||||||
|
}
|
||||||
|
if len(results) > 0 && results[0].Passed != tt.wantPass {
|
||||||
|
t.Errorf("validatePrimaryKeyDatatype() passed=%v, want %v", results[0].Passed, tt.wantPass)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidatePrimaryKeyAutoIncrement(t *testing.T) {
|
||||||
|
db := createTestDatabase()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
rule Rule
|
||||||
|
wantLen int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "require auto increment",
|
||||||
|
rule: Rule{
|
||||||
|
RequireAutoIncrement: true,
|
||||||
|
Message: "Primary key should have auto-increment",
|
||||||
|
},
|
||||||
|
wantLen: 0, // No violations - all PKs have auto-increment
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "disallow auto increment",
|
||||||
|
rule: Rule{
|
||||||
|
RequireAutoIncrement: false,
|
||||||
|
Message: "Primary key should not have auto-increment",
|
||||||
|
},
|
||||||
|
wantLen: 2, // 2 violations - both PKs have auto-increment
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
results := validatePrimaryKeyAutoIncrement(db, tt.rule, "test_rule")
|
||||||
|
if len(results) != tt.wantLen {
|
||||||
|
t.Errorf("validatePrimaryKeyAutoIncrement() returned %d results, want %d", len(results), tt.wantLen)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateForeignKeyColumnNaming(t *testing.T) {
|
||||||
|
db := createTestDatabase()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
rule Rule
|
||||||
|
wantLen int
|
||||||
|
wantPass bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "matching pattern rid_",
|
||||||
|
rule: Rule{
|
||||||
|
Pattern: "^rid_",
|
||||||
|
Message: "Foreign key columns should start with 'rid_'",
|
||||||
|
},
|
||||||
|
wantLen: 1,
|
||||||
|
wantPass: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "non-matching pattern fk_",
|
||||||
|
rule: Rule{
|
||||||
|
Pattern: "^fk_",
|
||||||
|
Message: "Foreign key columns should start with 'fk_'",
|
||||||
|
},
|
||||||
|
wantLen: 1,
|
||||||
|
wantPass: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
results := validateForeignKeyColumnNaming(db, tt.rule, "test_rule")
|
||||||
|
if len(results) != tt.wantLen {
|
||||||
|
t.Errorf("validateForeignKeyColumnNaming() returned %d results, want %d", len(results), tt.wantLen)
|
||||||
|
}
|
||||||
|
if len(results) > 0 && results[0].Passed != tt.wantPass {
|
||||||
|
t.Errorf("validateForeignKeyColumnNaming() passed=%v, want %v", results[0].Passed, tt.wantPass)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateForeignKeyConstraintNaming(t *testing.T) {
|
||||||
|
db := createTestDatabase()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
rule Rule
|
||||||
|
wantLen int
|
||||||
|
wantPass bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "matching pattern fk_",
|
||||||
|
rule: Rule{
|
||||||
|
Pattern: "^fk_",
|
||||||
|
Message: "Foreign key constraints should start with 'fk_'",
|
||||||
|
},
|
||||||
|
wantLen: 1,
|
||||||
|
wantPass: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "non-matching pattern FK_",
|
||||||
|
rule: Rule{
|
||||||
|
Pattern: "^FK_",
|
||||||
|
Message: "Foreign key constraints should start with 'FK_'",
|
||||||
|
},
|
||||||
|
wantLen: 1,
|
||||||
|
wantPass: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
results := validateForeignKeyConstraintNaming(db, tt.rule, "test_rule")
|
||||||
|
if len(results) != tt.wantLen {
|
||||||
|
t.Errorf("validateForeignKeyConstraintNaming() returned %d results, want %d", len(results), tt.wantLen)
|
||||||
|
}
|
||||||
|
if len(results) > 0 && results[0].Passed != tt.wantPass {
|
||||||
|
t.Errorf("validateForeignKeyConstraintNaming() passed=%v, want %v", results[0].Passed, tt.wantPass)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateForeignKeyIndex(t *testing.T) {
|
||||||
|
db := createTestDatabase()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
rule Rule
|
||||||
|
wantLen int
|
||||||
|
wantPass bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "require index with index present",
|
||||||
|
rule: Rule{
|
||||||
|
RequireIndex: true,
|
||||||
|
Message: "Foreign key columns should have indexes",
|
||||||
|
},
|
||||||
|
wantLen: 1,
|
||||||
|
wantPass: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no requirement",
|
||||||
|
rule: Rule{
|
||||||
|
RequireIndex: false,
|
||||||
|
Message: "Foreign key index check disabled",
|
||||||
|
},
|
||||||
|
wantLen: 0,
|
||||||
|
wantPass: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
results := validateForeignKeyIndex(db, tt.rule, "test_rule")
|
||||||
|
if len(results) != tt.wantLen {
|
||||||
|
t.Errorf("validateForeignKeyIndex() returned %d results, want %d", len(results), tt.wantLen)
|
||||||
|
}
|
||||||
|
if len(results) > 0 && results[0].Passed != tt.wantPass {
|
||||||
|
t.Errorf("validateForeignKeyIndex() passed=%v, want %v", results[0].Passed, tt.wantPass)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateTableNamingCase(t *testing.T) {
|
||||||
|
db := createTestDatabase()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
rule Rule
|
||||||
|
wantLen int
|
||||||
|
wantPass bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "lowercase snake_case pattern",
|
||||||
|
rule: Rule{
|
||||||
|
Pattern: "^[a-z][a-z0-9_]*$",
|
||||||
|
Case: "lowercase",
|
||||||
|
Message: "Table names should be lowercase snake_case",
|
||||||
|
},
|
||||||
|
wantLen: 2,
|
||||||
|
wantPass: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "uppercase pattern",
|
||||||
|
rule: Rule{
|
||||||
|
Pattern: "^[A-Z][A-Z0-9_]*$",
|
||||||
|
Case: "uppercase",
|
||||||
|
Message: "Table names should be uppercase",
|
||||||
|
},
|
||||||
|
wantLen: 2,
|
||||||
|
wantPass: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
results := validateTableNamingCase(db, tt.rule, "test_rule")
|
||||||
|
if len(results) != tt.wantLen {
|
||||||
|
t.Errorf("validateTableNamingCase() returned %d results, want %d", len(results), tt.wantLen)
|
||||||
|
}
|
||||||
|
if len(results) > 0 && results[0].Passed != tt.wantPass {
|
||||||
|
t.Errorf("validateTableNamingCase() passed=%v, want %v", results[0].Passed, tt.wantPass)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateColumnNamingCase(t *testing.T) {
|
||||||
|
db := createTestDatabase()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
rule Rule
|
||||||
|
wantLen int
|
||||||
|
wantPass bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "lowercase snake_case pattern",
|
||||||
|
rule: Rule{
|
||||||
|
Pattern: "^[a-z][a-z0-9_]*$",
|
||||||
|
Case: "lowercase",
|
||||||
|
Message: "Column names should be lowercase snake_case",
|
||||||
|
},
|
||||||
|
wantLen: 5, // 5 total columns across both tables
|
||||||
|
wantPass: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "camelCase pattern",
|
||||||
|
rule: Rule{
|
||||||
|
Pattern: "^[a-z][a-zA-Z0-9]*$",
|
||||||
|
Case: "camelCase",
|
||||||
|
Message: "Column names should be camelCase",
|
||||||
|
},
|
||||||
|
wantLen: 5,
|
||||||
|
wantPass: false, // rid_organization has underscore
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
results := validateColumnNamingCase(db, tt.rule, "test_rule")
|
||||||
|
if len(results) != tt.wantLen {
|
||||||
|
t.Errorf("validateColumnNamingCase() returned %d results, want %d", len(results), tt.wantLen)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateTableNameLength(t *testing.T) {
|
||||||
|
db := createTestDatabase()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
rule Rule
|
||||||
|
wantLen int
|
||||||
|
wantPass bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "max length 64",
|
||||||
|
rule: Rule{
|
||||||
|
MaxLength: 64,
|
||||||
|
Message: "Table name too long",
|
||||||
|
},
|
||||||
|
wantLen: 2,
|
||||||
|
wantPass: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "max length 5",
|
||||||
|
rule: Rule{
|
||||||
|
MaxLength: 5,
|
||||||
|
Message: "Table name too long",
|
||||||
|
},
|
||||||
|
wantLen: 2,
|
||||||
|
wantPass: false, // "users" is 5 chars (passes), "organizations" is 13 (fails)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
results := validateTableNameLength(db, tt.rule, "test_rule")
|
||||||
|
if len(results) != tt.wantLen {
|
||||||
|
t.Errorf("validateTableNameLength() returned %d results, want %d", len(results), tt.wantLen)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateColumnNameLength(t *testing.T) {
|
||||||
|
db := createTestDatabase()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
rule Rule
|
||||||
|
wantLen int
|
||||||
|
wantPass bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "max length 64",
|
||||||
|
rule: Rule{
|
||||||
|
MaxLength: 64,
|
||||||
|
Message: "Column name too long",
|
||||||
|
},
|
||||||
|
wantLen: 5,
|
||||||
|
wantPass: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "max length 5",
|
||||||
|
rule: Rule{
|
||||||
|
MaxLength: 5,
|
||||||
|
Message: "Column name too long",
|
||||||
|
},
|
||||||
|
wantLen: 5,
|
||||||
|
wantPass: false, // Some columns exceed 5 chars
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
results := validateColumnNameLength(db, tt.rule, "test_rule")
|
||||||
|
if len(results) != tt.wantLen {
|
||||||
|
t.Errorf("validateColumnNameLength() returned %d results, want %d", len(results), tt.wantLen)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateReservedKeywords(t *testing.T) {
|
||||||
|
// Create a database with reserved keywords
|
||||||
|
db := &models.Database{
|
||||||
|
Name: "testdb",
|
||||||
|
Schemas: []*models.Schema{
|
||||||
|
{
|
||||||
|
Name: "public",
|
||||||
|
Tables: []*models.Table{
|
||||||
|
{
|
||||||
|
Name: "user", // "user" is a reserved keyword
|
||||||
|
Columns: map[string]*models.Column{
|
||||||
|
"id": {
|
||||||
|
Name: "id",
|
||||||
|
Type: "bigint",
|
||||||
|
IsPrimaryKey: true,
|
||||||
|
},
|
||||||
|
"select": { // "select" is a reserved keyword
|
||||||
|
Name: "select",
|
||||||
|
Type: "varchar(50)",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
rule Rule
|
||||||
|
wantLen int
|
||||||
|
checkPasses bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "check tables only",
|
||||||
|
rule: Rule{
|
||||||
|
CheckTables: true,
|
||||||
|
CheckColumns: false,
|
||||||
|
Message: "Reserved keyword used",
|
||||||
|
},
|
||||||
|
wantLen: 1, // "user" table
|
||||||
|
checkPasses: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "check columns only",
|
||||||
|
rule: Rule{
|
||||||
|
CheckTables: false,
|
||||||
|
CheckColumns: true,
|
||||||
|
Message: "Reserved keyword used",
|
||||||
|
},
|
||||||
|
wantLen: 2, // "id", "select" columns (id passes, select fails)
|
||||||
|
checkPasses: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "check both",
|
||||||
|
rule: Rule{
|
||||||
|
CheckTables: true,
|
||||||
|
CheckColumns: true,
|
||||||
|
Message: "Reserved keyword used",
|
||||||
|
},
|
||||||
|
wantLen: 3, // "user" table + "id", "select" columns
|
||||||
|
checkPasses: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
results := validateReservedKeywords(db, tt.rule, "test_rule")
|
||||||
|
if len(results) != tt.wantLen {
|
||||||
|
t.Errorf("validateReservedKeywords() returned %d results, want %d", len(results), tt.wantLen)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateMissingPrimaryKey(t *testing.T) {
|
||||||
|
// Create database with and without primary keys
|
||||||
|
db := &models.Database{
|
||||||
|
Name: "testdb",
|
||||||
|
Schemas: []*models.Schema{
|
||||||
|
{
|
||||||
|
Name: "public",
|
||||||
|
Tables: []*models.Table{
|
||||||
|
{
|
||||||
|
Name: "with_pk",
|
||||||
|
Columns: map[string]*models.Column{
|
||||||
|
"id": {
|
||||||
|
Name: "id",
|
||||||
|
Type: "bigint",
|
||||||
|
IsPrimaryKey: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "without_pk",
|
||||||
|
Columns: map[string]*models.Column{
|
||||||
|
"name": {
|
||||||
|
Name: "name",
|
||||||
|
Type: "varchar(50)",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
rule := Rule{
|
||||||
|
Message: "Table missing primary key",
|
||||||
|
}
|
||||||
|
|
||||||
|
results := validateMissingPrimaryKey(db, rule, "test_rule")
|
||||||
|
|
||||||
|
if len(results) != 2 {
|
||||||
|
t.Errorf("validateMissingPrimaryKey() returned %d results, want 2", len(results))
|
||||||
|
}
|
||||||
|
|
||||||
|
// First result should pass (with_pk has PK)
|
||||||
|
if results[0].Passed != true {
|
||||||
|
t.Errorf("validateMissingPrimaryKey() result[0].Passed=%v, want true", results[0].Passed)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Second result should fail (without_pk missing PK)
|
||||||
|
if results[1].Passed != false {
|
||||||
|
t.Errorf("validateMissingPrimaryKey() result[1].Passed=%v, want false", results[1].Passed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateOrphanedForeignKey(t *testing.T) {
|
||||||
|
// Create database with orphaned FK
|
||||||
|
db := &models.Database{
|
||||||
|
Name: "testdb",
|
||||||
|
Schemas: []*models.Schema{
|
||||||
|
{
|
||||||
|
Name: "public",
|
||||||
|
Tables: []*models.Table{
|
||||||
|
{
|
||||||
|
Name: "users",
|
||||||
|
Columns: map[string]*models.Column{
|
||||||
|
"id": {
|
||||||
|
Name: "id",
|
||||||
|
Type: "bigint",
|
||||||
|
IsPrimaryKey: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Constraints: map[string]*models.Constraint{
|
||||||
|
"fk_nonexistent": {
|
||||||
|
Name: "fk_nonexistent",
|
||||||
|
Type: models.ForeignKeyConstraint,
|
||||||
|
Columns: []string{"rid_organization"},
|
||||||
|
ReferencedTable: "nonexistent_table",
|
||||||
|
ReferencedSchema: "public",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
rule := Rule{
|
||||||
|
Message: "Foreign key references non-existent table",
|
||||||
|
}
|
||||||
|
|
||||||
|
results := validateOrphanedForeignKey(db, rule, "test_rule")
|
||||||
|
|
||||||
|
if len(results) != 1 {
|
||||||
|
t.Errorf("validateOrphanedForeignKey() returned %d results, want 1", len(results))
|
||||||
|
}
|
||||||
|
|
||||||
|
if results[0].Passed != false {
|
||||||
|
t.Errorf("validateOrphanedForeignKey() passed=%v, want false", results[0].Passed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateCircularDependency(t *testing.T) {
|
||||||
|
// Create database with circular dependency
|
||||||
|
db := &models.Database{
|
||||||
|
Name: "testdb",
|
||||||
|
Schemas: []*models.Schema{
|
||||||
|
{
|
||||||
|
Name: "public",
|
||||||
|
Tables: []*models.Table{
|
||||||
|
{
|
||||||
|
Name: "table_a",
|
||||||
|
Columns: map[string]*models.Column{
|
||||||
|
"id": {Name: "id", Type: "bigint", IsPrimaryKey: true},
|
||||||
|
},
|
||||||
|
Constraints: map[string]*models.Constraint{
|
||||||
|
"fk_to_b": {
|
||||||
|
Name: "fk_to_b",
|
||||||
|
Type: models.ForeignKeyConstraint,
|
||||||
|
ReferencedTable: "table_b",
|
||||||
|
ReferencedSchema: "public",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "table_b",
|
||||||
|
Columns: map[string]*models.Column{
|
||||||
|
"id": {Name: "id", Type: "bigint", IsPrimaryKey: true},
|
||||||
|
},
|
||||||
|
Constraints: map[string]*models.Constraint{
|
||||||
|
"fk_to_a": {
|
||||||
|
Name: "fk_to_a",
|
||||||
|
Type: models.ForeignKeyConstraint,
|
||||||
|
ReferencedTable: "table_a",
|
||||||
|
ReferencedSchema: "public",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
rule := Rule{
|
||||||
|
Message: "Circular dependency detected",
|
||||||
|
}
|
||||||
|
|
||||||
|
results := validateCircularDependency(db, rule, "test_rule")
|
||||||
|
|
||||||
|
// Should detect circular dependency in both tables
|
||||||
|
if len(results) == 0 {
|
||||||
|
t.Error("validateCircularDependency() returned 0 results, expected circular dependency detection")
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, result := range results {
|
||||||
|
if result.Passed {
|
||||||
|
t.Error("validateCircularDependency() passed=true, want false for circular dependency")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeDataType(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
input string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{"varchar(50)", "varchar"},
|
||||||
|
{"decimal(10,2)", "decimal"},
|
||||||
|
{"int", "int"},
|
||||||
|
{"BIGINT", "bigint"},
|
||||||
|
{"VARCHAR(255)", "varchar"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.input, func(t *testing.T) {
|
||||||
|
result := normalizeDataType(tt.input)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("normalizeDataType(%q) = %q, want %q", tt.input, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestContains(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
slice []string
|
||||||
|
value string
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{"found exact", []string{"foo", "bar", "baz"}, "bar", true},
|
||||||
|
{"not found", []string{"foo", "bar", "baz"}, "qux", false},
|
||||||
|
{"case insensitive match", []string{"foo", "Bar", "baz"}, "bar", true},
|
||||||
|
{"empty slice", []string{}, "foo", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := contains(tt.slice, tt.value)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("contains(%v, %q) = %v, want %v", tt.slice, tt.value, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHasCycle(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
graph map[string][]string
|
||||||
|
node string
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "simple cycle",
|
||||||
|
graph: map[string][]string{
|
||||||
|
"A": {"B"},
|
||||||
|
"B": {"C"},
|
||||||
|
"C": {"A"},
|
||||||
|
},
|
||||||
|
node: "A",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no cycle",
|
||||||
|
graph: map[string][]string{
|
||||||
|
"A": {"B"},
|
||||||
|
"B": {"C"},
|
||||||
|
"C": {},
|
||||||
|
},
|
||||||
|
node: "A",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "self cycle",
|
||||||
|
graph: map[string][]string{
|
||||||
|
"A": {"A"},
|
||||||
|
},
|
||||||
|
node: "A",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
visited := make(map[string]bool)
|
||||||
|
recStack := make(map[string]bool)
|
||||||
|
result := hasCycle(tt.node, tt.graph, visited, recStack)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("hasCycle() = %v, want %v", result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFormatLocation(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
schema string
|
||||||
|
table string
|
||||||
|
column string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{"public", "users", "id", "public.users.id"},
|
||||||
|
{"public", "users", "", "public.users"},
|
||||||
|
{"public", "", "", "public"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.expected, func(t *testing.T) {
|
||||||
|
result := formatLocation(tt.schema, tt.table, tt.column)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("formatLocation(%q, %q, %q) = %q, want %q",
|
||||||
|
tt.schema, tt.table, tt.column, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -12,14 +12,16 @@ import (
|
|||||||
|
|
||||||
// MergeResult represents the result of a merge operation
|
// MergeResult represents the result of a merge operation
|
||||||
type MergeResult struct {
|
type MergeResult struct {
|
||||||
SchemasAdded int
|
SchemasAdded int
|
||||||
TablesAdded int
|
TablesAdded int
|
||||||
ColumnsAdded int
|
ColumnsAdded int
|
||||||
RelationsAdded int
|
ConstraintsAdded int
|
||||||
DomainsAdded int
|
IndexesAdded int
|
||||||
EnumsAdded int
|
RelationsAdded int
|
||||||
ViewsAdded int
|
DomainsAdded int
|
||||||
SequencesAdded int
|
EnumsAdded int
|
||||||
|
ViewsAdded int
|
||||||
|
SequencesAdded int
|
||||||
}
|
}
|
||||||
|
|
||||||
// MergeOptions contains options for merge operations
|
// MergeOptions contains options for merge operations
|
||||||
@@ -120,8 +122,10 @@ func (r *MergeResult) mergeTables(schema *models.Schema, source *models.Schema,
|
|||||||
}
|
}
|
||||||
|
|
||||||
if tgtTable, exists := existingTables[tableName]; exists {
|
if tgtTable, exists := existingTables[tableName]; exists {
|
||||||
// Table exists, merge its columns
|
// Table exists, merge its columns, constraints, and indexes
|
||||||
r.mergeColumns(tgtTable, srcTable)
|
r.mergeColumns(tgtTable, srcTable)
|
||||||
|
r.mergeConstraints(tgtTable, srcTable)
|
||||||
|
r.mergeIndexes(tgtTable, srcTable)
|
||||||
} else {
|
} else {
|
||||||
// Table doesn't exist, add it
|
// Table doesn't exist, add it
|
||||||
newTable := cloneTable(srcTable)
|
newTable := cloneTable(srcTable)
|
||||||
@@ -151,6 +155,52 @@ func (r *MergeResult) mergeColumns(table *models.Table, srcTable *models.Table)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *MergeResult) mergeConstraints(table *models.Table, srcTable *models.Table) {
|
||||||
|
// Initialize constraints map if nil
|
||||||
|
if table.Constraints == nil {
|
||||||
|
table.Constraints = make(map[string]*models.Constraint)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create map of existing constraints
|
||||||
|
existingConstraints := make(map[string]*models.Constraint)
|
||||||
|
for constName := range table.Constraints {
|
||||||
|
existingConstraints[constName] = table.Constraints[constName]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Merge constraints
|
||||||
|
for constName, srcConst := range srcTable.Constraints {
|
||||||
|
if _, exists := existingConstraints[constName]; !exists {
|
||||||
|
// Constraint doesn't exist, add it
|
||||||
|
newConst := cloneConstraint(srcConst)
|
||||||
|
table.Constraints[constName] = newConst
|
||||||
|
r.ConstraintsAdded++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *MergeResult) mergeIndexes(table *models.Table, srcTable *models.Table) {
|
||||||
|
// Initialize indexes map if nil
|
||||||
|
if table.Indexes == nil {
|
||||||
|
table.Indexes = make(map[string]*models.Index)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create map of existing indexes
|
||||||
|
existingIndexes := make(map[string]*models.Index)
|
||||||
|
for idxName := range table.Indexes {
|
||||||
|
existingIndexes[idxName] = table.Indexes[idxName]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Merge indexes
|
||||||
|
for idxName, srcIdx := range srcTable.Indexes {
|
||||||
|
if _, exists := existingIndexes[idxName]; !exists {
|
||||||
|
// Index doesn't exist, add it
|
||||||
|
newIdx := cloneIndex(srcIdx)
|
||||||
|
table.Indexes[idxName] = newIdx
|
||||||
|
r.IndexesAdded++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (r *MergeResult) mergeViews(schema *models.Schema, source *models.Schema) {
|
func (r *MergeResult) mergeViews(schema *models.Schema, source *models.Schema) {
|
||||||
// Create map of existing views
|
// Create map of existing views
|
||||||
existingViews := make(map[string]*models.View)
|
existingViews := make(map[string]*models.View)
|
||||||
@@ -552,6 +602,8 @@ func GetMergeSummary(result *MergeResult) string {
|
|||||||
fmt.Sprintf("Schemas added: %d", result.SchemasAdded),
|
fmt.Sprintf("Schemas added: %d", result.SchemasAdded),
|
||||||
fmt.Sprintf("Tables added: %d", result.TablesAdded),
|
fmt.Sprintf("Tables added: %d", result.TablesAdded),
|
||||||
fmt.Sprintf("Columns added: %d", result.ColumnsAdded),
|
fmt.Sprintf("Columns added: %d", result.ColumnsAdded),
|
||||||
|
fmt.Sprintf("Constraints added: %d", result.ConstraintsAdded),
|
||||||
|
fmt.Sprintf("Indexes added: %d", result.IndexesAdded),
|
||||||
fmt.Sprintf("Views added: %d", result.ViewsAdded),
|
fmt.Sprintf("Views added: %d", result.ViewsAdded),
|
||||||
fmt.Sprintf("Sequences added: %d", result.SequencesAdded),
|
fmt.Sprintf("Sequences added: %d", result.SequencesAdded),
|
||||||
fmt.Sprintf("Enums added: %d", result.EnumsAdded),
|
fmt.Sprintf("Enums added: %d", result.EnumsAdded),
|
||||||
@@ -560,6 +612,7 @@ func GetMergeSummary(result *MergeResult) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
totalAdded := result.SchemasAdded + result.TablesAdded + result.ColumnsAdded +
|
totalAdded := result.SchemasAdded + result.TablesAdded + result.ColumnsAdded +
|
||||||
|
result.ConstraintsAdded + result.IndexesAdded +
|
||||||
result.ViewsAdded + result.SequencesAdded + result.EnumsAdded +
|
result.ViewsAdded + result.SequencesAdded + result.EnumsAdded +
|
||||||
result.RelationsAdded + result.DomainsAdded
|
result.RelationsAdded + result.DomainsAdded
|
||||||
|
|
||||||
|
|||||||
617
pkg/merge/merge_test.go
Normal file
617
pkg/merge/merge_test.go
Normal file
@@ -0,0 +1,617 @@
|
|||||||
|
package merge
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMergeDatabases_NilInputs(t *testing.T) {
|
||||||
|
result := MergeDatabases(nil, nil, nil)
|
||||||
|
if result == nil {
|
||||||
|
t.Fatal("Expected non-nil result")
|
||||||
|
}
|
||||||
|
if result.SchemasAdded != 0 {
|
||||||
|
t.Errorf("Expected 0 schemas added, got %d", result.SchemasAdded)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMergeDatabases_NewSchema(t *testing.T) {
|
||||||
|
target := &models.Database{
|
||||||
|
Schemas: []*models.Schema{
|
||||||
|
{Name: "public"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
source := &models.Database{
|
||||||
|
Schemas: []*models.Schema{
|
||||||
|
{Name: "auth"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := MergeDatabases(target, source, nil)
|
||||||
|
if result.SchemasAdded != 1 {
|
||||||
|
t.Errorf("Expected 1 schema added, got %d", result.SchemasAdded)
|
||||||
|
}
|
||||||
|
if len(target.Schemas) != 2 {
|
||||||
|
t.Errorf("Expected 2 schemas in target, got %d", len(target.Schemas))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMergeDatabases_ExistingSchema(t *testing.T) {
|
||||||
|
target := &models.Database{
|
||||||
|
Schemas: []*models.Schema{
|
||||||
|
{Name: "public"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
source := &models.Database{
|
||||||
|
Schemas: []*models.Schema{
|
||||||
|
{Name: "public"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := MergeDatabases(target, source, nil)
|
||||||
|
if result.SchemasAdded != 0 {
|
||||||
|
t.Errorf("Expected 0 schemas added, got %d", result.SchemasAdded)
|
||||||
|
}
|
||||||
|
if len(target.Schemas) != 1 {
|
||||||
|
t.Errorf("Expected 1 schema in target, got %d", len(target.Schemas))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMergeTables_NewTable(t *testing.T) {
|
||||||
|
target := &models.Database{
|
||||||
|
Schemas: []*models.Schema{
|
||||||
|
{
|
||||||
|
Name: "public",
|
||||||
|
Tables: []*models.Table{
|
||||||
|
{
|
||||||
|
Name: "users",
|
||||||
|
Schema: "public",
|
||||||
|
Columns: map[string]*models.Column{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
source := &models.Database{
|
||||||
|
Schemas: []*models.Schema{
|
||||||
|
{
|
||||||
|
Name: "public",
|
||||||
|
Tables: []*models.Table{
|
||||||
|
{
|
||||||
|
Name: "posts",
|
||||||
|
Schema: "public",
|
||||||
|
Columns: map[string]*models.Column{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := MergeDatabases(target, source, nil)
|
||||||
|
if result.TablesAdded != 1 {
|
||||||
|
t.Errorf("Expected 1 table added, got %d", result.TablesAdded)
|
||||||
|
}
|
||||||
|
if len(target.Schemas[0].Tables) != 2 {
|
||||||
|
t.Errorf("Expected 2 tables in target schema, got %d", len(target.Schemas[0].Tables))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMergeColumns_NewColumn(t *testing.T) {
|
||||||
|
target := &models.Database{
|
||||||
|
Schemas: []*models.Schema{
|
||||||
|
{
|
||||||
|
Name: "public",
|
||||||
|
Tables: []*models.Table{
|
||||||
|
{
|
||||||
|
Name: "users",
|
||||||
|
Schema: "public",
|
||||||
|
Columns: map[string]*models.Column{
|
||||||
|
"id": {Name: "id", Type: "int"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
source := &models.Database{
|
||||||
|
Schemas: []*models.Schema{
|
||||||
|
{
|
||||||
|
Name: "public",
|
||||||
|
Tables: []*models.Table{
|
||||||
|
{
|
||||||
|
Name: "users",
|
||||||
|
Schema: "public",
|
||||||
|
Columns: map[string]*models.Column{
|
||||||
|
"email": {Name: "email", Type: "varchar"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := MergeDatabases(target, source, nil)
|
||||||
|
if result.ColumnsAdded != 1 {
|
||||||
|
t.Errorf("Expected 1 column added, got %d", result.ColumnsAdded)
|
||||||
|
}
|
||||||
|
if len(target.Schemas[0].Tables[0].Columns) != 2 {
|
||||||
|
t.Errorf("Expected 2 columns in target table, got %d", len(target.Schemas[0].Tables[0].Columns))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMergeConstraints_NewConstraint(t *testing.T) {
|
||||||
|
target := &models.Database{
|
||||||
|
Schemas: []*models.Schema{
|
||||||
|
{
|
||||||
|
Name: "public",
|
||||||
|
Tables: []*models.Table{
|
||||||
|
{
|
||||||
|
Name: "users",
|
||||||
|
Schema: "public",
|
||||||
|
Columns: map[string]*models.Column{},
|
||||||
|
Constraints: map[string]*models.Constraint{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
source := &models.Database{
|
||||||
|
Schemas: []*models.Schema{
|
||||||
|
{
|
||||||
|
Name: "public",
|
||||||
|
Tables: []*models.Table{
|
||||||
|
{
|
||||||
|
Name: "users",
|
||||||
|
Schema: "public",
|
||||||
|
Columns: map[string]*models.Column{},
|
||||||
|
Constraints: map[string]*models.Constraint{
|
||||||
|
"ukey_users_email": {
|
||||||
|
Type: models.UniqueConstraint,
|
||||||
|
Columns: []string{"email"},
|
||||||
|
Name: "ukey_users_email",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := MergeDatabases(target, source, nil)
|
||||||
|
if result.ConstraintsAdded != 1 {
|
||||||
|
t.Errorf("Expected 1 constraint added, got %d", result.ConstraintsAdded)
|
||||||
|
}
|
||||||
|
if len(target.Schemas[0].Tables[0].Constraints) != 1 {
|
||||||
|
t.Errorf("Expected 1 constraint in target table, got %d", len(target.Schemas[0].Tables[0].Constraints))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMergeConstraints_NilConstraintsMap(t *testing.T) {
|
||||||
|
target := &models.Database{
|
||||||
|
Schemas: []*models.Schema{
|
||||||
|
{
|
||||||
|
Name: "public",
|
||||||
|
Tables: []*models.Table{
|
||||||
|
{
|
||||||
|
Name: "users",
|
||||||
|
Schema: "public",
|
||||||
|
Columns: map[string]*models.Column{},
|
||||||
|
Constraints: nil, // Nil map
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
source := &models.Database{
|
||||||
|
Schemas: []*models.Schema{
|
||||||
|
{
|
||||||
|
Name: "public",
|
||||||
|
Tables: []*models.Table{
|
||||||
|
{
|
||||||
|
Name: "users",
|
||||||
|
Schema: "public",
|
||||||
|
Columns: map[string]*models.Column{},
|
||||||
|
Constraints: map[string]*models.Constraint{
|
||||||
|
"ukey_users_email": {
|
||||||
|
Type: models.UniqueConstraint,
|
||||||
|
Columns: []string{"email"},
|
||||||
|
Name: "ukey_users_email",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := MergeDatabases(target, source, nil)
|
||||||
|
if result.ConstraintsAdded != 1 {
|
||||||
|
t.Errorf("Expected 1 constraint added, got %d", result.ConstraintsAdded)
|
||||||
|
}
|
||||||
|
if target.Schemas[0].Tables[0].Constraints == nil {
|
||||||
|
t.Error("Expected constraints map to be initialized")
|
||||||
|
}
|
||||||
|
if len(target.Schemas[0].Tables[0].Constraints) != 1 {
|
||||||
|
t.Errorf("Expected 1 constraint in target table, got %d", len(target.Schemas[0].Tables[0].Constraints))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMergeIndexes_NewIndex(t *testing.T) {
|
||||||
|
target := &models.Database{
|
||||||
|
Schemas: []*models.Schema{
|
||||||
|
{
|
||||||
|
Name: "public",
|
||||||
|
Tables: []*models.Table{
|
||||||
|
{
|
||||||
|
Name: "users",
|
||||||
|
Schema: "public",
|
||||||
|
Columns: map[string]*models.Column{},
|
||||||
|
Indexes: map[string]*models.Index{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
source := &models.Database{
|
||||||
|
Schemas: []*models.Schema{
|
||||||
|
{
|
||||||
|
Name: "public",
|
||||||
|
Tables: []*models.Table{
|
||||||
|
{
|
||||||
|
Name: "users",
|
||||||
|
Schema: "public",
|
||||||
|
Columns: map[string]*models.Column{},
|
||||||
|
Indexes: map[string]*models.Index{
|
||||||
|
"idx_users_email": {
|
||||||
|
Name: "idx_users_email",
|
||||||
|
Columns: []string{"email"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := MergeDatabases(target, source, nil)
|
||||||
|
if result.IndexesAdded != 1 {
|
||||||
|
t.Errorf("Expected 1 index added, got %d", result.IndexesAdded)
|
||||||
|
}
|
||||||
|
if len(target.Schemas[0].Tables[0].Indexes) != 1 {
|
||||||
|
t.Errorf("Expected 1 index in target table, got %d", len(target.Schemas[0].Tables[0].Indexes))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMergeIndexes_NilIndexesMap(t *testing.T) {
|
||||||
|
target := &models.Database{
|
||||||
|
Schemas: []*models.Schema{
|
||||||
|
{
|
||||||
|
Name: "public",
|
||||||
|
Tables: []*models.Table{
|
||||||
|
{
|
||||||
|
Name: "users",
|
||||||
|
Schema: "public",
|
||||||
|
Columns: map[string]*models.Column{},
|
||||||
|
Indexes: nil, // Nil map
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
source := &models.Database{
|
||||||
|
Schemas: []*models.Schema{
|
||||||
|
{
|
||||||
|
Name: "public",
|
||||||
|
Tables: []*models.Table{
|
||||||
|
{
|
||||||
|
Name: "users",
|
||||||
|
Schema: "public",
|
||||||
|
Columns: map[string]*models.Column{},
|
||||||
|
Indexes: map[string]*models.Index{
|
||||||
|
"idx_users_email": {
|
||||||
|
Name: "idx_users_email",
|
||||||
|
Columns: []string{"email"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := MergeDatabases(target, source, nil)
|
||||||
|
if result.IndexesAdded != 1 {
|
||||||
|
t.Errorf("Expected 1 index added, got %d", result.IndexesAdded)
|
||||||
|
}
|
||||||
|
if target.Schemas[0].Tables[0].Indexes == nil {
|
||||||
|
t.Error("Expected indexes map to be initialized")
|
||||||
|
}
|
||||||
|
if len(target.Schemas[0].Tables[0].Indexes) != 1 {
|
||||||
|
t.Errorf("Expected 1 index in target table, got %d", len(target.Schemas[0].Tables[0].Indexes))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMergeOptions_SkipTableNames(t *testing.T) {
|
||||||
|
target := &models.Database{
|
||||||
|
Schemas: []*models.Schema{
|
||||||
|
{
|
||||||
|
Name: "public",
|
||||||
|
Tables: []*models.Table{
|
||||||
|
{
|
||||||
|
Name: "users",
|
||||||
|
Schema: "public",
|
||||||
|
Columns: map[string]*models.Column{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
source := &models.Database{
|
||||||
|
Schemas: []*models.Schema{
|
||||||
|
{
|
||||||
|
Name: "public",
|
||||||
|
Tables: []*models.Table{
|
||||||
|
{
|
||||||
|
Name: "migrations",
|
||||||
|
Schema: "public",
|
||||||
|
Columns: map[string]*models.Column{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
opts := &MergeOptions{
|
||||||
|
SkipTableNames: map[string]bool{
|
||||||
|
"migrations": true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := MergeDatabases(target, source, opts)
|
||||||
|
if result.TablesAdded != 0 {
|
||||||
|
t.Errorf("Expected 0 tables added (skipped), got %d", result.TablesAdded)
|
||||||
|
}
|
||||||
|
if len(target.Schemas[0].Tables) != 1 {
|
||||||
|
t.Errorf("Expected 1 table in target schema, got %d", len(target.Schemas[0].Tables))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMergeViews_NewView(t *testing.T) {
|
||||||
|
target := &models.Database{
|
||||||
|
Schemas: []*models.Schema{
|
||||||
|
{
|
||||||
|
Name: "public",
|
||||||
|
Views: []*models.View{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
source := &models.Database{
|
||||||
|
Schemas: []*models.Schema{
|
||||||
|
{
|
||||||
|
Name: "public",
|
||||||
|
Views: []*models.View{
|
||||||
|
{
|
||||||
|
Name: "user_summary",
|
||||||
|
Schema: "public",
|
||||||
|
Definition: "SELECT * FROM users",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := MergeDatabases(target, source, nil)
|
||||||
|
if result.ViewsAdded != 1 {
|
||||||
|
t.Errorf("Expected 1 view added, got %d", result.ViewsAdded)
|
||||||
|
}
|
||||||
|
if len(target.Schemas[0].Views) != 1 {
|
||||||
|
t.Errorf("Expected 1 view in target schema, got %d", len(target.Schemas[0].Views))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMergeEnums_NewEnum(t *testing.T) {
|
||||||
|
target := &models.Database{
|
||||||
|
Schemas: []*models.Schema{
|
||||||
|
{
|
||||||
|
Name: "public",
|
||||||
|
Enums: []*models.Enum{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
source := &models.Database{
|
||||||
|
Schemas: []*models.Schema{
|
||||||
|
{
|
||||||
|
Name: "public",
|
||||||
|
Enums: []*models.Enum{
|
||||||
|
{
|
||||||
|
Name: "user_role",
|
||||||
|
Schema: "public",
|
||||||
|
Values: []string{"admin", "user"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := MergeDatabases(target, source, nil)
|
||||||
|
if result.EnumsAdded != 1 {
|
||||||
|
t.Errorf("Expected 1 enum added, got %d", result.EnumsAdded)
|
||||||
|
}
|
||||||
|
if len(target.Schemas[0].Enums) != 1 {
|
||||||
|
t.Errorf("Expected 1 enum in target schema, got %d", len(target.Schemas[0].Enums))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMergeDomains_NewDomain(t *testing.T) {
|
||||||
|
target := &models.Database{
|
||||||
|
Domains: []*models.Domain{},
|
||||||
|
}
|
||||||
|
source := &models.Database{
|
||||||
|
Domains: []*models.Domain{
|
||||||
|
{
|
||||||
|
Name: "auth",
|
||||||
|
Description: "Authentication domain",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := MergeDatabases(target, source, nil)
|
||||||
|
if result.DomainsAdded != 1 {
|
||||||
|
t.Errorf("Expected 1 domain added, got %d", result.DomainsAdded)
|
||||||
|
}
|
||||||
|
if len(target.Domains) != 1 {
|
||||||
|
t.Errorf("Expected 1 domain in target, got %d", len(target.Domains))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMergeRelations_NewRelation(t *testing.T) {
|
||||||
|
target := &models.Database{
|
||||||
|
Schemas: []*models.Schema{
|
||||||
|
{
|
||||||
|
Name: "public",
|
||||||
|
Relations: []*models.Relationship{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
source := &models.Database{
|
||||||
|
Schemas: []*models.Schema{
|
||||||
|
{
|
||||||
|
Name: "public",
|
||||||
|
Relations: []*models.Relationship{
|
||||||
|
{
|
||||||
|
Name: "fk_posts_user",
|
||||||
|
Type: models.OneToMany,
|
||||||
|
FromTable: "posts",
|
||||||
|
FromColumns: []string{"user_id"},
|
||||||
|
ToTable: "users",
|
||||||
|
ToColumns: []string{"id"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := MergeDatabases(target, source, nil)
|
||||||
|
if result.RelationsAdded != 1 {
|
||||||
|
t.Errorf("Expected 1 relation added, got %d", result.RelationsAdded)
|
||||||
|
}
|
||||||
|
if len(target.Schemas[0].Relations) != 1 {
|
||||||
|
t.Errorf("Expected 1 relation in target schema, got %d", len(target.Schemas[0].Relations))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetMergeSummary(t *testing.T) {
|
||||||
|
result := &MergeResult{
|
||||||
|
SchemasAdded: 1,
|
||||||
|
TablesAdded: 2,
|
||||||
|
ColumnsAdded: 5,
|
||||||
|
ConstraintsAdded: 3,
|
||||||
|
IndexesAdded: 2,
|
||||||
|
ViewsAdded: 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
summary := GetMergeSummary(result)
|
||||||
|
if summary == "" {
|
||||||
|
t.Error("Expected non-empty summary")
|
||||||
|
}
|
||||||
|
if len(summary) < 50 {
|
||||||
|
t.Errorf("Summary seems too short: %s", summary)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetMergeSummary_Nil(t *testing.T) {
|
||||||
|
summary := GetMergeSummary(nil)
|
||||||
|
if summary == "" {
|
||||||
|
t.Error("Expected non-empty summary for nil result")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestComplexMerge(t *testing.T) {
|
||||||
|
// Target with existing structure
|
||||||
|
target := &models.Database{
|
||||||
|
Schemas: []*models.Schema{
|
||||||
|
{
|
||||||
|
Name: "public",
|
||||||
|
Tables: []*models.Table{
|
||||||
|
{
|
||||||
|
Name: "users",
|
||||||
|
Schema: "public",
|
||||||
|
Columns: map[string]*models.Column{
|
||||||
|
"id": {Name: "id", Type: "int"},
|
||||||
|
},
|
||||||
|
Constraints: map[string]*models.Constraint{},
|
||||||
|
Indexes: map[string]*models.Index{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Source with new columns, constraints, and indexes
|
||||||
|
source := &models.Database{
|
||||||
|
Schemas: []*models.Schema{
|
||||||
|
{
|
||||||
|
Name: "public",
|
||||||
|
Tables: []*models.Table{
|
||||||
|
{
|
||||||
|
Name: "users",
|
||||||
|
Schema: "public",
|
||||||
|
Columns: map[string]*models.Column{
|
||||||
|
"email": {Name: "email", Type: "varchar"},
|
||||||
|
"guid": {Name: "guid", Type: "uuid"},
|
||||||
|
},
|
||||||
|
Constraints: map[string]*models.Constraint{
|
||||||
|
"ukey_users_email": {
|
||||||
|
Type: models.UniqueConstraint,
|
||||||
|
Columns: []string{"email"},
|
||||||
|
Name: "ukey_users_email",
|
||||||
|
},
|
||||||
|
"ukey_users_guid": {
|
||||||
|
Type: models.UniqueConstraint,
|
||||||
|
Columns: []string{"guid"},
|
||||||
|
Name: "ukey_users_guid",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Indexes: map[string]*models.Index{
|
||||||
|
"idx_users_email": {
|
||||||
|
Name: "idx_users_email",
|
||||||
|
Columns: []string{"email"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := MergeDatabases(target, source, nil)
|
||||||
|
|
||||||
|
// Verify counts
|
||||||
|
if result.ColumnsAdded != 2 {
|
||||||
|
t.Errorf("Expected 2 columns added, got %d", result.ColumnsAdded)
|
||||||
|
}
|
||||||
|
if result.ConstraintsAdded != 2 {
|
||||||
|
t.Errorf("Expected 2 constraints added, got %d", result.ConstraintsAdded)
|
||||||
|
}
|
||||||
|
if result.IndexesAdded != 1 {
|
||||||
|
t.Errorf("Expected 1 index added, got %d", result.IndexesAdded)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify target has merged data
|
||||||
|
table := target.Schemas[0].Tables[0]
|
||||||
|
if len(table.Columns) != 3 {
|
||||||
|
t.Errorf("Expected 3 columns in merged table, got %d", len(table.Columns))
|
||||||
|
}
|
||||||
|
if len(table.Constraints) != 2 {
|
||||||
|
t.Errorf("Expected 2 constraints in merged table, got %d", len(table.Constraints))
|
||||||
|
}
|
||||||
|
if len(table.Indexes) != 1 {
|
||||||
|
t.Errorf("Expected 1 index in merged table, got %d", len(table.Indexes))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify specific constraint
|
||||||
|
if _, exists := table.Constraints["ukey_users_guid"]; !exists {
|
||||||
|
t.Error("Expected ukey_users_guid constraint to exist")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -4,31 +4,31 @@ import "strings"
|
|||||||
|
|
||||||
var GoToStdTypes = map[string]string{
|
var GoToStdTypes = map[string]string{
|
||||||
"bool": "boolean",
|
"bool": "boolean",
|
||||||
"int64": "integer",
|
"int64": "bigint",
|
||||||
"int": "integer",
|
"int": "integer",
|
||||||
"int8": "integer",
|
"int8": "smallint",
|
||||||
"int16": "integer",
|
"int16": "smallint",
|
||||||
"int32": "integer",
|
"int32": "integer",
|
||||||
"uint": "integer",
|
"uint": "integer",
|
||||||
"uint8": "integer",
|
"uint8": "smallint",
|
||||||
"uint16": "integer",
|
"uint16": "smallint",
|
||||||
"uint32": "integer",
|
"uint32": "integer",
|
||||||
"uint64": "integer",
|
"uint64": "bigint",
|
||||||
"uintptr": "integer",
|
"uintptr": "bigint",
|
||||||
"znullint64": "integer",
|
"znullint64": "bigint",
|
||||||
"znullint32": "integer",
|
"znullint32": "integer",
|
||||||
"znullbyte": "integer",
|
"znullbyte": "smallint",
|
||||||
"float64": "double",
|
"float64": "double",
|
||||||
"float32": "double",
|
"float32": "double",
|
||||||
"complex64": "double",
|
"complex64": "double",
|
||||||
"complex128": "double",
|
"complex128": "double",
|
||||||
"customfloat64": "double",
|
"customfloat64": "double",
|
||||||
"string": "string",
|
"string": "text",
|
||||||
"Pointer": "integer",
|
"Pointer": "bigint",
|
||||||
"[]byte": "blob",
|
"[]byte": "blob",
|
||||||
"customdate": "string",
|
"customdate": "date",
|
||||||
"customtime": "string",
|
"customtime": "time",
|
||||||
"customtimestamp": "string",
|
"customtimestamp": "timestamp",
|
||||||
"sqlfloat64": "double",
|
"sqlfloat64": "double",
|
||||||
"sqlfloat16": "double",
|
"sqlfloat16": "double",
|
||||||
"sqluuid": "uuid",
|
"sqluuid": "uuid",
|
||||||
@@ -36,9 +36,9 @@ var GoToStdTypes = map[string]string{
|
|||||||
"sqljson": "json",
|
"sqljson": "json",
|
||||||
"sqlint64": "bigint",
|
"sqlint64": "bigint",
|
||||||
"sqlint32": "integer",
|
"sqlint32": "integer",
|
||||||
"sqlint16": "integer",
|
"sqlint16": "smallint",
|
||||||
"sqlbool": "boolean",
|
"sqlbool": "boolean",
|
||||||
"sqlstring": "string",
|
"sqlstring": "text",
|
||||||
"nullablejsonb": "jsonb",
|
"nullablejsonb": "jsonb",
|
||||||
"nullablejson": "json",
|
"nullablejson": "json",
|
||||||
"nullableuuid": "uuid",
|
"nullableuuid": "uuid",
|
||||||
@@ -67,7 +67,7 @@ var GoToPGSQLTypes = map[string]string{
|
|||||||
"float32": "real",
|
"float32": "real",
|
||||||
"complex64": "double precision",
|
"complex64": "double precision",
|
||||||
"complex128": "double precision",
|
"complex128": "double precision",
|
||||||
"customfloat64": "double precisio",
|
"customfloat64": "double precision",
|
||||||
"string": "text",
|
"string": "text",
|
||||||
"Pointer": "bigint",
|
"Pointer": "bigint",
|
||||||
"[]byte": "bytea",
|
"[]byte": "bytea",
|
||||||
@@ -81,9 +81,9 @@ var GoToPGSQLTypes = map[string]string{
|
|||||||
"sqljson": "json",
|
"sqljson": "json",
|
||||||
"sqlint64": "bigint",
|
"sqlint64": "bigint",
|
||||||
"sqlint32": "integer",
|
"sqlint32": "integer",
|
||||||
"sqlint16": "integer",
|
"sqlint16": "smallint",
|
||||||
"sqlbool": "boolean",
|
"sqlbool": "boolean",
|
||||||
"sqlstring": "string",
|
"sqlstring": "text",
|
||||||
"nullablejsonb": "jsonb",
|
"nullablejsonb": "jsonb",
|
||||||
"nullablejson": "json",
|
"nullablejson": "json",
|
||||||
"nullableuuid": "uuid",
|
"nullableuuid": "uuid",
|
||||||
|
|||||||
339
pkg/pgsql/datatypes_test.go
Normal file
339
pkg/pgsql/datatypes_test.go
Normal file
@@ -0,0 +1,339 @@
|
|||||||
|
package pgsql
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestValidSQLType(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
sqltype string
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
// PostgreSQL types
|
||||||
|
{"Valid PGSQL bigint", "bigint", true},
|
||||||
|
{"Valid PGSQL integer", "integer", true},
|
||||||
|
{"Valid PGSQL text", "text", true},
|
||||||
|
{"Valid PGSQL boolean", "boolean", true},
|
||||||
|
{"Valid PGSQL double precision", "double precision", true},
|
||||||
|
{"Valid PGSQL bytea", "bytea", true},
|
||||||
|
{"Valid PGSQL uuid", "uuid", true},
|
||||||
|
{"Valid PGSQL jsonb", "jsonb", true},
|
||||||
|
{"Valid PGSQL json", "json", true},
|
||||||
|
{"Valid PGSQL timestamp", "timestamp", true},
|
||||||
|
{"Valid PGSQL date", "date", true},
|
||||||
|
{"Valid PGSQL time", "time", true},
|
||||||
|
{"Valid PGSQL citext", "citext", true},
|
||||||
|
|
||||||
|
// Standard types
|
||||||
|
{"Valid std double", "double", true},
|
||||||
|
{"Valid std blob", "blob", true},
|
||||||
|
|
||||||
|
// Case insensitive
|
||||||
|
{"Case insensitive BIGINT", "BIGINT", true},
|
||||||
|
{"Case insensitive TeXt", "TeXt", true},
|
||||||
|
{"Case insensitive BoOlEaN", "BoOlEaN", true},
|
||||||
|
|
||||||
|
// Invalid types
|
||||||
|
{"Invalid type", "invalidtype", false},
|
||||||
|
{"Invalid type varchar", "varchar", false},
|
||||||
|
{"Empty string", "", false},
|
||||||
|
{"Random string", "foobar", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := ValidSQLType(tt.sqltype)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("ValidSQLType(%q) = %v, want %v", tt.sqltype, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetSQLType(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
anytype string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
// Go types to PostgreSQL types
|
||||||
|
{"Go bool to boolean", "bool", "boolean"},
|
||||||
|
{"Go int64 to bigint", "int64", "bigint"},
|
||||||
|
{"Go int to integer", "int", "integer"},
|
||||||
|
{"Go string to text", "string", "text"},
|
||||||
|
{"Go float64 to double precision", "float64", "double precision"},
|
||||||
|
{"Go float32 to real", "float32", "real"},
|
||||||
|
{"Go []byte to bytea", "[]byte", "bytea"},
|
||||||
|
|
||||||
|
// SQL types remain SQL types
|
||||||
|
{"SQL bigint", "bigint", "bigint"},
|
||||||
|
{"SQL integer", "integer", "integer"},
|
||||||
|
{"SQL text", "text", "text"},
|
||||||
|
{"SQL boolean", "boolean", "boolean"},
|
||||||
|
{"SQL uuid", "uuid", "uuid"},
|
||||||
|
{"SQL jsonb", "jsonb", "jsonb"},
|
||||||
|
|
||||||
|
// Case insensitive Go types
|
||||||
|
{"Case insensitive BOOL", "BOOL", "boolean"},
|
||||||
|
{"Case insensitive InT64", "InT64", "bigint"},
|
||||||
|
{"Case insensitive STRING", "STRING", "text"},
|
||||||
|
|
||||||
|
// Case insensitive SQL types
|
||||||
|
{"Case insensitive BIGINT", "BIGINT", "bigint"},
|
||||||
|
{"Case insensitive TEXT", "TEXT", "text"},
|
||||||
|
|
||||||
|
// Custom types
|
||||||
|
{"Custom sqluuid", "sqluuid", "uuid"},
|
||||||
|
{"Custom sqljsonb", "sqljsonb", "jsonb"},
|
||||||
|
{"Custom sqlint64", "sqlint64", "bigint"},
|
||||||
|
|
||||||
|
// Unknown types default to text
|
||||||
|
{"Unknown type varchar", "varchar", "text"},
|
||||||
|
{"Unknown type foobar", "foobar", "text"},
|
||||||
|
{"Empty string", "", "text"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := GetSQLType(tt.anytype)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("GetSQLType(%q) = %q, want %q", tt.anytype, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConvertSQLType(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
anytype string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
// Go types to PostgreSQL types
|
||||||
|
{"Go bool to boolean", "bool", "boolean"},
|
||||||
|
{"Go int64 to bigint", "int64", "bigint"},
|
||||||
|
{"Go int to integer", "int", "integer"},
|
||||||
|
{"Go string to text", "string", "text"},
|
||||||
|
{"Go float64 to double precision", "float64", "double precision"},
|
||||||
|
{"Go float32 to real", "float32", "real"},
|
||||||
|
{"Go []byte to bytea", "[]byte", "bytea"},
|
||||||
|
|
||||||
|
// SQL types remain SQL types
|
||||||
|
{"SQL bigint", "bigint", "bigint"},
|
||||||
|
{"SQL integer", "integer", "integer"},
|
||||||
|
{"SQL text", "text", "text"},
|
||||||
|
{"SQL boolean", "boolean", "boolean"},
|
||||||
|
|
||||||
|
// Case insensitive
|
||||||
|
{"Case insensitive BOOL", "BOOL", "boolean"},
|
||||||
|
{"Case insensitive InT64", "InT64", "bigint"},
|
||||||
|
|
||||||
|
// Unknown types remain unchanged (difference from GetSQLType)
|
||||||
|
{"Unknown type varchar", "varchar", "varchar"},
|
||||||
|
{"Unknown type foobar", "foobar", "foobar"},
|
||||||
|
{"Empty string", "", ""},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := ConvertSQLType(tt.anytype)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("ConvertSQLType(%q) = %q, want %q", tt.anytype, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsGoType(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
typeName string
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
// Go basic types
|
||||||
|
{"Go bool", "bool", true},
|
||||||
|
{"Go int64", "int64", true},
|
||||||
|
{"Go int", "int", true},
|
||||||
|
{"Go int32", "int32", true},
|
||||||
|
{"Go int16", "int16", true},
|
||||||
|
{"Go int8", "int8", true},
|
||||||
|
{"Go uint", "uint", true},
|
||||||
|
{"Go uint64", "uint64", true},
|
||||||
|
{"Go uint32", "uint32", true},
|
||||||
|
{"Go uint16", "uint16", true},
|
||||||
|
{"Go uint8", "uint8", true},
|
||||||
|
{"Go float64", "float64", true},
|
||||||
|
{"Go float32", "float32", true},
|
||||||
|
{"Go string", "string", true},
|
||||||
|
{"Go []byte", "[]byte", true},
|
||||||
|
|
||||||
|
// Go custom types
|
||||||
|
{"Go complex64", "complex64", true},
|
||||||
|
{"Go complex128", "complex128", true},
|
||||||
|
{"Go uintptr", "uintptr", true},
|
||||||
|
{"Go Pointer", "Pointer", true},
|
||||||
|
|
||||||
|
// Custom SQL types
|
||||||
|
{"Custom sqluuid", "sqluuid", true},
|
||||||
|
{"Custom sqljsonb", "sqljsonb", true},
|
||||||
|
{"Custom sqlint64", "sqlint64", true},
|
||||||
|
{"Custom customdate", "customdate", true},
|
||||||
|
{"Custom customtime", "customtime", true},
|
||||||
|
|
||||||
|
// Case insensitive
|
||||||
|
{"Case insensitive BOOL", "BOOL", true},
|
||||||
|
{"Case insensitive InT64", "InT64", true},
|
||||||
|
{"Case insensitive STRING", "STRING", true},
|
||||||
|
|
||||||
|
// SQL types (not Go types)
|
||||||
|
{"SQL bigint", "bigint", false},
|
||||||
|
{"SQL integer", "integer", false},
|
||||||
|
{"SQL text", "text", false},
|
||||||
|
{"SQL boolean", "boolean", false},
|
||||||
|
|
||||||
|
// Invalid types
|
||||||
|
{"Invalid type", "invalidtype", false},
|
||||||
|
{"Empty string", "", false},
|
||||||
|
{"Random string", "foobar", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := IsGoType(tt.typeName)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("IsGoType(%q) = %v, want %v", tt.typeName, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetStdTypeFromGo(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
typeName string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
// Go types to standard SQL types
|
||||||
|
{"Go bool to boolean", "bool", "boolean"},
|
||||||
|
{"Go int64 to bigint", "int64", "bigint"},
|
||||||
|
{"Go int to integer", "int", "integer"},
|
||||||
|
{"Go string to text", "string", "text"},
|
||||||
|
{"Go float64 to double", "float64", "double"},
|
||||||
|
{"Go float32 to double", "float32", "double"},
|
||||||
|
{"Go []byte to blob", "[]byte", "blob"},
|
||||||
|
{"Go int32 to integer", "int32", "integer"},
|
||||||
|
{"Go int16 to smallint", "int16", "smallint"},
|
||||||
|
|
||||||
|
// Custom types
|
||||||
|
{"Custom sqluuid to uuid", "sqluuid", "uuid"},
|
||||||
|
{"Custom sqljsonb to jsonb", "sqljsonb", "jsonb"},
|
||||||
|
{"Custom sqlint64 to bigint", "sqlint64", "bigint"},
|
||||||
|
{"Custom customdate to date", "customdate", "date"},
|
||||||
|
|
||||||
|
// Case insensitive
|
||||||
|
{"Case insensitive BOOL", "BOOL", "boolean"},
|
||||||
|
{"Case insensitive InT64", "InT64", "bigint"},
|
||||||
|
{"Case insensitive STRING", "STRING", "text"},
|
||||||
|
|
||||||
|
// Non-Go types remain unchanged
|
||||||
|
{"SQL bigint unchanged", "bigint", "bigint"},
|
||||||
|
{"SQL integer unchanged", "integer", "integer"},
|
||||||
|
{"Invalid type unchanged", "invalidtype", "invalidtype"},
|
||||||
|
{"Empty string unchanged", "", ""},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := GetStdTypeFromGo(tt.typeName)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("GetStdTypeFromGo(%q) = %q, want %q", tt.typeName, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGoToStdTypesMap(t *testing.T) {
|
||||||
|
// Test that the map contains expected entries
|
||||||
|
expectedMappings := map[string]string{
|
||||||
|
"bool": "boolean",
|
||||||
|
"int64": "bigint",
|
||||||
|
"int": "integer",
|
||||||
|
"string": "text",
|
||||||
|
"float64": "double",
|
||||||
|
"[]byte": "blob",
|
||||||
|
}
|
||||||
|
|
||||||
|
for goType, expectedStd := range expectedMappings {
|
||||||
|
if stdType, ok := GoToStdTypes[goType]; !ok {
|
||||||
|
t.Errorf("GoToStdTypes missing entry for %q", goType)
|
||||||
|
} else if stdType != expectedStd {
|
||||||
|
t.Errorf("GoToStdTypes[%q] = %q, want %q", goType, stdType, expectedStd)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test that the map is not empty
|
||||||
|
if len(GoToStdTypes) == 0 {
|
||||||
|
t.Error("GoToStdTypes map is empty")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGoToPGSQLTypesMap(t *testing.T) {
|
||||||
|
// Test that the map contains expected entries
|
||||||
|
expectedMappings := map[string]string{
|
||||||
|
"bool": "boolean",
|
||||||
|
"int64": "bigint",
|
||||||
|
"int": "integer",
|
||||||
|
"string": "text",
|
||||||
|
"float64": "double precision",
|
||||||
|
"float32": "real",
|
||||||
|
"[]byte": "bytea",
|
||||||
|
}
|
||||||
|
|
||||||
|
for goType, expectedPG := range expectedMappings {
|
||||||
|
if pgType, ok := GoToPGSQLTypes[goType]; !ok {
|
||||||
|
t.Errorf("GoToPGSQLTypes missing entry for %q", goType)
|
||||||
|
} else if pgType != expectedPG {
|
||||||
|
t.Errorf("GoToPGSQLTypes[%q] = %q, want %q", goType, pgType, expectedPG)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test that the map is not empty
|
||||||
|
if len(GoToPGSQLTypes) == 0 {
|
||||||
|
t.Error("GoToPGSQLTypes map is empty")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTypeConversionConsistency(t *testing.T) {
|
||||||
|
// Test that GetSQLType and ConvertSQLType are consistent for known types
|
||||||
|
knownGoTypes := []string{"bool", "int64", "int", "string", "float64", "[]byte"}
|
||||||
|
|
||||||
|
for _, goType := range knownGoTypes {
|
||||||
|
getSQLResult := GetSQLType(goType)
|
||||||
|
convertResult := ConvertSQLType(goType)
|
||||||
|
|
||||||
|
if getSQLResult != convertResult {
|
||||||
|
t.Errorf("Inconsistent results for %q: GetSQLType=%q, ConvertSQLType=%q",
|
||||||
|
goType, getSQLResult, convertResult)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetSQLTypeVsConvertSQLTypeDifference(t *testing.T) {
|
||||||
|
// Test that GetSQLType returns "text" for unknown types
|
||||||
|
// while ConvertSQLType returns the input unchanged
|
||||||
|
unknownTypes := []string{"varchar", "char", "customtype", "unknowntype"}
|
||||||
|
|
||||||
|
for _, unknown := range unknownTypes {
|
||||||
|
getSQLResult := GetSQLType(unknown)
|
||||||
|
convertResult := ConvertSQLType(unknown)
|
||||||
|
|
||||||
|
if getSQLResult != "text" {
|
||||||
|
t.Errorf("GetSQLType(%q) = %q, want %q", unknown, getSQLResult, "text")
|
||||||
|
}
|
||||||
|
|
||||||
|
if convertResult != unknown {
|
||||||
|
t.Errorf("ConvertSQLType(%q) = %q, want %q", unknown, convertResult, unknown)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
136
pkg/pgsql/keywords_test.go
Normal file
136
pkg/pgsql/keywords_test.go
Normal file
@@ -0,0 +1,136 @@
|
|||||||
|
package pgsql
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGetPostgresKeywords(t *testing.T) {
|
||||||
|
keywords := GetPostgresKeywords()
|
||||||
|
|
||||||
|
// Test that keywords are returned
|
||||||
|
if len(keywords) == 0 {
|
||||||
|
t.Fatal("Expected non-empty list of keywords")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test that we get all keywords from the map
|
||||||
|
expectedCount := len(postgresKeywords)
|
||||||
|
if len(keywords) != expectedCount {
|
||||||
|
t.Errorf("Expected %d keywords, got %d", expectedCount, len(keywords))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test that all returned keywords exist in the map
|
||||||
|
for _, keyword := range keywords {
|
||||||
|
if !postgresKeywords[keyword] {
|
||||||
|
t.Errorf("Keyword %q not found in postgresKeywords map", keyword)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test that no duplicate keywords are returned
|
||||||
|
seen := make(map[string]bool)
|
||||||
|
for _, keyword := range keywords {
|
||||||
|
if seen[keyword] {
|
||||||
|
t.Errorf("Duplicate keyword found: %q", keyword)
|
||||||
|
}
|
||||||
|
seen[keyword] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPostgresKeywordsMap(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
keyword string
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{"SELECT keyword", "select", true},
|
||||||
|
{"FROM keyword", "from", true},
|
||||||
|
{"WHERE keyword", "where", true},
|
||||||
|
{"TABLE keyword", "table", true},
|
||||||
|
{"PRIMARY keyword", "primary", true},
|
||||||
|
{"FOREIGN keyword", "foreign", true},
|
||||||
|
{"CREATE keyword", "create", true},
|
||||||
|
{"DROP keyword", "drop", true},
|
||||||
|
{"ALTER keyword", "alter", true},
|
||||||
|
{"INDEX keyword", "index", true},
|
||||||
|
{"NOT keyword", "not", true},
|
||||||
|
{"NULL keyword", "null", true},
|
||||||
|
{"TRUE keyword", "true", true},
|
||||||
|
{"FALSE keyword", "false", true},
|
||||||
|
{"Non-keyword lowercase", "notakeyword", false},
|
||||||
|
{"Non-keyword uppercase", "NOTAKEYWORD", false},
|
||||||
|
{"Empty string", "", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := postgresKeywords[tt.keyword]
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("postgresKeywords[%q] = %v, want %v", tt.keyword, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPostgresKeywordsMapContent(t *testing.T) {
|
||||||
|
// Test that the map contains expected common keywords
|
||||||
|
commonKeywords := []string{
|
||||||
|
"select", "insert", "update", "delete", "create", "drop", "alter",
|
||||||
|
"table", "index", "view", "schema", "function", "procedure",
|
||||||
|
"primary", "foreign", "key", "constraint", "unique", "check",
|
||||||
|
"null", "not", "and", "or", "like", "in", "between",
|
||||||
|
"join", "inner", "left", "right", "cross", "full", "outer",
|
||||||
|
"where", "having", "group", "order", "limit", "offset",
|
||||||
|
"union", "intersect", "except",
|
||||||
|
"begin", "commit", "rollback", "transaction",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, keyword := range commonKeywords {
|
||||||
|
if !postgresKeywords[keyword] {
|
||||||
|
t.Errorf("Expected common keyword %q to be in postgresKeywords map", keyword)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPostgresKeywordsMapSize(t *testing.T) {
|
||||||
|
// PostgreSQL has a substantial list of reserved keywords
|
||||||
|
// This test ensures the map has a reasonable number of entries
|
||||||
|
minExpectedKeywords := 200 // PostgreSQL 13+ has 400+ reserved words
|
||||||
|
|
||||||
|
if len(postgresKeywords) < minExpectedKeywords {
|
||||||
|
t.Errorf("Expected at least %d keywords, got %d. The map may be incomplete.",
|
||||||
|
minExpectedKeywords, len(postgresKeywords))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetPostgresKeywordsConsistency(t *testing.T) {
|
||||||
|
// Test that calling GetPostgresKeywords multiple times returns consistent results
|
||||||
|
keywords1 := GetPostgresKeywords()
|
||||||
|
keywords2 := GetPostgresKeywords()
|
||||||
|
|
||||||
|
if len(keywords1) != len(keywords2) {
|
||||||
|
t.Errorf("Inconsistent results: first call returned %d keywords, second call returned %d",
|
||||||
|
len(keywords1), len(keywords2))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a map from both results to compare
|
||||||
|
map1 := make(map[string]bool)
|
||||||
|
map2 := make(map[string]bool)
|
||||||
|
|
||||||
|
for _, k := range keywords1 {
|
||||||
|
map1[k] = true
|
||||||
|
}
|
||||||
|
for _, k := range keywords2 {
|
||||||
|
map2[k] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that both contain the same keywords
|
||||||
|
for k := range map1 {
|
||||||
|
if !map2[k] {
|
||||||
|
t.Errorf("Keyword %q present in first call but not in second", k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for k := range map2 {
|
||||||
|
if !map1[k] {
|
||||||
|
t.Errorf("Keyword %q present in second call but not in first", k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -128,6 +128,46 @@ func (r *Reader) readDirectoryDBML(dirPath string) (*models.Database, error) {
|
|||||||
return db, nil
|
return db, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// splitIdentifier splits a dotted identifier while respecting quotes
|
||||||
|
// Handles cases like: "schema.with.dots"."table"."column"
|
||||||
|
func splitIdentifier(s string) []string {
|
||||||
|
var parts []string
|
||||||
|
var current strings.Builder
|
||||||
|
inQuote := false
|
||||||
|
quoteChar := byte(0)
|
||||||
|
|
||||||
|
for i := 0; i < len(s); i++ {
|
||||||
|
ch := s[i]
|
||||||
|
|
||||||
|
if !inQuote {
|
||||||
|
switch ch {
|
||||||
|
case '"', '\'':
|
||||||
|
inQuote = true
|
||||||
|
quoteChar = ch
|
||||||
|
current.WriteByte(ch)
|
||||||
|
case '.':
|
||||||
|
if current.Len() > 0 {
|
||||||
|
parts = append(parts, current.String())
|
||||||
|
current.Reset()
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
current.WriteByte(ch)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
current.WriteByte(ch)
|
||||||
|
if ch == quoteChar {
|
||||||
|
inQuote = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if current.Len() > 0 {
|
||||||
|
parts = append(parts, current.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
return parts
|
||||||
|
}
|
||||||
|
|
||||||
// stripQuotes removes surrounding quotes and comments from an identifier
|
// stripQuotes removes surrounding quotes and comments from an identifier
|
||||||
func stripQuotes(s string) string {
|
func stripQuotes(s string) string {
|
||||||
s = strings.TrimSpace(s)
|
s = strings.TrimSpace(s)
|
||||||
@@ -409,7 +449,9 @@ func (r *Reader) parseDBML(content string) (*models.Database, error) {
|
|||||||
// Parse Table definition
|
// Parse Table definition
|
||||||
if matches := tableRegex.FindStringSubmatch(line); matches != nil {
|
if matches := tableRegex.FindStringSubmatch(line); matches != nil {
|
||||||
tableName := matches[1]
|
tableName := matches[1]
|
||||||
parts := strings.Split(tableName, ".")
|
// Strip comments/notes before parsing to avoid dots in notes
|
||||||
|
tableName = strings.TrimSpace(regexp.MustCompile(`\s*\[.*?\]\s*`).ReplaceAllString(tableName, ""))
|
||||||
|
parts := splitIdentifier(tableName)
|
||||||
|
|
||||||
if len(parts) == 2 {
|
if len(parts) == 2 {
|
||||||
currentSchema = stripQuotes(parts[0])
|
currentSchema = stripQuotes(parts[0])
|
||||||
@@ -561,8 +603,10 @@ func (r *Reader) parseColumn(line, tableName, schemaName string) (*models.Column
|
|||||||
column.Default = strings.Trim(defaultVal, "'\"")
|
column.Default = strings.Trim(defaultVal, "'\"")
|
||||||
} else if attr == "unique" {
|
} else if attr == "unique" {
|
||||||
// Create a unique constraint
|
// Create a unique constraint
|
||||||
|
// Clean table name by removing leading underscores to avoid double underscores
|
||||||
|
cleanTableName := strings.TrimLeft(tableName, "_")
|
||||||
uniqueConstraint := models.InitConstraint(
|
uniqueConstraint := models.InitConstraint(
|
||||||
fmt.Sprintf("uq_%s", columnName),
|
fmt.Sprintf("ukey_%s_%s", cleanTableName, columnName),
|
||||||
models.UniqueConstraint,
|
models.UniqueConstraint,
|
||||||
)
|
)
|
||||||
uniqueConstraint.Schema = schemaName
|
uniqueConstraint.Schema = schemaName
|
||||||
@@ -587,10 +631,10 @@ func (r *Reader) parseColumn(line, tableName, schemaName string) (*models.Column
|
|||||||
refOp := strings.TrimSpace(refStr)
|
refOp := strings.TrimSpace(refStr)
|
||||||
var isReverse bool
|
var isReverse bool
|
||||||
if strings.HasPrefix(refOp, "<") {
|
if strings.HasPrefix(refOp, "<") {
|
||||||
isReverse = column.IsPrimaryKey // < on PK means "is referenced by" (reverse)
|
// < means "is referenced by" - only makes sense on PK columns
|
||||||
} else if strings.HasPrefix(refOp, ">") {
|
isReverse = column.IsPrimaryKey
|
||||||
isReverse = !column.IsPrimaryKey // > on FK means reverse
|
|
||||||
}
|
}
|
||||||
|
// > means "references" - always a forward FK, never reverse
|
||||||
|
|
||||||
constraint = r.parseRef(refStr)
|
constraint = r.parseRef(refStr)
|
||||||
if constraint != nil {
|
if constraint != nil {
|
||||||
@@ -610,8 +654,8 @@ func (r *Reader) parseColumn(line, tableName, schemaName string) (*models.Column
|
|||||||
constraint.Table = tableName
|
constraint.Table = tableName
|
||||||
constraint.Columns = []string{columnName}
|
constraint.Columns = []string{columnName}
|
||||||
}
|
}
|
||||||
// Generate short constraint name based on the column
|
// Generate constraint name based on table and columns
|
||||||
constraint.Name = fmt.Sprintf("fk_%s", constraint.Columns[0])
|
constraint.Name = fmt.Sprintf("fk_%s_%s", constraint.Table, strings.Join(constraint.Columns, "_"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -695,7 +739,11 @@ func (r *Reader) parseIndex(line, tableName, schemaName string) *models.Index {
|
|||||||
|
|
||||||
// Generate name if not provided
|
// Generate name if not provided
|
||||||
if index.Name == "" {
|
if index.Name == "" {
|
||||||
index.Name = fmt.Sprintf("idx_%s_%s", tableName, strings.Join(columns, "_"))
|
prefix := "idx"
|
||||||
|
if index.Unique {
|
||||||
|
prefix = "uidx"
|
||||||
|
}
|
||||||
|
index.Name = fmt.Sprintf("%s_%s_%s", prefix, tableName, strings.Join(columns, "_"))
|
||||||
}
|
}
|
||||||
|
|
||||||
return index
|
return index
|
||||||
@@ -755,10 +803,10 @@ func (r *Reader) parseRef(refStr string) *models.Constraint {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate short constraint name based on the source column
|
// Generate constraint name based on table and columns
|
||||||
constraintName := fmt.Sprintf("fk_%s_%s", fromTable, toTable)
|
constraintName := fmt.Sprintf("fk_%s_%s", fromTable, strings.Join(fromColumns, "_"))
|
||||||
if len(fromColumns) > 0 {
|
if len(fromColumns) == 0 {
|
||||||
constraintName = fmt.Sprintf("fk_%s", fromColumns[0])
|
constraintName = fmt.Sprintf("fk_%s_%s", fromTable, toTable)
|
||||||
}
|
}
|
||||||
|
|
||||||
constraint := models.InitConstraint(
|
constraint := models.InitConstraint(
|
||||||
@@ -814,7 +862,7 @@ func (r *Reader) parseTableRef(ref string) (schema, table string, columns []stri
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Parse schema, table, and optionally column
|
// Parse schema, table, and optionally column
|
||||||
parts := strings.Split(strings.TrimSpace(ref), ".")
|
parts := splitIdentifier(strings.TrimSpace(ref))
|
||||||
if len(parts) == 3 {
|
if len(parts) == 3 {
|
||||||
// Format: "schema"."table"."column"
|
// Format: "schema"."table"."column"
|
||||||
schema = stripQuotes(parts[0])
|
schema = stripQuotes(parts[0])
|
||||||
|
|||||||
@@ -777,6 +777,76 @@ func TestParseFilePrefix(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConstraintNaming(t *testing.T) {
|
||||||
|
// Test that constraints are named with proper prefixes
|
||||||
|
opts := &readers.ReaderOptions{
|
||||||
|
FilePath: filepath.Join("..", "..", "..", "tests", "assets", "dbml", "complex.dbml"),
|
||||||
|
}
|
||||||
|
|
||||||
|
reader := NewReader(opts)
|
||||||
|
db, err := reader.ReadDatabase()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ReadDatabase() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find users table
|
||||||
|
var usersTable *models.Table
|
||||||
|
var postsTable *models.Table
|
||||||
|
for _, schema := range db.Schemas {
|
||||||
|
for _, table := range schema.Tables {
|
||||||
|
if table.Name == "users" {
|
||||||
|
usersTable = table
|
||||||
|
} else if table.Name == "posts" {
|
||||||
|
postsTable = table
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if usersTable == nil {
|
||||||
|
t.Fatal("Users table not found")
|
||||||
|
}
|
||||||
|
if postsTable == nil {
|
||||||
|
t.Fatal("Posts table not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test unique constraint naming: ukey_table_column
|
||||||
|
if _, exists := usersTable.Constraints["ukey_users_email"]; !exists {
|
||||||
|
t.Error("Expected unique constraint 'ukey_users_email' not found")
|
||||||
|
t.Logf("Available constraints: %v", getKeys(usersTable.Constraints))
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, exists := postsTable.Constraints["ukey_posts_slug"]; !exists {
|
||||||
|
t.Error("Expected unique constraint 'ukey_posts_slug' not found")
|
||||||
|
t.Logf("Available constraints: %v", getKeys(postsTable.Constraints))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test foreign key naming: fk_table_column
|
||||||
|
if _, exists := postsTable.Constraints["fk_posts_user_id"]; !exists {
|
||||||
|
t.Error("Expected foreign key 'fk_posts_user_id' not found")
|
||||||
|
t.Logf("Available constraints: %v", getKeys(postsTable.Constraints))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test unique index naming: uidx_table_columns
|
||||||
|
if _, exists := postsTable.Indexes["uidx_posts_slug"]; !exists {
|
||||||
|
t.Error("Expected unique index 'uidx_posts_slug' not found")
|
||||||
|
t.Logf("Available indexes: %v", getKeys(postsTable.Indexes))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test regular index naming: idx_table_columns
|
||||||
|
if _, exists := postsTable.Indexes["idx_posts_user_id_published"]; !exists {
|
||||||
|
t.Error("Expected index 'idx_posts_user_id_published' not found")
|
||||||
|
t.Logf("Available indexes: %v", getKeys(postsTable.Indexes))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getKeys[V any](m map[string]V) []string {
|
||||||
|
keys := make([]string, 0, len(m))
|
||||||
|
for k := range m {
|
||||||
|
keys = append(keys, k)
|
||||||
|
}
|
||||||
|
return keys
|
||||||
|
}
|
||||||
|
|
||||||
func TestHasCommentedRefs(t *testing.T) {
|
func TestHasCommentedRefs(t *testing.T) {
|
||||||
// Test with the actual multifile test fixtures
|
// Test with the actual multifile test fixtures
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
|
|||||||
@@ -329,10 +329,10 @@ func (r *Reader) deriveRelationship(table *models.Table, fk *models.Constraint)
|
|||||||
relationshipName := fmt.Sprintf("%s_to_%s", table.Name, fk.ReferencedTable)
|
relationshipName := fmt.Sprintf("%s_to_%s", table.Name, fk.ReferencedTable)
|
||||||
|
|
||||||
relationship := models.InitRelationship(relationshipName, models.OneToMany)
|
relationship := models.InitRelationship(relationshipName, models.OneToMany)
|
||||||
relationship.FromTable = fk.ReferencedTable
|
relationship.FromTable = table.Name
|
||||||
relationship.FromSchema = fk.ReferencedSchema
|
relationship.FromSchema = table.Schema
|
||||||
relationship.ToTable = table.Name
|
relationship.ToTable = fk.ReferencedTable
|
||||||
relationship.ToSchema = table.Schema
|
relationship.ToSchema = fk.ReferencedSchema
|
||||||
relationship.ForeignKey = fk.Name
|
relationship.ForeignKey = fk.Name
|
||||||
|
|
||||||
// Store constraint actions in properties
|
// Store constraint actions in properties
|
||||||
|
|||||||
@@ -328,12 +328,12 @@ func TestDeriveRelationship(t *testing.T) {
|
|||||||
t.Errorf("Expected relationship type %s, got %s", models.OneToMany, rel.Type)
|
t.Errorf("Expected relationship type %s, got %s", models.OneToMany, rel.Type)
|
||||||
}
|
}
|
||||||
|
|
||||||
if rel.FromTable != "users" {
|
if rel.FromTable != "orders" {
|
||||||
t.Errorf("Expected FromTable 'users', got '%s'", rel.FromTable)
|
t.Errorf("Expected FromTable 'orders', got '%s'", rel.FromTable)
|
||||||
}
|
}
|
||||||
|
|
||||||
if rel.ToTable != "orders" {
|
if rel.ToTable != "users" {
|
||||||
t.Errorf("Expected ToTable 'orders', got '%s'", rel.ToTable)
|
t.Errorf("Expected ToTable 'users', got '%s'", rel.ToTable)
|
||||||
}
|
}
|
||||||
|
|
||||||
if rel.ForeignKey != "fk_orders_user_id" {
|
if rel.ForeignKey != "fk_orders_user_id" {
|
||||||
|
|||||||
@@ -93,6 +93,7 @@ fmt.Printf("Found %d scripts\n", len(schema.Scripts))
|
|||||||
## Features
|
## Features
|
||||||
|
|
||||||
- **Recursive Directory Scanning**: Automatically scans all subdirectories
|
- **Recursive Directory Scanning**: Automatically scans all subdirectories
|
||||||
|
- **Symlink Skipping**: Symbolic links are automatically skipped (prevents loops and duplicates)
|
||||||
- **Multiple Extensions**: Supports both `.sql` and `.pgsql` files
|
- **Multiple Extensions**: Supports both `.sql` and `.pgsql` files
|
||||||
- **Flexible Naming**: Extract metadata from filename patterns
|
- **Flexible Naming**: Extract metadata from filename patterns
|
||||||
- **Error Handling**: Validates directory existence and file accessibility
|
- **Error Handling**: Validates directory existence and file accessibility
|
||||||
@@ -153,8 +154,9 @@ go test ./pkg/readers/sqldir/
|
|||||||
```
|
```
|
||||||
|
|
||||||
Tests include:
|
Tests include:
|
||||||
- Valid file parsing
|
- Valid file parsing (underscore and hyphen formats)
|
||||||
- Recursive directory scanning
|
- Recursive directory scanning
|
||||||
|
- Symlink skipping
|
||||||
- Invalid filename handling
|
- Invalid filename handling
|
||||||
- Empty directory handling
|
- Empty directory handling
|
||||||
- Error conditions
|
- Error conditions
|
||||||
|
|||||||
@@ -107,11 +107,20 @@ func (r *Reader) readScripts() ([]*models.Script, error) {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Skip directories
|
// Don't process directories as files (WalkDir still descends into them recursively)
|
||||||
if d.IsDir() {
|
if d.IsDir() {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Skip symlinks
|
||||||
|
info, err := d.Info()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if info.Mode()&os.ModeSymlink != 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Get filename
|
// Get filename
|
||||||
filename := d.Name()
|
filename := d.Name()
|
||||||
|
|
||||||
|
|||||||
@@ -373,3 +373,65 @@ func TestReader_MixedFormat(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestReader_SkipSymlinks(t *testing.T) {
|
||||||
|
// Create temporary test directory
|
||||||
|
tempDir, err := os.MkdirTemp("", "sqldir-test-symlink-*")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create temp directory: %v", err)
|
||||||
|
}
|
||||||
|
defer os.RemoveAll(tempDir)
|
||||||
|
|
||||||
|
// Create a real SQL file
|
||||||
|
realFile := filepath.Join(tempDir, "1_001_real_file.sql")
|
||||||
|
if err := os.WriteFile(realFile, []byte("SELECT 1;"), 0644); err != nil {
|
||||||
|
t.Fatalf("Failed to create real file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create another file to link to
|
||||||
|
targetFile := filepath.Join(tempDir, "2_001_target.sql")
|
||||||
|
if err := os.WriteFile(targetFile, []byte("SELECT 2;"), 0644); err != nil {
|
||||||
|
t.Fatalf("Failed to create target file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a symlink to the target file (this should be skipped)
|
||||||
|
symlinkFile := filepath.Join(tempDir, "3_001_symlink.sql")
|
||||||
|
if err := os.Symlink(targetFile, symlinkFile); err != nil {
|
||||||
|
// Skip test on systems that don't support symlinks (e.g., Windows without admin)
|
||||||
|
t.Skipf("Symlink creation not supported: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create reader
|
||||||
|
reader := NewReader(&readers.ReaderOptions{
|
||||||
|
FilePath: tempDir,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Read database
|
||||||
|
db, err := reader.ReadDatabase()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ReadDatabase failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
schema := db.Schemas[0]
|
||||||
|
|
||||||
|
// Should only have 2 scripts (real_file and target), symlink should be skipped
|
||||||
|
if len(schema.Scripts) != 2 {
|
||||||
|
t.Errorf("Expected 2 scripts (symlink should be skipped), got %d", len(schema.Scripts))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the scripts are the real files, not the symlink
|
||||||
|
scriptNames := make(map[string]bool)
|
||||||
|
for _, script := range schema.Scripts {
|
||||||
|
scriptNames[script.Name] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if !scriptNames["real_file"] {
|
||||||
|
t.Error("Expected 'real_file' script to be present")
|
||||||
|
}
|
||||||
|
if !scriptNames["target"] {
|
||||||
|
t.Error("Expected 'target' script to be present")
|
||||||
|
}
|
||||||
|
if scriptNames["symlink"] {
|
||||||
|
t.Error("Symlink script should have been skipped but was found")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
490
pkg/reflectutil/helpers_test.go
Normal file
490
pkg/reflectutil/helpers_test.go
Normal file
@@ -0,0 +1,490 @@
|
|||||||
|
package reflectutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
type testStruct struct {
|
||||||
|
Name string
|
||||||
|
Age int
|
||||||
|
Active bool
|
||||||
|
Nested *nestedStruct
|
||||||
|
Private string
|
||||||
|
}
|
||||||
|
|
||||||
|
type nestedStruct struct {
|
||||||
|
Value string
|
||||||
|
Count int
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeref(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input interface{}
|
||||||
|
wantValid bool
|
||||||
|
wantKind reflect.Kind
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "non-pointer int",
|
||||||
|
input: 42,
|
||||||
|
wantValid: true,
|
||||||
|
wantKind: reflect.Int,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "single pointer",
|
||||||
|
input: ptrInt(42),
|
||||||
|
wantValid: true,
|
||||||
|
wantKind: reflect.Int,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "double pointer",
|
||||||
|
input: ptrPtr(ptrInt(42)),
|
||||||
|
wantValid: true,
|
||||||
|
wantKind: reflect.Int,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nil pointer",
|
||||||
|
input: (*int)(nil),
|
||||||
|
wantValid: false,
|
||||||
|
wantKind: reflect.Ptr,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "string",
|
||||||
|
input: "test",
|
||||||
|
wantValid: true,
|
||||||
|
wantKind: reflect.String,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "struct",
|
||||||
|
input: testStruct{Name: "test"},
|
||||||
|
wantValid: true,
|
||||||
|
wantKind: reflect.Struct,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
v := reflect.ValueOf(tt.input)
|
||||||
|
got, valid := Deref(v)
|
||||||
|
|
||||||
|
if valid != tt.wantValid {
|
||||||
|
t.Errorf("Deref() valid = %v, want %v", valid, tt.wantValid)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got.Kind() != tt.wantKind {
|
||||||
|
t.Errorf("Deref() kind = %v, want %v", got.Kind(), tt.wantKind)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDerefInterface(t *testing.T) {
|
||||||
|
i := 42
|
||||||
|
pi := &i
|
||||||
|
ppi := &pi
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input interface{}
|
||||||
|
wantKind reflect.Kind
|
||||||
|
}{
|
||||||
|
{"int", 42, reflect.Int},
|
||||||
|
{"pointer to int", &i, reflect.Int},
|
||||||
|
{"double pointer to int", ppi, reflect.Int},
|
||||||
|
{"string", "test", reflect.String},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := DerefInterface(tt.input)
|
||||||
|
if got.Kind() != tt.wantKind {
|
||||||
|
t.Errorf("DerefInterface() kind = %v, want %v", got.Kind(), tt.wantKind)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetFieldValue(t *testing.T) {
|
||||||
|
ts := testStruct{
|
||||||
|
Name: "John",
|
||||||
|
Age: 30,
|
||||||
|
Active: true,
|
||||||
|
Nested: &nestedStruct{Value: "nested", Count: 5},
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
item interface{}
|
||||||
|
field string
|
||||||
|
want interface{}
|
||||||
|
}{
|
||||||
|
{"struct field Name", ts, "Name", "John"},
|
||||||
|
{"struct field Age", ts, "Age", 30},
|
||||||
|
{"struct field Active", ts, "Active", true},
|
||||||
|
{"struct non-existent field", ts, "NonExistent", nil},
|
||||||
|
{"pointer to struct", &ts, "Name", "John"},
|
||||||
|
{"map string key", map[string]string{"key": "value"}, "key", "value"},
|
||||||
|
{"map int key", map[string]int{"count": 42}, "count", 42},
|
||||||
|
{"map non-existent key", map[string]string{"key": "value"}, "missing", nil},
|
||||||
|
{"nil pointer", (*testStruct)(nil), "Name", nil},
|
||||||
|
{"non-struct non-map", 42, "field", nil},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := GetFieldValue(tt.item, tt.field)
|
||||||
|
if !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("GetFieldValue() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsSliceOrArray(t *testing.T) {
|
||||||
|
arr := [3]int{1, 2, 3}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input interface{}
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{"slice", []int{1, 2, 3}, true},
|
||||||
|
{"array", arr, true},
|
||||||
|
{"pointer to slice", &[]int{1, 2, 3}, true},
|
||||||
|
{"string", "test", false},
|
||||||
|
{"int", 42, false},
|
||||||
|
{"map", map[string]int{}, false},
|
||||||
|
{"nil slice", ([]int)(nil), true}, // nil slice is still Kind==Slice
|
||||||
|
{"nil pointer", (*[]int)(nil), false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := IsSliceOrArray(tt.input)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("IsSliceOrArray() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsMap(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input interface{}
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{"map[string]int", map[string]int{"a": 1}, true},
|
||||||
|
{"map[int]string", map[int]string{1: "a"}, true},
|
||||||
|
{"pointer to map", &map[string]int{"a": 1}, true},
|
||||||
|
{"slice", []int{1, 2, 3}, false},
|
||||||
|
{"string", "test", false},
|
||||||
|
{"int", 42, false},
|
||||||
|
{"nil map", (map[string]int)(nil), true}, // nil map is still Kind==Map
|
||||||
|
{"nil pointer", (*map[string]int)(nil), false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := IsMap(tt.input)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("IsMap() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSliceLen(t *testing.T) {
|
||||||
|
arr := [3]int{1, 2, 3}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input interface{}
|
||||||
|
want int
|
||||||
|
}{
|
||||||
|
{"slice length 3", []int{1, 2, 3}, 3},
|
||||||
|
{"empty slice", []int{}, 0},
|
||||||
|
{"array length 3", arr, 3},
|
||||||
|
{"pointer to slice", &[]int{1, 2, 3}, 3},
|
||||||
|
{"not a slice", "test", 0},
|
||||||
|
{"int", 42, 0},
|
||||||
|
{"nil slice", ([]int)(nil), 0},
|
||||||
|
{"nil pointer", (*[]int)(nil), 0},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := SliceLen(tt.input)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("SliceLen() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMapLen(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input interface{}
|
||||||
|
want int
|
||||||
|
}{
|
||||||
|
{"map length 2", map[string]int{"a": 1, "b": 2}, 2},
|
||||||
|
{"empty map", map[string]int{}, 0},
|
||||||
|
{"pointer to map", &map[string]int{"a": 1}, 1},
|
||||||
|
{"not a map", []int{1, 2, 3}, 0},
|
||||||
|
{"string", "test", 0},
|
||||||
|
{"nil map", (map[string]int)(nil), 0},
|
||||||
|
{"nil pointer", (*map[string]int)(nil), 0},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := MapLen(tt.input)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("MapLen() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSliceToInterfaces(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input interface{}
|
||||||
|
want []interface{}
|
||||||
|
}{
|
||||||
|
{"int slice", []int{1, 2, 3}, []interface{}{1, 2, 3}},
|
||||||
|
{"string slice", []string{"a", "b"}, []interface{}{"a", "b"}},
|
||||||
|
{"empty slice", []int{}, []interface{}{}},
|
||||||
|
{"pointer to slice", &[]int{1, 2}, []interface{}{1, 2}},
|
||||||
|
{"not a slice", "test", []interface{}{}},
|
||||||
|
{"nil slice", ([]int)(nil), []interface{}{}},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := SliceToInterfaces(tt.input)
|
||||||
|
if !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("SliceToInterfaces() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMapKeys(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input interface{}
|
||||||
|
want []interface{}
|
||||||
|
}{
|
||||||
|
{"map with keys", map[string]int{"a": 1, "b": 2}, []interface{}{"a", "b"}},
|
||||||
|
{"empty map", map[string]int{}, []interface{}{}},
|
||||||
|
{"not a map", []int{1, 2, 3}, []interface{}{}},
|
||||||
|
{"nil map", (map[string]int)(nil), []interface{}{}},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := MapKeys(tt.input)
|
||||||
|
if len(got) != len(tt.want) {
|
||||||
|
t.Errorf("MapKeys() length = %v, want %v", len(got), len(tt.want))
|
||||||
|
}
|
||||||
|
// For maps, order is not guaranteed, so just check length
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMapValues(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input interface{}
|
||||||
|
want int // length of values
|
||||||
|
}{
|
||||||
|
{"map with values", map[string]int{"a": 1, "b": 2}, 2},
|
||||||
|
{"empty map", map[string]int{}, 0},
|
||||||
|
{"not a map", []int{1, 2, 3}, 0},
|
||||||
|
{"nil map", (map[string]int)(nil), 0},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := MapValues(tt.input)
|
||||||
|
if len(got) != tt.want {
|
||||||
|
t.Errorf("MapValues() length = %v, want %v", len(got), tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMapGet(t *testing.T) {
|
||||||
|
m := map[string]int{"a": 1, "b": 2}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input interface{}
|
||||||
|
key interface{}
|
||||||
|
want interface{}
|
||||||
|
}{
|
||||||
|
{"existing key", m, "a", 1},
|
||||||
|
{"existing key b", m, "b", 2},
|
||||||
|
{"non-existing key", m, "c", nil},
|
||||||
|
{"pointer to map", &m, "a", 1},
|
||||||
|
{"not a map", []int{1, 2}, 0, nil},
|
||||||
|
{"nil map", (map[string]int)(nil), "a", nil},
|
||||||
|
{"nil pointer", (*map[string]int)(nil), "a", nil},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := MapGet(tt.input, tt.key)
|
||||||
|
if !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("MapGet() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSliceIndex(t *testing.T) {
|
||||||
|
s := []int{10, 20, 30}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
slice interface{}
|
||||||
|
index int
|
||||||
|
want interface{}
|
||||||
|
}{
|
||||||
|
{"index 0", s, 0, 10},
|
||||||
|
{"index 1", s, 1, 20},
|
||||||
|
{"index 2", s, 2, 30},
|
||||||
|
{"negative index", s, -1, nil},
|
||||||
|
{"out of bounds", s, 5, nil},
|
||||||
|
{"pointer to slice", &s, 1, 20},
|
||||||
|
{"not a slice", "test", 0, nil},
|
||||||
|
{"nil slice", ([]int)(nil), 0, nil},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := SliceIndex(tt.slice, tt.index)
|
||||||
|
if !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("SliceIndex() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompareValues(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
a interface{}
|
||||||
|
b interface{}
|
||||||
|
want int
|
||||||
|
}{
|
||||||
|
{"both nil", nil, nil, 0},
|
||||||
|
{"a nil", nil, 5, -1},
|
||||||
|
{"b nil", 5, nil, 1},
|
||||||
|
{"equal strings", "abc", "abc", 0},
|
||||||
|
{"a less than b strings", "abc", "xyz", -1},
|
||||||
|
{"a greater than b strings", "xyz", "abc", 1},
|
||||||
|
{"equal ints", 5, 5, 0},
|
||||||
|
{"a less than b ints", 3, 7, -1},
|
||||||
|
{"a greater than b ints", 10, 5, 1},
|
||||||
|
{"equal floats", 3.14, 3.14, 0},
|
||||||
|
{"a less than b floats", 2.5, 5.5, -1},
|
||||||
|
{"a greater than b floats", 10.5, 5.5, 1},
|
||||||
|
{"equal uints", uint(5), uint(5), 0},
|
||||||
|
{"different types", "abc", 123, 0},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := CompareValues(tt.a, tt.b)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("CompareValues(%v, %v) = %v, want %v", tt.a, tt.b, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetNestedValue(t *testing.T) {
|
||||||
|
nested := map[string]interface{}{
|
||||||
|
"level1": map[string]interface{}{
|
||||||
|
"level2": map[string]interface{}{
|
||||||
|
"value": "deep",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
ts := testStruct{
|
||||||
|
Name: "John",
|
||||||
|
Nested: &nestedStruct{
|
||||||
|
Value: "nested value",
|
||||||
|
Count: 42,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input interface{}
|
||||||
|
path string
|
||||||
|
want interface{}
|
||||||
|
}{
|
||||||
|
{"empty path", nested, "", nested},
|
||||||
|
{"single level map", nested, "level1", nested["level1"]},
|
||||||
|
{"nested map", nested, "level1.level2", map[string]interface{}{"value": "deep"}},
|
||||||
|
{"deep nested map", nested, "level1.level2.value", "deep"},
|
||||||
|
{"struct field", ts, "Name", "John"},
|
||||||
|
{"nested struct field", ts, "Nested", ts.Nested},
|
||||||
|
{"non-existent path", nested, "missing.path", nil},
|
||||||
|
{"nil input", nil, "path", nil},
|
||||||
|
{"partial missing path", nested, "level1.missing", nil},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := GetNestedValue(tt.input, tt.path)
|
||||||
|
if !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("GetNestedValue() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeepEqual(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
a interface{}
|
||||||
|
b interface{}
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{"equal ints", 42, 42, true},
|
||||||
|
{"different ints", 42, 43, false},
|
||||||
|
{"equal strings", "test", "test", true},
|
||||||
|
{"different strings", "test", "other", false},
|
||||||
|
{"equal slices", []int{1, 2, 3}, []int{1, 2, 3}, true},
|
||||||
|
{"different slices", []int{1, 2, 3}, []int{1, 2, 4}, false},
|
||||||
|
{"equal maps", map[string]int{"a": 1}, map[string]int{"a": 1}, true},
|
||||||
|
{"different maps", map[string]int{"a": 1}, map[string]int{"a": 2}, false},
|
||||||
|
{"both nil", nil, nil, true},
|
||||||
|
{"one nil", nil, 42, false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := DeepEqual(tt.a, tt.b)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("DeepEqual(%v, %v) = %v, want %v", tt.a, tt.b, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper functions
|
||||||
|
func ptrInt(i int) *int {
|
||||||
|
return &i
|
||||||
|
}
|
||||||
|
|
||||||
|
func ptrPtr(p *int) **int {
|
||||||
|
return &p
|
||||||
|
}
|
||||||
@@ -106,19 +106,20 @@ func (td *TemplateData) FinalizeImports() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewModelData creates a new ModelData from a models.Table
|
// NewModelData creates a new ModelData from a models.Table
|
||||||
func NewModelData(table *models.Table, schema string, typeMapper *TypeMapper) *ModelData {
|
func NewModelData(table *models.Table, schema string, typeMapper *TypeMapper, flattenSchema bool) *ModelData {
|
||||||
tableName := table.Name
|
tableName := writers.QualifiedTableName(schema, table.Name, flattenSchema)
|
||||||
if schema != "" {
|
|
||||||
tableName = schema + "." + table.Name
|
|
||||||
}
|
|
||||||
|
|
||||||
// Generate model name: singularize and convert to PascalCase
|
// Generate model name: Model + Schema + Table (all PascalCase)
|
||||||
singularTable := Singularize(table.Name)
|
singularTable := Singularize(table.Name)
|
||||||
modelName := SnakeCaseToPascalCase(singularTable)
|
tablePart := SnakeCaseToPascalCase(singularTable)
|
||||||
|
|
||||||
// Add "Model" prefix if not already present
|
// Include schema name in model name
|
||||||
if !hasModelPrefix(modelName) {
|
var modelName string
|
||||||
modelName = "Model" + modelName
|
if schema != "" {
|
||||||
|
schemaPart := SnakeCaseToPascalCase(schema)
|
||||||
|
modelName = "Model" + schemaPart + tablePart
|
||||||
|
} else {
|
||||||
|
modelName = "Model" + tablePart
|
||||||
}
|
}
|
||||||
|
|
||||||
model := &ModelData{
|
model := &ModelData{
|
||||||
@@ -149,6 +150,8 @@ func NewModelData(table *models.Table, schema string, typeMapper *TypeMapper) *M
|
|||||||
columns := sortColumns(table.Columns)
|
columns := sortColumns(table.Columns)
|
||||||
for _, col := range columns {
|
for _, col := range columns {
|
||||||
field := columnToField(col, table, typeMapper)
|
field := columnToField(col, table, typeMapper)
|
||||||
|
// Check for name collision with generated methods and rename if needed
|
||||||
|
field.Name = resolveFieldNameCollision(field.Name)
|
||||||
model.Fields = append(model.Fields, field)
|
model.Fields = append(model.Fields, field)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -190,9 +193,28 @@ func formatComment(description, comment string) string {
|
|||||||
return comment
|
return comment
|
||||||
}
|
}
|
||||||
|
|
||||||
// hasModelPrefix checks if a name already has "Model" prefix
|
// resolveFieldNameCollision checks if a field name conflicts with generated method names
|
||||||
func hasModelPrefix(name string) bool {
|
// and adds an underscore suffix if there's a collision
|
||||||
return len(name) >= 5 && name[:5] == "Model"
|
func resolveFieldNameCollision(fieldName string) string {
|
||||||
|
// List of method names that are generated by the template
|
||||||
|
reservedNames := map[string]bool{
|
||||||
|
"TableName": true,
|
||||||
|
"TableNameOnly": true,
|
||||||
|
"SchemaName": true,
|
||||||
|
"GetID": true,
|
||||||
|
"GetIDStr": true,
|
||||||
|
"SetID": true,
|
||||||
|
"UpdateID": true,
|
||||||
|
"GetIDName": true,
|
||||||
|
"GetPrefix": true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if field name conflicts with a reserved method name
|
||||||
|
if reservedNames[fieldName] {
|
||||||
|
return fieldName + "_"
|
||||||
|
}
|
||||||
|
|
||||||
|
return fieldName
|
||||||
}
|
}
|
||||||
|
|
||||||
// sortColumns sorts columns by sequence, then by name
|
// sortColumns sorts columns by sequence, then by name
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"go/format"
|
"go/format"
|
||||||
"os"
|
"os"
|
||||||
|
"os/exec"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@@ -85,7 +86,7 @@ func (w *Writer) writeSingleFile(db *models.Database) error {
|
|||||||
// Collect all models
|
// Collect all models
|
||||||
for _, schema := range db.Schemas {
|
for _, schema := range db.Schemas {
|
||||||
for _, table := range schema.Tables {
|
for _, table := range schema.Tables {
|
||||||
modelData := NewModelData(table, schema.Name, w.typeMapper)
|
modelData := NewModelData(table, schema.Name, w.typeMapper, w.options.FlattenSchema)
|
||||||
|
|
||||||
// Add relationship fields
|
// Add relationship fields
|
||||||
w.addRelationshipFields(modelData, table, schema, db)
|
w.addRelationshipFields(modelData, table, schema, db)
|
||||||
@@ -124,7 +125,16 @@ func (w *Writer) writeSingleFile(db *models.Database) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Write output
|
// Write output
|
||||||
return w.writeOutput(formatted)
|
if err := w.writeOutput(formatted); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run go fmt on the output file
|
||||||
|
if w.options.OutputPath != "" {
|
||||||
|
w.runGoFmt(w.options.OutputPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// writeMultiFile writes each table to a separate file
|
// writeMultiFile writes each table to a separate file
|
||||||
@@ -171,7 +181,7 @@ func (w *Writer) writeMultiFile(db *models.Database) error {
|
|||||||
templateData.AddImport(fmt.Sprintf("resolvespec_common \"%s\"", w.typeMapper.GetSQLTypesImport()))
|
templateData.AddImport(fmt.Sprintf("resolvespec_common \"%s\"", w.typeMapper.GetSQLTypesImport()))
|
||||||
|
|
||||||
// Create model data
|
// Create model data
|
||||||
modelData := NewModelData(table, schema.Name, w.typeMapper)
|
modelData := NewModelData(table, schema.Name, w.typeMapper, w.options.FlattenSchema)
|
||||||
|
|
||||||
// Add relationship fields
|
// Add relationship fields
|
||||||
w.addRelationshipFields(modelData, table, schema, db)
|
w.addRelationshipFields(modelData, table, schema, db)
|
||||||
@@ -217,6 +227,9 @@ func (w *Writer) writeMultiFile(db *models.Database) error {
|
|||||||
if err := os.WriteFile(filepath, []byte(formatted), 0644); err != nil {
|
if err := os.WriteFile(filepath, []byte(formatted), 0644); err != nil {
|
||||||
return fmt.Errorf("failed to write file %s: %w", filename, err)
|
return fmt.Errorf("failed to write file %s: %w", filename, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Run go fmt on the generated file
|
||||||
|
w.runGoFmt(filepath)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -241,7 +254,7 @@ func (w *Writer) addRelationshipFields(modelData *ModelData, table *models.Table
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create relationship field (has-one in Bun, similar to belongs-to in GORM)
|
// Create relationship field (has-one in Bun, similar to belongs-to in GORM)
|
||||||
refModelName := w.getModelName(constraint.ReferencedTable)
|
refModelName := w.getModelName(constraint.ReferencedSchema, constraint.ReferencedTable)
|
||||||
fieldName := w.generateHasOneFieldName(constraint)
|
fieldName := w.generateHasOneFieldName(constraint)
|
||||||
fieldName = w.ensureUniqueFieldName(fieldName, usedFieldNames)
|
fieldName = w.ensureUniqueFieldName(fieldName, usedFieldNames)
|
||||||
relationTag := w.typeMapper.BuildRelationshipTag(constraint, "has-one")
|
relationTag := w.typeMapper.BuildRelationshipTag(constraint, "has-one")
|
||||||
@@ -270,8 +283,8 @@ func (w *Writer) addRelationshipFields(modelData *ModelData, table *models.Table
|
|||||||
// Check if this constraint references our table
|
// Check if this constraint references our table
|
||||||
if constraint.ReferencedTable == table.Name && constraint.ReferencedSchema == schema.Name {
|
if constraint.ReferencedTable == table.Name && constraint.ReferencedSchema == schema.Name {
|
||||||
// Add has-many relationship
|
// Add has-many relationship
|
||||||
otherModelName := w.getModelName(otherTable.Name)
|
otherModelName := w.getModelName(otherSchema.Name, otherTable.Name)
|
||||||
fieldName := w.generateHasManyFieldName(constraint, otherTable.Name)
|
fieldName := w.generateHasManyFieldName(constraint, otherSchema.Name, otherTable.Name)
|
||||||
fieldName = w.ensureUniqueFieldName(fieldName, usedFieldNames)
|
fieldName = w.ensureUniqueFieldName(fieldName, usedFieldNames)
|
||||||
relationTag := w.typeMapper.BuildRelationshipTag(constraint, "has-many")
|
relationTag := w.typeMapper.BuildRelationshipTag(constraint, "has-many")
|
||||||
|
|
||||||
@@ -303,13 +316,18 @@ func (w *Writer) findTable(schemaName, tableName string, db *models.Database) *m
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// getModelName generates the model name from a table name
|
// getModelName generates the model name from schema and table name
|
||||||
func (w *Writer) getModelName(tableName string) string {
|
func (w *Writer) getModelName(schemaName, tableName string) string {
|
||||||
singular := Singularize(tableName)
|
singular := Singularize(tableName)
|
||||||
modelName := SnakeCaseToPascalCase(singular)
|
tablePart := SnakeCaseToPascalCase(singular)
|
||||||
|
|
||||||
if !hasModelPrefix(modelName) {
|
// Include schema name in model name
|
||||||
modelName = "Model" + modelName
|
var modelName string
|
||||||
|
if schemaName != "" {
|
||||||
|
schemaPart := SnakeCaseToPascalCase(schemaName)
|
||||||
|
modelName = "Model" + schemaPart + tablePart
|
||||||
|
} else {
|
||||||
|
modelName = "Model" + tablePart
|
||||||
}
|
}
|
||||||
|
|
||||||
return modelName
|
return modelName
|
||||||
@@ -333,13 +351,13 @@ func (w *Writer) generateHasOneFieldName(constraint *models.Constraint) string {
|
|||||||
|
|
||||||
// generateHasManyFieldName generates a field name for has-many relationships
|
// generateHasManyFieldName generates a field name for has-many relationships
|
||||||
// Uses the foreign key column name + source table name to avoid duplicates
|
// Uses the foreign key column name + source table name to avoid duplicates
|
||||||
func (w *Writer) generateHasManyFieldName(constraint *models.Constraint, sourceTableName string) string {
|
func (w *Writer) generateHasManyFieldName(constraint *models.Constraint, sourceSchemaName, sourceTableName string) string {
|
||||||
// For has-many, we need to include the source table name to avoid duplicates
|
// For has-many, we need to include the source table name to avoid duplicates
|
||||||
// e.g., multiple tables referencing the same column on this table
|
// e.g., multiple tables referencing the same column on this table
|
||||||
if len(constraint.Columns) > 0 {
|
if len(constraint.Columns) > 0 {
|
||||||
columnName := constraint.Columns[0]
|
columnName := constraint.Columns[0]
|
||||||
// Get the model name for the source table (pluralized)
|
// Get the model name for the source table (pluralized)
|
||||||
sourceModelName := w.getModelName(sourceTableName)
|
sourceModelName := w.getModelName(sourceSchemaName, sourceTableName)
|
||||||
// Remove "Model" prefix if present
|
// Remove "Model" prefix if present
|
||||||
sourceModelName = strings.TrimPrefix(sourceModelName, "Model")
|
sourceModelName = strings.TrimPrefix(sourceModelName, "Model")
|
||||||
|
|
||||||
@@ -350,7 +368,7 @@ func (w *Writer) generateHasManyFieldName(constraint *models.Constraint, sourceT
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Fallback to table-based naming
|
// Fallback to table-based naming
|
||||||
sourceModelName := w.getModelName(sourceTableName)
|
sourceModelName := w.getModelName(sourceSchemaName, sourceTableName)
|
||||||
sourceModelName = strings.TrimPrefix(sourceModelName, "Model")
|
sourceModelName = strings.TrimPrefix(sourceModelName, "Model")
|
||||||
return "Rel" + Pluralize(sourceModelName)
|
return "Rel" + Pluralize(sourceModelName)
|
||||||
}
|
}
|
||||||
@@ -399,6 +417,15 @@ func (w *Writer) writeOutput(content string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// runGoFmt runs go fmt on the specified file
|
||||||
|
func (w *Writer) runGoFmt(filepath string) {
|
||||||
|
cmd := exec.Command("gofmt", "-w", filepath)
|
||||||
|
if err := cmd.Run(); err != nil {
|
||||||
|
// Don't fail the whole operation if gofmt fails, just warn
|
||||||
|
fmt.Fprintf(os.Stderr, "Warning: failed to run gofmt on %s: %v\n", filepath, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// shouldUseMultiFile determines whether to use multi-file mode based on metadata or output path
|
// shouldUseMultiFile determines whether to use multi-file mode based on metadata or output path
|
||||||
func (w *Writer) shouldUseMultiFile() bool {
|
func (w *Writer) shouldUseMultiFile() bool {
|
||||||
// Check if multi_file is explicitly set in metadata
|
// Check if multi_file is explicitly set in metadata
|
||||||
|
|||||||
@@ -66,7 +66,7 @@ func TestWriter_WriteTable(t *testing.T) {
|
|||||||
// Verify key elements are present
|
// Verify key elements are present
|
||||||
expectations := []string{
|
expectations := []string{
|
||||||
"package models",
|
"package models",
|
||||||
"type ModelUser struct",
|
"type ModelPublicUser struct",
|
||||||
"bun.BaseModel",
|
"bun.BaseModel",
|
||||||
"table:public.users",
|
"table:public.users",
|
||||||
"alias:users",
|
"alias:users",
|
||||||
@@ -78,9 +78,9 @@ func TestWriter_WriteTable(t *testing.T) {
|
|||||||
"resolvespec_common.SqlTime",
|
"resolvespec_common.SqlTime",
|
||||||
"bun:\"id",
|
"bun:\"id",
|
||||||
"bun:\"email",
|
"bun:\"email",
|
||||||
"func (m ModelUser) TableName() string",
|
"func (m ModelPublicUser) TableName() string",
|
||||||
"return \"public.users\"",
|
"return \"public.users\"",
|
||||||
"func (m ModelUser) GetID() int64",
|
"func (m ModelPublicUser) GetID() int64",
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, expected := range expectations {
|
for _, expected := range expectations {
|
||||||
@@ -191,9 +191,9 @@ func TestWriter_WriteDatabase_MultiFile(t *testing.T) {
|
|||||||
|
|
||||||
usersStr := string(usersContent)
|
usersStr := string(usersContent)
|
||||||
|
|
||||||
// Should have RelUserIDPosts (has-many) field
|
// Should have RelUserIDPublicPosts (has-many) field - includes schema prefix
|
||||||
if !strings.Contains(usersStr, "RelUserIDPosts") {
|
if !strings.Contains(usersStr, "RelUserIDPublicPosts") {
|
||||||
t.Errorf("Missing has-many relationship field RelUserIDPosts")
|
t.Errorf("Missing has-many relationship field RelUserIDPublicPosts")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -309,8 +309,8 @@ func TestWriter_MultipleReferencesToSameTable(t *testing.T) {
|
|||||||
|
|
||||||
// Should have two different has-many relationships with unique names
|
// Should have two different has-many relationships with unique names
|
||||||
hasManyExpectations := []string{
|
hasManyExpectations := []string{
|
||||||
"RelRIDFilepointerRequestAPIEvents", // Has many via rid_filepointer_request
|
"RelRIDFilepointerRequestOrgAPIEvents", // Has many via rid_filepointer_request
|
||||||
"RelRIDFilepointerResponseAPIEvents", // Has many via rid_filepointer_response
|
"RelRIDFilepointerResponseOrgAPIEvents", // Has many via rid_filepointer_response
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, exp := range hasManyExpectations {
|
for _, exp := range hasManyExpectations {
|
||||||
@@ -455,10 +455,10 @@ func TestWriter_MultipleHasManyRelationships(t *testing.T) {
|
|||||||
|
|
||||||
// Verify all has-many relationships have unique names
|
// Verify all has-many relationships have unique names
|
||||||
hasManyExpectations := []string{
|
hasManyExpectations := []string{
|
||||||
"RelRIDAPIProviderLogins", // Has many via Login
|
"RelRIDAPIProviderOrgLogins", // Has many via Login
|
||||||
"RelRIDAPIProviderFilepointers", // Has many via Filepointer
|
"RelRIDAPIProviderOrgFilepointers", // Has many via Filepointer
|
||||||
"RelRIDAPIProviderAPIEvents", // Has many via APIEvent
|
"RelRIDAPIProviderOrgAPIEvents", // Has many via APIEvent
|
||||||
"RelRIDOwner", // Has one via rid_owner
|
"RelRIDOwner", // Has one via rid_owner
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, exp := range hasManyExpectations {
|
for _, exp := range hasManyExpectations {
|
||||||
@@ -481,6 +481,74 @@ func TestWriter_MultipleHasManyRelationships(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestWriter_FieldNameCollision(t *testing.T) {
|
||||||
|
// Test scenario: table with columns that would conflict with generated method names
|
||||||
|
table := models.InitTable("audit_table", "audit")
|
||||||
|
table.Columns["id_audit_table"] = &models.Column{
|
||||||
|
Name: "id_audit_table",
|
||||||
|
Type: "smallint",
|
||||||
|
NotNull: true,
|
||||||
|
IsPrimaryKey: true,
|
||||||
|
Sequence: 1,
|
||||||
|
}
|
||||||
|
table.Columns["table_name"] = &models.Column{
|
||||||
|
Name: "table_name",
|
||||||
|
Type: "varchar",
|
||||||
|
Length: 100,
|
||||||
|
NotNull: true,
|
||||||
|
Sequence: 2,
|
||||||
|
}
|
||||||
|
table.Columns["table_schema"] = &models.Column{
|
||||||
|
Name: "table_schema",
|
||||||
|
Type: "varchar",
|
||||||
|
Length: 100,
|
||||||
|
NotNull: true,
|
||||||
|
Sequence: 3,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create writer
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
opts := &writers.WriterOptions{
|
||||||
|
PackageName: "models",
|
||||||
|
OutputPath: filepath.Join(tmpDir, "test.go"),
|
||||||
|
}
|
||||||
|
|
||||||
|
writer := NewWriter(opts)
|
||||||
|
|
||||||
|
err := writer.WriteTable(table)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("WriteTable failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read the generated file
|
||||||
|
content, err := os.ReadFile(opts.OutputPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to read generated file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
generated := string(content)
|
||||||
|
|
||||||
|
// Verify that TableName field was renamed to TableName_ to avoid collision
|
||||||
|
if !strings.Contains(generated, "TableName_") {
|
||||||
|
t.Errorf("Expected field 'TableName_' (with underscore) but not found\nGenerated:\n%s", generated)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the struct tag still references the correct database column
|
||||||
|
if !strings.Contains(generated, `bun:"table_name,`) {
|
||||||
|
t.Errorf("Expected bun tag to reference 'table_name' column\nGenerated:\n%s", generated)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the TableName() method still exists and doesn't conflict
|
||||||
|
if !strings.Contains(generated, "func (m ModelAuditAuditTable) TableName() string") {
|
||||||
|
t.Errorf("TableName() method should still be generated\nGenerated:\n%s", generated)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify NO field named just "TableName" (without underscore)
|
||||||
|
if strings.Contains(generated, "TableName resolvespec_common") || strings.Contains(generated, "TableName string") {
|
||||||
|
t.Errorf("Field 'TableName' without underscore should not exist (would conflict with method)\nGenerated:\n%s", generated)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestTypeMapper_SQLTypeToGoType_Bun(t *testing.T) {
|
func TestTypeMapper_SQLTypeToGoType_Bun(t *testing.T) {
|
||||||
mapper := NewTypeMapper()
|
mapper := NewTypeMapper()
|
||||||
|
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ type ModelData struct {
|
|||||||
Fields []*FieldData
|
Fields []*FieldData
|
||||||
Config *MethodConfig
|
Config *MethodConfig
|
||||||
PrimaryKeyField string // Name of the primary key field
|
PrimaryKeyField string // Name of the primary key field
|
||||||
|
PrimaryKeyType string // Go type of the primary key field
|
||||||
IDColumnName string // Name of the ID column in database
|
IDColumnName string // Name of the ID column in database
|
||||||
Prefix string // 3-letter prefix
|
Prefix string // 3-letter prefix
|
||||||
}
|
}
|
||||||
@@ -104,19 +105,20 @@ func (td *TemplateData) FinalizeImports() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewModelData creates a new ModelData from a models.Table
|
// NewModelData creates a new ModelData from a models.Table
|
||||||
func NewModelData(table *models.Table, schema string, typeMapper *TypeMapper) *ModelData {
|
func NewModelData(table *models.Table, schema string, typeMapper *TypeMapper, flattenSchema bool) *ModelData {
|
||||||
tableName := table.Name
|
tableName := writers.QualifiedTableName(schema, table.Name, flattenSchema)
|
||||||
if schema != "" {
|
|
||||||
tableName = schema + "." + table.Name
|
|
||||||
}
|
|
||||||
|
|
||||||
// Generate model name: singularize and convert to PascalCase
|
// Generate model name: Model + Schema + Table (all PascalCase)
|
||||||
singularTable := Singularize(table.Name)
|
singularTable := Singularize(table.Name)
|
||||||
modelName := SnakeCaseToPascalCase(singularTable)
|
tablePart := SnakeCaseToPascalCase(singularTable)
|
||||||
|
|
||||||
// Add "Model" prefix if not already present
|
// Include schema name in model name
|
||||||
if !hasModelPrefix(modelName) {
|
var modelName string
|
||||||
modelName = "Model" + modelName
|
if schema != "" {
|
||||||
|
schemaPart := SnakeCaseToPascalCase(schema)
|
||||||
|
modelName = "Model" + schemaPart + tablePart
|
||||||
|
} else {
|
||||||
|
modelName = "Model" + tablePart
|
||||||
}
|
}
|
||||||
|
|
||||||
model := &ModelData{
|
model := &ModelData{
|
||||||
@@ -135,6 +137,7 @@ func NewModelData(table *models.Table, schema string, typeMapper *TypeMapper) *M
|
|||||||
// Sanitize column name to remove backticks
|
// Sanitize column name to remove backticks
|
||||||
safeName := writers.SanitizeStructTagValue(col.Name)
|
safeName := writers.SanitizeStructTagValue(col.Name)
|
||||||
model.PrimaryKeyField = SnakeCaseToPascalCase(safeName)
|
model.PrimaryKeyField = SnakeCaseToPascalCase(safeName)
|
||||||
|
model.PrimaryKeyType = typeMapper.SQLTypeToGoType(col.Type, col.NotNull)
|
||||||
model.IDColumnName = safeName
|
model.IDColumnName = safeName
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@@ -144,6 +147,8 @@ func NewModelData(table *models.Table, schema string, typeMapper *TypeMapper) *M
|
|||||||
columns := sortColumns(table.Columns)
|
columns := sortColumns(table.Columns)
|
||||||
for _, col := range columns {
|
for _, col := range columns {
|
||||||
field := columnToField(col, table, typeMapper)
|
field := columnToField(col, table, typeMapper)
|
||||||
|
// Check for name collision with generated methods and rename if needed
|
||||||
|
field.Name = resolveFieldNameCollision(field.Name)
|
||||||
model.Fields = append(model.Fields, field)
|
model.Fields = append(model.Fields, field)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -185,9 +190,28 @@ func formatComment(description, comment string) string {
|
|||||||
return comment
|
return comment
|
||||||
}
|
}
|
||||||
|
|
||||||
// hasModelPrefix checks if a name already has "Model" prefix
|
// resolveFieldNameCollision checks if a field name conflicts with generated method names
|
||||||
func hasModelPrefix(name string) bool {
|
// and adds an underscore suffix if there's a collision
|
||||||
return len(name) >= 5 && name[:5] == "Model"
|
func resolveFieldNameCollision(fieldName string) string {
|
||||||
|
// List of method names that are generated by the template
|
||||||
|
reservedNames := map[string]bool{
|
||||||
|
"TableName": true,
|
||||||
|
"TableNameOnly": true,
|
||||||
|
"SchemaName": true,
|
||||||
|
"GetID": true,
|
||||||
|
"GetIDStr": true,
|
||||||
|
"SetID": true,
|
||||||
|
"UpdateID": true,
|
||||||
|
"GetIDName": true,
|
||||||
|
"GetPrefix": true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if field name conflicts with a reserved method name
|
||||||
|
if reservedNames[fieldName] {
|
||||||
|
return fieldName + "_"
|
||||||
|
}
|
||||||
|
|
||||||
|
return fieldName
|
||||||
}
|
}
|
||||||
|
|
||||||
// sortColumns sorts columns by sequence, then by name
|
// sortColumns sorts columns by sequence, then by name
|
||||||
|
|||||||
@@ -62,7 +62,7 @@ func (m {{.Name}}) SetID(newid int64) {
|
|||||||
{{if and .Config.GenerateUpdateID .PrimaryKeyField}}
|
{{if and .Config.GenerateUpdateID .PrimaryKeyField}}
|
||||||
// UpdateID updates the primary key value
|
// UpdateID updates the primary key value
|
||||||
func (m *{{.Name}}) UpdateID(newid int64) {
|
func (m *{{.Name}}) UpdateID(newid int64) {
|
||||||
m.{{.PrimaryKeyField}} = int32(newid)
|
m.{{.PrimaryKeyField}} = {{.PrimaryKeyType}}(newid)
|
||||||
}
|
}
|
||||||
{{end}}
|
{{end}}
|
||||||
{{if and .Config.GenerateGetIDName .IDColumnName}}
|
{{if and .Config.GenerateGetIDName .IDColumnName}}
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"go/format"
|
"go/format"
|
||||||
"os"
|
"os"
|
||||||
|
"os/exec"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@@ -82,7 +83,7 @@ func (w *Writer) writeSingleFile(db *models.Database) error {
|
|||||||
// Collect all models
|
// Collect all models
|
||||||
for _, schema := range db.Schemas {
|
for _, schema := range db.Schemas {
|
||||||
for _, table := range schema.Tables {
|
for _, table := range schema.Tables {
|
||||||
modelData := NewModelData(table, schema.Name, w.typeMapper)
|
modelData := NewModelData(table, schema.Name, w.typeMapper, w.options.FlattenSchema)
|
||||||
|
|
||||||
// Add relationship fields
|
// Add relationship fields
|
||||||
w.addRelationshipFields(modelData, table, schema, db)
|
w.addRelationshipFields(modelData, table, schema, db)
|
||||||
@@ -121,7 +122,16 @@ func (w *Writer) writeSingleFile(db *models.Database) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Write output
|
// Write output
|
||||||
return w.writeOutput(formatted)
|
if err := w.writeOutput(formatted); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run go fmt on the output file
|
||||||
|
if w.options.OutputPath != "" {
|
||||||
|
w.runGoFmt(w.options.OutputPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// writeMultiFile writes each table to a separate file
|
// writeMultiFile writes each table to a separate file
|
||||||
@@ -165,7 +175,7 @@ func (w *Writer) writeMultiFile(db *models.Database) error {
|
|||||||
templateData.AddImport(fmt.Sprintf("sql_types \"%s\"", w.typeMapper.GetSQLTypesImport()))
|
templateData.AddImport(fmt.Sprintf("sql_types \"%s\"", w.typeMapper.GetSQLTypesImport()))
|
||||||
|
|
||||||
// Create model data
|
// Create model data
|
||||||
modelData := NewModelData(table, schema.Name, w.typeMapper)
|
modelData := NewModelData(table, schema.Name, w.typeMapper, w.options.FlattenSchema)
|
||||||
|
|
||||||
// Add relationship fields
|
// Add relationship fields
|
||||||
w.addRelationshipFields(modelData, table, schema, db)
|
w.addRelationshipFields(modelData, table, schema, db)
|
||||||
@@ -211,6 +221,9 @@ func (w *Writer) writeMultiFile(db *models.Database) error {
|
|||||||
if err := os.WriteFile(filepath, []byte(formatted), 0644); err != nil {
|
if err := os.WriteFile(filepath, []byte(formatted), 0644); err != nil {
|
||||||
return fmt.Errorf("failed to write file %s: %w", filename, err)
|
return fmt.Errorf("failed to write file %s: %w", filename, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Run go fmt on the generated file
|
||||||
|
w.runGoFmt(filepath)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -235,7 +248,7 @@ func (w *Writer) addRelationshipFields(modelData *ModelData, table *models.Table
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create relationship field (belongs-to)
|
// Create relationship field (belongs-to)
|
||||||
refModelName := w.getModelName(constraint.ReferencedTable)
|
refModelName := w.getModelName(constraint.ReferencedSchema, constraint.ReferencedTable)
|
||||||
fieldName := w.generateBelongsToFieldName(constraint)
|
fieldName := w.generateBelongsToFieldName(constraint)
|
||||||
fieldName = w.ensureUniqueFieldName(fieldName, usedFieldNames)
|
fieldName = w.ensureUniqueFieldName(fieldName, usedFieldNames)
|
||||||
relationTag := w.typeMapper.BuildRelationshipTag(constraint, false)
|
relationTag := w.typeMapper.BuildRelationshipTag(constraint, false)
|
||||||
@@ -264,8 +277,8 @@ func (w *Writer) addRelationshipFields(modelData *ModelData, table *models.Table
|
|||||||
// Check if this constraint references our table
|
// Check if this constraint references our table
|
||||||
if constraint.ReferencedTable == table.Name && constraint.ReferencedSchema == schema.Name {
|
if constraint.ReferencedTable == table.Name && constraint.ReferencedSchema == schema.Name {
|
||||||
// Add has-many relationship
|
// Add has-many relationship
|
||||||
otherModelName := w.getModelName(otherTable.Name)
|
otherModelName := w.getModelName(otherSchema.Name, otherTable.Name)
|
||||||
fieldName := w.generateHasManyFieldName(constraint, otherTable.Name)
|
fieldName := w.generateHasManyFieldName(constraint, otherSchema.Name, otherTable.Name)
|
||||||
fieldName = w.ensureUniqueFieldName(fieldName, usedFieldNames)
|
fieldName = w.ensureUniqueFieldName(fieldName, usedFieldNames)
|
||||||
relationTag := w.typeMapper.BuildRelationshipTag(constraint, true)
|
relationTag := w.typeMapper.BuildRelationshipTag(constraint, true)
|
||||||
|
|
||||||
@@ -297,13 +310,18 @@ func (w *Writer) findTable(schemaName, tableName string, db *models.Database) *m
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// getModelName generates the model name from a table name
|
// getModelName generates the model name from schema and table name
|
||||||
func (w *Writer) getModelName(tableName string) string {
|
func (w *Writer) getModelName(schemaName, tableName string) string {
|
||||||
singular := Singularize(tableName)
|
singular := Singularize(tableName)
|
||||||
modelName := SnakeCaseToPascalCase(singular)
|
tablePart := SnakeCaseToPascalCase(singular)
|
||||||
|
|
||||||
if !hasModelPrefix(modelName) {
|
// Include schema name in model name
|
||||||
modelName = "Model" + modelName
|
var modelName string
|
||||||
|
if schemaName != "" {
|
||||||
|
schemaPart := SnakeCaseToPascalCase(schemaName)
|
||||||
|
modelName = "Model" + schemaPart + tablePart
|
||||||
|
} else {
|
||||||
|
modelName = "Model" + tablePart
|
||||||
}
|
}
|
||||||
|
|
||||||
return modelName
|
return modelName
|
||||||
@@ -327,13 +345,13 @@ func (w *Writer) generateBelongsToFieldName(constraint *models.Constraint) strin
|
|||||||
|
|
||||||
// generateHasManyFieldName generates a field name for has-many relationships
|
// generateHasManyFieldName generates a field name for has-many relationships
|
||||||
// Uses the foreign key column name + source table name to avoid duplicates
|
// Uses the foreign key column name + source table name to avoid duplicates
|
||||||
func (w *Writer) generateHasManyFieldName(constraint *models.Constraint, sourceTableName string) string {
|
func (w *Writer) generateHasManyFieldName(constraint *models.Constraint, sourceSchemaName, sourceTableName string) string {
|
||||||
// For has-many, we need to include the source table name to avoid duplicates
|
// For has-many, we need to include the source table name to avoid duplicates
|
||||||
// e.g., multiple tables referencing the same column on this table
|
// e.g., multiple tables referencing the same column on this table
|
||||||
if len(constraint.Columns) > 0 {
|
if len(constraint.Columns) > 0 {
|
||||||
columnName := constraint.Columns[0]
|
columnName := constraint.Columns[0]
|
||||||
// Get the model name for the source table (pluralized)
|
// Get the model name for the source table (pluralized)
|
||||||
sourceModelName := w.getModelName(sourceTableName)
|
sourceModelName := w.getModelName(sourceSchemaName, sourceTableName)
|
||||||
// Remove "Model" prefix if present
|
// Remove "Model" prefix if present
|
||||||
sourceModelName = strings.TrimPrefix(sourceModelName, "Model")
|
sourceModelName = strings.TrimPrefix(sourceModelName, "Model")
|
||||||
|
|
||||||
@@ -344,7 +362,7 @@ func (w *Writer) generateHasManyFieldName(constraint *models.Constraint, sourceT
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Fallback to table-based naming
|
// Fallback to table-based naming
|
||||||
sourceModelName := w.getModelName(sourceTableName)
|
sourceModelName := w.getModelName(sourceSchemaName, sourceTableName)
|
||||||
sourceModelName = strings.TrimPrefix(sourceModelName, "Model")
|
sourceModelName = strings.TrimPrefix(sourceModelName, "Model")
|
||||||
return "Rel" + Pluralize(sourceModelName)
|
return "Rel" + Pluralize(sourceModelName)
|
||||||
}
|
}
|
||||||
@@ -393,6 +411,15 @@ func (w *Writer) writeOutput(content string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// runGoFmt runs go fmt on the specified file
|
||||||
|
func (w *Writer) runGoFmt(filepath string) {
|
||||||
|
cmd := exec.Command("gofmt", "-w", filepath)
|
||||||
|
if err := cmd.Run(); err != nil {
|
||||||
|
// Don't fail the whole operation if gofmt fails, just warn
|
||||||
|
fmt.Fprintf(os.Stderr, "Warning: failed to run gofmt on %s: %v\n", filepath, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// shouldUseMultiFile determines whether to use multi-file mode based on metadata or output path
|
// shouldUseMultiFile determines whether to use multi-file mode based on metadata or output path
|
||||||
func (w *Writer) shouldUseMultiFile() bool {
|
func (w *Writer) shouldUseMultiFile() bool {
|
||||||
// Check if multi_file is explicitly set in metadata
|
// Check if multi_file is explicitly set in metadata
|
||||||
|
|||||||
@@ -66,7 +66,7 @@ func TestWriter_WriteTable(t *testing.T) {
|
|||||||
// Verify key elements are present
|
// Verify key elements are present
|
||||||
expectations := []string{
|
expectations := []string{
|
||||||
"package models",
|
"package models",
|
||||||
"type ModelUser struct",
|
"type ModelPublicUser struct",
|
||||||
"ID",
|
"ID",
|
||||||
"int64",
|
"int64",
|
||||||
"Email",
|
"Email",
|
||||||
@@ -75,9 +75,9 @@ func TestWriter_WriteTable(t *testing.T) {
|
|||||||
"time.Time",
|
"time.Time",
|
||||||
"gorm:\"column:id",
|
"gorm:\"column:id",
|
||||||
"gorm:\"column:email",
|
"gorm:\"column:email",
|
||||||
"func (m ModelUser) TableName() string",
|
"func (m ModelPublicUser) TableName() string",
|
||||||
"return \"public.users\"",
|
"return \"public.users\"",
|
||||||
"func (m ModelUser) GetID() int64",
|
"func (m ModelPublicUser) GetID() int64",
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, expected := range expectations {
|
for _, expected := range expectations {
|
||||||
@@ -180,9 +180,9 @@ func TestWriter_WriteDatabase_MultiFile(t *testing.T) {
|
|||||||
|
|
||||||
usersStr := string(usersContent)
|
usersStr := string(usersContent)
|
||||||
|
|
||||||
// Should have RelUserIDPosts (has-many) field
|
// Should have RelUserIDPublicPosts (has-many) field - includes schema prefix
|
||||||
if !strings.Contains(usersStr, "RelUserIDPosts") {
|
if !strings.Contains(usersStr, "RelUserIDPublicPosts") {
|
||||||
t.Errorf("Missing has-many relationship field RelUserIDPosts")
|
t.Errorf("Missing has-many relationship field RelUserIDPublicPosts")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -298,8 +298,8 @@ func TestWriter_MultipleReferencesToSameTable(t *testing.T) {
|
|||||||
|
|
||||||
// Should have two different has-many relationships with unique names
|
// Should have two different has-many relationships with unique names
|
||||||
hasManyExpectations := []string{
|
hasManyExpectations := []string{
|
||||||
"RelRIDFilepointerRequestAPIEvents", // Has many via rid_filepointer_request
|
"RelRIDFilepointerRequestOrgAPIEvents", // Has many via rid_filepointer_request
|
||||||
"RelRIDFilepointerResponseAPIEvents", // Has many via rid_filepointer_response
|
"RelRIDFilepointerResponseOrgAPIEvents", // Has many via rid_filepointer_response
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, exp := range hasManyExpectations {
|
for _, exp := range hasManyExpectations {
|
||||||
@@ -444,10 +444,10 @@ func TestWriter_MultipleHasManyRelationships(t *testing.T) {
|
|||||||
|
|
||||||
// Verify all has-many relationships have unique names
|
// Verify all has-many relationships have unique names
|
||||||
hasManyExpectations := []string{
|
hasManyExpectations := []string{
|
||||||
"RelRIDAPIProviderLogins", // Has many via Login
|
"RelRIDAPIProviderOrgLogins", // Has many via Login
|
||||||
"RelRIDAPIProviderFilepointers", // Has many via Filepointer
|
"RelRIDAPIProviderOrgFilepointers", // Has many via Filepointer
|
||||||
"RelRIDAPIProviderAPIEvents", // Has many via APIEvent
|
"RelRIDAPIProviderOrgAPIEvents", // Has many via APIEvent
|
||||||
"RelRIDOwner", // Belongs to via rid_owner
|
"RelRIDOwner", // Belongs to via rid_owner
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, exp := range hasManyExpectations {
|
for _, exp := range hasManyExpectations {
|
||||||
@@ -470,6 +470,134 @@ func TestWriter_MultipleHasManyRelationships(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestWriter_FieldNameCollision(t *testing.T) {
|
||||||
|
// Test scenario: table with columns that would conflict with generated method names
|
||||||
|
table := models.InitTable("audit_table", "audit")
|
||||||
|
table.Columns["id_audit_table"] = &models.Column{
|
||||||
|
Name: "id_audit_table",
|
||||||
|
Type: "smallint",
|
||||||
|
NotNull: true,
|
||||||
|
IsPrimaryKey: true,
|
||||||
|
Sequence: 1,
|
||||||
|
}
|
||||||
|
table.Columns["table_name"] = &models.Column{
|
||||||
|
Name: "table_name",
|
||||||
|
Type: "varchar",
|
||||||
|
Length: 100,
|
||||||
|
NotNull: true,
|
||||||
|
Sequence: 2,
|
||||||
|
}
|
||||||
|
table.Columns["table_schema"] = &models.Column{
|
||||||
|
Name: "table_schema",
|
||||||
|
Type: "varchar",
|
||||||
|
Length: 100,
|
||||||
|
NotNull: true,
|
||||||
|
Sequence: 3,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create writer
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
opts := &writers.WriterOptions{
|
||||||
|
PackageName: "models",
|
||||||
|
OutputPath: filepath.Join(tmpDir, "test.go"),
|
||||||
|
}
|
||||||
|
|
||||||
|
writer := NewWriter(opts)
|
||||||
|
|
||||||
|
err := writer.WriteTable(table)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("WriteTable failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read the generated file
|
||||||
|
content, err := os.ReadFile(opts.OutputPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to read generated file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
generated := string(content)
|
||||||
|
|
||||||
|
// Verify that TableName field was renamed to TableName_ to avoid collision
|
||||||
|
if !strings.Contains(generated, "TableName_") {
|
||||||
|
t.Errorf("Expected field 'TableName_' (with underscore) but not found\nGenerated:\n%s", generated)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the struct tag still references the correct database column
|
||||||
|
if !strings.Contains(generated, `gorm:"column:table_name;`) {
|
||||||
|
t.Errorf("Expected gorm tag to reference 'table_name' column\nGenerated:\n%s", generated)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the TableName() method still exists and doesn't conflict
|
||||||
|
if !strings.Contains(generated, "func (m ModelAuditAuditTable) TableName() string") {
|
||||||
|
t.Errorf("TableName() method should still be generated\nGenerated:\n%s", generated)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify NO field named just "TableName" (without underscore)
|
||||||
|
if strings.Contains(generated, "TableName sql_types") || strings.Contains(generated, "TableName string") {
|
||||||
|
t.Errorf("Field 'TableName' without underscore should not exist (would conflict with method)\nGenerated:\n%s", generated)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWriter_UpdateIDTypeSafety(t *testing.T) {
|
||||||
|
// Test scenario: tables with different primary key types
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
pkType string
|
||||||
|
expectedPK string
|
||||||
|
castType string
|
||||||
|
}{
|
||||||
|
{"int32_pk", "int", "int32", "int32(newid)"},
|
||||||
|
{"int16_pk", "smallint", "int16", "int16(newid)"},
|
||||||
|
{"int64_pk", "bigint", "int64", "int64(newid)"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
table := models.InitTable("test_table", "public")
|
||||||
|
table.Columns["id"] = &models.Column{
|
||||||
|
Name: "id",
|
||||||
|
Type: tt.pkType,
|
||||||
|
NotNull: true,
|
||||||
|
IsPrimaryKey: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
opts := &writers.WriterOptions{
|
||||||
|
PackageName: "models",
|
||||||
|
OutputPath: filepath.Join(tmpDir, "test.go"),
|
||||||
|
}
|
||||||
|
|
||||||
|
writer := NewWriter(opts)
|
||||||
|
err := writer.WriteTable(table)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("WriteTable failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
content, err := os.ReadFile(opts.OutputPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to read generated file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
generated := string(content)
|
||||||
|
|
||||||
|
// Verify UpdateID method has correct type cast
|
||||||
|
if !strings.Contains(generated, tt.castType) {
|
||||||
|
t.Errorf("Expected UpdateID to cast to %s\nGenerated:\n%s", tt.castType, generated)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify no invalid int32(newid) for non-int32 types
|
||||||
|
if tt.expectedPK != "int32" && strings.Contains(generated, "int32(newid)") {
|
||||||
|
t.Errorf("UpdateID should not cast to int32 for %s type\nGenerated:\n%s", tt.pkType, generated)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify UpdateID parameter is int64 (for consistency)
|
||||||
|
if !strings.Contains(generated, "UpdateID(newid int64)") {
|
||||||
|
t.Errorf("UpdateID should accept int64 parameter\nGenerated:\n%s", generated)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestNameConverter_SnakeCaseToPascalCase(t *testing.T) {
|
func TestNameConverter_SnakeCaseToPascalCase(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
input string
|
input string
|
||||||
|
|||||||
217
pkg/writers/pgsql/NAMING_CONVENTIONS.md
Normal file
217
pkg/writers/pgsql/NAMING_CONVENTIONS.md
Normal file
@@ -0,0 +1,217 @@
|
|||||||
|
# PostgreSQL Naming Conventions
|
||||||
|
|
||||||
|
Standardized naming rules for all database objects in RelSpec PostgreSQL output.
|
||||||
|
|
||||||
|
## Quick Reference
|
||||||
|
|
||||||
|
| Object Type | Prefix | Format | Example |
|
||||||
|
| ----------------- | ----------- | ---------------------------------- | ------------------------ |
|
||||||
|
| Primary Key | `pk_` | `pk_<schema>_<table>` | `pk_public_users` |
|
||||||
|
| Foreign Key | `fk_` | `fk_<table>_<referenced_table>` | `fk_posts_users` |
|
||||||
|
| Unique Constraint | `uk_` | `uk_<table>_<column>` | `uk_users_email` |
|
||||||
|
| Unique Index | `uidx_` | `uidx_<table>_<column>` | `uidx_users_email` |
|
||||||
|
| Regular Index | `idx_` | `idx_<table>_<column>` | `idx_posts_user_id` |
|
||||||
|
| Check Constraint | `chk_` | `chk_<table>_<constraint_purpose>` | `chk_users_age_positive` |
|
||||||
|
| Sequence | `identity_` | `identity_<table>_<column>` | `identity_users_id` |
|
||||||
|
| Trigger | `t_` | `t_<purpose>_<table>` | `t_audit_users` |
|
||||||
|
| Trigger Function | `tf_` | `tf_<purpose>_<table>` | `tf_audit_users` |
|
||||||
|
|
||||||
|
## Naming Rules by Object Type
|
||||||
|
|
||||||
|
### Primary Keys
|
||||||
|
|
||||||
|
**Pattern:** `pk_<schema>_<table>`
|
||||||
|
|
||||||
|
- Include schema name to avoid collisions across schemas
|
||||||
|
- Use lowercase, snake_case format
|
||||||
|
- Examples:
|
||||||
|
- `pk_public_users`
|
||||||
|
- `pk_audit_audit_log`
|
||||||
|
- `pk_staging_temp_data`
|
||||||
|
|
||||||
|
### Foreign Keys
|
||||||
|
|
||||||
|
**Pattern:** `fk_<table>_<referenced_table>`
|
||||||
|
|
||||||
|
- Reference the table containing the FK followed by the referenced table
|
||||||
|
- Use lowercase, snake_case format
|
||||||
|
- Do NOT include column names in standard FK constraints
|
||||||
|
- Examples:
|
||||||
|
- `fk_posts_users` (posts.user_id → users.id)
|
||||||
|
- `fk_comments_posts` (comments.post_id → posts.id)
|
||||||
|
- `fk_order_items_orders` (order_items.order_id → orders.id)
|
||||||
|
|
||||||
|
### Unique Constraints
|
||||||
|
|
||||||
|
**Pattern:** `uk_<table>_<column>`
|
||||||
|
|
||||||
|
- Use `uk_` prefix strictly for database constraints (CONSTRAINT type)
|
||||||
|
- Include column name for clarity
|
||||||
|
- Examples:
|
||||||
|
- `uk_users_email`
|
||||||
|
- `uk_users_username`
|
||||||
|
- `uk_products_sku`
|
||||||
|
|
||||||
|
### Unique Indexes
|
||||||
|
|
||||||
|
**Pattern:** `uidx_<table>_<column>`
|
||||||
|
|
||||||
|
- Use `uidx_` prefix strictly for index type objects
|
||||||
|
- Distinguished from constraints for clarity and implementation flexibility
|
||||||
|
- Examples:
|
||||||
|
- `uidx_users_email`
|
||||||
|
- `uidx_sessions_token`
|
||||||
|
- `uidx_api_keys_key`
|
||||||
|
|
||||||
|
### Regular Indexes
|
||||||
|
|
||||||
|
**Pattern:** `idx_<table>_<column>`
|
||||||
|
|
||||||
|
- Standard indexes for query optimization
|
||||||
|
- Single column: `idx_<table>_<column>`
|
||||||
|
- Examples:
|
||||||
|
- `idx_posts_user_id`
|
||||||
|
- `idx_orders_created_at`
|
||||||
|
- `idx_users_status`
|
||||||
|
|
||||||
|
### Check Constraints
|
||||||
|
|
||||||
|
**Pattern:** `chk_<table>_<constraint_purpose>`
|
||||||
|
|
||||||
|
- Describe the constraint validation purpose
|
||||||
|
- Use lowercase, snake_case for the purpose
|
||||||
|
- Examples:
|
||||||
|
- `chk_users_age_positive` (CHECK (age > 0))
|
||||||
|
- `chk_orders_quantity_positive` (CHECK (quantity > 0))
|
||||||
|
- `chk_products_price_valid` (CHECK (price >= 0))
|
||||||
|
- `chk_users_status_enum` (CHECK (status IN ('active', 'inactive')))
|
||||||
|
|
||||||
|
### Sequences
|
||||||
|
|
||||||
|
**Pattern:** `identity_<table>_<column>`
|
||||||
|
|
||||||
|
- Used for SERIAL/IDENTITY columns
|
||||||
|
- Explicitly named for clarity and management
|
||||||
|
- Examples:
|
||||||
|
- `identity_users_id`
|
||||||
|
- `identity_posts_id`
|
||||||
|
- `identity_transactions_id`
|
||||||
|
|
||||||
|
### Triggers
|
||||||
|
|
||||||
|
**Pattern:** `t_<purpose>_<table>`
|
||||||
|
|
||||||
|
- Include purpose before table name
|
||||||
|
- Lowercase, snake_case format
|
||||||
|
- Examples:
|
||||||
|
- `t_audit_users` (audit trigger on users table)
|
||||||
|
- `t_update_timestamp_posts` (timestamp update trigger on posts)
|
||||||
|
- `t_validate_orders` (validation trigger on orders)
|
||||||
|
|
||||||
|
### Trigger Functions
|
||||||
|
|
||||||
|
**Pattern:** `tf_<purpose>_<table>`
|
||||||
|
|
||||||
|
- Pair with trigger naming convention
|
||||||
|
- Use `tf_` prefix to distinguish from triggers themselves
|
||||||
|
- Examples:
|
||||||
|
- `tf_audit_users` (function for t_audit_users)
|
||||||
|
- `tf_update_timestamp_posts` (function for t_update_timestamp_posts)
|
||||||
|
- `tf_validate_orders` (function for t_validate_orders)
|
||||||
|
|
||||||
|
## Multi-Column Objects
|
||||||
|
|
||||||
|
### Composite Primary Keys
|
||||||
|
|
||||||
|
**Pattern:** `pk_<schema>_<table>`
|
||||||
|
|
||||||
|
- Same as single-column PKs
|
||||||
|
- Example: `pk_public_order_items` (composite key on order_id + item_id)
|
||||||
|
|
||||||
|
### Composite Unique Constraints
|
||||||
|
|
||||||
|
**Pattern:** `uk_<table>_<column1>_<column2>_[...]`
|
||||||
|
|
||||||
|
- Append all column names in order
|
||||||
|
- Examples:
|
||||||
|
- `uk_users_email_domain` (UNIQUE(email, domain))
|
||||||
|
- `uk_inventory_warehouse_sku` (UNIQUE(warehouse_id, sku))
|
||||||
|
|
||||||
|
### Composite Unique Indexes
|
||||||
|
|
||||||
|
**Pattern:** `uidx_<table>_<column1>_<column2>_[...]`
|
||||||
|
|
||||||
|
- Append all column names in order
|
||||||
|
- Examples:
|
||||||
|
- `uidx_users_first_name_last_name` (UNIQUE INDEX on first_name, last_name)
|
||||||
|
- `uidx_sessions_user_id_device_id` (UNIQUE INDEX on user_id, device_id)
|
||||||
|
|
||||||
|
### Composite Regular Indexes
|
||||||
|
|
||||||
|
**Pattern:** `idx_<table>_<column1>_<column2>_[...]`
|
||||||
|
|
||||||
|
- Append all column names in order
|
||||||
|
- List columns in typical query filter order
|
||||||
|
- Examples:
|
||||||
|
- `idx_orders_user_id_created_at` (filter by user, then sort by created_at)
|
||||||
|
- `idx_logs_level_timestamp` (filter by level, then by timestamp)
|
||||||
|
|
||||||
|
## Special Cases & Conventions
|
||||||
|
|
||||||
|
### Audit Trail Tables
|
||||||
|
|
||||||
|
- Audit table naming: `<original_table>_audit` or `audit_<original_table>`
|
||||||
|
- Audit indexes follow standard pattern: `idx_<audit_table>_<column>`
|
||||||
|
- Examples:
|
||||||
|
- Users table audit: `users_audit` with `idx_users_audit_tablename`, `idx_users_audit_changedate`
|
||||||
|
- Posts table audit: `posts_audit` with `idx_posts_audit_tablename`, `idx_posts_audit_changedate`
|
||||||
|
|
||||||
|
### Temporal/Versioning Tables
|
||||||
|
|
||||||
|
- Use suffix `_history` or `_versions` if needed
|
||||||
|
- Apply standard naming rules with the full table name
|
||||||
|
- Examples:
|
||||||
|
- `idx_users_history_user_id`
|
||||||
|
- `uk_posts_versions_version_number`
|
||||||
|
|
||||||
|
### Schema-Specific Objects
|
||||||
|
|
||||||
|
- Always qualify with schema when needed: `pk_<schema>_<table>`
|
||||||
|
- Multiple schemas allowed: `pk_public_users`, `pk_staging_users`
|
||||||
|
|
||||||
|
### Reserved Words & Special Names
|
||||||
|
|
||||||
|
- Avoid PostgreSQL reserved keywords in object names
|
||||||
|
- If column/table names conflict, use quoted identifiers in DDL
|
||||||
|
- Naming convention rules still apply to the logical name
|
||||||
|
|
||||||
|
### Generated/Anonymous Indexes
|
||||||
|
|
||||||
|
- If an index lacks explicit naming, default to: `idx_<schema>_<table>`
|
||||||
|
- Should be replaced with explicit names following standards
|
||||||
|
- Examples (to be renamed):
|
||||||
|
- `idx_public_users` → should be `idx_users_<column>`
|
||||||
|
|
||||||
|
## Implementation Notes
|
||||||
|
|
||||||
|
### Code Generation
|
||||||
|
|
||||||
|
- Names are always lowercase in generated SQL
|
||||||
|
- Underscore separators are required
|
||||||
|
|
||||||
|
### Migration Safety
|
||||||
|
|
||||||
|
- Do NOT rename objects after creation without explicit migration
|
||||||
|
- Names should be consistent across all schema versions
|
||||||
|
- Test generated DDL against PostgreSQL before deployment
|
||||||
|
|
||||||
|
### Testing
|
||||||
|
|
||||||
|
- Ensure consistency across all table and constraint generation
|
||||||
|
- Test with reserved words to verify escaping
|
||||||
|
|
||||||
|
## Related Documentation
|
||||||
|
|
||||||
|
- PostgreSQL Identifier Rules: https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-IDENTIFIERS
|
||||||
|
- Constraint Documentation: https://www.postgresql.org/docs/current/ddl-constraints.html
|
||||||
|
- Index Documentation: https://www.postgresql.org/docs/current/indexes.html
|
||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||||
|
"git.warky.dev/wdevs/relspecgo/pkg/pgsql"
|
||||||
"git.warky.dev/wdevs/relspecgo/pkg/writers"
|
"git.warky.dev/wdevs/relspecgo/pkg/writers"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -30,7 +31,7 @@ type MigrationWriter struct {
|
|||||||
|
|
||||||
// NewMigrationWriter creates a new templated migration writer
|
// NewMigrationWriter creates a new templated migration writer
|
||||||
func NewMigrationWriter(options *writers.WriterOptions) (*MigrationWriter, error) {
|
func NewMigrationWriter(options *writers.WriterOptions) (*MigrationWriter, error) {
|
||||||
executor, err := NewTemplateExecutor()
|
executor, err := NewTemplateExecutor(options.FlattenSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to create template executor: %w", err)
|
return nil, fmt.Errorf("failed to create template executor: %w", err)
|
||||||
}
|
}
|
||||||
@@ -335,7 +336,7 @@ func (w *MigrationWriter) generateAlterTableScripts(schema *models.Schema, model
|
|||||||
SchemaName: schema.Name,
|
SchemaName: schema.Name,
|
||||||
TableName: modelTable.Name,
|
TableName: modelTable.Name,
|
||||||
ColumnName: modelCol.Name,
|
ColumnName: modelCol.Name,
|
||||||
ColumnType: modelCol.Type,
|
ColumnType: pgsql.ConvertSQLType(modelCol.Type),
|
||||||
Default: defaultVal,
|
Default: defaultVal,
|
||||||
NotNull: modelCol.NotNull,
|
NotNull: modelCol.NotNull,
|
||||||
})
|
})
|
||||||
@@ -359,7 +360,7 @@ func (w *MigrationWriter) generateAlterTableScripts(schema *models.Schema, model
|
|||||||
SchemaName: schema.Name,
|
SchemaName: schema.Name,
|
||||||
TableName: modelTable.Name,
|
TableName: modelTable.Name,
|
||||||
ColumnName: modelCol.Name,
|
ColumnName: modelCol.Name,
|
||||||
NewType: modelCol.Type,
|
NewType: pgsql.ConvertSQLType(modelCol.Type),
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -427,9 +428,11 @@ func (w *MigrationWriter) generateIndexScripts(model *models.Schema, current *mo
|
|||||||
for _, modelTable := range model.Tables {
|
for _, modelTable := range model.Tables {
|
||||||
currentTable := currentTables[strings.ToLower(modelTable.Name)]
|
currentTable := currentTables[strings.ToLower(modelTable.Name)]
|
||||||
|
|
||||||
// Process primary keys first
|
// Process primary keys first - check explicit constraints
|
||||||
|
foundExplicitPK := false
|
||||||
for constraintName, constraint := range modelTable.Constraints {
|
for constraintName, constraint := range modelTable.Constraints {
|
||||||
if constraint.Type == models.PrimaryKeyConstraint {
|
if constraint.Type == models.PrimaryKeyConstraint {
|
||||||
|
foundExplicitPK = true
|
||||||
shouldCreate := true
|
shouldCreate := true
|
||||||
|
|
||||||
if currentTable != nil {
|
if currentTable != nil {
|
||||||
@@ -464,6 +467,53 @@ func (w *MigrationWriter) generateIndexScripts(model *models.Schema, current *mo
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If no explicit PK constraint, check for columns with IsPrimaryKey = true
|
||||||
|
if !foundExplicitPK {
|
||||||
|
pkColumns := []string{}
|
||||||
|
for _, col := range modelTable.Columns {
|
||||||
|
if col.IsPrimaryKey {
|
||||||
|
pkColumns = append(pkColumns, col.SQLName())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(pkColumns) > 0 {
|
||||||
|
sort.Strings(pkColumns)
|
||||||
|
constraintName := fmt.Sprintf("pk_%s_%s", model.SQLName(), modelTable.SQLName())
|
||||||
|
shouldCreate := true
|
||||||
|
|
||||||
|
if currentTable != nil {
|
||||||
|
// Check if a PK constraint already exists (by any name)
|
||||||
|
for _, constraint := range currentTable.Constraints {
|
||||||
|
if constraint.Type == models.PrimaryKeyConstraint {
|
||||||
|
shouldCreate = false
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if shouldCreate {
|
||||||
|
sql, err := w.executor.ExecuteCreatePrimaryKey(CreatePrimaryKeyData{
|
||||||
|
SchemaName: model.Name,
|
||||||
|
TableName: modelTable.Name,
|
||||||
|
ConstraintName: constraintName,
|
||||||
|
Columns: strings.Join(pkColumns, ", "),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
script := MigrationScript{
|
||||||
|
ObjectName: fmt.Sprintf("%s.%s.%s", model.Name, modelTable.Name, constraintName),
|
||||||
|
ObjectType: "create primary key",
|
||||||
|
Schema: model.Name,
|
||||||
|
Priority: 160,
|
||||||
|
Sequence: len(scripts),
|
||||||
|
Body: sql,
|
||||||
|
}
|
||||||
|
scripts = append(scripts, script)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Process indexes
|
// Process indexes
|
||||||
for indexName, modelIndex := range modelTable.Indexes {
|
for indexName, modelIndex := range modelTable.Indexes {
|
||||||
// Skip primary key indexes
|
// Skip primary key indexes
|
||||||
@@ -703,7 +753,7 @@ func (w *MigrationWriter) generateAuditScripts(schema *models.Schema, auditConfi
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Generate audit function
|
// Generate audit function
|
||||||
funcName := fmt.Sprintf("ft_audit_%s", table.Name)
|
funcName := fmt.Sprintf("tf_audit_%s", table.Name)
|
||||||
funcData := BuildAuditFunctionData(schema.Name, table, pk, config, auditSchema, auditConfig.UserFunction)
|
funcData := BuildAuditFunctionData(schema.Name, table, pk, config, auditSchema, auditConfig.UserFunction)
|
||||||
|
|
||||||
funcSQL, err := w.executor.ExecuteAuditFunction(funcData)
|
funcSQL, err := w.executor.ExecuteAuditFunction(funcData)
|
||||||
|
|||||||
@@ -121,7 +121,7 @@ func TestWriteMigration_WithAudit(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Verify audit function
|
// Verify audit function
|
||||||
if !strings.Contains(output, "CREATE OR REPLACE FUNCTION public.ft_audit_users()") {
|
if !strings.Contains(output, "CREATE OR REPLACE FUNCTION public.tf_audit_users()") {
|
||||||
t.Error("Migration missing audit function")
|
t.Error("Migration missing audit function")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -137,7 +137,7 @@ func TestWriteMigration_WithAudit(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestTemplateExecutor_CreateTable(t *testing.T) {
|
func TestTemplateExecutor_CreateTable(t *testing.T) {
|
||||||
executor, err := NewTemplateExecutor()
|
executor, err := NewTemplateExecutor(false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to create executor: %v", err)
|
t.Fatalf("Failed to create executor: %v", err)
|
||||||
}
|
}
|
||||||
@@ -170,14 +170,14 @@ func TestTemplateExecutor_CreateTable(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestTemplateExecutor_AuditFunction(t *testing.T) {
|
func TestTemplateExecutor_AuditFunction(t *testing.T) {
|
||||||
executor, err := NewTemplateExecutor()
|
executor, err := NewTemplateExecutor(false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to create executor: %v", err)
|
t.Fatalf("Failed to create executor: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
data := AuditFunctionData{
|
data := AuditFunctionData{
|
||||||
SchemaName: "public",
|
SchemaName: "public",
|
||||||
FunctionName: "ft_audit_users",
|
FunctionName: "tf_audit_users",
|
||||||
TableName: "users",
|
TableName: "users",
|
||||||
TablePrefix: "NULL",
|
TablePrefix: "NULL",
|
||||||
PrimaryKey: "id",
|
PrimaryKey: "id",
|
||||||
@@ -202,7 +202,7 @@ func TestTemplateExecutor_AuditFunction(t *testing.T) {
|
|||||||
|
|
||||||
t.Logf("Generated SQL:\n%s", sql)
|
t.Logf("Generated SQL:\n%s", sql)
|
||||||
|
|
||||||
if !strings.Contains(sql, "CREATE OR REPLACE FUNCTION public.ft_audit_users()") {
|
if !strings.Contains(sql, "CREATE OR REPLACE FUNCTION public.tf_audit_users()") {
|
||||||
t.Error("SQL missing function definition")
|
t.Error("SQL missing function definition")
|
||||||
}
|
}
|
||||||
if !strings.Contains(sql, "IF TG_OP = 'INSERT'") {
|
if !strings.Contains(sql, "IF TG_OP = 'INSERT'") {
|
||||||
@@ -215,3 +215,70 @@ func TestTemplateExecutor_AuditFunction(t *testing.T) {
|
|||||||
t.Error("SQL missing DELETE handling")
|
t.Error("SQL missing DELETE handling")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestWriteMigration_NumericConstraintNames(t *testing.T) {
|
||||||
|
// Current database (empty)
|
||||||
|
current := models.InitDatabase("testdb")
|
||||||
|
currentSchema := models.InitSchema("entity")
|
||||||
|
current.Schemas = append(current.Schemas, currentSchema)
|
||||||
|
|
||||||
|
// Model database (with constraint starting with number)
|
||||||
|
model := models.InitDatabase("testdb")
|
||||||
|
modelSchema := models.InitSchema("entity")
|
||||||
|
|
||||||
|
// Create individual_actor_relationship table
|
||||||
|
table := models.InitTable("individual_actor_relationship", "entity")
|
||||||
|
idCol := models.InitColumn("id", "individual_actor_relationship", "entity")
|
||||||
|
idCol.Type = "integer"
|
||||||
|
idCol.IsPrimaryKey = true
|
||||||
|
table.Columns["id"] = idCol
|
||||||
|
|
||||||
|
actorIDCol := models.InitColumn("actor_id", "individual_actor_relationship", "entity")
|
||||||
|
actorIDCol.Type = "integer"
|
||||||
|
table.Columns["actor_id"] = actorIDCol
|
||||||
|
|
||||||
|
// Add constraint with name starting with number
|
||||||
|
constraint := &models.Constraint{
|
||||||
|
Name: "215162_fk_actor",
|
||||||
|
Type: models.ForeignKeyConstraint,
|
||||||
|
Columns: []string{"actor_id"},
|
||||||
|
ReferencedSchema: "entity",
|
||||||
|
ReferencedTable: "actor",
|
||||||
|
ReferencedColumns: []string{"id"},
|
||||||
|
OnDelete: "CASCADE",
|
||||||
|
OnUpdate: "NO ACTION",
|
||||||
|
}
|
||||||
|
table.Constraints["215162_fk_actor"] = constraint
|
||||||
|
|
||||||
|
modelSchema.Tables = append(modelSchema.Tables, table)
|
||||||
|
model.Schemas = append(model.Schemas, modelSchema)
|
||||||
|
|
||||||
|
// Generate migration
|
||||||
|
var buf bytes.Buffer
|
||||||
|
writer, err := NewMigrationWriter(&writers.WriterOptions{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create writer: %v", err)
|
||||||
|
}
|
||||||
|
writer.writer = &buf
|
||||||
|
|
||||||
|
err = writer.WriteMigration(model, current)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("WriteMigration failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
output := buf.String()
|
||||||
|
t.Logf("Generated migration:\n%s", output)
|
||||||
|
|
||||||
|
// Verify constraint name is properly quoted
|
||||||
|
if !strings.Contains(output, `"215162_fk_actor"`) {
|
||||||
|
t.Error("Constraint name starting with number should be quoted")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the SQL is syntactically correct (contains required keywords)
|
||||||
|
if !strings.Contains(output, "ADD CONSTRAINT") {
|
||||||
|
t.Error("Migration missing ADD CONSTRAINT")
|
||||||
|
}
|
||||||
|
if !strings.Contains(output, "FOREIGN KEY") {
|
||||||
|
t.Error("Migration missing FOREIGN KEY")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ func TemplateFunctions() map[string]interface{} {
|
|||||||
"quote": quote,
|
"quote": quote,
|
||||||
"escape": escape,
|
"escape": escape,
|
||||||
"safe_identifier": safeIdentifier,
|
"safe_identifier": safeIdentifier,
|
||||||
|
"quote_ident": quoteIdent,
|
||||||
|
|
||||||
// Type conversion
|
// Type conversion
|
||||||
"goTypeToSQL": goTypeToSQL,
|
"goTypeToSQL": goTypeToSQL,
|
||||||
@@ -122,6 +123,43 @@ func safeIdentifier(s string) string {
|
|||||||
return strings.ToLower(safe)
|
return strings.ToLower(safe)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// quoteIdent quotes a PostgreSQL identifier if necessary
|
||||||
|
// Identifiers need quoting if they:
|
||||||
|
// - Start with a digit
|
||||||
|
// - Contain special characters
|
||||||
|
// - Are reserved keywords
|
||||||
|
// - Contain uppercase letters (to preserve case)
|
||||||
|
func quoteIdent(s string) string {
|
||||||
|
if s == "" {
|
||||||
|
return `""`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if quoting is needed
|
||||||
|
needsQuoting := unicode.IsDigit(rune(s[0]))
|
||||||
|
|
||||||
|
// Starts with digit
|
||||||
|
|
||||||
|
// Contains uppercase letters or special characters
|
||||||
|
for _, r := range s {
|
||||||
|
if unicode.IsUpper(r) {
|
||||||
|
needsQuoting = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if !unicode.IsLetter(r) && !unicode.IsDigit(r) && r != '_' {
|
||||||
|
needsQuoting = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if needsQuoting {
|
||||||
|
// Escape double quotes by doubling them
|
||||||
|
escaped := strings.ReplaceAll(s, `"`, `""`)
|
||||||
|
return `"` + escaped + `"`
|
||||||
|
}
|
||||||
|
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
// Type conversion functions
|
// Type conversion functions
|
||||||
|
|
||||||
// goTypeToSQL converts Go type to PostgreSQL type
|
// goTypeToSQL converts Go type to PostgreSQL type
|
||||||
|
|||||||
@@ -101,6 +101,31 @@ func TestSafeIdentifier(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestQuoteIdent(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
input string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{"valid_name", "valid_name"},
|
||||||
|
{"ValidName", `"ValidName"`},
|
||||||
|
{"123column", `"123column"`},
|
||||||
|
{"215162_fk_constraint", `"215162_fk_constraint"`},
|
||||||
|
{"user-id", `"user-id"`},
|
||||||
|
{"user@domain", `"user@domain"`},
|
||||||
|
{`"quoted"`, `"""quoted"""`},
|
||||||
|
{"", `""`},
|
||||||
|
{"lowercase", "lowercase"},
|
||||||
|
{"with_underscore", "with_underscore"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
result := quoteIdent(tt.input)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("quoteIdent(%q) = %q, want %q", tt.input, result, tt.expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestGoTypeToSQL(t *testing.T) {
|
func TestGoTypeToSQL(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
input string
|
input string
|
||||||
@@ -243,7 +268,7 @@ func TestTemplateFunctions(t *testing.T) {
|
|||||||
// Check that all expected functions are registered
|
// Check that all expected functions are registered
|
||||||
expectedFuncs := []string{
|
expectedFuncs := []string{
|
||||||
"upper", "lower", "snake_case", "camelCase",
|
"upper", "lower", "snake_case", "camelCase",
|
||||||
"indent", "quote", "escape", "safe_identifier",
|
"indent", "quote", "escape", "safe_identifier", "quote_ident",
|
||||||
"goTypeToSQL", "sqlTypeToGo", "isNumeric", "isText",
|
"goTypeToSQL", "sqlTypeToGo", "isNumeric", "isText",
|
||||||
"first", "last", "filter", "mapFunc", "join_with",
|
"first", "last", "filter", "mapFunc", "join_with",
|
||||||
"join",
|
"join",
|
||||||
@@ -289,7 +314,7 @@ func TestFormatType(t *testing.T) {
|
|||||||
|
|
||||||
// Test that template functions work in actual templates
|
// Test that template functions work in actual templates
|
||||||
func TestTemplateFunctionsInTemplate(t *testing.T) {
|
func TestTemplateFunctionsInTemplate(t *testing.T) {
|
||||||
executor, err := NewTemplateExecutor()
|
executor, err := NewTemplateExecutor(false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to create executor: %v", err)
|
t.Fatalf("Failed to create executor: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -18,14 +18,39 @@ type TemplateExecutor struct {
|
|||||||
templates *template.Template
|
templates *template.Template
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewTemplateExecutor creates a new template executor
|
// NewTemplateExecutor creates a new template executor.
|
||||||
func NewTemplateExecutor() (*TemplateExecutor, error) {
|
// flattenSchema controls whether schema.table identifiers use dot or underscore separation.
|
||||||
|
func NewTemplateExecutor(flattenSchema bool) (*TemplateExecutor, error) {
|
||||||
// Create template with custom functions
|
// Create template with custom functions
|
||||||
funcMap := make(template.FuncMap)
|
funcMap := make(template.FuncMap)
|
||||||
for k, v := range TemplateFunctions() {
|
for k, v := range TemplateFunctions() {
|
||||||
funcMap[k] = v
|
funcMap[k] = v
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// qual_table returns a quoted, schema-qualified identifier.
|
||||||
|
// With flatten=false: "schema"."table" (or unquoted equivalents).
|
||||||
|
// With flatten=true: "schema_table".
|
||||||
|
funcMap["qual_table"] = func(schema, name string) string {
|
||||||
|
if schema == "" {
|
||||||
|
return quoteIdent(name)
|
||||||
|
}
|
||||||
|
if flattenSchema {
|
||||||
|
return quoteIdent(schema + "_" + name)
|
||||||
|
}
|
||||||
|
return quoteIdent(schema) + "." + quoteIdent(name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// qual_table_raw is the same as qual_table but without identifier quoting.
|
||||||
|
funcMap["qual_table_raw"] = func(schema, name string) string {
|
||||||
|
if schema == "" {
|
||||||
|
return name
|
||||||
|
}
|
||||||
|
if flattenSchema {
|
||||||
|
return schema + "_" + name
|
||||||
|
}
|
||||||
|
return schema + "." + name
|
||||||
|
}
|
||||||
|
|
||||||
tmpl, err := template.New("").Funcs(funcMap).ParseFS(templateFS, "templates/*.tmpl")
|
tmpl, err := template.New("").Funcs(funcMap).ParseFS(templateFS, "templates/*.tmpl")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to parse templates: %w", err)
|
return nil, fmt.Errorf("failed to parse templates: %w", err)
|
||||||
@@ -177,6 +202,72 @@ type AuditTriggerData struct {
|
|||||||
Events string
|
Events string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CreateUniqueConstraintData contains data for create unique constraint template
|
||||||
|
type CreateUniqueConstraintData struct {
|
||||||
|
SchemaName string
|
||||||
|
TableName string
|
||||||
|
ConstraintName string
|
||||||
|
Columns string
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateCheckConstraintData contains data for create check constraint template
|
||||||
|
type CreateCheckConstraintData struct {
|
||||||
|
SchemaName string
|
||||||
|
TableName string
|
||||||
|
ConstraintName string
|
||||||
|
Expression string
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateForeignKeyWithCheckData contains data for create foreign key with existence check template
|
||||||
|
type CreateForeignKeyWithCheckData struct {
|
||||||
|
SchemaName string
|
||||||
|
TableName string
|
||||||
|
ConstraintName string
|
||||||
|
SourceColumns string
|
||||||
|
TargetSchema string
|
||||||
|
TargetTable string
|
||||||
|
TargetColumns string
|
||||||
|
OnDelete string
|
||||||
|
OnUpdate string
|
||||||
|
Deferrable bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetSequenceValueData contains data for set sequence value template
|
||||||
|
type SetSequenceValueData struct {
|
||||||
|
SchemaName string
|
||||||
|
TableName string
|
||||||
|
SequenceName string
|
||||||
|
ColumnName string
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateSequenceData contains data for create sequence template
|
||||||
|
type CreateSequenceData struct {
|
||||||
|
SchemaName string
|
||||||
|
SequenceName string
|
||||||
|
Increment int
|
||||||
|
MinValue int64
|
||||||
|
MaxValue int64
|
||||||
|
StartValue int64
|
||||||
|
CacheSize int
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddColumnWithCheckData contains data for add column with existence check template
|
||||||
|
type AddColumnWithCheckData struct {
|
||||||
|
SchemaName string
|
||||||
|
TableName string
|
||||||
|
ColumnName string
|
||||||
|
ColumnDefinition string
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreatePrimaryKeyWithAutoGenCheckData contains data for primary key with auto-generated key check template
|
||||||
|
type CreatePrimaryKeyWithAutoGenCheckData struct {
|
||||||
|
SchemaName string
|
||||||
|
TableName string
|
||||||
|
ConstraintName string
|
||||||
|
AutoGenNames string // Comma-separated list of names like "'name1', 'name2'"
|
||||||
|
Columns string
|
||||||
|
}
|
||||||
|
|
||||||
// Execute methods for each template
|
// Execute methods for each template
|
||||||
|
|
||||||
// ExecuteCreateTable executes the create table template
|
// ExecuteCreateTable executes the create table template
|
||||||
@@ -319,6 +410,76 @@ func (te *TemplateExecutor) ExecuteAuditTrigger(data AuditTriggerData) (string,
|
|||||||
return buf.String(), nil
|
return buf.String(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ExecuteCreateUniqueConstraint executes the create unique constraint template
|
||||||
|
func (te *TemplateExecutor) ExecuteCreateUniqueConstraint(data CreateUniqueConstraintData) (string, error) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
err := te.templates.ExecuteTemplate(&buf, "create_unique_constraint.tmpl", data)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to execute create_unique_constraint template: %w", err)
|
||||||
|
}
|
||||||
|
return buf.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExecuteCreateCheckConstraint executes the create check constraint template
|
||||||
|
func (te *TemplateExecutor) ExecuteCreateCheckConstraint(data CreateCheckConstraintData) (string, error) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
err := te.templates.ExecuteTemplate(&buf, "create_check_constraint.tmpl", data)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to execute create_check_constraint template: %w", err)
|
||||||
|
}
|
||||||
|
return buf.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExecuteCreateForeignKeyWithCheck executes the create foreign key with check template
|
||||||
|
func (te *TemplateExecutor) ExecuteCreateForeignKeyWithCheck(data CreateForeignKeyWithCheckData) (string, error) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
err := te.templates.ExecuteTemplate(&buf, "create_foreign_key_with_check.tmpl", data)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to execute create_foreign_key_with_check template: %w", err)
|
||||||
|
}
|
||||||
|
return buf.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExecuteSetSequenceValue executes the set sequence value template
|
||||||
|
func (te *TemplateExecutor) ExecuteSetSequenceValue(data SetSequenceValueData) (string, error) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
err := te.templates.ExecuteTemplate(&buf, "set_sequence_value.tmpl", data)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to execute set_sequence_value template: %w", err)
|
||||||
|
}
|
||||||
|
return buf.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExecuteCreateSequence executes the create sequence template
|
||||||
|
func (te *TemplateExecutor) ExecuteCreateSequence(data CreateSequenceData) (string, error) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
err := te.templates.ExecuteTemplate(&buf, "create_sequence.tmpl", data)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to execute create_sequence template: %w", err)
|
||||||
|
}
|
||||||
|
return buf.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExecuteAddColumnWithCheck executes the add column with check template
|
||||||
|
func (te *TemplateExecutor) ExecuteAddColumnWithCheck(data AddColumnWithCheckData) (string, error) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
err := te.templates.ExecuteTemplate(&buf, "add_column_with_check.tmpl", data)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to execute add_column_with_check template: %w", err)
|
||||||
|
}
|
||||||
|
return buf.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExecuteCreatePrimaryKeyWithAutoGenCheck executes the create primary key with auto-generated key check template
|
||||||
|
func (te *TemplateExecutor) ExecuteCreatePrimaryKeyWithAutoGenCheck(data CreatePrimaryKeyWithAutoGenCheckData) (string, error) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
err := te.templates.ExecuteTemplate(&buf, "create_primary_key_with_autogen_check.tmpl", data)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to execute create_primary_key_with_autogen_check template: %w", err)
|
||||||
|
}
|
||||||
|
return buf.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
// Helper functions to build template data from models
|
// Helper functions to build template data from models
|
||||||
|
|
||||||
// BuildCreateTableData builds CreateTableData from a models.Table
|
// BuildCreateTableData builds CreateTableData from a models.Table
|
||||||
@@ -355,7 +516,7 @@ func BuildAuditFunctionData(
|
|||||||
auditSchema string,
|
auditSchema string,
|
||||||
userFunction string,
|
userFunction string,
|
||||||
) AuditFunctionData {
|
) AuditFunctionData {
|
||||||
funcName := fmt.Sprintf("ft_audit_%s", table.Name)
|
funcName := fmt.Sprintf("tf_audit_%s", table.Name)
|
||||||
|
|
||||||
// Build list of audited columns
|
// Build list of audited columns
|
||||||
auditedColumns := make([]*models.Column, 0)
|
auditedColumns := make([]*models.Column, 0)
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
ALTER TABLE {{.SchemaName}}.{{.TableName}}
|
ALTER TABLE {{qual_table .SchemaName .TableName}}
|
||||||
ADD COLUMN IF NOT EXISTS {{.ColumnName}} {{.ColumnType}}
|
ADD COLUMN IF NOT EXISTS {{quote_ident .ColumnName}} {{.ColumnType}}
|
||||||
{{- if .Default}} DEFAULT {{.Default}}{{end}}
|
{{- if .Default}} DEFAULT {{.Default}}{{end}}
|
||||||
{{- if .NotNull}} NOT NULL{{end}};
|
{{- if .NotNull}} NOT NULL{{end}};
|
||||||
12
pkg/writers/pgsql/templates/add_column_with_check.tmpl
Normal file
12
pkg/writers/pgsql/templates/add_column_with_check.tmpl
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
DO $$
|
||||||
|
BEGIN
|
||||||
|
IF NOT EXISTS (
|
||||||
|
SELECT 1 FROM information_schema.columns
|
||||||
|
WHERE table_schema = '{{.SchemaName}}'
|
||||||
|
AND table_name = '{{.TableName}}'
|
||||||
|
AND column_name = '{{.ColumnName}}'
|
||||||
|
) THEN
|
||||||
|
ALTER TABLE {{qual_table .SchemaName .TableName}} ADD COLUMN {{.ColumnDefinition}};
|
||||||
|
END IF;
|
||||||
|
END;
|
||||||
|
$$;
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
{{- if .SetDefault -}}
|
{{- if .SetDefault -}}
|
||||||
ALTER TABLE {{.SchemaName}}.{{.TableName}}
|
ALTER TABLE {{qual_table .SchemaName .TableName}}
|
||||||
ALTER COLUMN {{.ColumnName}} SET DEFAULT {{.DefaultValue}};
|
ALTER COLUMN {{quote_ident .ColumnName}} SET DEFAULT {{.DefaultValue}};
|
||||||
{{- else -}}
|
{{- else -}}
|
||||||
ALTER TABLE {{.SchemaName}}.{{.TableName}}
|
ALTER TABLE {{qual_table .SchemaName .TableName}}
|
||||||
ALTER COLUMN {{.ColumnName}} DROP DEFAULT;
|
ALTER COLUMN {{quote_ident .ColumnName}} DROP DEFAULT;
|
||||||
{{- end -}}
|
{{- end -}}
|
||||||
@@ -1,2 +1,2 @@
|
|||||||
ALTER TABLE {{.SchemaName}}.{{.TableName}}
|
ALTER TABLE {{qual_table .SchemaName .TableName}}
|
||||||
ALTER COLUMN {{.ColumnName}} TYPE {{.NewType}};
|
ALTER COLUMN {{quote_ident .ColumnName}} TYPE {{.NewType}};
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
CREATE OR REPLACE FUNCTION {{.SchemaName}}.{{.FunctionName}}()
|
CREATE OR REPLACE FUNCTION {{qual_table_raw .SchemaName .FunctionName}}()
|
||||||
RETURNS trigger AS
|
RETURNS trigger AS
|
||||||
$body$
|
$body$
|
||||||
DECLARE
|
DECLARE
|
||||||
@@ -81,4 +81,4 @@ LANGUAGE plpgsql
|
|||||||
VOLATILE
|
VOLATILE
|
||||||
SECURITY DEFINER;
|
SECURITY DEFINER;
|
||||||
|
|
||||||
COMMENT ON FUNCTION {{.SchemaName}}.{{.FunctionName}}() IS 'Audit trigger function for table {{.SchemaName}}.{{.TableName}}';
|
COMMENT ON FUNCTION {{qual_table_raw .SchemaName .FunctionName}}() IS 'Audit trigger function for table {{qual_table_raw .SchemaName .TableName}}';
|
||||||
@@ -4,13 +4,13 @@ BEGIN
|
|||||||
SELECT 1
|
SELECT 1
|
||||||
FROM pg_trigger
|
FROM pg_trigger
|
||||||
WHERE tgname = '{{.TriggerName}}'
|
WHERE tgname = '{{.TriggerName}}'
|
||||||
AND tgrelid = '{{.SchemaName}}.{{.TableName}}'::regclass
|
AND tgrelid = '{{qual_table_raw .SchemaName .TableName}}'::regclass
|
||||||
) THEN
|
) THEN
|
||||||
CREATE TRIGGER {{.TriggerName}}
|
CREATE TRIGGER {{.TriggerName}}
|
||||||
AFTER {{.Events}}
|
AFTER {{.Events}}
|
||||||
ON {{.SchemaName}}.{{.TableName}}
|
ON {{qual_table_raw .SchemaName .TableName}}
|
||||||
FOR EACH ROW
|
FOR EACH ROW
|
||||||
EXECUTE FUNCTION {{.SchemaName}}.{{.FunctionName}}();
|
EXECUTE FUNCTION {{qual_table_raw .SchemaName .FunctionName}}();
|
||||||
END IF;
|
END IF;
|
||||||
END;
|
END;
|
||||||
$$;
|
$$;
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
{{/* Base constraint template */}}
|
{{/* Base constraint template */}}
|
||||||
{{- define "constraint_base" -}}
|
{{- define "constraint_base" -}}
|
||||||
ALTER TABLE {{.SchemaName}}.{{.TableName}}
|
ALTER TABLE {{qual_table_raw .SchemaName .TableName}}
|
||||||
ADD CONSTRAINT {{.ConstraintName}}
|
ADD CONSTRAINT {{.ConstraintName}}
|
||||||
{{block "constraint_definition" .}}{{end}};
|
{{block "constraint_definition" .}}{{end}};
|
||||||
{{- end -}}
|
{{- end -}}
|
||||||
@@ -15,7 +15,7 @@ BEGIN
|
|||||||
AND table_name = '{{.TableName}}'
|
AND table_name = '{{.TableName}}'
|
||||||
AND constraint_name = '{{.ConstraintName}}'
|
AND constraint_name = '{{.ConstraintName}}'
|
||||||
) THEN
|
) THEN
|
||||||
ALTER TABLE {{.SchemaName}}.{{.TableName}}
|
ALTER TABLE {{qual_table_raw .SchemaName .TableName}}
|
||||||
DROP CONSTRAINT {{.ConstraintName}};
|
DROP CONSTRAINT {{.ConstraintName}};
|
||||||
END IF;
|
END IF;
|
||||||
END;
|
END;
|
||||||
|
|||||||
@@ -11,7 +11,7 @@
|
|||||||
|
|
||||||
{{/* Base ALTER TABLE structure */}}
|
{{/* Base ALTER TABLE structure */}}
|
||||||
{{- define "alter_table_base" -}}
|
{{- define "alter_table_base" -}}
|
||||||
ALTER TABLE {{.SchemaName}}.{{.TableName}}
|
ALTER TABLE {{qual_table_raw .SchemaName .TableName}}
|
||||||
{{block "alter_operation" .}}{{end}};
|
{{block "alter_operation" .}}{{end}};
|
||||||
{{- end -}}
|
{{- end -}}
|
||||||
|
|
||||||
@@ -30,5 +30,5 @@ $$;
|
|||||||
|
|
||||||
{{/* Common drop pattern */}}
|
{{/* Common drop pattern */}}
|
||||||
{{- define "drop_if_exists" -}}
|
{{- define "drop_if_exists" -}}
|
||||||
{{block "drop_type" .}}{{end}} IF EXISTS {{.SchemaName}}.{{.ObjectName}};
|
{{block "drop_type" .}}{{end}} IF EXISTS {{qual_table_raw .SchemaName .ObjectName}};
|
||||||
{{- end -}}
|
{{- end -}}
|
||||||
@@ -1 +1 @@
|
|||||||
COMMENT ON COLUMN {{.SchemaName}}.{{.TableName}}.{{.ColumnName}} IS '{{.Comment}}';
|
COMMENT ON COLUMN {{qual_table .SchemaName .TableName}}.{{quote_ident .ColumnName}} IS '{{.Comment}}';
|
||||||
@@ -1 +1 @@
|
|||||||
COMMENT ON TABLE {{.SchemaName}}.{{.TableName}} IS '{{.Comment}}';
|
COMMENT ON TABLE {{qual_table .SchemaName .TableName}} IS '{{.Comment}}';
|
||||||
12
pkg/writers/pgsql/templates/create_check_constraint.tmpl
Normal file
12
pkg/writers/pgsql/templates/create_check_constraint.tmpl
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
DO $$
|
||||||
|
BEGIN
|
||||||
|
IF NOT EXISTS (
|
||||||
|
SELECT 1 FROM information_schema.table_constraints
|
||||||
|
WHERE table_schema = '{{.SchemaName}}'
|
||||||
|
AND table_name = '{{.TableName}}'
|
||||||
|
AND constraint_name = '{{.ConstraintName}}'
|
||||||
|
) THEN
|
||||||
|
ALTER TABLE {{qual_table .SchemaName .TableName}} ADD CONSTRAINT {{quote_ident .ConstraintName}} CHECK ({{.Expression}});
|
||||||
|
END IF;
|
||||||
|
END;
|
||||||
|
$$;
|
||||||
@@ -1,10 +1,10 @@
|
|||||||
ALTER TABLE {{.SchemaName}}.{{.TableName}}
|
ALTER TABLE {{qual_table .SchemaName .TableName}}
|
||||||
DROP CONSTRAINT IF EXISTS {{.ConstraintName}};
|
DROP CONSTRAINT IF EXISTS {{quote_ident .ConstraintName}};
|
||||||
|
|
||||||
ALTER TABLE {{.SchemaName}}.{{.TableName}}
|
ALTER TABLE {{qual_table .SchemaName .TableName}}
|
||||||
ADD CONSTRAINT {{.ConstraintName}}
|
ADD CONSTRAINT {{quote_ident .ConstraintName}}
|
||||||
FOREIGN KEY ({{.SourceColumns}})
|
FOREIGN KEY ({{.SourceColumns}})
|
||||||
REFERENCES {{.TargetSchema}}.{{.TargetTable}} ({{.TargetColumns}})
|
REFERENCES {{qual_table .TargetSchema .TargetTable}} ({{.TargetColumns}})
|
||||||
ON DELETE {{.OnDelete}}
|
ON DELETE {{.OnDelete}}
|
||||||
ON UPDATE {{.OnUpdate}}
|
ON UPDATE {{.OnUpdate}}
|
||||||
DEFERRABLE;
|
DEFERRABLE;
|
||||||
@@ -0,0 +1,18 @@
|
|||||||
|
DO $$
|
||||||
|
BEGIN
|
||||||
|
IF NOT EXISTS (
|
||||||
|
SELECT 1 FROM information_schema.table_constraints
|
||||||
|
WHERE table_schema = '{{.SchemaName}}'
|
||||||
|
AND table_name = '{{.TableName}}'
|
||||||
|
AND constraint_name = '{{.ConstraintName}}'
|
||||||
|
) THEN
|
||||||
|
ALTER TABLE {{qual_table .SchemaName .TableName}}
|
||||||
|
ADD CONSTRAINT {{quote_ident .ConstraintName}}
|
||||||
|
FOREIGN KEY ({{.SourceColumns}})
|
||||||
|
REFERENCES {{qual_table .TargetSchema .TargetTable}} ({{.TargetColumns}})
|
||||||
|
ON DELETE {{.OnDelete}}
|
||||||
|
ON UPDATE {{.OnUpdate}}{{if .Deferrable}}
|
||||||
|
DEFERRABLE{{end}};
|
||||||
|
END IF;
|
||||||
|
END;
|
||||||
|
$$;
|
||||||
@@ -1,2 +1,2 @@
|
|||||||
CREATE {{if .Unique}}UNIQUE {{end}}INDEX IF NOT EXISTS {{.IndexName}}
|
CREATE {{if .Unique}}UNIQUE {{end}}INDEX IF NOT EXISTS {{quote_ident .IndexName}}
|
||||||
ON {{.SchemaName}}.{{.TableName}} USING {{.IndexType}} ({{.Columns}});
|
ON {{qual_table .SchemaName .TableName}} USING {{.IndexType}} ({{.Columns}});
|
||||||
@@ -6,8 +6,8 @@ BEGIN
|
|||||||
AND table_name = '{{.TableName}}'
|
AND table_name = '{{.TableName}}'
|
||||||
AND constraint_name = '{{.ConstraintName}}'
|
AND constraint_name = '{{.ConstraintName}}'
|
||||||
) THEN
|
) THEN
|
||||||
ALTER TABLE {{.SchemaName}}.{{.TableName}}
|
ALTER TABLE {{qual_table .SchemaName .TableName}}
|
||||||
ADD CONSTRAINT {{.ConstraintName}} PRIMARY KEY ({{.Columns}});
|
ADD CONSTRAINT {{quote_ident .ConstraintName}} PRIMARY KEY ({{.Columns}});
|
||||||
END IF;
|
END IF;
|
||||||
END;
|
END;
|
||||||
$$;
|
$$;
|
||||||
@@ -0,0 +1,27 @@
|
|||||||
|
DO $$
|
||||||
|
DECLARE
|
||||||
|
auto_pk_name text;
|
||||||
|
BEGIN
|
||||||
|
-- Drop auto-generated primary key if it exists
|
||||||
|
SELECT constraint_name INTO auto_pk_name
|
||||||
|
FROM information_schema.table_constraints
|
||||||
|
WHERE table_schema = '{{.SchemaName}}'
|
||||||
|
AND table_name = '{{.TableName}}'
|
||||||
|
AND constraint_type = 'PRIMARY KEY'
|
||||||
|
AND constraint_name IN ({{.AutoGenNames}});
|
||||||
|
|
||||||
|
IF auto_pk_name IS NOT NULL THEN
|
||||||
|
EXECUTE 'ALTER TABLE {{qual_table .SchemaName .TableName}} DROP CONSTRAINT ' || quote_ident(auto_pk_name);
|
||||||
|
END IF;
|
||||||
|
|
||||||
|
-- Add named primary key if it doesn't exist
|
||||||
|
IF NOT EXISTS (
|
||||||
|
SELECT 1 FROM information_schema.table_constraints
|
||||||
|
WHERE table_schema = '{{.SchemaName}}'
|
||||||
|
AND table_name = '{{.TableName}}'
|
||||||
|
AND constraint_name = '{{.ConstraintName}}'
|
||||||
|
) THEN
|
||||||
|
ALTER TABLE {{qual_table .SchemaName .TableName}} ADD CONSTRAINT {{quote_ident .ConstraintName}} PRIMARY KEY ({{.Columns}});
|
||||||
|
END IF;
|
||||||
|
END;
|
||||||
|
$$;
|
||||||
6
pkg/writers/pgsql/templates/create_sequence.tmpl
Normal file
6
pkg/writers/pgsql/templates/create_sequence.tmpl
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
CREATE SEQUENCE IF NOT EXISTS {{qual_table .SchemaName .SequenceName}}
|
||||||
|
INCREMENT {{.Increment}}
|
||||||
|
MINVALUE {{.MinValue}}
|
||||||
|
MAXVALUE {{.MaxValue}}
|
||||||
|
START {{.StartValue}}
|
||||||
|
CACHE {{.CacheSize}};
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
CREATE TABLE IF NOT EXISTS {{.SchemaName}}.{{.TableName}} (
|
CREATE TABLE IF NOT EXISTS {{qual_table .SchemaName .TableName}} (
|
||||||
{{- range $i, $col := .Columns}}
|
{{- range $i, $col := .Columns}}
|
||||||
{{- if $i}},{{end}}
|
{{- if $i}},{{end}}
|
||||||
{{$col.Name}} {{$col.Type}}
|
{{quote_ident $col.Name}} {{$col.Type}}
|
||||||
{{- if $col.Default}} DEFAULT {{$col.Default}}{{end}}
|
{{- if $col.Default}} DEFAULT {{$col.Default}}{{end}}
|
||||||
{{- if $col.NotNull}} NOT NULL{{end}}
|
{{- if $col.NotNull}} NOT NULL{{end}}
|
||||||
{{- end}}
|
{{- end}}
|
||||||
|
|||||||
12
pkg/writers/pgsql/templates/create_unique_constraint.tmpl
Normal file
12
pkg/writers/pgsql/templates/create_unique_constraint.tmpl
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
DO $$
|
||||||
|
BEGIN
|
||||||
|
IF NOT EXISTS (
|
||||||
|
SELECT 1 FROM information_schema.table_constraints
|
||||||
|
WHERE table_schema = '{{.SchemaName}}'
|
||||||
|
AND table_name = '{{.TableName}}'
|
||||||
|
AND constraint_name = '{{.ConstraintName}}'
|
||||||
|
) THEN
|
||||||
|
ALTER TABLE {{qual_table .SchemaName .TableName}} ADD CONSTRAINT {{quote_ident .ConstraintName}} UNIQUE ({{.Columns}});
|
||||||
|
END IF;
|
||||||
|
END;
|
||||||
|
$$;
|
||||||
@@ -1 +1 @@
|
|||||||
ALTER TABLE {{.SchemaName}}.{{.TableName}} DROP CONSTRAINT IF EXISTS {{.ConstraintName}};
|
ALTER TABLE {{qual_table .SchemaName .TableName}} DROP CONSTRAINT IF EXISTS {{quote_ident .ConstraintName}};
|
||||||
@@ -1 +1 @@
|
|||||||
DROP INDEX IF EXISTS {{.SchemaName}}.{{.IndexName}} CASCADE;
|
DROP INDEX IF EXISTS {{qual_table .SchemaName .IndexName}} CASCADE;
|
||||||
@@ -16,7 +16,7 @@
|
|||||||
|
|
||||||
{{/* Qualified table name */}}
|
{{/* Qualified table name */}}
|
||||||
{{- define "qualified_table" -}}
|
{{- define "qualified_table" -}}
|
||||||
{{.SchemaName}}.{{.TableName}}
|
{{qual_table_raw .SchemaName .TableName}}
|
||||||
{{- end -}}
|
{{- end -}}
|
||||||
|
|
||||||
{{/* Index method clause */}}
|
{{/* Index method clause */}}
|
||||||
|
|||||||
19
pkg/writers/pgsql/templates/set_sequence_value.tmpl
Normal file
19
pkg/writers/pgsql/templates/set_sequence_value.tmpl
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
DO $$
|
||||||
|
DECLARE
|
||||||
|
m_cnt bigint;
|
||||||
|
BEGIN
|
||||||
|
IF EXISTS (
|
||||||
|
SELECT 1 FROM pg_class c
|
||||||
|
INNER JOIN pg_namespace n ON n.oid = c.relnamespace
|
||||||
|
WHERE c.relname = '{{.SequenceName}}'
|
||||||
|
AND n.nspname = '{{.SchemaName}}'
|
||||||
|
AND c.relkind = 'S'
|
||||||
|
) THEN
|
||||||
|
SELECT COALESCE(MAX({{quote_ident .ColumnName}}), 0) + 1
|
||||||
|
FROM {{qual_table .SchemaName .TableName}}
|
||||||
|
INTO m_cnt;
|
||||||
|
|
||||||
|
PERFORM setval('{{qual_table_raw .SchemaName .SequenceName}}'::regclass, m_cnt);
|
||||||
|
END IF;
|
||||||
|
END;
|
||||||
|
$$;
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -45,11 +45,11 @@ func TestWriteDatabase(t *testing.T) {
|
|||||||
|
|
||||||
// Add unique index
|
// Add unique index
|
||||||
uniqueEmailIndex := &models.Index{
|
uniqueEmailIndex := &models.Index{
|
||||||
Name: "uk_users_email",
|
Name: "uidx_users_email",
|
||||||
Unique: true,
|
Unique: true,
|
||||||
Columns: []string{"email"},
|
Columns: []string{"email"},
|
||||||
}
|
}
|
||||||
table.Indexes["uk_users_email"] = uniqueEmailIndex
|
table.Indexes["uidx_users_email"] = uniqueEmailIndex
|
||||||
|
|
||||||
schema.Tables = append(schema.Tables, table)
|
schema.Tables = append(schema.Tables, table)
|
||||||
db.Schemas = append(db.Schemas, schema)
|
db.Schemas = append(db.Schemas, schema)
|
||||||
@@ -164,6 +164,296 @@ func TestWriteForeignKeys(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestWriteUniqueConstraints(t *testing.T) {
|
||||||
|
// Create a test database with unique constraints
|
||||||
|
db := models.InitDatabase("testdb")
|
||||||
|
schema := models.InitSchema("public")
|
||||||
|
|
||||||
|
// Create table with unique constraints
|
||||||
|
table := models.InitTable("users", "public")
|
||||||
|
|
||||||
|
// Add columns
|
||||||
|
emailCol := models.InitColumn("email", "users", "public")
|
||||||
|
emailCol.Type = "varchar(255)"
|
||||||
|
emailCol.NotNull = true
|
||||||
|
table.Columns["email"] = emailCol
|
||||||
|
|
||||||
|
guidCol := models.InitColumn("guid", "users", "public")
|
||||||
|
guidCol.Type = "uuid"
|
||||||
|
guidCol.NotNull = true
|
||||||
|
table.Columns["guid"] = guidCol
|
||||||
|
|
||||||
|
// Add unique constraints
|
||||||
|
emailConstraint := &models.Constraint{
|
||||||
|
Name: "uq_email",
|
||||||
|
Type: models.UniqueConstraint,
|
||||||
|
Schema: "public",
|
||||||
|
Table: "users",
|
||||||
|
Columns: []string{"email"},
|
||||||
|
}
|
||||||
|
table.Constraints["uq_email"] = emailConstraint
|
||||||
|
|
||||||
|
guidConstraint := &models.Constraint{
|
||||||
|
Name: "uq_guid",
|
||||||
|
Type: models.UniqueConstraint,
|
||||||
|
Schema: "public",
|
||||||
|
Table: "users",
|
||||||
|
Columns: []string{"guid"},
|
||||||
|
}
|
||||||
|
table.Constraints["uq_guid"] = guidConstraint
|
||||||
|
|
||||||
|
schema.Tables = append(schema.Tables, table)
|
||||||
|
db.Schemas = append(db.Schemas, schema)
|
||||||
|
|
||||||
|
// Create writer with output to buffer
|
||||||
|
var buf bytes.Buffer
|
||||||
|
options := &writers.WriterOptions{}
|
||||||
|
writer := NewWriter(options)
|
||||||
|
writer.writer = &buf
|
||||||
|
|
||||||
|
// Write the database
|
||||||
|
err := writer.WriteDatabase(db)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("WriteDatabase failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
output := buf.String()
|
||||||
|
|
||||||
|
// Print output for debugging
|
||||||
|
t.Logf("Generated SQL:\n%s", output)
|
||||||
|
|
||||||
|
// Verify unique constraints are present
|
||||||
|
if !strings.Contains(output, "-- Unique constraints for schema: public") {
|
||||||
|
t.Errorf("Output missing unique constraints header")
|
||||||
|
}
|
||||||
|
if !strings.Contains(output, "ADD CONSTRAINT uq_email UNIQUE (email)") {
|
||||||
|
t.Errorf("Output missing uq_email unique constraint\nFull output:\n%s", output)
|
||||||
|
}
|
||||||
|
if !strings.Contains(output, "ADD CONSTRAINT uq_guid UNIQUE (guid)") {
|
||||||
|
t.Errorf("Output missing uq_guid unique constraint\nFull output:\n%s", output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWriteCheckConstraints(t *testing.T) {
|
||||||
|
// Create a test database with check constraints
|
||||||
|
db := models.InitDatabase("testdb")
|
||||||
|
schema := models.InitSchema("public")
|
||||||
|
|
||||||
|
// Create table with check constraints
|
||||||
|
table := models.InitTable("products", "public")
|
||||||
|
|
||||||
|
// Add columns
|
||||||
|
priceCol := models.InitColumn("price", "products", "public")
|
||||||
|
priceCol.Type = "numeric(10,2)"
|
||||||
|
table.Columns["price"] = priceCol
|
||||||
|
|
||||||
|
statusCol := models.InitColumn("status", "products", "public")
|
||||||
|
statusCol.Type = "varchar(20)"
|
||||||
|
table.Columns["status"] = statusCol
|
||||||
|
|
||||||
|
quantityCol := models.InitColumn("quantity", "products", "public")
|
||||||
|
quantityCol.Type = "integer"
|
||||||
|
table.Columns["quantity"] = quantityCol
|
||||||
|
|
||||||
|
// Add check constraints
|
||||||
|
priceConstraint := &models.Constraint{
|
||||||
|
Name: "ck_price_positive",
|
||||||
|
Type: models.CheckConstraint,
|
||||||
|
Schema: "public",
|
||||||
|
Table: "products",
|
||||||
|
Expression: "price >= 0",
|
||||||
|
}
|
||||||
|
table.Constraints["ck_price_positive"] = priceConstraint
|
||||||
|
|
||||||
|
statusConstraint := &models.Constraint{
|
||||||
|
Name: "ck_status_valid",
|
||||||
|
Type: models.CheckConstraint,
|
||||||
|
Schema: "public",
|
||||||
|
Table: "products",
|
||||||
|
Expression: "status IN ('active', 'inactive', 'discontinued')",
|
||||||
|
}
|
||||||
|
table.Constraints["ck_status_valid"] = statusConstraint
|
||||||
|
|
||||||
|
quantityConstraint := &models.Constraint{
|
||||||
|
Name: "ck_quantity_nonnegative",
|
||||||
|
Type: models.CheckConstraint,
|
||||||
|
Schema: "public",
|
||||||
|
Table: "products",
|
||||||
|
Expression: "quantity >= 0",
|
||||||
|
}
|
||||||
|
table.Constraints["ck_quantity_nonnegative"] = quantityConstraint
|
||||||
|
|
||||||
|
schema.Tables = append(schema.Tables, table)
|
||||||
|
db.Schemas = append(db.Schemas, schema)
|
||||||
|
|
||||||
|
// Create writer with output to buffer
|
||||||
|
var buf bytes.Buffer
|
||||||
|
options := &writers.WriterOptions{}
|
||||||
|
writer := NewWriter(options)
|
||||||
|
writer.writer = &buf
|
||||||
|
|
||||||
|
// Write the database
|
||||||
|
err := writer.WriteDatabase(db)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("WriteDatabase failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
output := buf.String()
|
||||||
|
|
||||||
|
// Print output for debugging
|
||||||
|
t.Logf("Generated SQL:\n%s", output)
|
||||||
|
|
||||||
|
// Verify check constraints are present
|
||||||
|
if !strings.Contains(output, "-- Check constraints for schema: public") {
|
||||||
|
t.Errorf("Output missing check constraints header")
|
||||||
|
}
|
||||||
|
if !strings.Contains(output, "ADD CONSTRAINT ck_price_positive CHECK (price >= 0)") {
|
||||||
|
t.Errorf("Output missing ck_price_positive check constraint\nFull output:\n%s", output)
|
||||||
|
}
|
||||||
|
if !strings.Contains(output, "ADD CONSTRAINT ck_status_valid CHECK (status IN ('active', 'inactive', 'discontinued'))") {
|
||||||
|
t.Errorf("Output missing ck_status_valid check constraint\nFull output:\n%s", output)
|
||||||
|
}
|
||||||
|
if !strings.Contains(output, "ADD CONSTRAINT ck_quantity_nonnegative CHECK (quantity >= 0)") {
|
||||||
|
t.Errorf("Output missing ck_quantity_nonnegative check constraint\nFull output:\n%s", output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWriteAllConstraintTypes(t *testing.T) {
|
||||||
|
// Create a comprehensive test with all constraint types
|
||||||
|
db := models.InitDatabase("testdb")
|
||||||
|
schema := models.InitSchema("public")
|
||||||
|
|
||||||
|
// Create orders table
|
||||||
|
ordersTable := models.InitTable("orders", "public")
|
||||||
|
|
||||||
|
// Add columns
|
||||||
|
idCol := models.InitColumn("id", "orders", "public")
|
||||||
|
idCol.Type = "integer"
|
||||||
|
idCol.IsPrimaryKey = true
|
||||||
|
ordersTable.Columns["id"] = idCol
|
||||||
|
|
||||||
|
userIDCol := models.InitColumn("user_id", "orders", "public")
|
||||||
|
userIDCol.Type = "integer"
|
||||||
|
userIDCol.NotNull = true
|
||||||
|
ordersTable.Columns["user_id"] = userIDCol
|
||||||
|
|
||||||
|
orderNumberCol := models.InitColumn("order_number", "orders", "public")
|
||||||
|
orderNumberCol.Type = "varchar(50)"
|
||||||
|
orderNumberCol.NotNull = true
|
||||||
|
ordersTable.Columns["order_number"] = orderNumberCol
|
||||||
|
|
||||||
|
totalCol := models.InitColumn("total", "orders", "public")
|
||||||
|
totalCol.Type = "numeric(10,2)"
|
||||||
|
ordersTable.Columns["total"] = totalCol
|
||||||
|
|
||||||
|
statusCol := models.InitColumn("status", "orders", "public")
|
||||||
|
statusCol.Type = "varchar(20)"
|
||||||
|
ordersTable.Columns["status"] = statusCol
|
||||||
|
|
||||||
|
// Add primary key constraint
|
||||||
|
pkConstraint := &models.Constraint{
|
||||||
|
Name: "pk_orders",
|
||||||
|
Type: models.PrimaryKeyConstraint,
|
||||||
|
Schema: "public",
|
||||||
|
Table: "orders",
|
||||||
|
Columns: []string{"id"},
|
||||||
|
}
|
||||||
|
ordersTable.Constraints["pk_orders"] = pkConstraint
|
||||||
|
|
||||||
|
// Add unique constraint
|
||||||
|
uniqueConstraint := &models.Constraint{
|
||||||
|
Name: "uq_order_number",
|
||||||
|
Type: models.UniqueConstraint,
|
||||||
|
Schema: "public",
|
||||||
|
Table: "orders",
|
||||||
|
Columns: []string{"order_number"},
|
||||||
|
}
|
||||||
|
ordersTable.Constraints["uq_order_number"] = uniqueConstraint
|
||||||
|
|
||||||
|
// Add check constraint
|
||||||
|
checkConstraint := &models.Constraint{
|
||||||
|
Name: "ck_total_positive",
|
||||||
|
Type: models.CheckConstraint,
|
||||||
|
Schema: "public",
|
||||||
|
Table: "orders",
|
||||||
|
Expression: "total > 0",
|
||||||
|
}
|
||||||
|
ordersTable.Constraints["ck_total_positive"] = checkConstraint
|
||||||
|
|
||||||
|
statusCheckConstraint := &models.Constraint{
|
||||||
|
Name: "ck_status_valid",
|
||||||
|
Type: models.CheckConstraint,
|
||||||
|
Schema: "public",
|
||||||
|
Table: "orders",
|
||||||
|
Expression: "status IN ('pending', 'completed', 'cancelled')",
|
||||||
|
}
|
||||||
|
ordersTable.Constraints["ck_status_valid"] = statusCheckConstraint
|
||||||
|
|
||||||
|
// Add foreign key constraint (referencing a users table)
|
||||||
|
fkConstraint := &models.Constraint{
|
||||||
|
Name: "fk_orders_user",
|
||||||
|
Type: models.ForeignKeyConstraint,
|
||||||
|
Schema: "public",
|
||||||
|
Table: "orders",
|
||||||
|
Columns: []string{"user_id"},
|
||||||
|
ReferencedSchema: "public",
|
||||||
|
ReferencedTable: "users",
|
||||||
|
ReferencedColumns: []string{"id"},
|
||||||
|
OnDelete: "CASCADE",
|
||||||
|
OnUpdate: "CASCADE",
|
||||||
|
}
|
||||||
|
ordersTable.Constraints["fk_orders_user"] = fkConstraint
|
||||||
|
|
||||||
|
schema.Tables = append(schema.Tables, ordersTable)
|
||||||
|
db.Schemas = append(db.Schemas, schema)
|
||||||
|
|
||||||
|
// Create writer with output to buffer
|
||||||
|
var buf bytes.Buffer
|
||||||
|
options := &writers.WriterOptions{}
|
||||||
|
writer := NewWriter(options)
|
||||||
|
writer.writer = &buf
|
||||||
|
|
||||||
|
// Write the database
|
||||||
|
err := writer.WriteDatabase(db)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("WriteDatabase failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
output := buf.String()
|
||||||
|
|
||||||
|
// Print output for debugging
|
||||||
|
t.Logf("Generated SQL:\n%s", output)
|
||||||
|
|
||||||
|
// Verify all constraint types are present
|
||||||
|
expectedConstraints := map[string]string{
|
||||||
|
"Primary Key": "PRIMARY KEY",
|
||||||
|
"Unique": "ADD CONSTRAINT uq_order_number UNIQUE (order_number)",
|
||||||
|
"Check (total)": "ADD CONSTRAINT ck_total_positive CHECK (total > 0)",
|
||||||
|
"Check (status)": "ADD CONSTRAINT ck_status_valid CHECK (status IN ('pending', 'completed', 'cancelled'))",
|
||||||
|
"Foreign Key": "FOREIGN KEY",
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, expected := range expectedConstraints {
|
||||||
|
if !strings.Contains(output, expected) {
|
||||||
|
t.Errorf("Output missing %s constraint: %s\nFull output:\n%s", name, expected, output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify section headers
|
||||||
|
sections := []string{
|
||||||
|
"-- Primary keys for schema: public",
|
||||||
|
"-- Unique constraints for schema: public",
|
||||||
|
"-- Check constraints for schema: public",
|
||||||
|
"-- Foreign keys for schema: public",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, section := range sections {
|
||||||
|
if !strings.Contains(output, section) {
|
||||||
|
t.Errorf("Output missing section header: %s", section)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestWriteTable(t *testing.T) {
|
func TestWriteTable(t *testing.T) {
|
||||||
// Create a single table
|
// Create a single table
|
||||||
table := models.InitTable("products", "public")
|
table := models.InitTable("products", "public")
|
||||||
@@ -241,3 +531,327 @@ func TestIsIntegerType(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestTypeConversion(t *testing.T) {
|
||||||
|
// Test that invalid Go types are converted to valid PostgreSQL types
|
||||||
|
db := models.InitDatabase("testdb")
|
||||||
|
schema := models.InitSchema("public")
|
||||||
|
|
||||||
|
// Create a test table with Go types instead of SQL types
|
||||||
|
table := models.InitTable("test_types", "public")
|
||||||
|
|
||||||
|
// Add columns with Go types (invalid for PostgreSQL)
|
||||||
|
stringCol := models.InitColumn("name", "test_types", "public")
|
||||||
|
stringCol.Type = "string" // Should be converted to "text"
|
||||||
|
table.Columns["name"] = stringCol
|
||||||
|
|
||||||
|
int64Col := models.InitColumn("big_id", "test_types", "public")
|
||||||
|
int64Col.Type = "int64" // Should be converted to "bigint"
|
||||||
|
table.Columns["big_id"] = int64Col
|
||||||
|
|
||||||
|
int16Col := models.InitColumn("small_id", "test_types", "public")
|
||||||
|
int16Col.Type = "int16" // Should be converted to "smallint"
|
||||||
|
table.Columns["small_id"] = int16Col
|
||||||
|
|
||||||
|
schema.Tables = append(schema.Tables, table)
|
||||||
|
db.Schemas = append(db.Schemas, schema)
|
||||||
|
|
||||||
|
// Create writer with output to buffer
|
||||||
|
var buf bytes.Buffer
|
||||||
|
options := &writers.WriterOptions{}
|
||||||
|
writer := NewWriter(options)
|
||||||
|
writer.writer = &buf
|
||||||
|
|
||||||
|
// Write the database
|
||||||
|
err := writer.WriteDatabase(db)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("WriteDatabase failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
output := buf.String()
|
||||||
|
|
||||||
|
// Print output for debugging
|
||||||
|
t.Logf("Generated SQL:\n%s", output)
|
||||||
|
|
||||||
|
// Verify that Go types were converted to PostgreSQL types
|
||||||
|
if strings.Contains(output, "string") {
|
||||||
|
t.Errorf("Output contains 'string' type - should be converted to 'text'\nFull output:\n%s", output)
|
||||||
|
}
|
||||||
|
if strings.Contains(output, "int64") {
|
||||||
|
t.Errorf("Output contains 'int64' type - should be converted to 'bigint'\nFull output:\n%s", output)
|
||||||
|
}
|
||||||
|
if strings.Contains(output, "int16") {
|
||||||
|
t.Errorf("Output contains 'int16' type - should be converted to 'smallint'\nFull output:\n%s", output)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify correct PostgreSQL types are present
|
||||||
|
if !strings.Contains(output, "text") {
|
||||||
|
t.Errorf("Output missing 'text' type (converted from 'string')\nFull output:\n%s", output)
|
||||||
|
}
|
||||||
|
if !strings.Contains(output, "bigint") {
|
||||||
|
t.Errorf("Output missing 'bigint' type (converted from 'int64')\nFull output:\n%s", output)
|
||||||
|
}
|
||||||
|
if !strings.Contains(output, "smallint") {
|
||||||
|
t.Errorf("Output missing 'smallint' type (converted from 'int16')\nFull output:\n%s", output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPrimaryKeyExistenceCheck(t *testing.T) {
|
||||||
|
db := models.InitDatabase("testdb")
|
||||||
|
schema := models.InitSchema("public")
|
||||||
|
table := models.InitTable("products", "public")
|
||||||
|
|
||||||
|
idCol := models.InitColumn("id", "products", "public")
|
||||||
|
idCol.Type = "integer"
|
||||||
|
idCol.IsPrimaryKey = true
|
||||||
|
table.Columns["id"] = idCol
|
||||||
|
|
||||||
|
nameCol := models.InitColumn("name", "products", "public")
|
||||||
|
nameCol.Type = "text"
|
||||||
|
table.Columns["name"] = nameCol
|
||||||
|
|
||||||
|
schema.Tables = append(schema.Tables, table)
|
||||||
|
db.Schemas = append(db.Schemas, schema)
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
options := &writers.WriterOptions{}
|
||||||
|
writer := NewWriter(options)
|
||||||
|
writer.writer = &buf
|
||||||
|
|
||||||
|
err := writer.WriteDatabase(db)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("WriteDatabase failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
output := buf.String()
|
||||||
|
t.Logf("Generated SQL:\n%s", output)
|
||||||
|
|
||||||
|
// Verify our naming convention is used
|
||||||
|
if !strings.Contains(output, "pk_public_products") {
|
||||||
|
t.Errorf("Output missing expected primary key name 'pk_public_products'\nFull output:\n%s", output)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify it drops auto-generated primary keys
|
||||||
|
if !strings.Contains(output, "products_pkey") || !strings.Contains(output, "DROP CONSTRAINT") {
|
||||||
|
t.Errorf("Output missing logic to drop auto-generated primary key\nFull output:\n%s", output)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify it checks for our specific named constraint before adding it
|
||||||
|
if !strings.Contains(output, "constraint_name = 'pk_public_products'") {
|
||||||
|
t.Errorf("Output missing check for our named primary key constraint\nFull output:\n%s", output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestColumnSizeSpecifiers(t *testing.T) {
|
||||||
|
db := models.InitDatabase("testdb")
|
||||||
|
schema := models.InitSchema("public")
|
||||||
|
table := models.InitTable("test_sizes", "public")
|
||||||
|
|
||||||
|
// Integer with invalid size specifier - should ignore size
|
||||||
|
integerCol := models.InitColumn("int_col", "test_sizes", "public")
|
||||||
|
integerCol.Type = "integer"
|
||||||
|
integerCol.Length = 32
|
||||||
|
table.Columns["int_col"] = integerCol
|
||||||
|
|
||||||
|
// Bigint with invalid size specifier - should ignore size
|
||||||
|
bigintCol := models.InitColumn("bigint_col", "test_sizes", "public")
|
||||||
|
bigintCol.Type = "bigint"
|
||||||
|
bigintCol.Length = 64
|
||||||
|
table.Columns["bigint_col"] = bigintCol
|
||||||
|
|
||||||
|
// Smallint with invalid size specifier - should ignore size
|
||||||
|
smallintCol := models.InitColumn("smallint_col", "test_sizes", "public")
|
||||||
|
smallintCol.Type = "smallint"
|
||||||
|
smallintCol.Length = 16
|
||||||
|
table.Columns["smallint_col"] = smallintCol
|
||||||
|
|
||||||
|
// Text with length - should convert to varchar
|
||||||
|
textCol := models.InitColumn("text_col", "test_sizes", "public")
|
||||||
|
textCol.Type = "text"
|
||||||
|
textCol.Length = 100
|
||||||
|
table.Columns["text_col"] = textCol
|
||||||
|
|
||||||
|
// Varchar with length - should keep varchar with length
|
||||||
|
varcharCol := models.InitColumn("varchar_col", "test_sizes", "public")
|
||||||
|
varcharCol.Type = "varchar"
|
||||||
|
varcharCol.Length = 50
|
||||||
|
table.Columns["varchar_col"] = varcharCol
|
||||||
|
|
||||||
|
// Decimal with precision and scale - should keep them
|
||||||
|
decimalCol := models.InitColumn("decimal_col", "test_sizes", "public")
|
||||||
|
decimalCol.Type = "decimal"
|
||||||
|
decimalCol.Precision = 19
|
||||||
|
decimalCol.Scale = 4
|
||||||
|
table.Columns["decimal_col"] = decimalCol
|
||||||
|
|
||||||
|
schema.Tables = append(schema.Tables, table)
|
||||||
|
db.Schemas = append(db.Schemas, schema)
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
options := &writers.WriterOptions{}
|
||||||
|
writer := NewWriter(options)
|
||||||
|
writer.writer = &buf
|
||||||
|
|
||||||
|
err := writer.WriteDatabase(db)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("WriteDatabase failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
output := buf.String()
|
||||||
|
t.Logf("Generated SQL:\n%s", output)
|
||||||
|
|
||||||
|
// Verify invalid size specifiers are NOT present
|
||||||
|
invalidPatterns := []string{
|
||||||
|
"integer(32)",
|
||||||
|
"bigint(64)",
|
||||||
|
"smallint(16)",
|
||||||
|
"text(100)",
|
||||||
|
}
|
||||||
|
for _, pattern := range invalidPatterns {
|
||||||
|
if strings.Contains(output, pattern) {
|
||||||
|
t.Errorf("Output contains invalid pattern '%s' - PostgreSQL doesn't support this\nFull output:\n%s", pattern, output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify valid patterns ARE present
|
||||||
|
validPatterns := []string{
|
||||||
|
"integer", // without size
|
||||||
|
"bigint", // without size
|
||||||
|
"smallint", // without size
|
||||||
|
"varchar(100)", // text converted to varchar with length
|
||||||
|
"varchar(50)", // varchar with length
|
||||||
|
"decimal(19,4)", // decimal with precision and scale
|
||||||
|
}
|
||||||
|
for _, pattern := range validPatterns {
|
||||||
|
if !strings.Contains(output, pattern) {
|
||||||
|
t.Errorf("Output missing expected pattern '%s'\nFull output:\n%s", pattern, output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerateAddColumnStatements(t *testing.T) {
|
||||||
|
// Create a test database with tables that have new columns
|
||||||
|
db := models.InitDatabase("testdb")
|
||||||
|
schema := models.InitSchema("public")
|
||||||
|
|
||||||
|
// Create a table with columns
|
||||||
|
table := models.InitTable("users", "public")
|
||||||
|
|
||||||
|
// Existing column
|
||||||
|
idCol := models.InitColumn("id", "users", "public")
|
||||||
|
idCol.Type = "integer"
|
||||||
|
idCol.NotNull = true
|
||||||
|
idCol.Sequence = 1
|
||||||
|
table.Columns["id"] = idCol
|
||||||
|
|
||||||
|
// New column to be added
|
||||||
|
emailCol := models.InitColumn("email", "users", "public")
|
||||||
|
emailCol.Type = "varchar"
|
||||||
|
emailCol.Length = 255
|
||||||
|
emailCol.NotNull = true
|
||||||
|
emailCol.Sequence = 2
|
||||||
|
table.Columns["email"] = emailCol
|
||||||
|
|
||||||
|
// New column with default
|
||||||
|
statusCol := models.InitColumn("status", "users", "public")
|
||||||
|
statusCol.Type = "text"
|
||||||
|
statusCol.Default = "active"
|
||||||
|
statusCol.Sequence = 3
|
||||||
|
table.Columns["status"] = statusCol
|
||||||
|
|
||||||
|
schema.Tables = append(schema.Tables, table)
|
||||||
|
db.Schemas = append(db.Schemas, schema)
|
||||||
|
|
||||||
|
// Create writer
|
||||||
|
options := &writers.WriterOptions{}
|
||||||
|
writer := NewWriter(options)
|
||||||
|
|
||||||
|
// Generate ADD COLUMN statements
|
||||||
|
statements, err := writer.GenerateAddColumnsForDatabase(db)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GenerateAddColumnsForDatabase failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Join all statements to verify content
|
||||||
|
output := strings.Join(statements, "\n")
|
||||||
|
t.Logf("Generated ADD COLUMN statements:\n%s", output)
|
||||||
|
|
||||||
|
// Verify expected elements
|
||||||
|
expectedStrings := []string{
|
||||||
|
"ALTER TABLE public.users ADD COLUMN id integer NOT NULL",
|
||||||
|
"ALTER TABLE public.users ADD COLUMN email varchar(255) NOT NULL",
|
||||||
|
"ALTER TABLE public.users ADD COLUMN status text DEFAULT 'active'",
|
||||||
|
"information_schema.columns",
|
||||||
|
"table_schema = 'public'",
|
||||||
|
"table_name = 'users'",
|
||||||
|
"column_name = 'id'",
|
||||||
|
"column_name = 'email'",
|
||||||
|
"column_name = 'status'",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, expected := range expectedStrings {
|
||||||
|
if !strings.Contains(output, expected) {
|
||||||
|
t.Errorf("Output missing expected string: %s\nFull output:\n%s", expected, output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify DO blocks are present for conditional adds
|
||||||
|
doBlockCount := strings.Count(output, "DO $$")
|
||||||
|
if doBlockCount < 3 {
|
||||||
|
t.Errorf("Expected at least 3 DO blocks (one per column), got %d", doBlockCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify IF NOT EXISTS logic
|
||||||
|
ifNotExistsCount := strings.Count(output, "IF NOT EXISTS")
|
||||||
|
if ifNotExistsCount < 3 {
|
||||||
|
t.Errorf("Expected at least 3 IF NOT EXISTS checks (one per column), got %d", ifNotExistsCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWriteAddColumnStatements(t *testing.T) {
|
||||||
|
// Create a test database
|
||||||
|
db := models.InitDatabase("testdb")
|
||||||
|
schema := models.InitSchema("public")
|
||||||
|
|
||||||
|
// Create a table with a new column to be added
|
||||||
|
table := models.InitTable("products", "public")
|
||||||
|
|
||||||
|
idCol := models.InitColumn("id", "products", "public")
|
||||||
|
idCol.Type = "integer"
|
||||||
|
table.Columns["id"] = idCol
|
||||||
|
|
||||||
|
// New column with various properties
|
||||||
|
descCol := models.InitColumn("description", "products", "public")
|
||||||
|
descCol.Type = "text"
|
||||||
|
descCol.NotNull = false
|
||||||
|
table.Columns["description"] = descCol
|
||||||
|
|
||||||
|
schema.Tables = append(schema.Tables, table)
|
||||||
|
db.Schemas = append(db.Schemas, schema)
|
||||||
|
|
||||||
|
// Create writer with output to buffer
|
||||||
|
var buf bytes.Buffer
|
||||||
|
options := &writers.WriterOptions{}
|
||||||
|
writer := NewWriter(options)
|
||||||
|
writer.writer = &buf
|
||||||
|
|
||||||
|
// Write ADD COLUMN statements
|
||||||
|
err := writer.WriteAddColumnStatements(db)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("WriteAddColumnStatements failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
output := buf.String()
|
||||||
|
t.Logf("Generated output:\n%s", output)
|
||||||
|
|
||||||
|
// Verify output contains expected elements
|
||||||
|
if !strings.Contains(output, "ALTER TABLE public.products ADD COLUMN id integer") {
|
||||||
|
t.Errorf("Output missing ADD COLUMN for id\nFull output:\n%s", output)
|
||||||
|
}
|
||||||
|
if !strings.Contains(output, "ALTER TABLE public.products ADD COLUMN description text") {
|
||||||
|
t.Errorf("Output missing ADD COLUMN for description\nFull output:\n%s", output)
|
||||||
|
}
|
||||||
|
if !strings.Contains(output, "DO $$") {
|
||||||
|
t.Errorf("Output missing DO block\nFull output:\n%s", output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ The SQL Executor Writer (`sqlexec`) executes SQL scripts from `models.Script` ob
|
|||||||
|
|
||||||
## Features
|
## Features
|
||||||
|
|
||||||
- **Ordered Execution**: Scripts execute in Priority→Sequence order
|
- **Ordered Execution**: Scripts execute in Priority→Sequence→Name order
|
||||||
- **PostgreSQL Support**: Uses `pgx/v5` driver for robust PostgreSQL connectivity
|
- **PostgreSQL Support**: Uses `pgx/v5` driver for robust PostgreSQL connectivity
|
||||||
- **Stop on Error**: Execution halts immediately on first error (default behavior)
|
- **Stop on Error**: Execution halts immediately on first error (default behavior)
|
||||||
- **Progress Reporting**: Prints execution status to stdout
|
- **Progress Reporting**: Prints execution status to stdout
|
||||||
@@ -103,19 +103,40 @@ Scripts are sorted and executed based on:
|
|||||||
|
|
||||||
1. **Priority** (ascending): Lower priority values execute first
|
1. **Priority** (ascending): Lower priority values execute first
|
||||||
2. **Sequence** (ascending): Within same priority, lower sequence values execute first
|
2. **Sequence** (ascending): Within same priority, lower sequence values execute first
|
||||||
|
3. **Name** (ascending): Within same priority and sequence, alphabetical order by name
|
||||||
|
|
||||||
### Example Execution Order
|
### Example Execution Order
|
||||||
|
|
||||||
Given these scripts:
|
Given these scripts:
|
||||||
```
|
```
|
||||||
Script A: Priority=2, Sequence=1
|
Script A: Priority=2, Sequence=1, Name="zebra"
|
||||||
Script B: Priority=1, Sequence=3
|
Script B: Priority=1, Sequence=3, Name="script"
|
||||||
Script C: Priority=1, Sequence=1
|
Script C: Priority=1, Sequence=1, Name="apple"
|
||||||
Script D: Priority=1, Sequence=2
|
Script D: Priority=1, Sequence=1, Name="beta"
|
||||||
Script E: Priority=3, Sequence=1
|
Script E: Priority=3, Sequence=1, Name="script"
|
||||||
```
|
```
|
||||||
|
|
||||||
Execution order: **C → D → B → A → E**
|
Execution order: **C (apple) → D (beta) → B → A → E**
|
||||||
|
|
||||||
|
### Directory-based Sorting Example
|
||||||
|
|
||||||
|
Given these files:
|
||||||
|
```
|
||||||
|
1_001_create_schema.sql
|
||||||
|
1_001_create_users.sql ← Alphabetically before "drop_tables"
|
||||||
|
1_001_drop_tables.sql
|
||||||
|
1_002_add_indexes.sql
|
||||||
|
2_001_constraints.sql
|
||||||
|
```
|
||||||
|
|
||||||
|
Execution order (note alphabetical sorting at same priority/sequence):
|
||||||
|
```
|
||||||
|
1_001_create_schema.sql
|
||||||
|
1_001_create_users.sql
|
||||||
|
1_001_drop_tables.sql
|
||||||
|
1_002_add_indexes.sql
|
||||||
|
2_001_constraints.sql
|
||||||
|
```
|
||||||
|
|
||||||
## Output
|
## Output
|
||||||
|
|
||||||
|
|||||||
@@ -23,6 +23,11 @@ func NewWriter(options *writers.WriterOptions) *Writer {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Options returns the writer options (useful for reading execution results)
|
||||||
|
func (w *Writer) Options() *writers.WriterOptions {
|
||||||
|
return w.options
|
||||||
|
}
|
||||||
|
|
||||||
// WriteDatabase executes all scripts from all schemas in the database
|
// WriteDatabase executes all scripts from all schemas in the database
|
||||||
func (w *Writer) WriteDatabase(db *models.Database) error {
|
func (w *Writer) WriteDatabase(db *models.Database) error {
|
||||||
if db == nil {
|
if db == nil {
|
||||||
@@ -86,20 +91,39 @@ func (w *Writer) WriteTable(table *models.Table) error {
|
|||||||
return fmt.Errorf("WriteTable is not supported for SQL script execution")
|
return fmt.Errorf("WriteTable is not supported for SQL script execution")
|
||||||
}
|
}
|
||||||
|
|
||||||
// executeScripts executes scripts in Priority then Sequence order
|
// executeScripts executes scripts in Priority, Sequence, then Name order
|
||||||
func (w *Writer) executeScripts(ctx context.Context, conn *pgx.Conn, scripts []*models.Script) error {
|
func (w *Writer) executeScripts(ctx context.Context, conn *pgx.Conn, scripts []*models.Script) error {
|
||||||
if len(scripts) == 0 {
|
if len(scripts) == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sort scripts by Priority (ascending) then Sequence (ascending)
|
// Check if we should ignore errors
|
||||||
|
ignoreErrors := false
|
||||||
|
if val, ok := w.options.Metadata["ignore_errors"].(bool); ok {
|
||||||
|
ignoreErrors = val
|
||||||
|
}
|
||||||
|
|
||||||
|
// Track failed scripts and execution counts
|
||||||
|
var failedScripts []struct {
|
||||||
|
name string
|
||||||
|
priority int
|
||||||
|
sequence uint
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
successCount := 0
|
||||||
|
totalCount := 0
|
||||||
|
|
||||||
|
// Sort scripts by Priority (ascending), Sequence (ascending), then Name (ascending)
|
||||||
sortedScripts := make([]*models.Script, len(scripts))
|
sortedScripts := make([]*models.Script, len(scripts))
|
||||||
copy(sortedScripts, scripts)
|
copy(sortedScripts, scripts)
|
||||||
sort.Slice(sortedScripts, func(i, j int) bool {
|
sort.Slice(sortedScripts, func(i, j int) bool {
|
||||||
if sortedScripts[i].Priority != sortedScripts[j].Priority {
|
if sortedScripts[i].Priority != sortedScripts[j].Priority {
|
||||||
return sortedScripts[i].Priority < sortedScripts[j].Priority
|
return sortedScripts[i].Priority < sortedScripts[j].Priority
|
||||||
}
|
}
|
||||||
return sortedScripts[i].Sequence < sortedScripts[j].Sequence
|
if sortedScripts[i].Sequence != sortedScripts[j].Sequence {
|
||||||
|
return sortedScripts[i].Sequence < sortedScripts[j].Sequence
|
||||||
|
}
|
||||||
|
return sortedScripts[i].Name < sortedScripts[j].Name
|
||||||
})
|
})
|
||||||
|
|
||||||
// Execute each script in order
|
// Execute each script in order
|
||||||
@@ -108,18 +132,49 @@ func (w *Writer) executeScripts(ctx context.Context, conn *pgx.Conn, scripts []*
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
totalCount++
|
||||||
fmt.Printf("Executing script: %s (Priority=%d, Sequence=%d)\n",
|
fmt.Printf("Executing script: %s (Priority=%d, Sequence=%d)\n",
|
||||||
script.Name, script.Priority, script.Sequence)
|
script.Name, script.Priority, script.Sequence)
|
||||||
|
|
||||||
// Execute the SQL script
|
// Execute the SQL script
|
||||||
_, err := conn.Exec(ctx, script.SQL)
|
_, err := conn.Exec(ctx, script.SQL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to execute script %s (Priority=%d, Sequence=%d): %w",
|
if ignoreErrors {
|
||||||
|
fmt.Printf("⚠ Error executing %s: %v (continuing due to --ignore-errors)\n", script.Name, err)
|
||||||
|
failedScripts = append(failedScripts, struct {
|
||||||
|
name string
|
||||||
|
priority int
|
||||||
|
sequence uint
|
||||||
|
err error
|
||||||
|
}{
|
||||||
|
name: script.Name,
|
||||||
|
priority: script.Priority,
|
||||||
|
sequence: script.Sequence,
|
||||||
|
err: err,
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return fmt.Errorf("script %s (Priority=%d, Sequence=%d): %w",
|
||||||
script.Name, script.Priority, script.Sequence, err)
|
script.Name, script.Priority, script.Sequence, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
successCount++
|
||||||
fmt.Printf("✓ Successfully executed: %s\n", script.Name)
|
fmt.Printf("✓ Successfully executed: %s\n", script.Name)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Store execution results in metadata for caller
|
||||||
|
w.options.Metadata["execution_total"] = totalCount
|
||||||
|
w.options.Metadata["execution_success"] = successCount
|
||||||
|
w.options.Metadata["execution_failed"] = len(failedScripts)
|
||||||
|
|
||||||
|
// Print summary of failed scripts if any
|
||||||
|
if len(failedScripts) > 0 {
|
||||||
|
fmt.Printf("\n⚠ Failed Scripts Summary (%d failed):\n", len(failedScripts))
|
||||||
|
for i, failed := range failedScripts {
|
||||||
|
fmt.Printf(" %d. %s (Priority=%d, Sequence=%d)\n Error: %v\n",
|
||||||
|
i+1, failed.name, failed.priority, failed.sequence, failed.err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -99,13 +99,13 @@ func TestWriter_WriteTable(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestScriptSorting verifies that scripts are sorted correctly by Priority then Sequence
|
// TestScriptSorting verifies that scripts are sorted correctly by Priority, Sequence, then Name
|
||||||
func TestScriptSorting(t *testing.T) {
|
func TestScriptSorting(t *testing.T) {
|
||||||
scripts := []*models.Script{
|
scripts := []*models.Script{
|
||||||
{Name: "script1", Priority: 2, Sequence: 1, SQL: "SELECT 1;"},
|
{Name: "z_script1", Priority: 2, Sequence: 1, SQL: "SELECT 1;"},
|
||||||
{Name: "script2", Priority: 1, Sequence: 3, SQL: "SELECT 2;"},
|
{Name: "script2", Priority: 1, Sequence: 3, SQL: "SELECT 2;"},
|
||||||
{Name: "script3", Priority: 1, Sequence: 1, SQL: "SELECT 3;"},
|
{Name: "a_script3", Priority: 1, Sequence: 1, SQL: "SELECT 3;"},
|
||||||
{Name: "script4", Priority: 1, Sequence: 2, SQL: "SELECT 4;"},
|
{Name: "b_script4", Priority: 1, Sequence: 1, SQL: "SELECT 4;"},
|
||||||
{Name: "script5", Priority: 3, Sequence: 1, SQL: "SELECT 5;"},
|
{Name: "script5", Priority: 3, Sequence: 1, SQL: "SELECT 5;"},
|
||||||
{Name: "script6", Priority: 2, Sequence: 2, SQL: "SELECT 6;"},
|
{Name: "script6", Priority: 2, Sequence: 2, SQL: "SELECT 6;"},
|
||||||
}
|
}
|
||||||
@@ -114,25 +114,35 @@ func TestScriptSorting(t *testing.T) {
|
|||||||
sortedScripts := make([]*models.Script, len(scripts))
|
sortedScripts := make([]*models.Script, len(scripts))
|
||||||
copy(sortedScripts, scripts)
|
copy(sortedScripts, scripts)
|
||||||
|
|
||||||
// Use the same sorting logic from executeScripts
|
// Sort by Priority, Sequence, then Name (matching executeScripts logic)
|
||||||
for i := 0; i < len(sortedScripts)-1; i++ {
|
for i := 0; i < len(sortedScripts)-1; i++ {
|
||||||
for j := i + 1; j < len(sortedScripts); j++ {
|
for j := i + 1; j < len(sortedScripts); j++ {
|
||||||
if sortedScripts[i].Priority > sortedScripts[j].Priority ||
|
si, sj := sortedScripts[i], sortedScripts[j]
|
||||||
(sortedScripts[i].Priority == sortedScripts[j].Priority &&
|
// Compare by priority first
|
||||||
sortedScripts[i].Sequence > sortedScripts[j].Sequence) {
|
if si.Priority > sj.Priority {
|
||||||
sortedScripts[i], sortedScripts[j] = sortedScripts[j], sortedScripts[i]
|
sortedScripts[i], sortedScripts[j] = sortedScripts[j], sortedScripts[i]
|
||||||
|
} else if si.Priority == sj.Priority {
|
||||||
|
// If same priority, compare by sequence
|
||||||
|
if si.Sequence > sj.Sequence {
|
||||||
|
sortedScripts[i], sortedScripts[j] = sortedScripts[j], sortedScripts[i]
|
||||||
|
} else if si.Sequence == sj.Sequence {
|
||||||
|
// If same sequence, compare by name
|
||||||
|
if si.Name > sj.Name {
|
||||||
|
sortedScripts[i], sortedScripts[j] = sortedScripts[j], sortedScripts[i]
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Expected order after sorting
|
// Expected order after sorting (Priority -> Sequence -> Name)
|
||||||
expectedOrder := []string{
|
expectedOrder := []string{
|
||||||
"script3", // Priority 1, Sequence 1
|
"a_script3", // Priority 1, Sequence 1, Name a_script3
|
||||||
"script4", // Priority 1, Sequence 2
|
"b_script4", // Priority 1, Sequence 1, Name b_script4
|
||||||
"script2", // Priority 1, Sequence 3
|
"script2", // Priority 1, Sequence 3
|
||||||
"script1", // Priority 2, Sequence 1
|
"z_script1", // Priority 2, Sequence 1
|
||||||
"script6", // Priority 2, Sequence 2
|
"script6", // Priority 2, Sequence 2
|
||||||
"script5", // Priority 3, Sequence 1
|
"script5", // Priority 3, Sequence 1
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, expected := range expectedOrder {
|
for i, expected := range expectedOrder {
|
||||||
@@ -153,6 +163,13 @@ func TestScriptSorting(t *testing.T) {
|
|||||||
t.Errorf("Sequence not ascending at position %d with same priority %d: %d > %d",
|
t.Errorf("Sequence not ascending at position %d with same priority %d: %d > %d",
|
||||||
i, sortedScripts[i].Priority, sortedScripts[i].Sequence, sortedScripts[i+1].Sequence)
|
i, sortedScripts[i].Priority, sortedScripts[i].Sequence, sortedScripts[i+1].Sequence)
|
||||||
}
|
}
|
||||||
|
// Within same priority and sequence, names should be ascending
|
||||||
|
if sortedScripts[i].Priority == sortedScripts[i+1].Priority &&
|
||||||
|
sortedScripts[i].Sequence == sortedScripts[i+1].Sequence &&
|
||||||
|
sortedScripts[i].Name > sortedScripts[i+1].Name {
|
||||||
|
t.Errorf("Name not ascending at position %d with same priority/sequence: %s > %s",
|
||||||
|
i, sortedScripts[i].Name, sortedScripts[i+1].Name)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -28,10 +28,29 @@ type WriterOptions struct {
|
|||||||
// PackageName is the Go package name (for code generation)
|
// PackageName is the Go package name (for code generation)
|
||||||
PackageName string
|
PackageName string
|
||||||
|
|
||||||
|
// FlattenSchema disables schema.table dot notation and instead joins
|
||||||
|
// schema and table with an underscore (e.g., "public_users").
|
||||||
|
// Useful for databases like SQLite that do not support schemas.
|
||||||
|
FlattenSchema bool
|
||||||
|
|
||||||
// Additional options can be added here as needed
|
// Additional options can be added here as needed
|
||||||
Metadata map[string]interface{}
|
Metadata map[string]interface{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// QualifiedTableName returns a schema-qualified table name.
|
||||||
|
// When flatten is true, schema and table are joined with underscore (e.g., "schema_table").
|
||||||
|
// When flatten is false, they are dot-separated (e.g., "schema.table").
|
||||||
|
// If schema is empty, just the table name is returned regardless of flatten.
|
||||||
|
func QualifiedTableName(schema, table string, flatten bool) string {
|
||||||
|
if schema == "" {
|
||||||
|
return table
|
||||||
|
}
|
||||||
|
if flatten {
|
||||||
|
return schema + "_" + table
|
||||||
|
}
|
||||||
|
return schema + "." + table
|
||||||
|
}
|
||||||
|
|
||||||
// SanitizeFilename removes quotes, comments, and invalid characters from identifiers
|
// SanitizeFilename removes quotes, comments, and invalid characters from identifiers
|
||||||
// to make them safe for use in filenames. This handles:
|
// to make them safe for use in filenames. This handles:
|
||||||
// - Double and single quotes: "table_name" or 'table_name' -> table_name
|
// - Double and single quotes: "table_name" or 'table_name' -> table_name
|
||||||
|
|||||||
346
vendor/modules.txt
vendored
346
vendor/modules.txt
vendored
@@ -1,6 +1,92 @@
|
|||||||
|
# 4d63.com/gocheckcompilerdirectives v1.3.0
|
||||||
|
## explicit; go 1.22.0
|
||||||
|
# 4d63.com/gochecknoglobals v0.2.2
|
||||||
|
## explicit; go 1.18
|
||||||
|
# github.com/4meepo/tagalign v1.4.2
|
||||||
|
## explicit; go 1.22.0
|
||||||
|
# github.com/Abirdcfly/dupword v0.1.3
|
||||||
|
## explicit; go 1.22.0
|
||||||
|
# github.com/Antonboom/errname v1.0.0
|
||||||
|
## explicit; go 1.22.1
|
||||||
|
# github.com/Antonboom/nilnil v1.0.1
|
||||||
|
## explicit; go 1.22.0
|
||||||
|
# github.com/Antonboom/testifylint v1.5.2
|
||||||
|
## explicit; go 1.22.1
|
||||||
|
# github.com/BurntSushi/toml v1.4.1-0.20240526193622-a339e1f7089c
|
||||||
|
## explicit; go 1.18
|
||||||
|
# github.com/Crocmagnon/fatcontext v0.7.1
|
||||||
|
## explicit; go 1.22.0
|
||||||
|
# github.com/Djarvur/go-err113 v0.0.0-20210108212216-aea10b59be24
|
||||||
|
## explicit; go 1.13
|
||||||
|
# github.com/GaijinEntertainment/go-exhaustruct/v3 v3.3.1
|
||||||
|
## explicit; go 1.23.0
|
||||||
|
# github.com/Masterminds/semver/v3 v3.3.0
|
||||||
|
## explicit; go 1.21
|
||||||
|
# github.com/OpenPeeDeeP/depguard/v2 v2.2.1
|
||||||
|
## explicit; go 1.23.0
|
||||||
|
# github.com/alecthomas/go-check-sumtype v0.3.1
|
||||||
|
## explicit; go 1.22.0
|
||||||
|
# github.com/alexkohler/nakedret/v2 v2.0.5
|
||||||
|
## explicit; go 1.21
|
||||||
|
# github.com/alexkohler/prealloc v1.0.0
|
||||||
|
## explicit; go 1.15
|
||||||
|
# github.com/alingse/asasalint v0.0.11
|
||||||
|
## explicit; go 1.18
|
||||||
|
# github.com/alingse/nilnesserr v0.1.2
|
||||||
|
## explicit; go 1.22.0
|
||||||
|
# github.com/ashanbrown/forbidigo v1.6.0
|
||||||
|
## explicit; go 1.13
|
||||||
|
# github.com/ashanbrown/makezero v1.2.0
|
||||||
|
## explicit; go 1.12
|
||||||
|
# github.com/beorn7/perks v1.0.1
|
||||||
|
## explicit; go 1.11
|
||||||
|
# github.com/bkielbasa/cyclop v1.2.3
|
||||||
|
## explicit; go 1.22.0
|
||||||
|
# github.com/blizzy78/varnamelen v0.8.0
|
||||||
|
## explicit; go 1.16
|
||||||
|
# github.com/bombsimon/wsl/v4 v4.5.0
|
||||||
|
## explicit; go 1.22
|
||||||
|
# github.com/breml/bidichk v0.3.2
|
||||||
|
## explicit; go 1.22.0
|
||||||
|
# github.com/breml/errchkjson v0.4.0
|
||||||
|
## explicit; go 1.22.0
|
||||||
|
# github.com/butuzov/ireturn v0.3.1
|
||||||
|
## explicit; go 1.18
|
||||||
|
# github.com/butuzov/mirror v1.3.0
|
||||||
|
## explicit; go 1.19
|
||||||
|
# github.com/catenacyber/perfsprint v0.8.2
|
||||||
|
## explicit; go 1.22.0
|
||||||
|
# github.com/ccojocar/zxcvbn-go v1.0.2
|
||||||
|
## explicit; go 1.20
|
||||||
|
# github.com/cespare/xxhash/v2 v2.3.0
|
||||||
|
## explicit; go 1.11
|
||||||
|
# github.com/charithe/durationcheck v0.0.10
|
||||||
|
## explicit; go 1.14
|
||||||
|
# github.com/chavacava/garif v0.1.0
|
||||||
|
## explicit; go 1.16
|
||||||
|
# github.com/ckaznocha/intrange v0.3.0
|
||||||
|
## explicit; go 1.22
|
||||||
|
# github.com/curioswitch/go-reassign v0.3.0
|
||||||
|
## explicit; go 1.21
|
||||||
|
# github.com/daixiang0/gci v0.13.5
|
||||||
|
## explicit; go 1.21
|
||||||
# github.com/davecgh/go-spew v1.1.1
|
# github.com/davecgh/go-spew v1.1.1
|
||||||
## explicit
|
## explicit
|
||||||
github.com/davecgh/go-spew/spew
|
github.com/davecgh/go-spew/spew
|
||||||
|
# github.com/denis-tingaikin/go-header v0.5.0
|
||||||
|
## explicit; go 1.21
|
||||||
|
# github.com/ettle/strcase v0.2.0
|
||||||
|
## explicit; go 1.12
|
||||||
|
# github.com/fatih/color v1.18.0
|
||||||
|
## explicit; go 1.17
|
||||||
|
# github.com/fatih/structtag v1.2.0
|
||||||
|
## explicit; go 1.12
|
||||||
|
# github.com/firefart/nonamedreturns v1.0.5
|
||||||
|
## explicit; go 1.18
|
||||||
|
# github.com/fsnotify/fsnotify v1.5.4
|
||||||
|
## explicit; go 1.16
|
||||||
|
# github.com/fzipp/gocyclo v0.6.0
|
||||||
|
## explicit; go 1.18
|
||||||
# github.com/gdamore/encoding v1.0.1
|
# github.com/gdamore/encoding v1.0.1
|
||||||
## explicit; go 1.9
|
## explicit; go 1.9
|
||||||
github.com/gdamore/encoding
|
github.com/gdamore/encoding
|
||||||
@@ -44,9 +130,75 @@ github.com/gdamore/tcell/v2/terminfo/x/xfce
|
|||||||
github.com/gdamore/tcell/v2/terminfo/x/xterm
|
github.com/gdamore/tcell/v2/terminfo/x/xterm
|
||||||
github.com/gdamore/tcell/v2/terminfo/x/xterm_ghostty
|
github.com/gdamore/tcell/v2/terminfo/x/xterm_ghostty
|
||||||
github.com/gdamore/tcell/v2/terminfo/x/xterm_kitty
|
github.com/gdamore/tcell/v2/terminfo/x/xterm_kitty
|
||||||
|
# github.com/ghostiam/protogetter v0.3.9
|
||||||
|
## explicit; go 1.22.0
|
||||||
|
# github.com/go-critic/go-critic v0.12.0
|
||||||
|
## explicit; go 1.22.0
|
||||||
|
# github.com/go-toolsmith/astcast v1.1.0
|
||||||
|
## explicit; go 1.16
|
||||||
|
# github.com/go-toolsmith/astcopy v1.1.0
|
||||||
|
## explicit; go 1.16
|
||||||
|
# github.com/go-toolsmith/astequal v1.2.0
|
||||||
|
## explicit; go 1.18
|
||||||
|
# github.com/go-toolsmith/astfmt v1.1.0
|
||||||
|
## explicit; go 1.16
|
||||||
|
# github.com/go-toolsmith/astp v1.1.0
|
||||||
|
## explicit; go 1.16
|
||||||
|
# github.com/go-toolsmith/strparse v1.1.0
|
||||||
|
## explicit; go 1.16
|
||||||
|
# github.com/go-toolsmith/typep v1.1.0
|
||||||
|
## explicit; go 1.16
|
||||||
|
# github.com/go-viper/mapstructure/v2 v2.2.1
|
||||||
|
## explicit; go 1.18
|
||||||
|
# github.com/go-xmlfmt/xmlfmt v1.1.3
|
||||||
|
## explicit
|
||||||
|
# github.com/gobwas/glob v0.2.3
|
||||||
|
## explicit
|
||||||
|
# github.com/gofrs/flock v0.12.1
|
||||||
|
## explicit; go 1.21.0
|
||||||
|
# github.com/golang/protobuf v1.5.3
|
||||||
|
## explicit; go 1.9
|
||||||
|
# github.com/golangci/dupl v0.0.0-20250308024227-f665c8d69b32
|
||||||
|
## explicit; go 1.22.0
|
||||||
|
# github.com/golangci/go-printf-func-name v0.1.0
|
||||||
|
## explicit; go 1.22.0
|
||||||
|
# github.com/golangci/gofmt v0.0.0-20250106114630-d62b90e6713d
|
||||||
|
## explicit; go 1.22.0
|
||||||
|
# github.com/golangci/golangci-lint v1.64.8
|
||||||
|
## explicit; go 1.23.0
|
||||||
|
# github.com/golangci/misspell v0.6.0
|
||||||
|
## explicit; go 1.21
|
||||||
|
# github.com/golangci/plugin-module-register v0.1.1
|
||||||
|
## explicit; go 1.21
|
||||||
|
# github.com/golangci/revgrep v0.8.0
|
||||||
|
## explicit; go 1.21
|
||||||
|
# github.com/golangci/unconvert v0.0.0-20240309020433-c5143eacb3ed
|
||||||
|
## explicit; go 1.20
|
||||||
|
# github.com/google/go-cmp v0.7.0
|
||||||
|
## explicit; go 1.21
|
||||||
# github.com/google/uuid v1.6.0
|
# github.com/google/uuid v1.6.0
|
||||||
## explicit
|
## explicit
|
||||||
github.com/google/uuid
|
github.com/google/uuid
|
||||||
|
# github.com/gordonklaus/ineffassign v0.1.0
|
||||||
|
## explicit; go 1.14
|
||||||
|
# github.com/gostaticanalysis/analysisutil v0.7.1
|
||||||
|
## explicit; go 1.16
|
||||||
|
# github.com/gostaticanalysis/comment v1.5.0
|
||||||
|
## explicit; go 1.22.9
|
||||||
|
# github.com/gostaticanalysis/forcetypeassert v0.2.0
|
||||||
|
## explicit; go 1.23.0
|
||||||
|
# github.com/gostaticanalysis/nilerr v0.1.1
|
||||||
|
## explicit; go 1.15
|
||||||
|
# github.com/hashicorp/go-immutable-radix/v2 v2.1.0
|
||||||
|
## explicit; go 1.18
|
||||||
|
# github.com/hashicorp/go-version v1.7.0
|
||||||
|
## explicit
|
||||||
|
# github.com/hashicorp/golang-lru/v2 v2.0.7
|
||||||
|
## explicit; go 1.18
|
||||||
|
# github.com/hashicorp/hcl v1.0.0
|
||||||
|
## explicit
|
||||||
|
# github.com/hexops/gotextdiff v1.0.3
|
||||||
|
## explicit; go 1.16
|
||||||
# github.com/inconshreveable/mousetrap v1.1.0
|
# github.com/inconshreveable/mousetrap v1.1.0
|
||||||
## explicit; go 1.18
|
## explicit; go 1.18
|
||||||
github.com/inconshreveable/mousetrap
|
github.com/inconshreveable/mousetrap
|
||||||
@@ -68,23 +220,115 @@ github.com/jackc/pgx/v5/pgconn/ctxwatch
|
|||||||
github.com/jackc/pgx/v5/pgconn/internal/bgreader
|
github.com/jackc/pgx/v5/pgconn/internal/bgreader
|
||||||
github.com/jackc/pgx/v5/pgproto3
|
github.com/jackc/pgx/v5/pgproto3
|
||||||
github.com/jackc/pgx/v5/pgtype
|
github.com/jackc/pgx/v5/pgtype
|
||||||
|
# github.com/jgautheron/goconst v1.7.1
|
||||||
|
## explicit; go 1.13
|
||||||
|
# github.com/jingyugao/rowserrcheck v1.1.1
|
||||||
|
## explicit; go 1.13
|
||||||
# github.com/jinzhu/inflection v1.0.0
|
# github.com/jinzhu/inflection v1.0.0
|
||||||
## explicit
|
## explicit
|
||||||
github.com/jinzhu/inflection
|
github.com/jinzhu/inflection
|
||||||
|
# github.com/jjti/go-spancheck v0.6.4
|
||||||
|
## explicit; go 1.22.1
|
||||||
|
# github.com/julz/importas v0.2.0
|
||||||
|
## explicit; go 1.20
|
||||||
|
# github.com/karamaru-alpha/copyloopvar v1.2.1
|
||||||
|
## explicit; go 1.21
|
||||||
|
# github.com/kisielk/errcheck v1.9.0
|
||||||
|
## explicit; go 1.22.0
|
||||||
|
# github.com/kkHAIKE/contextcheck v1.1.6
|
||||||
|
## explicit; go 1.23.0
|
||||||
# github.com/kr/pretty v0.3.1
|
# github.com/kr/pretty v0.3.1
|
||||||
## explicit; go 1.12
|
## explicit; go 1.12
|
||||||
|
# github.com/kulti/thelper v0.6.3
|
||||||
|
## explicit; go 1.18
|
||||||
|
# github.com/kunwardeep/paralleltest v1.0.10
|
||||||
|
## explicit; go 1.17
|
||||||
|
# github.com/lasiar/canonicalheader v1.1.2
|
||||||
|
## explicit; go 1.22.0
|
||||||
|
# github.com/ldez/exptostd v0.4.2
|
||||||
|
## explicit; go 1.22.0
|
||||||
|
# github.com/ldez/gomoddirectives v0.6.1
|
||||||
|
## explicit; go 1.22.0
|
||||||
|
# github.com/ldez/grignotin v0.9.0
|
||||||
|
## explicit; go 1.22.0
|
||||||
|
# github.com/ldez/tagliatelle v0.7.1
|
||||||
|
## explicit; go 1.22.0
|
||||||
|
# github.com/ldez/usetesting v0.4.2
|
||||||
|
## explicit; go 1.22.0
|
||||||
|
# github.com/leonklingele/grouper v1.1.2
|
||||||
|
## explicit; go 1.18
|
||||||
# github.com/lucasb-eyer/go-colorful v1.2.0
|
# github.com/lucasb-eyer/go-colorful v1.2.0
|
||||||
## explicit; go 1.12
|
## explicit; go 1.12
|
||||||
github.com/lucasb-eyer/go-colorful
|
github.com/lucasb-eyer/go-colorful
|
||||||
|
# github.com/macabu/inamedparam v0.1.3
|
||||||
|
## explicit; go 1.20
|
||||||
|
# github.com/magiconair/properties v1.8.6
|
||||||
|
## explicit; go 1.13
|
||||||
|
# github.com/maratori/testableexamples v1.0.0
|
||||||
|
## explicit; go 1.19
|
||||||
|
# github.com/maratori/testpackage v1.1.1
|
||||||
|
## explicit; go 1.20
|
||||||
|
# github.com/matoous/godox v1.1.0
|
||||||
|
## explicit; go 1.18
|
||||||
|
# github.com/mattn/go-colorable v0.1.14
|
||||||
|
## explicit; go 1.18
|
||||||
|
# github.com/mattn/go-isatty v0.0.20
|
||||||
|
## explicit; go 1.15
|
||||||
# github.com/mattn/go-runewidth v0.0.16
|
# github.com/mattn/go-runewidth v0.0.16
|
||||||
## explicit; go 1.9
|
## explicit; go 1.9
|
||||||
github.com/mattn/go-runewidth
|
github.com/mattn/go-runewidth
|
||||||
|
# github.com/matttproud/golang_protobuf_extensions v1.0.1
|
||||||
|
## explicit
|
||||||
|
# github.com/mgechev/revive v1.7.0
|
||||||
|
## explicit; go 1.22.1
|
||||||
|
# github.com/mitchellh/go-homedir v1.1.0
|
||||||
|
## explicit
|
||||||
|
# github.com/mitchellh/mapstructure v1.5.0
|
||||||
|
## explicit; go 1.14
|
||||||
|
# github.com/moricho/tparallel v0.3.2
|
||||||
|
## explicit; go 1.20
|
||||||
|
# github.com/nakabonne/nestif v0.3.1
|
||||||
|
## explicit; go 1.15
|
||||||
|
# github.com/nishanths/exhaustive v0.12.0
|
||||||
|
## explicit; go 1.18
|
||||||
|
# github.com/nishanths/predeclared v0.2.2
|
||||||
|
## explicit; go 1.14
|
||||||
|
# github.com/nunnatsa/ginkgolinter v0.19.1
|
||||||
|
## explicit; go 1.23.0
|
||||||
|
# github.com/olekukonko/tablewriter v0.0.5
|
||||||
|
## explicit; go 1.12
|
||||||
|
# github.com/pelletier/go-toml v1.9.5
|
||||||
|
## explicit; go 1.12
|
||||||
|
# github.com/pelletier/go-toml/v2 v2.2.3
|
||||||
|
## explicit; go 1.21.0
|
||||||
# github.com/pmezard/go-difflib v1.0.0
|
# github.com/pmezard/go-difflib v1.0.0
|
||||||
## explicit
|
## explicit
|
||||||
github.com/pmezard/go-difflib/difflib
|
github.com/pmezard/go-difflib/difflib
|
||||||
|
# github.com/polyfloyd/go-errorlint v1.7.1
|
||||||
|
## explicit; go 1.22.0
|
||||||
|
# github.com/prometheus/client_golang v1.12.1
|
||||||
|
## explicit; go 1.13
|
||||||
|
# github.com/prometheus/client_model v0.2.0
|
||||||
|
## explicit; go 1.9
|
||||||
|
# github.com/prometheus/common v0.32.1
|
||||||
|
## explicit; go 1.13
|
||||||
|
# github.com/prometheus/procfs v0.7.3
|
||||||
|
## explicit; go 1.13
|
||||||
# github.com/puzpuzpuz/xsync/v3 v3.5.1
|
# github.com/puzpuzpuz/xsync/v3 v3.5.1
|
||||||
## explicit; go 1.18
|
## explicit; go 1.18
|
||||||
github.com/puzpuzpuz/xsync/v3
|
github.com/puzpuzpuz/xsync/v3
|
||||||
|
# github.com/quasilyte/go-ruleguard v0.4.3-0.20240823090925-0fe6f58b47b1
|
||||||
|
## explicit; go 1.19
|
||||||
|
# github.com/quasilyte/go-ruleguard/dsl v0.3.22
|
||||||
|
## explicit; go 1.15
|
||||||
|
# github.com/quasilyte/gogrep v0.5.0
|
||||||
|
## explicit; go 1.16
|
||||||
|
# github.com/quasilyte/regex/syntax v0.0.0-20210819130434-b3f0c404a727
|
||||||
|
## explicit; go 1.14
|
||||||
|
# github.com/quasilyte/stdinfo v0.0.0-20220114132959-f7386bf02567
|
||||||
|
## explicit; go 1.17
|
||||||
|
# github.com/raeperd/recvcheck v0.2.0
|
||||||
|
## explicit; go 1.22.0
|
||||||
# github.com/rivo/tview v0.42.0
|
# github.com/rivo/tview v0.42.0
|
||||||
## explicit; go 1.18
|
## explicit; go 1.18
|
||||||
github.com/rivo/tview
|
github.com/rivo/tview
|
||||||
@@ -93,20 +337,76 @@ github.com/rivo/tview
|
|||||||
github.com/rivo/uniseg
|
github.com/rivo/uniseg
|
||||||
# github.com/rogpeppe/go-internal v1.14.1
|
# github.com/rogpeppe/go-internal v1.14.1
|
||||||
## explicit; go 1.23
|
## explicit; go 1.23
|
||||||
|
# github.com/ryancurrah/gomodguard v1.3.5
|
||||||
|
## explicit; go 1.22.0
|
||||||
|
# github.com/ryanrolds/sqlclosecheck v0.5.1
|
||||||
|
## explicit; go 1.20
|
||||||
|
# github.com/sanposhiho/wastedassign/v2 v2.1.0
|
||||||
|
## explicit; go 1.18
|
||||||
|
# github.com/santhosh-tekuri/jsonschema/v6 v6.0.1
|
||||||
|
## explicit; go 1.21
|
||||||
|
# github.com/sashamelentyev/interfacebloat v1.1.0
|
||||||
|
## explicit; go 1.18
|
||||||
|
# github.com/sashamelentyev/usestdlibvars v1.28.0
|
||||||
|
## explicit; go 1.20
|
||||||
|
# github.com/securego/gosec/v2 v2.22.2
|
||||||
|
## explicit; go 1.23.0
|
||||||
|
# github.com/sirupsen/logrus v1.9.3
|
||||||
|
## explicit; go 1.13
|
||||||
|
# github.com/sivchari/containedctx v1.0.3
|
||||||
|
## explicit; go 1.17
|
||||||
|
# github.com/sivchari/tenv v1.12.1
|
||||||
|
## explicit; go 1.22.0
|
||||||
|
# github.com/sonatard/noctx v0.1.0
|
||||||
|
## explicit; go 1.22.0
|
||||||
|
# github.com/sourcegraph/go-diff v0.7.0
|
||||||
|
## explicit; go 1.14
|
||||||
|
# github.com/spf13/afero v1.12.0
|
||||||
|
## explicit; go 1.21
|
||||||
|
# github.com/spf13/cast v1.5.0
|
||||||
|
## explicit; go 1.18
|
||||||
# github.com/spf13/cobra v1.10.2
|
# github.com/spf13/cobra v1.10.2
|
||||||
## explicit; go 1.15
|
## explicit; go 1.15
|
||||||
github.com/spf13/cobra
|
github.com/spf13/cobra
|
||||||
|
# github.com/spf13/jwalterweatherman v1.1.0
|
||||||
|
## explicit
|
||||||
# github.com/spf13/pflag v1.0.10
|
# github.com/spf13/pflag v1.0.10
|
||||||
## explicit; go 1.12
|
## explicit; go 1.12
|
||||||
github.com/spf13/pflag
|
github.com/spf13/pflag
|
||||||
|
# github.com/spf13/viper v1.12.0
|
||||||
|
## explicit; go 1.17
|
||||||
|
# github.com/ssgreg/nlreturn/v2 v2.2.1
|
||||||
|
## explicit; go 1.13
|
||||||
|
# github.com/stbenjam/no-sprintf-host-port v0.2.0
|
||||||
|
## explicit; go 1.18
|
||||||
|
# github.com/stretchr/objx v0.5.2
|
||||||
|
## explicit; go 1.20
|
||||||
# github.com/stretchr/testify v1.11.1
|
# github.com/stretchr/testify v1.11.1
|
||||||
## explicit; go 1.17
|
## explicit; go 1.17
|
||||||
github.com/stretchr/testify/assert
|
github.com/stretchr/testify/assert
|
||||||
github.com/stretchr/testify/assert/yaml
|
github.com/stretchr/testify/assert/yaml
|
||||||
github.com/stretchr/testify/require
|
github.com/stretchr/testify/require
|
||||||
|
# github.com/subosito/gotenv v1.4.1
|
||||||
|
## explicit; go 1.18
|
||||||
|
# github.com/tdakkota/asciicheck v0.4.1
|
||||||
|
## explicit; go 1.22.0
|
||||||
|
# github.com/tetafro/godot v1.5.0
|
||||||
|
## explicit; go 1.20
|
||||||
|
# github.com/timakin/bodyclose v0.0.0-20241017074812-ed6a65f985e3
|
||||||
|
## explicit; go 1.12
|
||||||
|
# github.com/timonwong/loggercheck v0.10.1
|
||||||
|
## explicit; go 1.22.0
|
||||||
# github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc
|
# github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc
|
||||||
## explicit
|
## explicit
|
||||||
github.com/tmthrgd/go-hex
|
github.com/tmthrgd/go-hex
|
||||||
|
# github.com/tomarrell/wrapcheck/v2 v2.10.0
|
||||||
|
## explicit; go 1.21
|
||||||
|
# github.com/tommy-muehle/go-mnd/v2 v2.5.1
|
||||||
|
## explicit; go 1.12
|
||||||
|
# github.com/ultraware/funlen v0.2.0
|
||||||
|
## explicit; go 1.22.0
|
||||||
|
# github.com/ultraware/whitespace v0.2.0
|
||||||
|
## explicit; go 1.20
|
||||||
# github.com/uptrace/bun v1.2.16
|
# github.com/uptrace/bun v1.2.16
|
||||||
## explicit; go 1.24.0
|
## explicit; go 1.24.0
|
||||||
github.com/uptrace/bun
|
github.com/uptrace/bun
|
||||||
@@ -118,6 +418,10 @@ github.com/uptrace/bun/internal
|
|||||||
github.com/uptrace/bun/internal/parser
|
github.com/uptrace/bun/internal/parser
|
||||||
github.com/uptrace/bun/internal/tagparser
|
github.com/uptrace/bun/internal/tagparser
|
||||||
github.com/uptrace/bun/schema
|
github.com/uptrace/bun/schema
|
||||||
|
# github.com/uudashr/gocognit v1.2.0
|
||||||
|
## explicit; go 1.19
|
||||||
|
# github.com/uudashr/iface v1.3.1
|
||||||
|
## explicit; go 1.22.1
|
||||||
# github.com/vmihailenco/msgpack/v5 v5.4.1
|
# github.com/vmihailenco/msgpack/v5 v5.4.1
|
||||||
## explicit; go 1.19
|
## explicit; go 1.19
|
||||||
github.com/vmihailenco/msgpack/v5
|
github.com/vmihailenco/msgpack/v5
|
||||||
@@ -127,9 +431,37 @@ github.com/vmihailenco/msgpack/v5/msgpcode
|
|||||||
github.com/vmihailenco/tagparser/v2
|
github.com/vmihailenco/tagparser/v2
|
||||||
github.com/vmihailenco/tagparser/v2/internal
|
github.com/vmihailenco/tagparser/v2/internal
|
||||||
github.com/vmihailenco/tagparser/v2/internal/parser
|
github.com/vmihailenco/tagparser/v2/internal/parser
|
||||||
|
# github.com/xen0n/gosmopolitan v1.2.2
|
||||||
|
## explicit; go 1.19
|
||||||
|
# github.com/yagipy/maintidx v1.0.0
|
||||||
|
## explicit; go 1.17
|
||||||
|
# github.com/yeya24/promlinter v0.3.0
|
||||||
|
## explicit; go 1.20
|
||||||
|
# github.com/ykadowak/zerologlint v0.1.5
|
||||||
|
## explicit; go 1.19
|
||||||
|
# gitlab.com/bosi/decorder v0.4.2
|
||||||
|
## explicit; go 1.20
|
||||||
|
# go-simpler.org/musttag v0.13.0
|
||||||
|
## explicit; go 1.20
|
||||||
|
# go-simpler.org/sloglint v0.9.0
|
||||||
|
## explicit; go 1.22.0
|
||||||
|
# go.uber.org/atomic v1.7.0
|
||||||
|
## explicit; go 1.13
|
||||||
|
# go.uber.org/automaxprocs v1.6.0
|
||||||
|
## explicit; go 1.20
|
||||||
|
# go.uber.org/multierr v1.6.0
|
||||||
|
## explicit; go 1.12
|
||||||
|
# go.uber.org/zap v1.24.0
|
||||||
|
## explicit; go 1.19
|
||||||
# golang.org/x/crypto v0.41.0
|
# golang.org/x/crypto v0.41.0
|
||||||
## explicit; go 1.23.0
|
## explicit; go 1.23.0
|
||||||
golang.org/x/crypto/pbkdf2
|
golang.org/x/crypto/pbkdf2
|
||||||
|
# golang.org/x/exp/typeparams v0.0.0-20250210185358-939b2ce775ac
|
||||||
|
## explicit; go 1.18
|
||||||
|
# golang.org/x/mod v0.26.0
|
||||||
|
## explicit; go 1.23.0
|
||||||
|
# golang.org/x/sync v0.16.0
|
||||||
|
## explicit; go 1.23.0
|
||||||
# golang.org/x/sys v0.38.0
|
# golang.org/x/sys v0.38.0
|
||||||
## explicit; go 1.24.0
|
## explicit; go 1.24.0
|
||||||
golang.org/x/sys/cpu
|
golang.org/x/sys/cpu
|
||||||
@@ -156,6 +488,20 @@ golang.org/x/text/transform
|
|||||||
golang.org/x/text/unicode/bidi
|
golang.org/x/text/unicode/bidi
|
||||||
golang.org/x/text/unicode/norm
|
golang.org/x/text/unicode/norm
|
||||||
golang.org/x/text/width
|
golang.org/x/text/width
|
||||||
|
# golang.org/x/tools v0.35.0
|
||||||
|
## explicit; go 1.23.0
|
||||||
|
# google.golang.org/protobuf v1.36.5
|
||||||
|
## explicit; go 1.21
|
||||||
|
# gopkg.in/ini.v1 v1.67.0
|
||||||
|
## explicit
|
||||||
|
# gopkg.in/yaml.v2 v2.4.0
|
||||||
|
## explicit; go 1.15
|
||||||
# gopkg.in/yaml.v3 v3.0.1
|
# gopkg.in/yaml.v3 v3.0.1
|
||||||
## explicit
|
## explicit
|
||||||
gopkg.in/yaml.v3
|
gopkg.in/yaml.v3
|
||||||
|
# honnef.co/go/tools v0.6.1
|
||||||
|
## explicit; go 1.23
|
||||||
|
# mvdan.cc/gofumpt v0.7.0
|
||||||
|
## explicit; go 1.22
|
||||||
|
# mvdan.cc/unparam v0.0.0-20240528143540-8a5130ca722f
|
||||||
|
## explicit; go 1.21
|
||||||
|
|||||||
Reference in New Issue
Block a user