22 Commits

Author SHA1 Message Date
Hein
5fb09b78c3 feat(relations): 🎉 add flatten schema option for output
All checks were successful
CI / Test (1.24) (push) Successful in -25m5s
CI / Test (1.25) (push) Successful in -24m57s
CI / Build (push) Successful in -26m5s
CI / Lint (push) Successful in -25m51s
Integration Tests / Integration Tests (push) Successful in -25m42s
Release / Build and Release (push) Successful in -24m39s
* Introduce `--flatten-schema` flag to convert, merge, and split commands.
* Modify database writing functions to support flattened schema names.
* Update template functions to handle schema.table naming convention.
* Enhance PostgreSQL writer to utilize flattened schema in generated SQL.
* Update tests to ensure compatibility with new flattening feature.
* Dependencies updated for improved functionality.
2026-02-05 14:07:55 +02:00
5d9770b430 test(pgsql, reflectutil): add comprehensive test coverage
All checks were successful
CI / Test (1.24) (push) Successful in -26m14s
CI / Test (1.25) (push) Successful in -26m3s
CI / Lint (push) Successful in -26m28s
CI / Build (push) Successful in -26m41s
Integration Tests / Integration Tests (push) Successful in -26m21s
* Introduce tests for PostgreSQL data types and keywords.
* Implement tests for reflect utility functions.
* Ensure consistency and correctness of type conversions and keyword mappings.
* Validate behavior for various edge cases and input types.
2026-01-31 22:30:00 +02:00
f2d500f98d feat(merge): 🎉 Add support for constraints and indexes in merge results
All checks were successful
CI / Test (1.24) (push) Successful in -26m24s
CI / Test (1.25) (push) Successful in -26m10s
CI / Lint (push) Successful in -26m33s
CI / Build (push) Successful in -26m40s
Release / Build and Release (push) Successful in -26m23s
Integration Tests / Integration Tests (push) Successful in -25m53s
* Enhance MergeResult to track added constraints and indexes.
* Update merge logic to increment counters for added constraints and indexes.
* Modify GetMergeSummary to include constraints and indexes in the output.
* Add comprehensive tests for merging constraints and indexes.
2026-01-31 21:30:55 +02:00
2ec9991324 feat(merge): 🎉 Add support for merging constraints and indexes
Some checks failed
CI / Test (1.24) (push) Failing after -26m37s
CI / Test (1.25) (push) Successful in -26m8s
CI / Lint (push) Successful in -26m32s
CI / Build (push) Successful in -26m42s
Release / Build and Release (push) Successful in -26m26s
Integration Tests / Integration Tests (push) Successful in -26m3s
* Implement mergeConstraints to handle table constraints
* Implement mergeIndexes to handle table indexes
* Update mergeTables to include constraints and indexes during merge
2026-01-31 21:27:28 +02:00
a3e45c206d feat(writer): 🎉 Enhance SQL execution logging and add statement type detection
All checks were successful
CI / Test (1.24) (push) Successful in -26m21s
CI / Test (1.25) (push) Successful in -26m15s
CI / Build (push) Successful in -26m39s
CI / Lint (push) Successful in -26m29s
Release / Build and Release (push) Successful in -26m28s
Integration Tests / Integration Tests (push) Successful in -26m11s
* Log statement type during execution for better debugging
* Introduce detectStatementType function to categorize SQL statements
* Update unique constraint naming convention in tests
2026-01-31 21:19:48 +02:00
165623bb1d feat(pgsql): Add templates for constraints and sequences
All checks were successful
CI / Test (1.24) (push) Successful in -26m21s
CI / Test (1.25) (push) Successful in -26m13s
CI / Build (push) Successful in -26m39s
CI / Lint (push) Successful in -26m29s
Release / Build and Release (push) Successful in -26m28s
Integration Tests / Integration Tests (push) Successful in -26m10s
* Introduce new templates for creating unique, check, and foreign key constraints with existence checks.
* Add templates for setting sequence values and creating sequences.
* Refactor existing SQL generation logic to utilize new templates for better maintainability and readability.
* Ensure identifiers are properly quoted to handle special characters and reserved keywords.
2026-01-31 21:04:43 +02:00
3c20c3c5d9 feat(writer): 🎉 Add support for check constraints in schema generation
All checks were successful
CI / Test (1.24) (push) Successful in -26m17s
CI / Test (1.25) (push) Successful in -26m14s
CI / Build (push) Successful in -26m41s
CI / Lint (push) Successful in -26m32s
Release / Build and Release (push) Successful in -26m31s
Integration Tests / Integration Tests (push) Successful in -26m13s
* Implement check constraints in the schema writer.
* Generate SQL statements to add check constraints if they do not exist.
* Add tests to verify correct generation of check constraints.
2026-01-31 20:42:19 +02:00
a54594e49b feat(writer): 🎉 Add support for unique constraints in schema generation
All checks were successful
CI / Test (1.24) (push) Successful in -26m26s
CI / Test (1.25) (push) Successful in -26m18s
CI / Lint (push) Successful in -26m25s
CI / Build (push) Successful in -26m35s
Release / Build and Release (push) Successful in -26m29s
Integration Tests / Integration Tests (push) Successful in -26m11s
* Implement unique constraint handling in GenerateSchemaStatements
* Add writeUniqueConstraints method for generating SQL statements
* Create unit test for unique constraints in writer_test.go
2026-01-31 20:33:08 +02:00
cafe6a461f feat(scripts): 🎉 Add --ignore-errors flag for script execution
All checks were successful
CI / Test (1.24) (push) Successful in -26m18s
CI / Test (1.25) (push) Successful in -26m14s
CI / Build (push) Successful in -26m38s
CI / Lint (push) Successful in -26m30s
Release / Build and Release (push) Successful in -26m27s
Integration Tests / Integration Tests (push) Successful in -26m10s
- Allow continued execution of scripts even if errors occur.
- Update execution summary to include counts of successful and failed scripts.
- Enhance error handling and reporting for better visibility.
2026-01-31 20:21:22 +02:00
abdb9b4c78 feat(dbml/reader): 🎉 Implement splitIdentifier function for parsing
All checks were successful
CI / Test (1.24) (push) Successful in -26m24s
CI / Test (1.25) (push) Successful in -26m17s
CI / Build (push) Successful in -26m44s
CI / Lint (push) Successful in -26m33s
Integration Tests / Integration Tests (push) Successful in -26m11s
Release / Build and Release (push) Successful in -26m36s
2026-01-31 19:45:24 +02:00
e7a15c8e4f feat(writer): 🎉 Implement add column statements for schema evolution
All checks were successful
CI / Test (1.24) (push) Successful in -26m24s
CI / Test (1.25) (push) Successful in -26m14s
CI / Lint (push) Successful in -26m30s
CI / Build (push) Successful in -26m41s
Release / Build and Release (push) Successful in -26m29s
Integration Tests / Integration Tests (push) Successful in -26m13s
* Add functionality to generate ALTER TABLE ADD COLUMN statements for existing tables.
* Introduce tests for generating and writing add column statements.
* Enhance schema evolution capabilities when new columns are added.
2026-01-31 19:12:00 +02:00
c36b5ede2b feat(writer): 🎉 Enhance primary key handling and add tests
All checks were successful
CI / Test (1.24) (push) Successful in -26m18s
CI / Test (1.25) (push) Successful in -26m11s
CI / Build (push) Successful in -26m43s
CI / Lint (push) Successful in -26m34s
Release / Build and Release (push) Successful in -26m31s
Integration Tests / Integration Tests (push) Successful in -26m20s
* Implement checks for existing primary keys before adding new ones.
* Drop auto-generated primary keys if they exist.
* Add tests for primary key existence and column size specifiers.
* Improve type conversion handling for PostgreSQL compatibility.
2026-01-31 18:59:32 +02:00
51ab29f8e3 feat(writer): 🎉 Update index naming conventions for consistency
All checks were successful
CI / Test (1.24) (push) Successful in -26m25s
CI / Test (1.25) (push) Successful in -26m17s
CI / Lint (push) Successful in -26m32s
CI / Build (push) Successful in -26m42s
Release / Build and Release (push) Successful in -26m31s
Integration Tests / Integration Tests (push) Successful in -26m24s
* Use SQLName() for primary key constraint naming
* Enhance index name formatting with column suffix
2026-01-31 17:23:18 +02:00
f532fc110c feat(writer): 🎉 Enhance script execution order and add symlink skipping
All checks were successful
CI / Test (1.24) (push) Successful in -26m10s
CI / Test (1.25) (push) Successful in -26m8s
CI / Build (push) Successful in -26m44s
CI / Lint (push) Successful in -26m32s
Integration Tests / Integration Tests (push) Successful in -26m26s
* Update script execution to sort by Priority, Sequence, and Name.
* Add functionality to skip symbolic links during directory scanning.
* Improve documentation to reflect changes in execution order and features.
* Add tests for symlink skipping and ensure correct script sorting.
2026-01-31 16:59:17 +02:00
92dff99725 feat(writer): enhance type conversion for PostgreSQL compatibility and add tests
Some checks failed
CI / Test (1.24) (push) Successful in -26m32s
CI / Test (1.25) (push) Successful in -26m27s
CI / Build (push) Successful in -26m48s
CI / Lint (push) Successful in -26m33s
Integration Tests / Integration Tests (push) Failing after -26m51s
Release / Build and Release (push) Successful in -26m41s
2026-01-29 21:36:23 +02:00
283b568adb feat(pgsql): add execution reporting for SQL statements
All checks were successful
CI / Test (1.24) (push) Successful in -25m29s
CI / Test (1.25) (push) Successful in -25m13s
CI / Lint (push) Successful in -26m13s
CI / Build (push) Successful in -26m27s
Integration Tests / Integration Tests (push) Successful in -26m11s
Release / Build and Release (push) Successful in -25m8s
- Implemented ExecutionReport to track the execution status of SQL statements.
- Added SchemaReport and TableReport to monitor execution per schema and table.
- Enhanced WriteDatabase to execute SQL directly on a PostgreSQL database if a connection string is provided.
- Included error handling and logging for failed statements during execution.
- Added functionality to write execution reports to a JSON file.
- Introduced utility functions to extract table names from CREATE TABLE statements and truncate long SQL statements for error messages.
2026-01-29 21:16:14 +02:00
122743ee43 feat(writer): 🎉 Improve primary key handling by checking for explicit constraints and columns
Some checks failed
CI / Test (1.25) (push) Successful in -26m17s
CI / Test (1.24) (push) Successful in -25m44s
CI / Lint (push) Successful in -26m43s
CI / Build (push) Failing after -27m1s
Release / Build and Release (push) Successful in -26m39s
Integration Tests / Integration Tests (push) Successful in -26m25s
2026-01-28 22:08:27 +02:00
91b6046b9b feat(writer): 🎉 Enhance PostgreSQL writer, fixed bugs found using origin
Some checks failed
CI / Test (1.24) (push) Failing after -24m5s
CI / Test (1.25) (push) Successful in -23m53s
CI / Build (push) Failing after -26m29s
CI / Lint (push) Successful in -26m12s
Integration Tests / Integration Tests (push) Successful in -26m20s
Release / Build and Release (push) Successful in -25m7s
2026-01-28 21:59:25 +02:00
6f55505444 feat(writer): 🎉 Enhance model name generation and formatting
All checks were successful
CI / Test (1.24) (push) Successful in -27m27s
CI / Test (1.25) (push) Successful in -27m17s
CI / Lint (push) Successful in -27m27s
CI / Build (push) Successful in -27m38s
Release / Build and Release (push) Successful in -27m24s
Integration Tests / Integration Tests (push) Successful in -27m16s
* Update model name generation to include schema name.
* Add gofmt execution after writing output files.
* Refactor relationship field naming to include schema.
* Update tests to reflect changes in model names and relationships.
2026-01-10 18:28:41 +02:00
e0e7b64c69 feat(writer): 🎉 Resolve field name collisions with methods
All checks were successful
CI / Test (1.24) (push) Successful in -27m21s
CI / Test (1.25) (push) Successful in -27m12s
CI / Build (push) Successful in -27m37s
CI / Lint (push) Successful in -27m26s
Release / Build and Release (push) Successful in -27m25s
Integration Tests / Integration Tests (push) Successful in -27m20s
* Implement field name collision resolution in model generation.
* Add tests to verify renaming of fields that conflict with generated method names.
* Ensure primary key type safety in UpdateID method.
2026-01-10 17:54:33 +02:00
4181cb1fbd feat(writer): 🎉 Enhance relationship field naming and uniqueness
All checks were successful
CI / Test (1.24) (push) Successful in -27m15s
CI / Test (1.25) (push) Successful in -27m10s
CI / Build (push) Successful in -27m38s
CI / Lint (push) Successful in -27m25s
Release / Build and Release (push) Successful in -27m27s
Integration Tests / Integration Tests (push) Successful in -27m18s
* Update relationship field naming conventions for has-one and has-many relationships.
* Implement logic to ensure unique field names by tracking used names.
* Add tests to verify new naming conventions and uniqueness constraints.
2026-01-10 17:45:13 +02:00
120ffc6a5a feat(writer): 🎉 Update relationship field naming convention
All checks were successful
CI / Test (1.24) (push) Successful in -27m26s
CI / Test (1.25) (push) Successful in -27m14s
CI / Lint (push) Successful in -27m27s
CI / Build (push) Successful in -27m36s
Release / Build and Release (push) Successful in -27m22s
Integration Tests / Integration Tests (push) Successful in -27m17s
* Refactor generateRelationshipFieldName to use foreign key columns for unique naming.
* Add test for multiple references to the same table to ensure unique relationship field names.
* Update existing tests to reflect new naming convention.
2026-01-10 13:49:54 +02:00
68 changed files with 9283 additions and 363 deletions

1
.gitignore vendored
View File

@@ -47,3 +47,4 @@ dist/
build/
bin/
tests/integration/failed_statements_example.txt
test_output.log

View File

@@ -38,13 +38,14 @@ import (
)
var (
convertSourceType string
convertSourcePath string
convertSourceConn string
convertTargetType string
convertTargetPath string
convertPackageName string
convertSchemaFilter string
convertSourceType string
convertSourcePath string
convertSourceConn string
convertTargetType string
convertTargetPath string
convertPackageName string
convertSchemaFilter string
convertFlattenSchema bool
)
var convertCmd = &cobra.Command{
@@ -148,6 +149,7 @@ func init() {
convertCmd.Flags().StringVar(&convertTargetPath, "to-path", "", "Target output path (file or directory)")
convertCmd.Flags().StringVar(&convertPackageName, "package", "", "Package name (for code generation formats like gorm/bun)")
convertCmd.Flags().StringVar(&convertSchemaFilter, "schema", "", "Filter to a specific schema by name (required for formats like dctx that only support single schemas)")
convertCmd.Flags().BoolVar(&convertFlattenSchema, "flatten-schema", false, "Flatten schema.table names to schema_table (useful for databases like SQLite that do not support schemas)")
err := convertCmd.MarkFlagRequired("from")
if err != nil {
@@ -202,7 +204,7 @@ func runConvert(cmd *cobra.Command, args []string) error {
fmt.Fprintf(os.Stderr, " Schema: %s\n", convertSchemaFilter)
}
if err := writeDatabase(db, convertTargetType, convertTargetPath, convertPackageName, convertSchemaFilter); err != nil {
if err := writeDatabase(db, convertTargetType, convertTargetPath, convertPackageName, convertSchemaFilter, convertFlattenSchema); err != nil {
return fmt.Errorf("failed to write target: %w", err)
}
@@ -301,12 +303,13 @@ func readDatabaseForConvert(dbType, filePath, connString string) (*models.Databa
return db, nil
}
func writeDatabase(db *models.Database, dbType, outputPath, packageName, schemaFilter string) error {
func writeDatabase(db *models.Database, dbType, outputPath, packageName, schemaFilter string, flattenSchema bool) error {
var writer writers.Writer
writerOpts := &writers.WriterOptions{
OutputPath: outputPath,
PackageName: packageName,
OutputPath: outputPath,
PackageName: packageName,
FlattenSchema: flattenSchema,
}
switch strings.ToLower(dbType) {

View File

@@ -55,6 +55,8 @@ var (
mergeSkipSequences bool
mergeSkipTables string // Comma-separated table names to skip
mergeVerbose bool
mergeReportPath string // Path to write merge report
mergeFlattenSchema bool
)
var mergeCmd = &cobra.Command{
@@ -78,6 +80,12 @@ Examples:
--source pgsql --source-conn "postgres://user:pass@localhost/source_db" \
--output json --output-path combined.json
# Merge and execute on PostgreSQL database with report
relspec merge --target json --target-path base.json \
--source json --source-path additional.json \
--output pgsql --output-conn "postgres://user:pass@localhost/target_db" \
--merge-report merge-report.json
# Merge DBML and YAML, skip relations
relspec merge --target dbml --target-path schema.dbml \
--source yaml --source-path tables.yaml \
@@ -115,6 +123,8 @@ func init() {
mergeCmd.Flags().BoolVar(&mergeSkipSequences, "skip-sequences", false, "Skip sequences during merge")
mergeCmd.Flags().StringVar(&mergeSkipTables, "skip-tables", "", "Comma-separated list of table names to skip during merge")
mergeCmd.Flags().BoolVar(&mergeVerbose, "verbose", false, "Show verbose output")
mergeCmd.Flags().StringVar(&mergeReportPath, "merge-report", "", "Path to write merge report (JSON format)")
mergeCmd.Flags().BoolVar(&mergeFlattenSchema, "flatten-schema", false, "Flatten schema.table names to schema_table (useful for databases like SQLite that do not support schemas)")
}
func runMerge(cmd *cobra.Command, args []string) error {
@@ -229,7 +239,7 @@ func runMerge(cmd *cobra.Command, args []string) error {
fmt.Fprintf(os.Stderr, " Path: %s\n", mergeOutputPath)
}
err = writeDatabaseForMerge(mergeOutputType, mergeOutputPath, "", targetDB, "Output")
err = writeDatabaseForMerge(mergeOutputType, mergeOutputPath, mergeOutputConn, targetDB, "Output", mergeFlattenSchema)
if err != nil {
return fmt.Errorf("failed to write output: %w", err)
}
@@ -316,7 +326,7 @@ func readDatabaseForMerge(dbType, filePath, connString, label string) (*models.D
return db, nil
}
func writeDatabaseForMerge(dbType, filePath, connString string, db *models.Database, label string) error {
func writeDatabaseForMerge(dbType, filePath, connString string, db *models.Database, label string, flattenSchema bool) error {
var writer writers.Writer
switch strings.ToLower(dbType) {
@@ -324,59 +334,69 @@ func writeDatabaseForMerge(dbType, filePath, connString string, db *models.Datab
if filePath == "" {
return fmt.Errorf("%s: file path is required for DBML format", label)
}
writer = wdbml.NewWriter(&writers.WriterOptions{OutputPath: filePath})
writer = wdbml.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema})
case "dctx":
if filePath == "" {
return fmt.Errorf("%s: file path is required for DCTX format", label)
}
writer = wdctx.NewWriter(&writers.WriterOptions{OutputPath: filePath})
writer = wdctx.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema})
case "drawdb":
if filePath == "" {
return fmt.Errorf("%s: file path is required for DrawDB format", label)
}
writer = wdrawdb.NewWriter(&writers.WriterOptions{OutputPath: filePath})
writer = wdrawdb.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema})
case "graphql":
if filePath == "" {
return fmt.Errorf("%s: file path is required for GraphQL format", label)
}
writer = wgraphql.NewWriter(&writers.WriterOptions{OutputPath: filePath})
writer = wgraphql.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema})
case "json":
if filePath == "" {
return fmt.Errorf("%s: file path is required for JSON format", label)
}
writer = wjson.NewWriter(&writers.WriterOptions{OutputPath: filePath})
writer = wjson.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema})
case "yaml":
if filePath == "" {
return fmt.Errorf("%s: file path is required for YAML format", label)
}
writer = wyaml.NewWriter(&writers.WriterOptions{OutputPath: filePath})
writer = wyaml.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema})
case "gorm":
if filePath == "" {
return fmt.Errorf("%s: file path is required for GORM format", label)
}
writer = wgorm.NewWriter(&writers.WriterOptions{OutputPath: filePath})
writer = wgorm.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema})
case "bun":
if filePath == "" {
return fmt.Errorf("%s: file path is required for Bun format", label)
}
writer = wbun.NewWriter(&writers.WriterOptions{OutputPath: filePath})
writer = wbun.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema})
case "drizzle":
if filePath == "" {
return fmt.Errorf("%s: file path is required for Drizzle format", label)
}
writer = wdrizzle.NewWriter(&writers.WriterOptions{OutputPath: filePath})
writer = wdrizzle.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema})
case "prisma":
if filePath == "" {
return fmt.Errorf("%s: file path is required for Prisma format", label)
}
writer = wprisma.NewWriter(&writers.WriterOptions{OutputPath: filePath})
writer = wprisma.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema})
case "typeorm":
if filePath == "" {
return fmt.Errorf("%s: file path is required for TypeORM format", label)
}
writer = wtypeorm.NewWriter(&writers.WriterOptions{OutputPath: filePath})
writer = wtypeorm.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema})
case "pgsql":
writer = wpgsql.NewWriter(&writers.WriterOptions{OutputPath: filePath})
writerOpts := &writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema}
if connString != "" {
writerOpts.Metadata = map[string]interface{}{
"connection_string": connString,
}
// Add report path if merge report is enabled
if mergeReportPath != "" {
writerOpts.Metadata["report_path"] = mergeReportPath
}
}
writer = wpgsql.NewWriter(writerOpts)
default:
return fmt.Errorf("%s: unsupported format '%s'", label, dbType)
}

View File

@@ -14,10 +14,11 @@ import (
)
var (
scriptsDir string
scriptsConn string
scriptsSchemaName string
scriptsDBName string
scriptsDir string
scriptsConn string
scriptsSchemaName string
scriptsDBName string
scriptsIgnoreErrors bool
)
var scriptsCmd = &cobra.Command{
@@ -39,8 +40,8 @@ Example filenames (hyphen format):
1-002-create-posts.sql # Priority 1, Sequence 2
10-10-create-newid.pgsql # Priority 10, Sequence 10
Both formats can be mixed in the same directory.
Scripts are executed in order: Priority (ascending), then Sequence (ascending).`,
Both formats can be mixed in the same directory and subdirectories.
Scripts are executed in order: Priority (ascending), Sequence (ascending), Name (alphabetical).`,
}
var scriptsListCmd = &cobra.Command{
@@ -48,8 +49,8 @@ var scriptsListCmd = &cobra.Command{
Short: "List SQL scripts from a directory",
Long: `List SQL scripts from a directory and show their execution order.
The scripts are read from the specified directory and displayed in the order
they would be executed (Priority ascending, then Sequence ascending).
The scripts are read recursively from the specified directory and displayed in the order
they would be executed: Priority (ascending), then Sequence (ascending), then Name (alphabetical).
Example:
relspec scripts list --dir ./migrations`,
@@ -61,10 +62,10 @@ var scriptsExecuteCmd = &cobra.Command{
Short: "Execute SQL scripts against a database",
Long: `Execute SQL scripts from a directory against a PostgreSQL database.
Scripts are executed in order: Priority (ascending), then Sequence (ascending).
Execution stops immediately on the first error.
Scripts are executed in order: Priority (ascending), Sequence (ascending), Name (alphabetical).
By default, execution stops immediately on the first error. Use --ignore-errors to continue execution.
The directory is scanned recursively for files matching the patterns:
The directory is scanned recursively for all subdirectories and files matching the patterns:
{priority}_{sequence}_{name}.sql or .pgsql (underscore format)
{priority}-{sequence}-{name}.sql or .pgsql (hyphen format)
@@ -75,7 +76,7 @@ PostgreSQL Connection String Examples:
postgresql://user:pass@host/dbname?sslmode=require
Examples:
# Execute migration scripts
# Execute migration scripts from a directory (including subdirectories)
relspec scripts execute --dir ./migrations \
--conn "postgres://user:pass@localhost:5432/mydb"
@@ -86,7 +87,12 @@ Examples:
# Execute with SSL disabled
relspec scripts execute --dir ./sql \
--conn "postgres://user:pass@localhost/db?sslmode=disable"`,
--conn "postgres://user:pass@localhost/db?sslmode=disable"
# Continue executing even if errors occur
relspec scripts execute --dir ./migrations \
--conn "postgres://localhost/mydb" \
--ignore-errors`,
RunE: runScriptsExecute,
}
@@ -105,6 +111,7 @@ func init() {
scriptsExecuteCmd.Flags().StringVar(&scriptsConn, "conn", "", "PostgreSQL connection string (required)")
scriptsExecuteCmd.Flags().StringVar(&scriptsSchemaName, "schema", "public", "Schema name (optional, default: public)")
scriptsExecuteCmd.Flags().StringVar(&scriptsDBName, "database", "database", "Database name (optional, default: database)")
scriptsExecuteCmd.Flags().BoolVar(&scriptsIgnoreErrors, "ignore-errors", false, "Continue executing scripts even if errors occur")
err = scriptsExecuteCmd.MarkFlagRequired("dir")
if err != nil {
@@ -149,7 +156,7 @@ func runScriptsList(cmd *cobra.Command, args []string) error {
return nil
}
// Sort scripts by Priority then Sequence
// Sort scripts by Priority, Sequence, then Name
sortedScripts := make([]*struct {
name string
priority int
@@ -186,7 +193,10 @@ func runScriptsList(cmd *cobra.Command, args []string) error {
if sortedScripts[i].priority != sortedScripts[j].priority {
return sortedScripts[i].priority < sortedScripts[j].priority
}
return sortedScripts[i].sequence < sortedScripts[j].sequence
if sortedScripts[i].sequence != sortedScripts[j].sequence {
return sortedScripts[i].sequence < sortedScripts[j].sequence
}
return sortedScripts[i].name < sortedScripts[j].name
})
fmt.Fprintf(os.Stderr, "Found %d script(s) in execution order:\n\n", len(sortedScripts))
@@ -242,22 +252,44 @@ func runScriptsExecute(cmd *cobra.Command, args []string) error {
fmt.Fprintf(os.Stderr, " ✓ Found %d script(s)\n\n", len(schema.Scripts))
// Step 2: Execute scripts
fmt.Fprintf(os.Stderr, "[2/2] Executing scripts in order (Priority → Sequence)...\n\n")
fmt.Fprintf(os.Stderr, "[2/2] Executing scripts in order (Priority → Sequence → Name)...\n\n")
writer := sqlexec.NewWriter(&writers.WriterOptions{
Metadata: map[string]any{
"connection_string": scriptsConn,
"ignore_errors": scriptsIgnoreErrors,
},
})
if err := writer.WriteSchema(schema); err != nil {
fmt.Fprintf(os.Stderr, "\n")
return fmt.Errorf("execution failed: %w", err)
return fmt.Errorf("script execution failed: %w", err)
}
// Get execution results from writer metadata
totalCount := len(schema.Scripts)
successCount := totalCount
failedCount := 0
opts := writer.Options()
if total, exists := opts.Metadata["execution_total"].(int); exists {
totalCount = total
}
if success, exists := opts.Metadata["execution_success"].(int); exists {
successCount = success
}
if failed, exists := opts.Metadata["execution_failed"].(int); exists {
failedCount = failed
}
fmt.Fprintf(os.Stderr, "\n=== Execution Complete ===\n")
fmt.Fprintf(os.Stderr, "Completed at: %s\n", getCurrentTimestamp())
fmt.Fprintf(os.Stderr, "Successfully executed %d script(s)\n\n", len(schema.Scripts))
fmt.Fprintf(os.Stderr, "Total scripts: %d\n", totalCount)
fmt.Fprintf(os.Stderr, "Successful: %d\n", successCount)
if failedCount > 0 {
fmt.Fprintf(os.Stderr, "Failed: %d\n", failedCount)
}
fmt.Fprintf(os.Stderr, "\n")
return nil
}

View File

@@ -183,7 +183,8 @@ func runSplit(cmd *cobra.Command, args []string) error {
splitTargetType,
splitTargetPath,
splitPackageName,
"", // no schema filter for split
"", // no schema filter for split
false, // no flatten-schema for split
)
if err != nil {
return fmt.Errorf("failed to write output: %w", err)

View File

@@ -0,0 +1,714 @@
package commontypes
import (
"testing"
)
func TestExtractBaseType(t *testing.T) {
tests := []struct {
name string
sqlType string
want string
}{
{"varchar with length", "varchar(100)", "varchar"},
{"VARCHAR uppercase with length", "VARCHAR(255)", "varchar"},
{"numeric with precision", "numeric(10,2)", "numeric"},
{"NUMERIC uppercase", "NUMERIC(18,4)", "numeric"},
{"decimal with precision", "decimal(15,3)", "decimal"},
{"char with length", "char(50)", "char"},
{"simple integer", "integer", "integer"},
{"simple text", "text", "text"},
{"bigint", "bigint", "bigint"},
{"With spaces", " varchar(100) ", "varchar"},
{"No parentheses", "boolean", "boolean"},
{"Empty string", "", ""},
{"Mixed case", "VarChar(100)", "varchar"},
{"timestamp with time zone", "timestamp(6) with time zone", "timestamp"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := ExtractBaseType(tt.sqlType)
if got != tt.want {
t.Errorf("ExtractBaseType(%q) = %q, want %q", tt.sqlType, got, tt.want)
}
})
}
}
func TestNormalizeType(t *testing.T) {
// NormalizeType is an alias for ExtractBaseType, test that they behave the same
testCases := []string{
"varchar(100)",
"numeric(10,2)",
"integer",
"text",
" VARCHAR(255) ",
}
for _, tc := range testCases {
t.Run(tc, func(t *testing.T) {
extracted := ExtractBaseType(tc)
normalized := NormalizeType(tc)
if extracted != normalized {
t.Errorf("ExtractBaseType(%q) = %q, but NormalizeType(%q) = %q",
tc, extracted, tc, normalized)
}
})
}
}
func TestSQLToGo(t *testing.T) {
tests := []struct {
name string
sqlType string
nullable bool
want string
}{
// Integer types (nullable)
{"integer nullable", "integer", true, "int32"},
{"bigint nullable", "bigint", true, "int64"},
{"smallint nullable", "smallint", true, "int16"},
{"serial nullable", "serial", true, "int32"},
// Integer types (not nullable)
{"integer not nullable", "integer", false, "*int32"},
{"bigint not nullable", "bigint", false, "*int64"},
{"smallint not nullable", "smallint", false, "*int16"},
// String types (nullable)
{"text nullable", "text", true, "string"},
{"varchar nullable", "varchar", true, "string"},
{"varchar with length nullable", "varchar(100)", true, "string"},
// String types (not nullable)
{"text not nullable", "text", false, "*string"},
{"varchar not nullable", "varchar", false, "*string"},
// Boolean
{"boolean nullable", "boolean", true, "bool"},
{"boolean not nullable", "boolean", false, "*bool"},
// Float types
{"real nullable", "real", true, "float32"},
{"double precision nullable", "double precision", true, "float64"},
{"real not nullable", "real", false, "*float32"},
{"double precision not nullable", "double precision", false, "*float64"},
// Date/Time types
{"timestamp nullable", "timestamp", true, "time.Time"},
{"date nullable", "date", true, "time.Time"},
{"timestamp not nullable", "timestamp", false, "*time.Time"},
// Binary
{"bytea nullable", "bytea", true, "[]byte"},
{"bytea not nullable", "bytea", false, "[]byte"}, // Slices don't get pointer
// UUID
{"uuid nullable", "uuid", true, "string"},
{"uuid not nullable", "uuid", false, "*string"},
// JSON
{"json nullable", "json", true, "string"},
{"jsonb nullable", "jsonb", true, "string"},
// Array
{"array nullable", "array", true, "[]string"},
{"array not nullable", "array", false, "[]string"}, // Slices don't get pointer
// Unknown types
{"unknown type nullable", "unknowntype", true, "interface{}"},
{"unknown type not nullable", "unknowntype", false, "interface{}"}, // Interface doesn't get pointer
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := SQLToGo(tt.sqlType, tt.nullable)
if got != tt.want {
t.Errorf("SQLToGo(%q, %v) = %q, want %q", tt.sqlType, tt.nullable, got, tt.want)
}
})
}
}
func TestSQLToTypeScript(t *testing.T) {
tests := []struct {
name string
sqlType string
nullable bool
want string
}{
// Integer types
{"integer nullable", "integer", true, "number"},
{"integer not nullable", "integer", false, "number | null"},
{"bigint nullable", "bigint", true, "number"},
{"bigint not nullable", "bigint", false, "number | null"},
// String types
{"text nullable", "text", true, "string"},
{"text not nullable", "text", false, "string | null"},
{"varchar nullable", "varchar", true, "string"},
{"varchar(100) nullable", "varchar(100)", true, "string"},
// Boolean
{"boolean nullable", "boolean", true, "boolean"},
{"boolean not nullable", "boolean", false, "boolean | null"},
// Float types
{"real nullable", "real", true, "number"},
{"double precision nullable", "double precision", true, "number"},
// Date/Time types
{"timestamp nullable", "timestamp", true, "Date"},
{"date nullable", "date", true, "Date"},
{"timestamp not nullable", "timestamp", false, "Date | null"},
// Binary
{"bytea nullable", "bytea", true, "Buffer"},
{"bytea not nullable", "bytea", false, "Buffer | null"},
// JSON
{"json nullable", "json", true, "any"},
{"jsonb nullable", "jsonb", true, "any"},
// UUID
{"uuid nullable", "uuid", true, "string"},
// Unknown types
{"unknown type nullable", "unknowntype", true, "any"},
{"unknown type not nullable", "unknowntype", false, "any | null"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := SQLToTypeScript(tt.sqlType, tt.nullable)
if got != tt.want {
t.Errorf("SQLToTypeScript(%q, %v) = %q, want %q", tt.sqlType, tt.nullable, got, tt.want)
}
})
}
}
func TestSQLToPython(t *testing.T) {
tests := []struct {
name string
sqlType string
want string
}{
// Integer types
{"integer", "integer", "int"},
{"bigint", "bigint", "int"},
{"smallint", "smallint", "int"},
// String types
{"text", "text", "str"},
{"varchar", "varchar", "str"},
{"varchar(100)", "varchar(100)", "str"},
// Boolean
{"boolean", "boolean", "bool"},
// Float types
{"real", "real", "float"},
{"double precision", "double precision", "float"},
{"numeric", "numeric", "Decimal"},
{"decimal", "decimal", "Decimal"},
// Date/Time types
{"timestamp", "timestamp", "datetime"},
{"date", "date", "date"},
{"time", "time", "time"},
// Binary
{"bytea", "bytea", "bytes"},
// JSON
{"json", "json", "dict"},
{"jsonb", "jsonb", "dict"},
// UUID
{"uuid", "uuid", "UUID"},
// Array
{"array", "array", "list"},
// Unknown types
{"unknown type", "unknowntype", "Any"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := SQLToPython(tt.sqlType)
if got != tt.want {
t.Errorf("SQLToPython(%q) = %q, want %q", tt.sqlType, got, tt.want)
}
})
}
}
func TestSQLToCSharp(t *testing.T) {
tests := []struct {
name string
sqlType string
nullable bool
want string
}{
// Integer types (nullable)
{"integer nullable", "integer", true, "int"},
{"bigint nullable", "bigint", true, "long"},
{"smallint nullable", "smallint", true, "short"},
// Integer types (not nullable - value types get ?)
{"integer not nullable", "integer", false, "int?"},
{"bigint not nullable", "bigint", false, "long?"},
{"smallint not nullable", "smallint", false, "short?"},
// String types (reference types, no ? needed)
{"text nullable", "text", true, "string"},
{"text not nullable", "text", false, "string"},
{"varchar nullable", "varchar", true, "string"},
{"varchar(100) nullable", "varchar(100)", true, "string"},
// Boolean
{"boolean nullable", "boolean", true, "bool"},
{"boolean not nullable", "boolean", false, "bool?"},
// Float types
{"real nullable", "real", true, "float"},
{"double precision nullable", "double precision", true, "double"},
{"decimal nullable", "decimal", true, "decimal"},
{"real not nullable", "real", false, "float?"},
{"double precision not nullable", "double precision", false, "double?"},
{"decimal not nullable", "decimal", false, "decimal?"},
// Date/Time types
{"timestamp nullable", "timestamp", true, "DateTime"},
{"date nullable", "date", true, "DateTime"},
{"timestamptz nullable", "timestamptz", true, "DateTimeOffset"},
{"timestamp not nullable", "timestamp", false, "DateTime?"},
{"timestamptz not nullable", "timestamptz", false, "DateTimeOffset?"},
// Binary (array type, no ?)
{"bytea nullable", "bytea", true, "byte[]"},
{"bytea not nullable", "bytea", false, "byte[]"},
// UUID
{"uuid nullable", "uuid", true, "Guid"},
{"uuid not nullable", "uuid", false, "Guid?"},
// JSON
{"json nullable", "json", true, "string"},
// Unknown types (object is reference type)
{"unknown type nullable", "unknowntype", true, "object"},
{"unknown type not nullable", "unknowntype", false, "object"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := SQLToCSharp(tt.sqlType, tt.nullable)
if got != tt.want {
t.Errorf("SQLToCSharp(%q, %v) = %q, want %q", tt.sqlType, tt.nullable, got, tt.want)
}
})
}
}
func TestNeedsTimeImport(t *testing.T) {
tests := []struct {
name string
goType string
want bool
}{
{"time.Time type", "time.Time", true},
{"pointer to time.Time", "*time.Time", true},
{"int32 type", "int32", false},
{"string type", "string", false},
{"bool type", "bool", false},
{"[]byte type", "[]byte", false},
{"interface{}", "interface{}", false},
{"empty string", "", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := NeedsTimeImport(tt.goType)
if got != tt.want {
t.Errorf("NeedsTimeImport(%q) = %v, want %v", tt.goType, got, tt.want)
}
})
}
}
func TestGoTypeMap(t *testing.T) {
// Test that the map contains expected entries
expectedMappings := map[string]string{
"integer": "int32",
"bigint": "int64",
"text": "string",
"boolean": "bool",
"double precision": "float64",
"bytea": "[]byte",
"timestamp": "time.Time",
"uuid": "string",
"json": "string",
}
for sqlType, expectedGoType := range expectedMappings {
if goType, ok := GoTypeMap[sqlType]; !ok {
t.Errorf("GoTypeMap missing entry for %q", sqlType)
} else if goType != expectedGoType {
t.Errorf("GoTypeMap[%q] = %q, want %q", sqlType, goType, expectedGoType)
}
}
if len(GoTypeMap) == 0 {
t.Error("GoTypeMap is empty")
}
}
func TestTypeScriptTypeMap(t *testing.T) {
expectedMappings := map[string]string{
"integer": "number",
"bigint": "number",
"text": "string",
"boolean": "boolean",
"double precision": "number",
"bytea": "Buffer",
"timestamp": "Date",
"uuid": "string",
"json": "any",
}
for sqlType, expectedTSType := range expectedMappings {
if tsType, ok := TypeScriptTypeMap[sqlType]; !ok {
t.Errorf("TypeScriptTypeMap missing entry for %q", sqlType)
} else if tsType != expectedTSType {
t.Errorf("TypeScriptTypeMap[%q] = %q, want %q", sqlType, tsType, expectedTSType)
}
}
if len(TypeScriptTypeMap) == 0 {
t.Error("TypeScriptTypeMap is empty")
}
}
func TestPythonTypeMap(t *testing.T) {
expectedMappings := map[string]string{
"integer": "int",
"bigint": "int",
"text": "str",
"boolean": "bool",
"real": "float",
"numeric": "Decimal",
"bytea": "bytes",
"date": "date",
"uuid": "UUID",
"json": "dict",
}
for sqlType, expectedPyType := range expectedMappings {
if pyType, ok := PythonTypeMap[sqlType]; !ok {
t.Errorf("PythonTypeMap missing entry for %q", sqlType)
} else if pyType != expectedPyType {
t.Errorf("PythonTypeMap[%q] = %q, want %q", sqlType, pyType, expectedPyType)
}
}
if len(PythonTypeMap) == 0 {
t.Error("PythonTypeMap is empty")
}
}
func TestCSharpTypeMap(t *testing.T) {
expectedMappings := map[string]string{
"integer": "int",
"bigint": "long",
"smallint": "short",
"text": "string",
"boolean": "bool",
"double precision": "double",
"decimal": "decimal",
"bytea": "byte[]",
"timestamp": "DateTime",
"uuid": "Guid",
"json": "string",
}
for sqlType, expectedCSType := range expectedMappings {
if csType, ok := CSharpTypeMap[sqlType]; !ok {
t.Errorf("CSharpTypeMap missing entry for %q", sqlType)
} else if csType != expectedCSType {
t.Errorf("CSharpTypeMap[%q] = %q, want %q", sqlType, csType, expectedCSType)
}
}
if len(CSharpTypeMap) == 0 {
t.Error("CSharpTypeMap is empty")
}
}
func TestSQLToJava(t *testing.T) {
tests := []struct {
name string
sqlType string
nullable bool
want string
}{
// Integer types
{"integer nullable", "integer", true, "Integer"},
{"integer not nullable", "integer", false, "Integer"},
{"bigint nullable", "bigint", true, "Long"},
{"smallint nullable", "smallint", true, "Short"},
// String types
{"text nullable", "text", true, "String"},
{"varchar nullable", "varchar", true, "String"},
{"varchar(100) nullable", "varchar(100)", true, "String"},
// Boolean
{"boolean nullable", "boolean", true, "Boolean"},
// Float types
{"real nullable", "real", true, "Float"},
{"double precision nullable", "double precision", true, "Double"},
{"numeric nullable", "numeric", true, "BigDecimal"},
// Date/Time types
{"timestamp nullable", "timestamp", true, "Timestamp"},
{"date nullable", "date", true, "Date"},
{"time nullable", "time", true, "Time"},
// Binary
{"bytea nullable", "bytea", true, "byte[]"},
// UUID
{"uuid nullable", "uuid", true, "UUID"},
// JSON
{"json nullable", "json", true, "String"},
// Unknown types
{"unknown type nullable", "unknowntype", true, "Object"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := SQLToJava(tt.sqlType, tt.nullable)
if got != tt.want {
t.Errorf("SQLToJava(%q, %v) = %q, want %q", tt.sqlType, tt.nullable, got, tt.want)
}
})
}
}
func TestSQLToPhp(t *testing.T) {
tests := []struct {
name string
sqlType string
nullable bool
want string
}{
// Integer types (nullable)
{"integer nullable", "integer", true, "int"},
{"bigint nullable", "bigint", true, "int"},
{"smallint nullable", "smallint", true, "int"},
// Integer types (not nullable)
{"integer not nullable", "integer", false, "?int"},
{"bigint not nullable", "bigint", false, "?int"},
// String types
{"text nullable", "text", true, "string"},
{"text not nullable", "text", false, "?string"},
{"varchar nullable", "varchar", true, "string"},
{"varchar(100) nullable", "varchar(100)", true, "string"},
// Boolean
{"boolean nullable", "boolean", true, "bool"},
{"boolean not nullable", "boolean", false, "?bool"},
// Float types
{"real nullable", "real", true, "float"},
{"double precision nullable", "double precision", true, "float"},
{"real not nullable", "real", false, "?float"},
// Date/Time types
{"timestamp nullable", "timestamp", true, "\\DateTime"},
{"date nullable", "date", true, "\\DateTime"},
{"timestamp not nullable", "timestamp", false, "?\\DateTime"},
// Binary
{"bytea nullable", "bytea", true, "string"},
{"bytea not nullable", "bytea", false, "?string"},
// JSON
{"json nullable", "json", true, "array"},
{"json not nullable", "json", false, "?array"},
// UUID
{"uuid nullable", "uuid", true, "string"},
// Unknown types
{"unknown type nullable", "unknowntype", true, "mixed"},
{"unknown type not nullable", "unknowntype", false, "mixed"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := SQLToPhp(tt.sqlType, tt.nullable)
if got != tt.want {
t.Errorf("SQLToPhp(%q, %v) = %q, want %q", tt.sqlType, tt.nullable, got, tt.want)
}
})
}
}
func TestSQLToRust(t *testing.T) {
tests := []struct {
name string
sqlType string
nullable bool
want string
}{
// Integer types (nullable)
{"integer nullable", "integer", true, "i32"},
{"bigint nullable", "bigint", true, "i64"},
{"smallint nullable", "smallint", true, "i16"},
// Integer types (not nullable)
{"integer not nullable", "integer", false, "Option<i32>"},
{"bigint not nullable", "bigint", false, "Option<i64>"},
{"smallint not nullable", "smallint", false, "Option<i16>"},
// String types
{"text nullable", "text", true, "String"},
{"text not nullable", "text", false, "Option<String>"},
{"varchar nullable", "varchar", true, "String"},
{"varchar(100) nullable", "varchar(100)", true, "String"},
// Boolean
{"boolean nullable", "boolean", true, "bool"},
{"boolean not nullable", "boolean", false, "Option<bool>"},
// Float types
{"real nullable", "real", true, "f32"},
{"double precision nullable", "double precision", true, "f64"},
{"real not nullable", "real", false, "Option<f32>"},
{"double precision not nullable", "double precision", false, "Option<f64>"},
// Date/Time types
{"timestamp nullable", "timestamp", true, "NaiveDateTime"},
{"timestamptz nullable", "timestamptz", true, "DateTime<Utc>"},
{"date nullable", "date", true, "NaiveDate"},
{"time nullable", "time", true, "NaiveTime"},
{"timestamp not nullable", "timestamp", false, "Option<NaiveDateTime>"},
// Binary
{"bytea nullable", "bytea", true, "Vec<u8>"},
{"bytea not nullable", "bytea", false, "Option<Vec<u8>>"},
// JSON
{"json nullable", "json", true, "serde_json::Value"},
{"json not nullable", "json", false, "Option<serde_json::Value>"},
// UUID
{"uuid nullable", "uuid", true, "String"},
// Unknown types
{"unknown type nullable", "unknowntype", true, "String"},
{"unknown type not nullable", "unknowntype", false, "Option<String>"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := SQLToRust(tt.sqlType, tt.nullable)
if got != tt.want {
t.Errorf("SQLToRust(%q, %v) = %q, want %q", tt.sqlType, tt.nullable, got, tt.want)
}
})
}
}
func TestJavaTypeMap(t *testing.T) {
expectedMappings := map[string]string{
"integer": "Integer",
"bigint": "Long",
"smallint": "Short",
"text": "String",
"boolean": "Boolean",
"double precision": "Double",
"numeric": "BigDecimal",
"bytea": "byte[]",
"timestamp": "Timestamp",
"uuid": "UUID",
"date": "Date",
}
for sqlType, expectedJavaType := range expectedMappings {
if javaType, ok := JavaTypeMap[sqlType]; !ok {
t.Errorf("JavaTypeMap missing entry for %q", sqlType)
} else if javaType != expectedJavaType {
t.Errorf("JavaTypeMap[%q] = %q, want %q", sqlType, javaType, expectedJavaType)
}
}
if len(JavaTypeMap) == 0 {
t.Error("JavaTypeMap is empty")
}
}
func TestPHPTypeMap(t *testing.T) {
expectedMappings := map[string]string{
"integer": "int",
"bigint": "int",
"text": "string",
"boolean": "bool",
"double precision": "float",
"bytea": "string",
"timestamp": "\\DateTime",
"uuid": "string",
"json": "array",
}
for sqlType, expectedPHPType := range expectedMappings {
if phpType, ok := PHPTypeMap[sqlType]; !ok {
t.Errorf("PHPTypeMap missing entry for %q", sqlType)
} else if phpType != expectedPHPType {
t.Errorf("PHPTypeMap[%q] = %q, want %q", sqlType, phpType, expectedPHPType)
}
}
if len(PHPTypeMap) == 0 {
t.Error("PHPTypeMap is empty")
}
}
func TestRustTypeMap(t *testing.T) {
expectedMappings := map[string]string{
"integer": "i32",
"bigint": "i64",
"smallint": "i16",
"text": "String",
"boolean": "bool",
"double precision": "f64",
"real": "f32",
"bytea": "Vec<u8>",
"timestamp": "NaiveDateTime",
"timestamptz": "DateTime<Utc>",
"date": "NaiveDate",
"json": "serde_json::Value",
}
for sqlType, expectedRustType := range expectedMappings {
if rustType, ok := RustTypeMap[sqlType]; !ok {
t.Errorf("RustTypeMap missing entry for %q", sqlType)
} else if rustType != expectedRustType {
t.Errorf("RustTypeMap[%q] = %q, want %q", sqlType, rustType, expectedRustType)
}
}
if len(RustTypeMap) == 0 {
t.Error("RustTypeMap is empty")
}
}

558
pkg/diff/diff_test.go Normal file
View File

@@ -0,0 +1,558 @@
package diff
import (
"testing"
"git.warky.dev/wdevs/relspecgo/pkg/models"
)
func TestCompareDatabases(t *testing.T) {
tests := []struct {
name string
source *models.Database
target *models.Database
want func(*DiffResult) bool
}{
{
name: "identical databases",
source: &models.Database{
Name: "source",
Schemas: []*models.Schema{},
},
target: &models.Database{
Name: "target",
Schemas: []*models.Schema{},
},
want: func(r *DiffResult) bool {
return r.Source == "source" && r.Target == "target" &&
len(r.Schemas.Missing) == 0 && len(r.Schemas.Extra) == 0
},
},
{
name: "different schemas",
source: &models.Database{
Name: "source",
Schemas: []*models.Schema{
{Name: "public"},
},
},
target: &models.Database{
Name: "target",
Schemas: []*models.Schema{},
},
want: func(r *DiffResult) bool {
return len(r.Schemas.Missing) == 1 && r.Schemas.Missing[0].Name == "public"
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := CompareDatabases(tt.source, tt.target)
if !tt.want(got) {
t.Errorf("CompareDatabases() result doesn't match expectations")
}
})
}
}
func TestCompareColumns(t *testing.T) {
tests := []struct {
name string
source map[string]*models.Column
target map[string]*models.Column
want func(*ColumnDiff) bool
}{
{
name: "identical columns",
source: map[string]*models.Column{},
target: map[string]*models.Column{},
want: func(d *ColumnDiff) bool {
return len(d.Missing) == 0 && len(d.Extra) == 0 && len(d.Modified) == 0
},
},
{
name: "missing column",
source: map[string]*models.Column{
"id": {Name: "id", Type: "integer"},
},
target: map[string]*models.Column{},
want: func(d *ColumnDiff) bool {
return len(d.Missing) == 1 && d.Missing[0].Name == "id"
},
},
{
name: "extra column",
source: map[string]*models.Column{},
target: map[string]*models.Column{
"id": {Name: "id", Type: "integer"},
},
want: func(d *ColumnDiff) bool {
return len(d.Extra) == 1 && d.Extra[0].Name == "id"
},
},
{
name: "modified column type",
source: map[string]*models.Column{
"id": {Name: "id", Type: "integer"},
},
target: map[string]*models.Column{
"id": {Name: "id", Type: "bigint"},
},
want: func(d *ColumnDiff) bool {
return len(d.Modified) == 1 && d.Modified[0].Name == "id" &&
d.Modified[0].Changes["type"] != nil
},
},
{
name: "modified column nullable",
source: map[string]*models.Column{
"name": {Name: "name", Type: "text", NotNull: true},
},
target: map[string]*models.Column{
"name": {Name: "name", Type: "text", NotNull: false},
},
want: func(d *ColumnDiff) bool {
return len(d.Modified) == 1 && d.Modified[0].Changes["not_null"] != nil
},
},
{
name: "modified column length",
source: map[string]*models.Column{
"name": {Name: "name", Type: "varchar", Length: 100},
},
target: map[string]*models.Column{
"name": {Name: "name", Type: "varchar", Length: 255},
},
want: func(d *ColumnDiff) bool {
return len(d.Modified) == 1 && d.Modified[0].Changes["length"] != nil
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := compareColumns(tt.source, tt.target)
if !tt.want(got) {
t.Errorf("compareColumns() result doesn't match expectations")
}
})
}
}
func TestCompareColumnDetails(t *testing.T) {
tests := []struct {
name string
source *models.Column
target *models.Column
want int // number of changes
}{
{
name: "identical columns",
source: &models.Column{Name: "id", Type: "integer"},
target: &models.Column{Name: "id", Type: "integer"},
want: 0,
},
{
name: "type change",
source: &models.Column{Name: "id", Type: "integer"},
target: &models.Column{Name: "id", Type: "bigint"},
want: 1,
},
{
name: "length change",
source: &models.Column{Name: "name", Type: "varchar", Length: 100},
target: &models.Column{Name: "name", Type: "varchar", Length: 255},
want: 1,
},
{
name: "precision change",
source: &models.Column{Name: "price", Type: "numeric", Precision: 10},
target: &models.Column{Name: "price", Type: "numeric", Precision: 12},
want: 1,
},
{
name: "scale change",
source: &models.Column{Name: "price", Type: "numeric", Scale: 2},
target: &models.Column{Name: "price", Type: "numeric", Scale: 4},
want: 1,
},
{
name: "not null change",
source: &models.Column{Name: "name", Type: "text", NotNull: true},
target: &models.Column{Name: "name", Type: "text", NotNull: false},
want: 1,
},
{
name: "auto increment change",
source: &models.Column{Name: "id", Type: "integer", AutoIncrement: true},
target: &models.Column{Name: "id", Type: "integer", AutoIncrement: false},
want: 1,
},
{
name: "primary key change",
source: &models.Column{Name: "id", Type: "integer", IsPrimaryKey: true},
target: &models.Column{Name: "id", Type: "integer", IsPrimaryKey: false},
want: 1,
},
{
name: "multiple changes",
source: &models.Column{Name: "id", Type: "integer", NotNull: true, AutoIncrement: true},
target: &models.Column{Name: "id", Type: "bigint", NotNull: false, AutoIncrement: false},
want: 3,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := compareColumnDetails(tt.source, tt.target)
if len(got) != tt.want {
t.Errorf("compareColumnDetails() = %d changes, want %d", len(got), tt.want)
}
})
}
}
func TestCompareIndexes(t *testing.T) {
tests := []struct {
name string
source map[string]*models.Index
target map[string]*models.Index
want func(*IndexDiff) bool
}{
{
name: "identical indexes",
source: map[string]*models.Index{},
target: map[string]*models.Index{},
want: func(d *IndexDiff) bool {
return len(d.Missing) == 0 && len(d.Extra) == 0 && len(d.Modified) == 0
},
},
{
name: "missing index",
source: map[string]*models.Index{
"idx_name": {Name: "idx_name", Columns: []string{"name"}},
},
target: map[string]*models.Index{},
want: func(d *IndexDiff) bool {
return len(d.Missing) == 1 && d.Missing[0].Name == "idx_name"
},
},
{
name: "extra index",
source: map[string]*models.Index{},
target: map[string]*models.Index{
"idx_name": {Name: "idx_name", Columns: []string{"name"}},
},
want: func(d *IndexDiff) bool {
return len(d.Extra) == 1 && d.Extra[0].Name == "idx_name"
},
},
{
name: "modified index uniqueness",
source: map[string]*models.Index{
"idx_name": {Name: "idx_name", Columns: []string{"name"}, Unique: false},
},
target: map[string]*models.Index{
"idx_name": {Name: "idx_name", Columns: []string{"name"}, Unique: true},
},
want: func(d *IndexDiff) bool {
return len(d.Modified) == 1 && d.Modified[0].Name == "idx_name"
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := compareIndexes(tt.source, tt.target)
if !tt.want(got) {
t.Errorf("compareIndexes() result doesn't match expectations")
}
})
}
}
func TestCompareConstraints(t *testing.T) {
tests := []struct {
name string
source map[string]*models.Constraint
target map[string]*models.Constraint
want func(*ConstraintDiff) bool
}{
{
name: "identical constraints",
source: map[string]*models.Constraint{},
target: map[string]*models.Constraint{},
want: func(d *ConstraintDiff) bool {
return len(d.Missing) == 0 && len(d.Extra) == 0 && len(d.Modified) == 0
},
},
{
name: "missing constraint",
source: map[string]*models.Constraint{
"pk_id": {Name: "pk_id", Type: "PRIMARY KEY", Columns: []string{"id"}},
},
target: map[string]*models.Constraint{},
want: func(d *ConstraintDiff) bool {
return len(d.Missing) == 1 && d.Missing[0].Name == "pk_id"
},
},
{
name: "extra constraint",
source: map[string]*models.Constraint{},
target: map[string]*models.Constraint{
"pk_id": {Name: "pk_id", Type: "PRIMARY KEY", Columns: []string{"id"}},
},
want: func(d *ConstraintDiff) bool {
return len(d.Extra) == 1 && d.Extra[0].Name == "pk_id"
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := compareConstraints(tt.source, tt.target)
if !tt.want(got) {
t.Errorf("compareConstraints() result doesn't match expectations")
}
})
}
}
func TestCompareRelationships(t *testing.T) {
tests := []struct {
name string
source map[string]*models.Relationship
target map[string]*models.Relationship
want func(*RelationshipDiff) bool
}{
{
name: "identical relationships",
source: map[string]*models.Relationship{},
target: map[string]*models.Relationship{},
want: func(d *RelationshipDiff) bool {
return len(d.Missing) == 0 && len(d.Extra) == 0 && len(d.Modified) == 0
},
},
{
name: "missing relationship",
source: map[string]*models.Relationship{
"fk_user": {Name: "fk_user", Type: "FOREIGN KEY"},
},
target: map[string]*models.Relationship{},
want: func(d *RelationshipDiff) bool {
return len(d.Missing) == 1 && d.Missing[0].Name == "fk_user"
},
},
{
name: "extra relationship",
source: map[string]*models.Relationship{},
target: map[string]*models.Relationship{
"fk_user": {Name: "fk_user", Type: "FOREIGN KEY"},
},
want: func(d *RelationshipDiff) bool {
return len(d.Extra) == 1 && d.Extra[0].Name == "fk_user"
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := compareRelationships(tt.source, tt.target)
if !tt.want(got) {
t.Errorf("compareRelationships() result doesn't match expectations")
}
})
}
}
func TestCompareTables(t *testing.T) {
tests := []struct {
name string
source []*models.Table
target []*models.Table
want func(*TableDiff) bool
}{
{
name: "identical tables",
source: []*models.Table{},
target: []*models.Table{},
want: func(d *TableDiff) bool {
return len(d.Missing) == 0 && len(d.Extra) == 0 && len(d.Modified) == 0
},
},
{
name: "missing table",
source: []*models.Table{
{Name: "users", Schema: "public"},
},
target: []*models.Table{},
want: func(d *TableDiff) bool {
return len(d.Missing) == 1 && d.Missing[0].Name == "users"
},
},
{
name: "extra table",
source: []*models.Table{},
target: []*models.Table{
{Name: "users", Schema: "public"},
},
want: func(d *TableDiff) bool {
return len(d.Extra) == 1 && d.Extra[0].Name == "users"
},
},
{
name: "modified table",
source: []*models.Table{
{
Name: "users",
Schema: "public",
Columns: map[string]*models.Column{
"id": {Name: "id", Type: "integer"},
},
},
},
target: []*models.Table{
{
Name: "users",
Schema: "public",
Columns: map[string]*models.Column{
"id": {Name: "id", Type: "bigint"},
},
},
},
want: func(d *TableDiff) bool {
return len(d.Modified) == 1 && d.Modified[0].Name == "users"
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := compareTables(tt.source, tt.target)
if !tt.want(got) {
t.Errorf("compareTables() result doesn't match expectations")
}
})
}
}
func TestCompareSchemas(t *testing.T) {
tests := []struct {
name string
source []*models.Schema
target []*models.Schema
want func(*SchemaDiff) bool
}{
{
name: "identical schemas",
source: []*models.Schema{},
target: []*models.Schema{},
want: func(d *SchemaDiff) bool {
return len(d.Missing) == 0 && len(d.Extra) == 0 && len(d.Modified) == 0
},
},
{
name: "missing schema",
source: []*models.Schema{
{Name: "public"},
},
target: []*models.Schema{},
want: func(d *SchemaDiff) bool {
return len(d.Missing) == 1 && d.Missing[0].Name == "public"
},
},
{
name: "extra schema",
source: []*models.Schema{},
target: []*models.Schema{
{Name: "public"},
},
want: func(d *SchemaDiff) bool {
return len(d.Extra) == 1 && d.Extra[0].Name == "public"
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := compareSchemas(tt.source, tt.target)
if !tt.want(got) {
t.Errorf("compareSchemas() result doesn't match expectations")
}
})
}
}
func TestIsEmpty(t *testing.T) {
tests := []struct {
name string
v interface{}
want bool
}{
{"empty ColumnDiff", &ColumnDiff{Missing: []*models.Column{}, Extra: []*models.Column{}, Modified: []*ColumnChange{}}, true},
{"ColumnDiff with missing", &ColumnDiff{Missing: []*models.Column{{Name: "id"}}, Extra: []*models.Column{}, Modified: []*ColumnChange{}}, false},
{"ColumnDiff with extra", &ColumnDiff{Missing: []*models.Column{}, Extra: []*models.Column{{Name: "id"}}, Modified: []*ColumnChange{}}, false},
{"empty IndexDiff", &IndexDiff{Missing: []*models.Index{}, Extra: []*models.Index{}, Modified: []*IndexChange{}}, true},
{"IndexDiff with missing", &IndexDiff{Missing: []*models.Index{{Name: "idx"}}, Extra: []*models.Index{}, Modified: []*IndexChange{}}, false},
{"empty TableDiff", &TableDiff{Missing: []*models.Table{}, Extra: []*models.Table{}, Modified: []*TableChange{}}, true},
{"TableDiff with extra", &TableDiff{Missing: []*models.Table{}, Extra: []*models.Table{{Name: "users"}}, Modified: []*TableChange{}}, false},
{"empty ConstraintDiff", &ConstraintDiff{Missing: []*models.Constraint{}, Extra: []*models.Constraint{}, Modified: []*ConstraintChange{}}, true},
{"empty RelationshipDiff", &RelationshipDiff{Missing: []*models.Relationship{}, Extra: []*models.Relationship{}, Modified: []*RelationshipChange{}}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := isEmpty(tt.v)
if got != tt.want {
t.Errorf("isEmpty() = %v, want %v", got, tt.want)
}
})
}
}
func TestComputeSummary(t *testing.T) {
tests := []struct {
name string
result *DiffResult
want func(*Summary) bool
}{
{
name: "empty diff",
result: &DiffResult{
Schemas: &SchemaDiff{
Missing: []*models.Schema{},
Extra: []*models.Schema{},
Modified: []*SchemaChange{},
},
},
want: func(s *Summary) bool {
return s.Schemas.Missing == 0 && s.Schemas.Extra == 0 && s.Schemas.Modified == 0
},
},
{
name: "schemas with differences",
result: &DiffResult{
Schemas: &SchemaDiff{
Missing: []*models.Schema{{Name: "schema1"}},
Extra: []*models.Schema{{Name: "schema2"}, {Name: "schema3"}},
Modified: []*SchemaChange{
{Name: "public"},
},
},
},
want: func(s *Summary) bool {
return s.Schemas.Missing == 1 && s.Schemas.Extra == 2 && s.Schemas.Modified == 1
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := ComputeSummary(tt.result)
if !tt.want(got) {
t.Errorf("ComputeSummary() result doesn't match expectations")
}
})
}
}

440
pkg/diff/formatters_test.go Normal file
View File

@@ -0,0 +1,440 @@
package diff
import (
"bytes"
"encoding/json"
"strings"
"testing"
"git.warky.dev/wdevs/relspecgo/pkg/models"
)
func TestFormatDiff(t *testing.T) {
result := &DiffResult{
Source: "source_db",
Target: "target_db",
Schemas: &SchemaDiff{
Missing: []*models.Schema{},
Extra: []*models.Schema{},
Modified: []*SchemaChange{},
},
}
tests := []struct {
name string
format OutputFormat
wantErr bool
}{
{"summary format", FormatSummary, false},
{"json format", FormatJSON, false},
{"html format", FormatHTML, false},
{"invalid format", OutputFormat("invalid"), true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var buf bytes.Buffer
err := FormatDiff(result, tt.format, &buf)
if (err != nil) != tt.wantErr {
t.Errorf("FormatDiff() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !tt.wantErr && buf.Len() == 0 {
t.Error("FormatDiff() produced empty output")
}
})
}
}
func TestFormatSummary(t *testing.T) {
tests := []struct {
name string
result *DiffResult
wantStr []string // strings that should appear in output
}{
{
name: "no differences",
result: &DiffResult{
Source: "source",
Target: "target",
Schemas: &SchemaDiff{
Missing: []*models.Schema{},
Extra: []*models.Schema{},
Modified: []*SchemaChange{},
},
},
wantStr: []string{"source", "target", "No differences found"},
},
{
name: "with schema differences",
result: &DiffResult{
Source: "source",
Target: "target",
Schemas: &SchemaDiff{
Missing: []*models.Schema{{Name: "schema1"}},
Extra: []*models.Schema{{Name: "schema2"}},
Modified: []*SchemaChange{
{Name: "public"},
},
},
},
wantStr: []string{"Schemas:", "Missing: 1", "Extra: 1", "Modified: 1"},
},
{
name: "with table differences",
result: &DiffResult{
Source: "source",
Target: "target",
Schemas: &SchemaDiff{
Modified: []*SchemaChange{
{
Name: "public",
Tables: &TableDiff{
Missing: []*models.Table{{Name: "users"}},
Extra: []*models.Table{{Name: "posts"}},
Modified: []*TableChange{
{Name: "comments", Schema: "public"},
},
},
},
},
},
},
wantStr: []string{"Tables:", "Missing: 1", "Extra: 1", "Modified: 1"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var buf bytes.Buffer
err := formatSummary(tt.result, &buf)
if err != nil {
t.Errorf("formatSummary() error = %v", err)
return
}
output := buf.String()
for _, want := range tt.wantStr {
if !strings.Contains(output, want) {
t.Errorf("formatSummary() output doesn't contain %q\nGot: %s", want, output)
}
}
})
}
}
func TestFormatJSON(t *testing.T) {
result := &DiffResult{
Source: "source",
Target: "target",
Schemas: &SchemaDiff{
Missing: []*models.Schema{{Name: "schema1"}},
Extra: []*models.Schema{},
Modified: []*SchemaChange{},
},
}
var buf bytes.Buffer
err := formatJSON(result, &buf)
if err != nil {
t.Errorf("formatJSON() error = %v", err)
return
}
// Check if output is valid JSON
var decoded DiffResult
if err := json.Unmarshal(buf.Bytes(), &decoded); err != nil {
t.Errorf("formatJSON() produced invalid JSON: %v", err)
}
// Check basic structure
if decoded.Source != "source" {
t.Errorf("formatJSON() source = %v, want %v", decoded.Source, "source")
}
if decoded.Target != "target" {
t.Errorf("formatJSON() target = %v, want %v", decoded.Target, "target")
}
if len(decoded.Schemas.Missing) != 1 {
t.Errorf("formatJSON() missing schemas = %v, want 1", len(decoded.Schemas.Missing))
}
}
func TestFormatHTML(t *testing.T) {
tests := []struct {
name string
result *DiffResult
wantStr []string // HTML elements/content that should appear
}{
{
name: "basic HTML structure",
result: &DiffResult{
Source: "source",
Target: "target",
Schemas: &SchemaDiff{
Missing: []*models.Schema{},
Extra: []*models.Schema{},
Modified: []*SchemaChange{},
},
},
wantStr: []string{
"<!DOCTYPE html>",
"<title>Database Diff Report</title>",
"source",
"target",
},
},
{
name: "with schema differences",
result: &DiffResult{
Source: "source",
Target: "target",
Schemas: &SchemaDiff{
Missing: []*models.Schema{{Name: "missing_schema"}},
Extra: []*models.Schema{{Name: "extra_schema"}},
Modified: []*SchemaChange{},
},
},
wantStr: []string{
"<!DOCTYPE html>",
"missing_schema",
"extra_schema",
"MISSING",
"EXTRA",
},
},
{
name: "with table modifications",
result: &DiffResult{
Source: "source",
Target: "target",
Schemas: &SchemaDiff{
Modified: []*SchemaChange{
{
Name: "public",
Tables: &TableDiff{
Modified: []*TableChange{
{
Name: "users",
Schema: "public",
Columns: &ColumnDiff{
Missing: []*models.Column{{Name: "email", Type: "text"}},
},
},
},
},
},
},
},
},
wantStr: []string{
"public",
"users",
"email",
"text",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var buf bytes.Buffer
err := formatHTML(tt.result, &buf)
if err != nil {
t.Errorf("formatHTML() error = %v", err)
return
}
output := buf.String()
for _, want := range tt.wantStr {
if !strings.Contains(output, want) {
t.Errorf("formatHTML() output doesn't contain %q", want)
}
}
})
}
}
func TestFormatSummaryWithColumns(t *testing.T) {
result := &DiffResult{
Source: "source",
Target: "target",
Schemas: &SchemaDiff{
Modified: []*SchemaChange{
{
Name: "public",
Tables: &TableDiff{
Modified: []*TableChange{
{
Name: "users",
Schema: "public",
Columns: &ColumnDiff{
Missing: []*models.Column{{Name: "email"}},
Extra: []*models.Column{{Name: "phone"}, {Name: "address"}},
Modified: []*ColumnChange{
{Name: "name"},
},
},
},
},
},
},
},
},
}
var buf bytes.Buffer
err := formatSummary(result, &buf)
if err != nil {
t.Errorf("formatSummary() error = %v", err)
return
}
output := buf.String()
wantStrings := []string{
"Columns:",
"Missing: 1",
"Extra: 2",
"Modified: 1",
}
for _, want := range wantStrings {
if !strings.Contains(output, want) {
t.Errorf("formatSummary() output doesn't contain %q\nGot: %s", want, output)
}
}
}
func TestFormatSummaryWithIndexes(t *testing.T) {
result := &DiffResult{
Source: "source",
Target: "target",
Schemas: &SchemaDiff{
Modified: []*SchemaChange{
{
Name: "public",
Tables: &TableDiff{
Modified: []*TableChange{
{
Name: "users",
Schema: "public",
Indexes: &IndexDiff{
Missing: []*models.Index{{Name: "idx_email"}},
Extra: []*models.Index{{Name: "idx_phone"}},
Modified: []*IndexChange{{Name: "idx_name"}},
},
},
},
},
},
},
},
}
var buf bytes.Buffer
err := formatSummary(result, &buf)
if err != nil {
t.Errorf("formatSummary() error = %v", err)
return
}
output := buf.String()
if !strings.Contains(output, "Indexes:") {
t.Error("formatSummary() output doesn't contain Indexes section")
}
if !strings.Contains(output, "Missing: 1") {
t.Error("formatSummary() output doesn't contain correct missing count")
}
}
func TestFormatSummaryWithConstraints(t *testing.T) {
result := &DiffResult{
Source: "source",
Target: "target",
Schemas: &SchemaDiff{
Modified: []*SchemaChange{
{
Name: "public",
Tables: &TableDiff{
Modified: []*TableChange{
{
Name: "users",
Schema: "public",
Constraints: &ConstraintDiff{
Missing: []*models.Constraint{{Name: "pk_users", Type: "PRIMARY KEY"}},
Extra: []*models.Constraint{{Name: "fk_users_roles", Type: "FOREIGN KEY"}},
},
},
},
},
},
},
},
}
var buf bytes.Buffer
err := formatSummary(result, &buf)
if err != nil {
t.Errorf("formatSummary() error = %v", err)
return
}
output := buf.String()
if !strings.Contains(output, "Constraints:") {
t.Error("formatSummary() output doesn't contain Constraints section")
}
}
func TestFormatJSONIndentation(t *testing.T) {
result := &DiffResult{
Source: "source",
Target: "target",
Schemas: &SchemaDiff{
Missing: []*models.Schema{{Name: "test"}},
},
}
var buf bytes.Buffer
err := formatJSON(result, &buf)
if err != nil {
t.Errorf("formatJSON() error = %v", err)
return
}
// Check that JSON is indented (has newlines and spaces)
output := buf.String()
if !strings.Contains(output, "\n") {
t.Error("formatJSON() should produce indented JSON with newlines")
}
if !strings.Contains(output, " ") {
t.Error("formatJSON() should produce indented JSON with spaces")
}
}
func TestOutputFormatConstants(t *testing.T) {
tests := []struct {
name string
format OutputFormat
want string
}{
{"summary constant", FormatSummary, "summary"},
{"json constant", FormatJSON, "json"},
{"html constant", FormatHTML, "html"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if string(tt.format) != tt.want {
t.Errorf("OutputFormat %v = %v, want %v", tt.name, tt.format, tt.want)
}
})
}
}

View File

@@ -0,0 +1,238 @@
package inspector
import (
"testing"
)
func TestNewInspector(t *testing.T) {
db := createTestDatabase()
config := GetDefaultConfig()
inspector := NewInspector(db, config)
if inspector == nil {
t.Fatal("NewInspector() returned nil")
}
if inspector.db != db {
t.Error("NewInspector() database not set correctly")
}
if inspector.config != config {
t.Error("NewInspector() config not set correctly")
}
}
func TestInspect(t *testing.T) {
db := createTestDatabase()
config := GetDefaultConfig()
inspector := NewInspector(db, config)
report, err := inspector.Inspect()
if err != nil {
t.Fatalf("Inspect() returned error: %v", err)
}
if report == nil {
t.Fatal("Inspect() returned nil report")
}
if report.Database != db.Name {
t.Errorf("Inspect() report.Database = %q, want %q", report.Database, db.Name)
}
if report.Summary.TotalRules != len(config.Rules) {
t.Errorf("Inspect() TotalRules = %d, want %d", report.Summary.TotalRules, len(config.Rules))
}
if len(report.Violations) == 0 {
t.Error("Inspect() returned no violations, expected some results")
}
}
func TestInspectWithDisabledRules(t *testing.T) {
db := createTestDatabase()
config := GetDefaultConfig()
// Disable all rules
for name := range config.Rules {
rule := config.Rules[name]
rule.Enabled = "off"
config.Rules[name] = rule
}
inspector := NewInspector(db, config)
report, err := inspector.Inspect()
if err != nil {
t.Fatalf("Inspect() with disabled rules returned error: %v", err)
}
if report.Summary.RulesChecked != 0 {
t.Errorf("Inspect() RulesChecked = %d, want 0 (all disabled)", report.Summary.RulesChecked)
}
if report.Summary.RulesSkipped != len(config.Rules) {
t.Errorf("Inspect() RulesSkipped = %d, want %d", report.Summary.RulesSkipped, len(config.Rules))
}
}
func TestInspectWithEnforcedRules(t *testing.T) {
db := createTestDatabase()
config := GetDefaultConfig()
// Enable only one rule and enforce it
for name := range config.Rules {
rule := config.Rules[name]
rule.Enabled = "off"
config.Rules[name] = rule
}
primaryKeyRule := config.Rules["primary_key_naming"]
primaryKeyRule.Enabled = "enforce"
primaryKeyRule.Pattern = "^id$"
config.Rules["primary_key_naming"] = primaryKeyRule
inspector := NewInspector(db, config)
report, err := inspector.Inspect()
if err != nil {
t.Fatalf("Inspect() returned error: %v", err)
}
if report.Summary.RulesChecked != 1 {
t.Errorf("Inspect() RulesChecked = %d, want 1", report.Summary.RulesChecked)
}
// All results should be at error level for enforced rules
for _, violation := range report.Violations {
if violation.Level != "error" {
t.Errorf("Enforced rule violation has Level = %q, want \"error\"", violation.Level)
}
}
}
func TestGenerateSummary(t *testing.T) {
db := createTestDatabase()
config := GetDefaultConfig()
inspector := NewInspector(db, config)
results := []ValidationResult{
{RuleName: "rule1", Passed: true, Level: "error"},
{RuleName: "rule2", Passed: false, Level: "error"},
{RuleName: "rule3", Passed: false, Level: "warning"},
{RuleName: "rule4", Passed: true, Level: "warning"},
}
summary := inspector.generateSummary(results)
if summary.PassedCount != 2 {
t.Errorf("generateSummary() PassedCount = %d, want 2", summary.PassedCount)
}
if summary.ErrorCount != 1 {
t.Errorf("generateSummary() ErrorCount = %d, want 1", summary.ErrorCount)
}
if summary.WarningCount != 1 {
t.Errorf("generateSummary() WarningCount = %d, want 1", summary.WarningCount)
}
}
func TestHasErrors(t *testing.T) {
tests := []struct {
name string
report *InspectorReport
want bool
}{
{
name: "with errors",
report: &InspectorReport{
Summary: ReportSummary{
ErrorCount: 5,
},
},
want: true,
},
{
name: "without errors",
report: &InspectorReport{
Summary: ReportSummary{
ErrorCount: 0,
WarningCount: 3,
},
},
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.report.HasErrors(); got != tt.want {
t.Errorf("HasErrors() = %v, want %v", got, tt.want)
}
})
}
}
func TestGetValidator(t *testing.T) {
tests := []struct {
name string
functionName string
wantExists bool
}{
{"primary_key_naming", "primary_key_naming", true},
{"primary_key_datatype", "primary_key_datatype", true},
{"foreign_key_column_naming", "foreign_key_column_naming", true},
{"table_regexpr", "table_regexpr", true},
{"column_regexpr", "column_regexpr", true},
{"reserved_words", "reserved_words", true},
{"have_primary_key", "have_primary_key", true},
{"orphaned_foreign_key", "orphaned_foreign_key", true},
{"circular_dependency", "circular_dependency", true},
{"unknown_function", "unknown_function", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, exists := getValidator(tt.functionName)
if exists != tt.wantExists {
t.Errorf("getValidator(%q) exists = %v, want %v", tt.functionName, exists, tt.wantExists)
}
})
}
}
func TestCreateResult(t *testing.T) {
result := createResult(
"test_rule",
true,
"Test message",
"schema.table.column",
map[string]interface{}{
"key1": "value1",
"key2": 42,
},
)
if result.RuleName != "test_rule" {
t.Errorf("createResult() RuleName = %q, want \"test_rule\"", result.RuleName)
}
if !result.Passed {
t.Error("createResult() Passed = false, want true")
}
if result.Message != "Test message" {
t.Errorf("createResult() Message = %q, want \"Test message\"", result.Message)
}
if result.Location != "schema.table.column" {
t.Errorf("createResult() Location = %q, want \"schema.table.column\"", result.Location)
}
if len(result.Context) != 2 {
t.Errorf("createResult() Context length = %d, want 2", len(result.Context))
}
}

View File

@@ -0,0 +1,366 @@
package inspector
import (
"bytes"
"encoding/json"
"strings"
"testing"
"time"
)
func createTestReport() *InspectorReport {
return &InspectorReport{
Summary: ReportSummary{
TotalRules: 10,
RulesChecked: 8,
RulesSkipped: 2,
ErrorCount: 3,
WarningCount: 5,
PassedCount: 12,
},
Violations: []ValidationResult{
{
RuleName: "primary_key_naming",
Level: "error",
Message: "Primary key should start with 'id_'",
Location: "public.users.user_id",
Passed: false,
Context: map[string]interface{}{
"schema": "public",
"table": "users",
"column": "user_id",
"pattern": "^id_",
},
},
{
RuleName: "table_name_length",
Level: "warning",
Message: "Table name too long",
Location: "public.very_long_table_name_that_exceeds_limits",
Passed: false,
Context: map[string]interface{}{
"schema": "public",
"table": "very_long_table_name_that_exceeds_limits",
"length": 44,
"max_length": 32,
},
},
},
GeneratedAt: time.Now(),
Database: "testdb",
SourceFormat: "postgresql",
}
}
func TestNewMarkdownFormatter(t *testing.T) {
var buf bytes.Buffer
formatter := NewMarkdownFormatter(&buf)
if formatter == nil {
t.Fatal("NewMarkdownFormatter() returned nil")
}
// Buffer is not a terminal, so colors should be disabled
if formatter.UseColors {
t.Error("NewMarkdownFormatter() UseColors should be false for non-terminal")
}
}
func TestNewJSONFormatter(t *testing.T) {
formatter := NewJSONFormatter()
if formatter == nil {
t.Fatal("NewJSONFormatter() returned nil")
}
}
func TestMarkdownFormatter_Format(t *testing.T) {
report := createTestReport()
var buf bytes.Buffer
formatter := NewMarkdownFormatter(&buf)
output, err := formatter.Format(report)
if err != nil {
t.Fatalf("MarkdownFormatter.Format() returned error: %v", err)
}
// Check that output contains expected sections
if !strings.Contains(output, "# RelSpec Inspector Report") {
t.Error("Markdown output missing header")
}
if !strings.Contains(output, "Database:") {
t.Error("Markdown output missing database field")
}
if !strings.Contains(output, "testdb") {
t.Error("Markdown output missing database name")
}
if !strings.Contains(output, "Summary") {
t.Error("Markdown output missing summary section")
}
if !strings.Contains(output, "Rules Checked: 8") {
t.Error("Markdown output missing rules checked count")
}
if !strings.Contains(output, "Errors: 3") {
t.Error("Markdown output missing error count")
}
if !strings.Contains(output, "Warnings: 5") {
t.Error("Markdown output missing warning count")
}
if !strings.Contains(output, "Violations") {
t.Error("Markdown output missing violations section")
}
if !strings.Contains(output, "primary_key_naming") {
t.Error("Markdown output missing rule name")
}
if !strings.Contains(output, "public.users.user_id") {
t.Error("Markdown output missing location")
}
}
func TestMarkdownFormatter_FormatNoViolations(t *testing.T) {
report := &InspectorReport{
Summary: ReportSummary{
TotalRules: 10,
RulesChecked: 10,
RulesSkipped: 0,
ErrorCount: 0,
WarningCount: 0,
PassedCount: 50,
},
Violations: []ValidationResult{},
GeneratedAt: time.Now(),
Database: "testdb",
SourceFormat: "postgresql",
}
var buf bytes.Buffer
formatter := NewMarkdownFormatter(&buf)
output, err := formatter.Format(report)
if err != nil {
t.Fatalf("MarkdownFormatter.Format() returned error: %v", err)
}
if !strings.Contains(output, "No violations found") {
t.Error("Markdown output should indicate no violations")
}
}
func TestJSONFormatter_Format(t *testing.T) {
report := createTestReport()
formatter := NewJSONFormatter()
output, err := formatter.Format(report)
if err != nil {
t.Fatalf("JSONFormatter.Format() returned error: %v", err)
}
// Verify it's valid JSON
var decoded InspectorReport
if err := json.Unmarshal([]byte(output), &decoded); err != nil {
t.Fatalf("JSONFormatter.Format() produced invalid JSON: %v", err)
}
// Check key fields
if decoded.Database != "testdb" {
t.Errorf("JSON decoded Database = %q, want \"testdb\"", decoded.Database)
}
if decoded.Summary.ErrorCount != 3 {
t.Errorf("JSON decoded ErrorCount = %d, want 3", decoded.Summary.ErrorCount)
}
if len(decoded.Violations) != 2 {
t.Errorf("JSON decoded Violations length = %d, want 2", len(decoded.Violations))
}
}
func TestMarkdownFormatter_FormatHeader(t *testing.T) {
var buf bytes.Buffer
formatter := NewMarkdownFormatter(&buf)
header := formatter.formatHeader("Test Header")
if !strings.Contains(header, "# Test Header") {
t.Errorf("formatHeader() = %q, want to contain \"# Test Header\"", header)
}
}
func TestMarkdownFormatter_FormatBold(t *testing.T) {
tests := []struct {
name string
useColors bool
text string
wantContains string
}{
{
name: "without colors",
useColors: false,
text: "Bold Text",
wantContains: "**Bold Text**",
},
{
name: "with colors",
useColors: true,
text: "Bold Text",
wantContains: "Bold Text",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
formatter := &MarkdownFormatter{UseColors: tt.useColors}
result := formatter.formatBold(tt.text)
if !strings.Contains(result, tt.wantContains) {
t.Errorf("formatBold() = %q, want to contain %q", result, tt.wantContains)
}
})
}
}
func TestMarkdownFormatter_Colorize(t *testing.T) {
tests := []struct {
name string
useColors bool
text string
color string
wantColor bool
}{
{
name: "without colors",
useColors: false,
text: "Test",
color: colorRed,
wantColor: false,
},
{
name: "with colors",
useColors: true,
text: "Test",
color: colorRed,
wantColor: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
formatter := &MarkdownFormatter{UseColors: tt.useColors}
result := formatter.colorize(tt.text, tt.color)
hasColor := strings.Contains(result, tt.color)
if hasColor != tt.wantColor {
t.Errorf("colorize() has color codes = %v, want %v", hasColor, tt.wantColor)
}
if !strings.Contains(result, tt.text) {
t.Errorf("colorize() doesn't contain original text %q", tt.text)
}
})
}
}
func TestMarkdownFormatter_FormatContext(t *testing.T) {
formatter := &MarkdownFormatter{UseColors: false}
context := map[string]interface{}{
"schema": "public",
"table": "users",
"column": "id",
"pattern": "^id_",
"max_length": 64,
}
result := formatter.formatContext(context)
// Should not include schema, table, column (they're in location)
if strings.Contains(result, "schema") {
t.Error("formatContext() should skip schema field")
}
if strings.Contains(result, "table=") {
t.Error("formatContext() should skip table field")
}
if strings.Contains(result, "column=") {
t.Error("formatContext() should skip column field")
}
// Should include other fields
if !strings.Contains(result, "pattern") {
t.Error("formatContext() should include pattern field")
}
if !strings.Contains(result, "max_length") {
t.Error("formatContext() should include max_length field")
}
}
func TestMarkdownFormatter_FormatViolation(t *testing.T) {
formatter := &MarkdownFormatter{UseColors: false}
violation := ValidationResult{
RuleName: "test_rule",
Level: "error",
Message: "Test violation message",
Location: "public.users.id",
Passed: false,
Context: map[string]interface{}{
"pattern": "^id_",
},
}
result := formatter.formatViolation(violation, colorRed)
if !strings.Contains(result, "test_rule") {
t.Error("formatViolation() should include rule name")
}
if !strings.Contains(result, "Test violation message") {
t.Error("formatViolation() should include message")
}
if !strings.Contains(result, "public.users.id") {
t.Error("formatViolation() should include location")
}
if !strings.Contains(result, "Location:") {
t.Error("formatViolation() should include Location label")
}
if !strings.Contains(result, "Message:") {
t.Error("formatViolation() should include Message label")
}
}
func TestReportFormatConstants(t *testing.T) {
// Test that color constants are defined
if colorReset == "" {
t.Error("colorReset is not defined")
}
if colorRed == "" {
t.Error("colorRed is not defined")
}
if colorYellow == "" {
t.Error("colorYellow is not defined")
}
if colorGreen == "" {
t.Error("colorGreen is not defined")
}
if colorBold == "" {
t.Error("colorBold is not defined")
}
}

249
pkg/inspector/rules_test.go Normal file
View File

@@ -0,0 +1,249 @@
package inspector
import (
"os"
"path/filepath"
"testing"
)
func TestGetDefaultConfig(t *testing.T) {
config := GetDefaultConfig()
if config == nil {
t.Fatal("GetDefaultConfig() returned nil")
}
if config.Version != "1.0" {
t.Errorf("GetDefaultConfig() Version = %q, want \"1.0\"", config.Version)
}
if len(config.Rules) == 0 {
t.Error("GetDefaultConfig() returned no rules")
}
// Check that all expected rules are present
expectedRules := []string{
"primary_key_naming",
"primary_key_datatype",
"primary_key_auto_increment",
"foreign_key_column_naming",
"foreign_key_constraint_naming",
"foreign_key_index",
"table_naming_case",
"column_naming_case",
"table_name_length",
"column_name_length",
"reserved_keywords",
"missing_primary_key",
"orphaned_foreign_key",
"circular_dependency",
}
for _, ruleName := range expectedRules {
if _, exists := config.Rules[ruleName]; !exists {
t.Errorf("GetDefaultConfig() missing rule: %q", ruleName)
}
}
}
func TestLoadConfig_NonExistentFile(t *testing.T) {
// Try to load a non-existent file
config, err := LoadConfig("/path/to/nonexistent/file.yaml")
if err != nil {
t.Fatalf("LoadConfig() with non-existent file returned error: %v", err)
}
// Should return default config
if config == nil {
t.Fatal("LoadConfig() returned nil config for non-existent file")
}
if len(config.Rules) == 0 {
t.Error("LoadConfig() returned config with no rules")
}
}
func TestLoadConfig_ValidFile(t *testing.T) {
// Create a temporary config file
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "test-config.yaml")
configContent := `version: "1.0"
rules:
primary_key_naming:
enabled: "enforce"
function: "primary_key_naming"
pattern: "^pk_"
message: "Primary keys must start with pk_"
table_name_length:
enabled: "warn"
function: "table_name_length"
max_length: 50
message: "Table name too long"
`
err := os.WriteFile(configPath, []byte(configContent), 0644)
if err != nil {
t.Fatalf("Failed to create test config file: %v", err)
}
config, err := LoadConfig(configPath)
if err != nil {
t.Fatalf("LoadConfig() returned error: %v", err)
}
if config.Version != "1.0" {
t.Errorf("LoadConfig() Version = %q, want \"1.0\"", config.Version)
}
if len(config.Rules) != 2 {
t.Errorf("LoadConfig() loaded %d rules, want 2", len(config.Rules))
}
// Check primary_key_naming rule
pkRule, exists := config.Rules["primary_key_naming"]
if !exists {
t.Fatal("LoadConfig() missing primary_key_naming rule")
}
if pkRule.Enabled != "enforce" {
t.Errorf("primary_key_naming.Enabled = %q, want \"enforce\"", pkRule.Enabled)
}
if pkRule.Pattern != "^pk_" {
t.Errorf("primary_key_naming.Pattern = %q, want \"^pk_\"", pkRule.Pattern)
}
// Check table_name_length rule
lengthRule, exists := config.Rules["table_name_length"]
if !exists {
t.Fatal("LoadConfig() missing table_name_length rule")
}
if lengthRule.MaxLength != 50 {
t.Errorf("table_name_length.MaxLength = %d, want 50", lengthRule.MaxLength)
}
}
func TestLoadConfig_InvalidYAML(t *testing.T) {
// Create a temporary invalid config file
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "invalid-config.yaml")
invalidContent := `invalid: yaml: content: {[}]`
err := os.WriteFile(configPath, []byte(invalidContent), 0644)
if err != nil {
t.Fatalf("Failed to create test config file: %v", err)
}
_, err = LoadConfig(configPath)
if err == nil {
t.Error("LoadConfig() with invalid YAML did not return error")
}
}
func TestRuleIsEnabled(t *testing.T) {
tests := []struct {
name string
rule Rule
want bool
}{
{
name: "enforce is enabled",
rule: Rule{Enabled: "enforce"},
want: true,
},
{
name: "warn is enabled",
rule: Rule{Enabled: "warn"},
want: true,
},
{
name: "off is not enabled",
rule: Rule{Enabled: "off"},
want: false,
},
{
name: "empty is not enabled",
rule: Rule{Enabled: ""},
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.rule.IsEnabled(); got != tt.want {
t.Errorf("Rule.IsEnabled() = %v, want %v", got, tt.want)
}
})
}
}
func TestRuleIsEnforced(t *testing.T) {
tests := []struct {
name string
rule Rule
want bool
}{
{
name: "enforce is enforced",
rule: Rule{Enabled: "enforce"},
want: true,
},
{
name: "warn is not enforced",
rule: Rule{Enabled: "warn"},
want: false,
},
{
name: "off is not enforced",
rule: Rule{Enabled: "off"},
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.rule.IsEnforced(); got != tt.want {
t.Errorf("Rule.IsEnforced() = %v, want %v", got, tt.want)
}
})
}
}
func TestDefaultConfigRuleSettings(t *testing.T) {
config := GetDefaultConfig()
// Test specific rule settings
pkNamingRule := config.Rules["primary_key_naming"]
if pkNamingRule.Function != "primary_key_naming" {
t.Errorf("primary_key_naming.Function = %q, want \"primary_key_naming\"", pkNamingRule.Function)
}
if pkNamingRule.Pattern != "^id_" {
t.Errorf("primary_key_naming.Pattern = %q, want \"^id_\"", pkNamingRule.Pattern)
}
// Test datatype rule
pkDatatypeRule := config.Rules["primary_key_datatype"]
if len(pkDatatypeRule.AllowedTypes) == 0 {
t.Error("primary_key_datatype has no allowed types")
}
// Test length rule
tableLengthRule := config.Rules["table_name_length"]
if tableLengthRule.MaxLength != 64 {
t.Errorf("table_name_length.MaxLength = %d, want 64", tableLengthRule.MaxLength)
}
// Test reserved keywords rule
reservedRule := config.Rules["reserved_keywords"]
if !reservedRule.CheckTables {
t.Error("reserved_keywords.CheckTables should be true")
}
if !reservedRule.CheckColumns {
t.Error("reserved_keywords.CheckColumns should be true")
}
}

View File

@@ -0,0 +1,837 @@
package inspector
import (
"testing"
"git.warky.dev/wdevs/relspecgo/pkg/models"
)
// Helper function to create test database
func createTestDatabase() *models.Database {
return &models.Database{
Name: "testdb",
Schemas: []*models.Schema{
{
Name: "public",
Tables: []*models.Table{
{
Name: "users",
Columns: map[string]*models.Column{
"id": {
Name: "id",
Type: "bigserial",
IsPrimaryKey: true,
AutoIncrement: true,
},
"username": {
Name: "username",
Type: "varchar(50)",
NotNull: true,
IsPrimaryKey: false,
},
"rid_organization": {
Name: "rid_organization",
Type: "bigint",
NotNull: true,
IsPrimaryKey: false,
},
},
Constraints: map[string]*models.Constraint{
"fk_users_organization": {
Name: "fk_users_organization",
Type: models.ForeignKeyConstraint,
Columns: []string{"rid_organization"},
ReferencedTable: "organizations",
ReferencedSchema: "public",
ReferencedColumns: []string{"id"},
},
},
Indexes: map[string]*models.Index{
"idx_rid_organization": {
Name: "idx_rid_organization",
Columns: []string{"rid_organization"},
},
},
},
{
Name: "organizations",
Columns: map[string]*models.Column{
"id": {
Name: "id",
Type: "bigserial",
IsPrimaryKey: true,
AutoIncrement: true,
},
"name": {
Name: "name",
Type: "varchar(100)",
NotNull: true,
IsPrimaryKey: false,
},
},
},
},
},
},
}
}
func TestValidatePrimaryKeyNaming(t *testing.T) {
db := createTestDatabase()
tests := []struct {
name string
rule Rule
wantLen int
wantPass bool
}{
{
name: "matching pattern id",
rule: Rule{
Pattern: "^id$",
Message: "Primary key should be 'id'",
},
wantLen: 2,
wantPass: true,
},
{
name: "non-matching pattern id_",
rule: Rule{
Pattern: "^id_",
Message: "Primary key should start with 'id_'",
},
wantLen: 2,
wantPass: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
results := validatePrimaryKeyNaming(db, tt.rule, "test_rule")
if len(results) != tt.wantLen {
t.Errorf("validatePrimaryKeyNaming() returned %d results, want %d", len(results), tt.wantLen)
}
if len(results) > 0 && results[0].Passed != tt.wantPass {
t.Errorf("validatePrimaryKeyNaming() passed=%v, want %v", results[0].Passed, tt.wantPass)
}
})
}
}
func TestValidatePrimaryKeyDatatype(t *testing.T) {
db := createTestDatabase()
tests := []struct {
name string
rule Rule
wantLen int
wantPass bool
}{
{
name: "allowed type bigserial",
rule: Rule{
AllowedTypes: []string{"bigserial", "bigint", "int"},
Message: "Primary key should use integer types",
},
wantLen: 2,
wantPass: true,
},
{
name: "disallowed type",
rule: Rule{
AllowedTypes: []string{"uuid"},
Message: "Primary key should use UUID",
},
wantLen: 2,
wantPass: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
results := validatePrimaryKeyDatatype(db, tt.rule, "test_rule")
if len(results) != tt.wantLen {
t.Errorf("validatePrimaryKeyDatatype() returned %d results, want %d", len(results), tt.wantLen)
}
if len(results) > 0 && results[0].Passed != tt.wantPass {
t.Errorf("validatePrimaryKeyDatatype() passed=%v, want %v", results[0].Passed, tt.wantPass)
}
})
}
}
func TestValidatePrimaryKeyAutoIncrement(t *testing.T) {
db := createTestDatabase()
tests := []struct {
name string
rule Rule
wantLen int
}{
{
name: "require auto increment",
rule: Rule{
RequireAutoIncrement: true,
Message: "Primary key should have auto-increment",
},
wantLen: 0, // No violations - all PKs have auto-increment
},
{
name: "disallow auto increment",
rule: Rule{
RequireAutoIncrement: false,
Message: "Primary key should not have auto-increment",
},
wantLen: 2, // 2 violations - both PKs have auto-increment
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
results := validatePrimaryKeyAutoIncrement(db, tt.rule, "test_rule")
if len(results) != tt.wantLen {
t.Errorf("validatePrimaryKeyAutoIncrement() returned %d results, want %d", len(results), tt.wantLen)
}
})
}
}
func TestValidateForeignKeyColumnNaming(t *testing.T) {
db := createTestDatabase()
tests := []struct {
name string
rule Rule
wantLen int
wantPass bool
}{
{
name: "matching pattern rid_",
rule: Rule{
Pattern: "^rid_",
Message: "Foreign key columns should start with 'rid_'",
},
wantLen: 1,
wantPass: true,
},
{
name: "non-matching pattern fk_",
rule: Rule{
Pattern: "^fk_",
Message: "Foreign key columns should start with 'fk_'",
},
wantLen: 1,
wantPass: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
results := validateForeignKeyColumnNaming(db, tt.rule, "test_rule")
if len(results) != tt.wantLen {
t.Errorf("validateForeignKeyColumnNaming() returned %d results, want %d", len(results), tt.wantLen)
}
if len(results) > 0 && results[0].Passed != tt.wantPass {
t.Errorf("validateForeignKeyColumnNaming() passed=%v, want %v", results[0].Passed, tt.wantPass)
}
})
}
}
func TestValidateForeignKeyConstraintNaming(t *testing.T) {
db := createTestDatabase()
tests := []struct {
name string
rule Rule
wantLen int
wantPass bool
}{
{
name: "matching pattern fk_",
rule: Rule{
Pattern: "^fk_",
Message: "Foreign key constraints should start with 'fk_'",
},
wantLen: 1,
wantPass: true,
},
{
name: "non-matching pattern FK_",
rule: Rule{
Pattern: "^FK_",
Message: "Foreign key constraints should start with 'FK_'",
},
wantLen: 1,
wantPass: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
results := validateForeignKeyConstraintNaming(db, tt.rule, "test_rule")
if len(results) != tt.wantLen {
t.Errorf("validateForeignKeyConstraintNaming() returned %d results, want %d", len(results), tt.wantLen)
}
if len(results) > 0 && results[0].Passed != tt.wantPass {
t.Errorf("validateForeignKeyConstraintNaming() passed=%v, want %v", results[0].Passed, tt.wantPass)
}
})
}
}
func TestValidateForeignKeyIndex(t *testing.T) {
db := createTestDatabase()
tests := []struct {
name string
rule Rule
wantLen int
wantPass bool
}{
{
name: "require index with index present",
rule: Rule{
RequireIndex: true,
Message: "Foreign key columns should have indexes",
},
wantLen: 1,
wantPass: true,
},
{
name: "no requirement",
rule: Rule{
RequireIndex: false,
Message: "Foreign key index check disabled",
},
wantLen: 0,
wantPass: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
results := validateForeignKeyIndex(db, tt.rule, "test_rule")
if len(results) != tt.wantLen {
t.Errorf("validateForeignKeyIndex() returned %d results, want %d", len(results), tt.wantLen)
}
if len(results) > 0 && results[0].Passed != tt.wantPass {
t.Errorf("validateForeignKeyIndex() passed=%v, want %v", results[0].Passed, tt.wantPass)
}
})
}
}
func TestValidateTableNamingCase(t *testing.T) {
db := createTestDatabase()
tests := []struct {
name string
rule Rule
wantLen int
wantPass bool
}{
{
name: "lowercase snake_case pattern",
rule: Rule{
Pattern: "^[a-z][a-z0-9_]*$",
Case: "lowercase",
Message: "Table names should be lowercase snake_case",
},
wantLen: 2,
wantPass: true,
},
{
name: "uppercase pattern",
rule: Rule{
Pattern: "^[A-Z][A-Z0-9_]*$",
Case: "uppercase",
Message: "Table names should be uppercase",
},
wantLen: 2,
wantPass: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
results := validateTableNamingCase(db, tt.rule, "test_rule")
if len(results) != tt.wantLen {
t.Errorf("validateTableNamingCase() returned %d results, want %d", len(results), tt.wantLen)
}
if len(results) > 0 && results[0].Passed != tt.wantPass {
t.Errorf("validateTableNamingCase() passed=%v, want %v", results[0].Passed, tt.wantPass)
}
})
}
}
func TestValidateColumnNamingCase(t *testing.T) {
db := createTestDatabase()
tests := []struct {
name string
rule Rule
wantLen int
wantPass bool
}{
{
name: "lowercase snake_case pattern",
rule: Rule{
Pattern: "^[a-z][a-z0-9_]*$",
Case: "lowercase",
Message: "Column names should be lowercase snake_case",
},
wantLen: 5, // 5 total columns across both tables
wantPass: true,
},
{
name: "camelCase pattern",
rule: Rule{
Pattern: "^[a-z][a-zA-Z0-9]*$",
Case: "camelCase",
Message: "Column names should be camelCase",
},
wantLen: 5,
wantPass: false, // rid_organization has underscore
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
results := validateColumnNamingCase(db, tt.rule, "test_rule")
if len(results) != tt.wantLen {
t.Errorf("validateColumnNamingCase() returned %d results, want %d", len(results), tt.wantLen)
}
})
}
}
func TestValidateTableNameLength(t *testing.T) {
db := createTestDatabase()
tests := []struct {
name string
rule Rule
wantLen int
wantPass bool
}{
{
name: "max length 64",
rule: Rule{
MaxLength: 64,
Message: "Table name too long",
},
wantLen: 2,
wantPass: true,
},
{
name: "max length 5",
rule: Rule{
MaxLength: 5,
Message: "Table name too long",
},
wantLen: 2,
wantPass: false, // "users" is 5 chars (passes), "organizations" is 13 (fails)
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
results := validateTableNameLength(db, tt.rule, "test_rule")
if len(results) != tt.wantLen {
t.Errorf("validateTableNameLength() returned %d results, want %d", len(results), tt.wantLen)
}
})
}
}
func TestValidateColumnNameLength(t *testing.T) {
db := createTestDatabase()
tests := []struct {
name string
rule Rule
wantLen int
wantPass bool
}{
{
name: "max length 64",
rule: Rule{
MaxLength: 64,
Message: "Column name too long",
},
wantLen: 5,
wantPass: true,
},
{
name: "max length 5",
rule: Rule{
MaxLength: 5,
Message: "Column name too long",
},
wantLen: 5,
wantPass: false, // Some columns exceed 5 chars
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
results := validateColumnNameLength(db, tt.rule, "test_rule")
if len(results) != tt.wantLen {
t.Errorf("validateColumnNameLength() returned %d results, want %d", len(results), tt.wantLen)
}
})
}
}
func TestValidateReservedKeywords(t *testing.T) {
// Create a database with reserved keywords
db := &models.Database{
Name: "testdb",
Schemas: []*models.Schema{
{
Name: "public",
Tables: []*models.Table{
{
Name: "user", // "user" is a reserved keyword
Columns: map[string]*models.Column{
"id": {
Name: "id",
Type: "bigint",
IsPrimaryKey: true,
},
"select": { // "select" is a reserved keyword
Name: "select",
Type: "varchar(50)",
},
},
},
},
},
},
}
tests := []struct {
name string
rule Rule
wantLen int
checkPasses bool
}{
{
name: "check tables only",
rule: Rule{
CheckTables: true,
CheckColumns: false,
Message: "Reserved keyword used",
},
wantLen: 1, // "user" table
checkPasses: false,
},
{
name: "check columns only",
rule: Rule{
CheckTables: false,
CheckColumns: true,
Message: "Reserved keyword used",
},
wantLen: 2, // "id", "select" columns (id passes, select fails)
checkPasses: false,
},
{
name: "check both",
rule: Rule{
CheckTables: true,
CheckColumns: true,
Message: "Reserved keyword used",
},
wantLen: 3, // "user" table + "id", "select" columns
checkPasses: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
results := validateReservedKeywords(db, tt.rule, "test_rule")
if len(results) != tt.wantLen {
t.Errorf("validateReservedKeywords() returned %d results, want %d", len(results), tt.wantLen)
}
})
}
}
func TestValidateMissingPrimaryKey(t *testing.T) {
// Create database with and without primary keys
db := &models.Database{
Name: "testdb",
Schemas: []*models.Schema{
{
Name: "public",
Tables: []*models.Table{
{
Name: "with_pk",
Columns: map[string]*models.Column{
"id": {
Name: "id",
Type: "bigint",
IsPrimaryKey: true,
},
},
},
{
Name: "without_pk",
Columns: map[string]*models.Column{
"name": {
Name: "name",
Type: "varchar(50)",
},
},
},
},
},
},
}
rule := Rule{
Message: "Table missing primary key",
}
results := validateMissingPrimaryKey(db, rule, "test_rule")
if len(results) != 2 {
t.Errorf("validateMissingPrimaryKey() returned %d results, want 2", len(results))
}
// First result should pass (with_pk has PK)
if results[0].Passed != true {
t.Errorf("validateMissingPrimaryKey() result[0].Passed=%v, want true", results[0].Passed)
}
// Second result should fail (without_pk missing PK)
if results[1].Passed != false {
t.Errorf("validateMissingPrimaryKey() result[1].Passed=%v, want false", results[1].Passed)
}
}
func TestValidateOrphanedForeignKey(t *testing.T) {
// Create database with orphaned FK
db := &models.Database{
Name: "testdb",
Schemas: []*models.Schema{
{
Name: "public",
Tables: []*models.Table{
{
Name: "users",
Columns: map[string]*models.Column{
"id": {
Name: "id",
Type: "bigint",
IsPrimaryKey: true,
},
},
Constraints: map[string]*models.Constraint{
"fk_nonexistent": {
Name: "fk_nonexistent",
Type: models.ForeignKeyConstraint,
Columns: []string{"rid_organization"},
ReferencedTable: "nonexistent_table",
ReferencedSchema: "public",
},
},
},
},
},
},
}
rule := Rule{
Message: "Foreign key references non-existent table",
}
results := validateOrphanedForeignKey(db, rule, "test_rule")
if len(results) != 1 {
t.Errorf("validateOrphanedForeignKey() returned %d results, want 1", len(results))
}
if results[0].Passed != false {
t.Errorf("validateOrphanedForeignKey() passed=%v, want false", results[0].Passed)
}
}
func TestValidateCircularDependency(t *testing.T) {
// Create database with circular dependency
db := &models.Database{
Name: "testdb",
Schemas: []*models.Schema{
{
Name: "public",
Tables: []*models.Table{
{
Name: "table_a",
Columns: map[string]*models.Column{
"id": {Name: "id", Type: "bigint", IsPrimaryKey: true},
},
Constraints: map[string]*models.Constraint{
"fk_to_b": {
Name: "fk_to_b",
Type: models.ForeignKeyConstraint,
ReferencedTable: "table_b",
ReferencedSchema: "public",
},
},
},
{
Name: "table_b",
Columns: map[string]*models.Column{
"id": {Name: "id", Type: "bigint", IsPrimaryKey: true},
},
Constraints: map[string]*models.Constraint{
"fk_to_a": {
Name: "fk_to_a",
Type: models.ForeignKeyConstraint,
ReferencedTable: "table_a",
ReferencedSchema: "public",
},
},
},
},
},
},
}
rule := Rule{
Message: "Circular dependency detected",
}
results := validateCircularDependency(db, rule, "test_rule")
// Should detect circular dependency in both tables
if len(results) == 0 {
t.Error("validateCircularDependency() returned 0 results, expected circular dependency detection")
}
for _, result := range results {
if result.Passed {
t.Error("validateCircularDependency() passed=true, want false for circular dependency")
}
}
}
func TestNormalizeDataType(t *testing.T) {
tests := []struct {
input string
expected string
}{
{"varchar(50)", "varchar"},
{"decimal(10,2)", "decimal"},
{"int", "int"},
{"BIGINT", "bigint"},
{"VARCHAR(255)", "varchar"},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
result := normalizeDataType(tt.input)
if result != tt.expected {
t.Errorf("normalizeDataType(%q) = %q, want %q", tt.input, result, tt.expected)
}
})
}
}
func TestContains(t *testing.T) {
tests := []struct {
name string
slice []string
value string
expected bool
}{
{"found exact", []string{"foo", "bar", "baz"}, "bar", true},
{"not found", []string{"foo", "bar", "baz"}, "qux", false},
{"case insensitive match", []string{"foo", "Bar", "baz"}, "bar", true},
{"empty slice", []string{}, "foo", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := contains(tt.slice, tt.value)
if result != tt.expected {
t.Errorf("contains(%v, %q) = %v, want %v", tt.slice, tt.value, result, tt.expected)
}
})
}
}
func TestHasCycle(t *testing.T) {
tests := []struct {
name string
graph map[string][]string
node string
expected bool
}{
{
name: "simple cycle",
graph: map[string][]string{
"A": {"B"},
"B": {"C"},
"C": {"A"},
},
node: "A",
expected: true,
},
{
name: "no cycle",
graph: map[string][]string{
"A": {"B"},
"B": {"C"},
"C": {},
},
node: "A",
expected: false,
},
{
name: "self cycle",
graph: map[string][]string{
"A": {"A"},
},
node: "A",
expected: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
visited := make(map[string]bool)
recStack := make(map[string]bool)
result := hasCycle(tt.node, tt.graph, visited, recStack)
if result != tt.expected {
t.Errorf("hasCycle() = %v, want %v", result, tt.expected)
}
})
}
}
func TestFormatLocation(t *testing.T) {
tests := []struct {
schema string
table string
column string
expected string
}{
{"public", "users", "id", "public.users.id"},
{"public", "users", "", "public.users"},
{"public", "", "", "public"},
}
for _, tt := range tests {
t.Run(tt.expected, func(t *testing.T) {
result := formatLocation(tt.schema, tt.table, tt.column)
if result != tt.expected {
t.Errorf("formatLocation(%q, %q, %q) = %q, want %q",
tt.schema, tt.table, tt.column, result, tt.expected)
}
})
}
}

View File

@@ -12,14 +12,16 @@ import (
// MergeResult represents the result of a merge operation
type MergeResult struct {
SchemasAdded int
TablesAdded int
ColumnsAdded int
RelationsAdded int
DomainsAdded int
EnumsAdded int
ViewsAdded int
SequencesAdded int
SchemasAdded int
TablesAdded int
ColumnsAdded int
ConstraintsAdded int
IndexesAdded int
RelationsAdded int
DomainsAdded int
EnumsAdded int
ViewsAdded int
SequencesAdded int
}
// MergeOptions contains options for merge operations
@@ -120,8 +122,10 @@ func (r *MergeResult) mergeTables(schema *models.Schema, source *models.Schema,
}
if tgtTable, exists := existingTables[tableName]; exists {
// Table exists, merge its columns
// Table exists, merge its columns, constraints, and indexes
r.mergeColumns(tgtTable, srcTable)
r.mergeConstraints(tgtTable, srcTable)
r.mergeIndexes(tgtTable, srcTable)
} else {
// Table doesn't exist, add it
newTable := cloneTable(srcTable)
@@ -151,6 +155,52 @@ func (r *MergeResult) mergeColumns(table *models.Table, srcTable *models.Table)
}
}
func (r *MergeResult) mergeConstraints(table *models.Table, srcTable *models.Table) {
// Initialize constraints map if nil
if table.Constraints == nil {
table.Constraints = make(map[string]*models.Constraint)
}
// Create map of existing constraints
existingConstraints := make(map[string]*models.Constraint)
for constName := range table.Constraints {
existingConstraints[constName] = table.Constraints[constName]
}
// Merge constraints
for constName, srcConst := range srcTable.Constraints {
if _, exists := existingConstraints[constName]; !exists {
// Constraint doesn't exist, add it
newConst := cloneConstraint(srcConst)
table.Constraints[constName] = newConst
r.ConstraintsAdded++
}
}
}
func (r *MergeResult) mergeIndexes(table *models.Table, srcTable *models.Table) {
// Initialize indexes map if nil
if table.Indexes == nil {
table.Indexes = make(map[string]*models.Index)
}
// Create map of existing indexes
existingIndexes := make(map[string]*models.Index)
for idxName := range table.Indexes {
existingIndexes[idxName] = table.Indexes[idxName]
}
// Merge indexes
for idxName, srcIdx := range srcTable.Indexes {
if _, exists := existingIndexes[idxName]; !exists {
// Index doesn't exist, add it
newIdx := cloneIndex(srcIdx)
table.Indexes[idxName] = newIdx
r.IndexesAdded++
}
}
}
func (r *MergeResult) mergeViews(schema *models.Schema, source *models.Schema) {
// Create map of existing views
existingViews := make(map[string]*models.View)
@@ -552,6 +602,8 @@ func GetMergeSummary(result *MergeResult) string {
fmt.Sprintf("Schemas added: %d", result.SchemasAdded),
fmt.Sprintf("Tables added: %d", result.TablesAdded),
fmt.Sprintf("Columns added: %d", result.ColumnsAdded),
fmt.Sprintf("Constraints added: %d", result.ConstraintsAdded),
fmt.Sprintf("Indexes added: %d", result.IndexesAdded),
fmt.Sprintf("Views added: %d", result.ViewsAdded),
fmt.Sprintf("Sequences added: %d", result.SequencesAdded),
fmt.Sprintf("Enums added: %d", result.EnumsAdded),
@@ -560,6 +612,7 @@ func GetMergeSummary(result *MergeResult) string {
}
totalAdded := result.SchemasAdded + result.TablesAdded + result.ColumnsAdded +
result.ConstraintsAdded + result.IndexesAdded +
result.ViewsAdded + result.SequencesAdded + result.EnumsAdded +
result.RelationsAdded + result.DomainsAdded

617
pkg/merge/merge_test.go Normal file
View File

@@ -0,0 +1,617 @@
package merge
import (
"testing"
"git.warky.dev/wdevs/relspecgo/pkg/models"
)
func TestMergeDatabases_NilInputs(t *testing.T) {
result := MergeDatabases(nil, nil, nil)
if result == nil {
t.Fatal("Expected non-nil result")
}
if result.SchemasAdded != 0 {
t.Errorf("Expected 0 schemas added, got %d", result.SchemasAdded)
}
}
func TestMergeDatabases_NewSchema(t *testing.T) {
target := &models.Database{
Schemas: []*models.Schema{
{Name: "public"},
},
}
source := &models.Database{
Schemas: []*models.Schema{
{Name: "auth"},
},
}
result := MergeDatabases(target, source, nil)
if result.SchemasAdded != 1 {
t.Errorf("Expected 1 schema added, got %d", result.SchemasAdded)
}
if len(target.Schemas) != 2 {
t.Errorf("Expected 2 schemas in target, got %d", len(target.Schemas))
}
}
func TestMergeDatabases_ExistingSchema(t *testing.T) {
target := &models.Database{
Schemas: []*models.Schema{
{Name: "public"},
},
}
source := &models.Database{
Schemas: []*models.Schema{
{Name: "public"},
},
}
result := MergeDatabases(target, source, nil)
if result.SchemasAdded != 0 {
t.Errorf("Expected 0 schemas added, got %d", result.SchemasAdded)
}
if len(target.Schemas) != 1 {
t.Errorf("Expected 1 schema in target, got %d", len(target.Schemas))
}
}
func TestMergeTables_NewTable(t *testing.T) {
target := &models.Database{
Schemas: []*models.Schema{
{
Name: "public",
Tables: []*models.Table{
{
Name: "users",
Schema: "public",
Columns: map[string]*models.Column{},
},
},
},
},
}
source := &models.Database{
Schemas: []*models.Schema{
{
Name: "public",
Tables: []*models.Table{
{
Name: "posts",
Schema: "public",
Columns: map[string]*models.Column{},
},
},
},
},
}
result := MergeDatabases(target, source, nil)
if result.TablesAdded != 1 {
t.Errorf("Expected 1 table added, got %d", result.TablesAdded)
}
if len(target.Schemas[0].Tables) != 2 {
t.Errorf("Expected 2 tables in target schema, got %d", len(target.Schemas[0].Tables))
}
}
func TestMergeColumns_NewColumn(t *testing.T) {
target := &models.Database{
Schemas: []*models.Schema{
{
Name: "public",
Tables: []*models.Table{
{
Name: "users",
Schema: "public",
Columns: map[string]*models.Column{
"id": {Name: "id", Type: "int"},
},
},
},
},
},
}
source := &models.Database{
Schemas: []*models.Schema{
{
Name: "public",
Tables: []*models.Table{
{
Name: "users",
Schema: "public",
Columns: map[string]*models.Column{
"email": {Name: "email", Type: "varchar"},
},
},
},
},
},
}
result := MergeDatabases(target, source, nil)
if result.ColumnsAdded != 1 {
t.Errorf("Expected 1 column added, got %d", result.ColumnsAdded)
}
if len(target.Schemas[0].Tables[0].Columns) != 2 {
t.Errorf("Expected 2 columns in target table, got %d", len(target.Schemas[0].Tables[0].Columns))
}
}
func TestMergeConstraints_NewConstraint(t *testing.T) {
target := &models.Database{
Schemas: []*models.Schema{
{
Name: "public",
Tables: []*models.Table{
{
Name: "users",
Schema: "public",
Columns: map[string]*models.Column{},
Constraints: map[string]*models.Constraint{},
},
},
},
},
}
source := &models.Database{
Schemas: []*models.Schema{
{
Name: "public",
Tables: []*models.Table{
{
Name: "users",
Schema: "public",
Columns: map[string]*models.Column{},
Constraints: map[string]*models.Constraint{
"ukey_users_email": {
Type: models.UniqueConstraint,
Columns: []string{"email"},
Name: "ukey_users_email",
},
},
},
},
},
},
}
result := MergeDatabases(target, source, nil)
if result.ConstraintsAdded != 1 {
t.Errorf("Expected 1 constraint added, got %d", result.ConstraintsAdded)
}
if len(target.Schemas[0].Tables[0].Constraints) != 1 {
t.Errorf("Expected 1 constraint in target table, got %d", len(target.Schemas[0].Tables[0].Constraints))
}
}
func TestMergeConstraints_NilConstraintsMap(t *testing.T) {
target := &models.Database{
Schemas: []*models.Schema{
{
Name: "public",
Tables: []*models.Table{
{
Name: "users",
Schema: "public",
Columns: map[string]*models.Column{},
Constraints: nil, // Nil map
},
},
},
},
}
source := &models.Database{
Schemas: []*models.Schema{
{
Name: "public",
Tables: []*models.Table{
{
Name: "users",
Schema: "public",
Columns: map[string]*models.Column{},
Constraints: map[string]*models.Constraint{
"ukey_users_email": {
Type: models.UniqueConstraint,
Columns: []string{"email"},
Name: "ukey_users_email",
},
},
},
},
},
},
}
result := MergeDatabases(target, source, nil)
if result.ConstraintsAdded != 1 {
t.Errorf("Expected 1 constraint added, got %d", result.ConstraintsAdded)
}
if target.Schemas[0].Tables[0].Constraints == nil {
t.Error("Expected constraints map to be initialized")
}
if len(target.Schemas[0].Tables[0].Constraints) != 1 {
t.Errorf("Expected 1 constraint in target table, got %d", len(target.Schemas[0].Tables[0].Constraints))
}
}
func TestMergeIndexes_NewIndex(t *testing.T) {
target := &models.Database{
Schemas: []*models.Schema{
{
Name: "public",
Tables: []*models.Table{
{
Name: "users",
Schema: "public",
Columns: map[string]*models.Column{},
Indexes: map[string]*models.Index{},
},
},
},
},
}
source := &models.Database{
Schemas: []*models.Schema{
{
Name: "public",
Tables: []*models.Table{
{
Name: "users",
Schema: "public",
Columns: map[string]*models.Column{},
Indexes: map[string]*models.Index{
"idx_users_email": {
Name: "idx_users_email",
Columns: []string{"email"},
},
},
},
},
},
},
}
result := MergeDatabases(target, source, nil)
if result.IndexesAdded != 1 {
t.Errorf("Expected 1 index added, got %d", result.IndexesAdded)
}
if len(target.Schemas[0].Tables[0].Indexes) != 1 {
t.Errorf("Expected 1 index in target table, got %d", len(target.Schemas[0].Tables[0].Indexes))
}
}
func TestMergeIndexes_NilIndexesMap(t *testing.T) {
target := &models.Database{
Schemas: []*models.Schema{
{
Name: "public",
Tables: []*models.Table{
{
Name: "users",
Schema: "public",
Columns: map[string]*models.Column{},
Indexes: nil, // Nil map
},
},
},
},
}
source := &models.Database{
Schemas: []*models.Schema{
{
Name: "public",
Tables: []*models.Table{
{
Name: "users",
Schema: "public",
Columns: map[string]*models.Column{},
Indexes: map[string]*models.Index{
"idx_users_email": {
Name: "idx_users_email",
Columns: []string{"email"},
},
},
},
},
},
},
}
result := MergeDatabases(target, source, nil)
if result.IndexesAdded != 1 {
t.Errorf("Expected 1 index added, got %d", result.IndexesAdded)
}
if target.Schemas[0].Tables[0].Indexes == nil {
t.Error("Expected indexes map to be initialized")
}
if len(target.Schemas[0].Tables[0].Indexes) != 1 {
t.Errorf("Expected 1 index in target table, got %d", len(target.Schemas[0].Tables[0].Indexes))
}
}
func TestMergeOptions_SkipTableNames(t *testing.T) {
target := &models.Database{
Schemas: []*models.Schema{
{
Name: "public",
Tables: []*models.Table{
{
Name: "users",
Schema: "public",
Columns: map[string]*models.Column{},
},
},
},
},
}
source := &models.Database{
Schemas: []*models.Schema{
{
Name: "public",
Tables: []*models.Table{
{
Name: "migrations",
Schema: "public",
Columns: map[string]*models.Column{},
},
},
},
},
}
opts := &MergeOptions{
SkipTableNames: map[string]bool{
"migrations": true,
},
}
result := MergeDatabases(target, source, opts)
if result.TablesAdded != 0 {
t.Errorf("Expected 0 tables added (skipped), got %d", result.TablesAdded)
}
if len(target.Schemas[0].Tables) != 1 {
t.Errorf("Expected 1 table in target schema, got %d", len(target.Schemas[0].Tables))
}
}
func TestMergeViews_NewView(t *testing.T) {
target := &models.Database{
Schemas: []*models.Schema{
{
Name: "public",
Views: []*models.View{},
},
},
}
source := &models.Database{
Schemas: []*models.Schema{
{
Name: "public",
Views: []*models.View{
{
Name: "user_summary",
Schema: "public",
Definition: "SELECT * FROM users",
},
},
},
},
}
result := MergeDatabases(target, source, nil)
if result.ViewsAdded != 1 {
t.Errorf("Expected 1 view added, got %d", result.ViewsAdded)
}
if len(target.Schemas[0].Views) != 1 {
t.Errorf("Expected 1 view in target schema, got %d", len(target.Schemas[0].Views))
}
}
func TestMergeEnums_NewEnum(t *testing.T) {
target := &models.Database{
Schemas: []*models.Schema{
{
Name: "public",
Enums: []*models.Enum{},
},
},
}
source := &models.Database{
Schemas: []*models.Schema{
{
Name: "public",
Enums: []*models.Enum{
{
Name: "user_role",
Schema: "public",
Values: []string{"admin", "user"},
},
},
},
},
}
result := MergeDatabases(target, source, nil)
if result.EnumsAdded != 1 {
t.Errorf("Expected 1 enum added, got %d", result.EnumsAdded)
}
if len(target.Schemas[0].Enums) != 1 {
t.Errorf("Expected 1 enum in target schema, got %d", len(target.Schemas[0].Enums))
}
}
func TestMergeDomains_NewDomain(t *testing.T) {
target := &models.Database{
Domains: []*models.Domain{},
}
source := &models.Database{
Domains: []*models.Domain{
{
Name: "auth",
Description: "Authentication domain",
},
},
}
result := MergeDatabases(target, source, nil)
if result.DomainsAdded != 1 {
t.Errorf("Expected 1 domain added, got %d", result.DomainsAdded)
}
if len(target.Domains) != 1 {
t.Errorf("Expected 1 domain in target, got %d", len(target.Domains))
}
}
func TestMergeRelations_NewRelation(t *testing.T) {
target := &models.Database{
Schemas: []*models.Schema{
{
Name: "public",
Relations: []*models.Relationship{},
},
},
}
source := &models.Database{
Schemas: []*models.Schema{
{
Name: "public",
Relations: []*models.Relationship{
{
Name: "fk_posts_user",
Type: models.OneToMany,
FromTable: "posts",
FromColumns: []string{"user_id"},
ToTable: "users",
ToColumns: []string{"id"},
},
},
},
},
}
result := MergeDatabases(target, source, nil)
if result.RelationsAdded != 1 {
t.Errorf("Expected 1 relation added, got %d", result.RelationsAdded)
}
if len(target.Schemas[0].Relations) != 1 {
t.Errorf("Expected 1 relation in target schema, got %d", len(target.Schemas[0].Relations))
}
}
func TestGetMergeSummary(t *testing.T) {
result := &MergeResult{
SchemasAdded: 1,
TablesAdded: 2,
ColumnsAdded: 5,
ConstraintsAdded: 3,
IndexesAdded: 2,
ViewsAdded: 1,
}
summary := GetMergeSummary(result)
if summary == "" {
t.Error("Expected non-empty summary")
}
if len(summary) < 50 {
t.Errorf("Summary seems too short: %s", summary)
}
}
func TestGetMergeSummary_Nil(t *testing.T) {
summary := GetMergeSummary(nil)
if summary == "" {
t.Error("Expected non-empty summary for nil result")
}
}
func TestComplexMerge(t *testing.T) {
// Target with existing structure
target := &models.Database{
Schemas: []*models.Schema{
{
Name: "public",
Tables: []*models.Table{
{
Name: "users",
Schema: "public",
Columns: map[string]*models.Column{
"id": {Name: "id", Type: "int"},
},
Constraints: map[string]*models.Constraint{},
Indexes: map[string]*models.Index{},
},
},
},
},
}
// Source with new columns, constraints, and indexes
source := &models.Database{
Schemas: []*models.Schema{
{
Name: "public",
Tables: []*models.Table{
{
Name: "users",
Schema: "public",
Columns: map[string]*models.Column{
"email": {Name: "email", Type: "varchar"},
"guid": {Name: "guid", Type: "uuid"},
},
Constraints: map[string]*models.Constraint{
"ukey_users_email": {
Type: models.UniqueConstraint,
Columns: []string{"email"},
Name: "ukey_users_email",
},
"ukey_users_guid": {
Type: models.UniqueConstraint,
Columns: []string{"guid"},
Name: "ukey_users_guid",
},
},
Indexes: map[string]*models.Index{
"idx_users_email": {
Name: "idx_users_email",
Columns: []string{"email"},
},
},
},
},
},
},
}
result := MergeDatabases(target, source, nil)
// Verify counts
if result.ColumnsAdded != 2 {
t.Errorf("Expected 2 columns added, got %d", result.ColumnsAdded)
}
if result.ConstraintsAdded != 2 {
t.Errorf("Expected 2 constraints added, got %d", result.ConstraintsAdded)
}
if result.IndexesAdded != 1 {
t.Errorf("Expected 1 index added, got %d", result.IndexesAdded)
}
// Verify target has merged data
table := target.Schemas[0].Tables[0]
if len(table.Columns) != 3 {
t.Errorf("Expected 3 columns in merged table, got %d", len(table.Columns))
}
if len(table.Constraints) != 2 {
t.Errorf("Expected 2 constraints in merged table, got %d", len(table.Constraints))
}
if len(table.Indexes) != 1 {
t.Errorf("Expected 1 index in merged table, got %d", len(table.Indexes))
}
// Verify specific constraint
if _, exists := table.Constraints["ukey_users_guid"]; !exists {
t.Error("Expected ukey_users_guid constraint to exist")
}
}

View File

@@ -4,31 +4,31 @@ import "strings"
var GoToStdTypes = map[string]string{
"bool": "boolean",
"int64": "integer",
"int64": "bigint",
"int": "integer",
"int8": "integer",
"int16": "integer",
"int8": "smallint",
"int16": "smallint",
"int32": "integer",
"uint": "integer",
"uint8": "integer",
"uint16": "integer",
"uint8": "smallint",
"uint16": "smallint",
"uint32": "integer",
"uint64": "integer",
"uintptr": "integer",
"znullint64": "integer",
"uint64": "bigint",
"uintptr": "bigint",
"znullint64": "bigint",
"znullint32": "integer",
"znullbyte": "integer",
"znullbyte": "smallint",
"float64": "double",
"float32": "double",
"complex64": "double",
"complex128": "double",
"customfloat64": "double",
"string": "string",
"Pointer": "integer",
"string": "text",
"Pointer": "bigint",
"[]byte": "blob",
"customdate": "string",
"customtime": "string",
"customtimestamp": "string",
"customdate": "date",
"customtime": "time",
"customtimestamp": "timestamp",
"sqlfloat64": "double",
"sqlfloat16": "double",
"sqluuid": "uuid",
@@ -36,9 +36,9 @@ var GoToStdTypes = map[string]string{
"sqljson": "json",
"sqlint64": "bigint",
"sqlint32": "integer",
"sqlint16": "integer",
"sqlint16": "smallint",
"sqlbool": "boolean",
"sqlstring": "string",
"sqlstring": "text",
"nullablejsonb": "jsonb",
"nullablejson": "json",
"nullableuuid": "uuid",
@@ -67,7 +67,7 @@ var GoToPGSQLTypes = map[string]string{
"float32": "real",
"complex64": "double precision",
"complex128": "double precision",
"customfloat64": "double precisio",
"customfloat64": "double precision",
"string": "text",
"Pointer": "bigint",
"[]byte": "bytea",
@@ -81,9 +81,9 @@ var GoToPGSQLTypes = map[string]string{
"sqljson": "json",
"sqlint64": "bigint",
"sqlint32": "integer",
"sqlint16": "integer",
"sqlint16": "smallint",
"sqlbool": "boolean",
"sqlstring": "string",
"sqlstring": "text",
"nullablejsonb": "jsonb",
"nullablejson": "json",
"nullableuuid": "uuid",

339
pkg/pgsql/datatypes_test.go Normal file
View File

@@ -0,0 +1,339 @@
package pgsql
import (
"testing"
)
func TestValidSQLType(t *testing.T) {
tests := []struct {
name string
sqltype string
want bool
}{
// PostgreSQL types
{"Valid PGSQL bigint", "bigint", true},
{"Valid PGSQL integer", "integer", true},
{"Valid PGSQL text", "text", true},
{"Valid PGSQL boolean", "boolean", true},
{"Valid PGSQL double precision", "double precision", true},
{"Valid PGSQL bytea", "bytea", true},
{"Valid PGSQL uuid", "uuid", true},
{"Valid PGSQL jsonb", "jsonb", true},
{"Valid PGSQL json", "json", true},
{"Valid PGSQL timestamp", "timestamp", true},
{"Valid PGSQL date", "date", true},
{"Valid PGSQL time", "time", true},
{"Valid PGSQL citext", "citext", true},
// Standard types
{"Valid std double", "double", true},
{"Valid std blob", "blob", true},
// Case insensitive
{"Case insensitive BIGINT", "BIGINT", true},
{"Case insensitive TeXt", "TeXt", true},
{"Case insensitive BoOlEaN", "BoOlEaN", true},
// Invalid types
{"Invalid type", "invalidtype", false},
{"Invalid type varchar", "varchar", false},
{"Empty string", "", false},
{"Random string", "foobar", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := ValidSQLType(tt.sqltype)
if got != tt.want {
t.Errorf("ValidSQLType(%q) = %v, want %v", tt.sqltype, got, tt.want)
}
})
}
}
func TestGetSQLType(t *testing.T) {
tests := []struct {
name string
anytype string
want string
}{
// Go types to PostgreSQL types
{"Go bool to boolean", "bool", "boolean"},
{"Go int64 to bigint", "int64", "bigint"},
{"Go int to integer", "int", "integer"},
{"Go string to text", "string", "text"},
{"Go float64 to double precision", "float64", "double precision"},
{"Go float32 to real", "float32", "real"},
{"Go []byte to bytea", "[]byte", "bytea"},
// SQL types remain SQL types
{"SQL bigint", "bigint", "bigint"},
{"SQL integer", "integer", "integer"},
{"SQL text", "text", "text"},
{"SQL boolean", "boolean", "boolean"},
{"SQL uuid", "uuid", "uuid"},
{"SQL jsonb", "jsonb", "jsonb"},
// Case insensitive Go types
{"Case insensitive BOOL", "BOOL", "boolean"},
{"Case insensitive InT64", "InT64", "bigint"},
{"Case insensitive STRING", "STRING", "text"},
// Case insensitive SQL types
{"Case insensitive BIGINT", "BIGINT", "bigint"},
{"Case insensitive TEXT", "TEXT", "text"},
// Custom types
{"Custom sqluuid", "sqluuid", "uuid"},
{"Custom sqljsonb", "sqljsonb", "jsonb"},
{"Custom sqlint64", "sqlint64", "bigint"},
// Unknown types default to text
{"Unknown type varchar", "varchar", "text"},
{"Unknown type foobar", "foobar", "text"},
{"Empty string", "", "text"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := GetSQLType(tt.anytype)
if got != tt.want {
t.Errorf("GetSQLType(%q) = %q, want %q", tt.anytype, got, tt.want)
}
})
}
}
func TestConvertSQLType(t *testing.T) {
tests := []struct {
name string
anytype string
want string
}{
// Go types to PostgreSQL types
{"Go bool to boolean", "bool", "boolean"},
{"Go int64 to bigint", "int64", "bigint"},
{"Go int to integer", "int", "integer"},
{"Go string to text", "string", "text"},
{"Go float64 to double precision", "float64", "double precision"},
{"Go float32 to real", "float32", "real"},
{"Go []byte to bytea", "[]byte", "bytea"},
// SQL types remain SQL types
{"SQL bigint", "bigint", "bigint"},
{"SQL integer", "integer", "integer"},
{"SQL text", "text", "text"},
{"SQL boolean", "boolean", "boolean"},
// Case insensitive
{"Case insensitive BOOL", "BOOL", "boolean"},
{"Case insensitive InT64", "InT64", "bigint"},
// Unknown types remain unchanged (difference from GetSQLType)
{"Unknown type varchar", "varchar", "varchar"},
{"Unknown type foobar", "foobar", "foobar"},
{"Empty string", "", ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := ConvertSQLType(tt.anytype)
if got != tt.want {
t.Errorf("ConvertSQLType(%q) = %q, want %q", tt.anytype, got, tt.want)
}
})
}
}
func TestIsGoType(t *testing.T) {
tests := []struct {
name string
typeName string
want bool
}{
// Go basic types
{"Go bool", "bool", true},
{"Go int64", "int64", true},
{"Go int", "int", true},
{"Go int32", "int32", true},
{"Go int16", "int16", true},
{"Go int8", "int8", true},
{"Go uint", "uint", true},
{"Go uint64", "uint64", true},
{"Go uint32", "uint32", true},
{"Go uint16", "uint16", true},
{"Go uint8", "uint8", true},
{"Go float64", "float64", true},
{"Go float32", "float32", true},
{"Go string", "string", true},
{"Go []byte", "[]byte", true},
// Go custom types
{"Go complex64", "complex64", true},
{"Go complex128", "complex128", true},
{"Go uintptr", "uintptr", true},
{"Go Pointer", "Pointer", true},
// Custom SQL types
{"Custom sqluuid", "sqluuid", true},
{"Custom sqljsonb", "sqljsonb", true},
{"Custom sqlint64", "sqlint64", true},
{"Custom customdate", "customdate", true},
{"Custom customtime", "customtime", true},
// Case insensitive
{"Case insensitive BOOL", "BOOL", true},
{"Case insensitive InT64", "InT64", true},
{"Case insensitive STRING", "STRING", true},
// SQL types (not Go types)
{"SQL bigint", "bigint", false},
{"SQL integer", "integer", false},
{"SQL text", "text", false},
{"SQL boolean", "boolean", false},
// Invalid types
{"Invalid type", "invalidtype", false},
{"Empty string", "", false},
{"Random string", "foobar", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := IsGoType(tt.typeName)
if got != tt.want {
t.Errorf("IsGoType(%q) = %v, want %v", tt.typeName, got, tt.want)
}
})
}
}
func TestGetStdTypeFromGo(t *testing.T) {
tests := []struct {
name string
typeName string
want string
}{
// Go types to standard SQL types
{"Go bool to boolean", "bool", "boolean"},
{"Go int64 to bigint", "int64", "bigint"},
{"Go int to integer", "int", "integer"},
{"Go string to text", "string", "text"},
{"Go float64 to double", "float64", "double"},
{"Go float32 to double", "float32", "double"},
{"Go []byte to blob", "[]byte", "blob"},
{"Go int32 to integer", "int32", "integer"},
{"Go int16 to smallint", "int16", "smallint"},
// Custom types
{"Custom sqluuid to uuid", "sqluuid", "uuid"},
{"Custom sqljsonb to jsonb", "sqljsonb", "jsonb"},
{"Custom sqlint64 to bigint", "sqlint64", "bigint"},
{"Custom customdate to date", "customdate", "date"},
// Case insensitive
{"Case insensitive BOOL", "BOOL", "boolean"},
{"Case insensitive InT64", "InT64", "bigint"},
{"Case insensitive STRING", "STRING", "text"},
// Non-Go types remain unchanged
{"SQL bigint unchanged", "bigint", "bigint"},
{"SQL integer unchanged", "integer", "integer"},
{"Invalid type unchanged", "invalidtype", "invalidtype"},
{"Empty string unchanged", "", ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := GetStdTypeFromGo(tt.typeName)
if got != tt.want {
t.Errorf("GetStdTypeFromGo(%q) = %q, want %q", tt.typeName, got, tt.want)
}
})
}
}
func TestGoToStdTypesMap(t *testing.T) {
// Test that the map contains expected entries
expectedMappings := map[string]string{
"bool": "boolean",
"int64": "bigint",
"int": "integer",
"string": "text",
"float64": "double",
"[]byte": "blob",
}
for goType, expectedStd := range expectedMappings {
if stdType, ok := GoToStdTypes[goType]; !ok {
t.Errorf("GoToStdTypes missing entry for %q", goType)
} else if stdType != expectedStd {
t.Errorf("GoToStdTypes[%q] = %q, want %q", goType, stdType, expectedStd)
}
}
// Test that the map is not empty
if len(GoToStdTypes) == 0 {
t.Error("GoToStdTypes map is empty")
}
}
func TestGoToPGSQLTypesMap(t *testing.T) {
// Test that the map contains expected entries
expectedMappings := map[string]string{
"bool": "boolean",
"int64": "bigint",
"int": "integer",
"string": "text",
"float64": "double precision",
"float32": "real",
"[]byte": "bytea",
}
for goType, expectedPG := range expectedMappings {
if pgType, ok := GoToPGSQLTypes[goType]; !ok {
t.Errorf("GoToPGSQLTypes missing entry for %q", goType)
} else if pgType != expectedPG {
t.Errorf("GoToPGSQLTypes[%q] = %q, want %q", goType, pgType, expectedPG)
}
}
// Test that the map is not empty
if len(GoToPGSQLTypes) == 0 {
t.Error("GoToPGSQLTypes map is empty")
}
}
func TestTypeConversionConsistency(t *testing.T) {
// Test that GetSQLType and ConvertSQLType are consistent for known types
knownGoTypes := []string{"bool", "int64", "int", "string", "float64", "[]byte"}
for _, goType := range knownGoTypes {
getSQLResult := GetSQLType(goType)
convertResult := ConvertSQLType(goType)
if getSQLResult != convertResult {
t.Errorf("Inconsistent results for %q: GetSQLType=%q, ConvertSQLType=%q",
goType, getSQLResult, convertResult)
}
}
}
func TestGetSQLTypeVsConvertSQLTypeDifference(t *testing.T) {
// Test that GetSQLType returns "text" for unknown types
// while ConvertSQLType returns the input unchanged
unknownTypes := []string{"varchar", "char", "customtype", "unknowntype"}
for _, unknown := range unknownTypes {
getSQLResult := GetSQLType(unknown)
convertResult := ConvertSQLType(unknown)
if getSQLResult != "text" {
t.Errorf("GetSQLType(%q) = %q, want %q", unknown, getSQLResult, "text")
}
if convertResult != unknown {
t.Errorf("ConvertSQLType(%q) = %q, want %q", unknown, convertResult, unknown)
}
}
}

136
pkg/pgsql/keywords_test.go Normal file
View File

@@ -0,0 +1,136 @@
package pgsql
import (
"testing"
)
func TestGetPostgresKeywords(t *testing.T) {
keywords := GetPostgresKeywords()
// Test that keywords are returned
if len(keywords) == 0 {
t.Fatal("Expected non-empty list of keywords")
}
// Test that we get all keywords from the map
expectedCount := len(postgresKeywords)
if len(keywords) != expectedCount {
t.Errorf("Expected %d keywords, got %d", expectedCount, len(keywords))
}
// Test that all returned keywords exist in the map
for _, keyword := range keywords {
if !postgresKeywords[keyword] {
t.Errorf("Keyword %q not found in postgresKeywords map", keyword)
}
}
// Test that no duplicate keywords are returned
seen := make(map[string]bool)
for _, keyword := range keywords {
if seen[keyword] {
t.Errorf("Duplicate keyword found: %q", keyword)
}
seen[keyword] = true
}
}
func TestPostgresKeywordsMap(t *testing.T) {
tests := []struct {
name string
keyword string
want bool
}{
{"SELECT keyword", "select", true},
{"FROM keyword", "from", true},
{"WHERE keyword", "where", true},
{"TABLE keyword", "table", true},
{"PRIMARY keyword", "primary", true},
{"FOREIGN keyword", "foreign", true},
{"CREATE keyword", "create", true},
{"DROP keyword", "drop", true},
{"ALTER keyword", "alter", true},
{"INDEX keyword", "index", true},
{"NOT keyword", "not", true},
{"NULL keyword", "null", true},
{"TRUE keyword", "true", true},
{"FALSE keyword", "false", true},
{"Non-keyword lowercase", "notakeyword", false},
{"Non-keyword uppercase", "NOTAKEYWORD", false},
{"Empty string", "", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := postgresKeywords[tt.keyword]
if got != tt.want {
t.Errorf("postgresKeywords[%q] = %v, want %v", tt.keyword, got, tt.want)
}
})
}
}
func TestPostgresKeywordsMapContent(t *testing.T) {
// Test that the map contains expected common keywords
commonKeywords := []string{
"select", "insert", "update", "delete", "create", "drop", "alter",
"table", "index", "view", "schema", "function", "procedure",
"primary", "foreign", "key", "constraint", "unique", "check",
"null", "not", "and", "or", "like", "in", "between",
"join", "inner", "left", "right", "cross", "full", "outer",
"where", "having", "group", "order", "limit", "offset",
"union", "intersect", "except",
"begin", "commit", "rollback", "transaction",
}
for _, keyword := range commonKeywords {
if !postgresKeywords[keyword] {
t.Errorf("Expected common keyword %q to be in postgresKeywords map", keyword)
}
}
}
func TestPostgresKeywordsMapSize(t *testing.T) {
// PostgreSQL has a substantial list of reserved keywords
// This test ensures the map has a reasonable number of entries
minExpectedKeywords := 200 // PostgreSQL 13+ has 400+ reserved words
if len(postgresKeywords) < minExpectedKeywords {
t.Errorf("Expected at least %d keywords, got %d. The map may be incomplete.",
minExpectedKeywords, len(postgresKeywords))
}
}
func TestGetPostgresKeywordsConsistency(t *testing.T) {
// Test that calling GetPostgresKeywords multiple times returns consistent results
keywords1 := GetPostgresKeywords()
keywords2 := GetPostgresKeywords()
if len(keywords1) != len(keywords2) {
t.Errorf("Inconsistent results: first call returned %d keywords, second call returned %d",
len(keywords1), len(keywords2))
}
// Create a map from both results to compare
map1 := make(map[string]bool)
map2 := make(map[string]bool)
for _, k := range keywords1 {
map1[k] = true
}
for _, k := range keywords2 {
map2[k] = true
}
// Check that both contain the same keywords
for k := range map1 {
if !map2[k] {
t.Errorf("Keyword %q present in first call but not in second", k)
}
}
for k := range map2 {
if !map1[k] {
t.Errorf("Keyword %q present in second call but not in first", k)
}
}
}

View File

@@ -128,6 +128,46 @@ func (r *Reader) readDirectoryDBML(dirPath string) (*models.Database, error) {
return db, nil
}
// splitIdentifier splits a dotted identifier while respecting quotes
// Handles cases like: "schema.with.dots"."table"."column"
func splitIdentifier(s string) []string {
var parts []string
var current strings.Builder
inQuote := false
quoteChar := byte(0)
for i := 0; i < len(s); i++ {
ch := s[i]
if !inQuote {
switch ch {
case '"', '\'':
inQuote = true
quoteChar = ch
current.WriteByte(ch)
case '.':
if current.Len() > 0 {
parts = append(parts, current.String())
current.Reset()
}
default:
current.WriteByte(ch)
}
} else {
current.WriteByte(ch)
if ch == quoteChar {
inQuote = false
}
}
}
if current.Len() > 0 {
parts = append(parts, current.String())
}
return parts
}
// stripQuotes removes surrounding quotes and comments from an identifier
func stripQuotes(s string) string {
s = strings.TrimSpace(s)
@@ -409,7 +449,9 @@ func (r *Reader) parseDBML(content string) (*models.Database, error) {
// Parse Table definition
if matches := tableRegex.FindStringSubmatch(line); matches != nil {
tableName := matches[1]
parts := strings.Split(tableName, ".")
// Strip comments/notes before parsing to avoid dots in notes
tableName = strings.TrimSpace(regexp.MustCompile(`\s*\[.*?\]\s*`).ReplaceAllString(tableName, ""))
parts := splitIdentifier(tableName)
if len(parts) == 2 {
currentSchema = stripQuotes(parts[0])
@@ -561,8 +603,10 @@ func (r *Reader) parseColumn(line, tableName, schemaName string) (*models.Column
column.Default = strings.Trim(defaultVal, "'\"")
} else if attr == "unique" {
// Create a unique constraint
// Clean table name by removing leading underscores to avoid double underscores
cleanTableName := strings.TrimLeft(tableName, "_")
uniqueConstraint := models.InitConstraint(
fmt.Sprintf("uq_%s", columnName),
fmt.Sprintf("ukey_%s_%s", cleanTableName, columnName),
models.UniqueConstraint,
)
uniqueConstraint.Schema = schemaName
@@ -587,10 +631,10 @@ func (r *Reader) parseColumn(line, tableName, schemaName string) (*models.Column
refOp := strings.TrimSpace(refStr)
var isReverse bool
if strings.HasPrefix(refOp, "<") {
isReverse = column.IsPrimaryKey // < on PK means "is referenced by" (reverse)
} else if strings.HasPrefix(refOp, ">") {
isReverse = !column.IsPrimaryKey // > on FK means reverse
// < means "is referenced by" - only makes sense on PK columns
isReverse = column.IsPrimaryKey
}
// > means "references" - always a forward FK, never reverse
constraint = r.parseRef(refStr)
if constraint != nil {
@@ -610,8 +654,8 @@ func (r *Reader) parseColumn(line, tableName, schemaName string) (*models.Column
constraint.Table = tableName
constraint.Columns = []string{columnName}
}
// Generate short constraint name based on the column
constraint.Name = fmt.Sprintf("fk_%s", constraint.Columns[0])
// Generate constraint name based on table and columns
constraint.Name = fmt.Sprintf("fk_%s_%s", constraint.Table, strings.Join(constraint.Columns, "_"))
}
}
}
@@ -695,7 +739,11 @@ func (r *Reader) parseIndex(line, tableName, schemaName string) *models.Index {
// Generate name if not provided
if index.Name == "" {
index.Name = fmt.Sprintf("idx_%s_%s", tableName, strings.Join(columns, "_"))
prefix := "idx"
if index.Unique {
prefix = "uidx"
}
index.Name = fmt.Sprintf("%s_%s_%s", prefix, tableName, strings.Join(columns, "_"))
}
return index
@@ -755,10 +803,10 @@ func (r *Reader) parseRef(refStr string) *models.Constraint {
return nil
}
// Generate short constraint name based on the source column
constraintName := fmt.Sprintf("fk_%s_%s", fromTable, toTable)
if len(fromColumns) > 0 {
constraintName = fmt.Sprintf("fk_%s", fromColumns[0])
// Generate constraint name based on table and columns
constraintName := fmt.Sprintf("fk_%s_%s", fromTable, strings.Join(fromColumns, "_"))
if len(fromColumns) == 0 {
constraintName = fmt.Sprintf("fk_%s_%s", fromTable, toTable)
}
constraint := models.InitConstraint(
@@ -814,7 +862,7 @@ func (r *Reader) parseTableRef(ref string) (schema, table string, columns []stri
}
// Parse schema, table, and optionally column
parts := strings.Split(strings.TrimSpace(ref), ".")
parts := splitIdentifier(strings.TrimSpace(ref))
if len(parts) == 3 {
// Format: "schema"."table"."column"
schema = stripQuotes(parts[0])

View File

@@ -777,6 +777,76 @@ func TestParseFilePrefix(t *testing.T) {
}
}
func TestConstraintNaming(t *testing.T) {
// Test that constraints are named with proper prefixes
opts := &readers.ReaderOptions{
FilePath: filepath.Join("..", "..", "..", "tests", "assets", "dbml", "complex.dbml"),
}
reader := NewReader(opts)
db, err := reader.ReadDatabase()
if err != nil {
t.Fatalf("ReadDatabase() error = %v", err)
}
// Find users table
var usersTable *models.Table
var postsTable *models.Table
for _, schema := range db.Schemas {
for _, table := range schema.Tables {
if table.Name == "users" {
usersTable = table
} else if table.Name == "posts" {
postsTable = table
}
}
}
if usersTable == nil {
t.Fatal("Users table not found")
}
if postsTable == nil {
t.Fatal("Posts table not found")
}
// Test unique constraint naming: ukey_table_column
if _, exists := usersTable.Constraints["ukey_users_email"]; !exists {
t.Error("Expected unique constraint 'ukey_users_email' not found")
t.Logf("Available constraints: %v", getKeys(usersTable.Constraints))
}
if _, exists := postsTable.Constraints["ukey_posts_slug"]; !exists {
t.Error("Expected unique constraint 'ukey_posts_slug' not found")
t.Logf("Available constraints: %v", getKeys(postsTable.Constraints))
}
// Test foreign key naming: fk_table_column
if _, exists := postsTable.Constraints["fk_posts_user_id"]; !exists {
t.Error("Expected foreign key 'fk_posts_user_id' not found")
t.Logf("Available constraints: %v", getKeys(postsTable.Constraints))
}
// Test unique index naming: uidx_table_columns
if _, exists := postsTable.Indexes["uidx_posts_slug"]; !exists {
t.Error("Expected unique index 'uidx_posts_slug' not found")
t.Logf("Available indexes: %v", getKeys(postsTable.Indexes))
}
// Test regular index naming: idx_table_columns
if _, exists := postsTable.Indexes["idx_posts_user_id_published"]; !exists {
t.Error("Expected index 'idx_posts_user_id_published' not found")
t.Logf("Available indexes: %v", getKeys(postsTable.Indexes))
}
}
func getKeys[V any](m map[string]V) []string {
keys := make([]string, 0, len(m))
for k := range m {
keys = append(keys, k)
}
return keys
}
func TestHasCommentedRefs(t *testing.T) {
// Test with the actual multifile test fixtures
tests := []struct {

View File

@@ -329,10 +329,10 @@ func (r *Reader) deriveRelationship(table *models.Table, fk *models.Constraint)
relationshipName := fmt.Sprintf("%s_to_%s", table.Name, fk.ReferencedTable)
relationship := models.InitRelationship(relationshipName, models.OneToMany)
relationship.FromTable = fk.ReferencedTable
relationship.FromSchema = fk.ReferencedSchema
relationship.ToTable = table.Name
relationship.ToSchema = table.Schema
relationship.FromTable = table.Name
relationship.FromSchema = table.Schema
relationship.ToTable = fk.ReferencedTable
relationship.ToSchema = fk.ReferencedSchema
relationship.ForeignKey = fk.Name
// Store constraint actions in properties

View File

@@ -328,12 +328,12 @@ func TestDeriveRelationship(t *testing.T) {
t.Errorf("Expected relationship type %s, got %s", models.OneToMany, rel.Type)
}
if rel.FromTable != "users" {
t.Errorf("Expected FromTable 'users', got '%s'", rel.FromTable)
if rel.FromTable != "orders" {
t.Errorf("Expected FromTable 'orders', got '%s'", rel.FromTable)
}
if rel.ToTable != "orders" {
t.Errorf("Expected ToTable 'orders', got '%s'", rel.ToTable)
if rel.ToTable != "users" {
t.Errorf("Expected ToTable 'users', got '%s'", rel.ToTable)
}
if rel.ForeignKey != "fk_orders_user_id" {

View File

@@ -93,6 +93,7 @@ fmt.Printf("Found %d scripts\n", len(schema.Scripts))
## Features
- **Recursive Directory Scanning**: Automatically scans all subdirectories
- **Symlink Skipping**: Symbolic links are automatically skipped (prevents loops and duplicates)
- **Multiple Extensions**: Supports both `.sql` and `.pgsql` files
- **Flexible Naming**: Extract metadata from filename patterns
- **Error Handling**: Validates directory existence and file accessibility
@@ -153,8 +154,9 @@ go test ./pkg/readers/sqldir/
```
Tests include:
- Valid file parsing
- Valid file parsing (underscore and hyphen formats)
- Recursive directory scanning
- Symlink skipping
- Invalid filename handling
- Empty directory handling
- Error conditions

View File

@@ -107,11 +107,20 @@ func (r *Reader) readScripts() ([]*models.Script, error) {
return err
}
// Skip directories
// Don't process directories as files (WalkDir still descends into them recursively)
if d.IsDir() {
return nil
}
// Skip symlinks
info, err := d.Info()
if err != nil {
return err
}
if info.Mode()&os.ModeSymlink != 0 {
return nil
}
// Get filename
filename := d.Name()

View File

@@ -373,3 +373,65 @@ func TestReader_MixedFormat(t *testing.T) {
}
}
}
func TestReader_SkipSymlinks(t *testing.T) {
// Create temporary test directory
tempDir, err := os.MkdirTemp("", "sqldir-test-symlink-*")
if err != nil {
t.Fatalf("Failed to create temp directory: %v", err)
}
defer os.RemoveAll(tempDir)
// Create a real SQL file
realFile := filepath.Join(tempDir, "1_001_real_file.sql")
if err := os.WriteFile(realFile, []byte("SELECT 1;"), 0644); err != nil {
t.Fatalf("Failed to create real file: %v", err)
}
// Create another file to link to
targetFile := filepath.Join(tempDir, "2_001_target.sql")
if err := os.WriteFile(targetFile, []byte("SELECT 2;"), 0644); err != nil {
t.Fatalf("Failed to create target file: %v", err)
}
// Create a symlink to the target file (this should be skipped)
symlinkFile := filepath.Join(tempDir, "3_001_symlink.sql")
if err := os.Symlink(targetFile, symlinkFile); err != nil {
// Skip test on systems that don't support symlinks (e.g., Windows without admin)
t.Skipf("Symlink creation not supported: %v", err)
}
// Create reader
reader := NewReader(&readers.ReaderOptions{
FilePath: tempDir,
})
// Read database
db, err := reader.ReadDatabase()
if err != nil {
t.Fatalf("ReadDatabase failed: %v", err)
}
schema := db.Schemas[0]
// Should only have 2 scripts (real_file and target), symlink should be skipped
if len(schema.Scripts) != 2 {
t.Errorf("Expected 2 scripts (symlink should be skipped), got %d", len(schema.Scripts))
}
// Verify the scripts are the real files, not the symlink
scriptNames := make(map[string]bool)
for _, script := range schema.Scripts {
scriptNames[script.Name] = true
}
if !scriptNames["real_file"] {
t.Error("Expected 'real_file' script to be present")
}
if !scriptNames["target"] {
t.Error("Expected 'target' script to be present")
}
if scriptNames["symlink"] {
t.Error("Symlink script should have been skipped but was found")
}
}

View File

@@ -0,0 +1,490 @@
package reflectutil
import (
"reflect"
"testing"
)
type testStruct struct {
Name string
Age int
Active bool
Nested *nestedStruct
Private string
}
type nestedStruct struct {
Value string
Count int
}
func TestDeref(t *testing.T) {
tests := []struct {
name string
input interface{}
wantValid bool
wantKind reflect.Kind
}{
{
name: "non-pointer int",
input: 42,
wantValid: true,
wantKind: reflect.Int,
},
{
name: "single pointer",
input: ptrInt(42),
wantValid: true,
wantKind: reflect.Int,
},
{
name: "double pointer",
input: ptrPtr(ptrInt(42)),
wantValid: true,
wantKind: reflect.Int,
},
{
name: "nil pointer",
input: (*int)(nil),
wantValid: false,
wantKind: reflect.Ptr,
},
{
name: "string",
input: "test",
wantValid: true,
wantKind: reflect.String,
},
{
name: "struct",
input: testStruct{Name: "test"},
wantValid: true,
wantKind: reflect.Struct,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
v := reflect.ValueOf(tt.input)
got, valid := Deref(v)
if valid != tt.wantValid {
t.Errorf("Deref() valid = %v, want %v", valid, tt.wantValid)
}
if got.Kind() != tt.wantKind {
t.Errorf("Deref() kind = %v, want %v", got.Kind(), tt.wantKind)
}
})
}
}
func TestDerefInterface(t *testing.T) {
i := 42
pi := &i
ppi := &pi
tests := []struct {
name string
input interface{}
wantKind reflect.Kind
}{
{"int", 42, reflect.Int},
{"pointer to int", &i, reflect.Int},
{"double pointer to int", ppi, reflect.Int},
{"string", "test", reflect.String},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := DerefInterface(tt.input)
if got.Kind() != tt.wantKind {
t.Errorf("DerefInterface() kind = %v, want %v", got.Kind(), tt.wantKind)
}
})
}
}
func TestGetFieldValue(t *testing.T) {
ts := testStruct{
Name: "John",
Age: 30,
Active: true,
Nested: &nestedStruct{Value: "nested", Count: 5},
}
tests := []struct {
name string
item interface{}
field string
want interface{}
}{
{"struct field Name", ts, "Name", "John"},
{"struct field Age", ts, "Age", 30},
{"struct field Active", ts, "Active", true},
{"struct non-existent field", ts, "NonExistent", nil},
{"pointer to struct", &ts, "Name", "John"},
{"map string key", map[string]string{"key": "value"}, "key", "value"},
{"map int key", map[string]int{"count": 42}, "count", 42},
{"map non-existent key", map[string]string{"key": "value"}, "missing", nil},
{"nil pointer", (*testStruct)(nil), "Name", nil},
{"non-struct non-map", 42, "field", nil},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := GetFieldValue(tt.item, tt.field)
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("GetFieldValue() = %v, want %v", got, tt.want)
}
})
}
}
func TestIsSliceOrArray(t *testing.T) {
arr := [3]int{1, 2, 3}
tests := []struct {
name string
input interface{}
want bool
}{
{"slice", []int{1, 2, 3}, true},
{"array", arr, true},
{"pointer to slice", &[]int{1, 2, 3}, true},
{"string", "test", false},
{"int", 42, false},
{"map", map[string]int{}, false},
{"nil slice", ([]int)(nil), true}, // nil slice is still Kind==Slice
{"nil pointer", (*[]int)(nil), false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := IsSliceOrArray(tt.input)
if got != tt.want {
t.Errorf("IsSliceOrArray() = %v, want %v", got, tt.want)
}
})
}
}
func TestIsMap(t *testing.T) {
tests := []struct {
name string
input interface{}
want bool
}{
{"map[string]int", map[string]int{"a": 1}, true},
{"map[int]string", map[int]string{1: "a"}, true},
{"pointer to map", &map[string]int{"a": 1}, true},
{"slice", []int{1, 2, 3}, false},
{"string", "test", false},
{"int", 42, false},
{"nil map", (map[string]int)(nil), true}, // nil map is still Kind==Map
{"nil pointer", (*map[string]int)(nil), false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := IsMap(tt.input)
if got != tt.want {
t.Errorf("IsMap() = %v, want %v", got, tt.want)
}
})
}
}
func TestSliceLen(t *testing.T) {
arr := [3]int{1, 2, 3}
tests := []struct {
name string
input interface{}
want int
}{
{"slice length 3", []int{1, 2, 3}, 3},
{"empty slice", []int{}, 0},
{"array length 3", arr, 3},
{"pointer to slice", &[]int{1, 2, 3}, 3},
{"not a slice", "test", 0},
{"int", 42, 0},
{"nil slice", ([]int)(nil), 0},
{"nil pointer", (*[]int)(nil), 0},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := SliceLen(tt.input)
if got != tt.want {
t.Errorf("SliceLen() = %v, want %v", got, tt.want)
}
})
}
}
func TestMapLen(t *testing.T) {
tests := []struct {
name string
input interface{}
want int
}{
{"map length 2", map[string]int{"a": 1, "b": 2}, 2},
{"empty map", map[string]int{}, 0},
{"pointer to map", &map[string]int{"a": 1}, 1},
{"not a map", []int{1, 2, 3}, 0},
{"string", "test", 0},
{"nil map", (map[string]int)(nil), 0},
{"nil pointer", (*map[string]int)(nil), 0},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := MapLen(tt.input)
if got != tt.want {
t.Errorf("MapLen() = %v, want %v", got, tt.want)
}
})
}
}
func TestSliceToInterfaces(t *testing.T) {
tests := []struct {
name string
input interface{}
want []interface{}
}{
{"int slice", []int{1, 2, 3}, []interface{}{1, 2, 3}},
{"string slice", []string{"a", "b"}, []interface{}{"a", "b"}},
{"empty slice", []int{}, []interface{}{}},
{"pointer to slice", &[]int{1, 2}, []interface{}{1, 2}},
{"not a slice", "test", []interface{}{}},
{"nil slice", ([]int)(nil), []interface{}{}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := SliceToInterfaces(tt.input)
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("SliceToInterfaces() = %v, want %v", got, tt.want)
}
})
}
}
func TestMapKeys(t *testing.T) {
tests := []struct {
name string
input interface{}
want []interface{}
}{
{"map with keys", map[string]int{"a": 1, "b": 2}, []interface{}{"a", "b"}},
{"empty map", map[string]int{}, []interface{}{}},
{"not a map", []int{1, 2, 3}, []interface{}{}},
{"nil map", (map[string]int)(nil), []interface{}{}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := MapKeys(tt.input)
if len(got) != len(tt.want) {
t.Errorf("MapKeys() length = %v, want %v", len(got), len(tt.want))
}
// For maps, order is not guaranteed, so just check length
})
}
}
func TestMapValues(t *testing.T) {
tests := []struct {
name string
input interface{}
want int // length of values
}{
{"map with values", map[string]int{"a": 1, "b": 2}, 2},
{"empty map", map[string]int{}, 0},
{"not a map", []int{1, 2, 3}, 0},
{"nil map", (map[string]int)(nil), 0},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := MapValues(tt.input)
if len(got) != tt.want {
t.Errorf("MapValues() length = %v, want %v", len(got), tt.want)
}
})
}
}
func TestMapGet(t *testing.T) {
m := map[string]int{"a": 1, "b": 2}
tests := []struct {
name string
input interface{}
key interface{}
want interface{}
}{
{"existing key", m, "a", 1},
{"existing key b", m, "b", 2},
{"non-existing key", m, "c", nil},
{"pointer to map", &m, "a", 1},
{"not a map", []int{1, 2}, 0, nil},
{"nil map", (map[string]int)(nil), "a", nil},
{"nil pointer", (*map[string]int)(nil), "a", nil},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := MapGet(tt.input, tt.key)
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("MapGet() = %v, want %v", got, tt.want)
}
})
}
}
func TestSliceIndex(t *testing.T) {
s := []int{10, 20, 30}
tests := []struct {
name string
slice interface{}
index int
want interface{}
}{
{"index 0", s, 0, 10},
{"index 1", s, 1, 20},
{"index 2", s, 2, 30},
{"negative index", s, -1, nil},
{"out of bounds", s, 5, nil},
{"pointer to slice", &s, 1, 20},
{"not a slice", "test", 0, nil},
{"nil slice", ([]int)(nil), 0, nil},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := SliceIndex(tt.slice, tt.index)
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("SliceIndex() = %v, want %v", got, tt.want)
}
})
}
}
func TestCompareValues(t *testing.T) {
tests := []struct {
name string
a interface{}
b interface{}
want int
}{
{"both nil", nil, nil, 0},
{"a nil", nil, 5, -1},
{"b nil", 5, nil, 1},
{"equal strings", "abc", "abc", 0},
{"a less than b strings", "abc", "xyz", -1},
{"a greater than b strings", "xyz", "abc", 1},
{"equal ints", 5, 5, 0},
{"a less than b ints", 3, 7, -1},
{"a greater than b ints", 10, 5, 1},
{"equal floats", 3.14, 3.14, 0},
{"a less than b floats", 2.5, 5.5, -1},
{"a greater than b floats", 10.5, 5.5, 1},
{"equal uints", uint(5), uint(5), 0},
{"different types", "abc", 123, 0},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := CompareValues(tt.a, tt.b)
if got != tt.want {
t.Errorf("CompareValues(%v, %v) = %v, want %v", tt.a, tt.b, got, tt.want)
}
})
}
}
func TestGetNestedValue(t *testing.T) {
nested := map[string]interface{}{
"level1": map[string]interface{}{
"level2": map[string]interface{}{
"value": "deep",
},
},
}
ts := testStruct{
Name: "John",
Nested: &nestedStruct{
Value: "nested value",
Count: 42,
},
}
tests := []struct {
name string
input interface{}
path string
want interface{}
}{
{"empty path", nested, "", nested},
{"single level map", nested, "level1", nested["level1"]},
{"nested map", nested, "level1.level2", map[string]interface{}{"value": "deep"}},
{"deep nested map", nested, "level1.level2.value", "deep"},
{"struct field", ts, "Name", "John"},
{"nested struct field", ts, "Nested", ts.Nested},
{"non-existent path", nested, "missing.path", nil},
{"nil input", nil, "path", nil},
{"partial missing path", nested, "level1.missing", nil},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := GetNestedValue(tt.input, tt.path)
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("GetNestedValue() = %v, want %v", got, tt.want)
}
})
}
}
func TestDeepEqual(t *testing.T) {
tests := []struct {
name string
a interface{}
b interface{}
want bool
}{
{"equal ints", 42, 42, true},
{"different ints", 42, 43, false},
{"equal strings", "test", "test", true},
{"different strings", "test", "other", false},
{"equal slices", []int{1, 2, 3}, []int{1, 2, 3}, true},
{"different slices", []int{1, 2, 3}, []int{1, 2, 4}, false},
{"equal maps", map[string]int{"a": 1}, map[string]int{"a": 1}, true},
{"different maps", map[string]int{"a": 1}, map[string]int{"a": 2}, false},
{"both nil", nil, nil, true},
{"one nil", nil, 42, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := DeepEqual(tt.a, tt.b)
if got != tt.want {
t.Errorf("DeepEqual(%v, %v) = %v, want %v", tt.a, tt.b, got, tt.want)
}
})
}
}
// Helper functions
func ptrInt(i int) *int {
return &i
}
func ptrPtr(p *int) **int {
return &p
}

View File

@@ -106,19 +106,20 @@ func (td *TemplateData) FinalizeImports() {
}
// NewModelData creates a new ModelData from a models.Table
func NewModelData(table *models.Table, schema string, typeMapper *TypeMapper) *ModelData {
tableName := table.Name
if schema != "" {
tableName = schema + "." + table.Name
}
func NewModelData(table *models.Table, schema string, typeMapper *TypeMapper, flattenSchema bool) *ModelData {
tableName := writers.QualifiedTableName(schema, table.Name, flattenSchema)
// Generate model name: singularize and convert to PascalCase
// Generate model name: Model + Schema + Table (all PascalCase)
singularTable := Singularize(table.Name)
modelName := SnakeCaseToPascalCase(singularTable)
tablePart := SnakeCaseToPascalCase(singularTable)
// Add "Model" prefix if not already present
if !hasModelPrefix(modelName) {
modelName = "Model" + modelName
// Include schema name in model name
var modelName string
if schema != "" {
schemaPart := SnakeCaseToPascalCase(schema)
modelName = "Model" + schemaPart + tablePart
} else {
modelName = "Model" + tablePart
}
model := &ModelData{
@@ -149,6 +150,8 @@ func NewModelData(table *models.Table, schema string, typeMapper *TypeMapper) *M
columns := sortColumns(table.Columns)
for _, col := range columns {
field := columnToField(col, table, typeMapper)
// Check for name collision with generated methods and rename if needed
field.Name = resolveFieldNameCollision(field.Name)
model.Fields = append(model.Fields, field)
}
@@ -190,9 +193,28 @@ func formatComment(description, comment string) string {
return comment
}
// hasModelPrefix checks if a name already has "Model" prefix
func hasModelPrefix(name string) bool {
return len(name) >= 5 && name[:5] == "Model"
// resolveFieldNameCollision checks if a field name conflicts with generated method names
// and adds an underscore suffix if there's a collision
func resolveFieldNameCollision(fieldName string) string {
// List of method names that are generated by the template
reservedNames := map[string]bool{
"TableName": true,
"TableNameOnly": true,
"SchemaName": true,
"GetID": true,
"GetIDStr": true,
"SetID": true,
"UpdateID": true,
"GetIDName": true,
"GetPrefix": true,
}
// Check if field name conflicts with a reserved method name
if reservedNames[fieldName] {
return fieldName + "_"
}
return fieldName
}
// sortColumns sorts columns by sequence, then by name

View File

@@ -4,6 +4,7 @@ import (
"fmt"
"go/format"
"os"
"os/exec"
"path/filepath"
"strings"
@@ -85,7 +86,7 @@ func (w *Writer) writeSingleFile(db *models.Database) error {
// Collect all models
for _, schema := range db.Schemas {
for _, table := range schema.Tables {
modelData := NewModelData(table, schema.Name, w.typeMapper)
modelData := NewModelData(table, schema.Name, w.typeMapper, w.options.FlattenSchema)
// Add relationship fields
w.addRelationshipFields(modelData, table, schema, db)
@@ -124,7 +125,16 @@ func (w *Writer) writeSingleFile(db *models.Database) error {
}
// Write output
return w.writeOutput(formatted)
if err := w.writeOutput(formatted); err != nil {
return err
}
// Run go fmt on the output file
if w.options.OutputPath != "" {
w.runGoFmt(w.options.OutputPath)
}
return nil
}
// writeMultiFile writes each table to a separate file
@@ -171,7 +181,7 @@ func (w *Writer) writeMultiFile(db *models.Database) error {
templateData.AddImport(fmt.Sprintf("resolvespec_common \"%s\"", w.typeMapper.GetSQLTypesImport()))
// Create model data
modelData := NewModelData(table, schema.Name, w.typeMapper)
modelData := NewModelData(table, schema.Name, w.typeMapper, w.options.FlattenSchema)
// Add relationship fields
w.addRelationshipFields(modelData, table, schema, db)
@@ -217,6 +227,9 @@ func (w *Writer) writeMultiFile(db *models.Database) error {
if err := os.WriteFile(filepath, []byte(formatted), 0644); err != nil {
return fmt.Errorf("failed to write file %s: %w", filename, err)
}
// Run go fmt on the generated file
w.runGoFmt(filepath)
}
}
@@ -225,6 +238,9 @@ func (w *Writer) writeMultiFile(db *models.Database) error {
// addRelationshipFields adds relationship fields to the model based on foreign keys
func (w *Writer) addRelationshipFields(modelData *ModelData, table *models.Table, schema *models.Schema, db *models.Database) {
// Track used field names to detect duplicates
usedFieldNames := make(map[string]int)
// For each foreign key in this table, add a belongs-to/has-one relationship
for _, constraint := range table.Constraints {
if constraint.Type != models.ForeignKeyConstraint {
@@ -238,8 +254,9 @@ func (w *Writer) addRelationshipFields(modelData *ModelData, table *models.Table
}
// Create relationship field (has-one in Bun, similar to belongs-to in GORM)
refModelName := w.getModelName(constraint.ReferencedTable)
fieldName := w.generateRelationshipFieldName(constraint.ReferencedTable)
refModelName := w.getModelName(constraint.ReferencedSchema, constraint.ReferencedTable)
fieldName := w.generateHasOneFieldName(constraint)
fieldName = w.ensureUniqueFieldName(fieldName, usedFieldNames)
relationTag := w.typeMapper.BuildRelationshipTag(constraint, "has-one")
modelData.AddRelationshipField(&FieldData{
@@ -266,8 +283,9 @@ func (w *Writer) addRelationshipFields(modelData *ModelData, table *models.Table
// Check if this constraint references our table
if constraint.ReferencedTable == table.Name && constraint.ReferencedSchema == schema.Name {
// Add has-many relationship
otherModelName := w.getModelName(otherTable.Name)
fieldName := w.generateRelationshipFieldName(otherTable.Name) + "s" // Pluralize
otherModelName := w.getModelName(otherSchema.Name, otherTable.Name)
fieldName := w.generateHasManyFieldName(constraint, otherSchema.Name, otherTable.Name)
fieldName = w.ensureUniqueFieldName(fieldName, usedFieldNames)
relationTag := w.typeMapper.BuildRelationshipTag(constraint, "has-many")
modelData.AddRelationshipField(&FieldData{
@@ -298,22 +316,77 @@ func (w *Writer) findTable(schemaName, tableName string, db *models.Database) *m
return nil
}
// getModelName generates the model name from a table name
func (w *Writer) getModelName(tableName string) string {
// getModelName generates the model name from schema and table name
func (w *Writer) getModelName(schemaName, tableName string) string {
singular := Singularize(tableName)
modelName := SnakeCaseToPascalCase(singular)
tablePart := SnakeCaseToPascalCase(singular)
if !hasModelPrefix(modelName) {
modelName = "Model" + modelName
// Include schema name in model name
var modelName string
if schemaName != "" {
schemaPart := SnakeCaseToPascalCase(schemaName)
modelName = "Model" + schemaPart + tablePart
} else {
modelName = "Model" + tablePart
}
return modelName
}
// generateRelationshipFieldName generates a field name for a relationship
func (w *Writer) generateRelationshipFieldName(tableName string) string {
// Use just the prefix (3 letters) for relationship fields
return GeneratePrefix(tableName)
// generateHasOneFieldName generates a field name for has-one relationships
// Uses the foreign key column name for uniqueness
func (w *Writer) generateHasOneFieldName(constraint *models.Constraint) string {
// Use the foreign key column name to ensure uniqueness
// If there are multiple columns, use the first one
if len(constraint.Columns) > 0 {
columnName := constraint.Columns[0]
// Convert to PascalCase for proper Go field naming
// e.g., "rid_filepointer_request" -> "RelRIDFilepointerRequest"
return "Rel" + SnakeCaseToPascalCase(columnName)
}
// Fallback to table-based prefix if no columns defined
return "Rel" + GeneratePrefix(constraint.ReferencedTable)
}
// generateHasManyFieldName generates a field name for has-many relationships
// Uses the foreign key column name + source table name to avoid duplicates
func (w *Writer) generateHasManyFieldName(constraint *models.Constraint, sourceSchemaName, sourceTableName string) string {
// For has-many, we need to include the source table name to avoid duplicates
// e.g., multiple tables referencing the same column on this table
if len(constraint.Columns) > 0 {
columnName := constraint.Columns[0]
// Get the model name for the source table (pluralized)
sourceModelName := w.getModelName(sourceSchemaName, sourceTableName)
// Remove "Model" prefix if present
sourceModelName = strings.TrimPrefix(sourceModelName, "Model")
// Convert column to PascalCase and combine with source table
// e.g., "rid_api_provider" + "Login" -> "RelRIDAPIProviderLogins"
columnPart := SnakeCaseToPascalCase(columnName)
return "Rel" + columnPart + Pluralize(sourceModelName)
}
// Fallback to table-based naming
sourceModelName := w.getModelName(sourceSchemaName, sourceTableName)
sourceModelName = strings.TrimPrefix(sourceModelName, "Model")
return "Rel" + Pluralize(sourceModelName)
}
// ensureUniqueFieldName ensures a field name is unique by adding numeric suffixes if needed
func (w *Writer) ensureUniqueFieldName(fieldName string, usedNames map[string]int) string {
originalName := fieldName
count := usedNames[originalName]
if count > 0 {
// Name is already used, add numeric suffix
fieldName = fmt.Sprintf("%s%d", originalName, count+1)
}
// Increment the counter for this base name
usedNames[originalName]++
return fieldName
}
// getPackageName returns the package name from options or defaults to "models"
@@ -344,6 +417,15 @@ func (w *Writer) writeOutput(content string) error {
return nil
}
// runGoFmt runs go fmt on the specified file
func (w *Writer) runGoFmt(filepath string) {
cmd := exec.Command("gofmt", "-w", filepath)
if err := cmd.Run(); err != nil {
// Don't fail the whole operation if gofmt fails, just warn
fmt.Fprintf(os.Stderr, "Warning: failed to run gofmt on %s: %v\n", filepath, err)
}
}
// shouldUseMultiFile determines whether to use multi-file mode based on metadata or output path
func (w *Writer) shouldUseMultiFile() bool {
// Check if multi_file is explicitly set in metadata

View File

@@ -66,7 +66,7 @@ func TestWriter_WriteTable(t *testing.T) {
// Verify key elements are present
expectations := []string{
"package models",
"type ModelUser struct",
"type ModelPublicUser struct",
"bun.BaseModel",
"table:public.users",
"alias:users",
@@ -78,9 +78,9 @@ func TestWriter_WriteTable(t *testing.T) {
"resolvespec_common.SqlTime",
"bun:\"id",
"bun:\"email",
"func (m ModelUser) TableName() string",
"func (m ModelPublicUser) TableName() string",
"return \"public.users\"",
"func (m ModelUser) GetID() int64",
"func (m ModelPublicUser) GetID() int64",
}
for _, expected := range expectations {
@@ -175,12 +175,378 @@ func TestWriter_WriteDatabase_MultiFile(t *testing.T) {
postsStr := string(postsContent)
// Verify relationship is present with Bun format
if !strings.Contains(postsStr, "USE") {
t.Errorf("Missing relationship field USE")
// Should now be RelUserID (has-one) instead of USE
if !strings.Contains(postsStr, "RelUserID") {
t.Errorf("Missing relationship field RelUserID (new naming convention)")
}
if !strings.Contains(postsStr, "rel:has-one") {
t.Errorf("Missing Bun relationship tag: %s", postsStr)
}
// Check users file contains has-many relationship
usersContent, err := os.ReadFile(filepath.Join(tmpDir, "sql_public_users.go"))
if err != nil {
t.Fatalf("Failed to read users file: %v", err)
}
usersStr := string(usersContent)
// Should have RelUserIDPublicPosts (has-many) field - includes schema prefix
if !strings.Contains(usersStr, "RelUserIDPublicPosts") {
t.Errorf("Missing has-many relationship field RelUserIDPublicPosts")
}
}
func TestWriter_MultipleReferencesToSameTable(t *testing.T) {
// Test scenario: api_event table with multiple foreign keys to filepointer table
db := models.InitDatabase("testdb")
schema := models.InitSchema("org")
// Filepointer table
filepointer := models.InitTable("filepointer", "org")
filepointer.Columns["id_filepointer"] = &models.Column{
Name: "id_filepointer",
Type: "bigserial",
NotNull: true,
IsPrimaryKey: true,
}
schema.Tables = append(schema.Tables, filepointer)
// API event table with two foreign keys to filepointer
apiEvent := models.InitTable("api_event", "org")
apiEvent.Columns["id_api_event"] = &models.Column{
Name: "id_api_event",
Type: "bigserial",
NotNull: true,
IsPrimaryKey: true,
}
apiEvent.Columns["rid_filepointer_request"] = &models.Column{
Name: "rid_filepointer_request",
Type: "bigint",
NotNull: false,
}
apiEvent.Columns["rid_filepointer_response"] = &models.Column{
Name: "rid_filepointer_response",
Type: "bigint",
NotNull: false,
}
// Add constraints
apiEvent.Constraints["fk_request"] = &models.Constraint{
Name: "fk_request",
Type: models.ForeignKeyConstraint,
Columns: []string{"rid_filepointer_request"},
ReferencedTable: "filepointer",
ReferencedSchema: "org",
ReferencedColumns: []string{"id_filepointer"},
}
apiEvent.Constraints["fk_response"] = &models.Constraint{
Name: "fk_response",
Type: models.ForeignKeyConstraint,
Columns: []string{"rid_filepointer_response"},
ReferencedTable: "filepointer",
ReferencedSchema: "org",
ReferencedColumns: []string{"id_filepointer"},
}
schema.Tables = append(schema.Tables, apiEvent)
db.Schemas = append(db.Schemas, schema)
// Create writer
tmpDir := t.TempDir()
opts := &writers.WriterOptions{
PackageName: "models",
OutputPath: tmpDir,
Metadata: map[string]interface{}{
"multi_file": true,
},
}
writer := NewWriter(opts)
err := writer.WriteDatabase(db)
if err != nil {
t.Fatalf("WriteDatabase failed: %v", err)
}
// Read the api_event file
apiEventContent, err := os.ReadFile(filepath.Join(tmpDir, "sql_org_api_event.go"))
if err != nil {
t.Fatalf("Failed to read api_event file: %v", err)
}
contentStr := string(apiEventContent)
// Verify both relationships have unique names based on column names
expectations := []struct {
fieldName string
tag string
}{
{"RelRIDFilepointerRequest", "join:rid_filepointer_request=id_filepointer"},
{"RelRIDFilepointerResponse", "join:rid_filepointer_response=id_filepointer"},
}
for _, exp := range expectations {
if !strings.Contains(contentStr, exp.fieldName) {
t.Errorf("Missing relationship field: %s\nGenerated:\n%s", exp.fieldName, contentStr)
}
if !strings.Contains(contentStr, exp.tag) {
t.Errorf("Missing relationship tag: %s\nGenerated:\n%s", exp.tag, contentStr)
}
}
// Verify NO duplicate field names (old behavior would create duplicate "FIL" fields)
if strings.Contains(contentStr, "FIL *ModelFilepointer") {
t.Errorf("Found old prefix-based naming (FIL), should use column-based naming")
}
// Also verify has-many relationships on filepointer table
filepointerContent, err := os.ReadFile(filepath.Join(tmpDir, "sql_org_filepointer.go"))
if err != nil {
t.Fatalf("Failed to read filepointer file: %v", err)
}
filepointerStr := string(filepointerContent)
// Should have two different has-many relationships with unique names
hasManyExpectations := []string{
"RelRIDFilepointerRequestOrgAPIEvents", // Has many via rid_filepointer_request
"RelRIDFilepointerResponseOrgAPIEvents", // Has many via rid_filepointer_response
}
for _, exp := range hasManyExpectations {
if !strings.Contains(filepointerStr, exp) {
t.Errorf("Missing has-many relationship field: %s\nGenerated:\n%s", exp, filepointerStr)
}
}
}
func TestWriter_MultipleHasManyRelationships(t *testing.T) {
// Test scenario: api_provider table referenced by multiple tables via rid_api_provider
db := models.InitDatabase("testdb")
schema := models.InitSchema("org")
// Owner table
owner := models.InitTable("owner", "org")
owner.Columns["id_owner"] = &models.Column{
Name: "id_owner",
Type: "bigserial",
NotNull: true,
IsPrimaryKey: true,
}
schema.Tables = append(schema.Tables, owner)
// API Provider table
apiProvider := models.InitTable("api_provider", "org")
apiProvider.Columns["id_api_provider"] = &models.Column{
Name: "id_api_provider",
Type: "bigserial",
NotNull: true,
IsPrimaryKey: true,
}
apiProvider.Columns["rid_owner"] = &models.Column{
Name: "rid_owner",
Type: "bigint",
NotNull: true,
}
apiProvider.Constraints["fk_owner"] = &models.Constraint{
Name: "fk_owner",
Type: models.ForeignKeyConstraint,
Columns: []string{"rid_owner"},
ReferencedTable: "owner",
ReferencedSchema: "org",
ReferencedColumns: []string{"id_owner"},
}
schema.Tables = append(schema.Tables, apiProvider)
// Login table
login := models.InitTable("login", "org")
login.Columns["id_login"] = &models.Column{
Name: "id_login",
Type: "bigserial",
NotNull: true,
IsPrimaryKey: true,
}
login.Columns["rid_api_provider"] = &models.Column{
Name: "rid_api_provider",
Type: "bigint",
NotNull: true,
}
login.Constraints["fk_api_provider"] = &models.Constraint{
Name: "fk_api_provider",
Type: models.ForeignKeyConstraint,
Columns: []string{"rid_api_provider"},
ReferencedTable: "api_provider",
ReferencedSchema: "org",
ReferencedColumns: []string{"id_api_provider"},
}
schema.Tables = append(schema.Tables, login)
// Filepointer table
filepointer := models.InitTable("filepointer", "org")
filepointer.Columns["id_filepointer"] = &models.Column{
Name: "id_filepointer",
Type: "bigserial",
NotNull: true,
IsPrimaryKey: true,
}
filepointer.Columns["rid_api_provider"] = &models.Column{
Name: "rid_api_provider",
Type: "bigint",
NotNull: true,
}
filepointer.Constraints["fk_api_provider"] = &models.Constraint{
Name: "fk_api_provider",
Type: models.ForeignKeyConstraint,
Columns: []string{"rid_api_provider"},
ReferencedTable: "api_provider",
ReferencedSchema: "org",
ReferencedColumns: []string{"id_api_provider"},
}
schema.Tables = append(schema.Tables, filepointer)
// API Event table
apiEvent := models.InitTable("api_event", "org")
apiEvent.Columns["id_api_event"] = &models.Column{
Name: "id_api_event",
Type: "bigserial",
NotNull: true,
IsPrimaryKey: true,
}
apiEvent.Columns["rid_api_provider"] = &models.Column{
Name: "rid_api_provider",
Type: "bigint",
NotNull: true,
}
apiEvent.Constraints["fk_api_provider"] = &models.Constraint{
Name: "fk_api_provider",
Type: models.ForeignKeyConstraint,
Columns: []string{"rid_api_provider"},
ReferencedTable: "api_provider",
ReferencedSchema: "org",
ReferencedColumns: []string{"id_api_provider"},
}
schema.Tables = append(schema.Tables, apiEvent)
db.Schemas = append(db.Schemas, schema)
// Create writer
tmpDir := t.TempDir()
opts := &writers.WriterOptions{
PackageName: "models",
OutputPath: tmpDir,
Metadata: map[string]interface{}{
"multi_file": true,
},
}
writer := NewWriter(opts)
err := writer.WriteDatabase(db)
if err != nil {
t.Fatalf("WriteDatabase failed: %v", err)
}
// Read the api_provider file
apiProviderContent, err := os.ReadFile(filepath.Join(tmpDir, "sql_org_api_provider.go"))
if err != nil {
t.Fatalf("Failed to read api_provider file: %v", err)
}
contentStr := string(apiProviderContent)
// Verify all has-many relationships have unique names
hasManyExpectations := []string{
"RelRIDAPIProviderOrgLogins", // Has many via Login
"RelRIDAPIProviderOrgFilepointers", // Has many via Filepointer
"RelRIDAPIProviderOrgAPIEvents", // Has many via APIEvent
"RelRIDOwner", // Has one via rid_owner
}
for _, exp := range hasManyExpectations {
if !strings.Contains(contentStr, exp) {
t.Errorf("Missing relationship field: %s\nGenerated:\n%s", exp, contentStr)
}
}
// Verify NO duplicate field names
// Count occurrences of "RelRIDAPIProvider" fields - should have 3 unique ones
count := strings.Count(contentStr, "RelRIDAPIProvider")
if count != 3 {
t.Errorf("Expected 3 RelRIDAPIProvider* fields, found %d\nGenerated:\n%s", count, contentStr)
}
// Verify no duplicate declarations (would cause compilation error)
duplicatePattern := "RelRIDAPIProviders []*Model"
if strings.Contains(contentStr, duplicatePattern) {
t.Errorf("Found duplicate field declaration pattern, fields should be unique")
}
}
func TestWriter_FieldNameCollision(t *testing.T) {
// Test scenario: table with columns that would conflict with generated method names
table := models.InitTable("audit_table", "audit")
table.Columns["id_audit_table"] = &models.Column{
Name: "id_audit_table",
Type: "smallint",
NotNull: true,
IsPrimaryKey: true,
Sequence: 1,
}
table.Columns["table_name"] = &models.Column{
Name: "table_name",
Type: "varchar",
Length: 100,
NotNull: true,
Sequence: 2,
}
table.Columns["table_schema"] = &models.Column{
Name: "table_schema",
Type: "varchar",
Length: 100,
NotNull: true,
Sequence: 3,
}
// Create writer
tmpDir := t.TempDir()
opts := &writers.WriterOptions{
PackageName: "models",
OutputPath: filepath.Join(tmpDir, "test.go"),
}
writer := NewWriter(opts)
err := writer.WriteTable(table)
if err != nil {
t.Fatalf("WriteTable failed: %v", err)
}
// Read the generated file
content, err := os.ReadFile(opts.OutputPath)
if err != nil {
t.Fatalf("Failed to read generated file: %v", err)
}
generated := string(content)
// Verify that TableName field was renamed to TableName_ to avoid collision
if !strings.Contains(generated, "TableName_") {
t.Errorf("Expected field 'TableName_' (with underscore) but not found\nGenerated:\n%s", generated)
}
// Verify the struct tag still references the correct database column
if !strings.Contains(generated, `bun:"table_name,`) {
t.Errorf("Expected bun tag to reference 'table_name' column\nGenerated:\n%s", generated)
}
// Verify the TableName() method still exists and doesn't conflict
if !strings.Contains(generated, "func (m ModelAuditAuditTable) TableName() string") {
t.Errorf("TableName() method should still be generated\nGenerated:\n%s", generated)
}
// Verify NO field named just "TableName" (without underscore)
if strings.Contains(generated, "TableName resolvespec_common") || strings.Contains(generated, "TableName string") {
t.Errorf("Field 'TableName' without underscore should not exist (would conflict with method)\nGenerated:\n%s", generated)
}
}
func TestTypeMapper_SQLTypeToGoType_Bun(t *testing.T) {

View File

@@ -25,6 +25,7 @@ type ModelData struct {
Fields []*FieldData
Config *MethodConfig
PrimaryKeyField string // Name of the primary key field
PrimaryKeyType string // Go type of the primary key field
IDColumnName string // Name of the ID column in database
Prefix string // 3-letter prefix
}
@@ -104,19 +105,20 @@ func (td *TemplateData) FinalizeImports() {
}
// NewModelData creates a new ModelData from a models.Table
func NewModelData(table *models.Table, schema string, typeMapper *TypeMapper) *ModelData {
tableName := table.Name
if schema != "" {
tableName = schema + "." + table.Name
}
func NewModelData(table *models.Table, schema string, typeMapper *TypeMapper, flattenSchema bool) *ModelData {
tableName := writers.QualifiedTableName(schema, table.Name, flattenSchema)
// Generate model name: singularize and convert to PascalCase
// Generate model name: Model + Schema + Table (all PascalCase)
singularTable := Singularize(table.Name)
modelName := SnakeCaseToPascalCase(singularTable)
tablePart := SnakeCaseToPascalCase(singularTable)
// Add "Model" prefix if not already present
if !hasModelPrefix(modelName) {
modelName = "Model" + modelName
// Include schema name in model name
var modelName string
if schema != "" {
schemaPart := SnakeCaseToPascalCase(schema)
modelName = "Model" + schemaPart + tablePart
} else {
modelName = "Model" + tablePart
}
model := &ModelData{
@@ -135,6 +137,7 @@ func NewModelData(table *models.Table, schema string, typeMapper *TypeMapper) *M
// Sanitize column name to remove backticks
safeName := writers.SanitizeStructTagValue(col.Name)
model.PrimaryKeyField = SnakeCaseToPascalCase(safeName)
model.PrimaryKeyType = typeMapper.SQLTypeToGoType(col.Type, col.NotNull)
model.IDColumnName = safeName
break
}
@@ -144,6 +147,8 @@ func NewModelData(table *models.Table, schema string, typeMapper *TypeMapper) *M
columns := sortColumns(table.Columns)
for _, col := range columns {
field := columnToField(col, table, typeMapper)
// Check for name collision with generated methods and rename if needed
field.Name = resolveFieldNameCollision(field.Name)
model.Fields = append(model.Fields, field)
}
@@ -185,9 +190,28 @@ func formatComment(description, comment string) string {
return comment
}
// hasModelPrefix checks if a name already has "Model" prefix
func hasModelPrefix(name string) bool {
return len(name) >= 5 && name[:5] == "Model"
// resolveFieldNameCollision checks if a field name conflicts with generated method names
// and adds an underscore suffix if there's a collision
func resolveFieldNameCollision(fieldName string) string {
// List of method names that are generated by the template
reservedNames := map[string]bool{
"TableName": true,
"TableNameOnly": true,
"SchemaName": true,
"GetID": true,
"GetIDStr": true,
"SetID": true,
"UpdateID": true,
"GetIDName": true,
"GetPrefix": true,
}
// Check if field name conflicts with a reserved method name
if reservedNames[fieldName] {
return fieldName + "_"
}
return fieldName
}
// sortColumns sorts columns by sequence, then by name

View File

@@ -62,7 +62,7 @@ func (m {{.Name}}) SetID(newid int64) {
{{if and .Config.GenerateUpdateID .PrimaryKeyField}}
// UpdateID updates the primary key value
func (m *{{.Name}}) UpdateID(newid int64) {
m.{{.PrimaryKeyField}} = int32(newid)
m.{{.PrimaryKeyField}} = {{.PrimaryKeyType}}(newid)
}
{{end}}
{{if and .Config.GenerateGetIDName .IDColumnName}}

View File

@@ -4,6 +4,7 @@ import (
"fmt"
"go/format"
"os"
"os/exec"
"path/filepath"
"strings"
@@ -82,7 +83,7 @@ func (w *Writer) writeSingleFile(db *models.Database) error {
// Collect all models
for _, schema := range db.Schemas {
for _, table := range schema.Tables {
modelData := NewModelData(table, schema.Name, w.typeMapper)
modelData := NewModelData(table, schema.Name, w.typeMapper, w.options.FlattenSchema)
// Add relationship fields
w.addRelationshipFields(modelData, table, schema, db)
@@ -121,7 +122,16 @@ func (w *Writer) writeSingleFile(db *models.Database) error {
}
// Write output
return w.writeOutput(formatted)
if err := w.writeOutput(formatted); err != nil {
return err
}
// Run go fmt on the output file
if w.options.OutputPath != "" {
w.runGoFmt(w.options.OutputPath)
}
return nil
}
// writeMultiFile writes each table to a separate file
@@ -165,7 +175,7 @@ func (w *Writer) writeMultiFile(db *models.Database) error {
templateData.AddImport(fmt.Sprintf("sql_types \"%s\"", w.typeMapper.GetSQLTypesImport()))
// Create model data
modelData := NewModelData(table, schema.Name, w.typeMapper)
modelData := NewModelData(table, schema.Name, w.typeMapper, w.options.FlattenSchema)
// Add relationship fields
w.addRelationshipFields(modelData, table, schema, db)
@@ -211,6 +221,9 @@ func (w *Writer) writeMultiFile(db *models.Database) error {
if err := os.WriteFile(filepath, []byte(formatted), 0644); err != nil {
return fmt.Errorf("failed to write file %s: %w", filename, err)
}
// Run go fmt on the generated file
w.runGoFmt(filepath)
}
}
@@ -219,6 +232,9 @@ func (w *Writer) writeMultiFile(db *models.Database) error {
// addRelationshipFields adds relationship fields to the model based on foreign keys
func (w *Writer) addRelationshipFields(modelData *ModelData, table *models.Table, schema *models.Schema, db *models.Database) {
// Track used field names to detect duplicates
usedFieldNames := make(map[string]int)
// For each foreign key in this table, add a belongs-to relationship
for _, constraint := range table.Constraints {
if constraint.Type != models.ForeignKeyConstraint {
@@ -232,8 +248,9 @@ func (w *Writer) addRelationshipFields(modelData *ModelData, table *models.Table
}
// Create relationship field (belongs-to)
refModelName := w.getModelName(constraint.ReferencedTable)
fieldName := w.generateRelationshipFieldName(constraint.ReferencedTable)
refModelName := w.getModelName(constraint.ReferencedSchema, constraint.ReferencedTable)
fieldName := w.generateBelongsToFieldName(constraint)
fieldName = w.ensureUniqueFieldName(fieldName, usedFieldNames)
relationTag := w.typeMapper.BuildRelationshipTag(constraint, false)
modelData.AddRelationshipField(&FieldData{
@@ -260,8 +277,9 @@ func (w *Writer) addRelationshipFields(modelData *ModelData, table *models.Table
// Check if this constraint references our table
if constraint.ReferencedTable == table.Name && constraint.ReferencedSchema == schema.Name {
// Add has-many relationship
otherModelName := w.getModelName(otherTable.Name)
fieldName := w.generateRelationshipFieldName(otherTable.Name) + "s" // Pluralize
otherModelName := w.getModelName(otherSchema.Name, otherTable.Name)
fieldName := w.generateHasManyFieldName(constraint, otherSchema.Name, otherTable.Name)
fieldName = w.ensureUniqueFieldName(fieldName, usedFieldNames)
relationTag := w.typeMapper.BuildRelationshipTag(constraint, true)
modelData.AddRelationshipField(&FieldData{
@@ -292,22 +310,77 @@ func (w *Writer) findTable(schemaName, tableName string, db *models.Database) *m
return nil
}
// getModelName generates the model name from a table name
func (w *Writer) getModelName(tableName string) string {
// getModelName generates the model name from schema and table name
func (w *Writer) getModelName(schemaName, tableName string) string {
singular := Singularize(tableName)
modelName := SnakeCaseToPascalCase(singular)
tablePart := SnakeCaseToPascalCase(singular)
if !hasModelPrefix(modelName) {
modelName = "Model" + modelName
// Include schema name in model name
var modelName string
if schemaName != "" {
schemaPart := SnakeCaseToPascalCase(schemaName)
modelName = "Model" + schemaPart + tablePart
} else {
modelName = "Model" + tablePart
}
return modelName
}
// generateRelationshipFieldName generates a field name for a relationship
func (w *Writer) generateRelationshipFieldName(tableName string) string {
// Use just the prefix (3 letters) for relationship fields
return GeneratePrefix(tableName)
// generateBelongsToFieldName generates a field name for belongs-to relationships
// Uses the foreign key column name for uniqueness
func (w *Writer) generateBelongsToFieldName(constraint *models.Constraint) string {
// Use the foreign key column name to ensure uniqueness
// If there are multiple columns, use the first one
if len(constraint.Columns) > 0 {
columnName := constraint.Columns[0]
// Convert to PascalCase for proper Go field naming
// e.g., "rid_filepointer_request" -> "RelRIDFilepointerRequest"
return "Rel" + SnakeCaseToPascalCase(columnName)
}
// Fallback to table-based prefix if no columns defined
return "Rel" + GeneratePrefix(constraint.ReferencedTable)
}
// generateHasManyFieldName generates a field name for has-many relationships
// Uses the foreign key column name + source table name to avoid duplicates
func (w *Writer) generateHasManyFieldName(constraint *models.Constraint, sourceSchemaName, sourceTableName string) string {
// For has-many, we need to include the source table name to avoid duplicates
// e.g., multiple tables referencing the same column on this table
if len(constraint.Columns) > 0 {
columnName := constraint.Columns[0]
// Get the model name for the source table (pluralized)
sourceModelName := w.getModelName(sourceSchemaName, sourceTableName)
// Remove "Model" prefix if present
sourceModelName = strings.TrimPrefix(sourceModelName, "Model")
// Convert column to PascalCase and combine with source table
// e.g., "rid_api_provider" + "Login" -> "RelRIDAPIProviderLogins"
columnPart := SnakeCaseToPascalCase(columnName)
return "Rel" + columnPart + Pluralize(sourceModelName)
}
// Fallback to table-based naming
sourceModelName := w.getModelName(sourceSchemaName, sourceTableName)
sourceModelName = strings.TrimPrefix(sourceModelName, "Model")
return "Rel" + Pluralize(sourceModelName)
}
// ensureUniqueFieldName ensures a field name is unique by adding numeric suffixes if needed
func (w *Writer) ensureUniqueFieldName(fieldName string, usedNames map[string]int) string {
originalName := fieldName
count := usedNames[originalName]
if count > 0 {
// Name is already used, add numeric suffix
fieldName = fmt.Sprintf("%s%d", originalName, count+1)
}
// Increment the counter for this base name
usedNames[originalName]++
return fieldName
}
// getPackageName returns the package name from options or defaults to "models"
@@ -338,6 +411,15 @@ func (w *Writer) writeOutput(content string) error {
return nil
}
// runGoFmt runs go fmt on the specified file
func (w *Writer) runGoFmt(filepath string) {
cmd := exec.Command("gofmt", "-w", filepath)
if err := cmd.Run(); err != nil {
// Don't fail the whole operation if gofmt fails, just warn
fmt.Fprintf(os.Stderr, "Warning: failed to run gofmt on %s: %v\n", filepath, err)
}
}
// shouldUseMultiFile determines whether to use multi-file mode based on metadata or output path
func (w *Writer) shouldUseMultiFile() bool {
// Check if multi_file is explicitly set in metadata

View File

@@ -66,7 +66,7 @@ func TestWriter_WriteTable(t *testing.T) {
// Verify key elements are present
expectations := []string{
"package models",
"type ModelUser struct",
"type ModelPublicUser struct",
"ID",
"int64",
"Email",
@@ -75,9 +75,9 @@ func TestWriter_WriteTable(t *testing.T) {
"time.Time",
"gorm:\"column:id",
"gorm:\"column:email",
"func (m ModelUser) TableName() string",
"func (m ModelPublicUser) TableName() string",
"return \"public.users\"",
"func (m ModelUser) GetID() int64",
"func (m ModelPublicUser) GetID() int64",
}
for _, expected := range expectations {
@@ -164,9 +164,437 @@ func TestWriter_WriteDatabase_MultiFile(t *testing.T) {
t.Fatalf("Failed to read posts file: %v", err)
}
if !strings.Contains(string(postsContent), "USE *ModelUser") {
// Relationship field should be present
t.Logf("Posts content:\n%s", string(postsContent))
postsStr := string(postsContent)
// Verify relationship is present with new naming convention
// Should now be RelUserID (belongs-to) instead of USE
if !strings.Contains(postsStr, "RelUserID") {
t.Errorf("Missing relationship field RelUserID (new naming convention)")
}
// Check users file contains has-many relationship
usersContent, err := os.ReadFile(filepath.Join(tmpDir, "sql_public_users.go"))
if err != nil {
t.Fatalf("Failed to read users file: %v", err)
}
usersStr := string(usersContent)
// Should have RelUserIDPublicPosts (has-many) field - includes schema prefix
if !strings.Contains(usersStr, "RelUserIDPublicPosts") {
t.Errorf("Missing has-many relationship field RelUserIDPublicPosts")
}
}
func TestWriter_MultipleReferencesToSameTable(t *testing.T) {
// Test scenario: api_event table with multiple foreign keys to filepointer table
db := models.InitDatabase("testdb")
schema := models.InitSchema("org")
// Filepointer table
filepointer := models.InitTable("filepointer", "org")
filepointer.Columns["id_filepointer"] = &models.Column{
Name: "id_filepointer",
Type: "bigserial",
NotNull: true,
IsPrimaryKey: true,
}
schema.Tables = append(schema.Tables, filepointer)
// API event table with two foreign keys to filepointer
apiEvent := models.InitTable("api_event", "org")
apiEvent.Columns["id_api_event"] = &models.Column{
Name: "id_api_event",
Type: "bigserial",
NotNull: true,
IsPrimaryKey: true,
}
apiEvent.Columns["rid_filepointer_request"] = &models.Column{
Name: "rid_filepointer_request",
Type: "bigint",
NotNull: false,
}
apiEvent.Columns["rid_filepointer_response"] = &models.Column{
Name: "rid_filepointer_response",
Type: "bigint",
NotNull: false,
}
// Add constraints
apiEvent.Constraints["fk_request"] = &models.Constraint{
Name: "fk_request",
Type: models.ForeignKeyConstraint,
Columns: []string{"rid_filepointer_request"},
ReferencedTable: "filepointer",
ReferencedSchema: "org",
ReferencedColumns: []string{"id_filepointer"},
}
apiEvent.Constraints["fk_response"] = &models.Constraint{
Name: "fk_response",
Type: models.ForeignKeyConstraint,
Columns: []string{"rid_filepointer_response"},
ReferencedTable: "filepointer",
ReferencedSchema: "org",
ReferencedColumns: []string{"id_filepointer"},
}
schema.Tables = append(schema.Tables, apiEvent)
db.Schemas = append(db.Schemas, schema)
// Create writer
tmpDir := t.TempDir()
opts := &writers.WriterOptions{
PackageName: "models",
OutputPath: tmpDir,
Metadata: map[string]interface{}{
"multi_file": true,
},
}
writer := NewWriter(opts)
err := writer.WriteDatabase(db)
if err != nil {
t.Fatalf("WriteDatabase failed: %v", err)
}
// Read the api_event file
apiEventContent, err := os.ReadFile(filepath.Join(tmpDir, "sql_org_api_event.go"))
if err != nil {
t.Fatalf("Failed to read api_event file: %v", err)
}
contentStr := string(apiEventContent)
// Verify both relationships have unique names based on column names
expectations := []struct {
fieldName string
tag string
}{
{"RelRIDFilepointerRequest", "foreignKey:RIDFilepointerRequest"},
{"RelRIDFilepointerResponse", "foreignKey:RIDFilepointerResponse"},
}
for _, exp := range expectations {
if !strings.Contains(contentStr, exp.fieldName) {
t.Errorf("Missing relationship field: %s\nGenerated:\n%s", exp.fieldName, contentStr)
}
if !strings.Contains(contentStr, exp.tag) {
t.Errorf("Missing relationship tag: %s\nGenerated:\n%s", exp.tag, contentStr)
}
}
// Verify NO duplicate field names (old behavior would create duplicate "FIL" fields)
if strings.Contains(contentStr, "FIL *ModelFilepointer") {
t.Errorf("Found old prefix-based naming (FIL), should use column-based naming")
}
// Also verify has-many relationships on filepointer table
filepointerContent, err := os.ReadFile(filepath.Join(tmpDir, "sql_org_filepointer.go"))
if err != nil {
t.Fatalf("Failed to read filepointer file: %v", err)
}
filepointerStr := string(filepointerContent)
// Should have two different has-many relationships with unique names
hasManyExpectations := []string{
"RelRIDFilepointerRequestOrgAPIEvents", // Has many via rid_filepointer_request
"RelRIDFilepointerResponseOrgAPIEvents", // Has many via rid_filepointer_response
}
for _, exp := range hasManyExpectations {
if !strings.Contains(filepointerStr, exp) {
t.Errorf("Missing has-many relationship field: %s\nGenerated:\n%s", exp, filepointerStr)
}
}
}
func TestWriter_MultipleHasManyRelationships(t *testing.T) {
// Test scenario: api_provider table referenced by multiple tables via rid_api_provider
db := models.InitDatabase("testdb")
schema := models.InitSchema("org")
// Owner table
owner := models.InitTable("owner", "org")
owner.Columns["id_owner"] = &models.Column{
Name: "id_owner",
Type: "bigserial",
NotNull: true,
IsPrimaryKey: true,
}
schema.Tables = append(schema.Tables, owner)
// API Provider table
apiProvider := models.InitTable("api_provider", "org")
apiProvider.Columns["id_api_provider"] = &models.Column{
Name: "id_api_provider",
Type: "bigserial",
NotNull: true,
IsPrimaryKey: true,
}
apiProvider.Columns["rid_owner"] = &models.Column{
Name: "rid_owner",
Type: "bigint",
NotNull: true,
}
apiProvider.Constraints["fk_owner"] = &models.Constraint{
Name: "fk_owner",
Type: models.ForeignKeyConstraint,
Columns: []string{"rid_owner"},
ReferencedTable: "owner",
ReferencedSchema: "org",
ReferencedColumns: []string{"id_owner"},
}
schema.Tables = append(schema.Tables, apiProvider)
// Login table
login := models.InitTable("login", "org")
login.Columns["id_login"] = &models.Column{
Name: "id_login",
Type: "bigserial",
NotNull: true,
IsPrimaryKey: true,
}
login.Columns["rid_api_provider"] = &models.Column{
Name: "rid_api_provider",
Type: "bigint",
NotNull: true,
}
login.Constraints["fk_api_provider"] = &models.Constraint{
Name: "fk_api_provider",
Type: models.ForeignKeyConstraint,
Columns: []string{"rid_api_provider"},
ReferencedTable: "api_provider",
ReferencedSchema: "org",
ReferencedColumns: []string{"id_api_provider"},
}
schema.Tables = append(schema.Tables, login)
// Filepointer table
filepointer := models.InitTable("filepointer", "org")
filepointer.Columns["id_filepointer"] = &models.Column{
Name: "id_filepointer",
Type: "bigserial",
NotNull: true,
IsPrimaryKey: true,
}
filepointer.Columns["rid_api_provider"] = &models.Column{
Name: "rid_api_provider",
Type: "bigint",
NotNull: true,
}
filepointer.Constraints["fk_api_provider"] = &models.Constraint{
Name: "fk_api_provider",
Type: models.ForeignKeyConstraint,
Columns: []string{"rid_api_provider"},
ReferencedTable: "api_provider",
ReferencedSchema: "org",
ReferencedColumns: []string{"id_api_provider"},
}
schema.Tables = append(schema.Tables, filepointer)
// API Event table
apiEvent := models.InitTable("api_event", "org")
apiEvent.Columns["id_api_event"] = &models.Column{
Name: "id_api_event",
Type: "bigserial",
NotNull: true,
IsPrimaryKey: true,
}
apiEvent.Columns["rid_api_provider"] = &models.Column{
Name: "rid_api_provider",
Type: "bigint",
NotNull: true,
}
apiEvent.Constraints["fk_api_provider"] = &models.Constraint{
Name: "fk_api_provider",
Type: models.ForeignKeyConstraint,
Columns: []string{"rid_api_provider"},
ReferencedTable: "api_provider",
ReferencedSchema: "org",
ReferencedColumns: []string{"id_api_provider"},
}
schema.Tables = append(schema.Tables, apiEvent)
db.Schemas = append(db.Schemas, schema)
// Create writer
tmpDir := t.TempDir()
opts := &writers.WriterOptions{
PackageName: "models",
OutputPath: tmpDir,
Metadata: map[string]interface{}{
"multi_file": true,
},
}
writer := NewWriter(opts)
err := writer.WriteDatabase(db)
if err != nil {
t.Fatalf("WriteDatabase failed: %v", err)
}
// Read the api_provider file
apiProviderContent, err := os.ReadFile(filepath.Join(tmpDir, "sql_org_api_provider.go"))
if err != nil {
t.Fatalf("Failed to read api_provider file: %v", err)
}
contentStr := string(apiProviderContent)
// Verify all has-many relationships have unique names
hasManyExpectations := []string{
"RelRIDAPIProviderOrgLogins", // Has many via Login
"RelRIDAPIProviderOrgFilepointers", // Has many via Filepointer
"RelRIDAPIProviderOrgAPIEvents", // Has many via APIEvent
"RelRIDOwner", // Belongs to via rid_owner
}
for _, exp := range hasManyExpectations {
if !strings.Contains(contentStr, exp) {
t.Errorf("Missing relationship field: %s\nGenerated:\n%s", exp, contentStr)
}
}
// Verify NO duplicate field names
// Count occurrences of "RelRIDAPIProvider" fields - should have 3 unique ones
count := strings.Count(contentStr, "RelRIDAPIProvider")
if count != 3 {
t.Errorf("Expected 3 RelRIDAPIProvider* fields, found %d\nGenerated:\n%s", count, contentStr)
}
// Verify no duplicate declarations (would cause compilation error)
duplicatePattern := "RelRIDAPIProviders []*Model"
if strings.Contains(contentStr, duplicatePattern) {
t.Errorf("Found duplicate field declaration pattern, fields should be unique")
}
}
func TestWriter_FieldNameCollision(t *testing.T) {
// Test scenario: table with columns that would conflict with generated method names
table := models.InitTable("audit_table", "audit")
table.Columns["id_audit_table"] = &models.Column{
Name: "id_audit_table",
Type: "smallint",
NotNull: true,
IsPrimaryKey: true,
Sequence: 1,
}
table.Columns["table_name"] = &models.Column{
Name: "table_name",
Type: "varchar",
Length: 100,
NotNull: true,
Sequence: 2,
}
table.Columns["table_schema"] = &models.Column{
Name: "table_schema",
Type: "varchar",
Length: 100,
NotNull: true,
Sequence: 3,
}
// Create writer
tmpDir := t.TempDir()
opts := &writers.WriterOptions{
PackageName: "models",
OutputPath: filepath.Join(tmpDir, "test.go"),
}
writer := NewWriter(opts)
err := writer.WriteTable(table)
if err != nil {
t.Fatalf("WriteTable failed: %v", err)
}
// Read the generated file
content, err := os.ReadFile(opts.OutputPath)
if err != nil {
t.Fatalf("Failed to read generated file: %v", err)
}
generated := string(content)
// Verify that TableName field was renamed to TableName_ to avoid collision
if !strings.Contains(generated, "TableName_") {
t.Errorf("Expected field 'TableName_' (with underscore) but not found\nGenerated:\n%s", generated)
}
// Verify the struct tag still references the correct database column
if !strings.Contains(generated, `gorm:"column:table_name;`) {
t.Errorf("Expected gorm tag to reference 'table_name' column\nGenerated:\n%s", generated)
}
// Verify the TableName() method still exists and doesn't conflict
if !strings.Contains(generated, "func (m ModelAuditAuditTable) TableName() string") {
t.Errorf("TableName() method should still be generated\nGenerated:\n%s", generated)
}
// Verify NO field named just "TableName" (without underscore)
if strings.Contains(generated, "TableName sql_types") || strings.Contains(generated, "TableName string") {
t.Errorf("Field 'TableName' without underscore should not exist (would conflict with method)\nGenerated:\n%s", generated)
}
}
func TestWriter_UpdateIDTypeSafety(t *testing.T) {
// Test scenario: tables with different primary key types
tests := []struct {
name string
pkType string
expectedPK string
castType string
}{
{"int32_pk", "int", "int32", "int32(newid)"},
{"int16_pk", "smallint", "int16", "int16(newid)"},
{"int64_pk", "bigint", "int64", "int64(newid)"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
table := models.InitTable("test_table", "public")
table.Columns["id"] = &models.Column{
Name: "id",
Type: tt.pkType,
NotNull: true,
IsPrimaryKey: true,
}
tmpDir := t.TempDir()
opts := &writers.WriterOptions{
PackageName: "models",
OutputPath: filepath.Join(tmpDir, "test.go"),
}
writer := NewWriter(opts)
err := writer.WriteTable(table)
if err != nil {
t.Fatalf("WriteTable failed: %v", err)
}
content, err := os.ReadFile(opts.OutputPath)
if err != nil {
t.Fatalf("Failed to read generated file: %v", err)
}
generated := string(content)
// Verify UpdateID method has correct type cast
if !strings.Contains(generated, tt.castType) {
t.Errorf("Expected UpdateID to cast to %s\nGenerated:\n%s", tt.castType, generated)
}
// Verify no invalid int32(newid) for non-int32 types
if tt.expectedPK != "int32" && strings.Contains(generated, "int32(newid)") {
t.Errorf("UpdateID should not cast to int32 for %s type\nGenerated:\n%s", tt.pkType, generated)
}
// Verify UpdateID parameter is int64 (for consistency)
if !strings.Contains(generated, "UpdateID(newid int64)") {
t.Errorf("UpdateID should accept int64 parameter\nGenerated:\n%s", generated)
}
})
}
}

View File

@@ -0,0 +1,217 @@
# PostgreSQL Naming Conventions
Standardized naming rules for all database objects in RelSpec PostgreSQL output.
## Quick Reference
| Object Type | Prefix | Format | Example |
| ----------------- | ----------- | ---------------------------------- | ------------------------ |
| Primary Key | `pk_` | `pk_<schema>_<table>` | `pk_public_users` |
| Foreign Key | `fk_` | `fk_<table>_<referenced_table>` | `fk_posts_users` |
| Unique Constraint | `uk_` | `uk_<table>_<column>` | `uk_users_email` |
| Unique Index | `uidx_` | `uidx_<table>_<column>` | `uidx_users_email` |
| Regular Index | `idx_` | `idx_<table>_<column>` | `idx_posts_user_id` |
| Check Constraint | `chk_` | `chk_<table>_<constraint_purpose>` | `chk_users_age_positive` |
| Sequence | `identity_` | `identity_<table>_<column>` | `identity_users_id` |
| Trigger | `t_` | `t_<purpose>_<table>` | `t_audit_users` |
| Trigger Function | `tf_` | `tf_<purpose>_<table>` | `tf_audit_users` |
## Naming Rules by Object Type
### Primary Keys
**Pattern:** `pk_<schema>_<table>`
- Include schema name to avoid collisions across schemas
- Use lowercase, snake_case format
- Examples:
- `pk_public_users`
- `pk_audit_audit_log`
- `pk_staging_temp_data`
### Foreign Keys
**Pattern:** `fk_<table>_<referenced_table>`
- Reference the table containing the FK followed by the referenced table
- Use lowercase, snake_case format
- Do NOT include column names in standard FK constraints
- Examples:
- `fk_posts_users` (posts.user_id → users.id)
- `fk_comments_posts` (comments.post_id → posts.id)
- `fk_order_items_orders` (order_items.order_id → orders.id)
### Unique Constraints
**Pattern:** `uk_<table>_<column>`
- Use `uk_` prefix strictly for database constraints (CONSTRAINT type)
- Include column name for clarity
- Examples:
- `uk_users_email`
- `uk_users_username`
- `uk_products_sku`
### Unique Indexes
**Pattern:** `uidx_<table>_<column>`
- Use `uidx_` prefix strictly for index type objects
- Distinguished from constraints for clarity and implementation flexibility
- Examples:
- `uidx_users_email`
- `uidx_sessions_token`
- `uidx_api_keys_key`
### Regular Indexes
**Pattern:** `idx_<table>_<column>`
- Standard indexes for query optimization
- Single column: `idx_<table>_<column>`
- Examples:
- `idx_posts_user_id`
- `idx_orders_created_at`
- `idx_users_status`
### Check Constraints
**Pattern:** `chk_<table>_<constraint_purpose>`
- Describe the constraint validation purpose
- Use lowercase, snake_case for the purpose
- Examples:
- `chk_users_age_positive` (CHECK (age > 0))
- `chk_orders_quantity_positive` (CHECK (quantity > 0))
- `chk_products_price_valid` (CHECK (price >= 0))
- `chk_users_status_enum` (CHECK (status IN ('active', 'inactive')))
### Sequences
**Pattern:** `identity_<table>_<column>`
- Used for SERIAL/IDENTITY columns
- Explicitly named for clarity and management
- Examples:
- `identity_users_id`
- `identity_posts_id`
- `identity_transactions_id`
### Triggers
**Pattern:** `t_<purpose>_<table>`
- Include purpose before table name
- Lowercase, snake_case format
- Examples:
- `t_audit_users` (audit trigger on users table)
- `t_update_timestamp_posts` (timestamp update trigger on posts)
- `t_validate_orders` (validation trigger on orders)
### Trigger Functions
**Pattern:** `tf_<purpose>_<table>`
- Pair with trigger naming convention
- Use `tf_` prefix to distinguish from triggers themselves
- Examples:
- `tf_audit_users` (function for t_audit_users)
- `tf_update_timestamp_posts` (function for t_update_timestamp_posts)
- `tf_validate_orders` (function for t_validate_orders)
## Multi-Column Objects
### Composite Primary Keys
**Pattern:** `pk_<schema>_<table>`
- Same as single-column PKs
- Example: `pk_public_order_items` (composite key on order_id + item_id)
### Composite Unique Constraints
**Pattern:** `uk_<table>_<column1>_<column2>_[...]`
- Append all column names in order
- Examples:
- `uk_users_email_domain` (UNIQUE(email, domain))
- `uk_inventory_warehouse_sku` (UNIQUE(warehouse_id, sku))
### Composite Unique Indexes
**Pattern:** `uidx_<table>_<column1>_<column2>_[...]`
- Append all column names in order
- Examples:
- `uidx_users_first_name_last_name` (UNIQUE INDEX on first_name, last_name)
- `uidx_sessions_user_id_device_id` (UNIQUE INDEX on user_id, device_id)
### Composite Regular Indexes
**Pattern:** `idx_<table>_<column1>_<column2>_[...]`
- Append all column names in order
- List columns in typical query filter order
- Examples:
- `idx_orders_user_id_created_at` (filter by user, then sort by created_at)
- `idx_logs_level_timestamp` (filter by level, then by timestamp)
## Special Cases & Conventions
### Audit Trail Tables
- Audit table naming: `<original_table>_audit` or `audit_<original_table>`
- Audit indexes follow standard pattern: `idx_<audit_table>_<column>`
- Examples:
- Users table audit: `users_audit` with `idx_users_audit_tablename`, `idx_users_audit_changedate`
- Posts table audit: `posts_audit` with `idx_posts_audit_tablename`, `idx_posts_audit_changedate`
### Temporal/Versioning Tables
- Use suffix `_history` or `_versions` if needed
- Apply standard naming rules with the full table name
- Examples:
- `idx_users_history_user_id`
- `uk_posts_versions_version_number`
### Schema-Specific Objects
- Always qualify with schema when needed: `pk_<schema>_<table>`
- Multiple schemas allowed: `pk_public_users`, `pk_staging_users`
### Reserved Words & Special Names
- Avoid PostgreSQL reserved keywords in object names
- If column/table names conflict, use quoted identifiers in DDL
- Naming convention rules still apply to the logical name
### Generated/Anonymous Indexes
- If an index lacks explicit naming, default to: `idx_<schema>_<table>`
- Should be replaced with explicit names following standards
- Examples (to be renamed):
- `idx_public_users` → should be `idx_users_<column>`
## Implementation Notes
### Code Generation
- Names are always lowercase in generated SQL
- Underscore separators are required
### Migration Safety
- Do NOT rename objects after creation without explicit migration
- Names should be consistent across all schema versions
- Test generated DDL against PostgreSQL before deployment
### Testing
- Ensure consistency across all table and constraint generation
- Test with reserved words to verify escaping
## Related Documentation
- PostgreSQL Identifier Rules: https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-IDENTIFIERS
- Constraint Documentation: https://www.postgresql.org/docs/current/ddl-constraints.html
- Index Documentation: https://www.postgresql.org/docs/current/indexes.html

View File

@@ -8,6 +8,7 @@ import (
"strings"
"git.warky.dev/wdevs/relspecgo/pkg/models"
"git.warky.dev/wdevs/relspecgo/pkg/pgsql"
"git.warky.dev/wdevs/relspecgo/pkg/writers"
)
@@ -30,7 +31,7 @@ type MigrationWriter struct {
// NewMigrationWriter creates a new templated migration writer
func NewMigrationWriter(options *writers.WriterOptions) (*MigrationWriter, error) {
executor, err := NewTemplateExecutor()
executor, err := NewTemplateExecutor(options.FlattenSchema)
if err != nil {
return nil, fmt.Errorf("failed to create template executor: %w", err)
}
@@ -335,7 +336,7 @@ func (w *MigrationWriter) generateAlterTableScripts(schema *models.Schema, model
SchemaName: schema.Name,
TableName: modelTable.Name,
ColumnName: modelCol.Name,
ColumnType: modelCol.Type,
ColumnType: pgsql.ConvertSQLType(modelCol.Type),
Default: defaultVal,
NotNull: modelCol.NotNull,
})
@@ -359,7 +360,7 @@ func (w *MigrationWriter) generateAlterTableScripts(schema *models.Schema, model
SchemaName: schema.Name,
TableName: modelTable.Name,
ColumnName: modelCol.Name,
NewType: modelCol.Type,
NewType: pgsql.ConvertSQLType(modelCol.Type),
})
if err != nil {
return nil, err
@@ -427,9 +428,11 @@ func (w *MigrationWriter) generateIndexScripts(model *models.Schema, current *mo
for _, modelTable := range model.Tables {
currentTable := currentTables[strings.ToLower(modelTable.Name)]
// Process primary keys first
// Process primary keys first - check explicit constraints
foundExplicitPK := false
for constraintName, constraint := range modelTable.Constraints {
if constraint.Type == models.PrimaryKeyConstraint {
foundExplicitPK = true
shouldCreate := true
if currentTable != nil {
@@ -464,6 +467,53 @@ func (w *MigrationWriter) generateIndexScripts(model *models.Schema, current *mo
}
}
// If no explicit PK constraint, check for columns with IsPrimaryKey = true
if !foundExplicitPK {
pkColumns := []string{}
for _, col := range modelTable.Columns {
if col.IsPrimaryKey {
pkColumns = append(pkColumns, col.SQLName())
}
}
if len(pkColumns) > 0 {
sort.Strings(pkColumns)
constraintName := fmt.Sprintf("pk_%s_%s", model.SQLName(), modelTable.SQLName())
shouldCreate := true
if currentTable != nil {
// Check if a PK constraint already exists (by any name)
for _, constraint := range currentTable.Constraints {
if constraint.Type == models.PrimaryKeyConstraint {
shouldCreate = false
break
}
}
}
if shouldCreate {
sql, err := w.executor.ExecuteCreatePrimaryKey(CreatePrimaryKeyData{
SchemaName: model.Name,
TableName: modelTable.Name,
ConstraintName: constraintName,
Columns: strings.Join(pkColumns, ", "),
})
if err != nil {
return nil, err
}
script := MigrationScript{
ObjectName: fmt.Sprintf("%s.%s.%s", model.Name, modelTable.Name, constraintName),
ObjectType: "create primary key",
Schema: model.Name,
Priority: 160,
Sequence: len(scripts),
Body: sql,
}
scripts = append(scripts, script)
}
}
}
// Process indexes
for indexName, modelIndex := range modelTable.Indexes {
// Skip primary key indexes
@@ -703,7 +753,7 @@ func (w *MigrationWriter) generateAuditScripts(schema *models.Schema, auditConfi
}
// Generate audit function
funcName := fmt.Sprintf("ft_audit_%s", table.Name)
funcName := fmt.Sprintf("tf_audit_%s", table.Name)
funcData := BuildAuditFunctionData(schema.Name, table, pk, config, auditSchema, auditConfig.UserFunction)
funcSQL, err := w.executor.ExecuteAuditFunction(funcData)

View File

@@ -121,7 +121,7 @@ func TestWriteMigration_WithAudit(t *testing.T) {
}
// Verify audit function
if !strings.Contains(output, "CREATE OR REPLACE FUNCTION public.ft_audit_users()") {
if !strings.Contains(output, "CREATE OR REPLACE FUNCTION public.tf_audit_users()") {
t.Error("Migration missing audit function")
}
@@ -137,7 +137,7 @@ func TestWriteMigration_WithAudit(t *testing.T) {
}
func TestTemplateExecutor_CreateTable(t *testing.T) {
executor, err := NewTemplateExecutor()
executor, err := NewTemplateExecutor(false)
if err != nil {
t.Fatalf("Failed to create executor: %v", err)
}
@@ -170,14 +170,14 @@ func TestTemplateExecutor_CreateTable(t *testing.T) {
}
func TestTemplateExecutor_AuditFunction(t *testing.T) {
executor, err := NewTemplateExecutor()
executor, err := NewTemplateExecutor(false)
if err != nil {
t.Fatalf("Failed to create executor: %v", err)
}
data := AuditFunctionData{
SchemaName: "public",
FunctionName: "ft_audit_users",
FunctionName: "tf_audit_users",
TableName: "users",
TablePrefix: "NULL",
PrimaryKey: "id",
@@ -202,7 +202,7 @@ func TestTemplateExecutor_AuditFunction(t *testing.T) {
t.Logf("Generated SQL:\n%s", sql)
if !strings.Contains(sql, "CREATE OR REPLACE FUNCTION public.ft_audit_users()") {
if !strings.Contains(sql, "CREATE OR REPLACE FUNCTION public.tf_audit_users()") {
t.Error("SQL missing function definition")
}
if !strings.Contains(sql, "IF TG_OP = 'INSERT'") {
@@ -215,3 +215,70 @@ func TestTemplateExecutor_AuditFunction(t *testing.T) {
t.Error("SQL missing DELETE handling")
}
}
func TestWriteMigration_NumericConstraintNames(t *testing.T) {
// Current database (empty)
current := models.InitDatabase("testdb")
currentSchema := models.InitSchema("entity")
current.Schemas = append(current.Schemas, currentSchema)
// Model database (with constraint starting with number)
model := models.InitDatabase("testdb")
modelSchema := models.InitSchema("entity")
// Create individual_actor_relationship table
table := models.InitTable("individual_actor_relationship", "entity")
idCol := models.InitColumn("id", "individual_actor_relationship", "entity")
idCol.Type = "integer"
idCol.IsPrimaryKey = true
table.Columns["id"] = idCol
actorIDCol := models.InitColumn("actor_id", "individual_actor_relationship", "entity")
actorIDCol.Type = "integer"
table.Columns["actor_id"] = actorIDCol
// Add constraint with name starting with number
constraint := &models.Constraint{
Name: "215162_fk_actor",
Type: models.ForeignKeyConstraint,
Columns: []string{"actor_id"},
ReferencedSchema: "entity",
ReferencedTable: "actor",
ReferencedColumns: []string{"id"},
OnDelete: "CASCADE",
OnUpdate: "NO ACTION",
}
table.Constraints["215162_fk_actor"] = constraint
modelSchema.Tables = append(modelSchema.Tables, table)
model.Schemas = append(model.Schemas, modelSchema)
// Generate migration
var buf bytes.Buffer
writer, err := NewMigrationWriter(&writers.WriterOptions{})
if err != nil {
t.Fatalf("Failed to create writer: %v", err)
}
writer.writer = &buf
err = writer.WriteMigration(model, current)
if err != nil {
t.Fatalf("WriteMigration failed: %v", err)
}
output := buf.String()
t.Logf("Generated migration:\n%s", output)
// Verify constraint name is properly quoted
if !strings.Contains(output, `"215162_fk_actor"`) {
t.Error("Constraint name starting with number should be quoted")
}
// Verify the SQL is syntactically correct (contains required keywords)
if !strings.Contains(output, "ADD CONSTRAINT") {
t.Error("Migration missing ADD CONSTRAINT")
}
if !strings.Contains(output, "FOREIGN KEY") {
t.Error("Migration missing FOREIGN KEY")
}
}

View File

@@ -21,6 +21,7 @@ func TemplateFunctions() map[string]interface{} {
"quote": quote,
"escape": escape,
"safe_identifier": safeIdentifier,
"quote_ident": quoteIdent,
// Type conversion
"goTypeToSQL": goTypeToSQL,
@@ -122,6 +123,43 @@ func safeIdentifier(s string) string {
return strings.ToLower(safe)
}
// quoteIdent quotes a PostgreSQL identifier if necessary
// Identifiers need quoting if they:
// - Start with a digit
// - Contain special characters
// - Are reserved keywords
// - Contain uppercase letters (to preserve case)
func quoteIdent(s string) string {
if s == "" {
return `""`
}
// Check if quoting is needed
needsQuoting := unicode.IsDigit(rune(s[0]))
// Starts with digit
// Contains uppercase letters or special characters
for _, r := range s {
if unicode.IsUpper(r) {
needsQuoting = true
break
}
if !unicode.IsLetter(r) && !unicode.IsDigit(r) && r != '_' {
needsQuoting = true
break
}
}
if needsQuoting {
// Escape double quotes by doubling them
escaped := strings.ReplaceAll(s, `"`, `""`)
return `"` + escaped + `"`
}
return s
}
// Type conversion functions
// goTypeToSQL converts Go type to PostgreSQL type

View File

@@ -101,6 +101,31 @@ func TestSafeIdentifier(t *testing.T) {
}
}
func TestQuoteIdent(t *testing.T) {
tests := []struct {
input string
expected string
}{
{"valid_name", "valid_name"},
{"ValidName", `"ValidName"`},
{"123column", `"123column"`},
{"215162_fk_constraint", `"215162_fk_constraint"`},
{"user-id", `"user-id"`},
{"user@domain", `"user@domain"`},
{`"quoted"`, `"""quoted"""`},
{"", `""`},
{"lowercase", "lowercase"},
{"with_underscore", "with_underscore"},
}
for _, tt := range tests {
result := quoteIdent(tt.input)
if result != tt.expected {
t.Errorf("quoteIdent(%q) = %q, want %q", tt.input, result, tt.expected)
}
}
}
func TestGoTypeToSQL(t *testing.T) {
tests := []struct {
input string
@@ -243,7 +268,7 @@ func TestTemplateFunctions(t *testing.T) {
// Check that all expected functions are registered
expectedFuncs := []string{
"upper", "lower", "snake_case", "camelCase",
"indent", "quote", "escape", "safe_identifier",
"indent", "quote", "escape", "safe_identifier", "quote_ident",
"goTypeToSQL", "sqlTypeToGo", "isNumeric", "isText",
"first", "last", "filter", "mapFunc", "join_with",
"join",
@@ -289,7 +314,7 @@ func TestFormatType(t *testing.T) {
// Test that template functions work in actual templates
func TestTemplateFunctionsInTemplate(t *testing.T) {
executor, err := NewTemplateExecutor()
executor, err := NewTemplateExecutor(false)
if err != nil {
t.Fatalf("Failed to create executor: %v", err)
}

View File

@@ -18,14 +18,39 @@ type TemplateExecutor struct {
templates *template.Template
}
// NewTemplateExecutor creates a new template executor
func NewTemplateExecutor() (*TemplateExecutor, error) {
// NewTemplateExecutor creates a new template executor.
// flattenSchema controls whether schema.table identifiers use dot or underscore separation.
func NewTemplateExecutor(flattenSchema bool) (*TemplateExecutor, error) {
// Create template with custom functions
funcMap := make(template.FuncMap)
for k, v := range TemplateFunctions() {
funcMap[k] = v
}
// qual_table returns a quoted, schema-qualified identifier.
// With flatten=false: "schema"."table" (or unquoted equivalents).
// With flatten=true: "schema_table".
funcMap["qual_table"] = func(schema, name string) string {
if schema == "" {
return quoteIdent(name)
}
if flattenSchema {
return quoteIdent(schema + "_" + name)
}
return quoteIdent(schema) + "." + quoteIdent(name)
}
// qual_table_raw is the same as qual_table but without identifier quoting.
funcMap["qual_table_raw"] = func(schema, name string) string {
if schema == "" {
return name
}
if flattenSchema {
return schema + "_" + name
}
return schema + "." + name
}
tmpl, err := template.New("").Funcs(funcMap).ParseFS(templateFS, "templates/*.tmpl")
if err != nil {
return nil, fmt.Errorf("failed to parse templates: %w", err)
@@ -177,6 +202,72 @@ type AuditTriggerData struct {
Events string
}
// CreateUniqueConstraintData contains data for create unique constraint template
type CreateUniqueConstraintData struct {
SchemaName string
TableName string
ConstraintName string
Columns string
}
// CreateCheckConstraintData contains data for create check constraint template
type CreateCheckConstraintData struct {
SchemaName string
TableName string
ConstraintName string
Expression string
}
// CreateForeignKeyWithCheckData contains data for create foreign key with existence check template
type CreateForeignKeyWithCheckData struct {
SchemaName string
TableName string
ConstraintName string
SourceColumns string
TargetSchema string
TargetTable string
TargetColumns string
OnDelete string
OnUpdate string
Deferrable bool
}
// SetSequenceValueData contains data for set sequence value template
type SetSequenceValueData struct {
SchemaName string
TableName string
SequenceName string
ColumnName string
}
// CreateSequenceData contains data for create sequence template
type CreateSequenceData struct {
SchemaName string
SequenceName string
Increment int
MinValue int64
MaxValue int64
StartValue int64
CacheSize int
}
// AddColumnWithCheckData contains data for add column with existence check template
type AddColumnWithCheckData struct {
SchemaName string
TableName string
ColumnName string
ColumnDefinition string
}
// CreatePrimaryKeyWithAutoGenCheckData contains data for primary key with auto-generated key check template
type CreatePrimaryKeyWithAutoGenCheckData struct {
SchemaName string
TableName string
ConstraintName string
AutoGenNames string // Comma-separated list of names like "'name1', 'name2'"
Columns string
}
// Execute methods for each template
// ExecuteCreateTable executes the create table template
@@ -319,6 +410,76 @@ func (te *TemplateExecutor) ExecuteAuditTrigger(data AuditTriggerData) (string,
return buf.String(), nil
}
// ExecuteCreateUniqueConstraint executes the create unique constraint template
func (te *TemplateExecutor) ExecuteCreateUniqueConstraint(data CreateUniqueConstraintData) (string, error) {
var buf bytes.Buffer
err := te.templates.ExecuteTemplate(&buf, "create_unique_constraint.tmpl", data)
if err != nil {
return "", fmt.Errorf("failed to execute create_unique_constraint template: %w", err)
}
return buf.String(), nil
}
// ExecuteCreateCheckConstraint executes the create check constraint template
func (te *TemplateExecutor) ExecuteCreateCheckConstraint(data CreateCheckConstraintData) (string, error) {
var buf bytes.Buffer
err := te.templates.ExecuteTemplate(&buf, "create_check_constraint.tmpl", data)
if err != nil {
return "", fmt.Errorf("failed to execute create_check_constraint template: %w", err)
}
return buf.String(), nil
}
// ExecuteCreateForeignKeyWithCheck executes the create foreign key with check template
func (te *TemplateExecutor) ExecuteCreateForeignKeyWithCheck(data CreateForeignKeyWithCheckData) (string, error) {
var buf bytes.Buffer
err := te.templates.ExecuteTemplate(&buf, "create_foreign_key_with_check.tmpl", data)
if err != nil {
return "", fmt.Errorf("failed to execute create_foreign_key_with_check template: %w", err)
}
return buf.String(), nil
}
// ExecuteSetSequenceValue executes the set sequence value template
func (te *TemplateExecutor) ExecuteSetSequenceValue(data SetSequenceValueData) (string, error) {
var buf bytes.Buffer
err := te.templates.ExecuteTemplate(&buf, "set_sequence_value.tmpl", data)
if err != nil {
return "", fmt.Errorf("failed to execute set_sequence_value template: %w", err)
}
return buf.String(), nil
}
// ExecuteCreateSequence executes the create sequence template
func (te *TemplateExecutor) ExecuteCreateSequence(data CreateSequenceData) (string, error) {
var buf bytes.Buffer
err := te.templates.ExecuteTemplate(&buf, "create_sequence.tmpl", data)
if err != nil {
return "", fmt.Errorf("failed to execute create_sequence template: %w", err)
}
return buf.String(), nil
}
// ExecuteAddColumnWithCheck executes the add column with check template
func (te *TemplateExecutor) ExecuteAddColumnWithCheck(data AddColumnWithCheckData) (string, error) {
var buf bytes.Buffer
err := te.templates.ExecuteTemplate(&buf, "add_column_with_check.tmpl", data)
if err != nil {
return "", fmt.Errorf("failed to execute add_column_with_check template: %w", err)
}
return buf.String(), nil
}
// ExecuteCreatePrimaryKeyWithAutoGenCheck executes the create primary key with auto-generated key check template
func (te *TemplateExecutor) ExecuteCreatePrimaryKeyWithAutoGenCheck(data CreatePrimaryKeyWithAutoGenCheckData) (string, error) {
var buf bytes.Buffer
err := te.templates.ExecuteTemplate(&buf, "create_primary_key_with_autogen_check.tmpl", data)
if err != nil {
return "", fmt.Errorf("failed to execute create_primary_key_with_autogen_check template: %w", err)
}
return buf.String(), nil
}
// Helper functions to build template data from models
// BuildCreateTableData builds CreateTableData from a models.Table
@@ -355,7 +516,7 @@ func BuildAuditFunctionData(
auditSchema string,
userFunction string,
) AuditFunctionData {
funcName := fmt.Sprintf("ft_audit_%s", table.Name)
funcName := fmt.Sprintf("tf_audit_%s", table.Name)
// Build list of audited columns
auditedColumns := make([]*models.Column, 0)

View File

@@ -1,4 +1,4 @@
ALTER TABLE {{.SchemaName}}.{{.TableName}}
ADD COLUMN IF NOT EXISTS {{.ColumnName}} {{.ColumnType}}
ALTER TABLE {{qual_table .SchemaName .TableName}}
ADD COLUMN IF NOT EXISTS {{quote_ident .ColumnName}} {{.ColumnType}}
{{- if .Default}} DEFAULT {{.Default}}{{end}}
{{- if .NotNull}} NOT NULL{{end}};

View File

@@ -0,0 +1,12 @@
DO $$
BEGIN
IF NOT EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_schema = '{{.SchemaName}}'
AND table_name = '{{.TableName}}'
AND column_name = '{{.ColumnName}}'
) THEN
ALTER TABLE {{qual_table .SchemaName .TableName}} ADD COLUMN {{.ColumnDefinition}};
END IF;
END;
$$;

View File

@@ -1,7 +1,7 @@
{{- if .SetDefault -}}
ALTER TABLE {{.SchemaName}}.{{.TableName}}
ALTER COLUMN {{.ColumnName}} SET DEFAULT {{.DefaultValue}};
ALTER TABLE {{qual_table .SchemaName .TableName}}
ALTER COLUMN {{quote_ident .ColumnName}} SET DEFAULT {{.DefaultValue}};
{{- else -}}
ALTER TABLE {{.SchemaName}}.{{.TableName}}
ALTER COLUMN {{.ColumnName}} DROP DEFAULT;
ALTER TABLE {{qual_table .SchemaName .TableName}}
ALTER COLUMN {{quote_ident .ColumnName}} DROP DEFAULT;
{{- end -}}

View File

@@ -1,2 +1,2 @@
ALTER TABLE {{.SchemaName}}.{{.TableName}}
ALTER COLUMN {{.ColumnName}} TYPE {{.NewType}};
ALTER TABLE {{qual_table .SchemaName .TableName}}
ALTER COLUMN {{quote_ident .ColumnName}} TYPE {{.NewType}};

View File

@@ -1,4 +1,4 @@
CREATE OR REPLACE FUNCTION {{.SchemaName}}.{{.FunctionName}}()
CREATE OR REPLACE FUNCTION {{qual_table_raw .SchemaName .FunctionName}}()
RETURNS trigger AS
$body$
DECLARE
@@ -81,4 +81,4 @@ LANGUAGE plpgsql
VOLATILE
SECURITY DEFINER;
COMMENT ON FUNCTION {{.SchemaName}}.{{.FunctionName}}() IS 'Audit trigger function for table {{.SchemaName}}.{{.TableName}}';
COMMENT ON FUNCTION {{qual_table_raw .SchemaName .FunctionName}}() IS 'Audit trigger function for table {{qual_table_raw .SchemaName .TableName}}';

View File

@@ -4,13 +4,13 @@ BEGIN
SELECT 1
FROM pg_trigger
WHERE tgname = '{{.TriggerName}}'
AND tgrelid = '{{.SchemaName}}.{{.TableName}}'::regclass
AND tgrelid = '{{qual_table_raw .SchemaName .TableName}}'::regclass
) THEN
CREATE TRIGGER {{.TriggerName}}
AFTER {{.Events}}
ON {{.SchemaName}}.{{.TableName}}
ON {{qual_table_raw .SchemaName .TableName}}
FOR EACH ROW
EXECUTE FUNCTION {{.SchemaName}}.{{.FunctionName}}();
EXECUTE FUNCTION {{qual_table_raw .SchemaName .FunctionName}}();
END IF;
END;
$$;

View File

@@ -1,6 +1,6 @@
{{/* Base constraint template */}}
{{- define "constraint_base" -}}
ALTER TABLE {{.SchemaName}}.{{.TableName}}
ALTER TABLE {{qual_table_raw .SchemaName .TableName}}
ADD CONSTRAINT {{.ConstraintName}}
{{block "constraint_definition" .}}{{end}};
{{- end -}}
@@ -15,7 +15,7 @@ BEGIN
AND table_name = '{{.TableName}}'
AND constraint_name = '{{.ConstraintName}}'
) THEN
ALTER TABLE {{.SchemaName}}.{{.TableName}}
ALTER TABLE {{qual_table_raw .SchemaName .TableName}}
DROP CONSTRAINT {{.ConstraintName}};
END IF;
END;

View File

@@ -11,7 +11,7 @@
{{/* Base ALTER TABLE structure */}}
{{- define "alter_table_base" -}}
ALTER TABLE {{.SchemaName}}.{{.TableName}}
ALTER TABLE {{qual_table_raw .SchemaName .TableName}}
{{block "alter_operation" .}}{{end}};
{{- end -}}
@@ -30,5 +30,5 @@ $$;
{{/* Common drop pattern */}}
{{- define "drop_if_exists" -}}
{{block "drop_type" .}}{{end}} IF EXISTS {{.SchemaName}}.{{.ObjectName}};
{{block "drop_type" .}}{{end}} IF EXISTS {{qual_table_raw .SchemaName .ObjectName}};
{{- end -}}

View File

@@ -1 +1 @@
COMMENT ON COLUMN {{.SchemaName}}.{{.TableName}}.{{.ColumnName}} IS '{{.Comment}}';
COMMENT ON COLUMN {{qual_table .SchemaName .TableName}}.{{quote_ident .ColumnName}} IS '{{.Comment}}';

View File

@@ -1 +1 @@
COMMENT ON TABLE {{.SchemaName}}.{{.TableName}} IS '{{.Comment}}';
COMMENT ON TABLE {{qual_table .SchemaName .TableName}} IS '{{.Comment}}';

View File

@@ -0,0 +1,12 @@
DO $$
BEGIN
IF NOT EXISTS (
SELECT 1 FROM information_schema.table_constraints
WHERE table_schema = '{{.SchemaName}}'
AND table_name = '{{.TableName}}'
AND constraint_name = '{{.ConstraintName}}'
) THEN
ALTER TABLE {{qual_table .SchemaName .TableName}} ADD CONSTRAINT {{quote_ident .ConstraintName}} CHECK ({{.Expression}});
END IF;
END;
$$;

View File

@@ -1,10 +1,10 @@
ALTER TABLE {{.SchemaName}}.{{.TableName}}
DROP CONSTRAINT IF EXISTS {{.ConstraintName}};
ALTER TABLE {{qual_table .SchemaName .TableName}}
DROP CONSTRAINT IF EXISTS {{quote_ident .ConstraintName}};
ALTER TABLE {{.SchemaName}}.{{.TableName}}
ADD CONSTRAINT {{.ConstraintName}}
ALTER TABLE {{qual_table .SchemaName .TableName}}
ADD CONSTRAINT {{quote_ident .ConstraintName}}
FOREIGN KEY ({{.SourceColumns}})
REFERENCES {{.TargetSchema}}.{{.TargetTable}} ({{.TargetColumns}})
REFERENCES {{qual_table .TargetSchema .TargetTable}} ({{.TargetColumns}})
ON DELETE {{.OnDelete}}
ON UPDATE {{.OnUpdate}}
DEFERRABLE;

View File

@@ -0,0 +1,18 @@
DO $$
BEGIN
IF NOT EXISTS (
SELECT 1 FROM information_schema.table_constraints
WHERE table_schema = '{{.SchemaName}}'
AND table_name = '{{.TableName}}'
AND constraint_name = '{{.ConstraintName}}'
) THEN
ALTER TABLE {{qual_table .SchemaName .TableName}}
ADD CONSTRAINT {{quote_ident .ConstraintName}}
FOREIGN KEY ({{.SourceColumns}})
REFERENCES {{qual_table .TargetSchema .TargetTable}} ({{.TargetColumns}})
ON DELETE {{.OnDelete}}
ON UPDATE {{.OnUpdate}}{{if .Deferrable}}
DEFERRABLE{{end}};
END IF;
END;
$$;

View File

@@ -1,2 +1,2 @@
CREATE {{if .Unique}}UNIQUE {{end}}INDEX IF NOT EXISTS {{.IndexName}}
ON {{.SchemaName}}.{{.TableName}} USING {{.IndexType}} ({{.Columns}});
CREATE {{if .Unique}}UNIQUE {{end}}INDEX IF NOT EXISTS {{quote_ident .IndexName}}
ON {{qual_table .SchemaName .TableName}} USING {{.IndexType}} ({{.Columns}});

View File

@@ -6,8 +6,8 @@ BEGIN
AND table_name = '{{.TableName}}'
AND constraint_name = '{{.ConstraintName}}'
) THEN
ALTER TABLE {{.SchemaName}}.{{.TableName}}
ADD CONSTRAINT {{.ConstraintName}} PRIMARY KEY ({{.Columns}});
ALTER TABLE {{qual_table .SchemaName .TableName}}
ADD CONSTRAINT {{quote_ident .ConstraintName}} PRIMARY KEY ({{.Columns}});
END IF;
END;
$$;

View File

@@ -0,0 +1,27 @@
DO $$
DECLARE
auto_pk_name text;
BEGIN
-- Drop auto-generated primary key if it exists
SELECT constraint_name INTO auto_pk_name
FROM information_schema.table_constraints
WHERE table_schema = '{{.SchemaName}}'
AND table_name = '{{.TableName}}'
AND constraint_type = 'PRIMARY KEY'
AND constraint_name IN ({{.AutoGenNames}});
IF auto_pk_name IS NOT NULL THEN
EXECUTE 'ALTER TABLE {{qual_table .SchemaName .TableName}} DROP CONSTRAINT ' || quote_ident(auto_pk_name);
END IF;
-- Add named primary key if it doesn't exist
IF NOT EXISTS (
SELECT 1 FROM information_schema.table_constraints
WHERE table_schema = '{{.SchemaName}}'
AND table_name = '{{.TableName}}'
AND constraint_name = '{{.ConstraintName}}'
) THEN
ALTER TABLE {{qual_table .SchemaName .TableName}} ADD CONSTRAINT {{quote_ident .ConstraintName}} PRIMARY KEY ({{.Columns}});
END IF;
END;
$$;

View File

@@ -0,0 +1,6 @@
CREATE SEQUENCE IF NOT EXISTS {{qual_table .SchemaName .SequenceName}}
INCREMENT {{.Increment}}
MINVALUE {{.MinValue}}
MAXVALUE {{.MaxValue}}
START {{.StartValue}}
CACHE {{.CacheSize}};

View File

@@ -1,7 +1,7 @@
CREATE TABLE IF NOT EXISTS {{.SchemaName}}.{{.TableName}} (
CREATE TABLE IF NOT EXISTS {{qual_table .SchemaName .TableName}} (
{{- range $i, $col := .Columns}}
{{- if $i}},{{end}}
{{$col.Name}} {{$col.Type}}
{{quote_ident $col.Name}} {{$col.Type}}
{{- if $col.Default}} DEFAULT {{$col.Default}}{{end}}
{{- if $col.NotNull}} NOT NULL{{end}}
{{- end}}

View File

@@ -0,0 +1,12 @@
DO $$
BEGIN
IF NOT EXISTS (
SELECT 1 FROM information_schema.table_constraints
WHERE table_schema = '{{.SchemaName}}'
AND table_name = '{{.TableName}}'
AND constraint_name = '{{.ConstraintName}}'
) THEN
ALTER TABLE {{qual_table .SchemaName .TableName}} ADD CONSTRAINT {{quote_ident .ConstraintName}} UNIQUE ({{.Columns}});
END IF;
END;
$$;

View File

@@ -1 +1 @@
ALTER TABLE {{.SchemaName}}.{{.TableName}} DROP CONSTRAINT IF EXISTS {{.ConstraintName}};
ALTER TABLE {{qual_table .SchemaName .TableName}} DROP CONSTRAINT IF EXISTS {{quote_ident .ConstraintName}};

View File

@@ -1 +1 @@
DROP INDEX IF EXISTS {{.SchemaName}}.{{.IndexName}} CASCADE;
DROP INDEX IF EXISTS {{qual_table .SchemaName .IndexName}} CASCADE;

View File

@@ -16,7 +16,7 @@
{{/* Qualified table name */}}
{{- define "qualified_table" -}}
{{.SchemaName}}.{{.TableName}}
{{qual_table_raw .SchemaName .TableName}}
{{- end -}}
{{/* Index method clause */}}

View File

@@ -0,0 +1,19 @@
DO $$
DECLARE
m_cnt bigint;
BEGIN
IF EXISTS (
SELECT 1 FROM pg_class c
INNER JOIN pg_namespace n ON n.oid = c.relnamespace
WHERE c.relname = '{{.SequenceName}}'
AND n.nspname = '{{.SchemaName}}'
AND c.relkind = 'S'
) THEN
SELECT COALESCE(MAX({{quote_ident .ColumnName}}), 0) + 1
FROM {{qual_table .SchemaName .TableName}}
INTO m_cnt;
PERFORM setval('{{qual_table_raw .SchemaName .SequenceName}}'::regclass, m_cnt);
END IF;
END;
$$;

File diff suppressed because it is too large Load Diff

View File

@@ -45,11 +45,11 @@ func TestWriteDatabase(t *testing.T) {
// Add unique index
uniqueEmailIndex := &models.Index{
Name: "uk_users_email",
Name: "uidx_users_email",
Unique: true,
Columns: []string{"email"},
}
table.Indexes["uk_users_email"] = uniqueEmailIndex
table.Indexes["uidx_users_email"] = uniqueEmailIndex
schema.Tables = append(schema.Tables, table)
db.Schemas = append(db.Schemas, schema)
@@ -164,6 +164,296 @@ func TestWriteForeignKeys(t *testing.T) {
}
}
func TestWriteUniqueConstraints(t *testing.T) {
// Create a test database with unique constraints
db := models.InitDatabase("testdb")
schema := models.InitSchema("public")
// Create table with unique constraints
table := models.InitTable("users", "public")
// Add columns
emailCol := models.InitColumn("email", "users", "public")
emailCol.Type = "varchar(255)"
emailCol.NotNull = true
table.Columns["email"] = emailCol
guidCol := models.InitColumn("guid", "users", "public")
guidCol.Type = "uuid"
guidCol.NotNull = true
table.Columns["guid"] = guidCol
// Add unique constraints
emailConstraint := &models.Constraint{
Name: "uq_email",
Type: models.UniqueConstraint,
Schema: "public",
Table: "users",
Columns: []string{"email"},
}
table.Constraints["uq_email"] = emailConstraint
guidConstraint := &models.Constraint{
Name: "uq_guid",
Type: models.UniqueConstraint,
Schema: "public",
Table: "users",
Columns: []string{"guid"},
}
table.Constraints["uq_guid"] = guidConstraint
schema.Tables = append(schema.Tables, table)
db.Schemas = append(db.Schemas, schema)
// Create writer with output to buffer
var buf bytes.Buffer
options := &writers.WriterOptions{}
writer := NewWriter(options)
writer.writer = &buf
// Write the database
err := writer.WriteDatabase(db)
if err != nil {
t.Fatalf("WriteDatabase failed: %v", err)
}
output := buf.String()
// Print output for debugging
t.Logf("Generated SQL:\n%s", output)
// Verify unique constraints are present
if !strings.Contains(output, "-- Unique constraints for schema: public") {
t.Errorf("Output missing unique constraints header")
}
if !strings.Contains(output, "ADD CONSTRAINT uq_email UNIQUE (email)") {
t.Errorf("Output missing uq_email unique constraint\nFull output:\n%s", output)
}
if !strings.Contains(output, "ADD CONSTRAINT uq_guid UNIQUE (guid)") {
t.Errorf("Output missing uq_guid unique constraint\nFull output:\n%s", output)
}
}
func TestWriteCheckConstraints(t *testing.T) {
// Create a test database with check constraints
db := models.InitDatabase("testdb")
schema := models.InitSchema("public")
// Create table with check constraints
table := models.InitTable("products", "public")
// Add columns
priceCol := models.InitColumn("price", "products", "public")
priceCol.Type = "numeric(10,2)"
table.Columns["price"] = priceCol
statusCol := models.InitColumn("status", "products", "public")
statusCol.Type = "varchar(20)"
table.Columns["status"] = statusCol
quantityCol := models.InitColumn("quantity", "products", "public")
quantityCol.Type = "integer"
table.Columns["quantity"] = quantityCol
// Add check constraints
priceConstraint := &models.Constraint{
Name: "ck_price_positive",
Type: models.CheckConstraint,
Schema: "public",
Table: "products",
Expression: "price >= 0",
}
table.Constraints["ck_price_positive"] = priceConstraint
statusConstraint := &models.Constraint{
Name: "ck_status_valid",
Type: models.CheckConstraint,
Schema: "public",
Table: "products",
Expression: "status IN ('active', 'inactive', 'discontinued')",
}
table.Constraints["ck_status_valid"] = statusConstraint
quantityConstraint := &models.Constraint{
Name: "ck_quantity_nonnegative",
Type: models.CheckConstraint,
Schema: "public",
Table: "products",
Expression: "quantity >= 0",
}
table.Constraints["ck_quantity_nonnegative"] = quantityConstraint
schema.Tables = append(schema.Tables, table)
db.Schemas = append(db.Schemas, schema)
// Create writer with output to buffer
var buf bytes.Buffer
options := &writers.WriterOptions{}
writer := NewWriter(options)
writer.writer = &buf
// Write the database
err := writer.WriteDatabase(db)
if err != nil {
t.Fatalf("WriteDatabase failed: %v", err)
}
output := buf.String()
// Print output for debugging
t.Logf("Generated SQL:\n%s", output)
// Verify check constraints are present
if !strings.Contains(output, "-- Check constraints for schema: public") {
t.Errorf("Output missing check constraints header")
}
if !strings.Contains(output, "ADD CONSTRAINT ck_price_positive CHECK (price >= 0)") {
t.Errorf("Output missing ck_price_positive check constraint\nFull output:\n%s", output)
}
if !strings.Contains(output, "ADD CONSTRAINT ck_status_valid CHECK (status IN ('active', 'inactive', 'discontinued'))") {
t.Errorf("Output missing ck_status_valid check constraint\nFull output:\n%s", output)
}
if !strings.Contains(output, "ADD CONSTRAINT ck_quantity_nonnegative CHECK (quantity >= 0)") {
t.Errorf("Output missing ck_quantity_nonnegative check constraint\nFull output:\n%s", output)
}
}
func TestWriteAllConstraintTypes(t *testing.T) {
// Create a comprehensive test with all constraint types
db := models.InitDatabase("testdb")
schema := models.InitSchema("public")
// Create orders table
ordersTable := models.InitTable("orders", "public")
// Add columns
idCol := models.InitColumn("id", "orders", "public")
idCol.Type = "integer"
idCol.IsPrimaryKey = true
ordersTable.Columns["id"] = idCol
userIDCol := models.InitColumn("user_id", "orders", "public")
userIDCol.Type = "integer"
userIDCol.NotNull = true
ordersTable.Columns["user_id"] = userIDCol
orderNumberCol := models.InitColumn("order_number", "orders", "public")
orderNumberCol.Type = "varchar(50)"
orderNumberCol.NotNull = true
ordersTable.Columns["order_number"] = orderNumberCol
totalCol := models.InitColumn("total", "orders", "public")
totalCol.Type = "numeric(10,2)"
ordersTable.Columns["total"] = totalCol
statusCol := models.InitColumn("status", "orders", "public")
statusCol.Type = "varchar(20)"
ordersTable.Columns["status"] = statusCol
// Add primary key constraint
pkConstraint := &models.Constraint{
Name: "pk_orders",
Type: models.PrimaryKeyConstraint,
Schema: "public",
Table: "orders",
Columns: []string{"id"},
}
ordersTable.Constraints["pk_orders"] = pkConstraint
// Add unique constraint
uniqueConstraint := &models.Constraint{
Name: "uq_order_number",
Type: models.UniqueConstraint,
Schema: "public",
Table: "orders",
Columns: []string{"order_number"},
}
ordersTable.Constraints["uq_order_number"] = uniqueConstraint
// Add check constraint
checkConstraint := &models.Constraint{
Name: "ck_total_positive",
Type: models.CheckConstraint,
Schema: "public",
Table: "orders",
Expression: "total > 0",
}
ordersTable.Constraints["ck_total_positive"] = checkConstraint
statusCheckConstraint := &models.Constraint{
Name: "ck_status_valid",
Type: models.CheckConstraint,
Schema: "public",
Table: "orders",
Expression: "status IN ('pending', 'completed', 'cancelled')",
}
ordersTable.Constraints["ck_status_valid"] = statusCheckConstraint
// Add foreign key constraint (referencing a users table)
fkConstraint := &models.Constraint{
Name: "fk_orders_user",
Type: models.ForeignKeyConstraint,
Schema: "public",
Table: "orders",
Columns: []string{"user_id"},
ReferencedSchema: "public",
ReferencedTable: "users",
ReferencedColumns: []string{"id"},
OnDelete: "CASCADE",
OnUpdate: "CASCADE",
}
ordersTable.Constraints["fk_orders_user"] = fkConstraint
schema.Tables = append(schema.Tables, ordersTable)
db.Schemas = append(db.Schemas, schema)
// Create writer with output to buffer
var buf bytes.Buffer
options := &writers.WriterOptions{}
writer := NewWriter(options)
writer.writer = &buf
// Write the database
err := writer.WriteDatabase(db)
if err != nil {
t.Fatalf("WriteDatabase failed: %v", err)
}
output := buf.String()
// Print output for debugging
t.Logf("Generated SQL:\n%s", output)
// Verify all constraint types are present
expectedConstraints := map[string]string{
"Primary Key": "PRIMARY KEY",
"Unique": "ADD CONSTRAINT uq_order_number UNIQUE (order_number)",
"Check (total)": "ADD CONSTRAINT ck_total_positive CHECK (total > 0)",
"Check (status)": "ADD CONSTRAINT ck_status_valid CHECK (status IN ('pending', 'completed', 'cancelled'))",
"Foreign Key": "FOREIGN KEY",
}
for name, expected := range expectedConstraints {
if !strings.Contains(output, expected) {
t.Errorf("Output missing %s constraint: %s\nFull output:\n%s", name, expected, output)
}
}
// Verify section headers
sections := []string{
"-- Primary keys for schema: public",
"-- Unique constraints for schema: public",
"-- Check constraints for schema: public",
"-- Foreign keys for schema: public",
}
for _, section := range sections {
if !strings.Contains(output, section) {
t.Errorf("Output missing section header: %s", section)
}
}
}
func TestWriteTable(t *testing.T) {
// Create a single table
table := models.InitTable("products", "public")
@@ -241,3 +531,327 @@ func TestIsIntegerType(t *testing.T) {
}
}
}
func TestTypeConversion(t *testing.T) {
// Test that invalid Go types are converted to valid PostgreSQL types
db := models.InitDatabase("testdb")
schema := models.InitSchema("public")
// Create a test table with Go types instead of SQL types
table := models.InitTable("test_types", "public")
// Add columns with Go types (invalid for PostgreSQL)
stringCol := models.InitColumn("name", "test_types", "public")
stringCol.Type = "string" // Should be converted to "text"
table.Columns["name"] = stringCol
int64Col := models.InitColumn("big_id", "test_types", "public")
int64Col.Type = "int64" // Should be converted to "bigint"
table.Columns["big_id"] = int64Col
int16Col := models.InitColumn("small_id", "test_types", "public")
int16Col.Type = "int16" // Should be converted to "smallint"
table.Columns["small_id"] = int16Col
schema.Tables = append(schema.Tables, table)
db.Schemas = append(db.Schemas, schema)
// Create writer with output to buffer
var buf bytes.Buffer
options := &writers.WriterOptions{}
writer := NewWriter(options)
writer.writer = &buf
// Write the database
err := writer.WriteDatabase(db)
if err != nil {
t.Fatalf("WriteDatabase failed: %v", err)
}
output := buf.String()
// Print output for debugging
t.Logf("Generated SQL:\n%s", output)
// Verify that Go types were converted to PostgreSQL types
if strings.Contains(output, "string") {
t.Errorf("Output contains 'string' type - should be converted to 'text'\nFull output:\n%s", output)
}
if strings.Contains(output, "int64") {
t.Errorf("Output contains 'int64' type - should be converted to 'bigint'\nFull output:\n%s", output)
}
if strings.Contains(output, "int16") {
t.Errorf("Output contains 'int16' type - should be converted to 'smallint'\nFull output:\n%s", output)
}
// Verify correct PostgreSQL types are present
if !strings.Contains(output, "text") {
t.Errorf("Output missing 'text' type (converted from 'string')\nFull output:\n%s", output)
}
if !strings.Contains(output, "bigint") {
t.Errorf("Output missing 'bigint' type (converted from 'int64')\nFull output:\n%s", output)
}
if !strings.Contains(output, "smallint") {
t.Errorf("Output missing 'smallint' type (converted from 'int16')\nFull output:\n%s", output)
}
}
func TestPrimaryKeyExistenceCheck(t *testing.T) {
db := models.InitDatabase("testdb")
schema := models.InitSchema("public")
table := models.InitTable("products", "public")
idCol := models.InitColumn("id", "products", "public")
idCol.Type = "integer"
idCol.IsPrimaryKey = true
table.Columns["id"] = idCol
nameCol := models.InitColumn("name", "products", "public")
nameCol.Type = "text"
table.Columns["name"] = nameCol
schema.Tables = append(schema.Tables, table)
db.Schemas = append(db.Schemas, schema)
var buf bytes.Buffer
options := &writers.WriterOptions{}
writer := NewWriter(options)
writer.writer = &buf
err := writer.WriteDatabase(db)
if err != nil {
t.Fatalf("WriteDatabase failed: %v", err)
}
output := buf.String()
t.Logf("Generated SQL:\n%s", output)
// Verify our naming convention is used
if !strings.Contains(output, "pk_public_products") {
t.Errorf("Output missing expected primary key name 'pk_public_products'\nFull output:\n%s", output)
}
// Verify it drops auto-generated primary keys
if !strings.Contains(output, "products_pkey") || !strings.Contains(output, "DROP CONSTRAINT") {
t.Errorf("Output missing logic to drop auto-generated primary key\nFull output:\n%s", output)
}
// Verify it checks for our specific named constraint before adding it
if !strings.Contains(output, "constraint_name = 'pk_public_products'") {
t.Errorf("Output missing check for our named primary key constraint\nFull output:\n%s", output)
}
}
func TestColumnSizeSpecifiers(t *testing.T) {
db := models.InitDatabase("testdb")
schema := models.InitSchema("public")
table := models.InitTable("test_sizes", "public")
// Integer with invalid size specifier - should ignore size
integerCol := models.InitColumn("int_col", "test_sizes", "public")
integerCol.Type = "integer"
integerCol.Length = 32
table.Columns["int_col"] = integerCol
// Bigint with invalid size specifier - should ignore size
bigintCol := models.InitColumn("bigint_col", "test_sizes", "public")
bigintCol.Type = "bigint"
bigintCol.Length = 64
table.Columns["bigint_col"] = bigintCol
// Smallint with invalid size specifier - should ignore size
smallintCol := models.InitColumn("smallint_col", "test_sizes", "public")
smallintCol.Type = "smallint"
smallintCol.Length = 16
table.Columns["smallint_col"] = smallintCol
// Text with length - should convert to varchar
textCol := models.InitColumn("text_col", "test_sizes", "public")
textCol.Type = "text"
textCol.Length = 100
table.Columns["text_col"] = textCol
// Varchar with length - should keep varchar with length
varcharCol := models.InitColumn("varchar_col", "test_sizes", "public")
varcharCol.Type = "varchar"
varcharCol.Length = 50
table.Columns["varchar_col"] = varcharCol
// Decimal with precision and scale - should keep them
decimalCol := models.InitColumn("decimal_col", "test_sizes", "public")
decimalCol.Type = "decimal"
decimalCol.Precision = 19
decimalCol.Scale = 4
table.Columns["decimal_col"] = decimalCol
schema.Tables = append(schema.Tables, table)
db.Schemas = append(db.Schemas, schema)
var buf bytes.Buffer
options := &writers.WriterOptions{}
writer := NewWriter(options)
writer.writer = &buf
err := writer.WriteDatabase(db)
if err != nil {
t.Fatalf("WriteDatabase failed: %v", err)
}
output := buf.String()
t.Logf("Generated SQL:\n%s", output)
// Verify invalid size specifiers are NOT present
invalidPatterns := []string{
"integer(32)",
"bigint(64)",
"smallint(16)",
"text(100)",
}
for _, pattern := range invalidPatterns {
if strings.Contains(output, pattern) {
t.Errorf("Output contains invalid pattern '%s' - PostgreSQL doesn't support this\nFull output:\n%s", pattern, output)
}
}
// Verify valid patterns ARE present
validPatterns := []string{
"integer", // without size
"bigint", // without size
"smallint", // without size
"varchar(100)", // text converted to varchar with length
"varchar(50)", // varchar with length
"decimal(19,4)", // decimal with precision and scale
}
for _, pattern := range validPatterns {
if !strings.Contains(output, pattern) {
t.Errorf("Output missing expected pattern '%s'\nFull output:\n%s", pattern, output)
}
}
}
func TestGenerateAddColumnStatements(t *testing.T) {
// Create a test database with tables that have new columns
db := models.InitDatabase("testdb")
schema := models.InitSchema("public")
// Create a table with columns
table := models.InitTable("users", "public")
// Existing column
idCol := models.InitColumn("id", "users", "public")
idCol.Type = "integer"
idCol.NotNull = true
idCol.Sequence = 1
table.Columns["id"] = idCol
// New column to be added
emailCol := models.InitColumn("email", "users", "public")
emailCol.Type = "varchar"
emailCol.Length = 255
emailCol.NotNull = true
emailCol.Sequence = 2
table.Columns["email"] = emailCol
// New column with default
statusCol := models.InitColumn("status", "users", "public")
statusCol.Type = "text"
statusCol.Default = "active"
statusCol.Sequence = 3
table.Columns["status"] = statusCol
schema.Tables = append(schema.Tables, table)
db.Schemas = append(db.Schemas, schema)
// Create writer
options := &writers.WriterOptions{}
writer := NewWriter(options)
// Generate ADD COLUMN statements
statements, err := writer.GenerateAddColumnsForDatabase(db)
if err != nil {
t.Fatalf("GenerateAddColumnsForDatabase failed: %v", err)
}
// Join all statements to verify content
output := strings.Join(statements, "\n")
t.Logf("Generated ADD COLUMN statements:\n%s", output)
// Verify expected elements
expectedStrings := []string{
"ALTER TABLE public.users ADD COLUMN id integer NOT NULL",
"ALTER TABLE public.users ADD COLUMN email varchar(255) NOT NULL",
"ALTER TABLE public.users ADD COLUMN status text DEFAULT 'active'",
"information_schema.columns",
"table_schema = 'public'",
"table_name = 'users'",
"column_name = 'id'",
"column_name = 'email'",
"column_name = 'status'",
}
for _, expected := range expectedStrings {
if !strings.Contains(output, expected) {
t.Errorf("Output missing expected string: %s\nFull output:\n%s", expected, output)
}
}
// Verify DO blocks are present for conditional adds
doBlockCount := strings.Count(output, "DO $$")
if doBlockCount < 3 {
t.Errorf("Expected at least 3 DO blocks (one per column), got %d", doBlockCount)
}
// Verify IF NOT EXISTS logic
ifNotExistsCount := strings.Count(output, "IF NOT EXISTS")
if ifNotExistsCount < 3 {
t.Errorf("Expected at least 3 IF NOT EXISTS checks (one per column), got %d", ifNotExistsCount)
}
}
func TestWriteAddColumnStatements(t *testing.T) {
// Create a test database
db := models.InitDatabase("testdb")
schema := models.InitSchema("public")
// Create a table with a new column to be added
table := models.InitTable("products", "public")
idCol := models.InitColumn("id", "products", "public")
idCol.Type = "integer"
table.Columns["id"] = idCol
// New column with various properties
descCol := models.InitColumn("description", "products", "public")
descCol.Type = "text"
descCol.NotNull = false
table.Columns["description"] = descCol
schema.Tables = append(schema.Tables, table)
db.Schemas = append(db.Schemas, schema)
// Create writer with output to buffer
var buf bytes.Buffer
options := &writers.WriterOptions{}
writer := NewWriter(options)
writer.writer = &buf
// Write ADD COLUMN statements
err := writer.WriteAddColumnStatements(db)
if err != nil {
t.Fatalf("WriteAddColumnStatements failed: %v", err)
}
output := buf.String()
t.Logf("Generated output:\n%s", output)
// Verify output contains expected elements
if !strings.Contains(output, "ALTER TABLE public.products ADD COLUMN id integer") {
t.Errorf("Output missing ADD COLUMN for id\nFull output:\n%s", output)
}
if !strings.Contains(output, "ALTER TABLE public.products ADD COLUMN description text") {
t.Errorf("Output missing ADD COLUMN for description\nFull output:\n%s", output)
}
if !strings.Contains(output, "DO $$") {
t.Errorf("Output missing DO block\nFull output:\n%s", output)
}
}

View File

@@ -4,7 +4,7 @@ The SQL Executor Writer (`sqlexec`) executes SQL scripts from `models.Script` ob
## Features
- **Ordered Execution**: Scripts execute in Priority→Sequence order
- **Ordered Execution**: Scripts execute in Priority→Sequence→Name order
- **PostgreSQL Support**: Uses `pgx/v5` driver for robust PostgreSQL connectivity
- **Stop on Error**: Execution halts immediately on first error (default behavior)
- **Progress Reporting**: Prints execution status to stdout
@@ -103,19 +103,40 @@ Scripts are sorted and executed based on:
1. **Priority** (ascending): Lower priority values execute first
2. **Sequence** (ascending): Within same priority, lower sequence values execute first
3. **Name** (ascending): Within same priority and sequence, alphabetical order by name
### Example Execution Order
Given these scripts:
```
Script A: Priority=2, Sequence=1
Script B: Priority=1, Sequence=3
Script C: Priority=1, Sequence=1
Script D: Priority=1, Sequence=2
Script E: Priority=3, Sequence=1
Script A: Priority=2, Sequence=1, Name="zebra"
Script B: Priority=1, Sequence=3, Name="script"
Script C: Priority=1, Sequence=1, Name="apple"
Script D: Priority=1, Sequence=1, Name="beta"
Script E: Priority=3, Sequence=1, Name="script"
```
Execution order: **C → D → B → A → E**
Execution order: **C (apple) → D (beta) → B → A → E**
### Directory-based Sorting Example
Given these files:
```
1_001_create_schema.sql
1_001_create_users.sql ← Alphabetically before "drop_tables"
1_001_drop_tables.sql
1_002_add_indexes.sql
2_001_constraints.sql
```
Execution order (note alphabetical sorting at same priority/sequence):
```
1_001_create_schema.sql
1_001_create_users.sql
1_001_drop_tables.sql
1_002_add_indexes.sql
2_001_constraints.sql
```
## Output

View File

@@ -23,6 +23,11 @@ func NewWriter(options *writers.WriterOptions) *Writer {
}
}
// Options returns the writer options (useful for reading execution results)
func (w *Writer) Options() *writers.WriterOptions {
return w.options
}
// WriteDatabase executes all scripts from all schemas in the database
func (w *Writer) WriteDatabase(db *models.Database) error {
if db == nil {
@@ -86,20 +91,39 @@ func (w *Writer) WriteTable(table *models.Table) error {
return fmt.Errorf("WriteTable is not supported for SQL script execution")
}
// executeScripts executes scripts in Priority then Sequence order
// executeScripts executes scripts in Priority, Sequence, then Name order
func (w *Writer) executeScripts(ctx context.Context, conn *pgx.Conn, scripts []*models.Script) error {
if len(scripts) == 0 {
return nil
}
// Sort scripts by Priority (ascending) then Sequence (ascending)
// Check if we should ignore errors
ignoreErrors := false
if val, ok := w.options.Metadata["ignore_errors"].(bool); ok {
ignoreErrors = val
}
// Track failed scripts and execution counts
var failedScripts []struct {
name string
priority int
sequence uint
err error
}
successCount := 0
totalCount := 0
// Sort scripts by Priority (ascending), Sequence (ascending), then Name (ascending)
sortedScripts := make([]*models.Script, len(scripts))
copy(sortedScripts, scripts)
sort.Slice(sortedScripts, func(i, j int) bool {
if sortedScripts[i].Priority != sortedScripts[j].Priority {
return sortedScripts[i].Priority < sortedScripts[j].Priority
}
return sortedScripts[i].Sequence < sortedScripts[j].Sequence
if sortedScripts[i].Sequence != sortedScripts[j].Sequence {
return sortedScripts[i].Sequence < sortedScripts[j].Sequence
}
return sortedScripts[i].Name < sortedScripts[j].Name
})
// Execute each script in order
@@ -108,18 +132,49 @@ func (w *Writer) executeScripts(ctx context.Context, conn *pgx.Conn, scripts []*
continue
}
totalCount++
fmt.Printf("Executing script: %s (Priority=%d, Sequence=%d)\n",
script.Name, script.Priority, script.Sequence)
// Execute the SQL script
_, err := conn.Exec(ctx, script.SQL)
if err != nil {
return fmt.Errorf("failed to execute script %s (Priority=%d, Sequence=%d): %w",
if ignoreErrors {
fmt.Printf("⚠ Error executing %s: %v (continuing due to --ignore-errors)\n", script.Name, err)
failedScripts = append(failedScripts, struct {
name string
priority int
sequence uint
err error
}{
name: script.Name,
priority: script.Priority,
sequence: script.Sequence,
err: err,
})
continue
}
return fmt.Errorf("script %s (Priority=%d, Sequence=%d): %w",
script.Name, script.Priority, script.Sequence, err)
}
successCount++
fmt.Printf("✓ Successfully executed: %s\n", script.Name)
}
// Store execution results in metadata for caller
w.options.Metadata["execution_total"] = totalCount
w.options.Metadata["execution_success"] = successCount
w.options.Metadata["execution_failed"] = len(failedScripts)
// Print summary of failed scripts if any
if len(failedScripts) > 0 {
fmt.Printf("\n⚠ Failed Scripts Summary (%d failed):\n", len(failedScripts))
for i, failed := range failedScripts {
fmt.Printf(" %d. %s (Priority=%d, Sequence=%d)\n Error: %v\n",
i+1, failed.name, failed.priority, failed.sequence, failed.err)
}
}
return nil
}

View File

@@ -99,13 +99,13 @@ func TestWriter_WriteTable(t *testing.T) {
}
}
// TestScriptSorting verifies that scripts are sorted correctly by Priority then Sequence
// TestScriptSorting verifies that scripts are sorted correctly by Priority, Sequence, then Name
func TestScriptSorting(t *testing.T) {
scripts := []*models.Script{
{Name: "script1", Priority: 2, Sequence: 1, SQL: "SELECT 1;"},
{Name: "z_script1", Priority: 2, Sequence: 1, SQL: "SELECT 1;"},
{Name: "script2", Priority: 1, Sequence: 3, SQL: "SELECT 2;"},
{Name: "script3", Priority: 1, Sequence: 1, SQL: "SELECT 3;"},
{Name: "script4", Priority: 1, Sequence: 2, SQL: "SELECT 4;"},
{Name: "a_script3", Priority: 1, Sequence: 1, SQL: "SELECT 3;"},
{Name: "b_script4", Priority: 1, Sequence: 1, SQL: "SELECT 4;"},
{Name: "script5", Priority: 3, Sequence: 1, SQL: "SELECT 5;"},
{Name: "script6", Priority: 2, Sequence: 2, SQL: "SELECT 6;"},
}
@@ -114,25 +114,35 @@ func TestScriptSorting(t *testing.T) {
sortedScripts := make([]*models.Script, len(scripts))
copy(sortedScripts, scripts)
// Use the same sorting logic from executeScripts
// Sort by Priority, Sequence, then Name (matching executeScripts logic)
for i := 0; i < len(sortedScripts)-1; i++ {
for j := i + 1; j < len(sortedScripts); j++ {
if sortedScripts[i].Priority > sortedScripts[j].Priority ||
(sortedScripts[i].Priority == sortedScripts[j].Priority &&
sortedScripts[i].Sequence > sortedScripts[j].Sequence) {
si, sj := sortedScripts[i], sortedScripts[j]
// Compare by priority first
if si.Priority > sj.Priority {
sortedScripts[i], sortedScripts[j] = sortedScripts[j], sortedScripts[i]
} else if si.Priority == sj.Priority {
// If same priority, compare by sequence
if si.Sequence > sj.Sequence {
sortedScripts[i], sortedScripts[j] = sortedScripts[j], sortedScripts[i]
} else if si.Sequence == sj.Sequence {
// If same sequence, compare by name
if si.Name > sj.Name {
sortedScripts[i], sortedScripts[j] = sortedScripts[j], sortedScripts[i]
}
}
}
}
}
// Expected order after sorting
// Expected order after sorting (Priority -> Sequence -> Name)
expectedOrder := []string{
"script3", // Priority 1, Sequence 1
"script4", // Priority 1, Sequence 2
"script2", // Priority 1, Sequence 3
"script1", // Priority 2, Sequence 1
"script6", // Priority 2, Sequence 2
"script5", // Priority 3, Sequence 1
"a_script3", // Priority 1, Sequence 1, Name a_script3
"b_script4", // Priority 1, Sequence 1, Name b_script4
"script2", // Priority 1, Sequence 3
"z_script1", // Priority 2, Sequence 1
"script6", // Priority 2, Sequence 2
"script5", // Priority 3, Sequence 1
}
for i, expected := range expectedOrder {
@@ -153,6 +163,13 @@ func TestScriptSorting(t *testing.T) {
t.Errorf("Sequence not ascending at position %d with same priority %d: %d > %d",
i, sortedScripts[i].Priority, sortedScripts[i].Sequence, sortedScripts[i+1].Sequence)
}
// Within same priority and sequence, names should be ascending
if sortedScripts[i].Priority == sortedScripts[i+1].Priority &&
sortedScripts[i].Sequence == sortedScripts[i+1].Sequence &&
sortedScripts[i].Name > sortedScripts[i+1].Name {
t.Errorf("Name not ascending at position %d with same priority/sequence: %s > %s",
i, sortedScripts[i].Name, sortedScripts[i+1].Name)
}
}
}

View File

@@ -28,10 +28,29 @@ type WriterOptions struct {
// PackageName is the Go package name (for code generation)
PackageName string
// FlattenSchema disables schema.table dot notation and instead joins
// schema and table with an underscore (e.g., "public_users").
// Useful for databases like SQLite that do not support schemas.
FlattenSchema bool
// Additional options can be added here as needed
Metadata map[string]interface{}
}
// QualifiedTableName returns a schema-qualified table name.
// When flatten is true, schema and table are joined with underscore (e.g., "schema_table").
// When flatten is false, they are dot-separated (e.g., "schema.table").
// If schema is empty, just the table name is returned regardless of flatten.
func QualifiedTableName(schema, table string, flatten bool) string {
if schema == "" {
return table
}
if flatten {
return schema + "_" + table
}
return schema + "." + table
}
// SanitizeFilename removes quotes, comments, and invalid characters from identifiers
// to make them safe for use in filenames. This handles:
// - Double and single quotes: "table_name" or 'table_name' -> table_name

346
vendor/modules.txt vendored
View File

@@ -1,6 +1,92 @@
# 4d63.com/gocheckcompilerdirectives v1.3.0
## explicit; go 1.22.0
# 4d63.com/gochecknoglobals v0.2.2
## explicit; go 1.18
# github.com/4meepo/tagalign v1.4.2
## explicit; go 1.22.0
# github.com/Abirdcfly/dupword v0.1.3
## explicit; go 1.22.0
# github.com/Antonboom/errname v1.0.0
## explicit; go 1.22.1
# github.com/Antonboom/nilnil v1.0.1
## explicit; go 1.22.0
# github.com/Antonboom/testifylint v1.5.2
## explicit; go 1.22.1
# github.com/BurntSushi/toml v1.4.1-0.20240526193622-a339e1f7089c
## explicit; go 1.18
# github.com/Crocmagnon/fatcontext v0.7.1
## explicit; go 1.22.0
# github.com/Djarvur/go-err113 v0.0.0-20210108212216-aea10b59be24
## explicit; go 1.13
# github.com/GaijinEntertainment/go-exhaustruct/v3 v3.3.1
## explicit; go 1.23.0
# github.com/Masterminds/semver/v3 v3.3.0
## explicit; go 1.21
# github.com/OpenPeeDeeP/depguard/v2 v2.2.1
## explicit; go 1.23.0
# github.com/alecthomas/go-check-sumtype v0.3.1
## explicit; go 1.22.0
# github.com/alexkohler/nakedret/v2 v2.0.5
## explicit; go 1.21
# github.com/alexkohler/prealloc v1.0.0
## explicit; go 1.15
# github.com/alingse/asasalint v0.0.11
## explicit; go 1.18
# github.com/alingse/nilnesserr v0.1.2
## explicit; go 1.22.0
# github.com/ashanbrown/forbidigo v1.6.0
## explicit; go 1.13
# github.com/ashanbrown/makezero v1.2.0
## explicit; go 1.12
# github.com/beorn7/perks v1.0.1
## explicit; go 1.11
# github.com/bkielbasa/cyclop v1.2.3
## explicit; go 1.22.0
# github.com/blizzy78/varnamelen v0.8.0
## explicit; go 1.16
# github.com/bombsimon/wsl/v4 v4.5.0
## explicit; go 1.22
# github.com/breml/bidichk v0.3.2
## explicit; go 1.22.0
# github.com/breml/errchkjson v0.4.0
## explicit; go 1.22.0
# github.com/butuzov/ireturn v0.3.1
## explicit; go 1.18
# github.com/butuzov/mirror v1.3.0
## explicit; go 1.19
# github.com/catenacyber/perfsprint v0.8.2
## explicit; go 1.22.0
# github.com/ccojocar/zxcvbn-go v1.0.2
## explicit; go 1.20
# github.com/cespare/xxhash/v2 v2.3.0
## explicit; go 1.11
# github.com/charithe/durationcheck v0.0.10
## explicit; go 1.14
# github.com/chavacava/garif v0.1.0
## explicit; go 1.16
# github.com/ckaznocha/intrange v0.3.0
## explicit; go 1.22
# github.com/curioswitch/go-reassign v0.3.0
## explicit; go 1.21
# github.com/daixiang0/gci v0.13.5
## explicit; go 1.21
# github.com/davecgh/go-spew v1.1.1
## explicit
github.com/davecgh/go-spew/spew
# github.com/denis-tingaikin/go-header v0.5.0
## explicit; go 1.21
# github.com/ettle/strcase v0.2.0
## explicit; go 1.12
# github.com/fatih/color v1.18.0
## explicit; go 1.17
# github.com/fatih/structtag v1.2.0
## explicit; go 1.12
# github.com/firefart/nonamedreturns v1.0.5
## explicit; go 1.18
# github.com/fsnotify/fsnotify v1.5.4
## explicit; go 1.16
# github.com/fzipp/gocyclo v0.6.0
## explicit; go 1.18
# github.com/gdamore/encoding v1.0.1
## explicit; go 1.9
github.com/gdamore/encoding
@@ -44,9 +130,75 @@ github.com/gdamore/tcell/v2/terminfo/x/xfce
github.com/gdamore/tcell/v2/terminfo/x/xterm
github.com/gdamore/tcell/v2/terminfo/x/xterm_ghostty
github.com/gdamore/tcell/v2/terminfo/x/xterm_kitty
# github.com/ghostiam/protogetter v0.3.9
## explicit; go 1.22.0
# github.com/go-critic/go-critic v0.12.0
## explicit; go 1.22.0
# github.com/go-toolsmith/astcast v1.1.0
## explicit; go 1.16
# github.com/go-toolsmith/astcopy v1.1.0
## explicit; go 1.16
# github.com/go-toolsmith/astequal v1.2.0
## explicit; go 1.18
# github.com/go-toolsmith/astfmt v1.1.0
## explicit; go 1.16
# github.com/go-toolsmith/astp v1.1.0
## explicit; go 1.16
# github.com/go-toolsmith/strparse v1.1.0
## explicit; go 1.16
# github.com/go-toolsmith/typep v1.1.0
## explicit; go 1.16
# github.com/go-viper/mapstructure/v2 v2.2.1
## explicit; go 1.18
# github.com/go-xmlfmt/xmlfmt v1.1.3
## explicit
# github.com/gobwas/glob v0.2.3
## explicit
# github.com/gofrs/flock v0.12.1
## explicit; go 1.21.0
# github.com/golang/protobuf v1.5.3
## explicit; go 1.9
# github.com/golangci/dupl v0.0.0-20250308024227-f665c8d69b32
## explicit; go 1.22.0
# github.com/golangci/go-printf-func-name v0.1.0
## explicit; go 1.22.0
# github.com/golangci/gofmt v0.0.0-20250106114630-d62b90e6713d
## explicit; go 1.22.0
# github.com/golangci/golangci-lint v1.64.8
## explicit; go 1.23.0
# github.com/golangci/misspell v0.6.0
## explicit; go 1.21
# github.com/golangci/plugin-module-register v0.1.1
## explicit; go 1.21
# github.com/golangci/revgrep v0.8.0
## explicit; go 1.21
# github.com/golangci/unconvert v0.0.0-20240309020433-c5143eacb3ed
## explicit; go 1.20
# github.com/google/go-cmp v0.7.0
## explicit; go 1.21
# github.com/google/uuid v1.6.0
## explicit
github.com/google/uuid
# github.com/gordonklaus/ineffassign v0.1.0
## explicit; go 1.14
# github.com/gostaticanalysis/analysisutil v0.7.1
## explicit; go 1.16
# github.com/gostaticanalysis/comment v1.5.0
## explicit; go 1.22.9
# github.com/gostaticanalysis/forcetypeassert v0.2.0
## explicit; go 1.23.0
# github.com/gostaticanalysis/nilerr v0.1.1
## explicit; go 1.15
# github.com/hashicorp/go-immutable-radix/v2 v2.1.0
## explicit; go 1.18
# github.com/hashicorp/go-version v1.7.0
## explicit
# github.com/hashicorp/golang-lru/v2 v2.0.7
## explicit; go 1.18
# github.com/hashicorp/hcl v1.0.0
## explicit
# github.com/hexops/gotextdiff v1.0.3
## explicit; go 1.16
# github.com/inconshreveable/mousetrap v1.1.0
## explicit; go 1.18
github.com/inconshreveable/mousetrap
@@ -68,23 +220,115 @@ github.com/jackc/pgx/v5/pgconn/ctxwatch
github.com/jackc/pgx/v5/pgconn/internal/bgreader
github.com/jackc/pgx/v5/pgproto3
github.com/jackc/pgx/v5/pgtype
# github.com/jgautheron/goconst v1.7.1
## explicit; go 1.13
# github.com/jingyugao/rowserrcheck v1.1.1
## explicit; go 1.13
# github.com/jinzhu/inflection v1.0.0
## explicit
github.com/jinzhu/inflection
# github.com/jjti/go-spancheck v0.6.4
## explicit; go 1.22.1
# github.com/julz/importas v0.2.0
## explicit; go 1.20
# github.com/karamaru-alpha/copyloopvar v1.2.1
## explicit; go 1.21
# github.com/kisielk/errcheck v1.9.0
## explicit; go 1.22.0
# github.com/kkHAIKE/contextcheck v1.1.6
## explicit; go 1.23.0
# github.com/kr/pretty v0.3.1
## explicit; go 1.12
# github.com/kulti/thelper v0.6.3
## explicit; go 1.18
# github.com/kunwardeep/paralleltest v1.0.10
## explicit; go 1.17
# github.com/lasiar/canonicalheader v1.1.2
## explicit; go 1.22.0
# github.com/ldez/exptostd v0.4.2
## explicit; go 1.22.0
# github.com/ldez/gomoddirectives v0.6.1
## explicit; go 1.22.0
# github.com/ldez/grignotin v0.9.0
## explicit; go 1.22.0
# github.com/ldez/tagliatelle v0.7.1
## explicit; go 1.22.0
# github.com/ldez/usetesting v0.4.2
## explicit; go 1.22.0
# github.com/leonklingele/grouper v1.1.2
## explicit; go 1.18
# github.com/lucasb-eyer/go-colorful v1.2.0
## explicit; go 1.12
github.com/lucasb-eyer/go-colorful
# github.com/macabu/inamedparam v0.1.3
## explicit; go 1.20
# github.com/magiconair/properties v1.8.6
## explicit; go 1.13
# github.com/maratori/testableexamples v1.0.0
## explicit; go 1.19
# github.com/maratori/testpackage v1.1.1
## explicit; go 1.20
# github.com/matoous/godox v1.1.0
## explicit; go 1.18
# github.com/mattn/go-colorable v0.1.14
## explicit; go 1.18
# github.com/mattn/go-isatty v0.0.20
## explicit; go 1.15
# github.com/mattn/go-runewidth v0.0.16
## explicit; go 1.9
github.com/mattn/go-runewidth
# github.com/matttproud/golang_protobuf_extensions v1.0.1
## explicit
# github.com/mgechev/revive v1.7.0
## explicit; go 1.22.1
# github.com/mitchellh/go-homedir v1.1.0
## explicit
# github.com/mitchellh/mapstructure v1.5.0
## explicit; go 1.14
# github.com/moricho/tparallel v0.3.2
## explicit; go 1.20
# github.com/nakabonne/nestif v0.3.1
## explicit; go 1.15
# github.com/nishanths/exhaustive v0.12.0
## explicit; go 1.18
# github.com/nishanths/predeclared v0.2.2
## explicit; go 1.14
# github.com/nunnatsa/ginkgolinter v0.19.1
## explicit; go 1.23.0
# github.com/olekukonko/tablewriter v0.0.5
## explicit; go 1.12
# github.com/pelletier/go-toml v1.9.5
## explicit; go 1.12
# github.com/pelletier/go-toml/v2 v2.2.3
## explicit; go 1.21.0
# github.com/pmezard/go-difflib v1.0.0
## explicit
github.com/pmezard/go-difflib/difflib
# github.com/polyfloyd/go-errorlint v1.7.1
## explicit; go 1.22.0
# github.com/prometheus/client_golang v1.12.1
## explicit; go 1.13
# github.com/prometheus/client_model v0.2.0
## explicit; go 1.9
# github.com/prometheus/common v0.32.1
## explicit; go 1.13
# github.com/prometheus/procfs v0.7.3
## explicit; go 1.13
# github.com/puzpuzpuz/xsync/v3 v3.5.1
## explicit; go 1.18
github.com/puzpuzpuz/xsync/v3
# github.com/quasilyte/go-ruleguard v0.4.3-0.20240823090925-0fe6f58b47b1
## explicit; go 1.19
# github.com/quasilyte/go-ruleguard/dsl v0.3.22
## explicit; go 1.15
# github.com/quasilyte/gogrep v0.5.0
## explicit; go 1.16
# github.com/quasilyte/regex/syntax v0.0.0-20210819130434-b3f0c404a727
## explicit; go 1.14
# github.com/quasilyte/stdinfo v0.0.0-20220114132959-f7386bf02567
## explicit; go 1.17
# github.com/raeperd/recvcheck v0.2.0
## explicit; go 1.22.0
# github.com/rivo/tview v0.42.0
## explicit; go 1.18
github.com/rivo/tview
@@ -93,20 +337,76 @@ github.com/rivo/tview
github.com/rivo/uniseg
# github.com/rogpeppe/go-internal v1.14.1
## explicit; go 1.23
# github.com/ryancurrah/gomodguard v1.3.5
## explicit; go 1.22.0
# github.com/ryanrolds/sqlclosecheck v0.5.1
## explicit; go 1.20
# github.com/sanposhiho/wastedassign/v2 v2.1.0
## explicit; go 1.18
# github.com/santhosh-tekuri/jsonschema/v6 v6.0.1
## explicit; go 1.21
# github.com/sashamelentyev/interfacebloat v1.1.0
## explicit; go 1.18
# github.com/sashamelentyev/usestdlibvars v1.28.0
## explicit; go 1.20
# github.com/securego/gosec/v2 v2.22.2
## explicit; go 1.23.0
# github.com/sirupsen/logrus v1.9.3
## explicit; go 1.13
# github.com/sivchari/containedctx v1.0.3
## explicit; go 1.17
# github.com/sivchari/tenv v1.12.1
## explicit; go 1.22.0
# github.com/sonatard/noctx v0.1.0
## explicit; go 1.22.0
# github.com/sourcegraph/go-diff v0.7.0
## explicit; go 1.14
# github.com/spf13/afero v1.12.0
## explicit; go 1.21
# github.com/spf13/cast v1.5.0
## explicit; go 1.18
# github.com/spf13/cobra v1.10.2
## explicit; go 1.15
github.com/spf13/cobra
# github.com/spf13/jwalterweatherman v1.1.0
## explicit
# github.com/spf13/pflag v1.0.10
## explicit; go 1.12
github.com/spf13/pflag
# github.com/spf13/viper v1.12.0
## explicit; go 1.17
# github.com/ssgreg/nlreturn/v2 v2.2.1
## explicit; go 1.13
# github.com/stbenjam/no-sprintf-host-port v0.2.0
## explicit; go 1.18
# github.com/stretchr/objx v0.5.2
## explicit; go 1.20
# github.com/stretchr/testify v1.11.1
## explicit; go 1.17
github.com/stretchr/testify/assert
github.com/stretchr/testify/assert/yaml
github.com/stretchr/testify/require
# github.com/subosito/gotenv v1.4.1
## explicit; go 1.18
# github.com/tdakkota/asciicheck v0.4.1
## explicit; go 1.22.0
# github.com/tetafro/godot v1.5.0
## explicit; go 1.20
# github.com/timakin/bodyclose v0.0.0-20241017074812-ed6a65f985e3
## explicit; go 1.12
# github.com/timonwong/loggercheck v0.10.1
## explicit; go 1.22.0
# github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc
## explicit
github.com/tmthrgd/go-hex
# github.com/tomarrell/wrapcheck/v2 v2.10.0
## explicit; go 1.21
# github.com/tommy-muehle/go-mnd/v2 v2.5.1
## explicit; go 1.12
# github.com/ultraware/funlen v0.2.0
## explicit; go 1.22.0
# github.com/ultraware/whitespace v0.2.0
## explicit; go 1.20
# github.com/uptrace/bun v1.2.16
## explicit; go 1.24.0
github.com/uptrace/bun
@@ -118,6 +418,10 @@ github.com/uptrace/bun/internal
github.com/uptrace/bun/internal/parser
github.com/uptrace/bun/internal/tagparser
github.com/uptrace/bun/schema
# github.com/uudashr/gocognit v1.2.0
## explicit; go 1.19
# github.com/uudashr/iface v1.3.1
## explicit; go 1.22.1
# github.com/vmihailenco/msgpack/v5 v5.4.1
## explicit; go 1.19
github.com/vmihailenco/msgpack/v5
@@ -127,9 +431,37 @@ github.com/vmihailenco/msgpack/v5/msgpcode
github.com/vmihailenco/tagparser/v2
github.com/vmihailenco/tagparser/v2/internal
github.com/vmihailenco/tagparser/v2/internal/parser
# github.com/xen0n/gosmopolitan v1.2.2
## explicit; go 1.19
# github.com/yagipy/maintidx v1.0.0
## explicit; go 1.17
# github.com/yeya24/promlinter v0.3.0
## explicit; go 1.20
# github.com/ykadowak/zerologlint v0.1.5
## explicit; go 1.19
# gitlab.com/bosi/decorder v0.4.2
## explicit; go 1.20
# go-simpler.org/musttag v0.13.0
## explicit; go 1.20
# go-simpler.org/sloglint v0.9.0
## explicit; go 1.22.0
# go.uber.org/atomic v1.7.0
## explicit; go 1.13
# go.uber.org/automaxprocs v1.6.0
## explicit; go 1.20
# go.uber.org/multierr v1.6.0
## explicit; go 1.12
# go.uber.org/zap v1.24.0
## explicit; go 1.19
# golang.org/x/crypto v0.41.0
## explicit; go 1.23.0
golang.org/x/crypto/pbkdf2
# golang.org/x/exp/typeparams v0.0.0-20250210185358-939b2ce775ac
## explicit; go 1.18
# golang.org/x/mod v0.26.0
## explicit; go 1.23.0
# golang.org/x/sync v0.16.0
## explicit; go 1.23.0
# golang.org/x/sys v0.38.0
## explicit; go 1.24.0
golang.org/x/sys/cpu
@@ -156,6 +488,20 @@ golang.org/x/text/transform
golang.org/x/text/unicode/bidi
golang.org/x/text/unicode/norm
golang.org/x/text/width
# golang.org/x/tools v0.35.0
## explicit; go 1.23.0
# google.golang.org/protobuf v1.36.5
## explicit; go 1.21
# gopkg.in/ini.v1 v1.67.0
## explicit
# gopkg.in/yaml.v2 v2.4.0
## explicit; go 1.15
# gopkg.in/yaml.v3 v3.0.1
## explicit
gopkg.in/yaml.v3
# honnef.co/go/tools v0.6.1
## explicit; go 1.23
# mvdan.cc/gofumpt v0.7.0
## explicit; go 1.22
# mvdan.cc/unparam v0.0.0-20240528143540-8a5130ca722f
## explicit; go 1.21