Compare commits
36 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5fb09b78c3 | ||
| 5d9770b430 | |||
| f2d500f98d | |||
| 2ec9991324 | |||
| a3e45c206d | |||
| 165623bb1d | |||
| 3c20c3c5d9 | |||
| a54594e49b | |||
| cafe6a461f | |||
| abdb9b4c78 | |||
| e7a15c8e4f | |||
| c36b5ede2b | |||
| 51ab29f8e3 | |||
| f532fc110c | |||
| 92dff99725 | |||
| 283b568adb | |||
| 122743ee43 | |||
| 91b6046b9b | |||
| 6f55505444 | |||
| e0e7b64c69 | |||
| 4181cb1fbd | |||
| 120ffc6a5a | |||
| b20ad35485 | |||
| f258f8baeb | |||
| 6388daba56 | |||
| f6c3f2b460 | |||
| 156e655571 | |||
| b57e1ba304 | |||
| 19fba62f1b | |||
| b4ff4334cc | |||
| 5d9b00c8f2 | |||
| debf351c48 | |||
| d87d657275 | |||
| 1795eb64d1 | |||
| 355f0f918f | |||
| 5d3c86119e |
@@ -4,10 +4,7 @@
|
||||
"description": "Database Relations Specification Tool for Go",
|
||||
"language": "go"
|
||||
},
|
||||
"agent": {
|
||||
"preferred": "Explore",
|
||||
"description": "Use Explore agent for fast codebase navigation and Go project exploration"
|
||||
},
|
||||
|
||||
"codeStyle": {
|
||||
"useGofmt": true,
|
||||
"lineLength": 100,
|
||||
|
||||
5
.github/workflows/integration-tests.yml
vendored
5
.github/workflows/integration-tests.yml
vendored
@@ -46,6 +46,11 @@ jobs:
|
||||
- name: Download dependencies
|
||||
run: go mod download
|
||||
|
||||
- name: Install PostgreSQL client
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y postgresql-client
|
||||
|
||||
- name: Initialize test database
|
||||
env:
|
||||
PGPASSWORD: relspec_test_password
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -47,3 +47,4 @@ dist/
|
||||
build/
|
||||
bin/
|
||||
tests/integration/failed_statements_example.txt
|
||||
test_output.log
|
||||
|
||||
72
README.md
72
README.md
@@ -85,6 +85,29 @@ RelSpec includes a powerful schema validation and linting tool:
|
||||
## Use of AI
|
||||
[Rules and use of AI](./AI_USE.md)
|
||||
|
||||
## User Interface
|
||||
|
||||
RelSpec provides an interactive terminal-based user interface for managing and editing database schemas. The UI allows you to:
|
||||
|
||||
- **Browse Databases** - Navigate through your database structure with an intuitive menu system
|
||||
- **Edit Schemas** - Create, modify, and organize database schemas
|
||||
- **Manage Tables** - Add, update, or delete tables with full control over structure
|
||||
- **Configure Columns** - Define column properties, data types, constraints, and relationships
|
||||
- **Interactive Editing** - Real-time validation and feedback as you make changes
|
||||
|
||||
The interface supports multiple input formats, making it easy to load, edit, and save your database definitions in various formats.
|
||||
|
||||
<p align="center" width="100%">
|
||||
<img src="./assets/image/screenshots/main_screen.jpg">
|
||||
</p>
|
||||
<p align="center" width="100%">
|
||||
<img src="./assets/image/screenshots/table_view.jpg">
|
||||
</p>
|
||||
<p align="center" width="100%">
|
||||
<img src="./assets/image/screenshots/edit_column.jpg">
|
||||
</p>
|
||||
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
@@ -95,6 +118,55 @@ go install -v git.warky.dev/wdevs/relspecgo/cmd/relspec@latest
|
||||
|
||||
## Usage
|
||||
|
||||
### Interactive Schema Editor
|
||||
|
||||
```bash
|
||||
# Launch interactive editor with a DBML schema
|
||||
relspec edit --from dbml --from-path schema.dbml --to dbml --to-path schema.dbml
|
||||
|
||||
# Edit PostgreSQL database in place
|
||||
relspec edit --from pgsql --from-conn "postgres://user:pass@localhost/mydb" \
|
||||
--to pgsql --to-conn "postgres://user:pass@localhost/mydb"
|
||||
|
||||
# Edit JSON schema and save as GORM models
|
||||
relspec edit --from json --from-path db.json --to gorm --to-path models/
|
||||
```
|
||||
|
||||
The `edit` command launches an interactive terminal user interface where you can:
|
||||
- Browse and navigate your database structure
|
||||
- Create, modify, and delete schemas, tables, and columns
|
||||
- Configure column properties, constraints, and relationships
|
||||
- Save changes to various formats
|
||||
- Import and merge schemas from other databases
|
||||
|
||||
### Schema Merging
|
||||
|
||||
```bash
|
||||
# Merge two JSON schemas (additive merge - adds missing items only)
|
||||
relspec merge --target json --target-path base.json \
|
||||
--source json --source-path additions.json \
|
||||
--output json --output-path merged.json
|
||||
|
||||
# Merge PostgreSQL database into JSON, skipping specific tables
|
||||
relspec merge --target json --target-path current.json \
|
||||
--source pgsql --source-conn "postgres://user:pass@localhost/source_db" \
|
||||
--output json --output-path updated.json \
|
||||
--skip-tables "audit_log,temp_tables"
|
||||
|
||||
# Cross-format merge (DBML + YAML → JSON)
|
||||
relspec merge --target dbml --target-path base.dbml \
|
||||
--source yaml --source-path additions.yaml \
|
||||
--output json --output-path result.json \
|
||||
--skip-relations --skip-views
|
||||
```
|
||||
|
||||
The `merge` command combines two database schemas additively:
|
||||
- Adds missing schemas, tables, columns, and other objects
|
||||
- Never modifies or deletes existing items (safe operation)
|
||||
- Supports selective merging with skip options (domains, relations, enums, views, sequences, specific tables)
|
||||
- Works across any combination of supported formats
|
||||
- Perfect for integrating multiple schema definitions or applying patches
|
||||
|
||||
### Schema Conversion
|
||||
|
||||
```bash
|
||||
|
||||
11
TODO.md
11
TODO.md
@@ -22,6 +22,17 @@
|
||||
- [✔️] GraphQL schema generation
|
||||
|
||||
|
||||
## UI
|
||||
- [✔️] Basic UI (I went with tview)
|
||||
- [✔️] Save / Load Database
|
||||
- [✔️] Schemas / Domains / Tables
|
||||
- [ ] Add Relations
|
||||
- [ ] Add Indexes
|
||||
- [ ] Add Views
|
||||
- [ ] Add Sequences
|
||||
- [ ] Add Scripts
|
||||
- [ ] Domain / Table Assignment
|
||||
|
||||
## Documentation
|
||||
- [ ] API documentation (godoc)
|
||||
- [ ] Usage examples for each format combination
|
||||
|
||||
BIN
assets/image/screenshots/edit_column.jpg
Normal file
BIN
assets/image/screenshots/edit_column.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 42 KiB |
BIN
assets/image/screenshots/main_screen.jpg
Normal file
BIN
assets/image/screenshots/main_screen.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 50 KiB |
BIN
assets/image/screenshots/table_view.jpg
Normal file
BIN
assets/image/screenshots/table_view.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 67 KiB |
@@ -38,13 +38,14 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
convertSourceType string
|
||||
convertSourcePath string
|
||||
convertSourceConn string
|
||||
convertTargetType string
|
||||
convertTargetPath string
|
||||
convertPackageName string
|
||||
convertSchemaFilter string
|
||||
convertSourceType string
|
||||
convertSourcePath string
|
||||
convertSourceConn string
|
||||
convertTargetType string
|
||||
convertTargetPath string
|
||||
convertPackageName string
|
||||
convertSchemaFilter string
|
||||
convertFlattenSchema bool
|
||||
)
|
||||
|
||||
var convertCmd = &cobra.Command{
|
||||
@@ -148,6 +149,7 @@ func init() {
|
||||
convertCmd.Flags().StringVar(&convertTargetPath, "to-path", "", "Target output path (file or directory)")
|
||||
convertCmd.Flags().StringVar(&convertPackageName, "package", "", "Package name (for code generation formats like gorm/bun)")
|
||||
convertCmd.Flags().StringVar(&convertSchemaFilter, "schema", "", "Filter to a specific schema by name (required for formats like dctx that only support single schemas)")
|
||||
convertCmd.Flags().BoolVar(&convertFlattenSchema, "flatten-schema", false, "Flatten schema.table names to schema_table (useful for databases like SQLite that do not support schemas)")
|
||||
|
||||
err := convertCmd.MarkFlagRequired("from")
|
||||
if err != nil {
|
||||
@@ -202,7 +204,7 @@ func runConvert(cmd *cobra.Command, args []string) error {
|
||||
fmt.Fprintf(os.Stderr, " Schema: %s\n", convertSchemaFilter)
|
||||
}
|
||||
|
||||
if err := writeDatabase(db, convertTargetType, convertTargetPath, convertPackageName, convertSchemaFilter); err != nil {
|
||||
if err := writeDatabase(db, convertTargetType, convertTargetPath, convertPackageName, convertSchemaFilter, convertFlattenSchema); err != nil {
|
||||
return fmt.Errorf("failed to write target: %w", err)
|
||||
}
|
||||
|
||||
@@ -301,12 +303,13 @@ func readDatabaseForConvert(dbType, filePath, connString string) (*models.Databa
|
||||
return db, nil
|
||||
}
|
||||
|
||||
func writeDatabase(db *models.Database, dbType, outputPath, packageName, schemaFilter string) error {
|
||||
func writeDatabase(db *models.Database, dbType, outputPath, packageName, schemaFilter string, flattenSchema bool) error {
|
||||
var writer writers.Writer
|
||||
|
||||
writerOpts := &writers.WriterOptions{
|
||||
OutputPath: outputPath,
|
||||
PackageName: packageName,
|
||||
OutputPath: outputPath,
|
||||
PackageName: packageName,
|
||||
FlattenSchema: flattenSchema,
|
||||
}
|
||||
|
||||
switch strings.ToLower(dbType) {
|
||||
|
||||
334
cmd/relspec/edit.go
Normal file
334
cmd/relspec/edit.go
Normal file
@@ -0,0 +1,334 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/readers"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/readers/bun"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/readers/dbml"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/readers/dctx"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/readers/drawdb"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/readers/drizzle"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/readers/gorm"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/readers/graphql"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/readers/json"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/readers/pgsql"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/readers/prisma"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/readers/typeorm"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/readers/yaml"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/ui"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/writers"
|
||||
wbun "git.warky.dev/wdevs/relspecgo/pkg/writers/bun"
|
||||
wdbml "git.warky.dev/wdevs/relspecgo/pkg/writers/dbml"
|
||||
wdctx "git.warky.dev/wdevs/relspecgo/pkg/writers/dctx"
|
||||
wdrawdb "git.warky.dev/wdevs/relspecgo/pkg/writers/drawdb"
|
||||
wdrizzle "git.warky.dev/wdevs/relspecgo/pkg/writers/drizzle"
|
||||
wgorm "git.warky.dev/wdevs/relspecgo/pkg/writers/gorm"
|
||||
wgraphql "git.warky.dev/wdevs/relspecgo/pkg/writers/graphql"
|
||||
wjson "git.warky.dev/wdevs/relspecgo/pkg/writers/json"
|
||||
wpgsql "git.warky.dev/wdevs/relspecgo/pkg/writers/pgsql"
|
||||
wprisma "git.warky.dev/wdevs/relspecgo/pkg/writers/prisma"
|
||||
wtypeorm "git.warky.dev/wdevs/relspecgo/pkg/writers/typeorm"
|
||||
wyaml "git.warky.dev/wdevs/relspecgo/pkg/writers/yaml"
|
||||
)
|
||||
|
||||
var (
|
||||
editSourceType string
|
||||
editSourcePath string
|
||||
editSourceConn string
|
||||
editTargetType string
|
||||
editTargetPath string
|
||||
editSchemaFilter string
|
||||
)
|
||||
|
||||
var editCmd = &cobra.Command{
|
||||
Use: "edit",
|
||||
Short: "Edit database schema interactively with TUI",
|
||||
Long: `Edit database schemas from various formats using an interactive terminal UI.
|
||||
|
||||
Allows you to:
|
||||
- List and navigate schemas and tables
|
||||
- Create, edit, and delete schemas
|
||||
- Create, edit, and delete tables
|
||||
- Add, edit, and delete columns
|
||||
- Set table and column properties
|
||||
- Add constraints, indexes, and relationships
|
||||
|
||||
Supports reading from and writing to all supported formats:
|
||||
Input formats:
|
||||
- dbml: DBML schema files
|
||||
- dctx: DCTX schema files
|
||||
- drawdb: DrawDB JSON files
|
||||
- graphql: GraphQL schema files (.graphql, SDL)
|
||||
- json: JSON database schema
|
||||
- yaml: YAML database schema
|
||||
- gorm: GORM model files (Go, file or directory)
|
||||
- bun: Bun model files (Go, file or directory)
|
||||
- drizzle: Drizzle ORM schema files (TypeScript, file or directory)
|
||||
- prisma: Prisma schema files (.prisma)
|
||||
- typeorm: TypeORM entity files (TypeScript)
|
||||
- pgsql: PostgreSQL database (live connection)
|
||||
|
||||
Output formats:
|
||||
- dbml: DBML schema files
|
||||
- dctx: DCTX schema files
|
||||
- drawdb: DrawDB JSON files
|
||||
- graphql: GraphQL schema files (.graphql, SDL)
|
||||
- json: JSON database schema
|
||||
- yaml: YAML database schema
|
||||
- gorm: GORM model files (Go)
|
||||
- bun: Bun model files (Go)
|
||||
- drizzle: Drizzle ORM schema files (TypeScript)
|
||||
- prisma: Prisma schema files (.prisma)
|
||||
- typeorm: TypeORM entity files (TypeScript)
|
||||
- pgsql: PostgreSQL SQL schema
|
||||
|
||||
PostgreSQL Connection String Examples:
|
||||
postgres://username:password@localhost:5432/database_name
|
||||
postgres://username:password@localhost/database_name
|
||||
postgresql://user:pass@host:5432/dbname?sslmode=disable
|
||||
postgresql://user:pass@host/dbname?sslmode=require
|
||||
host=localhost port=5432 user=username password=pass dbname=mydb sslmode=disable
|
||||
|
||||
Examples:
|
||||
# Edit a DBML schema file
|
||||
relspec edit --from dbml --from-path schema.dbml --to dbml --to-path schema.dbml
|
||||
|
||||
# Edit a PostgreSQL database
|
||||
relspec edit --from pgsql --from-conn "postgres://user:pass@localhost/mydb" \
|
||||
--to pgsql --to-conn "postgres://user:pass@localhost/mydb"
|
||||
|
||||
# Edit JSON schema and output to GORM
|
||||
relspec edit --from json --from-path db.json --to gorm --to-path models/
|
||||
|
||||
# Edit GORM models in place
|
||||
relspec edit --from gorm --from-path ./models --to gorm --to-path ./models`,
|
||||
RunE: runEdit,
|
||||
}
|
||||
|
||||
func init() {
|
||||
editCmd.Flags().StringVar(&editSourceType, "from", "", "Source format (dbml, dctx, drawdb, graphql, json, yaml, gorm, bun, drizzle, prisma, typeorm, pgsql)")
|
||||
editCmd.Flags().StringVar(&editSourcePath, "from-path", "", "Source file path (for file-based formats)")
|
||||
editCmd.Flags().StringVar(&editSourceConn, "from-conn", "", "Source connection string (for database formats)")
|
||||
editCmd.Flags().StringVar(&editTargetType, "to", "", "Target format (dbml, dctx, drawdb, graphql, json, yaml, gorm, bun, drizzle, prisma, typeorm, pgsql)")
|
||||
editCmd.Flags().StringVar(&editTargetPath, "to-path", "", "Target file path (for file-based formats)")
|
||||
editCmd.Flags().StringVar(&editSchemaFilter, "schema", "", "Filter to a specific schema by name")
|
||||
|
||||
// Flags are now optional - if not provided, UI will prompt for load/save options
|
||||
}
|
||||
|
||||
func runEdit(cmd *cobra.Command, args []string) error {
|
||||
fmt.Fprintf(os.Stderr, "\n=== RelSpec Schema Editor ===\n")
|
||||
fmt.Fprintf(os.Stderr, "Started at: %s\n\n", getCurrentTimestamp())
|
||||
|
||||
var db *models.Database
|
||||
var loadConfig *ui.LoadConfig
|
||||
var saveConfig *ui.SaveConfig
|
||||
var err error
|
||||
|
||||
// Check if source parameters are provided
|
||||
if editSourceType != "" {
|
||||
// Read source database
|
||||
fmt.Fprintf(os.Stderr, "[1/3] Reading source schema...\n")
|
||||
fmt.Fprintf(os.Stderr, " Format: %s\n", editSourceType)
|
||||
if editSourcePath != "" {
|
||||
fmt.Fprintf(os.Stderr, " Path: %s\n", editSourcePath)
|
||||
}
|
||||
if editSourceConn != "" {
|
||||
fmt.Fprintf(os.Stderr, " Conn: %s\n", maskPassword(editSourceConn))
|
||||
}
|
||||
|
||||
db, err = readDatabaseForEdit(editSourceType, editSourcePath, editSourceConn, "Source")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read source: %w", err)
|
||||
}
|
||||
|
||||
// Apply schema filter if specified
|
||||
if editSchemaFilter != "" {
|
||||
db = filterDatabaseBySchema(db, editSchemaFilter)
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, " ✓ Successfully read database '%s'\n", db.Name)
|
||||
fmt.Fprintf(os.Stderr, " Found: %d schema(s)\n", len(db.Schemas))
|
||||
|
||||
totalTables := 0
|
||||
for _, schema := range db.Schemas {
|
||||
totalTables += len(schema.Tables)
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, " Found: %d table(s)\n\n", totalTables)
|
||||
|
||||
// Store load config
|
||||
loadConfig = &ui.LoadConfig{
|
||||
SourceType: editSourceType,
|
||||
FilePath: editSourcePath,
|
||||
ConnString: editSourceConn,
|
||||
}
|
||||
} else {
|
||||
// No source parameters provided, UI will show load screen
|
||||
fmt.Fprintf(os.Stderr, "[1/2] No source specified, editor will prompt for database\n\n")
|
||||
}
|
||||
|
||||
// Store save config if target parameters are provided
|
||||
if editTargetType != "" {
|
||||
saveConfig = &ui.SaveConfig{
|
||||
TargetType: editTargetType,
|
||||
FilePath: editTargetPath,
|
||||
}
|
||||
}
|
||||
|
||||
// Launch interactive TUI
|
||||
if editSourceType != "" {
|
||||
fmt.Fprintf(os.Stderr, "[2/3] Launching interactive editor...\n")
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "[2/2] Launching interactive editor...\n")
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, " Use arrow keys and shortcuts to navigate\n")
|
||||
fmt.Fprintf(os.Stderr, " Press ? for help\n\n")
|
||||
|
||||
editor := ui.NewSchemaEditorWithConfigs(db, loadConfig, saveConfig)
|
||||
if err := editor.Run(); err != nil {
|
||||
return fmt.Errorf("editor failed: %w", err)
|
||||
}
|
||||
|
||||
// Only write to output if target parameters were provided and database was loaded from command line
|
||||
if editTargetType != "" && editSourceType != "" && db != nil {
|
||||
fmt.Fprintf(os.Stderr, "[3/3] Writing changes to output...\n")
|
||||
fmt.Fprintf(os.Stderr, " Format: %s\n", editTargetType)
|
||||
if editTargetPath != "" {
|
||||
fmt.Fprintf(os.Stderr, " Path: %s\n", editTargetPath)
|
||||
}
|
||||
|
||||
// Get the potentially modified database from the editor
|
||||
err = writeDatabaseForEdit(editTargetType, editTargetPath, "", editor.GetDatabase(), "Target")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write output: %w", err)
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, " ✓ Successfully written database\n")
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\n=== Edit complete ===\n")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func readDatabaseForEdit(dbType, filePath, connString, label string) (*models.Database, error) {
|
||||
var reader readers.Reader
|
||||
|
||||
switch strings.ToLower(dbType) {
|
||||
case "dbml":
|
||||
if filePath == "" {
|
||||
return nil, fmt.Errorf("%s: file path is required for DBML format", label)
|
||||
}
|
||||
reader = dbml.NewReader(&readers.ReaderOptions{FilePath: filePath})
|
||||
case "dctx":
|
||||
if filePath == "" {
|
||||
return nil, fmt.Errorf("%s: file path is required for DCTX format", label)
|
||||
}
|
||||
reader = dctx.NewReader(&readers.ReaderOptions{FilePath: filePath})
|
||||
case "drawdb":
|
||||
if filePath == "" {
|
||||
return nil, fmt.Errorf("%s: file path is required for DrawDB format", label)
|
||||
}
|
||||
reader = drawdb.NewReader(&readers.ReaderOptions{FilePath: filePath})
|
||||
case "graphql":
|
||||
if filePath == "" {
|
||||
return nil, fmt.Errorf("%s: file path is required for GraphQL format", label)
|
||||
}
|
||||
reader = graphql.NewReader(&readers.ReaderOptions{FilePath: filePath})
|
||||
case "json":
|
||||
if filePath == "" {
|
||||
return nil, fmt.Errorf("%s: file path is required for JSON format", label)
|
||||
}
|
||||
reader = json.NewReader(&readers.ReaderOptions{FilePath: filePath})
|
||||
case "yaml":
|
||||
if filePath == "" {
|
||||
return nil, fmt.Errorf("%s: file path is required for YAML format", label)
|
||||
}
|
||||
reader = yaml.NewReader(&readers.ReaderOptions{FilePath: filePath})
|
||||
case "gorm":
|
||||
if filePath == "" {
|
||||
return nil, fmt.Errorf("%s: file path is required for GORM format", label)
|
||||
}
|
||||
reader = gorm.NewReader(&readers.ReaderOptions{FilePath: filePath})
|
||||
case "bun":
|
||||
if filePath == "" {
|
||||
return nil, fmt.Errorf("%s: file path is required for Bun format", label)
|
||||
}
|
||||
reader = bun.NewReader(&readers.ReaderOptions{FilePath: filePath})
|
||||
case "drizzle":
|
||||
if filePath == "" {
|
||||
return nil, fmt.Errorf("%s: file path is required for Drizzle format", label)
|
||||
}
|
||||
reader = drizzle.NewReader(&readers.ReaderOptions{FilePath: filePath})
|
||||
case "prisma":
|
||||
if filePath == "" {
|
||||
return nil, fmt.Errorf("%s: file path is required for Prisma format", label)
|
||||
}
|
||||
reader = prisma.NewReader(&readers.ReaderOptions{FilePath: filePath})
|
||||
case "typeorm":
|
||||
if filePath == "" {
|
||||
return nil, fmt.Errorf("%s: file path is required for TypeORM format", label)
|
||||
}
|
||||
reader = typeorm.NewReader(&readers.ReaderOptions{FilePath: filePath})
|
||||
case "pgsql":
|
||||
if connString == "" {
|
||||
return nil, fmt.Errorf("%s: connection string is required for PostgreSQL format", label)
|
||||
}
|
||||
reader = pgsql.NewReader(&readers.ReaderOptions{ConnectionString: connString})
|
||||
default:
|
||||
return nil, fmt.Errorf("%s: unsupported format: %s", label, dbType)
|
||||
}
|
||||
|
||||
db, err := reader.ReadDatabase()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%s: %w", label, err)
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
func writeDatabaseForEdit(dbType, filePath, connString string, db *models.Database, label string) error {
|
||||
var writer writers.Writer
|
||||
|
||||
switch strings.ToLower(dbType) {
|
||||
case "dbml":
|
||||
writer = wdbml.NewWriter(&writers.WriterOptions{OutputPath: filePath})
|
||||
case "dctx":
|
||||
writer = wdctx.NewWriter(&writers.WriterOptions{OutputPath: filePath})
|
||||
case "drawdb":
|
||||
writer = wdrawdb.NewWriter(&writers.WriterOptions{OutputPath: filePath})
|
||||
case "graphql":
|
||||
writer = wgraphql.NewWriter(&writers.WriterOptions{OutputPath: filePath})
|
||||
case "json":
|
||||
writer = wjson.NewWriter(&writers.WriterOptions{OutputPath: filePath})
|
||||
case "yaml":
|
||||
writer = wyaml.NewWriter(&writers.WriterOptions{OutputPath: filePath})
|
||||
case "gorm":
|
||||
writer = wgorm.NewWriter(&writers.WriterOptions{OutputPath: filePath})
|
||||
case "bun":
|
||||
writer = wbun.NewWriter(&writers.WriterOptions{OutputPath: filePath})
|
||||
case "drizzle":
|
||||
writer = wdrizzle.NewWriter(&writers.WriterOptions{OutputPath: filePath})
|
||||
case "prisma":
|
||||
writer = wprisma.NewWriter(&writers.WriterOptions{OutputPath: filePath})
|
||||
case "typeorm":
|
||||
writer = wtypeorm.NewWriter(&writers.WriterOptions{OutputPath: filePath})
|
||||
case "pgsql":
|
||||
writer = wpgsql.NewWriter(&writers.WriterOptions{OutputPath: filePath})
|
||||
default:
|
||||
return fmt.Errorf("%s: unsupported format: %s", label, dbType)
|
||||
}
|
||||
|
||||
err := writer.WriteDatabase(db)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: %w", label, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
453
cmd/relspec/merge.go
Normal file
453
cmd/relspec/merge.go
Normal file
@@ -0,0 +1,453 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/merge"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/readers"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/readers/bun"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/readers/dbml"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/readers/dctx"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/readers/drawdb"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/readers/drizzle"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/readers/gorm"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/readers/graphql"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/readers/json"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/readers/pgsql"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/readers/prisma"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/readers/typeorm"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/readers/yaml"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/writers"
|
||||
wbun "git.warky.dev/wdevs/relspecgo/pkg/writers/bun"
|
||||
wdbml "git.warky.dev/wdevs/relspecgo/pkg/writers/dbml"
|
||||
wdctx "git.warky.dev/wdevs/relspecgo/pkg/writers/dctx"
|
||||
wdrawdb "git.warky.dev/wdevs/relspecgo/pkg/writers/drawdb"
|
||||
wdrizzle "git.warky.dev/wdevs/relspecgo/pkg/writers/drizzle"
|
||||
wgorm "git.warky.dev/wdevs/relspecgo/pkg/writers/gorm"
|
||||
wgraphql "git.warky.dev/wdevs/relspecgo/pkg/writers/graphql"
|
||||
wjson "git.warky.dev/wdevs/relspecgo/pkg/writers/json"
|
||||
wpgsql "git.warky.dev/wdevs/relspecgo/pkg/writers/pgsql"
|
||||
wprisma "git.warky.dev/wdevs/relspecgo/pkg/writers/prisma"
|
||||
wtypeorm "git.warky.dev/wdevs/relspecgo/pkg/writers/typeorm"
|
||||
wyaml "git.warky.dev/wdevs/relspecgo/pkg/writers/yaml"
|
||||
)
|
||||
|
||||
var (
|
||||
mergeTargetType string
|
||||
mergeTargetPath string
|
||||
mergeTargetConn string
|
||||
mergeSourceType string
|
||||
mergeSourcePath string
|
||||
mergeSourceConn string
|
||||
mergeOutputType string
|
||||
mergeOutputPath string
|
||||
mergeOutputConn string
|
||||
mergeSkipDomains bool
|
||||
mergeSkipRelations bool
|
||||
mergeSkipEnums bool
|
||||
mergeSkipViews bool
|
||||
mergeSkipSequences bool
|
||||
mergeSkipTables string // Comma-separated table names to skip
|
||||
mergeVerbose bool
|
||||
mergeReportPath string // Path to write merge report
|
||||
mergeFlattenSchema bool
|
||||
)
|
||||
|
||||
var mergeCmd = &cobra.Command{
|
||||
Use: "merge",
|
||||
Short: "Merge database schemas (additive only - adds missing items)",
|
||||
Long: `Merge one database schema into another. Performs additive merging only:
|
||||
adds missing schemas, tables, columns, and other objects without modifying
|
||||
or deleting existing items.
|
||||
|
||||
The target database is loaded first, then the source database is merged into it.
|
||||
The result can be saved to a new format or updated in place.
|
||||
|
||||
Examples:
|
||||
# Merge two JSON schemas
|
||||
relspec merge --target json --target-path base.json \
|
||||
--source json --source-path additional.json \
|
||||
--output json --output-path merged.json
|
||||
|
||||
# Merge from PostgreSQL into JSON
|
||||
relspec merge --target json --target-path mydb.json \
|
||||
--source pgsql --source-conn "postgres://user:pass@localhost/source_db" \
|
||||
--output json --output-path combined.json
|
||||
|
||||
# Merge and execute on PostgreSQL database with report
|
||||
relspec merge --target json --target-path base.json \
|
||||
--source json --source-path additional.json \
|
||||
--output pgsql --output-conn "postgres://user:pass@localhost/target_db" \
|
||||
--merge-report merge-report.json
|
||||
|
||||
# Merge DBML and YAML, skip relations
|
||||
relspec merge --target dbml --target-path schema.dbml \
|
||||
--source yaml --source-path tables.yaml \
|
||||
--output dbml --output-path merged.dbml \
|
||||
--skip-relations
|
||||
|
||||
# Merge and save back to target format
|
||||
relspec merge --target json --target-path base.json \
|
||||
--source json --source-path patch.json \
|
||||
--output json --output-path base.json`,
|
||||
RunE: runMerge,
|
||||
}
|
||||
|
||||
func init() {
|
||||
// Target database flags
|
||||
mergeCmd.Flags().StringVar(&mergeTargetType, "target", "", "Target format (required): dbml, dctx, drawdb, graphql, json, yaml, gorm, bun, drizzle, prisma, typeorm, pgsql")
|
||||
mergeCmd.Flags().StringVar(&mergeTargetPath, "target-path", "", "Target file path (required for file-based formats)")
|
||||
mergeCmd.Flags().StringVar(&mergeTargetConn, "target-conn", "", "Target connection string (required for pgsql)")
|
||||
|
||||
// Source database flags
|
||||
mergeCmd.Flags().StringVar(&mergeSourceType, "source", "", "Source format (required): dbml, dctx, drawdb, graphql, json, yaml, gorm, bun, drizzle, prisma, typeorm, pgsql")
|
||||
mergeCmd.Flags().StringVar(&mergeSourcePath, "source-path", "", "Source file path (required for file-based formats)")
|
||||
mergeCmd.Flags().StringVar(&mergeSourceConn, "source-conn", "", "Source connection string (required for pgsql)")
|
||||
|
||||
// Output flags
|
||||
mergeCmd.Flags().StringVar(&mergeOutputType, "output", "", "Output format (required): dbml, dctx, drawdb, graphql, json, yaml, gorm, bun, drizzle, prisma, typeorm, pgsql")
|
||||
mergeCmd.Flags().StringVar(&mergeOutputPath, "output-path", "", "Output file path (required for file-based formats)")
|
||||
mergeCmd.Flags().StringVar(&mergeOutputConn, "output-conn", "", "Output connection string (for pgsql)")
|
||||
|
||||
// Merge options
|
||||
mergeCmd.Flags().BoolVar(&mergeSkipDomains, "skip-domains", false, "Skip domains during merge")
|
||||
mergeCmd.Flags().BoolVar(&mergeSkipRelations, "skip-relations", false, "Skip relations during merge")
|
||||
mergeCmd.Flags().BoolVar(&mergeSkipEnums, "skip-enums", false, "Skip enums during merge")
|
||||
mergeCmd.Flags().BoolVar(&mergeSkipViews, "skip-views", false, "Skip views during merge")
|
||||
mergeCmd.Flags().BoolVar(&mergeSkipSequences, "skip-sequences", false, "Skip sequences during merge")
|
||||
mergeCmd.Flags().StringVar(&mergeSkipTables, "skip-tables", "", "Comma-separated list of table names to skip during merge")
|
||||
mergeCmd.Flags().BoolVar(&mergeVerbose, "verbose", false, "Show verbose output")
|
||||
mergeCmd.Flags().StringVar(&mergeReportPath, "merge-report", "", "Path to write merge report (JSON format)")
|
||||
mergeCmd.Flags().BoolVar(&mergeFlattenSchema, "flatten-schema", false, "Flatten schema.table names to schema_table (useful for databases like SQLite that do not support schemas)")
|
||||
}
|
||||
|
||||
func runMerge(cmd *cobra.Command, args []string) error {
|
||||
fmt.Fprintf(os.Stderr, "\n=== RelSpec Merge ===\n")
|
||||
fmt.Fprintf(os.Stderr, "Started at: %s\n\n", getCurrentTimestamp())
|
||||
|
||||
// Validate required flags
|
||||
if mergeTargetType == "" {
|
||||
return fmt.Errorf("--target format is required")
|
||||
}
|
||||
if mergeSourceType == "" {
|
||||
return fmt.Errorf("--source format is required")
|
||||
}
|
||||
if mergeOutputType == "" {
|
||||
return fmt.Errorf("--output format is required")
|
||||
}
|
||||
|
||||
// Validate and expand file paths
|
||||
if mergeTargetType != "pgsql" {
|
||||
if mergeTargetPath == "" {
|
||||
return fmt.Errorf("--target-path is required for %s format", mergeTargetType)
|
||||
}
|
||||
mergeTargetPath = expandPath(mergeTargetPath)
|
||||
} else if mergeTargetConn == "" {
|
||||
|
||||
return fmt.Errorf("--target-conn is required for pgsql format")
|
||||
|
||||
}
|
||||
|
||||
if mergeSourceType != "pgsql" {
|
||||
if mergeSourcePath == "" {
|
||||
return fmt.Errorf("--source-path is required for %s format", mergeSourceType)
|
||||
}
|
||||
mergeSourcePath = expandPath(mergeSourcePath)
|
||||
} else if mergeSourceConn == "" {
|
||||
return fmt.Errorf("--source-conn is required for pgsql format")
|
||||
}
|
||||
|
||||
if mergeOutputType != "pgsql" {
|
||||
if mergeOutputPath == "" {
|
||||
return fmt.Errorf("--output-path is required for %s format", mergeOutputType)
|
||||
}
|
||||
mergeOutputPath = expandPath(mergeOutputPath)
|
||||
}
|
||||
|
||||
// Step 1: Read target database
|
||||
fmt.Fprintf(os.Stderr, "[1/3] Reading target database...\n")
|
||||
fmt.Fprintf(os.Stderr, " Format: %s\n", mergeTargetType)
|
||||
if mergeTargetPath != "" {
|
||||
fmt.Fprintf(os.Stderr, " Path: %s\n", mergeTargetPath)
|
||||
}
|
||||
if mergeTargetConn != "" {
|
||||
fmt.Fprintf(os.Stderr, " Conn: %s\n", maskPassword(mergeTargetConn))
|
||||
}
|
||||
|
||||
targetDB, err := readDatabaseForMerge(mergeTargetType, mergeTargetPath, mergeTargetConn, "Target")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read target database: %w", err)
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, " ✓ Successfully read target database '%s'\n", targetDB.Name)
|
||||
printDatabaseStats(targetDB)
|
||||
|
||||
// Step 2: Read source database
|
||||
fmt.Fprintf(os.Stderr, "\n[2/3] Reading source database...\n")
|
||||
fmt.Fprintf(os.Stderr, " Format: %s\n", mergeSourceType)
|
||||
if mergeSourcePath != "" {
|
||||
fmt.Fprintf(os.Stderr, " Path: %s\n", mergeSourcePath)
|
||||
}
|
||||
if mergeSourceConn != "" {
|
||||
fmt.Fprintf(os.Stderr, " Conn: %s\n", maskPassword(mergeSourceConn))
|
||||
}
|
||||
|
||||
sourceDB, err := readDatabaseForMerge(mergeSourceType, mergeSourcePath, mergeSourceConn, "Source")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read source database: %w", err)
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, " ✓ Successfully read source database '%s'\n", sourceDB.Name)
|
||||
printDatabaseStats(sourceDB)
|
||||
|
||||
// Step 3: Merge databases
|
||||
fmt.Fprintf(os.Stderr, "\n[3/3] Merging databases...\n")
|
||||
|
||||
opts := &merge.MergeOptions{
|
||||
SkipDomains: mergeSkipDomains,
|
||||
SkipRelations: mergeSkipRelations,
|
||||
SkipEnums: mergeSkipEnums,
|
||||
SkipViews: mergeSkipViews,
|
||||
SkipSequences: mergeSkipSequences,
|
||||
}
|
||||
|
||||
// Parse skip-tables flag
|
||||
if mergeSkipTables != "" {
|
||||
opts.SkipTableNames = parseSkipTables(mergeSkipTables)
|
||||
if len(opts.SkipTableNames) > 0 {
|
||||
fmt.Fprintf(os.Stderr, " Skipping tables: %s\n", mergeSkipTables)
|
||||
}
|
||||
}
|
||||
|
||||
result := merge.MergeDatabases(targetDB, sourceDB, opts)
|
||||
|
||||
// Update timestamp
|
||||
targetDB.UpdateDate()
|
||||
|
||||
// Print merge summary
|
||||
fmt.Fprintf(os.Stderr, " ✓ Merge complete\n\n")
|
||||
fmt.Fprintf(os.Stderr, "%s\n", merge.GetMergeSummary(result))
|
||||
|
||||
// Step 4: Write output
|
||||
fmt.Fprintf(os.Stderr, "\n[4/4] Writing output...\n")
|
||||
fmt.Fprintf(os.Stderr, " Format: %s\n", mergeOutputType)
|
||||
if mergeOutputPath != "" {
|
||||
fmt.Fprintf(os.Stderr, " Path: %s\n", mergeOutputPath)
|
||||
}
|
||||
|
||||
err = writeDatabaseForMerge(mergeOutputType, mergeOutputPath, mergeOutputConn, targetDB, "Output", mergeFlattenSchema)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write output: %w", err)
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, " ✓ Successfully written merged database\n")
|
||||
fmt.Fprintf(os.Stderr, "\n=== Merge complete ===\n")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func readDatabaseForMerge(dbType, filePath, connString, label string) (*models.Database, error) {
|
||||
var reader readers.Reader
|
||||
|
||||
switch strings.ToLower(dbType) {
|
||||
case "dbml":
|
||||
if filePath == "" {
|
||||
return nil, fmt.Errorf("%s: file path is required for DBML format", label)
|
||||
}
|
||||
reader = dbml.NewReader(&readers.ReaderOptions{FilePath: filePath})
|
||||
case "dctx":
|
||||
if filePath == "" {
|
||||
return nil, fmt.Errorf("%s: file path is required for DCTX format", label)
|
||||
}
|
||||
reader = dctx.NewReader(&readers.ReaderOptions{FilePath: filePath})
|
||||
case "drawdb":
|
||||
if filePath == "" {
|
||||
return nil, fmt.Errorf("%s: file path is required for DrawDB format", label)
|
||||
}
|
||||
reader = drawdb.NewReader(&readers.ReaderOptions{FilePath: filePath})
|
||||
case "graphql":
|
||||
if filePath == "" {
|
||||
return nil, fmt.Errorf("%s: file path is required for GraphQL format", label)
|
||||
}
|
||||
reader = graphql.NewReader(&readers.ReaderOptions{FilePath: filePath})
|
||||
case "json":
|
||||
if filePath == "" {
|
||||
return nil, fmt.Errorf("%s: file path is required for JSON format", label)
|
||||
}
|
||||
reader = json.NewReader(&readers.ReaderOptions{FilePath: filePath})
|
||||
case "yaml":
|
||||
if filePath == "" {
|
||||
return nil, fmt.Errorf("%s: file path is required for YAML format", label)
|
||||
}
|
||||
reader = yaml.NewReader(&readers.ReaderOptions{FilePath: filePath})
|
||||
case "gorm":
|
||||
if filePath == "" {
|
||||
return nil, fmt.Errorf("%s: file path is required for GORM format", label)
|
||||
}
|
||||
reader = gorm.NewReader(&readers.ReaderOptions{FilePath: filePath})
|
||||
case "bun":
|
||||
if filePath == "" {
|
||||
return nil, fmt.Errorf("%s: file path is required for Bun format", label)
|
||||
}
|
||||
reader = bun.NewReader(&readers.ReaderOptions{FilePath: filePath})
|
||||
case "drizzle":
|
||||
if filePath == "" {
|
||||
return nil, fmt.Errorf("%s: file path is required for Drizzle format", label)
|
||||
}
|
||||
reader = drizzle.NewReader(&readers.ReaderOptions{FilePath: filePath})
|
||||
case "prisma":
|
||||
if filePath == "" {
|
||||
return nil, fmt.Errorf("%s: file path is required for Prisma format", label)
|
||||
}
|
||||
reader = prisma.NewReader(&readers.ReaderOptions{FilePath: filePath})
|
||||
case "typeorm":
|
||||
if filePath == "" {
|
||||
return nil, fmt.Errorf("%s: file path is required for TypeORM format", label)
|
||||
}
|
||||
reader = typeorm.NewReader(&readers.ReaderOptions{FilePath: filePath})
|
||||
case "pgsql":
|
||||
if connString == "" {
|
||||
return nil, fmt.Errorf("%s: connection string is required for PostgreSQL format", label)
|
||||
}
|
||||
reader = pgsql.NewReader(&readers.ReaderOptions{ConnectionString: connString})
|
||||
default:
|
||||
return nil, fmt.Errorf("%s: unsupported format '%s'", label, dbType)
|
||||
}
|
||||
|
||||
db, err := reader.ReadDatabase()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
func writeDatabaseForMerge(dbType, filePath, connString string, db *models.Database, label string, flattenSchema bool) error {
|
||||
var writer writers.Writer
|
||||
|
||||
switch strings.ToLower(dbType) {
|
||||
case "dbml":
|
||||
if filePath == "" {
|
||||
return fmt.Errorf("%s: file path is required for DBML format", label)
|
||||
}
|
||||
writer = wdbml.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema})
|
||||
case "dctx":
|
||||
if filePath == "" {
|
||||
return fmt.Errorf("%s: file path is required for DCTX format", label)
|
||||
}
|
||||
writer = wdctx.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema})
|
||||
case "drawdb":
|
||||
if filePath == "" {
|
||||
return fmt.Errorf("%s: file path is required for DrawDB format", label)
|
||||
}
|
||||
writer = wdrawdb.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema})
|
||||
case "graphql":
|
||||
if filePath == "" {
|
||||
return fmt.Errorf("%s: file path is required for GraphQL format", label)
|
||||
}
|
||||
writer = wgraphql.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema})
|
||||
case "json":
|
||||
if filePath == "" {
|
||||
return fmt.Errorf("%s: file path is required for JSON format", label)
|
||||
}
|
||||
writer = wjson.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema})
|
||||
case "yaml":
|
||||
if filePath == "" {
|
||||
return fmt.Errorf("%s: file path is required for YAML format", label)
|
||||
}
|
||||
writer = wyaml.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema})
|
||||
case "gorm":
|
||||
if filePath == "" {
|
||||
return fmt.Errorf("%s: file path is required for GORM format", label)
|
||||
}
|
||||
writer = wgorm.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema})
|
||||
case "bun":
|
||||
if filePath == "" {
|
||||
return fmt.Errorf("%s: file path is required for Bun format", label)
|
||||
}
|
||||
writer = wbun.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema})
|
||||
case "drizzle":
|
||||
if filePath == "" {
|
||||
return fmt.Errorf("%s: file path is required for Drizzle format", label)
|
||||
}
|
||||
writer = wdrizzle.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema})
|
||||
case "prisma":
|
||||
if filePath == "" {
|
||||
return fmt.Errorf("%s: file path is required for Prisma format", label)
|
||||
}
|
||||
writer = wprisma.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema})
|
||||
case "typeorm":
|
||||
if filePath == "" {
|
||||
return fmt.Errorf("%s: file path is required for TypeORM format", label)
|
||||
}
|
||||
writer = wtypeorm.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema})
|
||||
case "pgsql":
|
||||
writerOpts := &writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema}
|
||||
if connString != "" {
|
||||
writerOpts.Metadata = map[string]interface{}{
|
||||
"connection_string": connString,
|
||||
}
|
||||
// Add report path if merge report is enabled
|
||||
if mergeReportPath != "" {
|
||||
writerOpts.Metadata["report_path"] = mergeReportPath
|
||||
}
|
||||
}
|
||||
writer = wpgsql.NewWriter(writerOpts)
|
||||
default:
|
||||
return fmt.Errorf("%s: unsupported format '%s'", label, dbType)
|
||||
}
|
||||
|
||||
return writer.WriteDatabase(db)
|
||||
}
|
||||
|
||||
func expandPath(path string) string {
|
||||
if len(path) > 0 && path[0] == '~' {
|
||||
home, err := os.UserHomeDir()
|
||||
if err == nil {
|
||||
return filepath.Join(home, path[1:])
|
||||
}
|
||||
}
|
||||
return path
|
||||
}
|
||||
|
||||
func printDatabaseStats(db *models.Database) {
|
||||
totalTables := 0
|
||||
totalColumns := 0
|
||||
totalConstraints := 0
|
||||
totalIndexes := 0
|
||||
|
||||
for _, schema := range db.Schemas {
|
||||
totalTables += len(schema.Tables)
|
||||
for _, table := range schema.Tables {
|
||||
totalColumns += len(table.Columns)
|
||||
totalConstraints += len(table.Constraints)
|
||||
totalIndexes += len(table.Indexes)
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, " Schemas: %d, Tables: %d, Columns: %d, Constraints: %d, Indexes: %d\n",
|
||||
len(db.Schemas), totalTables, totalColumns, totalConstraints, totalIndexes)
|
||||
}
|
||||
|
||||
func parseSkipTables(skipTablesStr string) map[string]bool {
|
||||
skipTables := make(map[string]bool)
|
||||
if skipTablesStr == "" {
|
||||
return skipTables
|
||||
}
|
||||
|
||||
// Split by comma and trim whitespace
|
||||
parts := strings.Split(skipTablesStr, ",")
|
||||
for _, part := range parts {
|
||||
trimmed := strings.TrimSpace(part)
|
||||
if trimmed != "" {
|
||||
// Store in lowercase for case-insensitive matching
|
||||
skipTables[strings.ToLower(trimmed)] = true
|
||||
}
|
||||
}
|
||||
|
||||
return skipTables
|
||||
}
|
||||
@@ -21,4 +21,7 @@ func init() {
|
||||
rootCmd.AddCommand(inspectCmd)
|
||||
rootCmd.AddCommand(scriptsCmd)
|
||||
rootCmd.AddCommand(templCmd)
|
||||
rootCmd.AddCommand(editCmd)
|
||||
rootCmd.AddCommand(mergeCmd)
|
||||
rootCmd.AddCommand(splitCmd)
|
||||
}
|
||||
|
||||
@@ -14,10 +14,11 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
scriptsDir string
|
||||
scriptsConn string
|
||||
scriptsSchemaName string
|
||||
scriptsDBName string
|
||||
scriptsDir string
|
||||
scriptsConn string
|
||||
scriptsSchemaName string
|
||||
scriptsDBName string
|
||||
scriptsIgnoreErrors bool
|
||||
)
|
||||
|
||||
var scriptsCmd = &cobra.Command{
|
||||
@@ -39,8 +40,8 @@ Example filenames (hyphen format):
|
||||
1-002-create-posts.sql # Priority 1, Sequence 2
|
||||
10-10-create-newid.pgsql # Priority 10, Sequence 10
|
||||
|
||||
Both formats can be mixed in the same directory.
|
||||
Scripts are executed in order: Priority (ascending), then Sequence (ascending).`,
|
||||
Both formats can be mixed in the same directory and subdirectories.
|
||||
Scripts are executed in order: Priority (ascending), Sequence (ascending), Name (alphabetical).`,
|
||||
}
|
||||
|
||||
var scriptsListCmd = &cobra.Command{
|
||||
@@ -48,8 +49,8 @@ var scriptsListCmd = &cobra.Command{
|
||||
Short: "List SQL scripts from a directory",
|
||||
Long: `List SQL scripts from a directory and show their execution order.
|
||||
|
||||
The scripts are read from the specified directory and displayed in the order
|
||||
they would be executed (Priority ascending, then Sequence ascending).
|
||||
The scripts are read recursively from the specified directory and displayed in the order
|
||||
they would be executed: Priority (ascending), then Sequence (ascending), then Name (alphabetical).
|
||||
|
||||
Example:
|
||||
relspec scripts list --dir ./migrations`,
|
||||
@@ -61,10 +62,10 @@ var scriptsExecuteCmd = &cobra.Command{
|
||||
Short: "Execute SQL scripts against a database",
|
||||
Long: `Execute SQL scripts from a directory against a PostgreSQL database.
|
||||
|
||||
Scripts are executed in order: Priority (ascending), then Sequence (ascending).
|
||||
Execution stops immediately on the first error.
|
||||
Scripts are executed in order: Priority (ascending), Sequence (ascending), Name (alphabetical).
|
||||
By default, execution stops immediately on the first error. Use --ignore-errors to continue execution.
|
||||
|
||||
The directory is scanned recursively for files matching the patterns:
|
||||
The directory is scanned recursively for all subdirectories and files matching the patterns:
|
||||
{priority}_{sequence}_{name}.sql or .pgsql (underscore format)
|
||||
{priority}-{sequence}-{name}.sql or .pgsql (hyphen format)
|
||||
|
||||
@@ -75,7 +76,7 @@ PostgreSQL Connection String Examples:
|
||||
postgresql://user:pass@host/dbname?sslmode=require
|
||||
|
||||
Examples:
|
||||
# Execute migration scripts
|
||||
# Execute migration scripts from a directory (including subdirectories)
|
||||
relspec scripts execute --dir ./migrations \
|
||||
--conn "postgres://user:pass@localhost:5432/mydb"
|
||||
|
||||
@@ -86,7 +87,12 @@ Examples:
|
||||
|
||||
# Execute with SSL disabled
|
||||
relspec scripts execute --dir ./sql \
|
||||
--conn "postgres://user:pass@localhost/db?sslmode=disable"`,
|
||||
--conn "postgres://user:pass@localhost/db?sslmode=disable"
|
||||
|
||||
# Continue executing even if errors occur
|
||||
relspec scripts execute --dir ./migrations \
|
||||
--conn "postgres://localhost/mydb" \
|
||||
--ignore-errors`,
|
||||
RunE: runScriptsExecute,
|
||||
}
|
||||
|
||||
@@ -105,6 +111,7 @@ func init() {
|
||||
scriptsExecuteCmd.Flags().StringVar(&scriptsConn, "conn", "", "PostgreSQL connection string (required)")
|
||||
scriptsExecuteCmd.Flags().StringVar(&scriptsSchemaName, "schema", "public", "Schema name (optional, default: public)")
|
||||
scriptsExecuteCmd.Flags().StringVar(&scriptsDBName, "database", "database", "Database name (optional, default: database)")
|
||||
scriptsExecuteCmd.Flags().BoolVar(&scriptsIgnoreErrors, "ignore-errors", false, "Continue executing scripts even if errors occur")
|
||||
|
||||
err = scriptsExecuteCmd.MarkFlagRequired("dir")
|
||||
if err != nil {
|
||||
@@ -149,7 +156,7 @@ func runScriptsList(cmd *cobra.Command, args []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Sort scripts by Priority then Sequence
|
||||
// Sort scripts by Priority, Sequence, then Name
|
||||
sortedScripts := make([]*struct {
|
||||
name string
|
||||
priority int
|
||||
@@ -186,7 +193,10 @@ func runScriptsList(cmd *cobra.Command, args []string) error {
|
||||
if sortedScripts[i].priority != sortedScripts[j].priority {
|
||||
return sortedScripts[i].priority < sortedScripts[j].priority
|
||||
}
|
||||
return sortedScripts[i].sequence < sortedScripts[j].sequence
|
||||
if sortedScripts[i].sequence != sortedScripts[j].sequence {
|
||||
return sortedScripts[i].sequence < sortedScripts[j].sequence
|
||||
}
|
||||
return sortedScripts[i].name < sortedScripts[j].name
|
||||
})
|
||||
|
||||
fmt.Fprintf(os.Stderr, "Found %d script(s) in execution order:\n\n", len(sortedScripts))
|
||||
@@ -242,22 +252,44 @@ func runScriptsExecute(cmd *cobra.Command, args []string) error {
|
||||
fmt.Fprintf(os.Stderr, " ✓ Found %d script(s)\n\n", len(schema.Scripts))
|
||||
|
||||
// Step 2: Execute scripts
|
||||
fmt.Fprintf(os.Stderr, "[2/2] Executing scripts in order (Priority → Sequence)...\n\n")
|
||||
fmt.Fprintf(os.Stderr, "[2/2] Executing scripts in order (Priority → Sequence → Name)...\n\n")
|
||||
|
||||
writer := sqlexec.NewWriter(&writers.WriterOptions{
|
||||
Metadata: map[string]any{
|
||||
"connection_string": scriptsConn,
|
||||
"ignore_errors": scriptsIgnoreErrors,
|
||||
},
|
||||
})
|
||||
|
||||
if err := writer.WriteSchema(schema); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "\n")
|
||||
return fmt.Errorf("execution failed: %w", err)
|
||||
return fmt.Errorf("script execution failed: %w", err)
|
||||
}
|
||||
|
||||
// Get execution results from writer metadata
|
||||
totalCount := len(schema.Scripts)
|
||||
successCount := totalCount
|
||||
failedCount := 0
|
||||
|
||||
opts := writer.Options()
|
||||
if total, exists := opts.Metadata["execution_total"].(int); exists {
|
||||
totalCount = total
|
||||
}
|
||||
if success, exists := opts.Metadata["execution_success"].(int); exists {
|
||||
successCount = success
|
||||
}
|
||||
if failed, exists := opts.Metadata["execution_failed"].(int); exists {
|
||||
failedCount = failed
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\n=== Execution Complete ===\n")
|
||||
fmt.Fprintf(os.Stderr, "Completed at: %s\n", getCurrentTimestamp())
|
||||
fmt.Fprintf(os.Stderr, "Successfully executed %d script(s)\n\n", len(schema.Scripts))
|
||||
fmt.Fprintf(os.Stderr, "Total scripts: %d\n", totalCount)
|
||||
fmt.Fprintf(os.Stderr, "Successful: %d\n", successCount)
|
||||
if failedCount > 0 {
|
||||
fmt.Fprintf(os.Stderr, "Failed: %d\n", failedCount)
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "\n")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
319
cmd/relspec/split.go
Normal file
319
cmd/relspec/split.go
Normal file
@@ -0,0 +1,319 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||
)
|
||||
|
||||
var (
|
||||
splitSourceType string
|
||||
splitSourcePath string
|
||||
splitSourceConn string
|
||||
splitTargetType string
|
||||
splitTargetPath string
|
||||
splitSchemas string
|
||||
splitTables string
|
||||
splitPackageName string
|
||||
splitDatabaseName string
|
||||
splitExcludeSchema string
|
||||
splitExcludeTables string
|
||||
)
|
||||
|
||||
var splitCmd = &cobra.Command{
|
||||
Use: "split",
|
||||
Short: "Split database schemas to extract selected tables into a separate database",
|
||||
Long: `Extract selected schemas and tables from a database and write them to a separate output.
|
||||
|
||||
The split command allows you to:
|
||||
- Select specific schemas to include in the output
|
||||
- Select specific tables within schemas
|
||||
- Exclude specific schemas or tables if preferred
|
||||
- Export the selected subset to any supported format
|
||||
|
||||
Input formats:
|
||||
- dbml: DBML schema files
|
||||
- dctx: DCTX schema files
|
||||
- drawdb: DrawDB JSON files
|
||||
- graphql: GraphQL schema files (.graphql, SDL)
|
||||
- json: JSON database schema
|
||||
- yaml: YAML database schema
|
||||
- gorm: GORM model files (Go, file or directory)
|
||||
- bun: Bun model files (Go, file or directory)
|
||||
- drizzle: Drizzle ORM schema files (TypeScript, file or directory)
|
||||
- prisma: Prisma schema files (.prisma)
|
||||
- typeorm: TypeORM entity files (TypeScript)
|
||||
- pgsql: PostgreSQL database (live connection)
|
||||
|
||||
Output formats:
|
||||
- dbml: DBML schema files
|
||||
- dctx: DCTX schema files
|
||||
- drawdb: DrawDB JSON files
|
||||
- graphql: GraphQL schema files (.graphql, SDL)
|
||||
- json: JSON database schema
|
||||
- yaml: YAML database schema
|
||||
- gorm: GORM model files (Go)
|
||||
- bun: Bun model files (Go)
|
||||
- drizzle: Drizzle ORM schema files (TypeScript)
|
||||
- prisma: Prisma schema files (.prisma)
|
||||
- typeorm: TypeORM entity files (TypeScript)
|
||||
- pgsql: PostgreSQL SQL schema
|
||||
|
||||
Examples:
|
||||
# Split specific schemas from DBML
|
||||
relspec split --from dbml --from-path schema.dbml \
|
||||
--schemas public,auth \
|
||||
--to json --to-path subset.json
|
||||
|
||||
# Extract specific tables from PostgreSQL
|
||||
relspec split --from pgsql \
|
||||
--from-conn "postgres://user:pass@localhost:5432/mydb" \
|
||||
--schemas public \
|
||||
--tables users,orders,products \
|
||||
--to dbml --to-path subset.dbml
|
||||
|
||||
# Exclude specific tables
|
||||
relspec split --from json --from-path schema.json \
|
||||
--exclude-tables "audit_log,system_config,temp_data" \
|
||||
--to json --to-path public_schema.json
|
||||
|
||||
# Split and convert to GORM
|
||||
relspec split --from json --from-path schema.json \
|
||||
--tables "users,posts,comments" \
|
||||
--to gorm --to-path models/ --package models \
|
||||
--database-name MyAppDB
|
||||
|
||||
# Exclude specific schema and tables
|
||||
relspec split --from pgsql \
|
||||
--from-conn "postgres://user:pass@localhost/db" \
|
||||
--exclude-schema pg_catalog,information_schema \
|
||||
--exclude-tables "temp_users,debug_logs" \
|
||||
--to json --to-path public_schema.json`,
|
||||
RunE: runSplit,
|
||||
}
|
||||
|
||||
func init() {
|
||||
splitCmd.Flags().StringVar(&splitSourceType, "from", "", "Source format (dbml, dctx, drawdb, graphql, json, yaml, gorm, bun, drizzle, prisma, typeorm, pgsql)")
|
||||
splitCmd.Flags().StringVar(&splitSourcePath, "from-path", "", "Source file path (for file-based formats)")
|
||||
splitCmd.Flags().StringVar(&splitSourceConn, "from-conn", "", "Source connection string (for database formats)")
|
||||
|
||||
splitCmd.Flags().StringVar(&splitTargetType, "to", "", "Target format (dbml, dctx, drawdb, graphql, json, yaml, gorm, bun, drizzle, prisma, typeorm, pgsql)")
|
||||
splitCmd.Flags().StringVar(&splitTargetPath, "to-path", "", "Target output path (file or directory)")
|
||||
splitCmd.Flags().StringVar(&splitPackageName, "package", "", "Package name (for code generation formats like gorm/bun)")
|
||||
splitCmd.Flags().StringVar(&splitDatabaseName, "database-name", "", "Override database name in output")
|
||||
|
||||
splitCmd.Flags().StringVar(&splitSchemas, "schemas", "", "Comma-separated list of schema names to include")
|
||||
splitCmd.Flags().StringVar(&splitTables, "tables", "", "Comma-separated list of table names to include (case-insensitive)")
|
||||
splitCmd.Flags().StringVar(&splitExcludeSchema, "exclude-schema", "", "Comma-separated list of schema names to exclude")
|
||||
splitCmd.Flags().StringVar(&splitExcludeTables, "exclude-tables", "", "Comma-separated list of table names to exclude (case-insensitive)")
|
||||
|
||||
err := splitCmd.MarkFlagRequired("from")
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error marking from flag as required: %v\n", err)
|
||||
}
|
||||
err = splitCmd.MarkFlagRequired("to")
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error marking to flag as required: %v\n", err)
|
||||
}
|
||||
err = splitCmd.MarkFlagRequired("to-path")
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error marking to-path flag as required: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
func runSplit(cmd *cobra.Command, args []string) error {
|
||||
fmt.Fprintf(os.Stderr, "\n=== RelSpec Schema Split ===\n")
|
||||
fmt.Fprintf(os.Stderr, "Started at: %s\n\n", getCurrentTimestamp())
|
||||
|
||||
// Read source database
|
||||
fmt.Fprintf(os.Stderr, "[1/3] Reading source schema...\n")
|
||||
fmt.Fprintf(os.Stderr, " Format: %s\n", splitSourceType)
|
||||
if splitSourcePath != "" {
|
||||
fmt.Fprintf(os.Stderr, " Path: %s\n", splitSourcePath)
|
||||
}
|
||||
if splitSourceConn != "" {
|
||||
fmt.Fprintf(os.Stderr, " Conn: %s\n", maskPassword(splitSourceConn))
|
||||
}
|
||||
|
||||
db, err := readDatabaseForConvert(splitSourceType, splitSourcePath, splitSourceConn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read source: %w", err)
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, " ✓ Successfully read database '%s'\n", db.Name)
|
||||
fmt.Fprintf(os.Stderr, " Found: %d schema(s)\n", len(db.Schemas))
|
||||
|
||||
totalTables := 0
|
||||
for _, schema := range db.Schemas {
|
||||
totalTables += len(schema.Tables)
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, " Found: %d table(s)\n\n", totalTables)
|
||||
|
||||
// Filter the database
|
||||
fmt.Fprintf(os.Stderr, "[2/3] Filtering schemas and tables...\n")
|
||||
filteredDB, err := filterDatabase(db)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to filter database: %w", err)
|
||||
}
|
||||
|
||||
if splitDatabaseName != "" {
|
||||
filteredDB.Name = splitDatabaseName
|
||||
}
|
||||
|
||||
filteredTables := 0
|
||||
for _, schema := range filteredDB.Schemas {
|
||||
filteredTables += len(schema.Tables)
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, " ✓ Filtered to: %d schema(s), %d table(s)\n\n", len(filteredDB.Schemas), filteredTables)
|
||||
|
||||
// Write to target format
|
||||
fmt.Fprintf(os.Stderr, "[3/3] Writing to target format...\n")
|
||||
fmt.Fprintf(os.Stderr, " Format: %s\n", splitTargetType)
|
||||
fmt.Fprintf(os.Stderr, " Output: %s\n", splitTargetPath)
|
||||
if splitPackageName != "" {
|
||||
fmt.Fprintf(os.Stderr, " Package: %s\n", splitPackageName)
|
||||
}
|
||||
|
||||
err = writeDatabase(
|
||||
filteredDB,
|
||||
splitTargetType,
|
||||
splitTargetPath,
|
||||
splitPackageName,
|
||||
"", // no schema filter for split
|
||||
false, // no flatten-schema for split
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write output: %w", err)
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, " ✓ Successfully written to '%s'\n\n", splitTargetPath)
|
||||
fmt.Fprintf(os.Stderr, "=== Split Completed Successfully ===\n")
|
||||
fmt.Fprintf(os.Stderr, "Completed at: %s\n\n", getCurrentTimestamp())
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// filterDatabase filters the database based on provided criteria
|
||||
func filterDatabase(db *models.Database) (*models.Database, error) {
|
||||
filteredDB := &models.Database{
|
||||
Name: db.Name,
|
||||
Description: db.Description,
|
||||
Comment: db.Comment,
|
||||
DatabaseType: db.DatabaseType,
|
||||
DatabaseVersion: db.DatabaseVersion,
|
||||
SourceFormat: db.SourceFormat,
|
||||
UpdatedAt: db.UpdatedAt,
|
||||
GUID: db.GUID,
|
||||
Schemas: []*models.Schema{},
|
||||
Domains: db.Domains, // Keep domains for now
|
||||
}
|
||||
|
||||
// Parse filter flags
|
||||
includeSchemas := parseCommaSeparated(splitSchemas)
|
||||
includeTables := parseCommaSeparated(splitTables)
|
||||
excludeSchemas := parseCommaSeparated(splitExcludeSchema)
|
||||
excludeTables := parseCommaSeparated(splitExcludeTables)
|
||||
|
||||
// Convert table names to lowercase for case-insensitive matching
|
||||
includeTablesLower := make(map[string]bool)
|
||||
for _, t := range includeTables {
|
||||
includeTablesLower[strings.ToLower(t)] = true
|
||||
}
|
||||
|
||||
excludeTablesLower := make(map[string]bool)
|
||||
for _, t := range excludeTables {
|
||||
excludeTablesLower[strings.ToLower(t)] = true
|
||||
}
|
||||
|
||||
// Iterate through schemas
|
||||
for _, schema := range db.Schemas {
|
||||
// Check if schema should be excluded
|
||||
if contains(excludeSchemas, schema.Name) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if schema should be included
|
||||
if len(includeSchemas) > 0 && !contains(includeSchemas, schema.Name) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Create a copy of the schema with filtered tables
|
||||
filteredSchema := &models.Schema{
|
||||
Name: schema.Name,
|
||||
Description: schema.Description,
|
||||
Owner: schema.Owner,
|
||||
Permissions: schema.Permissions,
|
||||
Comment: schema.Comment,
|
||||
Metadata: schema.Metadata,
|
||||
Scripts: schema.Scripts,
|
||||
Sequence: schema.Sequence,
|
||||
Relations: schema.Relations,
|
||||
Enums: schema.Enums,
|
||||
UpdatedAt: schema.UpdatedAt,
|
||||
GUID: schema.GUID,
|
||||
Tables: []*models.Table{},
|
||||
Views: schema.Views,
|
||||
Sequences: schema.Sequences,
|
||||
}
|
||||
|
||||
// Filter tables within the schema
|
||||
for _, table := range schema.Tables {
|
||||
tableLower := strings.ToLower(table.Name)
|
||||
|
||||
// Check if table should be excluded
|
||||
if excludeTablesLower[tableLower] {
|
||||
continue
|
||||
}
|
||||
|
||||
// If specific tables are requested, only include those
|
||||
if len(includeTablesLower) > 0 {
|
||||
if !includeTablesLower[tableLower] {
|
||||
continue
|
||||
}
|
||||
}
|
||||
filteredSchema.Tables = append(filteredSchema.Tables, table)
|
||||
}
|
||||
|
||||
// Only add schema if it has tables (unless no table filter was specified)
|
||||
if len(filteredSchema.Tables) > 0 || (len(includeTablesLower) == 0 && len(excludeTablesLower) == 0) {
|
||||
filteredDB.Schemas = append(filteredDB.Schemas, filteredSchema)
|
||||
}
|
||||
}
|
||||
|
||||
if len(filteredDB.Schemas) == 0 {
|
||||
return nil, fmt.Errorf("no schemas matched the filter criteria")
|
||||
}
|
||||
|
||||
return filteredDB, nil
|
||||
}
|
||||
|
||||
// parseCommaSeparated parses a comma-separated string into a slice, trimming whitespace
|
||||
func parseCommaSeparated(s string) []string {
|
||||
if s == "" {
|
||||
return []string{}
|
||||
}
|
||||
|
||||
parts := strings.Split(s, ",")
|
||||
result := make([]string, 0, len(parts))
|
||||
for _, p := range parts {
|
||||
trimmed := strings.TrimSpace(p)
|
||||
if trimmed != "" {
|
||||
result = append(result, trimmed)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// contains checks if a string is in a slice
|
||||
func contains(slice []string, item string) bool {
|
||||
for _, s := range slice {
|
||||
if s == item {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
149
docs/DOMAINS_DRAWDB.md
Normal file
149
docs/DOMAINS_DRAWDB.md
Normal file
@@ -0,0 +1,149 @@
|
||||
# Domains and DrawDB Areas Integration
|
||||
|
||||
## Overview
|
||||
|
||||
Domains provide a way to organize tables from potentially multiple schemas into logical business groupings. When working with DrawDB format, domains are automatically imported/exported as **Subject Areas** - a native DrawDB feature for visually grouping tables.
|
||||
|
||||
## How It Works
|
||||
|
||||
### Writing Domains to DrawDB (Export)
|
||||
|
||||
When you export a database with domains to DrawDB format:
|
||||
|
||||
1. **Schema Areas** are created automatically for each schema (existing behavior)
|
||||
2. **Domain Areas** are created for each domain, calculated based on the positions of the tables they contain
|
||||
3. The domain area bounds are automatically calculated to encompass all its tables with a small padding
|
||||
|
||||
```go
|
||||
// Example: Creating a domain and exporting to DrawDB
|
||||
db := models.InitDatabase("mydb")
|
||||
|
||||
// Create an "authentication" domain
|
||||
authDomain := models.InitDomain("authentication")
|
||||
authDomain.Tables = append(authDomain.Tables,
|
||||
models.InitDomainTable("users", "public"),
|
||||
models.InitDomainTable("roles", "public"),
|
||||
models.InitDomainTable("permissions", "public"),
|
||||
)
|
||||
db.Domains = append(db.Domains, authDomain)
|
||||
|
||||
// Create a "financial" domain spanning multiple schemas
|
||||
finDomain := models.InitDomain("financial")
|
||||
finDomain.Tables = append(finDomain.Tables,
|
||||
models.InitDomainTable("accounts", "public"),
|
||||
models.InitDomainTable("transactions", "public"),
|
||||
models.InitDomainTable("ledger", "finance"), // Different schema!
|
||||
)
|
||||
db.Domains = append(db.Domains, finDomain)
|
||||
|
||||
// Write to DrawDB - domains become subject areas
|
||||
writer := drawdb.NewWriter(&writers.WriterOptions{
|
||||
OutputPath: "schema.json",
|
||||
})
|
||||
writer.WriteDatabase(db)
|
||||
```
|
||||
|
||||
The resulting DrawDB JSON will have Subject Areas for both:
|
||||
- "authentication" area containing the auth tables
|
||||
- "financial" area containing the financial tables from both schemas
|
||||
|
||||
### Reading Domains from DrawDB (Import)
|
||||
|
||||
When you import a DrawDB file with Subject Areas:
|
||||
|
||||
1. **Subject Areas** are automatically converted to **Domains**
|
||||
2. Tables are assigned to a domain if they fall within the area's visual bounds
|
||||
3. Table references include both the table name and schema name
|
||||
|
||||
```go
|
||||
// Example: Reading DrawDB with areas
|
||||
reader := drawdb.NewReader(&readers.ReaderOptions{
|
||||
FilePath: "schema.json",
|
||||
})
|
||||
|
||||
db, err := reader.ReadDatabase()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Access domains
|
||||
for _, domain := range db.Domains {
|
||||
fmt.Printf("Domain: %s\n", domain.Name)
|
||||
for _, domainTable := range domain.Tables {
|
||||
fmt.Printf(" - %s.%s\n", domainTable.SchemaName, domainTable.TableName)
|
||||
|
||||
// Access the actual table reference if loaded
|
||||
if domainTable.RefTable != nil {
|
||||
fmt.Printf(" Description: %s\n", domainTable.RefTable.Description)
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Domain Structure
|
||||
|
||||
```go
|
||||
type Domain struct {
|
||||
Name string // Domain name (e.g., "authentication", "user_data")
|
||||
Description string // Optional human-readable description
|
||||
Tables []*DomainTable // Tables belonging to this domain
|
||||
Comment string // Optional comment
|
||||
Metadata map[string]any // Extensible metadata
|
||||
Sequence uint // Ordering hint
|
||||
}
|
||||
|
||||
type DomainTable struct {
|
||||
TableName string // Table name
|
||||
SchemaName string // Schema containing the table
|
||||
Sequence uint // Ordering hint
|
||||
RefTable *Table // Pointer to actual table (in-memory only, not serialized)
|
||||
}
|
||||
```
|
||||
|
||||
## Multi-Schema Domains
|
||||
|
||||
One of the key features of domains is that they can span multiple schemas:
|
||||
|
||||
```
|
||||
Domain: "user_data"
|
||||
├── public.users
|
||||
├── public.profiles
|
||||
├── public.user_preferences
|
||||
├── auth.user_sessions
|
||||
└── auth.mfa_devices
|
||||
```
|
||||
|
||||
This allows you to organize related tables even when they're stored in different schemas.
|
||||
|
||||
## Visual Organization in DrawDB
|
||||
|
||||
When viewing the exported DrawDB file in DrawDB Editor:
|
||||
|
||||
1. **Schema areas** appear in one color (original behavior)
|
||||
2. **Domain areas** appear in a different color
|
||||
3. Domain area bounds are calculated to fit all contained tables
|
||||
4. Areas can overlap - a table can visually belong to multiple areas
|
||||
|
||||
## Integration with Other Formats
|
||||
|
||||
Currently, domain/area integration is implemented for DrawDB format.
|
||||
|
||||
To implement similar functionality for other formats:
|
||||
|
||||
1. Identify if the format has a native grouping/area feature
|
||||
2. Add conversion logic in the reader to map format areas → Domain model
|
||||
3. Add conversion logic in the writer to map Domain model → format areas
|
||||
|
||||
Example formats that could support domains:
|
||||
- **DBML**: Could use DBML's `TableGroup` feature
|
||||
- **DrawDB**: ✅ Already implemented (Subject Areas)
|
||||
- **GraphQL**: Could use schema directives
|
||||
- **Custom formats**: Implement as needed
|
||||
|
||||
## Tips and Best Practices
|
||||
|
||||
1. **Keep domains focused**: Each domain should represent a distinct business area
|
||||
2. **Document purposes**: Use Description and Comment fields to explain each domain
|
||||
3. **Use meaningful names**: Domain names should clearly reflect their purpose
|
||||
4. **Maintain schema consistency**: Keep related tables together in the same schema when possible
|
||||
5. **Use metadata**: Store tool-specific information in the Metadata field
|
||||
@@ -44,12 +44,13 @@ The `--mode` flag controls how the template is executed:
|
||||
|------|-------------|--------|-------------|
|
||||
| `database` | Execute once for entire database | Single file | Documentation, reports, overview files |
|
||||
| `schema` | Execute once per schema | One file per schema | Schema-specific documentation |
|
||||
| `domain` | Execute once per domain | One file per domain | Domain-based documentation, domain exports |
|
||||
| `script` | Execute once per script | One file per script | Script processing |
|
||||
| `table` | Execute once per table | One file per table | Model generation, table docs |
|
||||
|
||||
### Filename Patterns
|
||||
|
||||
For multi-file modes (`schema`, `script`, `table`), use `--filename-pattern` to control output filenames:
|
||||
For multi-file modes (`schema`, `domain`, `script`, `table`), use `--filename-pattern` to control output filenames:
|
||||
|
||||
```bash
|
||||
# Default pattern
|
||||
@@ -296,6 +297,13 @@ The data available in templates depends on the execution mode:
|
||||
.Metadata // map[string]interface{} - User metadata
|
||||
```
|
||||
|
||||
### Domain Mode
|
||||
```go
|
||||
.Domain // *models.Domain - Current domain
|
||||
.ParentDatabase // *models.Database - Parent database context
|
||||
.Metadata // map[string]interface{} - User metadata
|
||||
```
|
||||
|
||||
### Table Mode
|
||||
```go
|
||||
.Table // *models.Table - Current table
|
||||
@@ -317,6 +325,7 @@ The data available in templates depends on the execution mode:
|
||||
**Database:**
|
||||
- `.Name` - Database name
|
||||
- `.Schemas` - List of schemas
|
||||
- `.Domains` - List of domains (business domain groupings)
|
||||
- `.Description`, `.Comment` - Documentation
|
||||
|
||||
**Schema:**
|
||||
@@ -325,6 +334,17 @@ The data available in templates depends on the execution mode:
|
||||
- `.Views`, `.Sequences`, `.Scripts` - Other objects
|
||||
- `.Enums` - Enum types
|
||||
|
||||
**Domain:**
|
||||
- `.Name` - Domain name
|
||||
- `.Tables` - List of DomainTable references
|
||||
- `.Description`, `.Comment` - Documentation
|
||||
- `.Metadata` - Custom metadata map
|
||||
|
||||
**DomainTable:**
|
||||
- `.TableName` - Name of the table
|
||||
- `.SchemaName` - Schema containing the table
|
||||
- `.RefTable` - Pointer to actual Table object (if loaded)
|
||||
|
||||
**Table:**
|
||||
- `.Name` - Table name
|
||||
- `.Schema` - Schema name
|
||||
|
||||
7
go.mod
7
go.mod
@@ -3,8 +3,10 @@ module git.warky.dev/wdevs/relspecgo
|
||||
go 1.24.0
|
||||
|
||||
require (
|
||||
github.com/gdamore/tcell/v2 v2.8.1
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/jackc/pgx/v5 v5.7.6
|
||||
github.com/rivo/tview v0.42.0
|
||||
github.com/spf13/cobra v1.10.2
|
||||
github.com/stretchr/testify v1.11.1
|
||||
github.com/uptrace/bun v1.2.16
|
||||
@@ -14,13 +16,17 @@ require (
|
||||
|
||||
require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/gdamore/encoding v1.0.1 // indirect
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||
github.com/kr/pretty v0.3.1 // indirect
|
||||
github.com/lucasb-eyer/go-colorful v1.2.0 // indirect
|
||||
github.com/mattn/go-runewidth v0.0.16 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/puzpuzpuz/xsync/v3 v3.5.1 // indirect
|
||||
github.com/rivo/uniseg v0.4.7 // indirect
|
||||
github.com/rogpeppe/go-internal v1.14.1 // indirect
|
||||
github.com/spf13/pflag v1.0.10 // indirect
|
||||
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc // indirect
|
||||
@@ -28,4 +34,5 @@ require (
|
||||
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
|
||||
golang.org/x/crypto v0.41.0 // indirect
|
||||
golang.org/x/sys v0.38.0 // indirect
|
||||
golang.org/x/term v0.34.0 // indirect
|
||||
)
|
||||
|
||||
79
go.sum
79
go.sum
@@ -3,6 +3,11 @@ github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ3
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/gdamore/encoding v1.0.1 h1:YzKZckdBL6jVt2Gc+5p82qhrGiqMdG/eNs6Wy0u3Uhw=
|
||||
github.com/gdamore/encoding v1.0.1/go.mod h1:0Z0cMFinngz9kS1QfMjCP8TY7em3bZYeeklsSDPivEo=
|
||||
github.com/gdamore/tcell/v2 v2.8.1 h1:KPNxyqclpWpWQlPLx6Xui1pMk8S+7+R37h3g07997NU=
|
||||
github.com/gdamore/tcell/v2 v2.8.1/go.mod h1:bj8ori1BG3OYMjmb3IklZVWfZUJ1UBQt9JXrOCOhGWw=
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||
@@ -21,11 +26,21 @@ github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY=
|
||||
github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
|
||||
github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc=
|
||||
github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
|
||||
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/puzpuzpuz/xsync/v3 v3.5.1 h1:GJYJZwO6IdxN/IKbneznS6yPkVC+c3zyY/j19c++5Fg=
|
||||
github.com/puzpuzpuz/xsync/v3 v3.5.1/go.mod h1:VjzYrABPabuM4KyBh1Ftq6u8nhwY5tBPKP9jpmh0nnA=
|
||||
github.com/rivo/tview v0.42.0 h1:b/ftp+RxtDsHSaynXTbJb+/n/BxDEi+W3UfF5jILK6c=
|
||||
github.com/rivo/tview v0.42.0/go.mod h1:cSfIYfhpSGCjp3r/ECJb+GKS7cGJnqV8vfjQPwoXyfY=
|
||||
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||
github.com/rivo/uniseg v0.4.3/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
|
||||
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
|
||||
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
|
||||
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
|
||||
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
|
||||
github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc=
|
||||
@@ -48,15 +63,79 @@ github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IU
|
||||
github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok=
|
||||
github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g=
|
||||
github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
|
||||
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
||||
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc=
|
||||
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
|
||||
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
|
||||
golang.org/x/crypto v0.41.0 h1:WKYxWedPGCTVVl5+WHSSrOBT0O8lx32+zxmHxijgXp4=
|
||||
golang.org/x/crypto v0.41.0/go.mod h1:pO5AFd7FA68rFak7rOAGVuygIISepHftHnr8dr6+sUc=
|
||||
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
|
||||
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
|
||||
golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
|
||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
|
||||
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
|
||||
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
|
||||
golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk=
|
||||
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
|
||||
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
|
||||
golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw=
|
||||
golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
|
||||
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
|
||||
golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
|
||||
golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU=
|
||||
golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk=
|
||||
golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY=
|
||||
golang.org/x/term v0.28.0/go.mod h1:Sw/lC2IAUZ92udQNf3WodGtn4k/XoLyZoh8v/8uiwek=
|
||||
golang.org/x/term v0.34.0 h1:O/2T7POpk0ZZ7MAzMeWFSg6S5IpWd/RXDlM9hgM3DR4=
|
||||
golang.org/x/term v0.34.0/go.mod h1:5jC53AEywhIVebHgPVeg0mj8OD3VO9OzclacVrqpaAw=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
||||
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
|
||||
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
|
||||
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
|
||||
golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng=
|
||||
golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
|
||||
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
|
||||
golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58=
|
||||
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||
|
||||
714
pkg/commontypes/commontypes_test.go
Normal file
714
pkg/commontypes/commontypes_test.go
Normal file
@@ -0,0 +1,714 @@
|
||||
package commontypes
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestExtractBaseType(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sqlType string
|
||||
want string
|
||||
}{
|
||||
{"varchar with length", "varchar(100)", "varchar"},
|
||||
{"VARCHAR uppercase with length", "VARCHAR(255)", "varchar"},
|
||||
{"numeric with precision", "numeric(10,2)", "numeric"},
|
||||
{"NUMERIC uppercase", "NUMERIC(18,4)", "numeric"},
|
||||
{"decimal with precision", "decimal(15,3)", "decimal"},
|
||||
{"char with length", "char(50)", "char"},
|
||||
{"simple integer", "integer", "integer"},
|
||||
{"simple text", "text", "text"},
|
||||
{"bigint", "bigint", "bigint"},
|
||||
{"With spaces", " varchar(100) ", "varchar"},
|
||||
{"No parentheses", "boolean", "boolean"},
|
||||
{"Empty string", "", ""},
|
||||
{"Mixed case", "VarChar(100)", "varchar"},
|
||||
{"timestamp with time zone", "timestamp(6) with time zone", "timestamp"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := ExtractBaseType(tt.sqlType)
|
||||
if got != tt.want {
|
||||
t.Errorf("ExtractBaseType(%q) = %q, want %q", tt.sqlType, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeType(t *testing.T) {
|
||||
// NormalizeType is an alias for ExtractBaseType, test that they behave the same
|
||||
testCases := []string{
|
||||
"varchar(100)",
|
||||
"numeric(10,2)",
|
||||
"integer",
|
||||
"text",
|
||||
" VARCHAR(255) ",
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc, func(t *testing.T) {
|
||||
extracted := ExtractBaseType(tc)
|
||||
normalized := NormalizeType(tc)
|
||||
if extracted != normalized {
|
||||
t.Errorf("ExtractBaseType(%q) = %q, but NormalizeType(%q) = %q",
|
||||
tc, extracted, tc, normalized)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLToGo(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sqlType string
|
||||
nullable bool
|
||||
want string
|
||||
}{
|
||||
// Integer types (nullable)
|
||||
{"integer nullable", "integer", true, "int32"},
|
||||
{"bigint nullable", "bigint", true, "int64"},
|
||||
{"smallint nullable", "smallint", true, "int16"},
|
||||
{"serial nullable", "serial", true, "int32"},
|
||||
|
||||
// Integer types (not nullable)
|
||||
{"integer not nullable", "integer", false, "*int32"},
|
||||
{"bigint not nullable", "bigint", false, "*int64"},
|
||||
{"smallint not nullable", "smallint", false, "*int16"},
|
||||
|
||||
// String types (nullable)
|
||||
{"text nullable", "text", true, "string"},
|
||||
{"varchar nullable", "varchar", true, "string"},
|
||||
{"varchar with length nullable", "varchar(100)", true, "string"},
|
||||
|
||||
// String types (not nullable)
|
||||
{"text not nullable", "text", false, "*string"},
|
||||
{"varchar not nullable", "varchar", false, "*string"},
|
||||
|
||||
// Boolean
|
||||
{"boolean nullable", "boolean", true, "bool"},
|
||||
{"boolean not nullable", "boolean", false, "*bool"},
|
||||
|
||||
// Float types
|
||||
{"real nullable", "real", true, "float32"},
|
||||
{"double precision nullable", "double precision", true, "float64"},
|
||||
{"real not nullable", "real", false, "*float32"},
|
||||
{"double precision not nullable", "double precision", false, "*float64"},
|
||||
|
||||
// Date/Time types
|
||||
{"timestamp nullable", "timestamp", true, "time.Time"},
|
||||
{"date nullable", "date", true, "time.Time"},
|
||||
{"timestamp not nullable", "timestamp", false, "*time.Time"},
|
||||
|
||||
// Binary
|
||||
{"bytea nullable", "bytea", true, "[]byte"},
|
||||
{"bytea not nullable", "bytea", false, "[]byte"}, // Slices don't get pointer
|
||||
|
||||
// UUID
|
||||
{"uuid nullable", "uuid", true, "string"},
|
||||
{"uuid not nullable", "uuid", false, "*string"},
|
||||
|
||||
// JSON
|
||||
{"json nullable", "json", true, "string"},
|
||||
{"jsonb nullable", "jsonb", true, "string"},
|
||||
|
||||
// Array
|
||||
{"array nullable", "array", true, "[]string"},
|
||||
{"array not nullable", "array", false, "[]string"}, // Slices don't get pointer
|
||||
|
||||
// Unknown types
|
||||
{"unknown type nullable", "unknowntype", true, "interface{}"},
|
||||
{"unknown type not nullable", "unknowntype", false, "interface{}"}, // Interface doesn't get pointer
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := SQLToGo(tt.sqlType, tt.nullable)
|
||||
if got != tt.want {
|
||||
t.Errorf("SQLToGo(%q, %v) = %q, want %q", tt.sqlType, tt.nullable, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLToTypeScript(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sqlType string
|
||||
nullable bool
|
||||
want string
|
||||
}{
|
||||
// Integer types
|
||||
{"integer nullable", "integer", true, "number"},
|
||||
{"integer not nullable", "integer", false, "number | null"},
|
||||
{"bigint nullable", "bigint", true, "number"},
|
||||
{"bigint not nullable", "bigint", false, "number | null"},
|
||||
|
||||
// String types
|
||||
{"text nullable", "text", true, "string"},
|
||||
{"text not nullable", "text", false, "string | null"},
|
||||
{"varchar nullable", "varchar", true, "string"},
|
||||
{"varchar(100) nullable", "varchar(100)", true, "string"},
|
||||
|
||||
// Boolean
|
||||
{"boolean nullable", "boolean", true, "boolean"},
|
||||
{"boolean not nullable", "boolean", false, "boolean | null"},
|
||||
|
||||
// Float types
|
||||
{"real nullable", "real", true, "number"},
|
||||
{"double precision nullable", "double precision", true, "number"},
|
||||
|
||||
// Date/Time types
|
||||
{"timestamp nullable", "timestamp", true, "Date"},
|
||||
{"date nullable", "date", true, "Date"},
|
||||
{"timestamp not nullable", "timestamp", false, "Date | null"},
|
||||
|
||||
// Binary
|
||||
{"bytea nullable", "bytea", true, "Buffer"},
|
||||
{"bytea not nullable", "bytea", false, "Buffer | null"},
|
||||
|
||||
// JSON
|
||||
{"json nullable", "json", true, "any"},
|
||||
{"jsonb nullable", "jsonb", true, "any"},
|
||||
|
||||
// UUID
|
||||
{"uuid nullable", "uuid", true, "string"},
|
||||
|
||||
// Unknown types
|
||||
{"unknown type nullable", "unknowntype", true, "any"},
|
||||
{"unknown type not nullable", "unknowntype", false, "any | null"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := SQLToTypeScript(tt.sqlType, tt.nullable)
|
||||
if got != tt.want {
|
||||
t.Errorf("SQLToTypeScript(%q, %v) = %q, want %q", tt.sqlType, tt.nullable, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLToPython(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sqlType string
|
||||
want string
|
||||
}{
|
||||
// Integer types
|
||||
{"integer", "integer", "int"},
|
||||
{"bigint", "bigint", "int"},
|
||||
{"smallint", "smallint", "int"},
|
||||
|
||||
// String types
|
||||
{"text", "text", "str"},
|
||||
{"varchar", "varchar", "str"},
|
||||
{"varchar(100)", "varchar(100)", "str"},
|
||||
|
||||
// Boolean
|
||||
{"boolean", "boolean", "bool"},
|
||||
|
||||
// Float types
|
||||
{"real", "real", "float"},
|
||||
{"double precision", "double precision", "float"},
|
||||
{"numeric", "numeric", "Decimal"},
|
||||
{"decimal", "decimal", "Decimal"},
|
||||
|
||||
// Date/Time types
|
||||
{"timestamp", "timestamp", "datetime"},
|
||||
{"date", "date", "date"},
|
||||
{"time", "time", "time"},
|
||||
|
||||
// Binary
|
||||
{"bytea", "bytea", "bytes"},
|
||||
|
||||
// JSON
|
||||
{"json", "json", "dict"},
|
||||
{"jsonb", "jsonb", "dict"},
|
||||
|
||||
// UUID
|
||||
{"uuid", "uuid", "UUID"},
|
||||
|
||||
// Array
|
||||
{"array", "array", "list"},
|
||||
|
||||
// Unknown types
|
||||
{"unknown type", "unknowntype", "Any"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := SQLToPython(tt.sqlType)
|
||||
if got != tt.want {
|
||||
t.Errorf("SQLToPython(%q) = %q, want %q", tt.sqlType, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLToCSharp(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sqlType string
|
||||
nullable bool
|
||||
want string
|
||||
}{
|
||||
// Integer types (nullable)
|
||||
{"integer nullable", "integer", true, "int"},
|
||||
{"bigint nullable", "bigint", true, "long"},
|
||||
{"smallint nullable", "smallint", true, "short"},
|
||||
|
||||
// Integer types (not nullable - value types get ?)
|
||||
{"integer not nullable", "integer", false, "int?"},
|
||||
{"bigint not nullable", "bigint", false, "long?"},
|
||||
{"smallint not nullable", "smallint", false, "short?"},
|
||||
|
||||
// String types (reference types, no ? needed)
|
||||
{"text nullable", "text", true, "string"},
|
||||
{"text not nullable", "text", false, "string"},
|
||||
{"varchar nullable", "varchar", true, "string"},
|
||||
{"varchar(100) nullable", "varchar(100)", true, "string"},
|
||||
|
||||
// Boolean
|
||||
{"boolean nullable", "boolean", true, "bool"},
|
||||
{"boolean not nullable", "boolean", false, "bool?"},
|
||||
|
||||
// Float types
|
||||
{"real nullable", "real", true, "float"},
|
||||
{"double precision nullable", "double precision", true, "double"},
|
||||
{"decimal nullable", "decimal", true, "decimal"},
|
||||
{"real not nullable", "real", false, "float?"},
|
||||
{"double precision not nullable", "double precision", false, "double?"},
|
||||
{"decimal not nullable", "decimal", false, "decimal?"},
|
||||
|
||||
// Date/Time types
|
||||
{"timestamp nullable", "timestamp", true, "DateTime"},
|
||||
{"date nullable", "date", true, "DateTime"},
|
||||
{"timestamptz nullable", "timestamptz", true, "DateTimeOffset"},
|
||||
{"timestamp not nullable", "timestamp", false, "DateTime?"},
|
||||
{"timestamptz not nullable", "timestamptz", false, "DateTimeOffset?"},
|
||||
|
||||
// Binary (array type, no ?)
|
||||
{"bytea nullable", "bytea", true, "byte[]"},
|
||||
{"bytea not nullable", "bytea", false, "byte[]"},
|
||||
|
||||
// UUID
|
||||
{"uuid nullable", "uuid", true, "Guid"},
|
||||
{"uuid not nullable", "uuid", false, "Guid?"},
|
||||
|
||||
// JSON
|
||||
{"json nullable", "json", true, "string"},
|
||||
|
||||
// Unknown types (object is reference type)
|
||||
{"unknown type nullable", "unknowntype", true, "object"},
|
||||
{"unknown type not nullable", "unknowntype", false, "object"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := SQLToCSharp(tt.sqlType, tt.nullable)
|
||||
if got != tt.want {
|
||||
t.Errorf("SQLToCSharp(%q, %v) = %q, want %q", tt.sqlType, tt.nullable, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNeedsTimeImport(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
goType string
|
||||
want bool
|
||||
}{
|
||||
{"time.Time type", "time.Time", true},
|
||||
{"pointer to time.Time", "*time.Time", true},
|
||||
{"int32 type", "int32", false},
|
||||
{"string type", "string", false},
|
||||
{"bool type", "bool", false},
|
||||
{"[]byte type", "[]byte", false},
|
||||
{"interface{}", "interface{}", false},
|
||||
{"empty string", "", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := NeedsTimeImport(tt.goType)
|
||||
if got != tt.want {
|
||||
t.Errorf("NeedsTimeImport(%q) = %v, want %v", tt.goType, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGoTypeMap(t *testing.T) {
|
||||
// Test that the map contains expected entries
|
||||
expectedMappings := map[string]string{
|
||||
"integer": "int32",
|
||||
"bigint": "int64",
|
||||
"text": "string",
|
||||
"boolean": "bool",
|
||||
"double precision": "float64",
|
||||
"bytea": "[]byte",
|
||||
"timestamp": "time.Time",
|
||||
"uuid": "string",
|
||||
"json": "string",
|
||||
}
|
||||
|
||||
for sqlType, expectedGoType := range expectedMappings {
|
||||
if goType, ok := GoTypeMap[sqlType]; !ok {
|
||||
t.Errorf("GoTypeMap missing entry for %q", sqlType)
|
||||
} else if goType != expectedGoType {
|
||||
t.Errorf("GoTypeMap[%q] = %q, want %q", sqlType, goType, expectedGoType)
|
||||
}
|
||||
}
|
||||
|
||||
if len(GoTypeMap) == 0 {
|
||||
t.Error("GoTypeMap is empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTypeScriptTypeMap(t *testing.T) {
|
||||
expectedMappings := map[string]string{
|
||||
"integer": "number",
|
||||
"bigint": "number",
|
||||
"text": "string",
|
||||
"boolean": "boolean",
|
||||
"double precision": "number",
|
||||
"bytea": "Buffer",
|
||||
"timestamp": "Date",
|
||||
"uuid": "string",
|
||||
"json": "any",
|
||||
}
|
||||
|
||||
for sqlType, expectedTSType := range expectedMappings {
|
||||
if tsType, ok := TypeScriptTypeMap[sqlType]; !ok {
|
||||
t.Errorf("TypeScriptTypeMap missing entry for %q", sqlType)
|
||||
} else if tsType != expectedTSType {
|
||||
t.Errorf("TypeScriptTypeMap[%q] = %q, want %q", sqlType, tsType, expectedTSType)
|
||||
}
|
||||
}
|
||||
|
||||
if len(TypeScriptTypeMap) == 0 {
|
||||
t.Error("TypeScriptTypeMap is empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPythonTypeMap(t *testing.T) {
|
||||
expectedMappings := map[string]string{
|
||||
"integer": "int",
|
||||
"bigint": "int",
|
||||
"text": "str",
|
||||
"boolean": "bool",
|
||||
"real": "float",
|
||||
"numeric": "Decimal",
|
||||
"bytea": "bytes",
|
||||
"date": "date",
|
||||
"uuid": "UUID",
|
||||
"json": "dict",
|
||||
}
|
||||
|
||||
for sqlType, expectedPyType := range expectedMappings {
|
||||
if pyType, ok := PythonTypeMap[sqlType]; !ok {
|
||||
t.Errorf("PythonTypeMap missing entry for %q", sqlType)
|
||||
} else if pyType != expectedPyType {
|
||||
t.Errorf("PythonTypeMap[%q] = %q, want %q", sqlType, pyType, expectedPyType)
|
||||
}
|
||||
}
|
||||
|
||||
if len(PythonTypeMap) == 0 {
|
||||
t.Error("PythonTypeMap is empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCSharpTypeMap(t *testing.T) {
|
||||
expectedMappings := map[string]string{
|
||||
"integer": "int",
|
||||
"bigint": "long",
|
||||
"smallint": "short",
|
||||
"text": "string",
|
||||
"boolean": "bool",
|
||||
"double precision": "double",
|
||||
"decimal": "decimal",
|
||||
"bytea": "byte[]",
|
||||
"timestamp": "DateTime",
|
||||
"uuid": "Guid",
|
||||
"json": "string",
|
||||
}
|
||||
|
||||
for sqlType, expectedCSType := range expectedMappings {
|
||||
if csType, ok := CSharpTypeMap[sqlType]; !ok {
|
||||
t.Errorf("CSharpTypeMap missing entry for %q", sqlType)
|
||||
} else if csType != expectedCSType {
|
||||
t.Errorf("CSharpTypeMap[%q] = %q, want %q", sqlType, csType, expectedCSType)
|
||||
}
|
||||
}
|
||||
|
||||
if len(CSharpTypeMap) == 0 {
|
||||
t.Error("CSharpTypeMap is empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLToJava(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sqlType string
|
||||
nullable bool
|
||||
want string
|
||||
}{
|
||||
// Integer types
|
||||
{"integer nullable", "integer", true, "Integer"},
|
||||
{"integer not nullable", "integer", false, "Integer"},
|
||||
{"bigint nullable", "bigint", true, "Long"},
|
||||
{"smallint nullable", "smallint", true, "Short"},
|
||||
|
||||
// String types
|
||||
{"text nullable", "text", true, "String"},
|
||||
{"varchar nullable", "varchar", true, "String"},
|
||||
{"varchar(100) nullable", "varchar(100)", true, "String"},
|
||||
|
||||
// Boolean
|
||||
{"boolean nullable", "boolean", true, "Boolean"},
|
||||
|
||||
// Float types
|
||||
{"real nullable", "real", true, "Float"},
|
||||
{"double precision nullable", "double precision", true, "Double"},
|
||||
{"numeric nullable", "numeric", true, "BigDecimal"},
|
||||
|
||||
// Date/Time types
|
||||
{"timestamp nullable", "timestamp", true, "Timestamp"},
|
||||
{"date nullable", "date", true, "Date"},
|
||||
{"time nullable", "time", true, "Time"},
|
||||
|
||||
// Binary
|
||||
{"bytea nullable", "bytea", true, "byte[]"},
|
||||
|
||||
// UUID
|
||||
{"uuid nullable", "uuid", true, "UUID"},
|
||||
|
||||
// JSON
|
||||
{"json nullable", "json", true, "String"},
|
||||
|
||||
// Unknown types
|
||||
{"unknown type nullable", "unknowntype", true, "Object"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := SQLToJava(tt.sqlType, tt.nullable)
|
||||
if got != tt.want {
|
||||
t.Errorf("SQLToJava(%q, %v) = %q, want %q", tt.sqlType, tt.nullable, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLToPhp(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sqlType string
|
||||
nullable bool
|
||||
want string
|
||||
}{
|
||||
// Integer types (nullable)
|
||||
{"integer nullable", "integer", true, "int"},
|
||||
{"bigint nullable", "bigint", true, "int"},
|
||||
{"smallint nullable", "smallint", true, "int"},
|
||||
|
||||
// Integer types (not nullable)
|
||||
{"integer not nullable", "integer", false, "?int"},
|
||||
{"bigint not nullable", "bigint", false, "?int"},
|
||||
|
||||
// String types
|
||||
{"text nullable", "text", true, "string"},
|
||||
{"text not nullable", "text", false, "?string"},
|
||||
{"varchar nullable", "varchar", true, "string"},
|
||||
{"varchar(100) nullable", "varchar(100)", true, "string"},
|
||||
|
||||
// Boolean
|
||||
{"boolean nullable", "boolean", true, "bool"},
|
||||
{"boolean not nullable", "boolean", false, "?bool"},
|
||||
|
||||
// Float types
|
||||
{"real nullable", "real", true, "float"},
|
||||
{"double precision nullable", "double precision", true, "float"},
|
||||
{"real not nullable", "real", false, "?float"},
|
||||
|
||||
// Date/Time types
|
||||
{"timestamp nullable", "timestamp", true, "\\DateTime"},
|
||||
{"date nullable", "date", true, "\\DateTime"},
|
||||
{"timestamp not nullable", "timestamp", false, "?\\DateTime"},
|
||||
|
||||
// Binary
|
||||
{"bytea nullable", "bytea", true, "string"},
|
||||
{"bytea not nullable", "bytea", false, "?string"},
|
||||
|
||||
// JSON
|
||||
{"json nullable", "json", true, "array"},
|
||||
{"json not nullable", "json", false, "?array"},
|
||||
|
||||
// UUID
|
||||
{"uuid nullable", "uuid", true, "string"},
|
||||
|
||||
// Unknown types
|
||||
{"unknown type nullable", "unknowntype", true, "mixed"},
|
||||
{"unknown type not nullable", "unknowntype", false, "mixed"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := SQLToPhp(tt.sqlType, tt.nullable)
|
||||
if got != tt.want {
|
||||
t.Errorf("SQLToPhp(%q, %v) = %q, want %q", tt.sqlType, tt.nullable, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLToRust(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sqlType string
|
||||
nullable bool
|
||||
want string
|
||||
}{
|
||||
// Integer types (nullable)
|
||||
{"integer nullable", "integer", true, "i32"},
|
||||
{"bigint nullable", "bigint", true, "i64"},
|
||||
{"smallint nullable", "smallint", true, "i16"},
|
||||
|
||||
// Integer types (not nullable)
|
||||
{"integer not nullable", "integer", false, "Option<i32>"},
|
||||
{"bigint not nullable", "bigint", false, "Option<i64>"},
|
||||
{"smallint not nullable", "smallint", false, "Option<i16>"},
|
||||
|
||||
// String types
|
||||
{"text nullable", "text", true, "String"},
|
||||
{"text not nullable", "text", false, "Option<String>"},
|
||||
{"varchar nullable", "varchar", true, "String"},
|
||||
{"varchar(100) nullable", "varchar(100)", true, "String"},
|
||||
|
||||
// Boolean
|
||||
{"boolean nullable", "boolean", true, "bool"},
|
||||
{"boolean not nullable", "boolean", false, "Option<bool>"},
|
||||
|
||||
// Float types
|
||||
{"real nullable", "real", true, "f32"},
|
||||
{"double precision nullable", "double precision", true, "f64"},
|
||||
{"real not nullable", "real", false, "Option<f32>"},
|
||||
{"double precision not nullable", "double precision", false, "Option<f64>"},
|
||||
|
||||
// Date/Time types
|
||||
{"timestamp nullable", "timestamp", true, "NaiveDateTime"},
|
||||
{"timestamptz nullable", "timestamptz", true, "DateTime<Utc>"},
|
||||
{"date nullable", "date", true, "NaiveDate"},
|
||||
{"time nullable", "time", true, "NaiveTime"},
|
||||
{"timestamp not nullable", "timestamp", false, "Option<NaiveDateTime>"},
|
||||
|
||||
// Binary
|
||||
{"bytea nullable", "bytea", true, "Vec<u8>"},
|
||||
{"bytea not nullable", "bytea", false, "Option<Vec<u8>>"},
|
||||
|
||||
// JSON
|
||||
{"json nullable", "json", true, "serde_json::Value"},
|
||||
{"json not nullable", "json", false, "Option<serde_json::Value>"},
|
||||
|
||||
// UUID
|
||||
{"uuid nullable", "uuid", true, "String"},
|
||||
|
||||
// Unknown types
|
||||
{"unknown type nullable", "unknowntype", true, "String"},
|
||||
{"unknown type not nullable", "unknowntype", false, "Option<String>"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := SQLToRust(tt.sqlType, tt.nullable)
|
||||
if got != tt.want {
|
||||
t.Errorf("SQLToRust(%q, %v) = %q, want %q", tt.sqlType, tt.nullable, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestJavaTypeMap(t *testing.T) {
|
||||
expectedMappings := map[string]string{
|
||||
"integer": "Integer",
|
||||
"bigint": "Long",
|
||||
"smallint": "Short",
|
||||
"text": "String",
|
||||
"boolean": "Boolean",
|
||||
"double precision": "Double",
|
||||
"numeric": "BigDecimal",
|
||||
"bytea": "byte[]",
|
||||
"timestamp": "Timestamp",
|
||||
"uuid": "UUID",
|
||||
"date": "Date",
|
||||
}
|
||||
|
||||
for sqlType, expectedJavaType := range expectedMappings {
|
||||
if javaType, ok := JavaTypeMap[sqlType]; !ok {
|
||||
t.Errorf("JavaTypeMap missing entry for %q", sqlType)
|
||||
} else if javaType != expectedJavaType {
|
||||
t.Errorf("JavaTypeMap[%q] = %q, want %q", sqlType, javaType, expectedJavaType)
|
||||
}
|
||||
}
|
||||
|
||||
if len(JavaTypeMap) == 0 {
|
||||
t.Error("JavaTypeMap is empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPHPTypeMap(t *testing.T) {
|
||||
expectedMappings := map[string]string{
|
||||
"integer": "int",
|
||||
"bigint": "int",
|
||||
"text": "string",
|
||||
"boolean": "bool",
|
||||
"double precision": "float",
|
||||
"bytea": "string",
|
||||
"timestamp": "\\DateTime",
|
||||
"uuid": "string",
|
||||
"json": "array",
|
||||
}
|
||||
|
||||
for sqlType, expectedPHPType := range expectedMappings {
|
||||
if phpType, ok := PHPTypeMap[sqlType]; !ok {
|
||||
t.Errorf("PHPTypeMap missing entry for %q", sqlType)
|
||||
} else if phpType != expectedPHPType {
|
||||
t.Errorf("PHPTypeMap[%q] = %q, want %q", sqlType, phpType, expectedPHPType)
|
||||
}
|
||||
}
|
||||
|
||||
if len(PHPTypeMap) == 0 {
|
||||
t.Error("PHPTypeMap is empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRustTypeMap(t *testing.T) {
|
||||
expectedMappings := map[string]string{
|
||||
"integer": "i32",
|
||||
"bigint": "i64",
|
||||
"smallint": "i16",
|
||||
"text": "String",
|
||||
"boolean": "bool",
|
||||
"double precision": "f64",
|
||||
"real": "f32",
|
||||
"bytea": "Vec<u8>",
|
||||
"timestamp": "NaiveDateTime",
|
||||
"timestamptz": "DateTime<Utc>",
|
||||
"date": "NaiveDate",
|
||||
"json": "serde_json::Value",
|
||||
}
|
||||
|
||||
for sqlType, expectedRustType := range expectedMappings {
|
||||
if rustType, ok := RustTypeMap[sqlType]; !ok {
|
||||
t.Errorf("RustTypeMap missing entry for %q", sqlType)
|
||||
} else if rustType != expectedRustType {
|
||||
t.Errorf("RustTypeMap[%q] = %q, want %q", sqlType, rustType, expectedRustType)
|
||||
}
|
||||
}
|
||||
|
||||
if len(RustTypeMap) == 0 {
|
||||
t.Error("RustTypeMap is empty")
|
||||
}
|
||||
}
|
||||
558
pkg/diff/diff_test.go
Normal file
558
pkg/diff/diff_test.go
Normal file
@@ -0,0 +1,558 @@
|
||||
package diff
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||
)
|
||||
|
||||
func TestCompareDatabases(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
source *models.Database
|
||||
target *models.Database
|
||||
want func(*DiffResult) bool
|
||||
}{
|
||||
{
|
||||
name: "identical databases",
|
||||
source: &models.Database{
|
||||
Name: "source",
|
||||
Schemas: []*models.Schema{},
|
||||
},
|
||||
target: &models.Database{
|
||||
Name: "target",
|
||||
Schemas: []*models.Schema{},
|
||||
},
|
||||
want: func(r *DiffResult) bool {
|
||||
return r.Source == "source" && r.Target == "target" &&
|
||||
len(r.Schemas.Missing) == 0 && len(r.Schemas.Extra) == 0
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "different schemas",
|
||||
source: &models.Database{
|
||||
Name: "source",
|
||||
Schemas: []*models.Schema{
|
||||
{Name: "public"},
|
||||
},
|
||||
},
|
||||
target: &models.Database{
|
||||
Name: "target",
|
||||
Schemas: []*models.Schema{},
|
||||
},
|
||||
want: func(r *DiffResult) bool {
|
||||
return len(r.Schemas.Missing) == 1 && r.Schemas.Missing[0].Name == "public"
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := CompareDatabases(tt.source, tt.target)
|
||||
if !tt.want(got) {
|
||||
t.Errorf("CompareDatabases() result doesn't match expectations")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompareColumns(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
source map[string]*models.Column
|
||||
target map[string]*models.Column
|
||||
want func(*ColumnDiff) bool
|
||||
}{
|
||||
{
|
||||
name: "identical columns",
|
||||
source: map[string]*models.Column{},
|
||||
target: map[string]*models.Column{},
|
||||
want: func(d *ColumnDiff) bool {
|
||||
return len(d.Missing) == 0 && len(d.Extra) == 0 && len(d.Modified) == 0
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing column",
|
||||
source: map[string]*models.Column{
|
||||
"id": {Name: "id", Type: "integer"},
|
||||
},
|
||||
target: map[string]*models.Column{},
|
||||
want: func(d *ColumnDiff) bool {
|
||||
return len(d.Missing) == 1 && d.Missing[0].Name == "id"
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "extra column",
|
||||
source: map[string]*models.Column{},
|
||||
target: map[string]*models.Column{
|
||||
"id": {Name: "id", Type: "integer"},
|
||||
},
|
||||
want: func(d *ColumnDiff) bool {
|
||||
return len(d.Extra) == 1 && d.Extra[0].Name == "id"
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "modified column type",
|
||||
source: map[string]*models.Column{
|
||||
"id": {Name: "id", Type: "integer"},
|
||||
},
|
||||
target: map[string]*models.Column{
|
||||
"id": {Name: "id", Type: "bigint"},
|
||||
},
|
||||
want: func(d *ColumnDiff) bool {
|
||||
return len(d.Modified) == 1 && d.Modified[0].Name == "id" &&
|
||||
d.Modified[0].Changes["type"] != nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "modified column nullable",
|
||||
source: map[string]*models.Column{
|
||||
"name": {Name: "name", Type: "text", NotNull: true},
|
||||
},
|
||||
target: map[string]*models.Column{
|
||||
"name": {Name: "name", Type: "text", NotNull: false},
|
||||
},
|
||||
want: func(d *ColumnDiff) bool {
|
||||
return len(d.Modified) == 1 && d.Modified[0].Changes["not_null"] != nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "modified column length",
|
||||
source: map[string]*models.Column{
|
||||
"name": {Name: "name", Type: "varchar", Length: 100},
|
||||
},
|
||||
target: map[string]*models.Column{
|
||||
"name": {Name: "name", Type: "varchar", Length: 255},
|
||||
},
|
||||
want: func(d *ColumnDiff) bool {
|
||||
return len(d.Modified) == 1 && d.Modified[0].Changes["length"] != nil
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := compareColumns(tt.source, tt.target)
|
||||
if !tt.want(got) {
|
||||
t.Errorf("compareColumns() result doesn't match expectations")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompareColumnDetails(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
source *models.Column
|
||||
target *models.Column
|
||||
want int // number of changes
|
||||
}{
|
||||
{
|
||||
name: "identical columns",
|
||||
source: &models.Column{Name: "id", Type: "integer"},
|
||||
target: &models.Column{Name: "id", Type: "integer"},
|
||||
want: 0,
|
||||
},
|
||||
{
|
||||
name: "type change",
|
||||
source: &models.Column{Name: "id", Type: "integer"},
|
||||
target: &models.Column{Name: "id", Type: "bigint"},
|
||||
want: 1,
|
||||
},
|
||||
{
|
||||
name: "length change",
|
||||
source: &models.Column{Name: "name", Type: "varchar", Length: 100},
|
||||
target: &models.Column{Name: "name", Type: "varchar", Length: 255},
|
||||
want: 1,
|
||||
},
|
||||
{
|
||||
name: "precision change",
|
||||
source: &models.Column{Name: "price", Type: "numeric", Precision: 10},
|
||||
target: &models.Column{Name: "price", Type: "numeric", Precision: 12},
|
||||
want: 1,
|
||||
},
|
||||
{
|
||||
name: "scale change",
|
||||
source: &models.Column{Name: "price", Type: "numeric", Scale: 2},
|
||||
target: &models.Column{Name: "price", Type: "numeric", Scale: 4},
|
||||
want: 1,
|
||||
},
|
||||
{
|
||||
name: "not null change",
|
||||
source: &models.Column{Name: "name", Type: "text", NotNull: true},
|
||||
target: &models.Column{Name: "name", Type: "text", NotNull: false},
|
||||
want: 1,
|
||||
},
|
||||
{
|
||||
name: "auto increment change",
|
||||
source: &models.Column{Name: "id", Type: "integer", AutoIncrement: true},
|
||||
target: &models.Column{Name: "id", Type: "integer", AutoIncrement: false},
|
||||
want: 1,
|
||||
},
|
||||
{
|
||||
name: "primary key change",
|
||||
source: &models.Column{Name: "id", Type: "integer", IsPrimaryKey: true},
|
||||
target: &models.Column{Name: "id", Type: "integer", IsPrimaryKey: false},
|
||||
want: 1,
|
||||
},
|
||||
{
|
||||
name: "multiple changes",
|
||||
source: &models.Column{Name: "id", Type: "integer", NotNull: true, AutoIncrement: true},
|
||||
target: &models.Column{Name: "id", Type: "bigint", NotNull: false, AutoIncrement: false},
|
||||
want: 3,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := compareColumnDetails(tt.source, tt.target)
|
||||
if len(got) != tt.want {
|
||||
t.Errorf("compareColumnDetails() = %d changes, want %d", len(got), tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompareIndexes(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
source map[string]*models.Index
|
||||
target map[string]*models.Index
|
||||
want func(*IndexDiff) bool
|
||||
}{
|
||||
{
|
||||
name: "identical indexes",
|
||||
source: map[string]*models.Index{},
|
||||
target: map[string]*models.Index{},
|
||||
want: func(d *IndexDiff) bool {
|
||||
return len(d.Missing) == 0 && len(d.Extra) == 0 && len(d.Modified) == 0
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing index",
|
||||
source: map[string]*models.Index{
|
||||
"idx_name": {Name: "idx_name", Columns: []string{"name"}},
|
||||
},
|
||||
target: map[string]*models.Index{},
|
||||
want: func(d *IndexDiff) bool {
|
||||
return len(d.Missing) == 1 && d.Missing[0].Name == "idx_name"
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "extra index",
|
||||
source: map[string]*models.Index{},
|
||||
target: map[string]*models.Index{
|
||||
"idx_name": {Name: "idx_name", Columns: []string{"name"}},
|
||||
},
|
||||
want: func(d *IndexDiff) bool {
|
||||
return len(d.Extra) == 1 && d.Extra[0].Name == "idx_name"
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "modified index uniqueness",
|
||||
source: map[string]*models.Index{
|
||||
"idx_name": {Name: "idx_name", Columns: []string{"name"}, Unique: false},
|
||||
},
|
||||
target: map[string]*models.Index{
|
||||
"idx_name": {Name: "idx_name", Columns: []string{"name"}, Unique: true},
|
||||
},
|
||||
want: func(d *IndexDiff) bool {
|
||||
return len(d.Modified) == 1 && d.Modified[0].Name == "idx_name"
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := compareIndexes(tt.source, tt.target)
|
||||
if !tt.want(got) {
|
||||
t.Errorf("compareIndexes() result doesn't match expectations")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompareConstraints(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
source map[string]*models.Constraint
|
||||
target map[string]*models.Constraint
|
||||
want func(*ConstraintDiff) bool
|
||||
}{
|
||||
{
|
||||
name: "identical constraints",
|
||||
source: map[string]*models.Constraint{},
|
||||
target: map[string]*models.Constraint{},
|
||||
want: func(d *ConstraintDiff) bool {
|
||||
return len(d.Missing) == 0 && len(d.Extra) == 0 && len(d.Modified) == 0
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing constraint",
|
||||
source: map[string]*models.Constraint{
|
||||
"pk_id": {Name: "pk_id", Type: "PRIMARY KEY", Columns: []string{"id"}},
|
||||
},
|
||||
target: map[string]*models.Constraint{},
|
||||
want: func(d *ConstraintDiff) bool {
|
||||
return len(d.Missing) == 1 && d.Missing[0].Name == "pk_id"
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "extra constraint",
|
||||
source: map[string]*models.Constraint{},
|
||||
target: map[string]*models.Constraint{
|
||||
"pk_id": {Name: "pk_id", Type: "PRIMARY KEY", Columns: []string{"id"}},
|
||||
},
|
||||
want: func(d *ConstraintDiff) bool {
|
||||
return len(d.Extra) == 1 && d.Extra[0].Name == "pk_id"
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := compareConstraints(tt.source, tt.target)
|
||||
if !tt.want(got) {
|
||||
t.Errorf("compareConstraints() result doesn't match expectations")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompareRelationships(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
source map[string]*models.Relationship
|
||||
target map[string]*models.Relationship
|
||||
want func(*RelationshipDiff) bool
|
||||
}{
|
||||
{
|
||||
name: "identical relationships",
|
||||
source: map[string]*models.Relationship{},
|
||||
target: map[string]*models.Relationship{},
|
||||
want: func(d *RelationshipDiff) bool {
|
||||
return len(d.Missing) == 0 && len(d.Extra) == 0 && len(d.Modified) == 0
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing relationship",
|
||||
source: map[string]*models.Relationship{
|
||||
"fk_user": {Name: "fk_user", Type: "FOREIGN KEY"},
|
||||
},
|
||||
target: map[string]*models.Relationship{},
|
||||
want: func(d *RelationshipDiff) bool {
|
||||
return len(d.Missing) == 1 && d.Missing[0].Name == "fk_user"
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "extra relationship",
|
||||
source: map[string]*models.Relationship{},
|
||||
target: map[string]*models.Relationship{
|
||||
"fk_user": {Name: "fk_user", Type: "FOREIGN KEY"},
|
||||
},
|
||||
want: func(d *RelationshipDiff) bool {
|
||||
return len(d.Extra) == 1 && d.Extra[0].Name == "fk_user"
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := compareRelationships(tt.source, tt.target)
|
||||
if !tt.want(got) {
|
||||
t.Errorf("compareRelationships() result doesn't match expectations")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompareTables(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
source []*models.Table
|
||||
target []*models.Table
|
||||
want func(*TableDiff) bool
|
||||
}{
|
||||
{
|
||||
name: "identical tables",
|
||||
source: []*models.Table{},
|
||||
target: []*models.Table{},
|
||||
want: func(d *TableDiff) bool {
|
||||
return len(d.Missing) == 0 && len(d.Extra) == 0 && len(d.Modified) == 0
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing table",
|
||||
source: []*models.Table{
|
||||
{Name: "users", Schema: "public"},
|
||||
},
|
||||
target: []*models.Table{},
|
||||
want: func(d *TableDiff) bool {
|
||||
return len(d.Missing) == 1 && d.Missing[0].Name == "users"
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "extra table",
|
||||
source: []*models.Table{},
|
||||
target: []*models.Table{
|
||||
{Name: "users", Schema: "public"},
|
||||
},
|
||||
want: func(d *TableDiff) bool {
|
||||
return len(d.Extra) == 1 && d.Extra[0].Name == "users"
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "modified table",
|
||||
source: []*models.Table{
|
||||
{
|
||||
Name: "users",
|
||||
Schema: "public",
|
||||
Columns: map[string]*models.Column{
|
||||
"id": {Name: "id", Type: "integer"},
|
||||
},
|
||||
},
|
||||
},
|
||||
target: []*models.Table{
|
||||
{
|
||||
Name: "users",
|
||||
Schema: "public",
|
||||
Columns: map[string]*models.Column{
|
||||
"id": {Name: "id", Type: "bigint"},
|
||||
},
|
||||
},
|
||||
},
|
||||
want: func(d *TableDiff) bool {
|
||||
return len(d.Modified) == 1 && d.Modified[0].Name == "users"
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := compareTables(tt.source, tt.target)
|
||||
if !tt.want(got) {
|
||||
t.Errorf("compareTables() result doesn't match expectations")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompareSchemas(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
source []*models.Schema
|
||||
target []*models.Schema
|
||||
want func(*SchemaDiff) bool
|
||||
}{
|
||||
{
|
||||
name: "identical schemas",
|
||||
source: []*models.Schema{},
|
||||
target: []*models.Schema{},
|
||||
want: func(d *SchemaDiff) bool {
|
||||
return len(d.Missing) == 0 && len(d.Extra) == 0 && len(d.Modified) == 0
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing schema",
|
||||
source: []*models.Schema{
|
||||
{Name: "public"},
|
||||
},
|
||||
target: []*models.Schema{},
|
||||
want: func(d *SchemaDiff) bool {
|
||||
return len(d.Missing) == 1 && d.Missing[0].Name == "public"
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "extra schema",
|
||||
source: []*models.Schema{},
|
||||
target: []*models.Schema{
|
||||
{Name: "public"},
|
||||
},
|
||||
want: func(d *SchemaDiff) bool {
|
||||
return len(d.Extra) == 1 && d.Extra[0].Name == "public"
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := compareSchemas(tt.source, tt.target)
|
||||
if !tt.want(got) {
|
||||
t.Errorf("compareSchemas() result doesn't match expectations")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsEmpty(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
v interface{}
|
||||
want bool
|
||||
}{
|
||||
{"empty ColumnDiff", &ColumnDiff{Missing: []*models.Column{}, Extra: []*models.Column{}, Modified: []*ColumnChange{}}, true},
|
||||
{"ColumnDiff with missing", &ColumnDiff{Missing: []*models.Column{{Name: "id"}}, Extra: []*models.Column{}, Modified: []*ColumnChange{}}, false},
|
||||
{"ColumnDiff with extra", &ColumnDiff{Missing: []*models.Column{}, Extra: []*models.Column{{Name: "id"}}, Modified: []*ColumnChange{}}, false},
|
||||
{"empty IndexDiff", &IndexDiff{Missing: []*models.Index{}, Extra: []*models.Index{}, Modified: []*IndexChange{}}, true},
|
||||
{"IndexDiff with missing", &IndexDiff{Missing: []*models.Index{{Name: "idx"}}, Extra: []*models.Index{}, Modified: []*IndexChange{}}, false},
|
||||
{"empty TableDiff", &TableDiff{Missing: []*models.Table{}, Extra: []*models.Table{}, Modified: []*TableChange{}}, true},
|
||||
{"TableDiff with extra", &TableDiff{Missing: []*models.Table{}, Extra: []*models.Table{{Name: "users"}}, Modified: []*TableChange{}}, false},
|
||||
{"empty ConstraintDiff", &ConstraintDiff{Missing: []*models.Constraint{}, Extra: []*models.Constraint{}, Modified: []*ConstraintChange{}}, true},
|
||||
{"empty RelationshipDiff", &RelationshipDiff{Missing: []*models.Relationship{}, Extra: []*models.Relationship{}, Modified: []*RelationshipChange{}}, true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := isEmpty(tt.v)
|
||||
if got != tt.want {
|
||||
t.Errorf("isEmpty() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeSummary(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
result *DiffResult
|
||||
want func(*Summary) bool
|
||||
}{
|
||||
{
|
||||
name: "empty diff",
|
||||
result: &DiffResult{
|
||||
Schemas: &SchemaDiff{
|
||||
Missing: []*models.Schema{},
|
||||
Extra: []*models.Schema{},
|
||||
Modified: []*SchemaChange{},
|
||||
},
|
||||
},
|
||||
want: func(s *Summary) bool {
|
||||
return s.Schemas.Missing == 0 && s.Schemas.Extra == 0 && s.Schemas.Modified == 0
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "schemas with differences",
|
||||
result: &DiffResult{
|
||||
Schemas: &SchemaDiff{
|
||||
Missing: []*models.Schema{{Name: "schema1"}},
|
||||
Extra: []*models.Schema{{Name: "schema2"}, {Name: "schema3"}},
|
||||
Modified: []*SchemaChange{
|
||||
{Name: "public"},
|
||||
},
|
||||
},
|
||||
},
|
||||
want: func(s *Summary) bool {
|
||||
return s.Schemas.Missing == 1 && s.Schemas.Extra == 2 && s.Schemas.Modified == 1
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := ComputeSummary(tt.result)
|
||||
if !tt.want(got) {
|
||||
t.Errorf("ComputeSummary() result doesn't match expectations")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
440
pkg/diff/formatters_test.go
Normal file
440
pkg/diff/formatters_test.go
Normal file
@@ -0,0 +1,440 @@
|
||||
package diff
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||
)
|
||||
|
||||
func TestFormatDiff(t *testing.T) {
|
||||
result := &DiffResult{
|
||||
Source: "source_db",
|
||||
Target: "target_db",
|
||||
Schemas: &SchemaDiff{
|
||||
Missing: []*models.Schema{},
|
||||
Extra: []*models.Schema{},
|
||||
Modified: []*SchemaChange{},
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
format OutputFormat
|
||||
wantErr bool
|
||||
}{
|
||||
{"summary format", FormatSummary, false},
|
||||
{"json format", FormatJSON, false},
|
||||
{"html format", FormatHTML, false},
|
||||
{"invalid format", OutputFormat("invalid"), true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
err := FormatDiff(result, tt.format, &buf)
|
||||
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("FormatDiff() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
if !tt.wantErr && buf.Len() == 0 {
|
||||
t.Error("FormatDiff() produced empty output")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatSummary(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
result *DiffResult
|
||||
wantStr []string // strings that should appear in output
|
||||
}{
|
||||
{
|
||||
name: "no differences",
|
||||
result: &DiffResult{
|
||||
Source: "source",
|
||||
Target: "target",
|
||||
Schemas: &SchemaDiff{
|
||||
Missing: []*models.Schema{},
|
||||
Extra: []*models.Schema{},
|
||||
Modified: []*SchemaChange{},
|
||||
},
|
||||
},
|
||||
wantStr: []string{"source", "target", "No differences found"},
|
||||
},
|
||||
{
|
||||
name: "with schema differences",
|
||||
result: &DiffResult{
|
||||
Source: "source",
|
||||
Target: "target",
|
||||
Schemas: &SchemaDiff{
|
||||
Missing: []*models.Schema{{Name: "schema1"}},
|
||||
Extra: []*models.Schema{{Name: "schema2"}},
|
||||
Modified: []*SchemaChange{
|
||||
{Name: "public"},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantStr: []string{"Schemas:", "Missing: 1", "Extra: 1", "Modified: 1"},
|
||||
},
|
||||
{
|
||||
name: "with table differences",
|
||||
result: &DiffResult{
|
||||
Source: "source",
|
||||
Target: "target",
|
||||
Schemas: &SchemaDiff{
|
||||
Modified: []*SchemaChange{
|
||||
{
|
||||
Name: "public",
|
||||
Tables: &TableDiff{
|
||||
Missing: []*models.Table{{Name: "users"}},
|
||||
Extra: []*models.Table{{Name: "posts"}},
|
||||
Modified: []*TableChange{
|
||||
{Name: "comments", Schema: "public"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantStr: []string{"Tables:", "Missing: 1", "Extra: 1", "Modified: 1"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
err := formatSummary(tt.result, &buf)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("formatSummary() error = %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
output := buf.String()
|
||||
for _, want := range tt.wantStr {
|
||||
if !strings.Contains(output, want) {
|
||||
t.Errorf("formatSummary() output doesn't contain %q\nGot: %s", want, output)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatJSON(t *testing.T) {
|
||||
result := &DiffResult{
|
||||
Source: "source",
|
||||
Target: "target",
|
||||
Schemas: &SchemaDiff{
|
||||
Missing: []*models.Schema{{Name: "schema1"}},
|
||||
Extra: []*models.Schema{},
|
||||
Modified: []*SchemaChange{},
|
||||
},
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
err := formatJSON(result, &buf)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("formatJSON() error = %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Check if output is valid JSON
|
||||
var decoded DiffResult
|
||||
if err := json.Unmarshal(buf.Bytes(), &decoded); err != nil {
|
||||
t.Errorf("formatJSON() produced invalid JSON: %v", err)
|
||||
}
|
||||
|
||||
// Check basic structure
|
||||
if decoded.Source != "source" {
|
||||
t.Errorf("formatJSON() source = %v, want %v", decoded.Source, "source")
|
||||
}
|
||||
if decoded.Target != "target" {
|
||||
t.Errorf("formatJSON() target = %v, want %v", decoded.Target, "target")
|
||||
}
|
||||
if len(decoded.Schemas.Missing) != 1 {
|
||||
t.Errorf("formatJSON() missing schemas = %v, want 1", len(decoded.Schemas.Missing))
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatHTML(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
result *DiffResult
|
||||
wantStr []string // HTML elements/content that should appear
|
||||
}{
|
||||
{
|
||||
name: "basic HTML structure",
|
||||
result: &DiffResult{
|
||||
Source: "source",
|
||||
Target: "target",
|
||||
Schemas: &SchemaDiff{
|
||||
Missing: []*models.Schema{},
|
||||
Extra: []*models.Schema{},
|
||||
Modified: []*SchemaChange{},
|
||||
},
|
||||
},
|
||||
wantStr: []string{
|
||||
"<!DOCTYPE html>",
|
||||
"<title>Database Diff Report</title>",
|
||||
"source",
|
||||
"target",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with schema differences",
|
||||
result: &DiffResult{
|
||||
Source: "source",
|
||||
Target: "target",
|
||||
Schemas: &SchemaDiff{
|
||||
Missing: []*models.Schema{{Name: "missing_schema"}},
|
||||
Extra: []*models.Schema{{Name: "extra_schema"}},
|
||||
Modified: []*SchemaChange{},
|
||||
},
|
||||
},
|
||||
wantStr: []string{
|
||||
"<!DOCTYPE html>",
|
||||
"missing_schema",
|
||||
"extra_schema",
|
||||
"MISSING",
|
||||
"EXTRA",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with table modifications",
|
||||
result: &DiffResult{
|
||||
Source: "source",
|
||||
Target: "target",
|
||||
Schemas: &SchemaDiff{
|
||||
Modified: []*SchemaChange{
|
||||
{
|
||||
Name: "public",
|
||||
Tables: &TableDiff{
|
||||
Modified: []*TableChange{
|
||||
{
|
||||
Name: "users",
|
||||
Schema: "public",
|
||||
Columns: &ColumnDiff{
|
||||
Missing: []*models.Column{{Name: "email", Type: "text"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantStr: []string{
|
||||
"public",
|
||||
"users",
|
||||
"email",
|
||||
"text",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
err := formatHTML(tt.result, &buf)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("formatHTML() error = %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
output := buf.String()
|
||||
for _, want := range tt.wantStr {
|
||||
if !strings.Contains(output, want) {
|
||||
t.Errorf("formatHTML() output doesn't contain %q", want)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatSummaryWithColumns(t *testing.T) {
|
||||
result := &DiffResult{
|
||||
Source: "source",
|
||||
Target: "target",
|
||||
Schemas: &SchemaDiff{
|
||||
Modified: []*SchemaChange{
|
||||
{
|
||||
Name: "public",
|
||||
Tables: &TableDiff{
|
||||
Modified: []*TableChange{
|
||||
{
|
||||
Name: "users",
|
||||
Schema: "public",
|
||||
Columns: &ColumnDiff{
|
||||
Missing: []*models.Column{{Name: "email"}},
|
||||
Extra: []*models.Column{{Name: "phone"}, {Name: "address"}},
|
||||
Modified: []*ColumnChange{
|
||||
{Name: "name"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
err := formatSummary(result, &buf)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("formatSummary() error = %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
output := buf.String()
|
||||
wantStrings := []string{
|
||||
"Columns:",
|
||||
"Missing: 1",
|
||||
"Extra: 2",
|
||||
"Modified: 1",
|
||||
}
|
||||
|
||||
for _, want := range wantStrings {
|
||||
if !strings.Contains(output, want) {
|
||||
t.Errorf("formatSummary() output doesn't contain %q\nGot: %s", want, output)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatSummaryWithIndexes(t *testing.T) {
|
||||
result := &DiffResult{
|
||||
Source: "source",
|
||||
Target: "target",
|
||||
Schemas: &SchemaDiff{
|
||||
Modified: []*SchemaChange{
|
||||
{
|
||||
Name: "public",
|
||||
Tables: &TableDiff{
|
||||
Modified: []*TableChange{
|
||||
{
|
||||
Name: "users",
|
||||
Schema: "public",
|
||||
Indexes: &IndexDiff{
|
||||
Missing: []*models.Index{{Name: "idx_email"}},
|
||||
Extra: []*models.Index{{Name: "idx_phone"}},
|
||||
Modified: []*IndexChange{{Name: "idx_name"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
err := formatSummary(result, &buf)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("formatSummary() error = %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
output := buf.String()
|
||||
if !strings.Contains(output, "Indexes:") {
|
||||
t.Error("formatSummary() output doesn't contain Indexes section")
|
||||
}
|
||||
if !strings.Contains(output, "Missing: 1") {
|
||||
t.Error("formatSummary() output doesn't contain correct missing count")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatSummaryWithConstraints(t *testing.T) {
|
||||
result := &DiffResult{
|
||||
Source: "source",
|
||||
Target: "target",
|
||||
Schemas: &SchemaDiff{
|
||||
Modified: []*SchemaChange{
|
||||
{
|
||||
Name: "public",
|
||||
Tables: &TableDiff{
|
||||
Modified: []*TableChange{
|
||||
{
|
||||
Name: "users",
|
||||
Schema: "public",
|
||||
Constraints: &ConstraintDiff{
|
||||
Missing: []*models.Constraint{{Name: "pk_users", Type: "PRIMARY KEY"}},
|
||||
Extra: []*models.Constraint{{Name: "fk_users_roles", Type: "FOREIGN KEY"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
err := formatSummary(result, &buf)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("formatSummary() error = %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
output := buf.String()
|
||||
if !strings.Contains(output, "Constraints:") {
|
||||
t.Error("formatSummary() output doesn't contain Constraints section")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatJSONIndentation(t *testing.T) {
|
||||
result := &DiffResult{
|
||||
Source: "source",
|
||||
Target: "target",
|
||||
Schemas: &SchemaDiff{
|
||||
Missing: []*models.Schema{{Name: "test"}},
|
||||
},
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
err := formatJSON(result, &buf)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("formatJSON() error = %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Check that JSON is indented (has newlines and spaces)
|
||||
output := buf.String()
|
||||
if !strings.Contains(output, "\n") {
|
||||
t.Error("formatJSON() should produce indented JSON with newlines")
|
||||
}
|
||||
if !strings.Contains(output, " ") {
|
||||
t.Error("formatJSON() should produce indented JSON with spaces")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOutputFormatConstants(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
format OutputFormat
|
||||
want string
|
||||
}{
|
||||
{"summary constant", FormatSummary, "summary"},
|
||||
{"json constant", FormatJSON, "json"},
|
||||
{"html constant", FormatHTML, "html"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if string(tt.format) != tt.want {
|
||||
t.Errorf("OutputFormat %v = %v, want %v", tt.name, tt.format, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
238
pkg/inspector/inspector_test.go
Normal file
238
pkg/inspector/inspector_test.go
Normal file
@@ -0,0 +1,238 @@
|
||||
package inspector
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewInspector(t *testing.T) {
|
||||
db := createTestDatabase()
|
||||
config := GetDefaultConfig()
|
||||
|
||||
inspector := NewInspector(db, config)
|
||||
|
||||
if inspector == nil {
|
||||
t.Fatal("NewInspector() returned nil")
|
||||
}
|
||||
|
||||
if inspector.db != db {
|
||||
t.Error("NewInspector() database not set correctly")
|
||||
}
|
||||
|
||||
if inspector.config != config {
|
||||
t.Error("NewInspector() config not set correctly")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInspect(t *testing.T) {
|
||||
db := createTestDatabase()
|
||||
config := GetDefaultConfig()
|
||||
|
||||
inspector := NewInspector(db, config)
|
||||
report, err := inspector.Inspect()
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Inspect() returned error: %v", err)
|
||||
}
|
||||
|
||||
if report == nil {
|
||||
t.Fatal("Inspect() returned nil report")
|
||||
}
|
||||
|
||||
if report.Database != db.Name {
|
||||
t.Errorf("Inspect() report.Database = %q, want %q", report.Database, db.Name)
|
||||
}
|
||||
|
||||
if report.Summary.TotalRules != len(config.Rules) {
|
||||
t.Errorf("Inspect() TotalRules = %d, want %d", report.Summary.TotalRules, len(config.Rules))
|
||||
}
|
||||
|
||||
if len(report.Violations) == 0 {
|
||||
t.Error("Inspect() returned no violations, expected some results")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInspectWithDisabledRules(t *testing.T) {
|
||||
db := createTestDatabase()
|
||||
config := GetDefaultConfig()
|
||||
|
||||
// Disable all rules
|
||||
for name := range config.Rules {
|
||||
rule := config.Rules[name]
|
||||
rule.Enabled = "off"
|
||||
config.Rules[name] = rule
|
||||
}
|
||||
|
||||
inspector := NewInspector(db, config)
|
||||
report, err := inspector.Inspect()
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Inspect() with disabled rules returned error: %v", err)
|
||||
}
|
||||
|
||||
if report.Summary.RulesChecked != 0 {
|
||||
t.Errorf("Inspect() RulesChecked = %d, want 0 (all disabled)", report.Summary.RulesChecked)
|
||||
}
|
||||
|
||||
if report.Summary.RulesSkipped != len(config.Rules) {
|
||||
t.Errorf("Inspect() RulesSkipped = %d, want %d", report.Summary.RulesSkipped, len(config.Rules))
|
||||
}
|
||||
}
|
||||
|
||||
func TestInspectWithEnforcedRules(t *testing.T) {
|
||||
db := createTestDatabase()
|
||||
config := GetDefaultConfig()
|
||||
|
||||
// Enable only one rule and enforce it
|
||||
for name := range config.Rules {
|
||||
rule := config.Rules[name]
|
||||
rule.Enabled = "off"
|
||||
config.Rules[name] = rule
|
||||
}
|
||||
|
||||
primaryKeyRule := config.Rules["primary_key_naming"]
|
||||
primaryKeyRule.Enabled = "enforce"
|
||||
primaryKeyRule.Pattern = "^id$"
|
||||
config.Rules["primary_key_naming"] = primaryKeyRule
|
||||
|
||||
inspector := NewInspector(db, config)
|
||||
report, err := inspector.Inspect()
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Inspect() returned error: %v", err)
|
||||
}
|
||||
|
||||
if report.Summary.RulesChecked != 1 {
|
||||
t.Errorf("Inspect() RulesChecked = %d, want 1", report.Summary.RulesChecked)
|
||||
}
|
||||
|
||||
// All results should be at error level for enforced rules
|
||||
for _, violation := range report.Violations {
|
||||
if violation.Level != "error" {
|
||||
t.Errorf("Enforced rule violation has Level = %q, want \"error\"", violation.Level)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateSummary(t *testing.T) {
|
||||
db := createTestDatabase()
|
||||
config := GetDefaultConfig()
|
||||
inspector := NewInspector(db, config)
|
||||
|
||||
results := []ValidationResult{
|
||||
{RuleName: "rule1", Passed: true, Level: "error"},
|
||||
{RuleName: "rule2", Passed: false, Level: "error"},
|
||||
{RuleName: "rule3", Passed: false, Level: "warning"},
|
||||
{RuleName: "rule4", Passed: true, Level: "warning"},
|
||||
}
|
||||
|
||||
summary := inspector.generateSummary(results)
|
||||
|
||||
if summary.PassedCount != 2 {
|
||||
t.Errorf("generateSummary() PassedCount = %d, want 2", summary.PassedCount)
|
||||
}
|
||||
|
||||
if summary.ErrorCount != 1 {
|
||||
t.Errorf("generateSummary() ErrorCount = %d, want 1", summary.ErrorCount)
|
||||
}
|
||||
|
||||
if summary.WarningCount != 1 {
|
||||
t.Errorf("generateSummary() WarningCount = %d, want 1", summary.WarningCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasErrors(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
report *InspectorReport
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "with errors",
|
||||
report: &InspectorReport{
|
||||
Summary: ReportSummary{
|
||||
ErrorCount: 5,
|
||||
},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "without errors",
|
||||
report: &InspectorReport{
|
||||
Summary: ReportSummary{
|
||||
ErrorCount: 0,
|
||||
WarningCount: 3,
|
||||
},
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := tt.report.HasErrors(); got != tt.want {
|
||||
t.Errorf("HasErrors() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetValidator(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
functionName string
|
||||
wantExists bool
|
||||
}{
|
||||
{"primary_key_naming", "primary_key_naming", true},
|
||||
{"primary_key_datatype", "primary_key_datatype", true},
|
||||
{"foreign_key_column_naming", "foreign_key_column_naming", true},
|
||||
{"table_regexpr", "table_regexpr", true},
|
||||
{"column_regexpr", "column_regexpr", true},
|
||||
{"reserved_words", "reserved_words", true},
|
||||
{"have_primary_key", "have_primary_key", true},
|
||||
{"orphaned_foreign_key", "orphaned_foreign_key", true},
|
||||
{"circular_dependency", "circular_dependency", true},
|
||||
{"unknown_function", "unknown_function", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, exists := getValidator(tt.functionName)
|
||||
if exists != tt.wantExists {
|
||||
t.Errorf("getValidator(%q) exists = %v, want %v", tt.functionName, exists, tt.wantExists)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateResult(t *testing.T) {
|
||||
result := createResult(
|
||||
"test_rule",
|
||||
true,
|
||||
"Test message",
|
||||
"schema.table.column",
|
||||
map[string]interface{}{
|
||||
"key1": "value1",
|
||||
"key2": 42,
|
||||
},
|
||||
)
|
||||
|
||||
if result.RuleName != "test_rule" {
|
||||
t.Errorf("createResult() RuleName = %q, want \"test_rule\"", result.RuleName)
|
||||
}
|
||||
|
||||
if !result.Passed {
|
||||
t.Error("createResult() Passed = false, want true")
|
||||
}
|
||||
|
||||
if result.Message != "Test message" {
|
||||
t.Errorf("createResult() Message = %q, want \"Test message\"", result.Message)
|
||||
}
|
||||
|
||||
if result.Location != "schema.table.column" {
|
||||
t.Errorf("createResult() Location = %q, want \"schema.table.column\"", result.Location)
|
||||
}
|
||||
|
||||
if len(result.Context) != 2 {
|
||||
t.Errorf("createResult() Context length = %d, want 2", len(result.Context))
|
||||
}
|
||||
}
|
||||
366
pkg/inspector/report_test.go
Normal file
366
pkg/inspector/report_test.go
Normal file
@@ -0,0 +1,366 @@
|
||||
package inspector
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func createTestReport() *InspectorReport {
|
||||
return &InspectorReport{
|
||||
Summary: ReportSummary{
|
||||
TotalRules: 10,
|
||||
RulesChecked: 8,
|
||||
RulesSkipped: 2,
|
||||
ErrorCount: 3,
|
||||
WarningCount: 5,
|
||||
PassedCount: 12,
|
||||
},
|
||||
Violations: []ValidationResult{
|
||||
{
|
||||
RuleName: "primary_key_naming",
|
||||
Level: "error",
|
||||
Message: "Primary key should start with 'id_'",
|
||||
Location: "public.users.user_id",
|
||||
Passed: false,
|
||||
Context: map[string]interface{}{
|
||||
"schema": "public",
|
||||
"table": "users",
|
||||
"column": "user_id",
|
||||
"pattern": "^id_",
|
||||
},
|
||||
},
|
||||
{
|
||||
RuleName: "table_name_length",
|
||||
Level: "warning",
|
||||
Message: "Table name too long",
|
||||
Location: "public.very_long_table_name_that_exceeds_limits",
|
||||
Passed: false,
|
||||
Context: map[string]interface{}{
|
||||
"schema": "public",
|
||||
"table": "very_long_table_name_that_exceeds_limits",
|
||||
"length": 44,
|
||||
"max_length": 32,
|
||||
},
|
||||
},
|
||||
},
|
||||
GeneratedAt: time.Now(),
|
||||
Database: "testdb",
|
||||
SourceFormat: "postgresql",
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewMarkdownFormatter(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
formatter := NewMarkdownFormatter(&buf)
|
||||
|
||||
if formatter == nil {
|
||||
t.Fatal("NewMarkdownFormatter() returned nil")
|
||||
}
|
||||
|
||||
// Buffer is not a terminal, so colors should be disabled
|
||||
if formatter.UseColors {
|
||||
t.Error("NewMarkdownFormatter() UseColors should be false for non-terminal")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewJSONFormatter(t *testing.T) {
|
||||
formatter := NewJSONFormatter()
|
||||
|
||||
if formatter == nil {
|
||||
t.Fatal("NewJSONFormatter() returned nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarkdownFormatter_Format(t *testing.T) {
|
||||
report := createTestReport()
|
||||
var buf bytes.Buffer
|
||||
formatter := NewMarkdownFormatter(&buf)
|
||||
|
||||
output, err := formatter.Format(report)
|
||||
if err != nil {
|
||||
t.Fatalf("MarkdownFormatter.Format() returned error: %v", err)
|
||||
}
|
||||
|
||||
// Check that output contains expected sections
|
||||
if !strings.Contains(output, "# RelSpec Inspector Report") {
|
||||
t.Error("Markdown output missing header")
|
||||
}
|
||||
|
||||
if !strings.Contains(output, "Database:") {
|
||||
t.Error("Markdown output missing database field")
|
||||
}
|
||||
|
||||
if !strings.Contains(output, "testdb") {
|
||||
t.Error("Markdown output missing database name")
|
||||
}
|
||||
|
||||
if !strings.Contains(output, "Summary") {
|
||||
t.Error("Markdown output missing summary section")
|
||||
}
|
||||
|
||||
if !strings.Contains(output, "Rules Checked: 8") {
|
||||
t.Error("Markdown output missing rules checked count")
|
||||
}
|
||||
|
||||
if !strings.Contains(output, "Errors: 3") {
|
||||
t.Error("Markdown output missing error count")
|
||||
}
|
||||
|
||||
if !strings.Contains(output, "Warnings: 5") {
|
||||
t.Error("Markdown output missing warning count")
|
||||
}
|
||||
|
||||
if !strings.Contains(output, "Violations") {
|
||||
t.Error("Markdown output missing violations section")
|
||||
}
|
||||
|
||||
if !strings.Contains(output, "primary_key_naming") {
|
||||
t.Error("Markdown output missing rule name")
|
||||
}
|
||||
|
||||
if !strings.Contains(output, "public.users.user_id") {
|
||||
t.Error("Markdown output missing location")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarkdownFormatter_FormatNoViolations(t *testing.T) {
|
||||
report := &InspectorReport{
|
||||
Summary: ReportSummary{
|
||||
TotalRules: 10,
|
||||
RulesChecked: 10,
|
||||
RulesSkipped: 0,
|
||||
ErrorCount: 0,
|
||||
WarningCount: 0,
|
||||
PassedCount: 50,
|
||||
},
|
||||
Violations: []ValidationResult{},
|
||||
GeneratedAt: time.Now(),
|
||||
Database: "testdb",
|
||||
SourceFormat: "postgresql",
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
formatter := NewMarkdownFormatter(&buf)
|
||||
|
||||
output, err := formatter.Format(report)
|
||||
if err != nil {
|
||||
t.Fatalf("MarkdownFormatter.Format() returned error: %v", err)
|
||||
}
|
||||
|
||||
if !strings.Contains(output, "No violations found") {
|
||||
t.Error("Markdown output should indicate no violations")
|
||||
}
|
||||
}
|
||||
|
||||
func TestJSONFormatter_Format(t *testing.T) {
|
||||
report := createTestReport()
|
||||
formatter := NewJSONFormatter()
|
||||
|
||||
output, err := formatter.Format(report)
|
||||
if err != nil {
|
||||
t.Fatalf("JSONFormatter.Format() returned error: %v", err)
|
||||
}
|
||||
|
||||
// Verify it's valid JSON
|
||||
var decoded InspectorReport
|
||||
if err := json.Unmarshal([]byte(output), &decoded); err != nil {
|
||||
t.Fatalf("JSONFormatter.Format() produced invalid JSON: %v", err)
|
||||
}
|
||||
|
||||
// Check key fields
|
||||
if decoded.Database != "testdb" {
|
||||
t.Errorf("JSON decoded Database = %q, want \"testdb\"", decoded.Database)
|
||||
}
|
||||
|
||||
if decoded.Summary.ErrorCount != 3 {
|
||||
t.Errorf("JSON decoded ErrorCount = %d, want 3", decoded.Summary.ErrorCount)
|
||||
}
|
||||
|
||||
if len(decoded.Violations) != 2 {
|
||||
t.Errorf("JSON decoded Violations length = %d, want 2", len(decoded.Violations))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarkdownFormatter_FormatHeader(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
formatter := NewMarkdownFormatter(&buf)
|
||||
|
||||
header := formatter.formatHeader("Test Header")
|
||||
|
||||
if !strings.Contains(header, "# Test Header") {
|
||||
t.Errorf("formatHeader() = %q, want to contain \"# Test Header\"", header)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarkdownFormatter_FormatBold(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
useColors bool
|
||||
text string
|
||||
wantContains string
|
||||
}{
|
||||
{
|
||||
name: "without colors",
|
||||
useColors: false,
|
||||
text: "Bold Text",
|
||||
wantContains: "**Bold Text**",
|
||||
},
|
||||
{
|
||||
name: "with colors",
|
||||
useColors: true,
|
||||
text: "Bold Text",
|
||||
wantContains: "Bold Text",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
formatter := &MarkdownFormatter{UseColors: tt.useColors}
|
||||
result := formatter.formatBold(tt.text)
|
||||
|
||||
if !strings.Contains(result, tt.wantContains) {
|
||||
t.Errorf("formatBold() = %q, want to contain %q", result, tt.wantContains)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarkdownFormatter_Colorize(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
useColors bool
|
||||
text string
|
||||
color string
|
||||
wantColor bool
|
||||
}{
|
||||
{
|
||||
name: "without colors",
|
||||
useColors: false,
|
||||
text: "Test",
|
||||
color: colorRed,
|
||||
wantColor: false,
|
||||
},
|
||||
{
|
||||
name: "with colors",
|
||||
useColors: true,
|
||||
text: "Test",
|
||||
color: colorRed,
|
||||
wantColor: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
formatter := &MarkdownFormatter{UseColors: tt.useColors}
|
||||
result := formatter.colorize(tt.text, tt.color)
|
||||
|
||||
hasColor := strings.Contains(result, tt.color)
|
||||
if hasColor != tt.wantColor {
|
||||
t.Errorf("colorize() has color codes = %v, want %v", hasColor, tt.wantColor)
|
||||
}
|
||||
|
||||
if !strings.Contains(result, tt.text) {
|
||||
t.Errorf("colorize() doesn't contain original text %q", tt.text)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarkdownFormatter_FormatContext(t *testing.T) {
|
||||
formatter := &MarkdownFormatter{UseColors: false}
|
||||
|
||||
context := map[string]interface{}{
|
||||
"schema": "public",
|
||||
"table": "users",
|
||||
"column": "id",
|
||||
"pattern": "^id_",
|
||||
"max_length": 64,
|
||||
}
|
||||
|
||||
result := formatter.formatContext(context)
|
||||
|
||||
// Should not include schema, table, column (they're in location)
|
||||
if strings.Contains(result, "schema") {
|
||||
t.Error("formatContext() should skip schema field")
|
||||
}
|
||||
|
||||
if strings.Contains(result, "table=") {
|
||||
t.Error("formatContext() should skip table field")
|
||||
}
|
||||
|
||||
if strings.Contains(result, "column=") {
|
||||
t.Error("formatContext() should skip column field")
|
||||
}
|
||||
|
||||
// Should include other fields
|
||||
if !strings.Contains(result, "pattern") {
|
||||
t.Error("formatContext() should include pattern field")
|
||||
}
|
||||
|
||||
if !strings.Contains(result, "max_length") {
|
||||
t.Error("formatContext() should include max_length field")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarkdownFormatter_FormatViolation(t *testing.T) {
|
||||
formatter := &MarkdownFormatter{UseColors: false}
|
||||
|
||||
violation := ValidationResult{
|
||||
RuleName: "test_rule",
|
||||
Level: "error",
|
||||
Message: "Test violation message",
|
||||
Location: "public.users.id",
|
||||
Passed: false,
|
||||
Context: map[string]interface{}{
|
||||
"pattern": "^id_",
|
||||
},
|
||||
}
|
||||
|
||||
result := formatter.formatViolation(violation, colorRed)
|
||||
|
||||
if !strings.Contains(result, "test_rule") {
|
||||
t.Error("formatViolation() should include rule name")
|
||||
}
|
||||
|
||||
if !strings.Contains(result, "Test violation message") {
|
||||
t.Error("formatViolation() should include message")
|
||||
}
|
||||
|
||||
if !strings.Contains(result, "public.users.id") {
|
||||
t.Error("formatViolation() should include location")
|
||||
}
|
||||
|
||||
if !strings.Contains(result, "Location:") {
|
||||
t.Error("formatViolation() should include Location label")
|
||||
}
|
||||
|
||||
if !strings.Contains(result, "Message:") {
|
||||
t.Error("formatViolation() should include Message label")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReportFormatConstants(t *testing.T) {
|
||||
// Test that color constants are defined
|
||||
if colorReset == "" {
|
||||
t.Error("colorReset is not defined")
|
||||
}
|
||||
|
||||
if colorRed == "" {
|
||||
t.Error("colorRed is not defined")
|
||||
}
|
||||
|
||||
if colorYellow == "" {
|
||||
t.Error("colorYellow is not defined")
|
||||
}
|
||||
|
||||
if colorGreen == "" {
|
||||
t.Error("colorGreen is not defined")
|
||||
}
|
||||
|
||||
if colorBold == "" {
|
||||
t.Error("colorBold is not defined")
|
||||
}
|
||||
}
|
||||
249
pkg/inspector/rules_test.go
Normal file
249
pkg/inspector/rules_test.go
Normal file
@@ -0,0 +1,249 @@
|
||||
package inspector
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGetDefaultConfig(t *testing.T) {
|
||||
config := GetDefaultConfig()
|
||||
|
||||
if config == nil {
|
||||
t.Fatal("GetDefaultConfig() returned nil")
|
||||
}
|
||||
|
||||
if config.Version != "1.0" {
|
||||
t.Errorf("GetDefaultConfig() Version = %q, want \"1.0\"", config.Version)
|
||||
}
|
||||
|
||||
if len(config.Rules) == 0 {
|
||||
t.Error("GetDefaultConfig() returned no rules")
|
||||
}
|
||||
|
||||
// Check that all expected rules are present
|
||||
expectedRules := []string{
|
||||
"primary_key_naming",
|
||||
"primary_key_datatype",
|
||||
"primary_key_auto_increment",
|
||||
"foreign_key_column_naming",
|
||||
"foreign_key_constraint_naming",
|
||||
"foreign_key_index",
|
||||
"table_naming_case",
|
||||
"column_naming_case",
|
||||
"table_name_length",
|
||||
"column_name_length",
|
||||
"reserved_keywords",
|
||||
"missing_primary_key",
|
||||
"orphaned_foreign_key",
|
||||
"circular_dependency",
|
||||
}
|
||||
|
||||
for _, ruleName := range expectedRules {
|
||||
if _, exists := config.Rules[ruleName]; !exists {
|
||||
t.Errorf("GetDefaultConfig() missing rule: %q", ruleName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfig_NonExistentFile(t *testing.T) {
|
||||
// Try to load a non-existent file
|
||||
config, err := LoadConfig("/path/to/nonexistent/file.yaml")
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() with non-existent file returned error: %v", err)
|
||||
}
|
||||
|
||||
// Should return default config
|
||||
if config == nil {
|
||||
t.Fatal("LoadConfig() returned nil config for non-existent file")
|
||||
}
|
||||
|
||||
if len(config.Rules) == 0 {
|
||||
t.Error("LoadConfig() returned config with no rules")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfig_ValidFile(t *testing.T) {
|
||||
// Create a temporary config file
|
||||
tmpDir := t.TempDir()
|
||||
configPath := filepath.Join(tmpDir, "test-config.yaml")
|
||||
|
||||
configContent := `version: "1.0"
|
||||
rules:
|
||||
primary_key_naming:
|
||||
enabled: "enforce"
|
||||
function: "primary_key_naming"
|
||||
pattern: "^pk_"
|
||||
message: "Primary keys must start with pk_"
|
||||
table_name_length:
|
||||
enabled: "warn"
|
||||
function: "table_name_length"
|
||||
max_length: 50
|
||||
message: "Table name too long"
|
||||
`
|
||||
|
||||
err := os.WriteFile(configPath, []byte(configContent), 0644)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test config file: %v", err)
|
||||
}
|
||||
|
||||
config, err := LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() returned error: %v", err)
|
||||
}
|
||||
|
||||
if config.Version != "1.0" {
|
||||
t.Errorf("LoadConfig() Version = %q, want \"1.0\"", config.Version)
|
||||
}
|
||||
|
||||
if len(config.Rules) != 2 {
|
||||
t.Errorf("LoadConfig() loaded %d rules, want 2", len(config.Rules))
|
||||
}
|
||||
|
||||
// Check primary_key_naming rule
|
||||
pkRule, exists := config.Rules["primary_key_naming"]
|
||||
if !exists {
|
||||
t.Fatal("LoadConfig() missing primary_key_naming rule")
|
||||
}
|
||||
|
||||
if pkRule.Enabled != "enforce" {
|
||||
t.Errorf("primary_key_naming.Enabled = %q, want \"enforce\"", pkRule.Enabled)
|
||||
}
|
||||
|
||||
if pkRule.Pattern != "^pk_" {
|
||||
t.Errorf("primary_key_naming.Pattern = %q, want \"^pk_\"", pkRule.Pattern)
|
||||
}
|
||||
|
||||
// Check table_name_length rule
|
||||
lengthRule, exists := config.Rules["table_name_length"]
|
||||
if !exists {
|
||||
t.Fatal("LoadConfig() missing table_name_length rule")
|
||||
}
|
||||
|
||||
if lengthRule.MaxLength != 50 {
|
||||
t.Errorf("table_name_length.MaxLength = %d, want 50", lengthRule.MaxLength)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfig_InvalidYAML(t *testing.T) {
|
||||
// Create a temporary invalid config file
|
||||
tmpDir := t.TempDir()
|
||||
configPath := filepath.Join(tmpDir, "invalid-config.yaml")
|
||||
|
||||
invalidContent := `invalid: yaml: content: {[}]`
|
||||
|
||||
err := os.WriteFile(configPath, []byte(invalidContent), 0644)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test config file: %v", err)
|
||||
}
|
||||
|
||||
_, err = LoadConfig(configPath)
|
||||
if err == nil {
|
||||
t.Error("LoadConfig() with invalid YAML did not return error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuleIsEnabled(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
rule Rule
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "enforce is enabled",
|
||||
rule: Rule{Enabled: "enforce"},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "warn is enabled",
|
||||
rule: Rule{Enabled: "warn"},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "off is not enabled",
|
||||
rule: Rule{Enabled: "off"},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "empty is not enabled",
|
||||
rule: Rule{Enabled: ""},
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := tt.rule.IsEnabled(); got != tt.want {
|
||||
t.Errorf("Rule.IsEnabled() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuleIsEnforced(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
rule Rule
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "enforce is enforced",
|
||||
rule: Rule{Enabled: "enforce"},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "warn is not enforced",
|
||||
rule: Rule{Enabled: "warn"},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "off is not enforced",
|
||||
rule: Rule{Enabled: "off"},
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := tt.rule.IsEnforced(); got != tt.want {
|
||||
t.Errorf("Rule.IsEnforced() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultConfigRuleSettings(t *testing.T) {
|
||||
config := GetDefaultConfig()
|
||||
|
||||
// Test specific rule settings
|
||||
pkNamingRule := config.Rules["primary_key_naming"]
|
||||
if pkNamingRule.Function != "primary_key_naming" {
|
||||
t.Errorf("primary_key_naming.Function = %q, want \"primary_key_naming\"", pkNamingRule.Function)
|
||||
}
|
||||
|
||||
if pkNamingRule.Pattern != "^id_" {
|
||||
t.Errorf("primary_key_naming.Pattern = %q, want \"^id_\"", pkNamingRule.Pattern)
|
||||
}
|
||||
|
||||
// Test datatype rule
|
||||
pkDatatypeRule := config.Rules["primary_key_datatype"]
|
||||
if len(pkDatatypeRule.AllowedTypes) == 0 {
|
||||
t.Error("primary_key_datatype has no allowed types")
|
||||
}
|
||||
|
||||
// Test length rule
|
||||
tableLengthRule := config.Rules["table_name_length"]
|
||||
if tableLengthRule.MaxLength != 64 {
|
||||
t.Errorf("table_name_length.MaxLength = %d, want 64", tableLengthRule.MaxLength)
|
||||
}
|
||||
|
||||
// Test reserved keywords rule
|
||||
reservedRule := config.Rules["reserved_keywords"]
|
||||
if !reservedRule.CheckTables {
|
||||
t.Error("reserved_keywords.CheckTables should be true")
|
||||
}
|
||||
if !reservedRule.CheckColumns {
|
||||
t.Error("reserved_keywords.CheckColumns should be true")
|
||||
}
|
||||
}
|
||||
837
pkg/inspector/validators_test.go
Normal file
837
pkg/inspector/validators_test.go
Normal file
@@ -0,0 +1,837 @@
|
||||
package inspector
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||
)
|
||||
|
||||
// Helper function to create test database
|
||||
func createTestDatabase() *models.Database {
|
||||
return &models.Database{
|
||||
Name: "testdb",
|
||||
Schemas: []*models.Schema{
|
||||
{
|
||||
Name: "public",
|
||||
Tables: []*models.Table{
|
||||
{
|
||||
Name: "users",
|
||||
Columns: map[string]*models.Column{
|
||||
"id": {
|
||||
Name: "id",
|
||||
Type: "bigserial",
|
||||
IsPrimaryKey: true,
|
||||
AutoIncrement: true,
|
||||
},
|
||||
"username": {
|
||||
Name: "username",
|
||||
Type: "varchar(50)",
|
||||
NotNull: true,
|
||||
IsPrimaryKey: false,
|
||||
},
|
||||
"rid_organization": {
|
||||
Name: "rid_organization",
|
||||
Type: "bigint",
|
||||
NotNull: true,
|
||||
IsPrimaryKey: false,
|
||||
},
|
||||
},
|
||||
Constraints: map[string]*models.Constraint{
|
||||
"fk_users_organization": {
|
||||
Name: "fk_users_organization",
|
||||
Type: models.ForeignKeyConstraint,
|
||||
Columns: []string{"rid_organization"},
|
||||
ReferencedTable: "organizations",
|
||||
ReferencedSchema: "public",
|
||||
ReferencedColumns: []string{"id"},
|
||||
},
|
||||
},
|
||||
Indexes: map[string]*models.Index{
|
||||
"idx_rid_organization": {
|
||||
Name: "idx_rid_organization",
|
||||
Columns: []string{"rid_organization"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "organizations",
|
||||
Columns: map[string]*models.Column{
|
||||
"id": {
|
||||
Name: "id",
|
||||
Type: "bigserial",
|
||||
IsPrimaryKey: true,
|
||||
AutoIncrement: true,
|
||||
},
|
||||
"name": {
|
||||
Name: "name",
|
||||
Type: "varchar(100)",
|
||||
NotNull: true,
|
||||
IsPrimaryKey: false,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidatePrimaryKeyNaming(t *testing.T) {
|
||||
db := createTestDatabase()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
rule Rule
|
||||
wantLen int
|
||||
wantPass bool
|
||||
}{
|
||||
{
|
||||
name: "matching pattern id",
|
||||
rule: Rule{
|
||||
Pattern: "^id$",
|
||||
Message: "Primary key should be 'id'",
|
||||
},
|
||||
wantLen: 2,
|
||||
wantPass: true,
|
||||
},
|
||||
{
|
||||
name: "non-matching pattern id_",
|
||||
rule: Rule{
|
||||
Pattern: "^id_",
|
||||
Message: "Primary key should start with 'id_'",
|
||||
},
|
||||
wantLen: 2,
|
||||
wantPass: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
results := validatePrimaryKeyNaming(db, tt.rule, "test_rule")
|
||||
if len(results) != tt.wantLen {
|
||||
t.Errorf("validatePrimaryKeyNaming() returned %d results, want %d", len(results), tt.wantLen)
|
||||
}
|
||||
if len(results) > 0 && results[0].Passed != tt.wantPass {
|
||||
t.Errorf("validatePrimaryKeyNaming() passed=%v, want %v", results[0].Passed, tt.wantPass)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidatePrimaryKeyDatatype(t *testing.T) {
|
||||
db := createTestDatabase()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
rule Rule
|
||||
wantLen int
|
||||
wantPass bool
|
||||
}{
|
||||
{
|
||||
name: "allowed type bigserial",
|
||||
rule: Rule{
|
||||
AllowedTypes: []string{"bigserial", "bigint", "int"},
|
||||
Message: "Primary key should use integer types",
|
||||
},
|
||||
wantLen: 2,
|
||||
wantPass: true,
|
||||
},
|
||||
{
|
||||
name: "disallowed type",
|
||||
rule: Rule{
|
||||
AllowedTypes: []string{"uuid"},
|
||||
Message: "Primary key should use UUID",
|
||||
},
|
||||
wantLen: 2,
|
||||
wantPass: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
results := validatePrimaryKeyDatatype(db, tt.rule, "test_rule")
|
||||
if len(results) != tt.wantLen {
|
||||
t.Errorf("validatePrimaryKeyDatatype() returned %d results, want %d", len(results), tt.wantLen)
|
||||
}
|
||||
if len(results) > 0 && results[0].Passed != tt.wantPass {
|
||||
t.Errorf("validatePrimaryKeyDatatype() passed=%v, want %v", results[0].Passed, tt.wantPass)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidatePrimaryKeyAutoIncrement(t *testing.T) {
|
||||
db := createTestDatabase()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
rule Rule
|
||||
wantLen int
|
||||
}{
|
||||
{
|
||||
name: "require auto increment",
|
||||
rule: Rule{
|
||||
RequireAutoIncrement: true,
|
||||
Message: "Primary key should have auto-increment",
|
||||
},
|
||||
wantLen: 0, // No violations - all PKs have auto-increment
|
||||
},
|
||||
{
|
||||
name: "disallow auto increment",
|
||||
rule: Rule{
|
||||
RequireAutoIncrement: false,
|
||||
Message: "Primary key should not have auto-increment",
|
||||
},
|
||||
wantLen: 2, // 2 violations - both PKs have auto-increment
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
results := validatePrimaryKeyAutoIncrement(db, tt.rule, "test_rule")
|
||||
if len(results) != tt.wantLen {
|
||||
t.Errorf("validatePrimaryKeyAutoIncrement() returned %d results, want %d", len(results), tt.wantLen)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateForeignKeyColumnNaming(t *testing.T) {
|
||||
db := createTestDatabase()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
rule Rule
|
||||
wantLen int
|
||||
wantPass bool
|
||||
}{
|
||||
{
|
||||
name: "matching pattern rid_",
|
||||
rule: Rule{
|
||||
Pattern: "^rid_",
|
||||
Message: "Foreign key columns should start with 'rid_'",
|
||||
},
|
||||
wantLen: 1,
|
||||
wantPass: true,
|
||||
},
|
||||
{
|
||||
name: "non-matching pattern fk_",
|
||||
rule: Rule{
|
||||
Pattern: "^fk_",
|
||||
Message: "Foreign key columns should start with 'fk_'",
|
||||
},
|
||||
wantLen: 1,
|
||||
wantPass: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
results := validateForeignKeyColumnNaming(db, tt.rule, "test_rule")
|
||||
if len(results) != tt.wantLen {
|
||||
t.Errorf("validateForeignKeyColumnNaming() returned %d results, want %d", len(results), tt.wantLen)
|
||||
}
|
||||
if len(results) > 0 && results[0].Passed != tt.wantPass {
|
||||
t.Errorf("validateForeignKeyColumnNaming() passed=%v, want %v", results[0].Passed, tt.wantPass)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateForeignKeyConstraintNaming(t *testing.T) {
|
||||
db := createTestDatabase()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
rule Rule
|
||||
wantLen int
|
||||
wantPass bool
|
||||
}{
|
||||
{
|
||||
name: "matching pattern fk_",
|
||||
rule: Rule{
|
||||
Pattern: "^fk_",
|
||||
Message: "Foreign key constraints should start with 'fk_'",
|
||||
},
|
||||
wantLen: 1,
|
||||
wantPass: true,
|
||||
},
|
||||
{
|
||||
name: "non-matching pattern FK_",
|
||||
rule: Rule{
|
||||
Pattern: "^FK_",
|
||||
Message: "Foreign key constraints should start with 'FK_'",
|
||||
},
|
||||
wantLen: 1,
|
||||
wantPass: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
results := validateForeignKeyConstraintNaming(db, tt.rule, "test_rule")
|
||||
if len(results) != tt.wantLen {
|
||||
t.Errorf("validateForeignKeyConstraintNaming() returned %d results, want %d", len(results), tt.wantLen)
|
||||
}
|
||||
if len(results) > 0 && results[0].Passed != tt.wantPass {
|
||||
t.Errorf("validateForeignKeyConstraintNaming() passed=%v, want %v", results[0].Passed, tt.wantPass)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateForeignKeyIndex(t *testing.T) {
|
||||
db := createTestDatabase()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
rule Rule
|
||||
wantLen int
|
||||
wantPass bool
|
||||
}{
|
||||
{
|
||||
name: "require index with index present",
|
||||
rule: Rule{
|
||||
RequireIndex: true,
|
||||
Message: "Foreign key columns should have indexes",
|
||||
},
|
||||
wantLen: 1,
|
||||
wantPass: true,
|
||||
},
|
||||
{
|
||||
name: "no requirement",
|
||||
rule: Rule{
|
||||
RequireIndex: false,
|
||||
Message: "Foreign key index check disabled",
|
||||
},
|
||||
wantLen: 0,
|
||||
wantPass: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
results := validateForeignKeyIndex(db, tt.rule, "test_rule")
|
||||
if len(results) != tt.wantLen {
|
||||
t.Errorf("validateForeignKeyIndex() returned %d results, want %d", len(results), tt.wantLen)
|
||||
}
|
||||
if len(results) > 0 && results[0].Passed != tt.wantPass {
|
||||
t.Errorf("validateForeignKeyIndex() passed=%v, want %v", results[0].Passed, tt.wantPass)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateTableNamingCase(t *testing.T) {
|
||||
db := createTestDatabase()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
rule Rule
|
||||
wantLen int
|
||||
wantPass bool
|
||||
}{
|
||||
{
|
||||
name: "lowercase snake_case pattern",
|
||||
rule: Rule{
|
||||
Pattern: "^[a-z][a-z0-9_]*$",
|
||||
Case: "lowercase",
|
||||
Message: "Table names should be lowercase snake_case",
|
||||
},
|
||||
wantLen: 2,
|
||||
wantPass: true,
|
||||
},
|
||||
{
|
||||
name: "uppercase pattern",
|
||||
rule: Rule{
|
||||
Pattern: "^[A-Z][A-Z0-9_]*$",
|
||||
Case: "uppercase",
|
||||
Message: "Table names should be uppercase",
|
||||
},
|
||||
wantLen: 2,
|
||||
wantPass: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
results := validateTableNamingCase(db, tt.rule, "test_rule")
|
||||
if len(results) != tt.wantLen {
|
||||
t.Errorf("validateTableNamingCase() returned %d results, want %d", len(results), tt.wantLen)
|
||||
}
|
||||
if len(results) > 0 && results[0].Passed != tt.wantPass {
|
||||
t.Errorf("validateTableNamingCase() passed=%v, want %v", results[0].Passed, tt.wantPass)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateColumnNamingCase(t *testing.T) {
|
||||
db := createTestDatabase()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
rule Rule
|
||||
wantLen int
|
||||
wantPass bool
|
||||
}{
|
||||
{
|
||||
name: "lowercase snake_case pattern",
|
||||
rule: Rule{
|
||||
Pattern: "^[a-z][a-z0-9_]*$",
|
||||
Case: "lowercase",
|
||||
Message: "Column names should be lowercase snake_case",
|
||||
},
|
||||
wantLen: 5, // 5 total columns across both tables
|
||||
wantPass: true,
|
||||
},
|
||||
{
|
||||
name: "camelCase pattern",
|
||||
rule: Rule{
|
||||
Pattern: "^[a-z][a-zA-Z0-9]*$",
|
||||
Case: "camelCase",
|
||||
Message: "Column names should be camelCase",
|
||||
},
|
||||
wantLen: 5,
|
||||
wantPass: false, // rid_organization has underscore
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
results := validateColumnNamingCase(db, tt.rule, "test_rule")
|
||||
if len(results) != tt.wantLen {
|
||||
t.Errorf("validateColumnNamingCase() returned %d results, want %d", len(results), tt.wantLen)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateTableNameLength(t *testing.T) {
|
||||
db := createTestDatabase()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
rule Rule
|
||||
wantLen int
|
||||
wantPass bool
|
||||
}{
|
||||
{
|
||||
name: "max length 64",
|
||||
rule: Rule{
|
||||
MaxLength: 64,
|
||||
Message: "Table name too long",
|
||||
},
|
||||
wantLen: 2,
|
||||
wantPass: true,
|
||||
},
|
||||
{
|
||||
name: "max length 5",
|
||||
rule: Rule{
|
||||
MaxLength: 5,
|
||||
Message: "Table name too long",
|
||||
},
|
||||
wantLen: 2,
|
||||
wantPass: false, // "users" is 5 chars (passes), "organizations" is 13 (fails)
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
results := validateTableNameLength(db, tt.rule, "test_rule")
|
||||
if len(results) != tt.wantLen {
|
||||
t.Errorf("validateTableNameLength() returned %d results, want %d", len(results), tt.wantLen)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateColumnNameLength(t *testing.T) {
|
||||
db := createTestDatabase()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
rule Rule
|
||||
wantLen int
|
||||
wantPass bool
|
||||
}{
|
||||
{
|
||||
name: "max length 64",
|
||||
rule: Rule{
|
||||
MaxLength: 64,
|
||||
Message: "Column name too long",
|
||||
},
|
||||
wantLen: 5,
|
||||
wantPass: true,
|
||||
},
|
||||
{
|
||||
name: "max length 5",
|
||||
rule: Rule{
|
||||
MaxLength: 5,
|
||||
Message: "Column name too long",
|
||||
},
|
||||
wantLen: 5,
|
||||
wantPass: false, // Some columns exceed 5 chars
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
results := validateColumnNameLength(db, tt.rule, "test_rule")
|
||||
if len(results) != tt.wantLen {
|
||||
t.Errorf("validateColumnNameLength() returned %d results, want %d", len(results), tt.wantLen)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateReservedKeywords(t *testing.T) {
|
||||
// Create a database with reserved keywords
|
||||
db := &models.Database{
|
||||
Name: "testdb",
|
||||
Schemas: []*models.Schema{
|
||||
{
|
||||
Name: "public",
|
||||
Tables: []*models.Table{
|
||||
{
|
||||
Name: "user", // "user" is a reserved keyword
|
||||
Columns: map[string]*models.Column{
|
||||
"id": {
|
||||
Name: "id",
|
||||
Type: "bigint",
|
||||
IsPrimaryKey: true,
|
||||
},
|
||||
"select": { // "select" is a reserved keyword
|
||||
Name: "select",
|
||||
Type: "varchar(50)",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
rule Rule
|
||||
wantLen int
|
||||
checkPasses bool
|
||||
}{
|
||||
{
|
||||
name: "check tables only",
|
||||
rule: Rule{
|
||||
CheckTables: true,
|
||||
CheckColumns: false,
|
||||
Message: "Reserved keyword used",
|
||||
},
|
||||
wantLen: 1, // "user" table
|
||||
checkPasses: false,
|
||||
},
|
||||
{
|
||||
name: "check columns only",
|
||||
rule: Rule{
|
||||
CheckTables: false,
|
||||
CheckColumns: true,
|
||||
Message: "Reserved keyword used",
|
||||
},
|
||||
wantLen: 2, // "id", "select" columns (id passes, select fails)
|
||||
checkPasses: false,
|
||||
},
|
||||
{
|
||||
name: "check both",
|
||||
rule: Rule{
|
||||
CheckTables: true,
|
||||
CheckColumns: true,
|
||||
Message: "Reserved keyword used",
|
||||
},
|
||||
wantLen: 3, // "user" table + "id", "select" columns
|
||||
checkPasses: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
results := validateReservedKeywords(db, tt.rule, "test_rule")
|
||||
if len(results) != tt.wantLen {
|
||||
t.Errorf("validateReservedKeywords() returned %d results, want %d", len(results), tt.wantLen)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateMissingPrimaryKey(t *testing.T) {
|
||||
// Create database with and without primary keys
|
||||
db := &models.Database{
|
||||
Name: "testdb",
|
||||
Schemas: []*models.Schema{
|
||||
{
|
||||
Name: "public",
|
||||
Tables: []*models.Table{
|
||||
{
|
||||
Name: "with_pk",
|
||||
Columns: map[string]*models.Column{
|
||||
"id": {
|
||||
Name: "id",
|
||||
Type: "bigint",
|
||||
IsPrimaryKey: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "without_pk",
|
||||
Columns: map[string]*models.Column{
|
||||
"name": {
|
||||
Name: "name",
|
||||
Type: "varchar(50)",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
rule := Rule{
|
||||
Message: "Table missing primary key",
|
||||
}
|
||||
|
||||
results := validateMissingPrimaryKey(db, rule, "test_rule")
|
||||
|
||||
if len(results) != 2 {
|
||||
t.Errorf("validateMissingPrimaryKey() returned %d results, want 2", len(results))
|
||||
}
|
||||
|
||||
// First result should pass (with_pk has PK)
|
||||
if results[0].Passed != true {
|
||||
t.Errorf("validateMissingPrimaryKey() result[0].Passed=%v, want true", results[0].Passed)
|
||||
}
|
||||
|
||||
// Second result should fail (without_pk missing PK)
|
||||
if results[1].Passed != false {
|
||||
t.Errorf("validateMissingPrimaryKey() result[1].Passed=%v, want false", results[1].Passed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateOrphanedForeignKey(t *testing.T) {
|
||||
// Create database with orphaned FK
|
||||
db := &models.Database{
|
||||
Name: "testdb",
|
||||
Schemas: []*models.Schema{
|
||||
{
|
||||
Name: "public",
|
||||
Tables: []*models.Table{
|
||||
{
|
||||
Name: "users",
|
||||
Columns: map[string]*models.Column{
|
||||
"id": {
|
||||
Name: "id",
|
||||
Type: "bigint",
|
||||
IsPrimaryKey: true,
|
||||
},
|
||||
},
|
||||
Constraints: map[string]*models.Constraint{
|
||||
"fk_nonexistent": {
|
||||
Name: "fk_nonexistent",
|
||||
Type: models.ForeignKeyConstraint,
|
||||
Columns: []string{"rid_organization"},
|
||||
ReferencedTable: "nonexistent_table",
|
||||
ReferencedSchema: "public",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
rule := Rule{
|
||||
Message: "Foreign key references non-existent table",
|
||||
}
|
||||
|
||||
results := validateOrphanedForeignKey(db, rule, "test_rule")
|
||||
|
||||
if len(results) != 1 {
|
||||
t.Errorf("validateOrphanedForeignKey() returned %d results, want 1", len(results))
|
||||
}
|
||||
|
||||
if results[0].Passed != false {
|
||||
t.Errorf("validateOrphanedForeignKey() passed=%v, want false", results[0].Passed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateCircularDependency(t *testing.T) {
|
||||
// Create database with circular dependency
|
||||
db := &models.Database{
|
||||
Name: "testdb",
|
||||
Schemas: []*models.Schema{
|
||||
{
|
||||
Name: "public",
|
||||
Tables: []*models.Table{
|
||||
{
|
||||
Name: "table_a",
|
||||
Columns: map[string]*models.Column{
|
||||
"id": {Name: "id", Type: "bigint", IsPrimaryKey: true},
|
||||
},
|
||||
Constraints: map[string]*models.Constraint{
|
||||
"fk_to_b": {
|
||||
Name: "fk_to_b",
|
||||
Type: models.ForeignKeyConstraint,
|
||||
ReferencedTable: "table_b",
|
||||
ReferencedSchema: "public",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "table_b",
|
||||
Columns: map[string]*models.Column{
|
||||
"id": {Name: "id", Type: "bigint", IsPrimaryKey: true},
|
||||
},
|
||||
Constraints: map[string]*models.Constraint{
|
||||
"fk_to_a": {
|
||||
Name: "fk_to_a",
|
||||
Type: models.ForeignKeyConstraint,
|
||||
ReferencedTable: "table_a",
|
||||
ReferencedSchema: "public",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
rule := Rule{
|
||||
Message: "Circular dependency detected",
|
||||
}
|
||||
|
||||
results := validateCircularDependency(db, rule, "test_rule")
|
||||
|
||||
// Should detect circular dependency in both tables
|
||||
if len(results) == 0 {
|
||||
t.Error("validateCircularDependency() returned 0 results, expected circular dependency detection")
|
||||
}
|
||||
|
||||
for _, result := range results {
|
||||
if result.Passed {
|
||||
t.Error("validateCircularDependency() passed=true, want false for circular dependency")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeDataType(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{"varchar(50)", "varchar"},
|
||||
{"decimal(10,2)", "decimal"},
|
||||
{"int", "int"},
|
||||
{"BIGINT", "bigint"},
|
||||
{"VARCHAR(255)", "varchar"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
result := normalizeDataType(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("normalizeDataType(%q) = %q, want %q", tt.input, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestContains(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
slice []string
|
||||
value string
|
||||
expected bool
|
||||
}{
|
||||
{"found exact", []string{"foo", "bar", "baz"}, "bar", true},
|
||||
{"not found", []string{"foo", "bar", "baz"}, "qux", false},
|
||||
{"case insensitive match", []string{"foo", "Bar", "baz"}, "bar", true},
|
||||
{"empty slice", []string{}, "foo", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := contains(tt.slice, tt.value)
|
||||
if result != tt.expected {
|
||||
t.Errorf("contains(%v, %q) = %v, want %v", tt.slice, tt.value, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasCycle(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
graph map[string][]string
|
||||
node string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "simple cycle",
|
||||
graph: map[string][]string{
|
||||
"A": {"B"},
|
||||
"B": {"C"},
|
||||
"C": {"A"},
|
||||
},
|
||||
node: "A",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "no cycle",
|
||||
graph: map[string][]string{
|
||||
"A": {"B"},
|
||||
"B": {"C"},
|
||||
"C": {},
|
||||
},
|
||||
node: "A",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "self cycle",
|
||||
graph: map[string][]string{
|
||||
"A": {"A"},
|
||||
},
|
||||
node: "A",
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
visited := make(map[string]bool)
|
||||
recStack := make(map[string]bool)
|
||||
result := hasCycle(tt.node, tt.graph, visited, recStack)
|
||||
if result != tt.expected {
|
||||
t.Errorf("hasCycle() = %v, want %v", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatLocation(t *testing.T) {
|
||||
tests := []struct {
|
||||
schema string
|
||||
table string
|
||||
column string
|
||||
expected string
|
||||
}{
|
||||
{"public", "users", "id", "public.users.id"},
|
||||
{"public", "users", "", "public.users"},
|
||||
{"public", "", "", "public"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.expected, func(t *testing.T) {
|
||||
result := formatLocation(tt.schema, tt.table, tt.column)
|
||||
if result != tt.expected {
|
||||
t.Errorf("formatLocation(%q, %q, %q) = %q, want %q",
|
||||
tt.schema, tt.table, tt.column, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
627
pkg/merge/merge.go
Normal file
627
pkg/merge/merge.go
Normal file
@@ -0,0 +1,627 @@
|
||||
// Package merge provides utilities for merging database schemas.
|
||||
// It allows combining schemas from multiple sources while avoiding duplicates,
|
||||
// supporting only additive operations (no deletion or modification of existing items).
|
||||
package merge
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||
)
|
||||
|
||||
// MergeResult represents the result of a merge operation
|
||||
type MergeResult struct {
|
||||
SchemasAdded int
|
||||
TablesAdded int
|
||||
ColumnsAdded int
|
||||
ConstraintsAdded int
|
||||
IndexesAdded int
|
||||
RelationsAdded int
|
||||
DomainsAdded int
|
||||
EnumsAdded int
|
||||
ViewsAdded int
|
||||
SequencesAdded int
|
||||
}
|
||||
|
||||
// MergeOptions contains options for merge operations
|
||||
type MergeOptions struct {
|
||||
SkipDomains bool
|
||||
SkipRelations bool
|
||||
SkipEnums bool
|
||||
SkipViews bool
|
||||
SkipSequences bool
|
||||
SkipTableNames map[string]bool // Tables to skip during merge (keyed by table name)
|
||||
}
|
||||
|
||||
// MergeDatabases merges the source database into the target database.
|
||||
// Only adds missing items; existing items are not modified.
|
||||
func MergeDatabases(target, source *models.Database, opts *MergeOptions) *MergeResult {
|
||||
if opts == nil {
|
||||
opts = &MergeOptions{}
|
||||
}
|
||||
|
||||
result := &MergeResult{}
|
||||
|
||||
if target == nil || source == nil {
|
||||
return result
|
||||
}
|
||||
|
||||
// Merge schemas and their contents
|
||||
result.merge(target, source, opts)
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func (r *MergeResult) merge(target, source *models.Database, opts *MergeOptions) {
|
||||
// Create maps of existing schemas for quick lookup
|
||||
existingSchemas := make(map[string]*models.Schema)
|
||||
for _, schema := range target.Schemas {
|
||||
existingSchemas[schema.SQLName()] = schema
|
||||
}
|
||||
|
||||
// Merge schemas
|
||||
for _, srcSchema := range source.Schemas {
|
||||
schemaName := srcSchema.SQLName()
|
||||
if tgtSchema, exists := existingSchemas[schemaName]; exists {
|
||||
// Schema exists, merge its contents
|
||||
r.mergeSchemaContents(tgtSchema, srcSchema, opts)
|
||||
} else {
|
||||
// Schema doesn't exist, add it
|
||||
newSchema := cloneSchema(srcSchema)
|
||||
target.Schemas = append(target.Schemas, newSchema)
|
||||
r.SchemasAdded++
|
||||
}
|
||||
}
|
||||
|
||||
// Merge domains if not skipped
|
||||
if !opts.SkipDomains {
|
||||
r.mergeDomains(target, source)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *MergeResult) mergeSchemaContents(target, source *models.Schema, opts *MergeOptions) {
|
||||
// Merge tables
|
||||
r.mergeTables(target, source, opts)
|
||||
|
||||
// Merge views if not skipped
|
||||
if !opts.SkipViews {
|
||||
r.mergeViews(target, source)
|
||||
}
|
||||
|
||||
// Merge sequences if not skipped
|
||||
if !opts.SkipSequences {
|
||||
r.mergeSequences(target, source)
|
||||
}
|
||||
|
||||
// Merge enums if not skipped
|
||||
if !opts.SkipEnums {
|
||||
r.mergeEnums(target, source)
|
||||
}
|
||||
|
||||
// Merge relations if not skipped
|
||||
if !opts.SkipRelations {
|
||||
r.mergeRelations(target, source)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *MergeResult) mergeTables(schema *models.Schema, source *models.Schema, opts *MergeOptions) {
|
||||
// Create map of existing tables
|
||||
existingTables := make(map[string]*models.Table)
|
||||
for _, table := range schema.Tables {
|
||||
existingTables[table.SQLName()] = table
|
||||
}
|
||||
|
||||
// Merge tables
|
||||
for _, srcTable := range source.Tables {
|
||||
tableName := srcTable.SQLName()
|
||||
|
||||
// Skip if table is in the skip list (case-insensitive)
|
||||
if opts != nil && opts.SkipTableNames != nil && opts.SkipTableNames[strings.ToLower(tableName)] {
|
||||
continue
|
||||
}
|
||||
|
||||
if tgtTable, exists := existingTables[tableName]; exists {
|
||||
// Table exists, merge its columns, constraints, and indexes
|
||||
r.mergeColumns(tgtTable, srcTable)
|
||||
r.mergeConstraints(tgtTable, srcTable)
|
||||
r.mergeIndexes(tgtTable, srcTable)
|
||||
} else {
|
||||
// Table doesn't exist, add it
|
||||
newTable := cloneTable(srcTable)
|
||||
schema.Tables = append(schema.Tables, newTable)
|
||||
r.TablesAdded++
|
||||
// Count columns in the newly added table
|
||||
r.ColumnsAdded += len(newTable.Columns)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *MergeResult) mergeColumns(table *models.Table, srcTable *models.Table) {
|
||||
// Create map of existing columns
|
||||
existingColumns := make(map[string]*models.Column)
|
||||
for colName := range table.Columns {
|
||||
existingColumns[colName] = table.Columns[colName]
|
||||
}
|
||||
|
||||
// Merge columns
|
||||
for colName, srcCol := range srcTable.Columns {
|
||||
if _, exists := existingColumns[colName]; !exists {
|
||||
// Column doesn't exist, add it
|
||||
newCol := cloneColumn(srcCol)
|
||||
table.Columns[colName] = newCol
|
||||
r.ColumnsAdded++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *MergeResult) mergeConstraints(table *models.Table, srcTable *models.Table) {
|
||||
// Initialize constraints map if nil
|
||||
if table.Constraints == nil {
|
||||
table.Constraints = make(map[string]*models.Constraint)
|
||||
}
|
||||
|
||||
// Create map of existing constraints
|
||||
existingConstraints := make(map[string]*models.Constraint)
|
||||
for constName := range table.Constraints {
|
||||
existingConstraints[constName] = table.Constraints[constName]
|
||||
}
|
||||
|
||||
// Merge constraints
|
||||
for constName, srcConst := range srcTable.Constraints {
|
||||
if _, exists := existingConstraints[constName]; !exists {
|
||||
// Constraint doesn't exist, add it
|
||||
newConst := cloneConstraint(srcConst)
|
||||
table.Constraints[constName] = newConst
|
||||
r.ConstraintsAdded++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *MergeResult) mergeIndexes(table *models.Table, srcTable *models.Table) {
|
||||
// Initialize indexes map if nil
|
||||
if table.Indexes == nil {
|
||||
table.Indexes = make(map[string]*models.Index)
|
||||
}
|
||||
|
||||
// Create map of existing indexes
|
||||
existingIndexes := make(map[string]*models.Index)
|
||||
for idxName := range table.Indexes {
|
||||
existingIndexes[idxName] = table.Indexes[idxName]
|
||||
}
|
||||
|
||||
// Merge indexes
|
||||
for idxName, srcIdx := range srcTable.Indexes {
|
||||
if _, exists := existingIndexes[idxName]; !exists {
|
||||
// Index doesn't exist, add it
|
||||
newIdx := cloneIndex(srcIdx)
|
||||
table.Indexes[idxName] = newIdx
|
||||
r.IndexesAdded++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *MergeResult) mergeViews(schema *models.Schema, source *models.Schema) {
|
||||
// Create map of existing views
|
||||
existingViews := make(map[string]*models.View)
|
||||
for _, view := range schema.Views {
|
||||
existingViews[view.SQLName()] = view
|
||||
}
|
||||
|
||||
// Merge views
|
||||
for _, srcView := range source.Views {
|
||||
viewName := srcView.SQLName()
|
||||
if _, exists := existingViews[viewName]; !exists {
|
||||
// View doesn't exist, add it
|
||||
newView := cloneView(srcView)
|
||||
schema.Views = append(schema.Views, newView)
|
||||
r.ViewsAdded++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *MergeResult) mergeSequences(schema *models.Schema, source *models.Schema) {
|
||||
// Create map of existing sequences
|
||||
existingSequences := make(map[string]*models.Sequence)
|
||||
for _, seq := range schema.Sequences {
|
||||
existingSequences[seq.SQLName()] = seq
|
||||
}
|
||||
|
||||
// Merge sequences
|
||||
for _, srcSeq := range source.Sequences {
|
||||
seqName := srcSeq.SQLName()
|
||||
if _, exists := existingSequences[seqName]; !exists {
|
||||
// Sequence doesn't exist, add it
|
||||
newSeq := cloneSequence(srcSeq)
|
||||
schema.Sequences = append(schema.Sequences, newSeq)
|
||||
r.SequencesAdded++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *MergeResult) mergeEnums(schema *models.Schema, source *models.Schema) {
|
||||
// Create map of existing enums
|
||||
existingEnums := make(map[string]*models.Enum)
|
||||
for _, enum := range schema.Enums {
|
||||
existingEnums[enum.SQLName()] = enum
|
||||
}
|
||||
|
||||
// Merge enums
|
||||
for _, srcEnum := range source.Enums {
|
||||
enumName := srcEnum.SQLName()
|
||||
if _, exists := existingEnums[enumName]; !exists {
|
||||
// Enum doesn't exist, add it
|
||||
newEnum := cloneEnum(srcEnum)
|
||||
schema.Enums = append(schema.Enums, newEnum)
|
||||
r.EnumsAdded++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *MergeResult) mergeRelations(schema *models.Schema, source *models.Schema) {
|
||||
// Create map of existing relations
|
||||
existingRelations := make(map[string]*models.Relationship)
|
||||
for _, rel := range schema.Relations {
|
||||
existingRelations[rel.SQLName()] = rel
|
||||
}
|
||||
|
||||
// Merge relations
|
||||
for _, srcRel := range source.Relations {
|
||||
if _, exists := existingRelations[srcRel.SQLName()]; !exists {
|
||||
// Relation doesn't exist, add it
|
||||
newRel := cloneRelation(srcRel)
|
||||
schema.Relations = append(schema.Relations, newRel)
|
||||
r.RelationsAdded++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *MergeResult) mergeDomains(target *models.Database, source *models.Database) {
|
||||
// Create map of existing domains
|
||||
existingDomains := make(map[string]*models.Domain)
|
||||
for _, domain := range target.Domains {
|
||||
existingDomains[domain.SQLName()] = domain
|
||||
}
|
||||
|
||||
// Merge domains
|
||||
for _, srcDomain := range source.Domains {
|
||||
domainName := srcDomain.SQLName()
|
||||
if _, exists := existingDomains[domainName]; !exists {
|
||||
// Domain doesn't exist, add it
|
||||
newDomain := cloneDomain(srcDomain)
|
||||
target.Domains = append(target.Domains, newDomain)
|
||||
r.DomainsAdded++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Clone functions to create deep copies of models
|
||||
|
||||
func cloneSchema(schema *models.Schema) *models.Schema {
|
||||
if schema == nil {
|
||||
return nil
|
||||
}
|
||||
newSchema := &models.Schema{
|
||||
Name: schema.Name,
|
||||
Description: schema.Description,
|
||||
Owner: schema.Owner,
|
||||
Comment: schema.Comment,
|
||||
Sequence: schema.Sequence,
|
||||
UpdatedAt: schema.UpdatedAt,
|
||||
Tables: make([]*models.Table, 0),
|
||||
Views: make([]*models.View, 0),
|
||||
Sequences: make([]*models.Sequence, 0),
|
||||
Enums: make([]*models.Enum, 0),
|
||||
Relations: make([]*models.Relationship, 0),
|
||||
}
|
||||
|
||||
if schema.Permissions != nil {
|
||||
newSchema.Permissions = make(map[string]string)
|
||||
for k, v := range schema.Permissions {
|
||||
newSchema.Permissions[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
if schema.Metadata != nil {
|
||||
newSchema.Metadata = make(map[string]interface{})
|
||||
for k, v := range schema.Metadata {
|
||||
newSchema.Metadata[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
if schema.Scripts != nil {
|
||||
newSchema.Scripts = make([]*models.Script, len(schema.Scripts))
|
||||
copy(newSchema.Scripts, schema.Scripts)
|
||||
}
|
||||
|
||||
// Clone tables
|
||||
for _, table := range schema.Tables {
|
||||
newSchema.Tables = append(newSchema.Tables, cloneTable(table))
|
||||
}
|
||||
|
||||
// Clone views
|
||||
for _, view := range schema.Views {
|
||||
newSchema.Views = append(newSchema.Views, cloneView(view))
|
||||
}
|
||||
|
||||
// Clone sequences
|
||||
for _, seq := range schema.Sequences {
|
||||
newSchema.Sequences = append(newSchema.Sequences, cloneSequence(seq))
|
||||
}
|
||||
|
||||
// Clone enums
|
||||
for _, enum := range schema.Enums {
|
||||
newSchema.Enums = append(newSchema.Enums, cloneEnum(enum))
|
||||
}
|
||||
|
||||
// Clone relations
|
||||
for _, rel := range schema.Relations {
|
||||
newSchema.Relations = append(newSchema.Relations, cloneRelation(rel))
|
||||
}
|
||||
|
||||
return newSchema
|
||||
}
|
||||
|
||||
func cloneTable(table *models.Table) *models.Table {
|
||||
if table == nil {
|
||||
return nil
|
||||
}
|
||||
newTable := &models.Table{
|
||||
Name: table.Name,
|
||||
Description: table.Description,
|
||||
Schema: table.Schema,
|
||||
Comment: table.Comment,
|
||||
Sequence: table.Sequence,
|
||||
UpdatedAt: table.UpdatedAt,
|
||||
Columns: make(map[string]*models.Column),
|
||||
Constraints: make(map[string]*models.Constraint),
|
||||
Indexes: make(map[string]*models.Index),
|
||||
}
|
||||
|
||||
if table.Metadata != nil {
|
||||
newTable.Metadata = make(map[string]interface{})
|
||||
for k, v := range table.Metadata {
|
||||
newTable.Metadata[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
// Clone columns
|
||||
for colName, col := range table.Columns {
|
||||
newTable.Columns[colName] = cloneColumn(col)
|
||||
}
|
||||
|
||||
// Clone constraints
|
||||
for constName, constraint := range table.Constraints {
|
||||
newTable.Constraints[constName] = cloneConstraint(constraint)
|
||||
}
|
||||
|
||||
// Clone indexes
|
||||
for idxName, index := range table.Indexes {
|
||||
newTable.Indexes[idxName] = cloneIndex(index)
|
||||
}
|
||||
|
||||
return newTable
|
||||
}
|
||||
|
||||
func cloneColumn(col *models.Column) *models.Column {
|
||||
if col == nil {
|
||||
return nil
|
||||
}
|
||||
newCol := &models.Column{
|
||||
Name: col.Name,
|
||||
Type: col.Type,
|
||||
Description: col.Description,
|
||||
Comment: col.Comment,
|
||||
IsPrimaryKey: col.IsPrimaryKey,
|
||||
NotNull: col.NotNull,
|
||||
Default: col.Default,
|
||||
Precision: col.Precision,
|
||||
Scale: col.Scale,
|
||||
Length: col.Length,
|
||||
Sequence: col.Sequence,
|
||||
AutoIncrement: col.AutoIncrement,
|
||||
Collation: col.Collation,
|
||||
}
|
||||
|
||||
return newCol
|
||||
}
|
||||
|
||||
func cloneConstraint(constraint *models.Constraint) *models.Constraint {
|
||||
if constraint == nil {
|
||||
return nil
|
||||
}
|
||||
newConstraint := &models.Constraint{
|
||||
Type: constraint.Type,
|
||||
Columns: make([]string, len(constraint.Columns)),
|
||||
ReferencedTable: constraint.ReferencedTable,
|
||||
ReferencedSchema: constraint.ReferencedSchema,
|
||||
ReferencedColumns: make([]string, len(constraint.ReferencedColumns)),
|
||||
OnUpdate: constraint.OnUpdate,
|
||||
OnDelete: constraint.OnDelete,
|
||||
Expression: constraint.Expression,
|
||||
Name: constraint.Name,
|
||||
Deferrable: constraint.Deferrable,
|
||||
InitiallyDeferred: constraint.InitiallyDeferred,
|
||||
Sequence: constraint.Sequence,
|
||||
}
|
||||
copy(newConstraint.Columns, constraint.Columns)
|
||||
copy(newConstraint.ReferencedColumns, constraint.ReferencedColumns)
|
||||
return newConstraint
|
||||
}
|
||||
|
||||
func cloneIndex(index *models.Index) *models.Index {
|
||||
if index == nil {
|
||||
return nil
|
||||
}
|
||||
newIndex := &models.Index{
|
||||
Name: index.Name,
|
||||
Description: index.Description,
|
||||
Table: index.Table,
|
||||
Schema: index.Schema,
|
||||
Columns: make([]string, len(index.Columns)),
|
||||
Unique: index.Unique,
|
||||
Type: index.Type,
|
||||
Where: index.Where,
|
||||
Concurrent: index.Concurrent,
|
||||
Include: make([]string, len(index.Include)),
|
||||
Comment: index.Comment,
|
||||
Sequence: index.Sequence,
|
||||
}
|
||||
copy(newIndex.Columns, index.Columns)
|
||||
copy(newIndex.Include, index.Include)
|
||||
return newIndex
|
||||
}
|
||||
|
||||
func cloneView(view *models.View) *models.View {
|
||||
if view == nil {
|
||||
return nil
|
||||
}
|
||||
newView := &models.View{
|
||||
Name: view.Name,
|
||||
Description: view.Description,
|
||||
Schema: view.Schema,
|
||||
Definition: view.Definition,
|
||||
Comment: view.Comment,
|
||||
Sequence: view.Sequence,
|
||||
Columns: make(map[string]*models.Column),
|
||||
}
|
||||
|
||||
if view.Metadata != nil {
|
||||
newView.Metadata = make(map[string]interface{})
|
||||
for k, v := range view.Metadata {
|
||||
newView.Metadata[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
// Clone columns
|
||||
for colName, col := range view.Columns {
|
||||
newView.Columns[colName] = cloneColumn(col)
|
||||
}
|
||||
|
||||
return newView
|
||||
}
|
||||
|
||||
func cloneSequence(seq *models.Sequence) *models.Sequence {
|
||||
if seq == nil {
|
||||
return nil
|
||||
}
|
||||
newSeq := &models.Sequence{
|
||||
Name: seq.Name,
|
||||
Description: seq.Description,
|
||||
Schema: seq.Schema,
|
||||
StartValue: seq.StartValue,
|
||||
MinValue: seq.MinValue,
|
||||
MaxValue: seq.MaxValue,
|
||||
IncrementBy: seq.IncrementBy,
|
||||
CacheSize: seq.CacheSize,
|
||||
Cycle: seq.Cycle,
|
||||
OwnedByTable: seq.OwnedByTable,
|
||||
OwnedByColumn: seq.OwnedByColumn,
|
||||
Comment: seq.Comment,
|
||||
Sequence: seq.Sequence,
|
||||
}
|
||||
return newSeq
|
||||
}
|
||||
|
||||
func cloneEnum(enum *models.Enum) *models.Enum {
|
||||
if enum == nil {
|
||||
return nil
|
||||
}
|
||||
newEnum := &models.Enum{
|
||||
Name: enum.Name,
|
||||
Values: make([]string, len(enum.Values)),
|
||||
Schema: enum.Schema,
|
||||
}
|
||||
copy(newEnum.Values, enum.Values)
|
||||
return newEnum
|
||||
}
|
||||
|
||||
func cloneRelation(rel *models.Relationship) *models.Relationship {
|
||||
if rel == nil {
|
||||
return nil
|
||||
}
|
||||
newRel := &models.Relationship{
|
||||
Name: rel.Name,
|
||||
Type: rel.Type,
|
||||
FromTable: rel.FromTable,
|
||||
FromSchema: rel.FromSchema,
|
||||
FromColumns: make([]string, len(rel.FromColumns)),
|
||||
ToTable: rel.ToTable,
|
||||
ToSchema: rel.ToSchema,
|
||||
ToColumns: make([]string, len(rel.ToColumns)),
|
||||
ForeignKey: rel.ForeignKey,
|
||||
ThroughTable: rel.ThroughTable,
|
||||
ThroughSchema: rel.ThroughSchema,
|
||||
Description: rel.Description,
|
||||
Sequence: rel.Sequence,
|
||||
}
|
||||
|
||||
if rel.Properties != nil {
|
||||
newRel.Properties = make(map[string]string)
|
||||
for k, v := range rel.Properties {
|
||||
newRel.Properties[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
copy(newRel.FromColumns, rel.FromColumns)
|
||||
copy(newRel.ToColumns, rel.ToColumns)
|
||||
return newRel
|
||||
}
|
||||
|
||||
func cloneDomain(domain *models.Domain) *models.Domain {
|
||||
if domain == nil {
|
||||
return nil
|
||||
}
|
||||
newDomain := &models.Domain{
|
||||
Name: domain.Name,
|
||||
Description: domain.Description,
|
||||
Comment: domain.Comment,
|
||||
Sequence: domain.Sequence,
|
||||
Tables: make([]*models.DomainTable, len(domain.Tables)),
|
||||
}
|
||||
|
||||
if domain.Metadata != nil {
|
||||
newDomain.Metadata = make(map[string]interface{})
|
||||
for k, v := range domain.Metadata {
|
||||
newDomain.Metadata[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
copy(newDomain.Tables, domain.Tables)
|
||||
return newDomain
|
||||
}
|
||||
|
||||
// GetMergeSummary returns a human-readable summary of the merge result
|
||||
func GetMergeSummary(result *MergeResult) string {
|
||||
if result == nil {
|
||||
return "No merge result available"
|
||||
}
|
||||
|
||||
lines := []string{
|
||||
"=== Merge Summary ===",
|
||||
fmt.Sprintf("Schemas added: %d", result.SchemasAdded),
|
||||
fmt.Sprintf("Tables added: %d", result.TablesAdded),
|
||||
fmt.Sprintf("Columns added: %d", result.ColumnsAdded),
|
||||
fmt.Sprintf("Constraints added: %d", result.ConstraintsAdded),
|
||||
fmt.Sprintf("Indexes added: %d", result.IndexesAdded),
|
||||
fmt.Sprintf("Views added: %d", result.ViewsAdded),
|
||||
fmt.Sprintf("Sequences added: %d", result.SequencesAdded),
|
||||
fmt.Sprintf("Enums added: %d", result.EnumsAdded),
|
||||
fmt.Sprintf("Relations added: %d", result.RelationsAdded),
|
||||
fmt.Sprintf("Domains added: %d", result.DomainsAdded),
|
||||
}
|
||||
|
||||
totalAdded := result.SchemasAdded + result.TablesAdded + result.ColumnsAdded +
|
||||
result.ConstraintsAdded + result.IndexesAdded +
|
||||
result.ViewsAdded + result.SequencesAdded + result.EnumsAdded +
|
||||
result.RelationsAdded + result.DomainsAdded
|
||||
|
||||
lines = append(lines, fmt.Sprintf("Total items added: %d", totalAdded))
|
||||
|
||||
summary := ""
|
||||
for _, line := range lines {
|
||||
summary += line + "\n"
|
||||
}
|
||||
|
||||
return summary
|
||||
}
|
||||
617
pkg/merge/merge_test.go
Normal file
617
pkg/merge/merge_test.go
Normal file
@@ -0,0 +1,617 @@
|
||||
package merge
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||
)
|
||||
|
||||
func TestMergeDatabases_NilInputs(t *testing.T) {
|
||||
result := MergeDatabases(nil, nil, nil)
|
||||
if result == nil {
|
||||
t.Fatal("Expected non-nil result")
|
||||
}
|
||||
if result.SchemasAdded != 0 {
|
||||
t.Errorf("Expected 0 schemas added, got %d", result.SchemasAdded)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeDatabases_NewSchema(t *testing.T) {
|
||||
target := &models.Database{
|
||||
Schemas: []*models.Schema{
|
||||
{Name: "public"},
|
||||
},
|
||||
}
|
||||
source := &models.Database{
|
||||
Schemas: []*models.Schema{
|
||||
{Name: "auth"},
|
||||
},
|
||||
}
|
||||
|
||||
result := MergeDatabases(target, source, nil)
|
||||
if result.SchemasAdded != 1 {
|
||||
t.Errorf("Expected 1 schema added, got %d", result.SchemasAdded)
|
||||
}
|
||||
if len(target.Schemas) != 2 {
|
||||
t.Errorf("Expected 2 schemas in target, got %d", len(target.Schemas))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeDatabases_ExistingSchema(t *testing.T) {
|
||||
target := &models.Database{
|
||||
Schemas: []*models.Schema{
|
||||
{Name: "public"},
|
||||
},
|
||||
}
|
||||
source := &models.Database{
|
||||
Schemas: []*models.Schema{
|
||||
{Name: "public"},
|
||||
},
|
||||
}
|
||||
|
||||
result := MergeDatabases(target, source, nil)
|
||||
if result.SchemasAdded != 0 {
|
||||
t.Errorf("Expected 0 schemas added, got %d", result.SchemasAdded)
|
||||
}
|
||||
if len(target.Schemas) != 1 {
|
||||
t.Errorf("Expected 1 schema in target, got %d", len(target.Schemas))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeTables_NewTable(t *testing.T) {
|
||||
target := &models.Database{
|
||||
Schemas: []*models.Schema{
|
||||
{
|
||||
Name: "public",
|
||||
Tables: []*models.Table{
|
||||
{
|
||||
Name: "users",
|
||||
Schema: "public",
|
||||
Columns: map[string]*models.Column{},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
source := &models.Database{
|
||||
Schemas: []*models.Schema{
|
||||
{
|
||||
Name: "public",
|
||||
Tables: []*models.Table{
|
||||
{
|
||||
Name: "posts",
|
||||
Schema: "public",
|
||||
Columns: map[string]*models.Column{},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result := MergeDatabases(target, source, nil)
|
||||
if result.TablesAdded != 1 {
|
||||
t.Errorf("Expected 1 table added, got %d", result.TablesAdded)
|
||||
}
|
||||
if len(target.Schemas[0].Tables) != 2 {
|
||||
t.Errorf("Expected 2 tables in target schema, got %d", len(target.Schemas[0].Tables))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeColumns_NewColumn(t *testing.T) {
|
||||
target := &models.Database{
|
||||
Schemas: []*models.Schema{
|
||||
{
|
||||
Name: "public",
|
||||
Tables: []*models.Table{
|
||||
{
|
||||
Name: "users",
|
||||
Schema: "public",
|
||||
Columns: map[string]*models.Column{
|
||||
"id": {Name: "id", Type: "int"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
source := &models.Database{
|
||||
Schemas: []*models.Schema{
|
||||
{
|
||||
Name: "public",
|
||||
Tables: []*models.Table{
|
||||
{
|
||||
Name: "users",
|
||||
Schema: "public",
|
||||
Columns: map[string]*models.Column{
|
||||
"email": {Name: "email", Type: "varchar"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result := MergeDatabases(target, source, nil)
|
||||
if result.ColumnsAdded != 1 {
|
||||
t.Errorf("Expected 1 column added, got %d", result.ColumnsAdded)
|
||||
}
|
||||
if len(target.Schemas[0].Tables[0].Columns) != 2 {
|
||||
t.Errorf("Expected 2 columns in target table, got %d", len(target.Schemas[0].Tables[0].Columns))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeConstraints_NewConstraint(t *testing.T) {
|
||||
target := &models.Database{
|
||||
Schemas: []*models.Schema{
|
||||
{
|
||||
Name: "public",
|
||||
Tables: []*models.Table{
|
||||
{
|
||||
Name: "users",
|
||||
Schema: "public",
|
||||
Columns: map[string]*models.Column{},
|
||||
Constraints: map[string]*models.Constraint{},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
source := &models.Database{
|
||||
Schemas: []*models.Schema{
|
||||
{
|
||||
Name: "public",
|
||||
Tables: []*models.Table{
|
||||
{
|
||||
Name: "users",
|
||||
Schema: "public",
|
||||
Columns: map[string]*models.Column{},
|
||||
Constraints: map[string]*models.Constraint{
|
||||
"ukey_users_email": {
|
||||
Type: models.UniqueConstraint,
|
||||
Columns: []string{"email"},
|
||||
Name: "ukey_users_email",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result := MergeDatabases(target, source, nil)
|
||||
if result.ConstraintsAdded != 1 {
|
||||
t.Errorf("Expected 1 constraint added, got %d", result.ConstraintsAdded)
|
||||
}
|
||||
if len(target.Schemas[0].Tables[0].Constraints) != 1 {
|
||||
t.Errorf("Expected 1 constraint in target table, got %d", len(target.Schemas[0].Tables[0].Constraints))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeConstraints_NilConstraintsMap(t *testing.T) {
|
||||
target := &models.Database{
|
||||
Schemas: []*models.Schema{
|
||||
{
|
||||
Name: "public",
|
||||
Tables: []*models.Table{
|
||||
{
|
||||
Name: "users",
|
||||
Schema: "public",
|
||||
Columns: map[string]*models.Column{},
|
||||
Constraints: nil, // Nil map
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
source := &models.Database{
|
||||
Schemas: []*models.Schema{
|
||||
{
|
||||
Name: "public",
|
||||
Tables: []*models.Table{
|
||||
{
|
||||
Name: "users",
|
||||
Schema: "public",
|
||||
Columns: map[string]*models.Column{},
|
||||
Constraints: map[string]*models.Constraint{
|
||||
"ukey_users_email": {
|
||||
Type: models.UniqueConstraint,
|
||||
Columns: []string{"email"},
|
||||
Name: "ukey_users_email",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result := MergeDatabases(target, source, nil)
|
||||
if result.ConstraintsAdded != 1 {
|
||||
t.Errorf("Expected 1 constraint added, got %d", result.ConstraintsAdded)
|
||||
}
|
||||
if target.Schemas[0].Tables[0].Constraints == nil {
|
||||
t.Error("Expected constraints map to be initialized")
|
||||
}
|
||||
if len(target.Schemas[0].Tables[0].Constraints) != 1 {
|
||||
t.Errorf("Expected 1 constraint in target table, got %d", len(target.Schemas[0].Tables[0].Constraints))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeIndexes_NewIndex(t *testing.T) {
|
||||
target := &models.Database{
|
||||
Schemas: []*models.Schema{
|
||||
{
|
||||
Name: "public",
|
||||
Tables: []*models.Table{
|
||||
{
|
||||
Name: "users",
|
||||
Schema: "public",
|
||||
Columns: map[string]*models.Column{},
|
||||
Indexes: map[string]*models.Index{},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
source := &models.Database{
|
||||
Schemas: []*models.Schema{
|
||||
{
|
||||
Name: "public",
|
||||
Tables: []*models.Table{
|
||||
{
|
||||
Name: "users",
|
||||
Schema: "public",
|
||||
Columns: map[string]*models.Column{},
|
||||
Indexes: map[string]*models.Index{
|
||||
"idx_users_email": {
|
||||
Name: "idx_users_email",
|
||||
Columns: []string{"email"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result := MergeDatabases(target, source, nil)
|
||||
if result.IndexesAdded != 1 {
|
||||
t.Errorf("Expected 1 index added, got %d", result.IndexesAdded)
|
||||
}
|
||||
if len(target.Schemas[0].Tables[0].Indexes) != 1 {
|
||||
t.Errorf("Expected 1 index in target table, got %d", len(target.Schemas[0].Tables[0].Indexes))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeIndexes_NilIndexesMap(t *testing.T) {
|
||||
target := &models.Database{
|
||||
Schemas: []*models.Schema{
|
||||
{
|
||||
Name: "public",
|
||||
Tables: []*models.Table{
|
||||
{
|
||||
Name: "users",
|
||||
Schema: "public",
|
||||
Columns: map[string]*models.Column{},
|
||||
Indexes: nil, // Nil map
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
source := &models.Database{
|
||||
Schemas: []*models.Schema{
|
||||
{
|
||||
Name: "public",
|
||||
Tables: []*models.Table{
|
||||
{
|
||||
Name: "users",
|
||||
Schema: "public",
|
||||
Columns: map[string]*models.Column{},
|
||||
Indexes: map[string]*models.Index{
|
||||
"idx_users_email": {
|
||||
Name: "idx_users_email",
|
||||
Columns: []string{"email"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result := MergeDatabases(target, source, nil)
|
||||
if result.IndexesAdded != 1 {
|
||||
t.Errorf("Expected 1 index added, got %d", result.IndexesAdded)
|
||||
}
|
||||
if target.Schemas[0].Tables[0].Indexes == nil {
|
||||
t.Error("Expected indexes map to be initialized")
|
||||
}
|
||||
if len(target.Schemas[0].Tables[0].Indexes) != 1 {
|
||||
t.Errorf("Expected 1 index in target table, got %d", len(target.Schemas[0].Tables[0].Indexes))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeOptions_SkipTableNames(t *testing.T) {
|
||||
target := &models.Database{
|
||||
Schemas: []*models.Schema{
|
||||
{
|
||||
Name: "public",
|
||||
Tables: []*models.Table{
|
||||
{
|
||||
Name: "users",
|
||||
Schema: "public",
|
||||
Columns: map[string]*models.Column{},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
source := &models.Database{
|
||||
Schemas: []*models.Schema{
|
||||
{
|
||||
Name: "public",
|
||||
Tables: []*models.Table{
|
||||
{
|
||||
Name: "migrations",
|
||||
Schema: "public",
|
||||
Columns: map[string]*models.Column{},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
opts := &MergeOptions{
|
||||
SkipTableNames: map[string]bool{
|
||||
"migrations": true,
|
||||
},
|
||||
}
|
||||
|
||||
result := MergeDatabases(target, source, opts)
|
||||
if result.TablesAdded != 0 {
|
||||
t.Errorf("Expected 0 tables added (skipped), got %d", result.TablesAdded)
|
||||
}
|
||||
if len(target.Schemas[0].Tables) != 1 {
|
||||
t.Errorf("Expected 1 table in target schema, got %d", len(target.Schemas[0].Tables))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeViews_NewView(t *testing.T) {
|
||||
target := &models.Database{
|
||||
Schemas: []*models.Schema{
|
||||
{
|
||||
Name: "public",
|
||||
Views: []*models.View{},
|
||||
},
|
||||
},
|
||||
}
|
||||
source := &models.Database{
|
||||
Schemas: []*models.Schema{
|
||||
{
|
||||
Name: "public",
|
||||
Views: []*models.View{
|
||||
{
|
||||
Name: "user_summary",
|
||||
Schema: "public",
|
||||
Definition: "SELECT * FROM users",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result := MergeDatabases(target, source, nil)
|
||||
if result.ViewsAdded != 1 {
|
||||
t.Errorf("Expected 1 view added, got %d", result.ViewsAdded)
|
||||
}
|
||||
if len(target.Schemas[0].Views) != 1 {
|
||||
t.Errorf("Expected 1 view in target schema, got %d", len(target.Schemas[0].Views))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeEnums_NewEnum(t *testing.T) {
|
||||
target := &models.Database{
|
||||
Schemas: []*models.Schema{
|
||||
{
|
||||
Name: "public",
|
||||
Enums: []*models.Enum{},
|
||||
},
|
||||
},
|
||||
}
|
||||
source := &models.Database{
|
||||
Schemas: []*models.Schema{
|
||||
{
|
||||
Name: "public",
|
||||
Enums: []*models.Enum{
|
||||
{
|
||||
Name: "user_role",
|
||||
Schema: "public",
|
||||
Values: []string{"admin", "user"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result := MergeDatabases(target, source, nil)
|
||||
if result.EnumsAdded != 1 {
|
||||
t.Errorf("Expected 1 enum added, got %d", result.EnumsAdded)
|
||||
}
|
||||
if len(target.Schemas[0].Enums) != 1 {
|
||||
t.Errorf("Expected 1 enum in target schema, got %d", len(target.Schemas[0].Enums))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeDomains_NewDomain(t *testing.T) {
|
||||
target := &models.Database{
|
||||
Domains: []*models.Domain{},
|
||||
}
|
||||
source := &models.Database{
|
||||
Domains: []*models.Domain{
|
||||
{
|
||||
Name: "auth",
|
||||
Description: "Authentication domain",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result := MergeDatabases(target, source, nil)
|
||||
if result.DomainsAdded != 1 {
|
||||
t.Errorf("Expected 1 domain added, got %d", result.DomainsAdded)
|
||||
}
|
||||
if len(target.Domains) != 1 {
|
||||
t.Errorf("Expected 1 domain in target, got %d", len(target.Domains))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeRelations_NewRelation(t *testing.T) {
|
||||
target := &models.Database{
|
||||
Schemas: []*models.Schema{
|
||||
{
|
||||
Name: "public",
|
||||
Relations: []*models.Relationship{},
|
||||
},
|
||||
},
|
||||
}
|
||||
source := &models.Database{
|
||||
Schemas: []*models.Schema{
|
||||
{
|
||||
Name: "public",
|
||||
Relations: []*models.Relationship{
|
||||
{
|
||||
Name: "fk_posts_user",
|
||||
Type: models.OneToMany,
|
||||
FromTable: "posts",
|
||||
FromColumns: []string{"user_id"},
|
||||
ToTable: "users",
|
||||
ToColumns: []string{"id"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result := MergeDatabases(target, source, nil)
|
||||
if result.RelationsAdded != 1 {
|
||||
t.Errorf("Expected 1 relation added, got %d", result.RelationsAdded)
|
||||
}
|
||||
if len(target.Schemas[0].Relations) != 1 {
|
||||
t.Errorf("Expected 1 relation in target schema, got %d", len(target.Schemas[0].Relations))
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetMergeSummary(t *testing.T) {
|
||||
result := &MergeResult{
|
||||
SchemasAdded: 1,
|
||||
TablesAdded: 2,
|
||||
ColumnsAdded: 5,
|
||||
ConstraintsAdded: 3,
|
||||
IndexesAdded: 2,
|
||||
ViewsAdded: 1,
|
||||
}
|
||||
|
||||
summary := GetMergeSummary(result)
|
||||
if summary == "" {
|
||||
t.Error("Expected non-empty summary")
|
||||
}
|
||||
if len(summary) < 50 {
|
||||
t.Errorf("Summary seems too short: %s", summary)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetMergeSummary_Nil(t *testing.T) {
|
||||
summary := GetMergeSummary(nil)
|
||||
if summary == "" {
|
||||
t.Error("Expected non-empty summary for nil result")
|
||||
}
|
||||
}
|
||||
|
||||
func TestComplexMerge(t *testing.T) {
|
||||
// Target with existing structure
|
||||
target := &models.Database{
|
||||
Schemas: []*models.Schema{
|
||||
{
|
||||
Name: "public",
|
||||
Tables: []*models.Table{
|
||||
{
|
||||
Name: "users",
|
||||
Schema: "public",
|
||||
Columns: map[string]*models.Column{
|
||||
"id": {Name: "id", Type: "int"},
|
||||
},
|
||||
Constraints: map[string]*models.Constraint{},
|
||||
Indexes: map[string]*models.Index{},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Source with new columns, constraints, and indexes
|
||||
source := &models.Database{
|
||||
Schemas: []*models.Schema{
|
||||
{
|
||||
Name: "public",
|
||||
Tables: []*models.Table{
|
||||
{
|
||||
Name: "users",
|
||||
Schema: "public",
|
||||
Columns: map[string]*models.Column{
|
||||
"email": {Name: "email", Type: "varchar"},
|
||||
"guid": {Name: "guid", Type: "uuid"},
|
||||
},
|
||||
Constraints: map[string]*models.Constraint{
|
||||
"ukey_users_email": {
|
||||
Type: models.UniqueConstraint,
|
||||
Columns: []string{"email"},
|
||||
Name: "ukey_users_email",
|
||||
},
|
||||
"ukey_users_guid": {
|
||||
Type: models.UniqueConstraint,
|
||||
Columns: []string{"guid"},
|
||||
Name: "ukey_users_guid",
|
||||
},
|
||||
},
|
||||
Indexes: map[string]*models.Index{
|
||||
"idx_users_email": {
|
||||
Name: "idx_users_email",
|
||||
Columns: []string{"email"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result := MergeDatabases(target, source, nil)
|
||||
|
||||
// Verify counts
|
||||
if result.ColumnsAdded != 2 {
|
||||
t.Errorf("Expected 2 columns added, got %d", result.ColumnsAdded)
|
||||
}
|
||||
if result.ConstraintsAdded != 2 {
|
||||
t.Errorf("Expected 2 constraints added, got %d", result.ConstraintsAdded)
|
||||
}
|
||||
if result.IndexesAdded != 1 {
|
||||
t.Errorf("Expected 1 index added, got %d", result.IndexesAdded)
|
||||
}
|
||||
|
||||
// Verify target has merged data
|
||||
table := target.Schemas[0].Tables[0]
|
||||
if len(table.Columns) != 3 {
|
||||
t.Errorf("Expected 3 columns in merged table, got %d", len(table.Columns))
|
||||
}
|
||||
if len(table.Constraints) != 2 {
|
||||
t.Errorf("Expected 2 constraints in merged table, got %d", len(table.Constraints))
|
||||
}
|
||||
if len(table.Indexes) != 1 {
|
||||
t.Errorf("Expected 1 index in merged table, got %d", len(table.Indexes))
|
||||
}
|
||||
|
||||
// Verify specific constraint
|
||||
if _, exists := table.Constraints["ukey_users_guid"]; !exists {
|
||||
t.Error("Expected ukey_users_guid constraint to exist")
|
||||
}
|
||||
}
|
||||
@@ -4,7 +4,12 @@
|
||||
// intermediate representation for converting between various database schema formats.
|
||||
package models
|
||||
|
||||
import "strings"
|
||||
import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// DatabaseType represents the type of database system.
|
||||
type DatabaseType string
|
||||
@@ -21,10 +26,13 @@ type Database struct {
|
||||
Name string `json:"name" yaml:"name"`
|
||||
Description string `json:"description,omitempty" yaml:"description,omitempty" xml:"description,omitempty"`
|
||||
Schemas []*Schema `json:"schemas" yaml:"schemas" xml:"schemas"`
|
||||
Domains []*Domain `json:"domains,omitempty" yaml:"domains,omitempty" xml:"domains,omitempty"`
|
||||
Comment string `json:"comment,omitempty" yaml:"comment,omitempty" xml:"comment,omitempty"`
|
||||
DatabaseType DatabaseType `json:"database_type,omitempty" yaml:"database_type,omitempty" xml:"database_type,omitempty"`
|
||||
DatabaseVersion string `json:"database_version,omitempty" yaml:"database_version,omitempty" xml:"database_version,omitempty"`
|
||||
SourceFormat string `json:"source_format,omitempty" yaml:"source_format,omitempty" xml:"source_format,omitempty"` // Source Format of the database.
|
||||
UpdatedAt string `json:"updatedat,omitempty" yaml:"updatedat,omitempty" xml:"updatedat,omitempty"`
|
||||
GUID string `json:"guid" yaml:"guid" xml:"guid"`
|
||||
}
|
||||
|
||||
// SQLName returns the database name in lowercase for SQL compatibility.
|
||||
@@ -32,6 +40,39 @@ func (d *Database) SQLName() string {
|
||||
return strings.ToLower(d.Name)
|
||||
}
|
||||
|
||||
// UpdateDate sets the UpdatedAt field to the current time in RFC3339 format.
|
||||
func (d *Database) UpdateDate() {
|
||||
d.UpdatedAt = time.Now().Format(time.RFC3339)
|
||||
}
|
||||
|
||||
// Domain represents a logical business domain grouping multiple tables from potentially different schemas.
|
||||
// Domains allow for organizing database tables by functional areas (e.g., authentication, user data, financial).
|
||||
type Domain struct {
|
||||
Name string `json:"name" yaml:"name" xml:"name"`
|
||||
Description string `json:"description,omitempty" yaml:"description,omitempty" xml:"description,omitempty"`
|
||||
Tables []*DomainTable `json:"tables" yaml:"tables" xml:"tables"`
|
||||
Comment string `json:"comment,omitempty" yaml:"comment,omitempty" xml:"comment,omitempty"`
|
||||
Metadata map[string]any `json:"metadata,omitempty" yaml:"metadata,omitempty" xml:"-"`
|
||||
Sequence uint `json:"sequence,omitempty" yaml:"sequence,omitempty" xml:"sequence,omitempty"`
|
||||
GUID string `json:"guid" yaml:"guid" xml:"guid"`
|
||||
}
|
||||
|
||||
// SQLName returns the domain name in lowercase for SQL compatibility.
|
||||
func (d *Domain) SQLName() string {
|
||||
return strings.ToLower(d.Name)
|
||||
}
|
||||
|
||||
// DomainTable represents a reference to a specific table within a domain.
|
||||
// It identifies the table by name and schema, allowing a single domain to include
|
||||
// tables from multiple schemas.
|
||||
type DomainTable struct {
|
||||
TableName string `json:"table_name" yaml:"table_name" xml:"table_name"`
|
||||
SchemaName string `json:"schema_name" yaml:"schema_name" xml:"schema_name"`
|
||||
Sequence uint `json:"sequence,omitempty" yaml:"sequence,omitempty" xml:"sequence,omitempty"`
|
||||
RefTable *Table `json:"-" yaml:"-" xml:"-"` // Excluded to prevent circular references
|
||||
GUID string `json:"guid" yaml:"guid" xml:"guid"`
|
||||
}
|
||||
|
||||
// Schema represents a database schema, which is a logical grouping of database objects
|
||||
// such as tables, views, sequences, and relationships within a database.
|
||||
type Schema struct {
|
||||
@@ -49,6 +90,16 @@ type Schema struct {
|
||||
RefDatabase *Database `json:"-" yaml:"-" xml:"-"` // Excluded to prevent circular references
|
||||
Relations []*Relationship `json:"relations,omitempty" yaml:"relations,omitempty" xml:"-"`
|
||||
Enums []*Enum `json:"enums,omitempty" yaml:"enums,omitempty" xml:"enums"`
|
||||
UpdatedAt string `json:"updatedat,omitempty" yaml:"updatedat,omitempty" xml:"updatedat,omitempty"`
|
||||
GUID string `json:"guid" yaml:"guid" xml:"guid"`
|
||||
}
|
||||
|
||||
// UpdaUpdateDateted sets the UpdatedAt field to the current time in RFC3339 format.
|
||||
func (d *Schema) UpdateDate() {
|
||||
d.UpdatedAt = time.Now().Format(time.RFC3339)
|
||||
if d.RefDatabase != nil {
|
||||
d.RefDatabase.UpdateDate()
|
||||
}
|
||||
}
|
||||
|
||||
// SQLName returns the schema name in lowercase for SQL compatibility.
|
||||
@@ -71,6 +122,16 @@ type Table struct {
|
||||
Metadata map[string]any `json:"metadata,omitempty" yaml:"metadata,omitempty" xml:"-"`
|
||||
Sequence uint `json:"sequence,omitempty" yaml:"sequence,omitempty" xml:"sequence,omitempty"`
|
||||
RefSchema *Schema `json:"-" yaml:"-" xml:"-"` // Excluded to prevent circular references
|
||||
UpdatedAt string `json:"updatedat,omitempty" yaml:"updatedat,omitempty" xml:"updatedat,omitempty"`
|
||||
GUID string `json:"guid" yaml:"guid" xml:"guid"`
|
||||
}
|
||||
|
||||
// UpdateDate sets the UpdatedAt field to the current time in RFC3339 format.
|
||||
func (d *Table) UpdateDate() {
|
||||
d.UpdatedAt = time.Now().Format(time.RFC3339)
|
||||
if d.RefSchema != nil {
|
||||
d.RefSchema.UpdateDate()
|
||||
}
|
||||
}
|
||||
|
||||
// SQLName returns the table name in lowercase for SQL compatibility.
|
||||
@@ -111,6 +172,7 @@ type View struct {
|
||||
Metadata map[string]any `json:"metadata,omitempty" yaml:"metadata,omitempty" xml:"-"`
|
||||
Sequence uint `json:"sequence,omitempty" yaml:"sequence,omitempty" xml:"sequence,omitempty"`
|
||||
RefSchema *Schema `json:"-" yaml:"-" xml:"-"` // Excluded to prevent circular references
|
||||
GUID string `json:"guid" yaml:"guid" xml:"guid"`
|
||||
}
|
||||
|
||||
// SQLName returns the view name in lowercase for SQL compatibility.
|
||||
@@ -134,6 +196,7 @@ type Sequence struct {
|
||||
Comment string `json:"comment,omitempty" yaml:"comment,omitempty" xml:"comment,omitempty"`
|
||||
Sequence uint `json:"sequence,omitempty" yaml:"sequence,omitempty" xml:"sequence,omitempty"`
|
||||
RefSchema *Schema `json:"-" yaml:"-" xml:"-"` // Excluded to prevent circular references
|
||||
GUID string `json:"guid" yaml:"guid" xml:"guid"`
|
||||
}
|
||||
|
||||
// SQLName returns the sequence name in lowercase for SQL compatibility.
|
||||
@@ -158,6 +221,7 @@ type Column struct {
|
||||
Comment string `json:"comment,omitempty" yaml:"comment,omitempty" xml:"comment,omitempty"`
|
||||
Collation string `json:"collation,omitempty" yaml:"collation,omitempty" xml:"collation,omitempty"`
|
||||
Sequence uint `json:"sequence,omitempty" yaml:"sequence,omitempty" xml:"sequence,omitempty"`
|
||||
GUID string `json:"guid" yaml:"guid" xml:"guid"`
|
||||
}
|
||||
|
||||
// SQLName returns the column name in lowercase for SQL compatibility.
|
||||
@@ -180,6 +244,7 @@ type Index struct {
|
||||
Include []string `json:"include,omitempty" yaml:"include,omitempty" xml:"include,omitempty"` // INCLUDE columns
|
||||
Comment string `json:"comment,omitempty" yaml:"comment,omitempty" xml:"comment,omitempty"`
|
||||
Sequence uint `json:"sequence,omitempty" yaml:"sequence,omitempty" xml:"sequence,omitempty"`
|
||||
GUID string `json:"guid" yaml:"guid" xml:"guid"`
|
||||
}
|
||||
|
||||
// SQLName returns the index name in lowercase for SQL compatibility.
|
||||
@@ -214,6 +279,7 @@ type Relationship struct {
|
||||
ThroughSchema string `json:"through_schema,omitempty" yaml:"through_schema,omitempty" xml:"through_schema,omitempty"`
|
||||
Description string `json:"description,omitempty" yaml:"description,omitempty" xml:"description,omitempty"`
|
||||
Sequence uint `json:"sequence,omitempty" yaml:"sequence,omitempty" xml:"sequence,omitempty"`
|
||||
GUID string `json:"guid" yaml:"guid" xml:"guid"`
|
||||
}
|
||||
|
||||
// SQLName returns the relationship name in lowercase for SQL compatibility.
|
||||
@@ -238,6 +304,7 @@ type Constraint struct {
|
||||
Deferrable bool `json:"deferrable,omitempty" yaml:"deferrable,omitempty" xml:"deferrable,omitempty"`
|
||||
InitiallyDeferred bool `json:"initially_deferred,omitempty" yaml:"initially_deferred,omitempty" xml:"initially_deferred,omitempty"`
|
||||
Sequence uint `json:"sequence,omitempty" yaml:"sequence,omitempty" xml:"sequence,omitempty"`
|
||||
GUID string `json:"guid" yaml:"guid" xml:"guid"`
|
||||
}
|
||||
|
||||
// SQLName returns the constraint name in lowercase for SQL compatibility.
|
||||
@@ -253,6 +320,7 @@ type Enum struct {
|
||||
Name string `json:"name" yaml:"name" xml:"name"`
|
||||
Values []string `json:"values" yaml:"values" xml:"values"`
|
||||
Schema string `json:"schema,omitempty" yaml:"schema,omitempty" xml:"schema,omitempty"`
|
||||
GUID string `json:"guid" yaml:"guid" xml:"guid"`
|
||||
}
|
||||
|
||||
// SQLName returns the enum name in lowercase for SQL compatibility.
|
||||
@@ -260,6 +328,16 @@ func (d *Enum) SQLName() string {
|
||||
return strings.ToLower(d.Name)
|
||||
}
|
||||
|
||||
// InitEnum initializes a new Enum with empty values slice
|
||||
func InitEnum(name, schema string) *Enum {
|
||||
return &Enum{
|
||||
Name: name,
|
||||
Schema: schema,
|
||||
Values: make([]string, 0),
|
||||
GUID: uuid.New().String(),
|
||||
}
|
||||
}
|
||||
|
||||
// Supported constraint types.
|
||||
const (
|
||||
PrimaryKeyConstraint ConstraintType = "primary_key" // Primary key uniquely identifies each record
|
||||
@@ -281,6 +359,7 @@ type Script struct {
|
||||
Version string `json:"version,omitempty" yaml:"version,omitempty" xml:"version,omitempty"`
|
||||
Priority int `json:"priority,omitempty" yaml:"priority,omitempty" xml:"priority,omitempty"`
|
||||
Sequence uint `json:"sequence,omitempty" yaml:"sequence,omitempty" xml:"sequence,omitempty"`
|
||||
GUID string `json:"guid" yaml:"guid" xml:"guid"`
|
||||
}
|
||||
|
||||
// SQLName returns the script name in lowercase for SQL compatibility.
|
||||
@@ -295,6 +374,8 @@ func InitDatabase(name string) *Database {
|
||||
return &Database{
|
||||
Name: name,
|
||||
Schemas: make([]*Schema, 0),
|
||||
Domains: make([]*Domain, 0),
|
||||
GUID: uuid.New().String(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -308,6 +389,7 @@ func InitSchema(name string) *Schema {
|
||||
Permissions: make(map[string]string),
|
||||
Metadata: make(map[string]any),
|
||||
Scripts: make([]*Script, 0),
|
||||
GUID: uuid.New().String(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -321,6 +403,7 @@ func InitTable(name, schema string) *Table {
|
||||
Indexes: make(map[string]*Index),
|
||||
Relationships: make(map[string]*Relationship),
|
||||
Metadata: make(map[string]any),
|
||||
GUID: uuid.New().String(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -330,6 +413,7 @@ func InitColumn(name, table, schema string) *Column {
|
||||
Name: name,
|
||||
Table: table,
|
||||
Schema: schema,
|
||||
GUID: uuid.New().String(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -341,6 +425,7 @@ func InitIndex(name, table, schema string) *Index {
|
||||
Schema: schema,
|
||||
Columns: make([]string, 0),
|
||||
Include: make([]string, 0),
|
||||
GUID: uuid.New().String(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -353,6 +438,7 @@ func InitRelation(name, schema string) *Relationship {
|
||||
Properties: make(map[string]string),
|
||||
FromColumns: make([]string, 0),
|
||||
ToColumns: make([]string, 0),
|
||||
GUID: uuid.New().String(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -362,6 +448,7 @@ func InitRelationship(name string, relType RelationType) *Relationship {
|
||||
Name: name,
|
||||
Type: relType,
|
||||
Properties: make(map[string]string),
|
||||
GUID: uuid.New().String(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -372,6 +459,7 @@ func InitConstraint(name string, constraintType ConstraintType) *Constraint {
|
||||
Type: constraintType,
|
||||
Columns: make([]string, 0),
|
||||
ReferencedColumns: make([]string, 0),
|
||||
GUID: uuid.New().String(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -380,6 +468,7 @@ func InitScript(name string) *Script {
|
||||
return &Script{
|
||||
Name: name,
|
||||
RunAfter: make([]string, 0),
|
||||
GUID: uuid.New().String(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -390,6 +479,7 @@ func InitView(name, schema string) *View {
|
||||
Schema: schema,
|
||||
Columns: make(map[string]*Column),
|
||||
Metadata: make(map[string]any),
|
||||
GUID: uuid.New().String(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -400,5 +490,25 @@ func InitSequence(name, schema string) *Sequence {
|
||||
Schema: schema,
|
||||
IncrementBy: 1,
|
||||
StartValue: 1,
|
||||
GUID: uuid.New().String(),
|
||||
}
|
||||
}
|
||||
|
||||
// InitDomain initializes a new Domain with empty slices and maps
|
||||
func InitDomain(name string) *Domain {
|
||||
return &Domain{
|
||||
Name: name,
|
||||
Tables: make([]*DomainTable, 0),
|
||||
Metadata: make(map[string]any),
|
||||
GUID: uuid.New().String(),
|
||||
}
|
||||
}
|
||||
|
||||
// InitDomainTable initializes a new DomainTable reference
|
||||
func InitDomainTable(tableName, schemaName string) *DomainTable {
|
||||
return &DomainTable{
|
||||
TableName: tableName,
|
||||
SchemaName: schemaName,
|
||||
GUID: uuid.New().String(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,31 +4,31 @@ import "strings"
|
||||
|
||||
var GoToStdTypes = map[string]string{
|
||||
"bool": "boolean",
|
||||
"int64": "integer",
|
||||
"int64": "bigint",
|
||||
"int": "integer",
|
||||
"int8": "integer",
|
||||
"int16": "integer",
|
||||
"int8": "smallint",
|
||||
"int16": "smallint",
|
||||
"int32": "integer",
|
||||
"uint": "integer",
|
||||
"uint8": "integer",
|
||||
"uint16": "integer",
|
||||
"uint8": "smallint",
|
||||
"uint16": "smallint",
|
||||
"uint32": "integer",
|
||||
"uint64": "integer",
|
||||
"uintptr": "integer",
|
||||
"znullint64": "integer",
|
||||
"uint64": "bigint",
|
||||
"uintptr": "bigint",
|
||||
"znullint64": "bigint",
|
||||
"znullint32": "integer",
|
||||
"znullbyte": "integer",
|
||||
"znullbyte": "smallint",
|
||||
"float64": "double",
|
||||
"float32": "double",
|
||||
"complex64": "double",
|
||||
"complex128": "double",
|
||||
"customfloat64": "double",
|
||||
"string": "string",
|
||||
"Pointer": "integer",
|
||||
"string": "text",
|
||||
"Pointer": "bigint",
|
||||
"[]byte": "blob",
|
||||
"customdate": "string",
|
||||
"customtime": "string",
|
||||
"customtimestamp": "string",
|
||||
"customdate": "date",
|
||||
"customtime": "time",
|
||||
"customtimestamp": "timestamp",
|
||||
"sqlfloat64": "double",
|
||||
"sqlfloat16": "double",
|
||||
"sqluuid": "uuid",
|
||||
@@ -36,9 +36,9 @@ var GoToStdTypes = map[string]string{
|
||||
"sqljson": "json",
|
||||
"sqlint64": "bigint",
|
||||
"sqlint32": "integer",
|
||||
"sqlint16": "integer",
|
||||
"sqlint16": "smallint",
|
||||
"sqlbool": "boolean",
|
||||
"sqlstring": "string",
|
||||
"sqlstring": "text",
|
||||
"nullablejsonb": "jsonb",
|
||||
"nullablejson": "json",
|
||||
"nullableuuid": "uuid",
|
||||
@@ -67,7 +67,7 @@ var GoToPGSQLTypes = map[string]string{
|
||||
"float32": "real",
|
||||
"complex64": "double precision",
|
||||
"complex128": "double precision",
|
||||
"customfloat64": "double precisio",
|
||||
"customfloat64": "double precision",
|
||||
"string": "text",
|
||||
"Pointer": "bigint",
|
||||
"[]byte": "bytea",
|
||||
@@ -81,9 +81,9 @@ var GoToPGSQLTypes = map[string]string{
|
||||
"sqljson": "json",
|
||||
"sqlint64": "bigint",
|
||||
"sqlint32": "integer",
|
||||
"sqlint16": "integer",
|
||||
"sqlint16": "smallint",
|
||||
"sqlbool": "boolean",
|
||||
"sqlstring": "string",
|
||||
"sqlstring": "text",
|
||||
"nullablejsonb": "jsonb",
|
||||
"nullablejson": "json",
|
||||
"nullableuuid": "uuid",
|
||||
|
||||
339
pkg/pgsql/datatypes_test.go
Normal file
339
pkg/pgsql/datatypes_test.go
Normal file
@@ -0,0 +1,339 @@
|
||||
package pgsql
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestValidSQLType(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sqltype string
|
||||
want bool
|
||||
}{
|
||||
// PostgreSQL types
|
||||
{"Valid PGSQL bigint", "bigint", true},
|
||||
{"Valid PGSQL integer", "integer", true},
|
||||
{"Valid PGSQL text", "text", true},
|
||||
{"Valid PGSQL boolean", "boolean", true},
|
||||
{"Valid PGSQL double precision", "double precision", true},
|
||||
{"Valid PGSQL bytea", "bytea", true},
|
||||
{"Valid PGSQL uuid", "uuid", true},
|
||||
{"Valid PGSQL jsonb", "jsonb", true},
|
||||
{"Valid PGSQL json", "json", true},
|
||||
{"Valid PGSQL timestamp", "timestamp", true},
|
||||
{"Valid PGSQL date", "date", true},
|
||||
{"Valid PGSQL time", "time", true},
|
||||
{"Valid PGSQL citext", "citext", true},
|
||||
|
||||
// Standard types
|
||||
{"Valid std double", "double", true},
|
||||
{"Valid std blob", "blob", true},
|
||||
|
||||
// Case insensitive
|
||||
{"Case insensitive BIGINT", "BIGINT", true},
|
||||
{"Case insensitive TeXt", "TeXt", true},
|
||||
{"Case insensitive BoOlEaN", "BoOlEaN", true},
|
||||
|
||||
// Invalid types
|
||||
{"Invalid type", "invalidtype", false},
|
||||
{"Invalid type varchar", "varchar", false},
|
||||
{"Empty string", "", false},
|
||||
{"Random string", "foobar", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := ValidSQLType(tt.sqltype)
|
||||
if got != tt.want {
|
||||
t.Errorf("ValidSQLType(%q) = %v, want %v", tt.sqltype, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetSQLType(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
anytype string
|
||||
want string
|
||||
}{
|
||||
// Go types to PostgreSQL types
|
||||
{"Go bool to boolean", "bool", "boolean"},
|
||||
{"Go int64 to bigint", "int64", "bigint"},
|
||||
{"Go int to integer", "int", "integer"},
|
||||
{"Go string to text", "string", "text"},
|
||||
{"Go float64 to double precision", "float64", "double precision"},
|
||||
{"Go float32 to real", "float32", "real"},
|
||||
{"Go []byte to bytea", "[]byte", "bytea"},
|
||||
|
||||
// SQL types remain SQL types
|
||||
{"SQL bigint", "bigint", "bigint"},
|
||||
{"SQL integer", "integer", "integer"},
|
||||
{"SQL text", "text", "text"},
|
||||
{"SQL boolean", "boolean", "boolean"},
|
||||
{"SQL uuid", "uuid", "uuid"},
|
||||
{"SQL jsonb", "jsonb", "jsonb"},
|
||||
|
||||
// Case insensitive Go types
|
||||
{"Case insensitive BOOL", "BOOL", "boolean"},
|
||||
{"Case insensitive InT64", "InT64", "bigint"},
|
||||
{"Case insensitive STRING", "STRING", "text"},
|
||||
|
||||
// Case insensitive SQL types
|
||||
{"Case insensitive BIGINT", "BIGINT", "bigint"},
|
||||
{"Case insensitive TEXT", "TEXT", "text"},
|
||||
|
||||
// Custom types
|
||||
{"Custom sqluuid", "sqluuid", "uuid"},
|
||||
{"Custom sqljsonb", "sqljsonb", "jsonb"},
|
||||
{"Custom sqlint64", "sqlint64", "bigint"},
|
||||
|
||||
// Unknown types default to text
|
||||
{"Unknown type varchar", "varchar", "text"},
|
||||
{"Unknown type foobar", "foobar", "text"},
|
||||
{"Empty string", "", "text"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := GetSQLType(tt.anytype)
|
||||
if got != tt.want {
|
||||
t.Errorf("GetSQLType(%q) = %q, want %q", tt.anytype, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertSQLType(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
anytype string
|
||||
want string
|
||||
}{
|
||||
// Go types to PostgreSQL types
|
||||
{"Go bool to boolean", "bool", "boolean"},
|
||||
{"Go int64 to bigint", "int64", "bigint"},
|
||||
{"Go int to integer", "int", "integer"},
|
||||
{"Go string to text", "string", "text"},
|
||||
{"Go float64 to double precision", "float64", "double precision"},
|
||||
{"Go float32 to real", "float32", "real"},
|
||||
{"Go []byte to bytea", "[]byte", "bytea"},
|
||||
|
||||
// SQL types remain SQL types
|
||||
{"SQL bigint", "bigint", "bigint"},
|
||||
{"SQL integer", "integer", "integer"},
|
||||
{"SQL text", "text", "text"},
|
||||
{"SQL boolean", "boolean", "boolean"},
|
||||
|
||||
// Case insensitive
|
||||
{"Case insensitive BOOL", "BOOL", "boolean"},
|
||||
{"Case insensitive InT64", "InT64", "bigint"},
|
||||
|
||||
// Unknown types remain unchanged (difference from GetSQLType)
|
||||
{"Unknown type varchar", "varchar", "varchar"},
|
||||
{"Unknown type foobar", "foobar", "foobar"},
|
||||
{"Empty string", "", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := ConvertSQLType(tt.anytype)
|
||||
if got != tt.want {
|
||||
t.Errorf("ConvertSQLType(%q) = %q, want %q", tt.anytype, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsGoType(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
typeName string
|
||||
want bool
|
||||
}{
|
||||
// Go basic types
|
||||
{"Go bool", "bool", true},
|
||||
{"Go int64", "int64", true},
|
||||
{"Go int", "int", true},
|
||||
{"Go int32", "int32", true},
|
||||
{"Go int16", "int16", true},
|
||||
{"Go int8", "int8", true},
|
||||
{"Go uint", "uint", true},
|
||||
{"Go uint64", "uint64", true},
|
||||
{"Go uint32", "uint32", true},
|
||||
{"Go uint16", "uint16", true},
|
||||
{"Go uint8", "uint8", true},
|
||||
{"Go float64", "float64", true},
|
||||
{"Go float32", "float32", true},
|
||||
{"Go string", "string", true},
|
||||
{"Go []byte", "[]byte", true},
|
||||
|
||||
// Go custom types
|
||||
{"Go complex64", "complex64", true},
|
||||
{"Go complex128", "complex128", true},
|
||||
{"Go uintptr", "uintptr", true},
|
||||
{"Go Pointer", "Pointer", true},
|
||||
|
||||
// Custom SQL types
|
||||
{"Custom sqluuid", "sqluuid", true},
|
||||
{"Custom sqljsonb", "sqljsonb", true},
|
||||
{"Custom sqlint64", "sqlint64", true},
|
||||
{"Custom customdate", "customdate", true},
|
||||
{"Custom customtime", "customtime", true},
|
||||
|
||||
// Case insensitive
|
||||
{"Case insensitive BOOL", "BOOL", true},
|
||||
{"Case insensitive InT64", "InT64", true},
|
||||
{"Case insensitive STRING", "STRING", true},
|
||||
|
||||
// SQL types (not Go types)
|
||||
{"SQL bigint", "bigint", false},
|
||||
{"SQL integer", "integer", false},
|
||||
{"SQL text", "text", false},
|
||||
{"SQL boolean", "boolean", false},
|
||||
|
||||
// Invalid types
|
||||
{"Invalid type", "invalidtype", false},
|
||||
{"Empty string", "", false},
|
||||
{"Random string", "foobar", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := IsGoType(tt.typeName)
|
||||
if got != tt.want {
|
||||
t.Errorf("IsGoType(%q) = %v, want %v", tt.typeName, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetStdTypeFromGo(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
typeName string
|
||||
want string
|
||||
}{
|
||||
// Go types to standard SQL types
|
||||
{"Go bool to boolean", "bool", "boolean"},
|
||||
{"Go int64 to bigint", "int64", "bigint"},
|
||||
{"Go int to integer", "int", "integer"},
|
||||
{"Go string to text", "string", "text"},
|
||||
{"Go float64 to double", "float64", "double"},
|
||||
{"Go float32 to double", "float32", "double"},
|
||||
{"Go []byte to blob", "[]byte", "blob"},
|
||||
{"Go int32 to integer", "int32", "integer"},
|
||||
{"Go int16 to smallint", "int16", "smallint"},
|
||||
|
||||
// Custom types
|
||||
{"Custom sqluuid to uuid", "sqluuid", "uuid"},
|
||||
{"Custom sqljsonb to jsonb", "sqljsonb", "jsonb"},
|
||||
{"Custom sqlint64 to bigint", "sqlint64", "bigint"},
|
||||
{"Custom customdate to date", "customdate", "date"},
|
||||
|
||||
// Case insensitive
|
||||
{"Case insensitive BOOL", "BOOL", "boolean"},
|
||||
{"Case insensitive InT64", "InT64", "bigint"},
|
||||
{"Case insensitive STRING", "STRING", "text"},
|
||||
|
||||
// Non-Go types remain unchanged
|
||||
{"SQL bigint unchanged", "bigint", "bigint"},
|
||||
{"SQL integer unchanged", "integer", "integer"},
|
||||
{"Invalid type unchanged", "invalidtype", "invalidtype"},
|
||||
{"Empty string unchanged", "", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := GetStdTypeFromGo(tt.typeName)
|
||||
if got != tt.want {
|
||||
t.Errorf("GetStdTypeFromGo(%q) = %q, want %q", tt.typeName, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGoToStdTypesMap(t *testing.T) {
|
||||
// Test that the map contains expected entries
|
||||
expectedMappings := map[string]string{
|
||||
"bool": "boolean",
|
||||
"int64": "bigint",
|
||||
"int": "integer",
|
||||
"string": "text",
|
||||
"float64": "double",
|
||||
"[]byte": "blob",
|
||||
}
|
||||
|
||||
for goType, expectedStd := range expectedMappings {
|
||||
if stdType, ok := GoToStdTypes[goType]; !ok {
|
||||
t.Errorf("GoToStdTypes missing entry for %q", goType)
|
||||
} else if stdType != expectedStd {
|
||||
t.Errorf("GoToStdTypes[%q] = %q, want %q", goType, stdType, expectedStd)
|
||||
}
|
||||
}
|
||||
|
||||
// Test that the map is not empty
|
||||
if len(GoToStdTypes) == 0 {
|
||||
t.Error("GoToStdTypes map is empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGoToPGSQLTypesMap(t *testing.T) {
|
||||
// Test that the map contains expected entries
|
||||
expectedMappings := map[string]string{
|
||||
"bool": "boolean",
|
||||
"int64": "bigint",
|
||||
"int": "integer",
|
||||
"string": "text",
|
||||
"float64": "double precision",
|
||||
"float32": "real",
|
||||
"[]byte": "bytea",
|
||||
}
|
||||
|
||||
for goType, expectedPG := range expectedMappings {
|
||||
if pgType, ok := GoToPGSQLTypes[goType]; !ok {
|
||||
t.Errorf("GoToPGSQLTypes missing entry for %q", goType)
|
||||
} else if pgType != expectedPG {
|
||||
t.Errorf("GoToPGSQLTypes[%q] = %q, want %q", goType, pgType, expectedPG)
|
||||
}
|
||||
}
|
||||
|
||||
// Test that the map is not empty
|
||||
if len(GoToPGSQLTypes) == 0 {
|
||||
t.Error("GoToPGSQLTypes map is empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTypeConversionConsistency(t *testing.T) {
|
||||
// Test that GetSQLType and ConvertSQLType are consistent for known types
|
||||
knownGoTypes := []string{"bool", "int64", "int", "string", "float64", "[]byte"}
|
||||
|
||||
for _, goType := range knownGoTypes {
|
||||
getSQLResult := GetSQLType(goType)
|
||||
convertResult := ConvertSQLType(goType)
|
||||
|
||||
if getSQLResult != convertResult {
|
||||
t.Errorf("Inconsistent results for %q: GetSQLType=%q, ConvertSQLType=%q",
|
||||
goType, getSQLResult, convertResult)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetSQLTypeVsConvertSQLTypeDifference(t *testing.T) {
|
||||
// Test that GetSQLType returns "text" for unknown types
|
||||
// while ConvertSQLType returns the input unchanged
|
||||
unknownTypes := []string{"varchar", "char", "customtype", "unknowntype"}
|
||||
|
||||
for _, unknown := range unknownTypes {
|
||||
getSQLResult := GetSQLType(unknown)
|
||||
convertResult := ConvertSQLType(unknown)
|
||||
|
||||
if getSQLResult != "text" {
|
||||
t.Errorf("GetSQLType(%q) = %q, want %q", unknown, getSQLResult, "text")
|
||||
}
|
||||
|
||||
if convertResult != unknown {
|
||||
t.Errorf("ConvertSQLType(%q) = %q, want %q", unknown, convertResult, unknown)
|
||||
}
|
||||
}
|
||||
}
|
||||
136
pkg/pgsql/keywords_test.go
Normal file
136
pkg/pgsql/keywords_test.go
Normal file
@@ -0,0 +1,136 @@
|
||||
package pgsql
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGetPostgresKeywords(t *testing.T) {
|
||||
keywords := GetPostgresKeywords()
|
||||
|
||||
// Test that keywords are returned
|
||||
if len(keywords) == 0 {
|
||||
t.Fatal("Expected non-empty list of keywords")
|
||||
}
|
||||
|
||||
// Test that we get all keywords from the map
|
||||
expectedCount := len(postgresKeywords)
|
||||
if len(keywords) != expectedCount {
|
||||
t.Errorf("Expected %d keywords, got %d", expectedCount, len(keywords))
|
||||
}
|
||||
|
||||
// Test that all returned keywords exist in the map
|
||||
for _, keyword := range keywords {
|
||||
if !postgresKeywords[keyword] {
|
||||
t.Errorf("Keyword %q not found in postgresKeywords map", keyword)
|
||||
}
|
||||
}
|
||||
|
||||
// Test that no duplicate keywords are returned
|
||||
seen := make(map[string]bool)
|
||||
for _, keyword := range keywords {
|
||||
if seen[keyword] {
|
||||
t.Errorf("Duplicate keyword found: %q", keyword)
|
||||
}
|
||||
seen[keyword] = true
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostgresKeywordsMap(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
keyword string
|
||||
want bool
|
||||
}{
|
||||
{"SELECT keyword", "select", true},
|
||||
{"FROM keyword", "from", true},
|
||||
{"WHERE keyword", "where", true},
|
||||
{"TABLE keyword", "table", true},
|
||||
{"PRIMARY keyword", "primary", true},
|
||||
{"FOREIGN keyword", "foreign", true},
|
||||
{"CREATE keyword", "create", true},
|
||||
{"DROP keyword", "drop", true},
|
||||
{"ALTER keyword", "alter", true},
|
||||
{"INDEX keyword", "index", true},
|
||||
{"NOT keyword", "not", true},
|
||||
{"NULL keyword", "null", true},
|
||||
{"TRUE keyword", "true", true},
|
||||
{"FALSE keyword", "false", true},
|
||||
{"Non-keyword lowercase", "notakeyword", false},
|
||||
{"Non-keyword uppercase", "NOTAKEYWORD", false},
|
||||
{"Empty string", "", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := postgresKeywords[tt.keyword]
|
||||
if got != tt.want {
|
||||
t.Errorf("postgresKeywords[%q] = %v, want %v", tt.keyword, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostgresKeywordsMapContent(t *testing.T) {
|
||||
// Test that the map contains expected common keywords
|
||||
commonKeywords := []string{
|
||||
"select", "insert", "update", "delete", "create", "drop", "alter",
|
||||
"table", "index", "view", "schema", "function", "procedure",
|
||||
"primary", "foreign", "key", "constraint", "unique", "check",
|
||||
"null", "not", "and", "or", "like", "in", "between",
|
||||
"join", "inner", "left", "right", "cross", "full", "outer",
|
||||
"where", "having", "group", "order", "limit", "offset",
|
||||
"union", "intersect", "except",
|
||||
"begin", "commit", "rollback", "transaction",
|
||||
}
|
||||
|
||||
for _, keyword := range commonKeywords {
|
||||
if !postgresKeywords[keyword] {
|
||||
t.Errorf("Expected common keyword %q to be in postgresKeywords map", keyword)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostgresKeywordsMapSize(t *testing.T) {
|
||||
// PostgreSQL has a substantial list of reserved keywords
|
||||
// This test ensures the map has a reasonable number of entries
|
||||
minExpectedKeywords := 200 // PostgreSQL 13+ has 400+ reserved words
|
||||
|
||||
if len(postgresKeywords) < minExpectedKeywords {
|
||||
t.Errorf("Expected at least %d keywords, got %d. The map may be incomplete.",
|
||||
minExpectedKeywords, len(postgresKeywords))
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetPostgresKeywordsConsistency(t *testing.T) {
|
||||
// Test that calling GetPostgresKeywords multiple times returns consistent results
|
||||
keywords1 := GetPostgresKeywords()
|
||||
keywords2 := GetPostgresKeywords()
|
||||
|
||||
if len(keywords1) != len(keywords2) {
|
||||
t.Errorf("Inconsistent results: first call returned %d keywords, second call returned %d",
|
||||
len(keywords1), len(keywords2))
|
||||
}
|
||||
|
||||
// Create a map from both results to compare
|
||||
map1 := make(map[string]bool)
|
||||
map2 := make(map[string]bool)
|
||||
|
||||
for _, k := range keywords1 {
|
||||
map1[k] = true
|
||||
}
|
||||
for _, k := range keywords2 {
|
||||
map2[k] = true
|
||||
}
|
||||
|
||||
// Check that both contain the same keywords
|
||||
for k := range map1 {
|
||||
if !map2[k] {
|
||||
t.Errorf("Keyword %q present in first call but not in second", k)
|
||||
}
|
||||
}
|
||||
for k := range map2 {
|
||||
if !map1[k] {
|
||||
t.Errorf("Keyword %q present in second call but not in first", k)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -632,6 +632,9 @@ func (r *Reader) parseColumn(fieldName string, fieldType ast.Expr, tag string, s
|
||||
column.Name = parts[0]
|
||||
}
|
||||
|
||||
// Track if we found explicit nullability markers
|
||||
hasExplicitNullableMarker := false
|
||||
|
||||
// Parse tag attributes
|
||||
for _, part := range parts[1:] {
|
||||
kv := strings.SplitN(part, ":", 2)
|
||||
@@ -649,6 +652,10 @@ func (r *Reader) parseColumn(fieldName string, fieldType ast.Expr, tag string, s
|
||||
column.IsPrimaryKey = true
|
||||
case "notnull":
|
||||
column.NotNull = true
|
||||
hasExplicitNullableMarker = true
|
||||
case "nullzero":
|
||||
column.NotNull = false
|
||||
hasExplicitNullableMarker = true
|
||||
case "autoincrement":
|
||||
column.AutoIncrement = true
|
||||
case "default":
|
||||
@@ -664,17 +671,15 @@ func (r *Reader) parseColumn(fieldName string, fieldType ast.Expr, tag string, s
|
||||
|
||||
// Determine if nullable based on Go type and bun tags
|
||||
// In Bun:
|
||||
// - nullzero tag means the field is nullable (can be NULL in DB)
|
||||
// - absence of nullzero means the field is NOT NULL
|
||||
// - primitive types (int64, bool, string) are NOT NULL by default
|
||||
column.NotNull = true
|
||||
// Primary keys are always NOT NULL
|
||||
|
||||
if strings.Contains(bunTag, "nullzero") {
|
||||
column.NotNull = false
|
||||
} else {
|
||||
// - explicit "notnull" tag means NOT NULL
|
||||
// - explicit "nullzero" tag means nullable
|
||||
// - absence of explicit markers: infer from Go type
|
||||
if !hasExplicitNullableMarker {
|
||||
// Infer from Go type if no explicit marker found
|
||||
column.NotNull = !r.isNullableGoType(fieldType)
|
||||
}
|
||||
|
||||
// Primary keys are always NOT NULL
|
||||
if column.IsPrimaryKey {
|
||||
column.NotNull = true
|
||||
}
|
||||
|
||||
@@ -4,7 +4,9 @@ import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||
@@ -24,11 +26,23 @@ func NewReader(options *readers.ReaderOptions) *Reader {
|
||||
}
|
||||
|
||||
// ReadDatabase reads and parses DBML input, returning a Database model
|
||||
// If FilePath points to a directory, all .dbml files are loaded and merged
|
||||
func (r *Reader) ReadDatabase() (*models.Database, error) {
|
||||
if r.options.FilePath == "" {
|
||||
return nil, fmt.Errorf("file path is required for DBML reader")
|
||||
}
|
||||
|
||||
// Check if path is a directory
|
||||
info, err := os.Stat(r.options.FilePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to stat path: %w", err)
|
||||
}
|
||||
|
||||
if info.IsDir() {
|
||||
return r.readDirectoryDBML(r.options.FilePath)
|
||||
}
|
||||
|
||||
// Single file - existing logic
|
||||
content, err := os.ReadFile(r.options.FilePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read file: %w", err)
|
||||
@@ -67,15 +81,341 @@ func (r *Reader) ReadTable() (*models.Table, error) {
|
||||
return schema.Tables[0], nil
|
||||
}
|
||||
|
||||
// stripQuotes removes surrounding quotes from an identifier
|
||||
// readDirectoryDBML processes all .dbml files in directory
|
||||
// Returns merged Database model
|
||||
func (r *Reader) readDirectoryDBML(dirPath string) (*models.Database, error) {
|
||||
// Discover and sort DBML files
|
||||
files, err := r.discoverDBMLFiles(dirPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to discover DBML files: %w", err)
|
||||
}
|
||||
|
||||
// If no files found, return empty database
|
||||
if len(files) == 0 {
|
||||
db := models.InitDatabase("database")
|
||||
if r.options.Metadata != nil {
|
||||
if name, ok := r.options.Metadata["name"].(string); ok {
|
||||
db.Name = name
|
||||
}
|
||||
}
|
||||
return db, nil
|
||||
}
|
||||
|
||||
// Initialize database (will be merged with files)
|
||||
var db *models.Database
|
||||
|
||||
// Process each file in sorted order
|
||||
for _, filePath := range files {
|
||||
content, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read file %s: %w", filePath, err)
|
||||
}
|
||||
|
||||
fileDB, err := r.parseDBML(string(content))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse file %s: %w", filePath, err)
|
||||
}
|
||||
|
||||
// First file initializes the database
|
||||
if db == nil {
|
||||
db = fileDB
|
||||
} else {
|
||||
// Subsequent files are merged
|
||||
mergeDatabase(db, fileDB)
|
||||
}
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
// splitIdentifier splits a dotted identifier while respecting quotes
|
||||
// Handles cases like: "schema.with.dots"."table"."column"
|
||||
func splitIdentifier(s string) []string {
|
||||
var parts []string
|
||||
var current strings.Builder
|
||||
inQuote := false
|
||||
quoteChar := byte(0)
|
||||
|
||||
for i := 0; i < len(s); i++ {
|
||||
ch := s[i]
|
||||
|
||||
if !inQuote {
|
||||
switch ch {
|
||||
case '"', '\'':
|
||||
inQuote = true
|
||||
quoteChar = ch
|
||||
current.WriteByte(ch)
|
||||
case '.':
|
||||
if current.Len() > 0 {
|
||||
parts = append(parts, current.String())
|
||||
current.Reset()
|
||||
}
|
||||
default:
|
||||
current.WriteByte(ch)
|
||||
}
|
||||
} else {
|
||||
current.WriteByte(ch)
|
||||
if ch == quoteChar {
|
||||
inQuote = false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if current.Len() > 0 {
|
||||
parts = append(parts, current.String())
|
||||
}
|
||||
|
||||
return parts
|
||||
}
|
||||
|
||||
// stripQuotes removes surrounding quotes and comments from an identifier
|
||||
func stripQuotes(s string) string {
|
||||
s = strings.TrimSpace(s)
|
||||
|
||||
// Remove DBML comments in brackets (e.g., [note: 'description'])
|
||||
// This handles inline comments like: "table_name" [note: 'comment']
|
||||
commentRegex := regexp.MustCompile(`\s*\[.*?\]\s*`)
|
||||
s = commentRegex.ReplaceAllString(s, "")
|
||||
|
||||
// Trim again after removing comments
|
||||
s = strings.TrimSpace(s)
|
||||
|
||||
// Remove surrounding quotes (double or single)
|
||||
if len(s) >= 2 && ((s[0] == '"' && s[len(s)-1] == '"') || (s[0] == '\'' && s[len(s)-1] == '\'')) {
|
||||
return s[1 : len(s)-1]
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// parseFilePrefix extracts numeric prefix from filename
|
||||
// Examples: "1_schema.dbml" -> (1, true), "tables.dbml" -> (0, false)
|
||||
func parseFilePrefix(filename string) (int, bool) {
|
||||
base := filepath.Base(filename)
|
||||
re := regexp.MustCompile(`^(\d+)[_-]`)
|
||||
matches := re.FindStringSubmatch(base)
|
||||
if len(matches) > 1 {
|
||||
var prefix int
|
||||
_, err := fmt.Sscanf(matches[1], "%d", &prefix)
|
||||
if err == nil {
|
||||
return prefix, true
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// hasCommentedRefs scans file content for commented-out Ref statements
|
||||
// Returns true if file contains lines like: // Ref: table.col > other.col
|
||||
func hasCommentedRefs(filePath string) (bool, error) {
|
||||
content, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(strings.NewReader(string(content)))
|
||||
commentedRefRegex := regexp.MustCompile(`^\s*//.*Ref:\s+`)
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if commentedRefRegex.MatchString(line) {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// discoverDBMLFiles finds all .dbml files in directory and returns them sorted
|
||||
func (r *Reader) discoverDBMLFiles(dirPath string) ([]string, error) {
|
||||
pattern := filepath.Join(dirPath, "*.dbml")
|
||||
files, err := filepath.Glob(pattern)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to glob .dbml files: %w", err)
|
||||
}
|
||||
|
||||
return sortDBMLFiles(files), nil
|
||||
}
|
||||
|
||||
// sortDBMLFiles sorts files by:
|
||||
// 1. Files without commented refs (by numeric prefix, then alphabetically)
|
||||
// 2. Files with commented refs (by numeric prefix, then alphabetically)
|
||||
func sortDBMLFiles(files []string) []string {
|
||||
// Create a slice to hold file info for sorting
|
||||
type fileInfo struct {
|
||||
path string
|
||||
hasCommented bool
|
||||
prefix int
|
||||
hasPrefix bool
|
||||
basename string
|
||||
}
|
||||
|
||||
fileInfos := make([]fileInfo, 0, len(files))
|
||||
|
||||
for _, file := range files {
|
||||
hasCommented, err := hasCommentedRefs(file)
|
||||
if err != nil {
|
||||
// If we can't read the file, treat it as not having commented refs
|
||||
hasCommented = false
|
||||
}
|
||||
|
||||
prefix, hasPrefix := parseFilePrefix(file)
|
||||
basename := filepath.Base(file)
|
||||
|
||||
fileInfos = append(fileInfos, fileInfo{
|
||||
path: file,
|
||||
hasCommented: hasCommented,
|
||||
prefix: prefix,
|
||||
hasPrefix: hasPrefix,
|
||||
basename: basename,
|
||||
})
|
||||
}
|
||||
|
||||
// Sort by: hasCommented (false first), hasPrefix (true first), prefix, basename
|
||||
sort.Slice(fileInfos, func(i, j int) bool {
|
||||
// First, sort by commented refs (files without commented refs come first)
|
||||
if fileInfos[i].hasCommented != fileInfos[j].hasCommented {
|
||||
return !fileInfos[i].hasCommented
|
||||
}
|
||||
|
||||
// Then by presence of prefix (files with prefix come first)
|
||||
if fileInfos[i].hasPrefix != fileInfos[j].hasPrefix {
|
||||
return fileInfos[i].hasPrefix
|
||||
}
|
||||
|
||||
// If both have prefix, sort by prefix value
|
||||
if fileInfos[i].hasPrefix && fileInfos[j].hasPrefix {
|
||||
if fileInfos[i].prefix != fileInfos[j].prefix {
|
||||
return fileInfos[i].prefix < fileInfos[j].prefix
|
||||
}
|
||||
}
|
||||
|
||||
// Finally, sort alphabetically by basename
|
||||
return fileInfos[i].basename < fileInfos[j].basename
|
||||
})
|
||||
|
||||
// Extract sorted paths
|
||||
sortedFiles := make([]string, len(fileInfos))
|
||||
for i, info := range fileInfos {
|
||||
sortedFiles[i] = info.path
|
||||
}
|
||||
|
||||
return sortedFiles
|
||||
}
|
||||
|
||||
// mergeTable combines two table definitions
|
||||
// Merges: Columns (map), Constraints (map), Indexes (map), Relationships (map)
|
||||
// Uses first non-empty Description
|
||||
func mergeTable(baseTable, fileTable *models.Table) {
|
||||
// Merge columns (map naturally merges - later keys overwrite)
|
||||
for key, col := range fileTable.Columns {
|
||||
baseTable.Columns[key] = col
|
||||
}
|
||||
|
||||
// Merge constraints
|
||||
for key, constraint := range fileTable.Constraints {
|
||||
baseTable.Constraints[key] = constraint
|
||||
}
|
||||
|
||||
// Merge indexes
|
||||
for key, index := range fileTable.Indexes {
|
||||
baseTable.Indexes[key] = index
|
||||
}
|
||||
|
||||
// Merge relationships
|
||||
for key, rel := range fileTable.Relationships {
|
||||
baseTable.Relationships[key] = rel
|
||||
}
|
||||
|
||||
// Use first non-empty description
|
||||
if baseTable.Description == "" && fileTable.Description != "" {
|
||||
baseTable.Description = fileTable.Description
|
||||
}
|
||||
|
||||
// Merge metadata maps
|
||||
if baseTable.Metadata == nil {
|
||||
baseTable.Metadata = make(map[string]any)
|
||||
}
|
||||
for key, val := range fileTable.Metadata {
|
||||
baseTable.Metadata[key] = val
|
||||
}
|
||||
}
|
||||
|
||||
// mergeSchema finds or creates schema and merges tables
|
||||
func mergeSchema(baseDB *models.Database, fileSchema *models.Schema) {
|
||||
// Find existing schema by name (normalize names by stripping quotes)
|
||||
var existingSchema *models.Schema
|
||||
fileSchemaName := stripQuotes(fileSchema.Name)
|
||||
for _, schema := range baseDB.Schemas {
|
||||
if stripQuotes(schema.Name) == fileSchemaName {
|
||||
existingSchema = schema
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// If schema doesn't exist, add it and return
|
||||
if existingSchema == nil {
|
||||
baseDB.Schemas = append(baseDB.Schemas, fileSchema)
|
||||
return
|
||||
}
|
||||
|
||||
// Merge tables from fileSchema into existingSchema
|
||||
for _, fileTable := range fileSchema.Tables {
|
||||
// Find existing table by name (normalize names by stripping quotes)
|
||||
var existingTable *models.Table
|
||||
fileTableName := stripQuotes(fileTable.Name)
|
||||
for _, table := range existingSchema.Tables {
|
||||
if stripQuotes(table.Name) == fileTableName {
|
||||
existingTable = table
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// If table doesn't exist, add it
|
||||
if existingTable == nil {
|
||||
existingSchema.Tables = append(existingSchema.Tables, fileTable)
|
||||
} else {
|
||||
// Merge table properties - tables are identical, skip
|
||||
mergeTable(existingTable, fileTable)
|
||||
}
|
||||
}
|
||||
|
||||
// Merge other schema properties
|
||||
existingSchema.Views = append(existingSchema.Views, fileSchema.Views...)
|
||||
existingSchema.Sequences = append(existingSchema.Sequences, fileSchema.Sequences...)
|
||||
existingSchema.Scripts = append(existingSchema.Scripts, fileSchema.Scripts...)
|
||||
|
||||
// Merge permissions
|
||||
if existingSchema.Permissions == nil {
|
||||
existingSchema.Permissions = make(map[string]string)
|
||||
}
|
||||
for key, val := range fileSchema.Permissions {
|
||||
existingSchema.Permissions[key] = val
|
||||
}
|
||||
|
||||
// Merge metadata
|
||||
if existingSchema.Metadata == nil {
|
||||
existingSchema.Metadata = make(map[string]any)
|
||||
}
|
||||
for key, val := range fileSchema.Metadata {
|
||||
existingSchema.Metadata[key] = val
|
||||
}
|
||||
}
|
||||
|
||||
// mergeDatabase merges schemas from fileDB into baseDB
|
||||
func mergeDatabase(baseDB, fileDB *models.Database) {
|
||||
// Merge each schema from fileDB
|
||||
for _, fileSchema := range fileDB.Schemas {
|
||||
mergeSchema(baseDB, fileSchema)
|
||||
}
|
||||
|
||||
// Merge domains
|
||||
baseDB.Domains = append(baseDB.Domains, fileDB.Domains...)
|
||||
|
||||
// Use first non-empty description
|
||||
if baseDB.Description == "" && fileDB.Description != "" {
|
||||
baseDB.Description = fileDB.Description
|
||||
}
|
||||
}
|
||||
|
||||
// parseDBML parses DBML content and returns a Database model
|
||||
func (r *Reader) parseDBML(content string) (*models.Database, error) {
|
||||
db := models.InitDatabase("database")
|
||||
@@ -109,7 +449,9 @@ func (r *Reader) parseDBML(content string) (*models.Database, error) {
|
||||
// Parse Table definition
|
||||
if matches := tableRegex.FindStringSubmatch(line); matches != nil {
|
||||
tableName := matches[1]
|
||||
parts := strings.Split(tableName, ".")
|
||||
// Strip comments/notes before parsing to avoid dots in notes
|
||||
tableName = strings.TrimSpace(regexp.MustCompile(`\s*\[.*?\]\s*`).ReplaceAllString(tableName, ""))
|
||||
parts := splitIdentifier(tableName)
|
||||
|
||||
if len(parts) == 2 {
|
||||
currentSchema = stripQuotes(parts[0])
|
||||
@@ -261,8 +603,10 @@ func (r *Reader) parseColumn(line, tableName, schemaName string) (*models.Column
|
||||
column.Default = strings.Trim(defaultVal, "'\"")
|
||||
} else if attr == "unique" {
|
||||
// Create a unique constraint
|
||||
// Clean table name by removing leading underscores to avoid double underscores
|
||||
cleanTableName := strings.TrimLeft(tableName, "_")
|
||||
uniqueConstraint := models.InitConstraint(
|
||||
fmt.Sprintf("uq_%s", columnName),
|
||||
fmt.Sprintf("ukey_%s_%s", cleanTableName, columnName),
|
||||
models.UniqueConstraint,
|
||||
)
|
||||
uniqueConstraint.Schema = schemaName
|
||||
@@ -287,10 +631,10 @@ func (r *Reader) parseColumn(line, tableName, schemaName string) (*models.Column
|
||||
refOp := strings.TrimSpace(refStr)
|
||||
var isReverse bool
|
||||
if strings.HasPrefix(refOp, "<") {
|
||||
isReverse = column.IsPrimaryKey // < on PK means "is referenced by" (reverse)
|
||||
} else if strings.HasPrefix(refOp, ">") {
|
||||
isReverse = !column.IsPrimaryKey // > on FK means reverse
|
||||
// < means "is referenced by" - only makes sense on PK columns
|
||||
isReverse = column.IsPrimaryKey
|
||||
}
|
||||
// > means "references" - always a forward FK, never reverse
|
||||
|
||||
constraint = r.parseRef(refStr)
|
||||
if constraint != nil {
|
||||
@@ -310,8 +654,8 @@ func (r *Reader) parseColumn(line, tableName, schemaName string) (*models.Column
|
||||
constraint.Table = tableName
|
||||
constraint.Columns = []string{columnName}
|
||||
}
|
||||
// Generate short constraint name based on the column
|
||||
constraint.Name = fmt.Sprintf("fk_%s", constraint.Columns[0])
|
||||
// Generate constraint name based on table and columns
|
||||
constraint.Name = fmt.Sprintf("fk_%s_%s", constraint.Table, strings.Join(constraint.Columns, "_"))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -332,27 +676,31 @@ func (r *Reader) parseIndex(line, tableName, schemaName string) *models.Index {
|
||||
// Format: (columns) [attributes] OR columnname [attributes]
|
||||
var columns []string
|
||||
|
||||
if strings.Contains(line, "(") && strings.Contains(line, ")") {
|
||||
// Find the attributes section to avoid parsing parentheses in notes/attributes
|
||||
attrStart := strings.Index(line, "[")
|
||||
columnPart := line
|
||||
if attrStart > 0 {
|
||||
columnPart = line[:attrStart]
|
||||
}
|
||||
|
||||
if strings.Contains(columnPart, "(") && strings.Contains(columnPart, ")") {
|
||||
// Multi-column format: (col1, col2) [attributes]
|
||||
colStart := strings.Index(line, "(")
|
||||
colEnd := strings.Index(line, ")")
|
||||
colStart := strings.Index(columnPart, "(")
|
||||
colEnd := strings.Index(columnPart, ")")
|
||||
if colStart >= colEnd {
|
||||
return nil
|
||||
}
|
||||
|
||||
columnsStr := line[colStart+1 : colEnd]
|
||||
columnsStr := columnPart[colStart+1 : colEnd]
|
||||
for _, col := range strings.Split(columnsStr, ",") {
|
||||
columns = append(columns, stripQuotes(strings.TrimSpace(col)))
|
||||
}
|
||||
} else if strings.Contains(line, "[") {
|
||||
} else if attrStart > 0 {
|
||||
// Single column format: columnname [attributes]
|
||||
// Extract column name before the bracket
|
||||
idx := strings.Index(line, "[")
|
||||
if idx > 0 {
|
||||
colName := strings.TrimSpace(line[:idx])
|
||||
if colName != "" {
|
||||
columns = []string{stripQuotes(colName)}
|
||||
}
|
||||
colName := strings.TrimSpace(columnPart)
|
||||
if colName != "" {
|
||||
columns = []string{stripQuotes(colName)}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -391,7 +739,11 @@ func (r *Reader) parseIndex(line, tableName, schemaName string) *models.Index {
|
||||
|
||||
// Generate name if not provided
|
||||
if index.Name == "" {
|
||||
index.Name = fmt.Sprintf("idx_%s_%s", tableName, strings.Join(columns, "_"))
|
||||
prefix := "idx"
|
||||
if index.Unique {
|
||||
prefix = "uidx"
|
||||
}
|
||||
index.Name = fmt.Sprintf("%s_%s_%s", prefix, tableName, strings.Join(columns, "_"))
|
||||
}
|
||||
|
||||
return index
|
||||
@@ -451,10 +803,10 @@ func (r *Reader) parseRef(refStr string) *models.Constraint {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Generate short constraint name based on the source column
|
||||
constraintName := fmt.Sprintf("fk_%s_%s", fromTable, toTable)
|
||||
if len(fromColumns) > 0 {
|
||||
constraintName = fmt.Sprintf("fk_%s", fromColumns[0])
|
||||
// Generate constraint name based on table and columns
|
||||
constraintName := fmt.Sprintf("fk_%s_%s", fromTable, strings.Join(fromColumns, "_"))
|
||||
if len(fromColumns) == 0 {
|
||||
constraintName = fmt.Sprintf("fk_%s_%s", fromTable, toTable)
|
||||
}
|
||||
|
||||
constraint := models.InitConstraint(
|
||||
@@ -510,7 +862,7 @@ func (r *Reader) parseTableRef(ref string) (schema, table string, columns []stri
|
||||
}
|
||||
|
||||
// Parse schema, table, and optionally column
|
||||
parts := strings.Split(strings.TrimSpace(ref), ".")
|
||||
parts := splitIdentifier(strings.TrimSpace(ref))
|
||||
if len(parts) == 3 {
|
||||
// Format: "schema"."table"."column"
|
||||
schema = stripQuotes(parts[0])
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package dbml
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
@@ -517,3 +518,356 @@ func TestGetForeignKeys(t *testing.T) {
|
||||
t.Error("Expected foreign key constraint type")
|
||||
}
|
||||
}
|
||||
|
||||
// Tests for multi-file directory loading
|
||||
|
||||
func TestReadDirectory_MultipleFiles(t *testing.T) {
|
||||
opts := &readers.ReaderOptions{
|
||||
FilePath: filepath.Join("..", "..", "..", "tests", "assets", "dbml", "multifile"),
|
||||
}
|
||||
|
||||
reader := NewReader(opts)
|
||||
db, err := reader.ReadDatabase()
|
||||
if err != nil {
|
||||
t.Fatalf("ReadDatabase() error = %v", err)
|
||||
}
|
||||
|
||||
if db == nil {
|
||||
t.Fatal("ReadDatabase() returned nil database")
|
||||
}
|
||||
|
||||
// Should have public schema
|
||||
if len(db.Schemas) == 0 {
|
||||
t.Fatal("Expected at least one schema")
|
||||
}
|
||||
|
||||
var publicSchema *models.Schema
|
||||
for _, schema := range db.Schemas {
|
||||
if schema.Name == "public" {
|
||||
publicSchema = schema
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if publicSchema == nil {
|
||||
t.Fatal("Public schema not found")
|
||||
}
|
||||
|
||||
// Should have 3 tables: users, posts, comments
|
||||
if len(publicSchema.Tables) != 3 {
|
||||
t.Fatalf("Expected 3 tables, got %d", len(publicSchema.Tables))
|
||||
}
|
||||
|
||||
// Find tables
|
||||
var usersTable, postsTable, commentsTable *models.Table
|
||||
for _, table := range publicSchema.Tables {
|
||||
switch table.Name {
|
||||
case "users":
|
||||
usersTable = table
|
||||
case "posts":
|
||||
postsTable = table
|
||||
case "comments":
|
||||
commentsTable = table
|
||||
}
|
||||
}
|
||||
|
||||
if usersTable == nil {
|
||||
t.Fatal("Users table not found")
|
||||
}
|
||||
if postsTable == nil {
|
||||
t.Fatal("Posts table not found")
|
||||
}
|
||||
if commentsTable == nil {
|
||||
t.Fatal("Comments table not found")
|
||||
}
|
||||
|
||||
// Verify users table has merged columns from 1_users.dbml and 3_add_columns.dbml
|
||||
expectedUserColumns := []string{"id", "email", "name", "created_at"}
|
||||
if len(usersTable.Columns) != len(expectedUserColumns) {
|
||||
t.Errorf("Expected %d columns in users table, got %d", len(expectedUserColumns), len(usersTable.Columns))
|
||||
}
|
||||
|
||||
for _, colName := range expectedUserColumns {
|
||||
if _, exists := usersTable.Columns[colName]; !exists {
|
||||
t.Errorf("Expected column '%s' in users table", colName)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify posts table columns
|
||||
expectedPostColumns := []string{"id", "user_id", "title", "content", "created_at"}
|
||||
for _, colName := range expectedPostColumns {
|
||||
if _, exists := postsTable.Columns[colName]; !exists {
|
||||
t.Errorf("Expected column '%s' in posts table", colName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadDirectory_TableMerging(t *testing.T) {
|
||||
opts := &readers.ReaderOptions{
|
||||
FilePath: filepath.Join("..", "..", "..", "tests", "assets", "dbml", "multifile"),
|
||||
}
|
||||
|
||||
reader := NewReader(opts)
|
||||
db, err := reader.ReadDatabase()
|
||||
if err != nil {
|
||||
t.Fatalf("ReadDatabase() error = %v", err)
|
||||
}
|
||||
|
||||
// Find users table
|
||||
var usersTable *models.Table
|
||||
for _, schema := range db.Schemas {
|
||||
for _, table := range schema.Tables {
|
||||
if table.Name == "users" && schema.Name == "public" {
|
||||
usersTable = table
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if usersTable == nil {
|
||||
t.Fatal("Users table not found")
|
||||
}
|
||||
|
||||
// Verify columns from file 1 (id, email)
|
||||
if _, exists := usersTable.Columns["id"]; !exists {
|
||||
t.Error("Column 'id' from 1_users.dbml not found")
|
||||
}
|
||||
if _, exists := usersTable.Columns["email"]; !exists {
|
||||
t.Error("Column 'email' from 1_users.dbml not found")
|
||||
}
|
||||
|
||||
// Verify columns from file 3 (name, created_at)
|
||||
if _, exists := usersTable.Columns["name"]; !exists {
|
||||
t.Error("Column 'name' from 3_add_columns.dbml not found")
|
||||
}
|
||||
if _, exists := usersTable.Columns["created_at"]; !exists {
|
||||
t.Error("Column 'created_at' from 3_add_columns.dbml not found")
|
||||
}
|
||||
|
||||
// Verify column properties from file 1
|
||||
emailCol := usersTable.Columns["email"]
|
||||
if !emailCol.NotNull {
|
||||
t.Error("Email column should be not null (from 1_users.dbml)")
|
||||
}
|
||||
if emailCol.Type != "varchar(255)" {
|
||||
t.Errorf("Expected email type 'varchar(255)', got '%s'", emailCol.Type)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadDirectory_CommentedRefsLast(t *testing.T) {
|
||||
// This test verifies that files with commented refs are processed last
|
||||
// by checking that the file discovery returns them in the correct order
|
||||
dirPath := filepath.Join("..", "..", "..", "tests", "assets", "dbml", "multifile")
|
||||
|
||||
opts := &readers.ReaderOptions{
|
||||
FilePath: dirPath,
|
||||
}
|
||||
|
||||
reader := NewReader(opts)
|
||||
files, err := reader.discoverDBMLFiles(dirPath)
|
||||
if err != nil {
|
||||
t.Fatalf("discoverDBMLFiles() error = %v", err)
|
||||
}
|
||||
|
||||
if len(files) < 2 {
|
||||
t.Skip("Not enough files to test ordering")
|
||||
}
|
||||
|
||||
// Check that 9_refs.dbml (which has commented refs) comes last
|
||||
lastFile := filepath.Base(files[len(files)-1])
|
||||
if lastFile != "9_refs.dbml" {
|
||||
t.Errorf("Expected last file to be '9_refs.dbml' (has commented refs), got '%s'", lastFile)
|
||||
}
|
||||
|
||||
// Check that numbered files without commented refs come first
|
||||
firstFile := filepath.Base(files[0])
|
||||
if firstFile != "1_users.dbml" {
|
||||
t.Errorf("Expected first file to be '1_users.dbml', got '%s'", firstFile)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadDirectory_EmptyDirectory(t *testing.T) {
|
||||
// Create a temporary empty directory
|
||||
tmpDir := filepath.Join("..", "..", "..", "tests", "assets", "dbml", "empty_test_dir")
|
||||
err := os.MkdirAll(tmpDir, 0755)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp directory: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
opts := &readers.ReaderOptions{
|
||||
FilePath: tmpDir,
|
||||
}
|
||||
|
||||
reader := NewReader(opts)
|
||||
db, err := reader.ReadDatabase()
|
||||
if err != nil {
|
||||
t.Fatalf("ReadDatabase() should not error on empty directory, got: %v", err)
|
||||
}
|
||||
|
||||
if db == nil {
|
||||
t.Fatal("ReadDatabase() returned nil database")
|
||||
}
|
||||
|
||||
// Empty directory should return empty database
|
||||
if len(db.Schemas) != 0 {
|
||||
t.Errorf("Expected 0 schemas for empty directory, got %d", len(db.Schemas))
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadDatabase_BackwardCompat(t *testing.T) {
|
||||
// Test that single file loading still works
|
||||
opts := &readers.ReaderOptions{
|
||||
FilePath: filepath.Join("..", "..", "..", "tests", "assets", "dbml", "simple.dbml"),
|
||||
}
|
||||
|
||||
reader := NewReader(opts)
|
||||
db, err := reader.ReadDatabase()
|
||||
if err != nil {
|
||||
t.Fatalf("ReadDatabase() error = %v", err)
|
||||
}
|
||||
|
||||
if db == nil {
|
||||
t.Fatal("ReadDatabase() returned nil database")
|
||||
}
|
||||
|
||||
if len(db.Schemas) == 0 {
|
||||
t.Fatal("Expected at least one schema")
|
||||
}
|
||||
|
||||
schema := db.Schemas[0]
|
||||
if len(schema.Tables) != 1 {
|
||||
t.Fatalf("Expected 1 table, got %d", len(schema.Tables))
|
||||
}
|
||||
|
||||
table := schema.Tables[0]
|
||||
if table.Name != "users" {
|
||||
t.Errorf("Expected table name 'users', got '%s'", table.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseFilePrefix(t *testing.T) {
|
||||
tests := []struct {
|
||||
filename string
|
||||
wantPrefix int
|
||||
wantHas bool
|
||||
}{
|
||||
{"1_schema.dbml", 1, true},
|
||||
{"2_tables.dbml", 2, true},
|
||||
{"10_relationships.dbml", 10, true},
|
||||
{"99_data.dbml", 99, true},
|
||||
{"schema.dbml", 0, false},
|
||||
{"tables_no_prefix.dbml", 0, false},
|
||||
{"/path/to/1_file.dbml", 1, true},
|
||||
{"/path/to/file.dbml", 0, false},
|
||||
{"1-file.dbml", 1, true},
|
||||
{"2-another.dbml", 2, true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.filename, func(t *testing.T) {
|
||||
gotPrefix, gotHas := parseFilePrefix(tt.filename)
|
||||
if gotPrefix != tt.wantPrefix {
|
||||
t.Errorf("parseFilePrefix(%s) prefix = %d, want %d", tt.filename, gotPrefix, tt.wantPrefix)
|
||||
}
|
||||
if gotHas != tt.wantHas {
|
||||
t.Errorf("parseFilePrefix(%s) hasPrefix = %v, want %v", tt.filename, gotHas, tt.wantHas)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConstraintNaming(t *testing.T) {
|
||||
// Test that constraints are named with proper prefixes
|
||||
opts := &readers.ReaderOptions{
|
||||
FilePath: filepath.Join("..", "..", "..", "tests", "assets", "dbml", "complex.dbml"),
|
||||
}
|
||||
|
||||
reader := NewReader(opts)
|
||||
db, err := reader.ReadDatabase()
|
||||
if err != nil {
|
||||
t.Fatalf("ReadDatabase() error = %v", err)
|
||||
}
|
||||
|
||||
// Find users table
|
||||
var usersTable *models.Table
|
||||
var postsTable *models.Table
|
||||
for _, schema := range db.Schemas {
|
||||
for _, table := range schema.Tables {
|
||||
if table.Name == "users" {
|
||||
usersTable = table
|
||||
} else if table.Name == "posts" {
|
||||
postsTable = table
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if usersTable == nil {
|
||||
t.Fatal("Users table not found")
|
||||
}
|
||||
if postsTable == nil {
|
||||
t.Fatal("Posts table not found")
|
||||
}
|
||||
|
||||
// Test unique constraint naming: ukey_table_column
|
||||
if _, exists := usersTable.Constraints["ukey_users_email"]; !exists {
|
||||
t.Error("Expected unique constraint 'ukey_users_email' not found")
|
||||
t.Logf("Available constraints: %v", getKeys(usersTable.Constraints))
|
||||
}
|
||||
|
||||
if _, exists := postsTable.Constraints["ukey_posts_slug"]; !exists {
|
||||
t.Error("Expected unique constraint 'ukey_posts_slug' not found")
|
||||
t.Logf("Available constraints: %v", getKeys(postsTable.Constraints))
|
||||
}
|
||||
|
||||
// Test foreign key naming: fk_table_column
|
||||
if _, exists := postsTable.Constraints["fk_posts_user_id"]; !exists {
|
||||
t.Error("Expected foreign key 'fk_posts_user_id' not found")
|
||||
t.Logf("Available constraints: %v", getKeys(postsTable.Constraints))
|
||||
}
|
||||
|
||||
// Test unique index naming: uidx_table_columns
|
||||
if _, exists := postsTable.Indexes["uidx_posts_slug"]; !exists {
|
||||
t.Error("Expected unique index 'uidx_posts_slug' not found")
|
||||
t.Logf("Available indexes: %v", getKeys(postsTable.Indexes))
|
||||
}
|
||||
|
||||
// Test regular index naming: idx_table_columns
|
||||
if _, exists := postsTable.Indexes["idx_posts_user_id_published"]; !exists {
|
||||
t.Error("Expected index 'idx_posts_user_id_published' not found")
|
||||
t.Logf("Available indexes: %v", getKeys(postsTable.Indexes))
|
||||
}
|
||||
}
|
||||
|
||||
func getKeys[V any](m map[string]V) []string {
|
||||
keys := make([]string, 0, len(m))
|
||||
for k := range m {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
return keys
|
||||
}
|
||||
|
||||
func TestHasCommentedRefs(t *testing.T) {
|
||||
// Test with the actual multifile test fixtures
|
||||
tests := []struct {
|
||||
filename string
|
||||
wantHas bool
|
||||
}{
|
||||
{filepath.Join("..", "..", "..", "tests", "assets", "dbml", "multifile", "1_users.dbml"), false},
|
||||
{filepath.Join("..", "..", "..", "tests", "assets", "dbml", "multifile", "2_posts.dbml"), false},
|
||||
{filepath.Join("..", "..", "..", "tests", "assets", "dbml", "multifile", "3_add_columns.dbml"), false},
|
||||
{filepath.Join("..", "..", "..", "tests", "assets", "dbml", "multifile", "9_refs.dbml"), true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(filepath.Base(tt.filename), func(t *testing.T) {
|
||||
gotHas, err := hasCommentedRefs(tt.filename)
|
||||
if err != nil {
|
||||
t.Fatalf("hasCommentedRefs() error = %v", err)
|
||||
}
|
||||
if gotHas != tt.wantHas {
|
||||
t.Errorf("hasCommentedRefs(%s) = %v, want %v", filepath.Base(tt.filename), gotHas, tt.wantHas)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -79,6 +79,8 @@ func (r *Reader) convertToDatabase(dctx *models.DCTXDictionary) (*models.Databas
|
||||
db := models.InitDatabase(dbName)
|
||||
schema := models.InitSchema("public")
|
||||
|
||||
// Note: DCTX doesn't have database GUID, but schema can use dictionary name if available
|
||||
|
||||
// Create GUID mappings for tables and keys
|
||||
tableGuidMap := make(map[string]string) // GUID -> table name
|
||||
keyGuidMap := make(map[string]*models.DCTXKey) // GUID -> key definition
|
||||
@@ -162,6 +164,10 @@ func (r *Reader) convertTable(dctxTable *models.DCTXTable) (*models.Table, map[s
|
||||
tableName := r.sanitizeName(dctxTable.Name)
|
||||
table := models.InitTable(tableName, "public")
|
||||
table.Description = dctxTable.Description
|
||||
// Assign GUID from DCTX table
|
||||
if dctxTable.Guid != "" {
|
||||
table.GUID = dctxTable.Guid
|
||||
}
|
||||
|
||||
fieldGuidMap := make(map[string]string)
|
||||
|
||||
@@ -202,6 +208,10 @@ func (r *Reader) convertField(dctxField *models.DCTXField, tableName string) ([]
|
||||
|
||||
// Convert single field
|
||||
column := models.InitColumn(r.sanitizeName(dctxField.Name), tableName, "public")
|
||||
// Assign GUID from DCTX field
|
||||
if dctxField.Guid != "" {
|
||||
column.GUID = dctxField.Guid
|
||||
}
|
||||
|
||||
// Map Clarion data types
|
||||
dataType, length := r.mapDataType(dctxField.DataType, dctxField.Size)
|
||||
@@ -346,6 +356,10 @@ func (r *Reader) convertKey(dctxKey *models.DCTXKey, table *models.Table, fieldG
|
||||
constraint.Table = table.Name
|
||||
constraint.Schema = table.Schema
|
||||
constraint.Columns = columns
|
||||
// Assign GUID from DCTX key
|
||||
if dctxKey.Guid != "" {
|
||||
constraint.GUID = dctxKey.Guid
|
||||
}
|
||||
|
||||
table.Constraints[constraint.Name] = constraint
|
||||
|
||||
@@ -366,6 +380,10 @@ func (r *Reader) convertKey(dctxKey *models.DCTXKey, table *models.Table, fieldG
|
||||
index.Columns = columns
|
||||
index.Unique = dctxKey.Unique
|
||||
index.Type = "btree"
|
||||
// Assign GUID from DCTX key
|
||||
if dctxKey.Guid != "" {
|
||||
index.GUID = dctxKey.Guid
|
||||
}
|
||||
|
||||
table.Indexes[index.Name] = index
|
||||
return nil
|
||||
@@ -460,6 +478,10 @@ func (r *Reader) processRelations(dctx *models.DCTXDictionary, schema *models.Sc
|
||||
constraint.ReferencedColumns = pkColumns
|
||||
constraint.OnDelete = r.mapReferentialAction(relation.Delete)
|
||||
constraint.OnUpdate = r.mapReferentialAction(relation.Update)
|
||||
// Assign GUID from DCTX relation
|
||||
if relation.Guid != "" {
|
||||
constraint.GUID = relation.Guid
|
||||
}
|
||||
|
||||
foreignTable.Constraints[fkName] = constraint
|
||||
|
||||
@@ -473,6 +495,10 @@ func (r *Reader) processRelations(dctx *models.DCTXDictionary, schema *models.Sc
|
||||
relationship.ForeignKey = fkName
|
||||
relationship.Properties["on_delete"] = constraint.OnDelete
|
||||
relationship.Properties["on_update"] = constraint.OnUpdate
|
||||
// Assign GUID from DCTX relation
|
||||
if relation.Guid != "" {
|
||||
relationship.GUID = relation.Guid
|
||||
}
|
||||
|
||||
foreignTable.Relationships[relationshipName] = relationship
|
||||
}
|
||||
|
||||
@@ -140,6 +140,32 @@ func (r *Reader) convertToDatabase(drawSchema *drawdb.DrawDBSchema) (*models.Dat
|
||||
db.Schemas = append(db.Schemas, schema)
|
||||
}
|
||||
|
||||
// Convert DrawDB subject areas to domains
|
||||
for _, area := range drawSchema.SubjectAreas {
|
||||
domain := models.InitDomain(area.Name)
|
||||
|
||||
// Find all tables that visually belong to this area
|
||||
// A table belongs to an area if its position is within the area bounds
|
||||
for _, drawTable := range drawSchema.Tables {
|
||||
if drawTable.X >= area.X && drawTable.X <= (area.X+area.Width) &&
|
||||
drawTable.Y >= area.Y && drawTable.Y <= (area.Y+area.Height) {
|
||||
|
||||
schemaName := drawTable.Schema
|
||||
if schemaName == "" {
|
||||
schemaName = "public"
|
||||
}
|
||||
|
||||
domainTable := models.InitDomainTable(drawTable.Name, schemaName)
|
||||
domain.Tables = append(domain.Tables, domainTable)
|
||||
}
|
||||
}
|
||||
|
||||
// Only add domain if it has tables
|
||||
if len(domain.Tables) > 0 {
|
||||
db.Domains = append(db.Domains, domain)
|
||||
}
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -241,11 +241,9 @@ func (r *Reader) parsePgEnum(line string, matches []string) *models.Enum {
|
||||
}
|
||||
}
|
||||
|
||||
return &models.Enum{
|
||||
Name: enumName,
|
||||
Values: values,
|
||||
Schema: "public",
|
||||
}
|
||||
enum := models.InitEnum(enumName, "public")
|
||||
enum.Values = values
|
||||
return enum
|
||||
}
|
||||
|
||||
// parseTableBlock parses a complete pgTable definition block
|
||||
|
||||
@@ -260,11 +260,7 @@ func (r *Reader) parseType(typeName string, lines []string, schema *models.Schem
|
||||
}
|
||||
|
||||
func (r *Reader) parseEnum(enumName string, lines []string, schema *models.Schema) {
|
||||
enum := &models.Enum{
|
||||
Name: enumName,
|
||||
Schema: schema.Name,
|
||||
Values: make([]string, 0),
|
||||
}
|
||||
enum := models.InitEnum(enumName, schema.Name)
|
||||
|
||||
for _, line := range lines {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
|
||||
@@ -329,10 +329,10 @@ func (r *Reader) deriveRelationship(table *models.Table, fk *models.Constraint)
|
||||
relationshipName := fmt.Sprintf("%s_to_%s", table.Name, fk.ReferencedTable)
|
||||
|
||||
relationship := models.InitRelationship(relationshipName, models.OneToMany)
|
||||
relationship.FromTable = fk.ReferencedTable
|
||||
relationship.FromSchema = fk.ReferencedSchema
|
||||
relationship.ToTable = table.Name
|
||||
relationship.ToSchema = table.Schema
|
||||
relationship.FromTable = table.Name
|
||||
relationship.FromSchema = table.Schema
|
||||
relationship.ToTable = fk.ReferencedTable
|
||||
relationship.ToSchema = fk.ReferencedSchema
|
||||
relationship.ForeignKey = fk.Name
|
||||
|
||||
// Store constraint actions in properties
|
||||
|
||||
@@ -328,12 +328,12 @@ func TestDeriveRelationship(t *testing.T) {
|
||||
t.Errorf("Expected relationship type %s, got %s", models.OneToMany, rel.Type)
|
||||
}
|
||||
|
||||
if rel.FromTable != "users" {
|
||||
t.Errorf("Expected FromTable 'users', got '%s'", rel.FromTable)
|
||||
if rel.FromTable != "orders" {
|
||||
t.Errorf("Expected FromTable 'orders', got '%s'", rel.FromTable)
|
||||
}
|
||||
|
||||
if rel.ToTable != "orders" {
|
||||
t.Errorf("Expected ToTable 'orders', got '%s'", rel.ToTable)
|
||||
if rel.ToTable != "users" {
|
||||
t.Errorf("Expected ToTable 'users', got '%s'", rel.ToTable)
|
||||
}
|
||||
|
||||
if rel.ForeignKey != "fk_orders_user_id" {
|
||||
|
||||
@@ -128,11 +128,7 @@ func (r *Reader) parsePrisma(content string) (*models.Database, error) {
|
||||
if matches := enumRegex.FindStringSubmatch(trimmed); matches != nil {
|
||||
currentBlock = "enum"
|
||||
enumName := matches[1]
|
||||
currentEnum = &models.Enum{
|
||||
Name: enumName,
|
||||
Schema: "public",
|
||||
Values: make([]string, 0),
|
||||
}
|
||||
currentEnum = models.InitEnum(enumName, "public")
|
||||
blockContent = []string{}
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -93,6 +93,7 @@ fmt.Printf("Found %d scripts\n", len(schema.Scripts))
|
||||
## Features
|
||||
|
||||
- **Recursive Directory Scanning**: Automatically scans all subdirectories
|
||||
- **Symlink Skipping**: Symbolic links are automatically skipped (prevents loops and duplicates)
|
||||
- **Multiple Extensions**: Supports both `.sql` and `.pgsql` files
|
||||
- **Flexible Naming**: Extract metadata from filename patterns
|
||||
- **Error Handling**: Validates directory existence and file accessibility
|
||||
@@ -153,8 +154,9 @@ go test ./pkg/readers/sqldir/
|
||||
```
|
||||
|
||||
Tests include:
|
||||
- Valid file parsing
|
||||
- Valid file parsing (underscore and hyphen formats)
|
||||
- Recursive directory scanning
|
||||
- Symlink skipping
|
||||
- Invalid filename handling
|
||||
- Empty directory handling
|
||||
- Error conditions
|
||||
|
||||
@@ -107,11 +107,20 @@ func (r *Reader) readScripts() ([]*models.Script, error) {
|
||||
return err
|
||||
}
|
||||
|
||||
// Skip directories
|
||||
// Don't process directories as files (WalkDir still descends into them recursively)
|
||||
if d.IsDir() {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Skip symlinks
|
||||
info, err := d.Info()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if info.Mode()&os.ModeSymlink != 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get filename
|
||||
filename := d.Name()
|
||||
|
||||
@@ -150,13 +159,11 @@ func (r *Reader) readScripts() ([]*models.Script, error) {
|
||||
}
|
||||
|
||||
// Create Script model
|
||||
script := &models.Script{
|
||||
Name: name,
|
||||
Description: fmt.Sprintf("SQL script from %s", relPath),
|
||||
SQL: string(content),
|
||||
Priority: priority,
|
||||
Sequence: uint(sequence),
|
||||
}
|
||||
script := models.InitScript(name)
|
||||
script.Description = fmt.Sprintf("SQL script from %s", relPath)
|
||||
script.SQL = string(content)
|
||||
script.Priority = priority
|
||||
script.Sequence = uint(sequence)
|
||||
|
||||
scripts = append(scripts, script)
|
||||
|
||||
|
||||
@@ -373,3 +373,65 @@ func TestReader_MixedFormat(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestReader_SkipSymlinks(t *testing.T) {
|
||||
// Create temporary test directory
|
||||
tempDir, err := os.MkdirTemp("", "sqldir-test-symlink-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp directory: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
// Create a real SQL file
|
||||
realFile := filepath.Join(tempDir, "1_001_real_file.sql")
|
||||
if err := os.WriteFile(realFile, []byte("SELECT 1;"), 0644); err != nil {
|
||||
t.Fatalf("Failed to create real file: %v", err)
|
||||
}
|
||||
|
||||
// Create another file to link to
|
||||
targetFile := filepath.Join(tempDir, "2_001_target.sql")
|
||||
if err := os.WriteFile(targetFile, []byte("SELECT 2;"), 0644); err != nil {
|
||||
t.Fatalf("Failed to create target file: %v", err)
|
||||
}
|
||||
|
||||
// Create a symlink to the target file (this should be skipped)
|
||||
symlinkFile := filepath.Join(tempDir, "3_001_symlink.sql")
|
||||
if err := os.Symlink(targetFile, symlinkFile); err != nil {
|
||||
// Skip test on systems that don't support symlinks (e.g., Windows without admin)
|
||||
t.Skipf("Symlink creation not supported: %v", err)
|
||||
}
|
||||
|
||||
// Create reader
|
||||
reader := NewReader(&readers.ReaderOptions{
|
||||
FilePath: tempDir,
|
||||
})
|
||||
|
||||
// Read database
|
||||
db, err := reader.ReadDatabase()
|
||||
if err != nil {
|
||||
t.Fatalf("ReadDatabase failed: %v", err)
|
||||
}
|
||||
|
||||
schema := db.Schemas[0]
|
||||
|
||||
// Should only have 2 scripts (real_file and target), symlink should be skipped
|
||||
if len(schema.Scripts) != 2 {
|
||||
t.Errorf("Expected 2 scripts (symlink should be skipped), got %d", len(schema.Scripts))
|
||||
}
|
||||
|
||||
// Verify the scripts are the real files, not the symlink
|
||||
scriptNames := make(map[string]bool)
|
||||
for _, script := range schema.Scripts {
|
||||
scriptNames[script.Name] = true
|
||||
}
|
||||
|
||||
if !scriptNames["real_file"] {
|
||||
t.Error("Expected 'real_file' script to be present")
|
||||
}
|
||||
if !scriptNames["target"] {
|
||||
t.Error("Expected 'target' script to be present")
|
||||
}
|
||||
if scriptNames["symlink"] {
|
||||
t.Error("Symlink script should have been skipped but was found")
|
||||
}
|
||||
}
|
||||
|
||||
490
pkg/reflectutil/helpers_test.go
Normal file
490
pkg/reflectutil/helpers_test.go
Normal file
@@ -0,0 +1,490 @@
|
||||
package reflectutil
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type testStruct struct {
|
||||
Name string
|
||||
Age int
|
||||
Active bool
|
||||
Nested *nestedStruct
|
||||
Private string
|
||||
}
|
||||
|
||||
type nestedStruct struct {
|
||||
Value string
|
||||
Count int
|
||||
}
|
||||
|
||||
func TestDeref(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
wantValid bool
|
||||
wantKind reflect.Kind
|
||||
}{
|
||||
{
|
||||
name: "non-pointer int",
|
||||
input: 42,
|
||||
wantValid: true,
|
||||
wantKind: reflect.Int,
|
||||
},
|
||||
{
|
||||
name: "single pointer",
|
||||
input: ptrInt(42),
|
||||
wantValid: true,
|
||||
wantKind: reflect.Int,
|
||||
},
|
||||
{
|
||||
name: "double pointer",
|
||||
input: ptrPtr(ptrInt(42)),
|
||||
wantValid: true,
|
||||
wantKind: reflect.Int,
|
||||
},
|
||||
{
|
||||
name: "nil pointer",
|
||||
input: (*int)(nil),
|
||||
wantValid: false,
|
||||
wantKind: reflect.Ptr,
|
||||
},
|
||||
{
|
||||
name: "string",
|
||||
input: "test",
|
||||
wantValid: true,
|
||||
wantKind: reflect.String,
|
||||
},
|
||||
{
|
||||
name: "struct",
|
||||
input: testStruct{Name: "test"},
|
||||
wantValid: true,
|
||||
wantKind: reflect.Struct,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
v := reflect.ValueOf(tt.input)
|
||||
got, valid := Deref(v)
|
||||
|
||||
if valid != tt.wantValid {
|
||||
t.Errorf("Deref() valid = %v, want %v", valid, tt.wantValid)
|
||||
}
|
||||
|
||||
if got.Kind() != tt.wantKind {
|
||||
t.Errorf("Deref() kind = %v, want %v", got.Kind(), tt.wantKind)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDerefInterface(t *testing.T) {
|
||||
i := 42
|
||||
pi := &i
|
||||
ppi := &pi
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
wantKind reflect.Kind
|
||||
}{
|
||||
{"int", 42, reflect.Int},
|
||||
{"pointer to int", &i, reflect.Int},
|
||||
{"double pointer to int", ppi, reflect.Int},
|
||||
{"string", "test", reflect.String},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := DerefInterface(tt.input)
|
||||
if got.Kind() != tt.wantKind {
|
||||
t.Errorf("DerefInterface() kind = %v, want %v", got.Kind(), tt.wantKind)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetFieldValue(t *testing.T) {
|
||||
ts := testStruct{
|
||||
Name: "John",
|
||||
Age: 30,
|
||||
Active: true,
|
||||
Nested: &nestedStruct{Value: "nested", Count: 5},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
item interface{}
|
||||
field string
|
||||
want interface{}
|
||||
}{
|
||||
{"struct field Name", ts, "Name", "John"},
|
||||
{"struct field Age", ts, "Age", 30},
|
||||
{"struct field Active", ts, "Active", true},
|
||||
{"struct non-existent field", ts, "NonExistent", nil},
|
||||
{"pointer to struct", &ts, "Name", "John"},
|
||||
{"map string key", map[string]string{"key": "value"}, "key", "value"},
|
||||
{"map int key", map[string]int{"count": 42}, "count", 42},
|
||||
{"map non-existent key", map[string]string{"key": "value"}, "missing", nil},
|
||||
{"nil pointer", (*testStruct)(nil), "Name", nil},
|
||||
{"non-struct non-map", 42, "field", nil},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := GetFieldValue(tt.item, tt.field)
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("GetFieldValue() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsSliceOrArray(t *testing.T) {
|
||||
arr := [3]int{1, 2, 3}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
want bool
|
||||
}{
|
||||
{"slice", []int{1, 2, 3}, true},
|
||||
{"array", arr, true},
|
||||
{"pointer to slice", &[]int{1, 2, 3}, true},
|
||||
{"string", "test", false},
|
||||
{"int", 42, false},
|
||||
{"map", map[string]int{}, false},
|
||||
{"nil slice", ([]int)(nil), true}, // nil slice is still Kind==Slice
|
||||
{"nil pointer", (*[]int)(nil), false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := IsSliceOrArray(tt.input)
|
||||
if got != tt.want {
|
||||
t.Errorf("IsSliceOrArray() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsMap(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
want bool
|
||||
}{
|
||||
{"map[string]int", map[string]int{"a": 1}, true},
|
||||
{"map[int]string", map[int]string{1: "a"}, true},
|
||||
{"pointer to map", &map[string]int{"a": 1}, true},
|
||||
{"slice", []int{1, 2, 3}, false},
|
||||
{"string", "test", false},
|
||||
{"int", 42, false},
|
||||
{"nil map", (map[string]int)(nil), true}, // nil map is still Kind==Map
|
||||
{"nil pointer", (*map[string]int)(nil), false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := IsMap(tt.input)
|
||||
if got != tt.want {
|
||||
t.Errorf("IsMap() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSliceLen(t *testing.T) {
|
||||
arr := [3]int{1, 2, 3}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
want int
|
||||
}{
|
||||
{"slice length 3", []int{1, 2, 3}, 3},
|
||||
{"empty slice", []int{}, 0},
|
||||
{"array length 3", arr, 3},
|
||||
{"pointer to slice", &[]int{1, 2, 3}, 3},
|
||||
{"not a slice", "test", 0},
|
||||
{"int", 42, 0},
|
||||
{"nil slice", ([]int)(nil), 0},
|
||||
{"nil pointer", (*[]int)(nil), 0},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := SliceLen(tt.input)
|
||||
if got != tt.want {
|
||||
t.Errorf("SliceLen() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMapLen(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
want int
|
||||
}{
|
||||
{"map length 2", map[string]int{"a": 1, "b": 2}, 2},
|
||||
{"empty map", map[string]int{}, 0},
|
||||
{"pointer to map", &map[string]int{"a": 1}, 1},
|
||||
{"not a map", []int{1, 2, 3}, 0},
|
||||
{"string", "test", 0},
|
||||
{"nil map", (map[string]int)(nil), 0},
|
||||
{"nil pointer", (*map[string]int)(nil), 0},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := MapLen(tt.input)
|
||||
if got != tt.want {
|
||||
t.Errorf("MapLen() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSliceToInterfaces(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
want []interface{}
|
||||
}{
|
||||
{"int slice", []int{1, 2, 3}, []interface{}{1, 2, 3}},
|
||||
{"string slice", []string{"a", "b"}, []interface{}{"a", "b"}},
|
||||
{"empty slice", []int{}, []interface{}{}},
|
||||
{"pointer to slice", &[]int{1, 2}, []interface{}{1, 2}},
|
||||
{"not a slice", "test", []interface{}{}},
|
||||
{"nil slice", ([]int)(nil), []interface{}{}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := SliceToInterfaces(tt.input)
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("SliceToInterfaces() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMapKeys(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
want []interface{}
|
||||
}{
|
||||
{"map with keys", map[string]int{"a": 1, "b": 2}, []interface{}{"a", "b"}},
|
||||
{"empty map", map[string]int{}, []interface{}{}},
|
||||
{"not a map", []int{1, 2, 3}, []interface{}{}},
|
||||
{"nil map", (map[string]int)(nil), []interface{}{}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := MapKeys(tt.input)
|
||||
if len(got) != len(tt.want) {
|
||||
t.Errorf("MapKeys() length = %v, want %v", len(got), len(tt.want))
|
||||
}
|
||||
// For maps, order is not guaranteed, so just check length
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMapValues(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
want int // length of values
|
||||
}{
|
||||
{"map with values", map[string]int{"a": 1, "b": 2}, 2},
|
||||
{"empty map", map[string]int{}, 0},
|
||||
{"not a map", []int{1, 2, 3}, 0},
|
||||
{"nil map", (map[string]int)(nil), 0},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := MapValues(tt.input)
|
||||
if len(got) != tt.want {
|
||||
t.Errorf("MapValues() length = %v, want %v", len(got), tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMapGet(t *testing.T) {
|
||||
m := map[string]int{"a": 1, "b": 2}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
key interface{}
|
||||
want interface{}
|
||||
}{
|
||||
{"existing key", m, "a", 1},
|
||||
{"existing key b", m, "b", 2},
|
||||
{"non-existing key", m, "c", nil},
|
||||
{"pointer to map", &m, "a", 1},
|
||||
{"not a map", []int{1, 2}, 0, nil},
|
||||
{"nil map", (map[string]int)(nil), "a", nil},
|
||||
{"nil pointer", (*map[string]int)(nil), "a", nil},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := MapGet(tt.input, tt.key)
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("MapGet() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSliceIndex(t *testing.T) {
|
||||
s := []int{10, 20, 30}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
slice interface{}
|
||||
index int
|
||||
want interface{}
|
||||
}{
|
||||
{"index 0", s, 0, 10},
|
||||
{"index 1", s, 1, 20},
|
||||
{"index 2", s, 2, 30},
|
||||
{"negative index", s, -1, nil},
|
||||
{"out of bounds", s, 5, nil},
|
||||
{"pointer to slice", &s, 1, 20},
|
||||
{"not a slice", "test", 0, nil},
|
||||
{"nil slice", ([]int)(nil), 0, nil},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := SliceIndex(tt.slice, tt.index)
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("SliceIndex() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompareValues(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
a interface{}
|
||||
b interface{}
|
||||
want int
|
||||
}{
|
||||
{"both nil", nil, nil, 0},
|
||||
{"a nil", nil, 5, -1},
|
||||
{"b nil", 5, nil, 1},
|
||||
{"equal strings", "abc", "abc", 0},
|
||||
{"a less than b strings", "abc", "xyz", -1},
|
||||
{"a greater than b strings", "xyz", "abc", 1},
|
||||
{"equal ints", 5, 5, 0},
|
||||
{"a less than b ints", 3, 7, -1},
|
||||
{"a greater than b ints", 10, 5, 1},
|
||||
{"equal floats", 3.14, 3.14, 0},
|
||||
{"a less than b floats", 2.5, 5.5, -1},
|
||||
{"a greater than b floats", 10.5, 5.5, 1},
|
||||
{"equal uints", uint(5), uint(5), 0},
|
||||
{"different types", "abc", 123, 0},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := CompareValues(tt.a, tt.b)
|
||||
if got != tt.want {
|
||||
t.Errorf("CompareValues(%v, %v) = %v, want %v", tt.a, tt.b, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetNestedValue(t *testing.T) {
|
||||
nested := map[string]interface{}{
|
||||
"level1": map[string]interface{}{
|
||||
"level2": map[string]interface{}{
|
||||
"value": "deep",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ts := testStruct{
|
||||
Name: "John",
|
||||
Nested: &nestedStruct{
|
||||
Value: "nested value",
|
||||
Count: 42,
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
path string
|
||||
want interface{}
|
||||
}{
|
||||
{"empty path", nested, "", nested},
|
||||
{"single level map", nested, "level1", nested["level1"]},
|
||||
{"nested map", nested, "level1.level2", map[string]interface{}{"value": "deep"}},
|
||||
{"deep nested map", nested, "level1.level2.value", "deep"},
|
||||
{"struct field", ts, "Name", "John"},
|
||||
{"nested struct field", ts, "Nested", ts.Nested},
|
||||
{"non-existent path", nested, "missing.path", nil},
|
||||
{"nil input", nil, "path", nil},
|
||||
{"partial missing path", nested, "level1.missing", nil},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := GetNestedValue(tt.input, tt.path)
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("GetNestedValue() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeepEqual(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
a interface{}
|
||||
b interface{}
|
||||
want bool
|
||||
}{
|
||||
{"equal ints", 42, 42, true},
|
||||
{"different ints", 42, 43, false},
|
||||
{"equal strings", "test", "test", true},
|
||||
{"different strings", "test", "other", false},
|
||||
{"equal slices", []int{1, 2, 3}, []int{1, 2, 3}, true},
|
||||
{"different slices", []int{1, 2, 3}, []int{1, 2, 4}, false},
|
||||
{"equal maps", map[string]int{"a": 1}, map[string]int{"a": 1}, true},
|
||||
{"different maps", map[string]int{"a": 1}, map[string]int{"a": 2}, false},
|
||||
{"both nil", nil, nil, true},
|
||||
{"one nil", nil, 42, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := DeepEqual(tt.a, tt.b)
|
||||
if got != tt.want {
|
||||
t.Errorf("DeepEqual(%v, %v) = %v, want %v", tt.a, tt.b, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
func ptrInt(i int) *int {
|
||||
return &i
|
||||
}
|
||||
|
||||
func ptrPtr(p *int) **int {
|
||||
return &p
|
||||
}
|
||||
95
pkg/ui/column_dataops.go
Normal file
95
pkg/ui/column_dataops.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package ui
|
||||
|
||||
import "git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||
|
||||
// Column data operations - business logic for column management
|
||||
|
||||
// CreateColumn creates a new column and adds it to a table
|
||||
func (se *SchemaEditor) CreateColumn(schemaIndex, tableIndex int, name, dataType string, isPrimaryKey, isNotNull bool) *models.Column {
|
||||
table := se.GetTable(schemaIndex, tableIndex)
|
||||
if table == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if table.Columns == nil {
|
||||
table.Columns = make(map[string]*models.Column)
|
||||
}
|
||||
|
||||
newColumn := &models.Column{
|
||||
Name: name,
|
||||
Type: dataType,
|
||||
IsPrimaryKey: isPrimaryKey,
|
||||
NotNull: isNotNull,
|
||||
}
|
||||
table.UpdateDate()
|
||||
table.Columns[name] = newColumn
|
||||
return newColumn
|
||||
}
|
||||
|
||||
// UpdateColumn updates an existing column's properties
|
||||
func (se *SchemaEditor) UpdateColumn(schemaIndex, tableIndex int, oldName, newName, dataType string, isPrimaryKey, isNotNull bool, defaultValue interface{}, description string) bool {
|
||||
table := se.GetTable(schemaIndex, tableIndex)
|
||||
if table == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
column, exists := table.Columns[oldName]
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
|
||||
table.UpdateDate()
|
||||
|
||||
// If name changed, remove old entry and create new one
|
||||
if oldName != newName {
|
||||
delete(table.Columns, oldName)
|
||||
column.Name = newName
|
||||
table.Columns[newName] = column
|
||||
}
|
||||
|
||||
// Update properties
|
||||
column.Type = dataType
|
||||
column.IsPrimaryKey = isPrimaryKey
|
||||
column.NotNull = isNotNull
|
||||
column.Default = defaultValue
|
||||
column.Description = description
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// DeleteColumn removes a column from a table
|
||||
func (se *SchemaEditor) DeleteColumn(schemaIndex, tableIndex int, columnName string) bool {
|
||||
table := se.GetTable(schemaIndex, tableIndex)
|
||||
if table == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if _, exists := table.Columns[columnName]; !exists {
|
||||
return false
|
||||
}
|
||||
|
||||
table.UpdateDate()
|
||||
|
||||
delete(table.Columns, columnName)
|
||||
return true
|
||||
}
|
||||
|
||||
// GetColumn returns a column by name
|
||||
func (se *SchemaEditor) GetColumn(schemaIndex, tableIndex int, columnName string) *models.Column {
|
||||
table := se.GetTable(schemaIndex, tableIndex)
|
||||
if table == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return table.Columns[columnName]
|
||||
}
|
||||
|
||||
// GetAllColumns returns all columns in a table
|
||||
func (se *SchemaEditor) GetAllColumns(schemaIndex, tableIndex int) map[string]*models.Column {
|
||||
table := se.GetTable(schemaIndex, tableIndex)
|
||||
if table == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return table.Columns
|
||||
}
|
||||
214
pkg/ui/column_screens.go
Normal file
214
pkg/ui/column_screens.go
Normal file
@@ -0,0 +1,214 @@
|
||||
package ui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/gdamore/tcell/v2"
|
||||
"github.com/rivo/tview"
|
||||
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||
)
|
||||
|
||||
// showColumnEditor shows editor for a specific column
|
||||
func (se *SchemaEditor) showColumnEditor(schemaIndex, tableIndex, colIndex int, column *models.Column) {
|
||||
form := tview.NewForm()
|
||||
|
||||
// Store original name to handle renames
|
||||
originalName := column.Name
|
||||
|
||||
// Local variables to collect changes
|
||||
newName := column.Name
|
||||
newType := column.Type
|
||||
newIsPK := column.IsPrimaryKey
|
||||
newIsNotNull := column.NotNull
|
||||
newDefault := column.Default
|
||||
newDescription := column.Description
|
||||
newGUID := column.GUID
|
||||
|
||||
// Column type options: PostgreSQL, MySQL, SQL Server, and common SQL types
|
||||
columnTypes := []string{
|
||||
// Numeric Types
|
||||
"SMALLINT", "INTEGER", "BIGINT", "INT", "TINYINT", "FLOAT", "REAL", "DOUBLE PRECISION",
|
||||
"DECIMAL(10,2)", "NUMERIC", "DECIMAL", "NUMERIC(10,2)",
|
||||
// Character Types
|
||||
"CHAR", "VARCHAR", "VARCHAR(255)", "TEXT", "NCHAR", "NVARCHAR", "NVARCHAR(255)",
|
||||
// Boolean
|
||||
"BOOLEAN", "BOOL", "BIT",
|
||||
// Date/Time Types
|
||||
"DATE", "TIME", "TIMESTAMP", "TIMESTAMP WITH TIME ZONE", "INTERVAL",
|
||||
"DATETIME", "DATETIME2", "DATEFIRST",
|
||||
// UUID and JSON
|
||||
"UUID", "GUID", "JSON", "JSONB",
|
||||
// Binary Types
|
||||
"BYTEA", "BLOB", "IMAGE", "VARBINARY", "VARBINARY(MAX)", "BINARY",
|
||||
// PostgreSQL Special Types
|
||||
"int4range", "int8range", "numrange", "tsrange", "tstzrange", "daterange",
|
||||
"HSTORE", "CITEXT", "INET", "MACADDR", "POINT", "LINE", "LSEG", "BOX", "PATH", "POLYGON", "CIRCLE",
|
||||
// Array Types
|
||||
"INTEGER ARRAY", "VARCHAR ARRAY", "TEXT ARRAY", "BIGINT ARRAY",
|
||||
// MySQL Specific
|
||||
"MEDIUMINT", "DOUBLE", "FLOAT(10,2)",
|
||||
// SQL Server Specific
|
||||
"MONEY", "SMALLMONEY", "SQL_VARIANT",
|
||||
}
|
||||
selectedTypeIndex := 0
|
||||
|
||||
// Add existing type if not already in the list
|
||||
typeExists := false
|
||||
for i, opt := range columnTypes {
|
||||
if opt == column.Type {
|
||||
selectedTypeIndex = i
|
||||
typeExists = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !typeExists && column.Type != "" {
|
||||
columnTypes = append(columnTypes, column.Type)
|
||||
selectedTypeIndex = len(columnTypes) - 1
|
||||
}
|
||||
|
||||
form.AddInputField("Column Name", column.Name, 40, nil, func(value string) {
|
||||
newName = value
|
||||
})
|
||||
|
||||
form.AddDropDown("Type", columnTypes, selectedTypeIndex, func(option string, index int) {
|
||||
newType = option
|
||||
})
|
||||
|
||||
form.AddCheckbox("Primary Key", column.IsPrimaryKey, func(checked bool) {
|
||||
newIsPK = checked
|
||||
})
|
||||
|
||||
form.AddCheckbox("Not Null", column.NotNull, func(checked bool) {
|
||||
newIsNotNull = checked
|
||||
})
|
||||
|
||||
defaultStr := ""
|
||||
if column.Default != nil {
|
||||
defaultStr = fmt.Sprintf("%v", column.Default)
|
||||
}
|
||||
form.AddInputField("Default Value", defaultStr, 40, nil, func(value string) {
|
||||
newDefault = value
|
||||
})
|
||||
|
||||
form.AddTextArea("Description", column.Description, 40, 5, 0, func(value string) {
|
||||
newDescription = value
|
||||
})
|
||||
|
||||
form.AddInputField("GUID", column.GUID, 40, nil, func(value string) {
|
||||
newGUID = value
|
||||
})
|
||||
|
||||
form.AddButton("Save", func() {
|
||||
// Apply changes using dataops
|
||||
se.UpdateColumn(schemaIndex, tableIndex, originalName, newName, newType, newIsPK, newIsNotNull, newDefault, newDescription)
|
||||
se.db.Schemas[schemaIndex].Tables[tableIndex].Columns[newName].GUID = newGUID
|
||||
|
||||
se.pages.RemovePage("column-editor")
|
||||
se.pages.SwitchToPage("table-editor")
|
||||
})
|
||||
|
||||
form.AddButton("Delete", func() {
|
||||
se.showDeleteColumnConfirm(schemaIndex, tableIndex, originalName)
|
||||
})
|
||||
|
||||
form.AddButton("Back", func() {
|
||||
// Discard changes - don't apply them
|
||||
se.pages.RemovePage("column-editor")
|
||||
se.pages.SwitchToPage("table-editor")
|
||||
})
|
||||
|
||||
form.SetBorder(true).SetTitle(" Edit Column ").SetTitleAlign(tview.AlignLeft)
|
||||
form.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
|
||||
if event.Key() == tcell.KeyEscape {
|
||||
se.showExitConfirmation("column-editor", "table-editor")
|
||||
return nil
|
||||
}
|
||||
return event
|
||||
})
|
||||
|
||||
se.pages.AddPage("column-editor", form, true, true)
|
||||
}
|
||||
|
||||
// showNewColumnDialog shows dialog to create a new column
|
||||
func (se *SchemaEditor) showNewColumnDialog(schemaIndex, tableIndex int) {
|
||||
form := tview.NewForm()
|
||||
|
||||
columnName := ""
|
||||
dataType := "VARCHAR(255)"
|
||||
|
||||
// Column type options: PostgreSQL, MySQL, SQL Server, and common SQL types
|
||||
columnTypes := []string{
|
||||
// Numeric Types
|
||||
"SMALLINT", "INTEGER", "BIGINT", "INT", "TINYINT", "FLOAT", "REAL", "DOUBLE PRECISION",
|
||||
"DECIMAL(10,2)", "NUMERIC", "DECIMAL", "NUMERIC(10,2)",
|
||||
// Character Types
|
||||
"CHAR", "VARCHAR", "VARCHAR(255)", "TEXT", "NCHAR", "NVARCHAR", "NVARCHAR(255)",
|
||||
// Boolean
|
||||
"BOOLEAN", "BOOL", "BIT",
|
||||
// Date/Time Types
|
||||
"DATE", "TIME", "TIMESTAMP", "TIMESTAMP WITH TIME ZONE", "INTERVAL",
|
||||
"DATETIME", "DATETIME2", "DATEFIRST",
|
||||
// UUID and JSON
|
||||
"UUID", "GUID", "JSON", "JSONB",
|
||||
// Binary Types
|
||||
"BYTEA", "BLOB", "IMAGE", "VARBINARY", "VARBINARY(MAX)", "BINARY",
|
||||
// PostgreSQL Special Types
|
||||
"int4range", "int8range", "numrange", "tsrange", "tstzrange", "daterange",
|
||||
"HSTORE", "CITEXT", "INET", "MACADDR", "POINT", "LINE", "LSEG", "BOX", "PATH", "POLYGON", "CIRCLE",
|
||||
// Array Types
|
||||
"INTEGER ARRAY", "VARCHAR ARRAY", "TEXT ARRAY", "BIGINT ARRAY",
|
||||
// MySQL Specific
|
||||
"MEDIUMINT", "DOUBLE", "FLOAT(10,2)",
|
||||
// SQL Server Specific
|
||||
"MONEY", "SMALLMONEY", "SQL_VARIANT",
|
||||
}
|
||||
selectedTypeIndex := 0
|
||||
|
||||
form.AddInputField("Column Name", "", 40, nil, func(value string) {
|
||||
columnName = value
|
||||
})
|
||||
|
||||
form.AddDropDown("Data Type", columnTypes, selectedTypeIndex, func(option string, index int) {
|
||||
dataType = option
|
||||
})
|
||||
|
||||
form.AddCheckbox("Primary Key", false, nil)
|
||||
form.AddCheckbox("Not Null", false, nil)
|
||||
form.AddCheckbox("Unique", false, nil)
|
||||
|
||||
form.AddButton("Save", func() {
|
||||
if columnName == "" {
|
||||
return
|
||||
}
|
||||
|
||||
// Get form values
|
||||
isPK := form.GetFormItemByLabel("Primary Key").(*tview.Checkbox).IsChecked()
|
||||
isNotNull := form.GetFormItemByLabel("Not Null").(*tview.Checkbox).IsChecked()
|
||||
|
||||
se.CreateColumn(schemaIndex, tableIndex, columnName, dataType, isPK, isNotNull)
|
||||
|
||||
table := se.db.Schemas[schemaIndex].Tables[tableIndex]
|
||||
se.pages.RemovePage("new-column")
|
||||
se.pages.RemovePage("table-editor")
|
||||
se.showTableEditor(schemaIndex, tableIndex, table)
|
||||
})
|
||||
|
||||
form.AddButton("Back", func() {
|
||||
table := se.db.Schemas[schemaIndex].Tables[tableIndex]
|
||||
se.pages.RemovePage("new-column")
|
||||
se.pages.RemovePage("table-editor")
|
||||
se.showTableEditor(schemaIndex, tableIndex, table)
|
||||
})
|
||||
|
||||
form.SetBorder(true).SetTitle(" New Column ").SetTitleAlign(tview.AlignLeft)
|
||||
form.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
|
||||
if event.Key() == tcell.KeyEscape {
|
||||
se.showExitConfirmation("new-column", "table-editor")
|
||||
return nil
|
||||
}
|
||||
return event
|
||||
})
|
||||
|
||||
se.pages.AddPage("new-column", form, true, true)
|
||||
}
|
||||
15
pkg/ui/database_dataops.go
Normal file
15
pkg/ui/database_dataops.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package ui
|
||||
|
||||
import (
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||
)
|
||||
|
||||
// updateDatabase updates database properties
|
||||
func (se *SchemaEditor) updateDatabase(name, description, comment, dbType, dbVersion string) {
|
||||
se.db.Name = name
|
||||
se.db.Description = description
|
||||
se.db.Comment = comment
|
||||
se.db.DatabaseType = models.DatabaseType(dbType)
|
||||
se.db.DatabaseVersion = dbVersion
|
||||
se.db.UpdateDate()
|
||||
}
|
||||
78
pkg/ui/database_screens.go
Normal file
78
pkg/ui/database_screens.go
Normal file
@@ -0,0 +1,78 @@
|
||||
package ui
|
||||
|
||||
import (
|
||||
"github.com/gdamore/tcell/v2"
|
||||
"github.com/rivo/tview"
|
||||
)
|
||||
|
||||
// showEditDatabaseForm displays a dialog to edit database properties
|
||||
func (se *SchemaEditor) showEditDatabaseForm() {
|
||||
form := tview.NewForm()
|
||||
|
||||
dbName := se.db.Name
|
||||
dbDescription := se.db.Description
|
||||
dbComment := se.db.Comment
|
||||
dbType := string(se.db.DatabaseType)
|
||||
dbVersion := se.db.DatabaseVersion
|
||||
dbGUID := se.db.GUID
|
||||
|
||||
// Database type options
|
||||
dbTypeOptions := []string{"pgsql", "mssql", "sqlite"}
|
||||
selectedTypeIndex := 0
|
||||
for i, opt := range dbTypeOptions {
|
||||
if opt == dbType {
|
||||
selectedTypeIndex = i
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
form.AddInputField("Database Name", dbName, 40, nil, func(value string) {
|
||||
dbName = value
|
||||
})
|
||||
|
||||
form.AddInputField("Description", dbDescription, 50, nil, func(value string) {
|
||||
dbDescription = value
|
||||
})
|
||||
|
||||
form.AddInputField("Comment", dbComment, 50, nil, func(value string) {
|
||||
dbComment = value
|
||||
})
|
||||
|
||||
form.AddDropDown("Database Type", dbTypeOptions, selectedTypeIndex, func(option string, index int) {
|
||||
dbType = option
|
||||
})
|
||||
|
||||
form.AddInputField("Database Version", dbVersion, 20, nil, func(value string) {
|
||||
dbVersion = value
|
||||
})
|
||||
|
||||
form.AddInputField("GUID", dbGUID, 40, nil, func(value string) {
|
||||
dbGUID = value
|
||||
})
|
||||
|
||||
form.AddButton("Save", func() {
|
||||
if dbName == "" {
|
||||
return
|
||||
}
|
||||
se.updateDatabase(dbName, dbDescription, dbComment, dbType, dbVersion)
|
||||
se.db.GUID = dbGUID
|
||||
se.pages.RemovePage("edit-database")
|
||||
se.pages.RemovePage("main")
|
||||
se.pages.AddPage("main", se.createMainMenu(), true, true)
|
||||
})
|
||||
|
||||
form.AddButton("Back", func() {
|
||||
se.pages.RemovePage("edit-database")
|
||||
})
|
||||
|
||||
form.SetBorder(true).SetTitle(" Edit Database ").SetTitleAlign(tview.AlignLeft)
|
||||
form.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
|
||||
if event.Key() == tcell.KeyEscape {
|
||||
se.showExitConfirmation("edit-database", "main")
|
||||
return nil
|
||||
}
|
||||
return event
|
||||
})
|
||||
|
||||
se.pages.AddPage("edit-database", form, true, true)
|
||||
}
|
||||
139
pkg/ui/dialogs.go
Normal file
139
pkg/ui/dialogs.go
Normal file
@@ -0,0 +1,139 @@
|
||||
package ui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/gdamore/tcell/v2"
|
||||
"github.com/rivo/tview"
|
||||
)
|
||||
|
||||
// showExitConfirmation shows a confirmation dialog when trying to exit without saving
|
||||
func (se *SchemaEditor) showExitConfirmation(pageToRemove, pageToSwitchTo string) {
|
||||
modal := tview.NewModal().
|
||||
SetText("Exit without saving changes?").
|
||||
AddButtons([]string{"Cancel", "No, exit without saving"}).
|
||||
SetDoneFunc(func(buttonIndex int, buttonLabel string) {
|
||||
if buttonLabel == "No, exit without saving" {
|
||||
se.pages.RemovePage(pageToRemove)
|
||||
se.pages.SwitchToPage(pageToSwitchTo)
|
||||
}
|
||||
se.pages.RemovePage("exit-confirm")
|
||||
})
|
||||
|
||||
modal.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
|
||||
if event.Key() == tcell.KeyEscape {
|
||||
se.pages.RemovePage("exit-confirm")
|
||||
return nil
|
||||
}
|
||||
return event
|
||||
})
|
||||
|
||||
se.pages.AddPage("exit-confirm", modal, true, true)
|
||||
}
|
||||
|
||||
// showExitEditorConfirm shows confirmation dialog when trying to exit the entire editor
|
||||
func (se *SchemaEditor) showExitEditorConfirm() {
|
||||
modal := tview.NewModal().
|
||||
SetText("Exit RelSpec Editor? Press ESC again to confirm.").
|
||||
AddButtons([]string{"Cancel", "Exit"}).
|
||||
SetDoneFunc(func(buttonIndex int, buttonLabel string) {
|
||||
if buttonLabel == "Exit" {
|
||||
se.app.Stop()
|
||||
}
|
||||
se.pages.RemovePage("exit-editor-confirm")
|
||||
})
|
||||
|
||||
modal.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
|
||||
if event.Key() == tcell.KeyEscape {
|
||||
se.app.Stop()
|
||||
return nil
|
||||
}
|
||||
return event
|
||||
})
|
||||
|
||||
se.pages.AddPage("exit-editor-confirm", modal, true, true)
|
||||
}
|
||||
|
||||
// showDeleteSchemaConfirm shows confirmation dialog for schema deletion
|
||||
func (se *SchemaEditor) showDeleteSchemaConfirm(schemaIndex int) {
|
||||
modal := tview.NewModal().
|
||||
SetText(fmt.Sprintf("Delete schema '%s'? This will delete all tables in this schema.",
|
||||
se.db.Schemas[schemaIndex].Name)).
|
||||
AddButtons([]string{"Cancel", "Delete"}).
|
||||
SetDoneFunc(func(buttonIndex int, buttonLabel string) {
|
||||
if buttonLabel == "Delete" {
|
||||
se.DeleteSchema(schemaIndex)
|
||||
se.pages.RemovePage("schema-editor")
|
||||
se.pages.RemovePage("schemas")
|
||||
se.showSchemaList()
|
||||
}
|
||||
se.pages.RemovePage("confirm-delete-schema")
|
||||
})
|
||||
|
||||
modal.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
|
||||
if event.Key() == tcell.KeyEscape {
|
||||
se.pages.RemovePage("confirm-delete-schema")
|
||||
return nil
|
||||
}
|
||||
return event
|
||||
})
|
||||
|
||||
se.pages.AddPage("confirm-delete-schema", modal, true, true)
|
||||
}
|
||||
|
||||
// showDeleteTableConfirm shows confirmation dialog for table deletion
|
||||
func (se *SchemaEditor) showDeleteTableConfirm(schemaIndex, tableIndex int) {
|
||||
table := se.db.Schemas[schemaIndex].Tables[tableIndex]
|
||||
modal := tview.NewModal().
|
||||
SetText(fmt.Sprintf("Delete table '%s'? This action cannot be undone.",
|
||||
table.Name)).
|
||||
AddButtons([]string{"Cancel", "Delete"}).
|
||||
SetDoneFunc(func(buttonIndex int, buttonLabel string) {
|
||||
if buttonLabel == "Delete" {
|
||||
se.DeleteTable(schemaIndex, tableIndex)
|
||||
schema := se.db.Schemas[schemaIndex]
|
||||
se.pages.RemovePage("table-editor")
|
||||
se.pages.RemovePage("schema-editor")
|
||||
se.showSchemaEditor(schemaIndex, schema)
|
||||
}
|
||||
se.pages.RemovePage("confirm-delete-table")
|
||||
})
|
||||
|
||||
modal.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
|
||||
if event.Key() == tcell.KeyEscape {
|
||||
se.pages.RemovePage("confirm-delete-table")
|
||||
return nil
|
||||
}
|
||||
return event
|
||||
})
|
||||
|
||||
se.pages.AddPage("confirm-delete-table", modal, true, true)
|
||||
}
|
||||
|
||||
// showDeleteColumnConfirm shows confirmation dialog for column deletion
|
||||
func (se *SchemaEditor) showDeleteColumnConfirm(schemaIndex, tableIndex int, columnName string) {
|
||||
modal := tview.NewModal().
|
||||
SetText(fmt.Sprintf("Delete column '%s'? This action cannot be undone.",
|
||||
columnName)).
|
||||
AddButtons([]string{"Cancel", "Delete"}).
|
||||
SetDoneFunc(func(buttonIndex int, buttonLabel string) {
|
||||
if buttonLabel == "Delete" {
|
||||
se.DeleteColumn(schemaIndex, tableIndex, columnName)
|
||||
se.pages.RemovePage("column-editor")
|
||||
se.pages.RemovePage("confirm-delete-column")
|
||||
se.pages.SwitchToPage("table-editor")
|
||||
} else {
|
||||
se.pages.RemovePage("confirm-delete-column")
|
||||
}
|
||||
})
|
||||
|
||||
modal.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
|
||||
if event.Key() == tcell.KeyEscape {
|
||||
se.pages.RemovePage("confirm-delete-column")
|
||||
return nil
|
||||
}
|
||||
return event
|
||||
})
|
||||
|
||||
se.pages.AddPage("confirm-delete-column", modal, true, true)
|
||||
}
|
||||
35
pkg/ui/domain_dataops.go
Normal file
35
pkg/ui/domain_dataops.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package ui
|
||||
|
||||
import (
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||
)
|
||||
|
||||
// createDomain creates a new domain
|
||||
func (se *SchemaEditor) createDomain(name, description string) {
|
||||
domain := &models.Domain{
|
||||
Name: name,
|
||||
Description: description,
|
||||
Tables: make([]*models.DomainTable, 0),
|
||||
Sequence: uint(len(se.db.Domains)),
|
||||
}
|
||||
|
||||
se.db.Domains = append(se.db.Domains, domain)
|
||||
se.showDomainList()
|
||||
}
|
||||
|
||||
// updateDomain updates an existing domain
|
||||
func (se *SchemaEditor) updateDomain(index int, name, description string) {
|
||||
if index >= 0 && index < len(se.db.Domains) {
|
||||
se.db.Domains[index].Name = name
|
||||
se.db.Domains[index].Description = description
|
||||
se.showDomainList()
|
||||
}
|
||||
}
|
||||
|
||||
// deleteDomain deletes a domain by index
|
||||
func (se *SchemaEditor) deleteDomain(index int) {
|
||||
if index >= 0 && index < len(se.db.Domains) {
|
||||
se.db.Domains = append(se.db.Domains[:index], se.db.Domains[index+1:]...)
|
||||
se.showDomainList()
|
||||
}
|
||||
}
|
||||
258
pkg/ui/domain_screens.go
Normal file
258
pkg/ui/domain_screens.go
Normal file
@@ -0,0 +1,258 @@
|
||||
package ui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/gdamore/tcell/v2"
|
||||
"github.com/rivo/tview"
|
||||
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||
)
|
||||
|
||||
// showDomainList displays the domain management screen
|
||||
func (se *SchemaEditor) showDomainList() {
|
||||
flex := tview.NewFlex().SetDirection(tview.FlexRow)
|
||||
|
||||
// Title
|
||||
title := tview.NewTextView().
|
||||
SetText("[::b]Manage Domains").
|
||||
SetDynamicColors(true).
|
||||
SetTextAlign(tview.AlignCenter)
|
||||
|
||||
// Create domains table
|
||||
domainTable := tview.NewTable().SetBorders(true).SetSelectable(true, false).SetFixed(1, 0)
|
||||
|
||||
// Add header row
|
||||
headers := []string{"Name", "Sequence", "Total Tables", "Description"}
|
||||
headerWidths := []int{20, 15, 20}
|
||||
for i, header := range headers {
|
||||
padding := ""
|
||||
if i < len(headerWidths) {
|
||||
padding = strings.Repeat(" ", headerWidths[i]-len(header))
|
||||
}
|
||||
cell := tview.NewTableCell(header + padding).
|
||||
SetTextColor(tcell.ColorYellow).
|
||||
SetSelectable(false).
|
||||
SetAlign(tview.AlignLeft)
|
||||
domainTable.SetCell(0, i, cell)
|
||||
}
|
||||
|
||||
// Add existing domains
|
||||
for row, domain := range se.db.Domains {
|
||||
domain := domain // capture for closure
|
||||
|
||||
// Name - pad to 20 chars
|
||||
nameStr := fmt.Sprintf("%-20s", domain.Name)
|
||||
nameCell := tview.NewTableCell(nameStr).SetSelectable(true)
|
||||
domainTable.SetCell(row+1, 0, nameCell)
|
||||
|
||||
// Sequence - pad to 15 chars
|
||||
seqStr := fmt.Sprintf("%-15s", fmt.Sprintf("%d", domain.Sequence))
|
||||
seqCell := tview.NewTableCell(seqStr).SetSelectable(true)
|
||||
domainTable.SetCell(row+1, 1, seqCell)
|
||||
|
||||
// Total Tables - pad to 20 chars
|
||||
tablesStr := fmt.Sprintf("%-20s", fmt.Sprintf("%d", len(domain.Tables)))
|
||||
tablesCell := tview.NewTableCell(tablesStr).SetSelectable(true)
|
||||
domainTable.SetCell(row+1, 2, tablesCell)
|
||||
|
||||
// Description - no padding, takes remaining space
|
||||
descCell := tview.NewTableCell(domain.Description).SetSelectable(true)
|
||||
domainTable.SetCell(row+1, 3, descCell)
|
||||
}
|
||||
|
||||
domainTable.SetTitle(" Domains ").SetBorder(true).SetTitleAlign(tview.AlignLeft)
|
||||
|
||||
// Action buttons flex
|
||||
btnFlex := tview.NewFlex()
|
||||
btnNewDomain := tview.NewButton("New Domain [n]").SetSelectedFunc(func() {
|
||||
se.showNewDomainDialog()
|
||||
})
|
||||
btnBack := tview.NewButton("Back [b]").SetSelectedFunc(func() {
|
||||
se.pages.SwitchToPage("main")
|
||||
se.pages.RemovePage("domains")
|
||||
})
|
||||
|
||||
// Set up button input captures for Tab/Shift+Tab navigation
|
||||
btnNewDomain.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
|
||||
if event.Key() == tcell.KeyBacktab {
|
||||
se.app.SetFocus(domainTable)
|
||||
return nil
|
||||
}
|
||||
if event.Key() == tcell.KeyTab {
|
||||
se.app.SetFocus(btnBack)
|
||||
return nil
|
||||
}
|
||||
return event
|
||||
})
|
||||
|
||||
btnBack.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
|
||||
if event.Key() == tcell.KeyBacktab {
|
||||
se.app.SetFocus(btnNewDomain)
|
||||
return nil
|
||||
}
|
||||
if event.Key() == tcell.KeyTab {
|
||||
se.app.SetFocus(domainTable)
|
||||
return nil
|
||||
}
|
||||
return event
|
||||
})
|
||||
|
||||
btnFlex.AddItem(btnNewDomain, 0, 1, true).
|
||||
AddItem(btnBack, 0, 1, false)
|
||||
|
||||
domainTable.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
|
||||
if event.Key() == tcell.KeyEscape {
|
||||
se.pages.SwitchToPage("main")
|
||||
se.pages.RemovePage("domains")
|
||||
return nil
|
||||
}
|
||||
if event.Key() == tcell.KeyTab {
|
||||
se.app.SetFocus(btnNewDomain)
|
||||
return nil
|
||||
}
|
||||
if event.Key() == tcell.KeyEnter {
|
||||
row, _ := domainTable.GetSelection()
|
||||
if row > 0 && row <= len(se.db.Domains) { // Skip header row
|
||||
domainIndex := row - 1
|
||||
se.showDomainEditor(domainIndex, se.db.Domains[domainIndex])
|
||||
return nil
|
||||
}
|
||||
}
|
||||
if event.Rune() == 'n' {
|
||||
se.showNewDomainDialog()
|
||||
return nil
|
||||
}
|
||||
if event.Rune() == 'b' {
|
||||
se.pages.SwitchToPage("main")
|
||||
se.pages.RemovePage("domains")
|
||||
return nil
|
||||
}
|
||||
return event
|
||||
})
|
||||
|
||||
flex.AddItem(title, 1, 0, false).
|
||||
AddItem(domainTable, 0, 1, true).
|
||||
AddItem(btnFlex, 1, 0, false)
|
||||
|
||||
se.pages.AddPage("domains", flex, true, true)
|
||||
}
|
||||
|
||||
// showNewDomainDialog displays a dialog to create a new domain
|
||||
func (se *SchemaEditor) showNewDomainDialog() {
|
||||
form := tview.NewForm()
|
||||
|
||||
domainName := ""
|
||||
domainDesc := ""
|
||||
|
||||
form.AddInputField("Name", "", 40, nil, func(value string) {
|
||||
domainName = value
|
||||
})
|
||||
|
||||
form.AddInputField("Description", "", 50, nil, func(value string) {
|
||||
domainDesc = value
|
||||
})
|
||||
|
||||
form.AddButton("Save", func() {
|
||||
if domainName == "" {
|
||||
return
|
||||
}
|
||||
se.createDomain(domainName, domainDesc)
|
||||
se.pages.RemovePage("new-domain")
|
||||
se.pages.RemovePage("domains")
|
||||
se.showDomainList()
|
||||
})
|
||||
|
||||
form.AddButton("Back", func() {
|
||||
se.pages.RemovePage("new-domain")
|
||||
se.pages.RemovePage("domains")
|
||||
se.showDomainList()
|
||||
})
|
||||
|
||||
form.SetBorder(true).SetTitle(" New Domain ").SetTitleAlign(tview.AlignLeft)
|
||||
form.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
|
||||
if event.Key() == tcell.KeyEscape {
|
||||
se.showExitConfirmation("new-domain", "domains")
|
||||
return nil
|
||||
}
|
||||
return event
|
||||
})
|
||||
|
||||
se.pages.AddPage("new-domain", form, true, true)
|
||||
}
|
||||
|
||||
// showDomainEditor displays a dialog to edit an existing domain
|
||||
func (se *SchemaEditor) showDomainEditor(index int, domain *models.Domain) {
|
||||
form := tview.NewForm()
|
||||
|
||||
domainName := domain.Name
|
||||
domainDesc := domain.Description
|
||||
|
||||
form.AddInputField("Name", domainName, 40, nil, func(value string) {
|
||||
domainName = value
|
||||
})
|
||||
|
||||
form.AddInputField("Description", domainDesc, 50, nil, func(value string) {
|
||||
domainDesc = value
|
||||
})
|
||||
|
||||
form.AddButton("Save", func() {
|
||||
if domainName == "" {
|
||||
return
|
||||
}
|
||||
se.updateDomain(index, domainName, domainDesc)
|
||||
se.pages.RemovePage("edit-domain")
|
||||
se.pages.RemovePage("domains")
|
||||
se.showDomainList()
|
||||
})
|
||||
|
||||
form.AddButton("Delete", func() {
|
||||
se.showDeleteDomainConfirm(index)
|
||||
})
|
||||
|
||||
form.AddButton("Back", func() {
|
||||
se.pages.RemovePage("edit-domain")
|
||||
se.pages.RemovePage("domains")
|
||||
se.showDomainList()
|
||||
})
|
||||
|
||||
form.SetBorder(true).SetTitle(" Edit Domain ").SetTitleAlign(tview.AlignLeft)
|
||||
form.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
|
||||
if event.Key() == tcell.KeyEscape {
|
||||
se.showExitConfirmation("edit-domain", "domains")
|
||||
return nil
|
||||
}
|
||||
return event
|
||||
})
|
||||
|
||||
se.pages.AddPage("edit-domain", form, true, true)
|
||||
}
|
||||
|
||||
// showDeleteDomainConfirm shows a confirmation dialog before deleting a domain
|
||||
func (se *SchemaEditor) showDeleteDomainConfirm(index int) {
|
||||
modal := tview.NewModal().
|
||||
SetText(fmt.Sprintf("Delete domain '%s'? This action cannot be undone.", se.db.Domains[index].Name)).
|
||||
AddButtons([]string{"Cancel", "Delete"}).
|
||||
SetDoneFunc(func(buttonIndex int, buttonLabel string) {
|
||||
if buttonLabel == "Delete" {
|
||||
se.deleteDomain(index)
|
||||
se.pages.RemovePage("delete-domain-confirm")
|
||||
se.pages.RemovePage("edit-domain")
|
||||
se.pages.RemovePage("domains")
|
||||
se.showDomainList()
|
||||
} else {
|
||||
se.pages.RemovePage("delete-domain-confirm")
|
||||
}
|
||||
})
|
||||
|
||||
modal.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
|
||||
if event.Key() == tcell.KeyEscape {
|
||||
se.pages.RemovePage("delete-domain-confirm")
|
||||
return nil
|
||||
}
|
||||
return event
|
||||
})
|
||||
|
||||
se.pages.AddAndSwitchToPage("delete-domain-confirm", modal, true)
|
||||
}
|
||||
73
pkg/ui/editor.go
Normal file
73
pkg/ui/editor.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package ui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/rivo/tview"
|
||||
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||
)
|
||||
|
||||
// SchemaEditor represents the interactive schema editor
|
||||
type SchemaEditor struct {
|
||||
db *models.Database
|
||||
app *tview.Application
|
||||
pages *tview.Pages
|
||||
loadConfig *LoadConfig
|
||||
saveConfig *SaveConfig
|
||||
}
|
||||
|
||||
// NewSchemaEditor creates a new schema editor
|
||||
func NewSchemaEditor(db *models.Database) *SchemaEditor {
|
||||
return &SchemaEditor{
|
||||
db: db,
|
||||
app: tview.NewApplication(),
|
||||
pages: tview.NewPages(),
|
||||
loadConfig: nil,
|
||||
saveConfig: nil,
|
||||
}
|
||||
}
|
||||
|
||||
// NewSchemaEditorWithConfigs creates a new schema editor with load/save configurations
|
||||
func NewSchemaEditorWithConfigs(db *models.Database, loadConfig *LoadConfig, saveConfig *SaveConfig) *SchemaEditor {
|
||||
return &SchemaEditor{
|
||||
db: db,
|
||||
app: tview.NewApplication(),
|
||||
pages: tview.NewPages(),
|
||||
loadConfig: loadConfig,
|
||||
saveConfig: saveConfig,
|
||||
}
|
||||
}
|
||||
|
||||
// Run starts the interactive editor
|
||||
func (se *SchemaEditor) Run() error {
|
||||
// If no database is loaded, show load screen
|
||||
if se.db == nil {
|
||||
se.showLoadScreen()
|
||||
} else {
|
||||
// Create main menu view
|
||||
mainMenu := se.createMainMenu()
|
||||
se.pages.AddPage("main", mainMenu, true, true)
|
||||
}
|
||||
|
||||
// Run the application
|
||||
if err := se.app.SetRoot(se.pages, true).Run(); err != nil {
|
||||
return fmt.Errorf("application error: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetDatabase returns the current database
|
||||
func (se *SchemaEditor) GetDatabase() *models.Database {
|
||||
return se.db
|
||||
}
|
||||
|
||||
// Helper function to get sorted column names
|
||||
func getColumnNames(table *models.Table) []string {
|
||||
names := make([]string, 0, len(table.Columns))
|
||||
for name := range table.Columns {
|
||||
names = append(names, name)
|
||||
}
|
||||
return names
|
||||
}
|
||||
791
pkg/ui/load_save_screens.go
Normal file
791
pkg/ui/load_save_screens.go
Normal file
@@ -0,0 +1,791 @@
|
||||
package ui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/gdamore/tcell/v2"
|
||||
"github.com/rivo/tview"
|
||||
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/merge"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/readers"
|
||||
rbun "git.warky.dev/wdevs/relspecgo/pkg/readers/bun"
|
||||
rdbml "git.warky.dev/wdevs/relspecgo/pkg/readers/dbml"
|
||||
rdctx "git.warky.dev/wdevs/relspecgo/pkg/readers/dctx"
|
||||
rdrawdb "git.warky.dev/wdevs/relspecgo/pkg/readers/drawdb"
|
||||
rdrizzle "git.warky.dev/wdevs/relspecgo/pkg/readers/drizzle"
|
||||
rgorm "git.warky.dev/wdevs/relspecgo/pkg/readers/gorm"
|
||||
rgraphql "git.warky.dev/wdevs/relspecgo/pkg/readers/graphql"
|
||||
rjson "git.warky.dev/wdevs/relspecgo/pkg/readers/json"
|
||||
rpgsql "git.warky.dev/wdevs/relspecgo/pkg/readers/pgsql"
|
||||
rprisma "git.warky.dev/wdevs/relspecgo/pkg/readers/prisma"
|
||||
rtypeorm "git.warky.dev/wdevs/relspecgo/pkg/readers/typeorm"
|
||||
ryaml "git.warky.dev/wdevs/relspecgo/pkg/readers/yaml"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/writers"
|
||||
wbun "git.warky.dev/wdevs/relspecgo/pkg/writers/bun"
|
||||
wdbml "git.warky.dev/wdevs/relspecgo/pkg/writers/dbml"
|
||||
wdctx "git.warky.dev/wdevs/relspecgo/pkg/writers/dctx"
|
||||
wdrawdb "git.warky.dev/wdevs/relspecgo/pkg/writers/drawdb"
|
||||
wdrizzle "git.warky.dev/wdevs/relspecgo/pkg/writers/drizzle"
|
||||
wgorm "git.warky.dev/wdevs/relspecgo/pkg/writers/gorm"
|
||||
wgraphql "git.warky.dev/wdevs/relspecgo/pkg/writers/graphql"
|
||||
wjson "git.warky.dev/wdevs/relspecgo/pkg/writers/json"
|
||||
wpgsql "git.warky.dev/wdevs/relspecgo/pkg/writers/pgsql"
|
||||
wprisma "git.warky.dev/wdevs/relspecgo/pkg/writers/prisma"
|
||||
wtypeorm "git.warky.dev/wdevs/relspecgo/pkg/writers/typeorm"
|
||||
wyaml "git.warky.dev/wdevs/relspecgo/pkg/writers/yaml"
|
||||
)
|
||||
|
||||
// LoadConfig holds the configuration for loading a database
|
||||
type LoadConfig struct {
|
||||
SourceType string
|
||||
FilePath string
|
||||
ConnString string
|
||||
}
|
||||
|
||||
// SaveConfig holds the configuration for saving a database
|
||||
type SaveConfig struct {
|
||||
TargetType string
|
||||
FilePath string
|
||||
ConnString string
|
||||
}
|
||||
|
||||
// showLoadScreen displays the database load screen
|
||||
func (se *SchemaEditor) showLoadScreen() {
|
||||
flex := tview.NewFlex().SetDirection(tview.FlexRow)
|
||||
|
||||
// Title
|
||||
title := tview.NewTextView().
|
||||
SetText("[::b]Load Database Schema").
|
||||
SetTextAlign(tview.AlignCenter).
|
||||
SetDynamicColors(true)
|
||||
|
||||
// Form
|
||||
form := tview.NewForm()
|
||||
form.SetBorder(true).SetTitle(" Load Configuration ").SetTitleAlign(tview.AlignLeft)
|
||||
|
||||
// Format selection
|
||||
formatOptions := []string{
|
||||
"dbml", "dctx", "drawdb", "graphql", "json", "yaml",
|
||||
"gorm", "bun", "drizzle", "prisma", "typeorm", "pgsql",
|
||||
}
|
||||
selectedFormat := 0
|
||||
currentFormat := formatOptions[selectedFormat]
|
||||
|
||||
// File path input
|
||||
filePath := ""
|
||||
connString := ""
|
||||
|
||||
form.AddDropDown("Format", formatOptions, 0, func(option string, index int) {
|
||||
selectedFormat = index
|
||||
currentFormat = option
|
||||
})
|
||||
|
||||
form.AddInputField("File Path", "", 50, nil, func(value string) {
|
||||
filePath = value
|
||||
})
|
||||
|
||||
form.AddInputField("Connection String", "", 50, nil, func(value string) {
|
||||
connString = value
|
||||
})
|
||||
|
||||
form.AddTextView("Help", getLoadHelpText(), 0, 5, true, false)
|
||||
|
||||
// Buttons
|
||||
form.AddButton("Load [l]", func() {
|
||||
se.loadDatabase(currentFormat, filePath, connString)
|
||||
})
|
||||
|
||||
form.AddButton("Create New [n]", func() {
|
||||
se.createNewDatabase()
|
||||
})
|
||||
|
||||
form.AddButton("Exit [q]", func() {
|
||||
se.app.Stop()
|
||||
})
|
||||
|
||||
// Keyboard shortcuts
|
||||
form.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
|
||||
if event.Key() == tcell.KeyEscape {
|
||||
se.app.Stop()
|
||||
return nil
|
||||
}
|
||||
switch event.Rune() {
|
||||
case 'l':
|
||||
se.loadDatabase(currentFormat, filePath, connString)
|
||||
return nil
|
||||
case 'n':
|
||||
se.createNewDatabase()
|
||||
return nil
|
||||
case 'q':
|
||||
se.app.Stop()
|
||||
return nil
|
||||
}
|
||||
return event
|
||||
})
|
||||
|
||||
// Tab navigation
|
||||
form.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
|
||||
if event.Key() == tcell.KeyEscape {
|
||||
se.app.Stop()
|
||||
return nil
|
||||
}
|
||||
if event.Rune() == 'l' || event.Rune() == 'n' || event.Rune() == 'q' {
|
||||
return event
|
||||
}
|
||||
return event
|
||||
})
|
||||
|
||||
flex.AddItem(title, 1, 0, false).
|
||||
AddItem(form, 0, 1, true)
|
||||
|
||||
se.pages.AddAndSwitchToPage("load-database", flex, true)
|
||||
}
|
||||
|
||||
// showSaveScreen displays the save database screen
|
||||
func (se *SchemaEditor) showSaveScreen() {
|
||||
flex := tview.NewFlex().SetDirection(tview.FlexRow)
|
||||
|
||||
// Title
|
||||
title := tview.NewTextView().
|
||||
SetText("[::b]Save Database Schema").
|
||||
SetTextAlign(tview.AlignCenter).
|
||||
SetDynamicColors(true)
|
||||
|
||||
// Form
|
||||
form := tview.NewForm()
|
||||
form.SetBorder(true).SetTitle(" Save Configuration ").SetTitleAlign(tview.AlignLeft)
|
||||
|
||||
// Format selection
|
||||
formatOptions := []string{
|
||||
"dbml", "dctx", "drawdb", "graphql", "json", "yaml",
|
||||
"gorm", "bun", "drizzle", "prisma", "typeorm", "pgsql",
|
||||
}
|
||||
selectedFormat := 0
|
||||
currentFormat := formatOptions[selectedFormat]
|
||||
|
||||
// File path input
|
||||
filePath := ""
|
||||
if se.saveConfig != nil {
|
||||
// Pre-populate with existing save config
|
||||
for i, format := range formatOptions {
|
||||
if format == se.saveConfig.TargetType {
|
||||
selectedFormat = i
|
||||
currentFormat = format
|
||||
break
|
||||
}
|
||||
}
|
||||
filePath = se.saveConfig.FilePath
|
||||
}
|
||||
|
||||
form.AddDropDown("Format", formatOptions, selectedFormat, func(option string, index int) {
|
||||
selectedFormat = index
|
||||
currentFormat = option
|
||||
})
|
||||
|
||||
form.AddInputField("File Path", filePath, 50, nil, func(value string) {
|
||||
filePath = value
|
||||
})
|
||||
|
||||
form.AddTextView("Help", getSaveHelpText(), 0, 5, true, false)
|
||||
|
||||
// Buttons
|
||||
form.AddButton("Save [s]", func() {
|
||||
se.saveDatabase(currentFormat, filePath)
|
||||
})
|
||||
|
||||
form.AddButton("Update Existing Database [u]", func() {
|
||||
// Use saveConfig if available, otherwise use loadConfig
|
||||
if se.saveConfig != nil {
|
||||
se.showUpdateExistingDatabaseConfirm()
|
||||
} else if se.loadConfig != nil {
|
||||
se.showUpdateExistingDatabaseConfirm()
|
||||
} else {
|
||||
se.showErrorDialog("Error", "No database source found. Use Save instead.")
|
||||
}
|
||||
})
|
||||
|
||||
form.AddButton("Back [b]", func() {
|
||||
se.pages.RemovePage("save-database")
|
||||
se.pages.SwitchToPage("main")
|
||||
})
|
||||
|
||||
// Keyboard shortcuts
|
||||
form.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
|
||||
if event.Key() == tcell.KeyEscape {
|
||||
se.pages.RemovePage("save-database")
|
||||
se.pages.SwitchToPage("main")
|
||||
return nil
|
||||
}
|
||||
switch event.Rune() {
|
||||
case 's':
|
||||
se.saveDatabase(currentFormat, filePath)
|
||||
return nil
|
||||
case 'u':
|
||||
// Use saveConfig if available, otherwise use loadConfig
|
||||
if se.saveConfig != nil {
|
||||
se.showUpdateExistingDatabaseConfirm()
|
||||
} else if se.loadConfig != nil {
|
||||
se.showUpdateExistingDatabaseConfirm()
|
||||
} else {
|
||||
se.showErrorDialog("Error", "No database source found. Use Save instead.")
|
||||
}
|
||||
return nil
|
||||
case 'b':
|
||||
se.pages.RemovePage("save-database")
|
||||
se.pages.SwitchToPage("main")
|
||||
return nil
|
||||
}
|
||||
return event
|
||||
})
|
||||
|
||||
flex.AddItem(title, 1, 0, false).
|
||||
AddItem(form, 0, 1, true)
|
||||
|
||||
se.pages.AddAndSwitchToPage("save-database", flex, true)
|
||||
}
|
||||
|
||||
// loadDatabase loads a database from the specified configuration
|
||||
func (se *SchemaEditor) loadDatabase(format, filePath, connString string) {
|
||||
// Validate input
|
||||
if format == "pgsql" {
|
||||
if connString == "" {
|
||||
se.showErrorDialog("Error", "Connection string is required for PostgreSQL")
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if filePath == "" {
|
||||
se.showErrorDialog("Error", "File path is required for "+format)
|
||||
return
|
||||
}
|
||||
// Expand home directory
|
||||
if len(filePath) > 0 && filePath[0] == '~' {
|
||||
home, err := os.UserHomeDir()
|
||||
if err == nil {
|
||||
filePath = filepath.Join(home, filePath[1:])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create reader
|
||||
var reader readers.Reader
|
||||
switch format {
|
||||
case "dbml":
|
||||
reader = rdbml.NewReader(&readers.ReaderOptions{FilePath: filePath})
|
||||
case "dctx":
|
||||
reader = rdctx.NewReader(&readers.ReaderOptions{FilePath: filePath})
|
||||
case "drawdb":
|
||||
reader = rdrawdb.NewReader(&readers.ReaderOptions{FilePath: filePath})
|
||||
case "graphql":
|
||||
reader = rgraphql.NewReader(&readers.ReaderOptions{FilePath: filePath})
|
||||
case "json":
|
||||
reader = rjson.NewReader(&readers.ReaderOptions{FilePath: filePath})
|
||||
case "yaml":
|
||||
reader = ryaml.NewReader(&readers.ReaderOptions{FilePath: filePath})
|
||||
case "gorm":
|
||||
reader = rgorm.NewReader(&readers.ReaderOptions{FilePath: filePath})
|
||||
case "bun":
|
||||
reader = rbun.NewReader(&readers.ReaderOptions{FilePath: filePath})
|
||||
case "drizzle":
|
||||
reader = rdrizzle.NewReader(&readers.ReaderOptions{FilePath: filePath})
|
||||
case "prisma":
|
||||
reader = rprisma.NewReader(&readers.ReaderOptions{FilePath: filePath})
|
||||
case "typeorm":
|
||||
reader = rtypeorm.NewReader(&readers.ReaderOptions{FilePath: filePath})
|
||||
case "pgsql":
|
||||
reader = rpgsql.NewReader(&readers.ReaderOptions{ConnectionString: connString})
|
||||
default:
|
||||
se.showErrorDialog("Error", "Unsupported format: "+format)
|
||||
return
|
||||
}
|
||||
|
||||
// Read database
|
||||
db, err := reader.ReadDatabase()
|
||||
if err != nil {
|
||||
se.showErrorDialog("Load Error", fmt.Sprintf("Failed to load database: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// Store load config
|
||||
se.loadConfig = &LoadConfig{
|
||||
SourceType: format,
|
||||
FilePath: filePath,
|
||||
ConnString: connString,
|
||||
}
|
||||
|
||||
// Update database
|
||||
se.db = db
|
||||
|
||||
// Show success and switch to main menu
|
||||
se.showSuccessDialog("Load Complete", fmt.Sprintf("Successfully loaded database '%s'", db.Name), func() {
|
||||
se.pages.RemovePage("load-database")
|
||||
se.pages.RemovePage("main")
|
||||
se.pages.AddPage("main", se.createMainMenu(), true, true)
|
||||
})
|
||||
}
|
||||
|
||||
// saveDatabase saves the database to the specified configuration
|
||||
func (se *SchemaEditor) saveDatabase(format, filePath string) {
|
||||
// Validate input
|
||||
if format == "pgsql" {
|
||||
se.showErrorDialog("Error", "Direct PostgreSQL save is not supported from the UI. Use --to pgsql --to-path output.sql")
|
||||
return
|
||||
}
|
||||
|
||||
if filePath == "" {
|
||||
se.showErrorDialog("Error", "File path is required")
|
||||
return
|
||||
}
|
||||
|
||||
// Expand home directory
|
||||
if len(filePath) > 0 && filePath[0] == '~' {
|
||||
home, err := os.UserHomeDir()
|
||||
if err == nil {
|
||||
filePath = filepath.Join(home, filePath[1:])
|
||||
}
|
||||
}
|
||||
|
||||
// Create writer
|
||||
var writer writers.Writer
|
||||
switch format {
|
||||
case "dbml":
|
||||
writer = wdbml.NewWriter(&writers.WriterOptions{OutputPath: filePath})
|
||||
case "dctx":
|
||||
writer = wdctx.NewWriter(&writers.WriterOptions{OutputPath: filePath})
|
||||
case "drawdb":
|
||||
writer = wdrawdb.NewWriter(&writers.WriterOptions{OutputPath: filePath})
|
||||
case "graphql":
|
||||
writer = wgraphql.NewWriter(&writers.WriterOptions{OutputPath: filePath})
|
||||
case "json":
|
||||
writer = wjson.NewWriter(&writers.WriterOptions{OutputPath: filePath})
|
||||
case "yaml":
|
||||
writer = wyaml.NewWriter(&writers.WriterOptions{OutputPath: filePath})
|
||||
case "gorm":
|
||||
writer = wgorm.NewWriter(&writers.WriterOptions{OutputPath: filePath})
|
||||
case "bun":
|
||||
writer = wbun.NewWriter(&writers.WriterOptions{OutputPath: filePath})
|
||||
case "drizzle":
|
||||
writer = wdrizzle.NewWriter(&writers.WriterOptions{OutputPath: filePath})
|
||||
case "prisma":
|
||||
writer = wprisma.NewWriter(&writers.WriterOptions{OutputPath: filePath})
|
||||
case "typeorm":
|
||||
writer = wtypeorm.NewWriter(&writers.WriterOptions{OutputPath: filePath})
|
||||
case "pgsql":
|
||||
writer = wpgsql.NewWriter(&writers.WriterOptions{OutputPath: filePath})
|
||||
default:
|
||||
se.showErrorDialog("Error", "Unsupported format: "+format)
|
||||
return
|
||||
}
|
||||
|
||||
// Write database
|
||||
err := writer.WriteDatabase(se.db)
|
||||
if err != nil {
|
||||
se.showErrorDialog("Save Error", fmt.Sprintf("Failed to save database: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// Store save config
|
||||
se.saveConfig = &SaveConfig{
|
||||
TargetType: format,
|
||||
FilePath: filePath,
|
||||
}
|
||||
|
||||
// Show success
|
||||
se.showSuccessDialog("Save Complete", fmt.Sprintf("Successfully saved database to %s", filePath), func() {
|
||||
se.pages.RemovePage("save-database")
|
||||
se.pages.SwitchToPage("main")
|
||||
})
|
||||
}
|
||||
|
||||
// createNewDatabase creates a new empty database
|
||||
func (se *SchemaEditor) createNewDatabase() {
|
||||
// Create a new empty database
|
||||
se.db = &models.Database{
|
||||
Name: "New Database",
|
||||
Schemas: []*models.Schema{},
|
||||
}
|
||||
|
||||
// Clear load config
|
||||
se.loadConfig = nil
|
||||
|
||||
// Show success and switch to main menu
|
||||
se.showSuccessDialog("New Database", "Created new empty database", func() {
|
||||
se.pages.RemovePage("load-database")
|
||||
se.pages.AddPage("main", se.createMainMenu(), true, true)
|
||||
})
|
||||
}
|
||||
|
||||
// showErrorDialog displays an error dialog
|
||||
func (se *SchemaEditor) showErrorDialog(_title, message string) {
|
||||
modal := tview.NewModal().
|
||||
SetText(message).
|
||||
AddButtons([]string{"OK"}).
|
||||
SetDoneFunc(func(buttonIndex int, buttonLabel string) {
|
||||
se.pages.RemovePage("error-dialog")
|
||||
})
|
||||
|
||||
modal.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
|
||||
if event.Key() == tcell.KeyEscape {
|
||||
se.pages.RemovePage("error-dialog")
|
||||
return nil
|
||||
}
|
||||
return event
|
||||
})
|
||||
|
||||
se.pages.AddPage("error-dialog", modal, true, true)
|
||||
}
|
||||
|
||||
// showSuccessDialog displays a success dialog
|
||||
func (se *SchemaEditor) showSuccessDialog(_title, message string, onClose func()) {
|
||||
modal := tview.NewModal().
|
||||
SetText(message).
|
||||
AddButtons([]string{"OK"}).
|
||||
SetDoneFunc(func(buttonIndex int, buttonLabel string) {
|
||||
se.pages.RemovePage("success-dialog")
|
||||
if onClose != nil {
|
||||
onClose()
|
||||
}
|
||||
})
|
||||
|
||||
modal.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
|
||||
if event.Key() == tcell.KeyEscape {
|
||||
se.pages.RemovePage("success-dialog")
|
||||
if onClose != nil {
|
||||
onClose()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return event
|
||||
})
|
||||
|
||||
se.pages.AddPage("success-dialog", modal, true, true)
|
||||
}
|
||||
|
||||
// getLoadHelpText returns the help text for the load screen
|
||||
func getLoadHelpText() string {
|
||||
return `File-based formats: dbml, dctx, drawdb, graphql, json, yaml, gorm, bun, drizzle, prisma, typeorm
|
||||
Database formats: pgsql (requires connection string)
|
||||
|
||||
Examples:
|
||||
- File path: ~/schemas/mydb.dbml or /path/to/schema.json
|
||||
- Connection: postgres://user:pass@localhost/dbname`
|
||||
}
|
||||
|
||||
// showUpdateExistingDatabaseConfirm displays a confirmation dialog before updating existing database
|
||||
func (se *SchemaEditor) showUpdateExistingDatabaseConfirm() {
|
||||
// Use saveConfig if available, otherwise use loadConfig
|
||||
var targetType, targetPath string
|
||||
if se.saveConfig != nil {
|
||||
targetType = se.saveConfig.TargetType
|
||||
targetPath = se.saveConfig.FilePath
|
||||
} else if se.loadConfig != nil {
|
||||
targetType = se.loadConfig.SourceType
|
||||
targetPath = se.loadConfig.FilePath
|
||||
} else {
|
||||
return
|
||||
}
|
||||
|
||||
confirmText := fmt.Sprintf("Update existing database?\n\nFormat: %s\nPath: %s\n\nThis will overwrite the source.",
|
||||
targetType, targetPath)
|
||||
|
||||
modal := tview.NewModal().
|
||||
SetText(confirmText).
|
||||
AddButtons([]string{"Cancel", "Update"}).
|
||||
SetDoneFunc(func(buttonIndex int, buttonLabel string) {
|
||||
if buttonLabel == "Update" {
|
||||
se.pages.RemovePage("update-confirm")
|
||||
se.pages.RemovePage("save-database")
|
||||
se.saveDatabase(targetType, targetPath)
|
||||
se.pages.SwitchToPage("main")
|
||||
} else {
|
||||
se.pages.RemovePage("update-confirm")
|
||||
}
|
||||
})
|
||||
|
||||
modal.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
|
||||
if event.Key() == tcell.KeyEscape {
|
||||
se.pages.RemovePage("update-confirm")
|
||||
return nil
|
||||
}
|
||||
return event
|
||||
})
|
||||
|
||||
se.pages.AddAndSwitchToPage("update-confirm", modal, true)
|
||||
}
|
||||
|
||||
// getSaveHelpText returns the help text for the save screen
|
||||
func getSaveHelpText() string {
|
||||
return `File-based formats: dbml, dctx, drawdb, graphql, json, yaml, gorm, bun, drizzle, prisma, typeorm, pgsql (SQL export)
|
||||
|
||||
Examples:
|
||||
- File: ~/schemas/mydb.dbml
|
||||
- Directory (for code formats): ./models/`
|
||||
}
|
||||
|
||||
// showImportScreen displays the import/merge database screen
|
||||
func (se *SchemaEditor) showImportScreen() {
|
||||
flex := tview.NewFlex().SetDirection(tview.FlexRow)
|
||||
|
||||
// Title
|
||||
title := tview.NewTextView().
|
||||
SetText("[::b]Import & Merge Database Schema").
|
||||
SetTextAlign(tview.AlignCenter).
|
||||
SetDynamicColors(true)
|
||||
|
||||
// Form
|
||||
form := tview.NewForm()
|
||||
form.SetBorder(true).SetTitle(" Import Configuration ").SetTitleAlign(tview.AlignLeft)
|
||||
|
||||
// Format selection
|
||||
formatOptions := []string{
|
||||
"dbml", "dctx", "drawdb", "graphql", "json", "yaml",
|
||||
"gorm", "bun", "drizzle", "prisma", "typeorm", "pgsql",
|
||||
}
|
||||
selectedFormat := 0
|
||||
currentFormat := formatOptions[selectedFormat]
|
||||
|
||||
// File path input
|
||||
filePath := ""
|
||||
connString := ""
|
||||
skipDomains := false
|
||||
skipRelations := false
|
||||
skipEnums := false
|
||||
skipViews := false
|
||||
skipSequences := false
|
||||
skipTables := ""
|
||||
|
||||
form.AddDropDown("Format", formatOptions, 0, func(option string, index int) {
|
||||
selectedFormat = index
|
||||
currentFormat = option
|
||||
})
|
||||
|
||||
form.AddInputField("File Path", "", 50, nil, func(value string) {
|
||||
filePath = value
|
||||
})
|
||||
|
||||
form.AddInputField("Connection String", "", 50, nil, func(value string) {
|
||||
connString = value
|
||||
})
|
||||
|
||||
form.AddInputField("Skip Tables (comma-separated)", "", 50, nil, func(value string) {
|
||||
skipTables = value
|
||||
})
|
||||
|
||||
form.AddCheckbox("Skip Domains", false, func(checked bool) {
|
||||
skipDomains = checked
|
||||
})
|
||||
|
||||
form.AddCheckbox("Skip Relations", false, func(checked bool) {
|
||||
skipRelations = checked
|
||||
})
|
||||
|
||||
form.AddCheckbox("Skip Enums", false, func(checked bool) {
|
||||
skipEnums = checked
|
||||
})
|
||||
|
||||
form.AddCheckbox("Skip Views", false, func(checked bool) {
|
||||
skipViews = checked
|
||||
})
|
||||
|
||||
form.AddCheckbox("Skip Sequences", false, func(checked bool) {
|
||||
skipSequences = checked
|
||||
})
|
||||
|
||||
form.AddTextView("Help", getImportHelpText(), 0, 7, true, false)
|
||||
|
||||
// Buttons
|
||||
form.AddButton("Import & Merge [i]", func() {
|
||||
se.importAndMergeDatabase(currentFormat, filePath, connString, skipDomains, skipRelations, skipEnums, skipViews, skipSequences, skipTables)
|
||||
})
|
||||
|
||||
form.AddButton("Back [b]", func() {
|
||||
se.pages.RemovePage("import-database")
|
||||
se.pages.SwitchToPage("main")
|
||||
})
|
||||
|
||||
form.AddButton("Exit [q]", func() {
|
||||
se.app.Stop()
|
||||
})
|
||||
|
||||
// Keyboard shortcuts
|
||||
form.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
|
||||
if event.Key() == tcell.KeyEscape {
|
||||
se.pages.RemovePage("import-database")
|
||||
se.pages.SwitchToPage("main")
|
||||
return nil
|
||||
}
|
||||
switch event.Rune() {
|
||||
case 'i':
|
||||
se.importAndMergeDatabase(currentFormat, filePath, connString, skipDomains, skipRelations, skipEnums, skipViews, skipSequences, skipTables)
|
||||
return nil
|
||||
case 'b':
|
||||
se.pages.RemovePage("import-database")
|
||||
se.pages.SwitchToPage("main")
|
||||
return nil
|
||||
case 'q':
|
||||
se.app.Stop()
|
||||
return nil
|
||||
}
|
||||
return event
|
||||
})
|
||||
|
||||
flex.AddItem(title, 1, 0, false).
|
||||
AddItem(form, 0, 1, true)
|
||||
|
||||
se.pages.AddAndSwitchToPage("import-database", flex, true)
|
||||
}
|
||||
|
||||
// importAndMergeDatabase imports and merges a database from the specified configuration
|
||||
func (se *SchemaEditor) importAndMergeDatabase(format, filePath, connString string, skipDomains, skipRelations, skipEnums, skipViews, skipSequences bool, skipTables string) {
|
||||
// Validate input
|
||||
if format == "pgsql" {
|
||||
if connString == "" {
|
||||
se.showErrorDialog("Error", "Connection string is required for PostgreSQL")
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if filePath == "" {
|
||||
se.showErrorDialog("Error", "File path is required for "+format)
|
||||
return
|
||||
}
|
||||
// Expand home directory
|
||||
if len(filePath) > 0 && filePath[0] == '~' {
|
||||
home, err := os.UserHomeDir()
|
||||
if err == nil {
|
||||
filePath = filepath.Join(home, filePath[1:])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create reader
|
||||
var reader readers.Reader
|
||||
switch format {
|
||||
case "dbml":
|
||||
reader = rdbml.NewReader(&readers.ReaderOptions{FilePath: filePath})
|
||||
case "dctx":
|
||||
reader = rdctx.NewReader(&readers.ReaderOptions{FilePath: filePath})
|
||||
case "drawdb":
|
||||
reader = rdrawdb.NewReader(&readers.ReaderOptions{FilePath: filePath})
|
||||
case "graphql":
|
||||
reader = rgraphql.NewReader(&readers.ReaderOptions{FilePath: filePath})
|
||||
case "json":
|
||||
reader = rjson.NewReader(&readers.ReaderOptions{FilePath: filePath})
|
||||
case "yaml":
|
||||
reader = ryaml.NewReader(&readers.ReaderOptions{FilePath: filePath})
|
||||
case "gorm":
|
||||
reader = rgorm.NewReader(&readers.ReaderOptions{FilePath: filePath})
|
||||
case "bun":
|
||||
reader = rbun.NewReader(&readers.ReaderOptions{FilePath: filePath})
|
||||
case "drizzle":
|
||||
reader = rdrizzle.NewReader(&readers.ReaderOptions{FilePath: filePath})
|
||||
case "prisma":
|
||||
reader = rprisma.NewReader(&readers.ReaderOptions{FilePath: filePath})
|
||||
case "typeorm":
|
||||
reader = rtypeorm.NewReader(&readers.ReaderOptions{FilePath: filePath})
|
||||
case "pgsql":
|
||||
reader = rpgsql.NewReader(&readers.ReaderOptions{ConnectionString: connString})
|
||||
default:
|
||||
se.showErrorDialog("Error", "Unsupported format: "+format)
|
||||
return
|
||||
}
|
||||
|
||||
// Read the database to import
|
||||
importDb, err := reader.ReadDatabase()
|
||||
if err != nil {
|
||||
se.showErrorDialog("Import Error", fmt.Sprintf("Failed to read database: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// Show confirmation dialog
|
||||
se.showImportConfirmation(importDb, skipDomains, skipRelations, skipEnums, skipViews, skipSequences, skipTables)
|
||||
}
|
||||
|
||||
// showImportConfirmation shows a confirmation dialog before merging
|
||||
func (se *SchemaEditor) showImportConfirmation(importDb *models.Database, skipDomains, skipRelations, skipEnums, skipViews, skipSequences bool, skipTables string) {
|
||||
confirmText := fmt.Sprintf("Import & Merge Database?\n\nSource: %s\nTarget: %s\n\nThis will add missing schemas, tables, columns, and other objects from the source to your database.\n\nExisting items will NOT be modified.",
|
||||
importDb.Name, se.db.Name)
|
||||
|
||||
modal := tview.NewModal().
|
||||
SetText(confirmText).
|
||||
AddButtons([]string{"Cancel", "Merge"}).
|
||||
SetDoneFunc(func(buttonIndex int, buttonLabel string) {
|
||||
se.pages.RemovePage("import-confirm")
|
||||
if buttonLabel == "Merge" {
|
||||
se.performMerge(importDb, skipDomains, skipRelations, skipEnums, skipViews, skipSequences, skipTables)
|
||||
}
|
||||
})
|
||||
|
||||
modal.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
|
||||
if event.Key() == tcell.KeyEscape {
|
||||
se.pages.RemovePage("import-confirm")
|
||||
se.pages.SwitchToPage("import-database")
|
||||
return nil
|
||||
}
|
||||
return event
|
||||
})
|
||||
|
||||
se.pages.AddAndSwitchToPage("import-confirm", modal, true)
|
||||
}
|
||||
|
||||
// performMerge performs the actual merge operation
|
||||
func (se *SchemaEditor) performMerge(importDb *models.Database, skipDomains, skipRelations, skipEnums, skipViews, skipSequences bool, skipTables string) {
|
||||
// Create merge options
|
||||
opts := &merge.MergeOptions{
|
||||
SkipDomains: skipDomains,
|
||||
SkipRelations: skipRelations,
|
||||
SkipEnums: skipEnums,
|
||||
SkipViews: skipViews,
|
||||
SkipSequences: skipSequences,
|
||||
}
|
||||
|
||||
// Parse skip tables
|
||||
if skipTables != "" {
|
||||
opts.SkipTableNames = parseSkipTablesUI(skipTables)
|
||||
}
|
||||
|
||||
// Perform the merge
|
||||
result := merge.MergeDatabases(se.db, importDb, opts)
|
||||
|
||||
// Update the database timestamp
|
||||
se.db.UpdateDate()
|
||||
|
||||
// Show success dialog with summary
|
||||
summary := merge.GetMergeSummary(result)
|
||||
se.showSuccessDialog("Import Complete", summary, func() {
|
||||
se.pages.RemovePage("import-database")
|
||||
se.pages.RemovePage("main")
|
||||
se.pages.AddPage("main", se.createMainMenu(), true, true)
|
||||
})
|
||||
}
|
||||
|
||||
// getImportHelpText returns the help text for the import screen
|
||||
func getImportHelpText() string {
|
||||
return `Import & Merge: Adds missing schemas, tables, columns, and other objects to your existing database.
|
||||
|
||||
File-based formats: dbml, dctx, drawdb, graphql, json, yaml, gorm, bun, drizzle, prisma, typeorm
|
||||
Database formats: pgsql (requires connection string)
|
||||
|
||||
Skip options: Check to exclude specific object types from the merge.`
|
||||
}
|
||||
|
||||
func parseSkipTablesUI(skipTablesStr string) map[string]bool {
|
||||
skipTables := make(map[string]bool)
|
||||
if skipTablesStr == "" {
|
||||
return skipTables
|
||||
}
|
||||
|
||||
// Split by comma and trim whitespace
|
||||
parts := strings.Split(skipTablesStr, ",")
|
||||
for _, part := range parts {
|
||||
trimmed := strings.TrimSpace(part)
|
||||
if trimmed != "" {
|
||||
// Store in lowercase for case-insensitive matching
|
||||
skipTables[strings.ToLower(trimmed)] = true
|
||||
}
|
||||
}
|
||||
|
||||
return skipTables
|
||||
}
|
||||
65
pkg/ui/main_menu.go
Normal file
65
pkg/ui/main_menu.go
Normal file
@@ -0,0 +1,65 @@
|
||||
package ui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/gdamore/tcell/v2"
|
||||
"github.com/rivo/tview"
|
||||
)
|
||||
|
||||
// createMainMenu creates the main menu screen
|
||||
func (se *SchemaEditor) createMainMenu() tview.Primitive {
|
||||
flex := tview.NewFlex().SetDirection(tview.FlexRow)
|
||||
|
||||
// Title with database name
|
||||
dbName := se.db.Name
|
||||
if dbName == "" {
|
||||
dbName = "Untitled"
|
||||
}
|
||||
updateAtStr := ""
|
||||
if se.db.UpdatedAt != "" {
|
||||
updateAtStr = fmt.Sprintf("Updated @ %s", se.db.UpdatedAt)
|
||||
}
|
||||
titleText := fmt.Sprintf("[::b]RelSpec Schema Editor\n[::d]Database: %s %s\n[::d]Press arrow keys to navigate, Enter to select", dbName, updateAtStr)
|
||||
title := tview.NewTextView().
|
||||
SetText(titleText).
|
||||
SetDynamicColors(true)
|
||||
|
||||
// Menu options
|
||||
menu := tview.NewList().
|
||||
AddItem("Edit Database", "Edit database name, description, and properties", 'e', func() {
|
||||
se.showEditDatabaseForm()
|
||||
}).
|
||||
AddItem("Manage Schemas", "View, create, edit, and delete schemas", 's', func() {
|
||||
se.showSchemaList()
|
||||
}).
|
||||
AddItem("Manage Tables", "View and manage tables in schemas", 't', func() {
|
||||
se.showTableList()
|
||||
}).
|
||||
AddItem("Manage Domains", "View, create, edit, and delete domains", 'd', func() {
|
||||
se.showDomainList()
|
||||
}).
|
||||
AddItem("Import & Merge", "Import and merge schema from another database", 'i', func() {
|
||||
se.showImportScreen()
|
||||
}).
|
||||
AddItem("Save Database", "Save database to file or database", 'w', func() {
|
||||
se.showSaveScreen()
|
||||
}).
|
||||
AddItem("Exit Editor", "Exit the editor", 'q', func() {
|
||||
se.app.Stop()
|
||||
})
|
||||
|
||||
menu.SetBorder(true).SetTitle(" Menu ").SetTitleAlign(tview.AlignLeft)
|
||||
menu.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
|
||||
if event.Key() == tcell.KeyEscape {
|
||||
se.showExitEditorConfirm()
|
||||
return nil
|
||||
}
|
||||
return event
|
||||
})
|
||||
|
||||
flex.AddItem(title, 5, 0, false).
|
||||
AddItem(menu, 0, 1, true)
|
||||
|
||||
return flex
|
||||
}
|
||||
55
pkg/ui/schema_dataops.go
Normal file
55
pkg/ui/schema_dataops.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package ui
|
||||
|
||||
import "git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||
|
||||
// Schema data operations - business logic for schema management
|
||||
|
||||
// CreateSchema creates a new schema and adds it to the database
|
||||
func (se *SchemaEditor) CreateSchema(name, description string) *models.Schema {
|
||||
newSchema := &models.Schema{
|
||||
Name: name,
|
||||
Description: description,
|
||||
Tables: make([]*models.Table, 0),
|
||||
Sequences: make([]*models.Sequence, 0),
|
||||
Enums: make([]*models.Enum, 0),
|
||||
}
|
||||
se.db.UpdateDate()
|
||||
se.db.Schemas = append(se.db.Schemas, newSchema)
|
||||
return newSchema
|
||||
}
|
||||
|
||||
// UpdateSchema updates an existing schema's properties
|
||||
func (se *SchemaEditor) UpdateSchema(schemaIndex int, name, owner, description string) {
|
||||
if schemaIndex < 0 || schemaIndex >= len(se.db.Schemas) {
|
||||
return
|
||||
}
|
||||
se.db.UpdateDate()
|
||||
schema := se.db.Schemas[schemaIndex]
|
||||
schema.Name = name
|
||||
schema.Owner = owner
|
||||
schema.Description = description
|
||||
schema.UpdateDate()
|
||||
}
|
||||
|
||||
// DeleteSchema removes a schema from the database
|
||||
func (se *SchemaEditor) DeleteSchema(schemaIndex int) bool {
|
||||
if schemaIndex < 0 || schemaIndex >= len(se.db.Schemas) {
|
||||
return false
|
||||
}
|
||||
se.db.UpdateDate()
|
||||
se.db.Schemas = append(se.db.Schemas[:schemaIndex], se.db.Schemas[schemaIndex+1:]...)
|
||||
return true
|
||||
}
|
||||
|
||||
// GetSchema returns a schema by index
|
||||
func (se *SchemaEditor) GetSchema(schemaIndex int) *models.Schema {
|
||||
if schemaIndex < 0 || schemaIndex >= len(se.db.Schemas) {
|
||||
return nil
|
||||
}
|
||||
return se.db.Schemas[schemaIndex]
|
||||
}
|
||||
|
||||
// GetAllSchemas returns all schemas
|
||||
func (se *SchemaEditor) GetAllSchemas() []*models.Schema {
|
||||
return se.db.Schemas
|
||||
}
|
||||
362
pkg/ui/schema_screens.go
Normal file
362
pkg/ui/schema_screens.go
Normal file
@@ -0,0 +1,362 @@
|
||||
package ui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/gdamore/tcell/v2"
|
||||
"github.com/rivo/tview"
|
||||
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||
)
|
||||
|
||||
// showSchemaList displays the schema management screen
|
||||
func (se *SchemaEditor) showSchemaList() {
|
||||
flex := tview.NewFlex().SetDirection(tview.FlexRow)
|
||||
|
||||
// Title
|
||||
title := tview.NewTextView().
|
||||
SetText("[::b]Manage Schemas").
|
||||
SetDynamicColors(true).
|
||||
SetTextAlign(tview.AlignCenter)
|
||||
|
||||
// Create schemas table
|
||||
schemaTable := tview.NewTable().SetBorders(true).SetSelectable(true, false).SetFixed(1, 0)
|
||||
|
||||
// Add header row with padding for full width
|
||||
headers := []string{"Name", "Sequence", "Total Tables", "Total Sequences", "Total Views", "GUID", "Description"}
|
||||
headerWidths := []int{20, 15, 20, 20, 15, 36} // Last column takes remaining space
|
||||
for i, header := range headers {
|
||||
padding := ""
|
||||
if i < len(headerWidths) {
|
||||
padding = strings.Repeat(" ", headerWidths[i]-len(header))
|
||||
}
|
||||
cell := tview.NewTableCell(header + padding).
|
||||
SetTextColor(tcell.ColorYellow).
|
||||
SetSelectable(false).
|
||||
SetAlign(tview.AlignLeft)
|
||||
schemaTable.SetCell(0, i, cell)
|
||||
}
|
||||
|
||||
// Add existing schemas
|
||||
for row, schema := range se.db.Schemas {
|
||||
schema := schema // capture for closure
|
||||
|
||||
// Name - pad to 20 chars
|
||||
nameStr := fmt.Sprintf("%-20s", schema.Name)
|
||||
nameCell := tview.NewTableCell(nameStr).SetSelectable(true)
|
||||
schemaTable.SetCell(row+1, 0, nameCell)
|
||||
|
||||
// Sequence - pad to 15 chars
|
||||
seqStr := fmt.Sprintf("%-15s", fmt.Sprintf("%d", schema.Sequence))
|
||||
seqCell := tview.NewTableCell(seqStr).SetSelectable(true)
|
||||
schemaTable.SetCell(row+1, 1, seqCell)
|
||||
|
||||
// Total Tables - pad to 20 chars
|
||||
tablesStr := fmt.Sprintf("%-20s", fmt.Sprintf("%d", len(schema.Tables)))
|
||||
tablesCell := tview.NewTableCell(tablesStr).SetSelectable(true)
|
||||
schemaTable.SetCell(row+1, 2, tablesCell)
|
||||
|
||||
// Total Sequences - pad to 20 chars
|
||||
sequencesStr := fmt.Sprintf("%-20s", fmt.Sprintf("%d", len(schema.Sequences)))
|
||||
sequencesCell := tview.NewTableCell(sequencesStr).SetSelectable(true)
|
||||
schemaTable.SetCell(row+1, 3, sequencesCell)
|
||||
|
||||
// Total Views - pad to 15 chars
|
||||
viewsStr := fmt.Sprintf("%-15s", fmt.Sprintf("%d", len(schema.Views)))
|
||||
viewsCell := tview.NewTableCell(viewsStr).SetSelectable(true)
|
||||
schemaTable.SetCell(row+1, 4, viewsCell)
|
||||
|
||||
// GUID - pad to 36 chars
|
||||
guidStr := fmt.Sprintf("%-36s", schema.GUID)
|
||||
guidCell := tview.NewTableCell(guidStr).SetSelectable(true)
|
||||
schemaTable.SetCell(row+1, 5, guidCell)
|
||||
|
||||
// Description - no padding, takes remaining space
|
||||
descCell := tview.NewTableCell(schema.Description).SetSelectable(true)
|
||||
schemaTable.SetCell(row+1, 6, descCell)
|
||||
}
|
||||
|
||||
schemaTable.SetTitle(" Schemas ").SetBorder(true).SetTitleAlign(tview.AlignLeft)
|
||||
|
||||
// Action buttons flex (define before input capture)
|
||||
btnFlex := tview.NewFlex()
|
||||
btnNewSchema := tview.NewButton("New Schema [n]").SetSelectedFunc(func() {
|
||||
se.showNewSchemaDialog()
|
||||
})
|
||||
btnBack := tview.NewButton("Back [b]").SetSelectedFunc(func() {
|
||||
se.pages.SwitchToPage("main")
|
||||
se.pages.RemovePage("schemas")
|
||||
})
|
||||
|
||||
// Set up button input captures for Tab/Shift+Tab navigation
|
||||
btnNewSchema.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
|
||||
if event.Key() == tcell.KeyBacktab {
|
||||
se.app.SetFocus(schemaTable)
|
||||
return nil
|
||||
}
|
||||
if event.Key() == tcell.KeyTab {
|
||||
se.app.SetFocus(btnBack)
|
||||
return nil
|
||||
}
|
||||
return event
|
||||
})
|
||||
|
||||
btnBack.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
|
||||
if event.Key() == tcell.KeyBacktab {
|
||||
se.app.SetFocus(btnNewSchema)
|
||||
return nil
|
||||
}
|
||||
if event.Key() == tcell.KeyTab {
|
||||
se.app.SetFocus(schemaTable)
|
||||
return nil
|
||||
}
|
||||
return event
|
||||
})
|
||||
|
||||
btnFlex.AddItem(btnNewSchema, 0, 1, true).
|
||||
AddItem(btnBack, 0, 1, false)
|
||||
|
||||
schemaTable.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
|
||||
if event.Key() == tcell.KeyEscape {
|
||||
se.pages.SwitchToPage("main")
|
||||
se.pages.RemovePage("schemas")
|
||||
return nil
|
||||
}
|
||||
if event.Key() == tcell.KeyTab {
|
||||
se.app.SetFocus(btnNewSchema)
|
||||
return nil
|
||||
}
|
||||
if event.Key() == tcell.KeyEnter {
|
||||
row, _ := schemaTable.GetSelection()
|
||||
if row > 0 && row <= len(se.db.Schemas) { // Skip header row
|
||||
schemaIndex := row - 1
|
||||
se.showSchemaEditor(schemaIndex, se.db.Schemas[schemaIndex])
|
||||
return nil
|
||||
}
|
||||
}
|
||||
if event.Rune() == 'n' {
|
||||
se.showNewSchemaDialog()
|
||||
return nil
|
||||
}
|
||||
if event.Rune() == 'b' {
|
||||
se.pages.SwitchToPage("main")
|
||||
se.pages.RemovePage("schemas")
|
||||
return nil
|
||||
}
|
||||
return event
|
||||
})
|
||||
|
||||
flex.AddItem(title, 1, 0, false).
|
||||
AddItem(schemaTable, 0, 1, true).
|
||||
AddItem(btnFlex, 1, 0, false)
|
||||
|
||||
se.pages.AddPage("schemas", flex, true, true)
|
||||
}
|
||||
|
||||
// showSchemaEditor shows the editor for a specific schema
|
||||
func (se *SchemaEditor) showSchemaEditor(index int, schema *models.Schema) {
|
||||
flex := tview.NewFlex().SetDirection(tview.FlexRow)
|
||||
|
||||
// Title
|
||||
title := tview.NewTextView().
|
||||
SetText(fmt.Sprintf("[::b]Schema: %s", schema.Name)).
|
||||
SetDynamicColors(true).
|
||||
SetTextAlign(tview.AlignCenter)
|
||||
|
||||
// Schema info display
|
||||
info := tview.NewTextView().SetDynamicColors(true)
|
||||
info.SetText(fmt.Sprintf("Tables: %d | Description: %s",
|
||||
len(schema.Tables), schema.Description))
|
||||
|
||||
// Table list
|
||||
tableList := tview.NewList().ShowSecondaryText(true)
|
||||
|
||||
for i, table := range schema.Tables {
|
||||
tableIndex := i
|
||||
table := table
|
||||
colCount := len(table.Columns)
|
||||
tableList.AddItem(table.Name, fmt.Sprintf("%d columns", colCount), rune('0'+i), func() {
|
||||
se.showTableEditor(index, tableIndex, table)
|
||||
})
|
||||
}
|
||||
|
||||
tableList.AddItem("[New Table]", "Add a new table to this schema", 'n', func() {
|
||||
se.showNewTableDialog(index)
|
||||
})
|
||||
|
||||
tableList.AddItem("[Edit Schema Info]", "Edit schema properties", 'e', func() {
|
||||
se.showEditSchemaDialog(index)
|
||||
})
|
||||
|
||||
tableList.AddItem("[Delete Schema]", "Delete this schema", 'd', func() {
|
||||
se.showDeleteSchemaConfirm(index)
|
||||
})
|
||||
|
||||
tableList.SetBorder(true).SetTitle(" Tables ").SetTitleAlign(tview.AlignLeft)
|
||||
|
||||
// Action buttons (define before input capture)
|
||||
btnFlex := tview.NewFlex()
|
||||
btnNewTable := tview.NewButton("New Table [n]").SetSelectedFunc(func() {
|
||||
se.showNewTableDialog(index)
|
||||
})
|
||||
btnBack := tview.NewButton("Back to Schemas [b]").SetSelectedFunc(func() {
|
||||
se.pages.RemovePage("schema-editor")
|
||||
se.pages.SwitchToPage("schemas")
|
||||
})
|
||||
|
||||
// Set up button input captures for Tab/Shift+Tab navigation
|
||||
btnNewTable.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
|
||||
if event.Key() == tcell.KeyBacktab {
|
||||
se.app.SetFocus(tableList)
|
||||
return nil
|
||||
}
|
||||
if event.Key() == tcell.KeyTab {
|
||||
se.app.SetFocus(btnBack)
|
||||
return nil
|
||||
}
|
||||
return event
|
||||
})
|
||||
|
||||
btnBack.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
|
||||
if event.Key() == tcell.KeyBacktab {
|
||||
se.app.SetFocus(btnNewTable)
|
||||
return nil
|
||||
}
|
||||
if event.Key() == tcell.KeyTab {
|
||||
se.app.SetFocus(tableList)
|
||||
return nil
|
||||
}
|
||||
return event
|
||||
})
|
||||
|
||||
tableList.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
|
||||
if event.Key() == tcell.KeyEscape {
|
||||
se.pages.RemovePage("schema-editor")
|
||||
se.pages.SwitchToPage("schemas")
|
||||
return nil
|
||||
}
|
||||
if event.Key() == tcell.KeyTab {
|
||||
se.app.SetFocus(btnNewTable)
|
||||
return nil
|
||||
}
|
||||
if event.Rune() == 'b' {
|
||||
se.pages.RemovePage("schema-editor")
|
||||
se.pages.SwitchToPage("schemas")
|
||||
}
|
||||
return event
|
||||
})
|
||||
|
||||
btnFlex.AddItem(btnNewTable, 0, 1, true).
|
||||
AddItem(btnBack, 0, 1, false)
|
||||
|
||||
flex.AddItem(title, 1, 0, false).
|
||||
AddItem(info, 2, 0, false).
|
||||
AddItem(tableList, 0, 1, true).
|
||||
AddItem(btnFlex, 1, 0, false)
|
||||
|
||||
se.pages.AddPage("schema-editor", flex, true, true)
|
||||
}
|
||||
|
||||
// showNewSchemaDialog shows dialog to create a new schema
|
||||
func (se *SchemaEditor) showNewSchemaDialog() {
|
||||
form := tview.NewForm()
|
||||
|
||||
schemaName := ""
|
||||
description := ""
|
||||
|
||||
form.AddInputField("Schema Name", "", 40, nil, func(value string) {
|
||||
schemaName = value
|
||||
})
|
||||
|
||||
form.AddInputField("Description", "", 40, nil, func(value string) {
|
||||
description = value
|
||||
})
|
||||
|
||||
form.AddButton("Save", func() {
|
||||
if schemaName == "" {
|
||||
return
|
||||
}
|
||||
|
||||
se.CreateSchema(schemaName, description)
|
||||
|
||||
se.pages.RemovePage("new-schema")
|
||||
se.pages.RemovePage("schemas")
|
||||
se.showSchemaList()
|
||||
})
|
||||
|
||||
form.AddButton("Back", func() {
|
||||
se.pages.RemovePage("new-schema")
|
||||
se.pages.RemovePage("schemas")
|
||||
se.showSchemaList()
|
||||
})
|
||||
|
||||
form.SetBorder(true).SetTitle(" New Schema ").SetTitleAlign(tview.AlignLeft)
|
||||
form.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
|
||||
if event.Key() == tcell.KeyEscape {
|
||||
se.showExitConfirmation("new-schema", "schemas")
|
||||
return nil
|
||||
}
|
||||
return event
|
||||
})
|
||||
|
||||
se.pages.AddPage("new-schema", form, true, true)
|
||||
}
|
||||
|
||||
// showEditSchemaDialog shows dialog to edit schema properties
|
||||
func (se *SchemaEditor) showEditSchemaDialog(schemaIndex int) {
|
||||
schema := se.db.Schemas[schemaIndex]
|
||||
form := tview.NewForm()
|
||||
|
||||
// Local variables to collect changes
|
||||
newName := schema.Name
|
||||
newOwner := schema.Owner
|
||||
newDescription := schema.Description
|
||||
newGUID := schema.GUID
|
||||
|
||||
form.AddInputField("Schema Name", schema.Name, 40, nil, func(value string) {
|
||||
newName = value
|
||||
})
|
||||
|
||||
form.AddInputField("Owner", schema.Owner, 40, nil, func(value string) {
|
||||
newOwner = value
|
||||
})
|
||||
|
||||
form.AddTextArea("Description", schema.Description, 40, 5, 0, func(value string) {
|
||||
newDescription = value
|
||||
})
|
||||
|
||||
form.AddInputField("GUID", schema.GUID, 40, nil, func(value string) {
|
||||
newGUID = value
|
||||
})
|
||||
|
||||
form.AddButton("Save", func() {
|
||||
// Apply changes using dataops
|
||||
se.UpdateSchema(schemaIndex, newName, newOwner, newDescription)
|
||||
se.db.Schemas[schemaIndex].GUID = newGUID
|
||||
|
||||
schema := se.db.Schemas[schemaIndex]
|
||||
se.pages.RemovePage("edit-schema")
|
||||
se.pages.RemovePage("schema-editor")
|
||||
se.showSchemaEditor(schemaIndex, schema)
|
||||
})
|
||||
|
||||
form.AddButton("Back", func() {
|
||||
// Discard changes - don't apply them
|
||||
schema := se.db.Schemas[schemaIndex]
|
||||
se.pages.RemovePage("edit-schema")
|
||||
se.pages.RemovePage("schema-editor")
|
||||
se.showSchemaEditor(schemaIndex, schema)
|
||||
})
|
||||
|
||||
form.SetBorder(true).SetTitle(" Edit Schema ").SetTitleAlign(tview.AlignLeft)
|
||||
form.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
|
||||
if event.Key() == tcell.KeyEscape {
|
||||
se.showExitConfirmation("edit-schema", "schema-editor")
|
||||
return nil
|
||||
}
|
||||
return event
|
||||
})
|
||||
|
||||
se.pages.AddPage("edit-schema", form, true, true)
|
||||
}
|
||||
88
pkg/ui/table_dataops.go
Normal file
88
pkg/ui/table_dataops.go
Normal file
@@ -0,0 +1,88 @@
|
||||
package ui
|
||||
|
||||
import "git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||
|
||||
// Table data operations - business logic for table management
|
||||
|
||||
// CreateTable creates a new table and adds it to a schema
|
||||
func (se *SchemaEditor) CreateTable(schemaIndex int, name, description string) *models.Table {
|
||||
if schemaIndex < 0 || schemaIndex >= len(se.db.Schemas) {
|
||||
return nil
|
||||
}
|
||||
|
||||
schema := se.db.Schemas[schemaIndex]
|
||||
newTable := &models.Table{
|
||||
Name: name,
|
||||
Schema: schema.Name,
|
||||
Description: description,
|
||||
Columns: make(map[string]*models.Column),
|
||||
Constraints: make(map[string]*models.Constraint),
|
||||
Indexes: make(map[string]*models.Index),
|
||||
}
|
||||
schema.UpdateDate()
|
||||
schema.Tables = append(schema.Tables, newTable)
|
||||
return newTable
|
||||
}
|
||||
|
||||
// UpdateTable updates an existing table's properties
|
||||
func (se *SchemaEditor) UpdateTable(schemaIndex, tableIndex int, name, description string) {
|
||||
if schemaIndex < 0 || schemaIndex >= len(se.db.Schemas) {
|
||||
return
|
||||
}
|
||||
|
||||
schema := se.db.Schemas[schemaIndex]
|
||||
if tableIndex < 0 || tableIndex >= len(schema.Tables) {
|
||||
return
|
||||
}
|
||||
schema.UpdateDate()
|
||||
table := schema.Tables[tableIndex]
|
||||
table.Name = name
|
||||
table.Description = description
|
||||
table.UpdateDate()
|
||||
}
|
||||
|
||||
// DeleteTable removes a table from a schema
|
||||
func (se *SchemaEditor) DeleteTable(schemaIndex, tableIndex int) bool {
|
||||
if schemaIndex < 0 || schemaIndex >= len(se.db.Schemas) {
|
||||
return false
|
||||
}
|
||||
|
||||
schema := se.db.Schemas[schemaIndex]
|
||||
if tableIndex < 0 || tableIndex >= len(schema.Tables) {
|
||||
return false
|
||||
}
|
||||
schema.UpdateDate()
|
||||
schema.Tables = append(schema.Tables[:tableIndex], schema.Tables[tableIndex+1:]...)
|
||||
return true
|
||||
}
|
||||
|
||||
// GetTable returns a table by schema and table index
|
||||
func (se *SchemaEditor) GetTable(schemaIndex, tableIndex int) *models.Table {
|
||||
if schemaIndex < 0 || schemaIndex >= len(se.db.Schemas) {
|
||||
return nil
|
||||
}
|
||||
|
||||
schema := se.db.Schemas[schemaIndex]
|
||||
if tableIndex < 0 || tableIndex >= len(schema.Tables) {
|
||||
return nil
|
||||
}
|
||||
|
||||
return schema.Tables[tableIndex]
|
||||
}
|
||||
|
||||
// GetAllTables returns all tables across all schemas
|
||||
func (se *SchemaEditor) GetAllTables() []*models.Table {
|
||||
var tables []*models.Table
|
||||
for _, schema := range se.db.Schemas {
|
||||
tables = append(tables, schema.Tables...)
|
||||
}
|
||||
return tables
|
||||
}
|
||||
|
||||
// GetTablesInSchema returns all tables in a specific schema
|
||||
func (se *SchemaEditor) GetTablesInSchema(schemaIndex int) []*models.Table {
|
||||
if schemaIndex < 0 || schemaIndex >= len(se.db.Schemas) {
|
||||
return nil
|
||||
}
|
||||
return se.db.Schemas[schemaIndex].Tables
|
||||
}
|
||||
546
pkg/ui/table_screens.go
Normal file
546
pkg/ui/table_screens.go
Normal file
@@ -0,0 +1,546 @@
|
||||
package ui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/gdamore/tcell/v2"
|
||||
"github.com/rivo/tview"
|
||||
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||
)
|
||||
|
||||
// showTableList displays all tables across all schemas
|
||||
func (se *SchemaEditor) showTableList() {
|
||||
flex := tview.NewFlex().SetDirection(tview.FlexRow)
|
||||
|
||||
// Title
|
||||
title := tview.NewTextView().
|
||||
SetText("[::b]All Tables").
|
||||
SetDynamicColors(true).
|
||||
SetTextAlign(tview.AlignCenter)
|
||||
|
||||
// Create tables table
|
||||
tableTable := tview.NewTable().SetBorders(true).SetSelectable(true, false).SetFixed(1, 0)
|
||||
|
||||
// Add header row with padding for full width
|
||||
headers := []string{"Name", "Schema", "Sequence", "Total Columns", "Total Relations", "Total Indexes", "GUID", "Description", "Comment"}
|
||||
headerWidths := []int{18, 15, 12, 14, 15, 14, 36, 0, 12} // Description gets remainder
|
||||
for i, header := range headers {
|
||||
padding := ""
|
||||
if i < len(headerWidths) && headerWidths[i] > 0 {
|
||||
padding = strings.Repeat(" ", headerWidths[i]-len(header))
|
||||
}
|
||||
cell := tview.NewTableCell(header + padding).
|
||||
SetTextColor(tcell.ColorYellow).
|
||||
SetSelectable(false).
|
||||
SetAlign(tview.AlignLeft)
|
||||
tableTable.SetCell(0, i, cell)
|
||||
}
|
||||
|
||||
var tables []*models.Table
|
||||
var tableLocations []struct{ schemaIdx, tableIdx int }
|
||||
|
||||
for si, schema := range se.db.Schemas {
|
||||
for ti, table := range schema.Tables {
|
||||
tables = append(tables, table)
|
||||
tableLocations = append(tableLocations, struct{ schemaIdx, tableIdx int }{si, ti})
|
||||
}
|
||||
}
|
||||
|
||||
for row, table := range tables {
|
||||
tableIdx := tableLocations[row]
|
||||
schema := se.db.Schemas[tableIdx.schemaIdx]
|
||||
|
||||
// Name - pad to 18 chars
|
||||
nameStr := fmt.Sprintf("%-18s", table.Name)
|
||||
nameCell := tview.NewTableCell(nameStr).SetSelectable(true)
|
||||
tableTable.SetCell(row+1, 0, nameCell)
|
||||
|
||||
// Schema - pad to 15 chars
|
||||
schemaStr := fmt.Sprintf("%-15s", schema.Name)
|
||||
schemaCell := tview.NewTableCell(schemaStr).SetSelectable(true)
|
||||
tableTable.SetCell(row+1, 1, schemaCell)
|
||||
|
||||
// Sequence - pad to 12 chars
|
||||
seqStr := fmt.Sprintf("%-12s", fmt.Sprintf("%d", table.Sequence))
|
||||
seqCell := tview.NewTableCell(seqStr).SetSelectable(true)
|
||||
tableTable.SetCell(row+1, 2, seqCell)
|
||||
|
||||
// Total Columns - pad to 14 chars
|
||||
colsStr := fmt.Sprintf("%-14s", fmt.Sprintf("%d", len(table.Columns)))
|
||||
colsCell := tview.NewTableCell(colsStr).SetSelectable(true)
|
||||
tableTable.SetCell(row+1, 3, colsCell)
|
||||
|
||||
// Total Relations - pad to 15 chars
|
||||
relsStr := fmt.Sprintf("%-15s", fmt.Sprintf("%d", len(table.Relationships)))
|
||||
relsCell := tview.NewTableCell(relsStr).SetSelectable(true)
|
||||
tableTable.SetCell(row+1, 4, relsCell)
|
||||
|
||||
// Total Indexes - pad to 14 chars
|
||||
idxStr := fmt.Sprintf("%-14s", fmt.Sprintf("%d", len(table.Indexes)))
|
||||
idxCell := tview.NewTableCell(idxStr).SetSelectable(true)
|
||||
tableTable.SetCell(row+1, 5, idxCell)
|
||||
|
||||
// GUID - pad to 36 chars
|
||||
guidStr := fmt.Sprintf("%-36s", table.GUID)
|
||||
guidCell := tview.NewTableCell(guidStr).SetSelectable(true)
|
||||
tableTable.SetCell(row+1, 6, guidCell)
|
||||
|
||||
// Description - no padding, takes remaining space
|
||||
descCell := tview.NewTableCell(table.Description).SetSelectable(true)
|
||||
tableTable.SetCell(row+1, 7, descCell)
|
||||
|
||||
// Comment - pad to 12 chars
|
||||
commentStr := fmt.Sprintf("%-12s", table.Comment)
|
||||
commentCell := tview.NewTableCell(commentStr).SetSelectable(true)
|
||||
tableTable.SetCell(row+1, 8, commentCell)
|
||||
}
|
||||
|
||||
tableTable.SetTitle(" All Tables ").SetBorder(true).SetTitleAlign(tview.AlignLeft)
|
||||
|
||||
// Action buttons (define before input capture)
|
||||
btnFlex := tview.NewFlex()
|
||||
btnNewTable := tview.NewButton("New Table [n]").SetSelectedFunc(func() {
|
||||
se.showNewTableDialogFromList()
|
||||
})
|
||||
btnBack := tview.NewButton("Back [b]").SetSelectedFunc(func() {
|
||||
se.pages.SwitchToPage("main")
|
||||
se.pages.RemovePage("tables")
|
||||
})
|
||||
|
||||
// Set up button input captures for Tab/Shift+Tab navigation
|
||||
btnNewTable.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
|
||||
if event.Key() == tcell.KeyBacktab {
|
||||
se.app.SetFocus(tableTable)
|
||||
return nil
|
||||
}
|
||||
if event.Key() == tcell.KeyTab {
|
||||
se.app.SetFocus(btnBack)
|
||||
return nil
|
||||
}
|
||||
return event
|
||||
})
|
||||
|
||||
btnBack.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
|
||||
if event.Key() == tcell.KeyBacktab {
|
||||
se.app.SetFocus(btnNewTable)
|
||||
return nil
|
||||
}
|
||||
if event.Key() == tcell.KeyTab {
|
||||
se.app.SetFocus(tableTable)
|
||||
return nil
|
||||
}
|
||||
return event
|
||||
})
|
||||
|
||||
btnFlex.AddItem(btnNewTable, 0, 1, true).
|
||||
AddItem(btnBack, 0, 1, false)
|
||||
|
||||
tableTable.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
|
||||
if event.Key() == tcell.KeyEscape {
|
||||
se.pages.SwitchToPage("main")
|
||||
se.pages.RemovePage("tables")
|
||||
return nil
|
||||
}
|
||||
if event.Key() == tcell.KeyTab {
|
||||
se.app.SetFocus(btnNewTable)
|
||||
return nil
|
||||
}
|
||||
if event.Key() == tcell.KeyEnter {
|
||||
row, _ := tableTable.GetSelection()
|
||||
if row > 0 && row <= len(tables) { // Skip header row
|
||||
tableIdx := tableLocations[row-1]
|
||||
se.showTableEditor(tableIdx.schemaIdx, tableIdx.tableIdx, tables[row-1])
|
||||
return nil
|
||||
}
|
||||
}
|
||||
if event.Rune() == 'n' {
|
||||
se.showNewTableDialogFromList()
|
||||
return nil
|
||||
}
|
||||
if event.Rune() == 'b' {
|
||||
se.pages.SwitchToPage("main")
|
||||
se.pages.RemovePage("tables")
|
||||
return nil
|
||||
}
|
||||
return event
|
||||
})
|
||||
|
||||
flex.AddItem(title, 1, 0, false).
|
||||
AddItem(tableTable, 0, 1, true).
|
||||
AddItem(btnFlex, 1, 0, false)
|
||||
|
||||
se.pages.AddPage("tables", flex, true, true)
|
||||
}
|
||||
|
||||
// showTableEditor shows editor for a specific table
|
||||
func (se *SchemaEditor) showTableEditor(schemaIndex, tableIndex int, table *models.Table) {
|
||||
flex := tview.NewFlex().SetDirection(tview.FlexRow)
|
||||
|
||||
// Title
|
||||
title := tview.NewTextView().
|
||||
SetText(fmt.Sprintf("[::b]Table: %s", table.Name)).
|
||||
SetDynamicColors(true).
|
||||
SetTextAlign(tview.AlignCenter)
|
||||
|
||||
// Table info
|
||||
info := tview.NewTextView().SetDynamicColors(true)
|
||||
info.SetText(fmt.Sprintf("Schema: %s | Columns: %d | Description: %s",
|
||||
table.Schema, len(table.Columns), table.Description))
|
||||
|
||||
// Create columns table
|
||||
colTable := tview.NewTable().SetBorders(true).SetSelectable(true, false).SetFixed(1, 0)
|
||||
|
||||
// Add header row with padding for full width
|
||||
headers := []string{"Name", "Type", "Default", "KeyType", "GUID", "Description"}
|
||||
headerWidths := []int{20, 18, 15, 15, 36} // Last column takes remaining space
|
||||
for i, header := range headers {
|
||||
padding := ""
|
||||
if i < len(headerWidths) {
|
||||
padding = strings.Repeat(" ", headerWidths[i]-len(header))
|
||||
}
|
||||
cell := tview.NewTableCell(header + padding).
|
||||
SetTextColor(tcell.ColorYellow).
|
||||
SetSelectable(false).
|
||||
SetAlign(tview.AlignLeft)
|
||||
colTable.SetCell(0, i, cell)
|
||||
}
|
||||
|
||||
// Get sorted column names
|
||||
columnNames := getColumnNames(table)
|
||||
for row, colName := range columnNames {
|
||||
column := table.Columns[colName]
|
||||
|
||||
// Name - pad to 20 chars
|
||||
nameStr := fmt.Sprintf("%-20s", colName)
|
||||
nameCell := tview.NewTableCell(nameStr).SetSelectable(true)
|
||||
colTable.SetCell(row+1, 0, nameCell)
|
||||
|
||||
// Type - pad to 18 chars
|
||||
typeStr := fmt.Sprintf("%-18s", column.Type)
|
||||
typeCell := tview.NewTableCell(typeStr).SetSelectable(true)
|
||||
colTable.SetCell(row+1, 1, typeCell)
|
||||
|
||||
// Default - pad to 15 chars
|
||||
defaultStr := ""
|
||||
if column.Default != nil {
|
||||
defaultStr = fmt.Sprintf("%v", column.Default)
|
||||
}
|
||||
defaultStr = fmt.Sprintf("%-15s", defaultStr)
|
||||
defaultCell := tview.NewTableCell(defaultStr).SetSelectable(true)
|
||||
colTable.SetCell(row+1, 2, defaultCell)
|
||||
|
||||
// KeyType - pad to 15 chars
|
||||
keyTypeStr := ""
|
||||
if column.IsPrimaryKey {
|
||||
keyTypeStr = "PRIMARY"
|
||||
} else if column.NotNull {
|
||||
keyTypeStr = "NOT NULL"
|
||||
}
|
||||
keyTypeStr = fmt.Sprintf("%-15s", keyTypeStr)
|
||||
keyTypeCell := tview.NewTableCell(keyTypeStr).SetSelectable(true)
|
||||
colTable.SetCell(row+1, 3, keyTypeCell)
|
||||
|
||||
// GUID - pad to 36 chars
|
||||
guidStr := fmt.Sprintf("%-36s", column.GUID)
|
||||
guidCell := tview.NewTableCell(guidStr).SetSelectable(true)
|
||||
colTable.SetCell(row+1, 4, guidCell)
|
||||
|
||||
// Description
|
||||
descCell := tview.NewTableCell(column.Description).SetSelectable(true)
|
||||
colTable.SetCell(row+1, 5, descCell)
|
||||
}
|
||||
|
||||
colTable.SetTitle(" Columns ").SetBorder(true).SetTitleAlign(tview.AlignLeft)
|
||||
|
||||
// Action buttons flex (define before input capture)
|
||||
btnFlex := tview.NewFlex()
|
||||
btnNewCol := tview.NewButton("Add Column [n]").SetSelectedFunc(func() {
|
||||
se.showNewColumnDialog(schemaIndex, tableIndex)
|
||||
})
|
||||
btnEditTable := tview.NewButton("Edit Table [e]").SetSelectedFunc(func() {
|
||||
se.showEditTableDialog(schemaIndex, tableIndex)
|
||||
})
|
||||
btnEditColumn := tview.NewButton("Edit Column [c]").SetSelectedFunc(func() {
|
||||
row, _ := colTable.GetSelection()
|
||||
if row > 0 && row <= len(columnNames) { // Skip header row
|
||||
colName := columnNames[row-1]
|
||||
column := table.Columns[colName]
|
||||
se.showColumnEditor(schemaIndex, tableIndex, row-1, column)
|
||||
}
|
||||
})
|
||||
btnDelTable := tview.NewButton("Delete Table [d]").SetSelectedFunc(func() {
|
||||
se.showDeleteTableConfirm(schemaIndex, tableIndex)
|
||||
})
|
||||
btnBack := tview.NewButton("Back to Schema [b]").SetSelectedFunc(func() {
|
||||
se.pages.RemovePage("table-editor")
|
||||
se.pages.SwitchToPage("schema-editor")
|
||||
})
|
||||
|
||||
// Set up button input captures for Tab/Shift+Tab navigation
|
||||
btnNewCol.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
|
||||
if event.Key() == tcell.KeyBacktab {
|
||||
se.app.SetFocus(colTable)
|
||||
return nil
|
||||
}
|
||||
if event.Key() == tcell.KeyTab {
|
||||
se.app.SetFocus(btnEditColumn)
|
||||
return nil
|
||||
}
|
||||
return event
|
||||
})
|
||||
|
||||
btnEditColumn.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
|
||||
if event.Key() == tcell.KeyBacktab {
|
||||
se.app.SetFocus(btnNewCol)
|
||||
return nil
|
||||
}
|
||||
if event.Key() == tcell.KeyTab {
|
||||
se.app.SetFocus(btnEditTable)
|
||||
return nil
|
||||
}
|
||||
return event
|
||||
})
|
||||
|
||||
btnEditTable.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
|
||||
if event.Key() == tcell.KeyBacktab {
|
||||
se.app.SetFocus(btnEditColumn)
|
||||
return nil
|
||||
}
|
||||
if event.Key() == tcell.KeyTab {
|
||||
se.app.SetFocus(btnDelTable)
|
||||
return nil
|
||||
}
|
||||
return event
|
||||
})
|
||||
|
||||
btnDelTable.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
|
||||
if event.Key() == tcell.KeyBacktab {
|
||||
se.app.SetFocus(btnEditTable)
|
||||
return nil
|
||||
}
|
||||
if event.Key() == tcell.KeyTab {
|
||||
se.app.SetFocus(btnBack)
|
||||
return nil
|
||||
}
|
||||
return event
|
||||
})
|
||||
|
||||
btnBack.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
|
||||
if event.Key() == tcell.KeyBacktab {
|
||||
se.app.SetFocus(btnDelTable)
|
||||
return nil
|
||||
}
|
||||
if event.Key() == tcell.KeyTab {
|
||||
se.app.SetFocus(colTable)
|
||||
return nil
|
||||
}
|
||||
return event
|
||||
})
|
||||
|
||||
btnFlex.AddItem(btnNewCol, 0, 1, true).
|
||||
AddItem(btnEditColumn, 0, 1, false).
|
||||
AddItem(btnEditTable, 0, 1, false).
|
||||
AddItem(btnDelTable, 0, 1, false).
|
||||
AddItem(btnBack, 0, 1, false)
|
||||
|
||||
colTable.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
|
||||
if event.Key() == tcell.KeyEscape {
|
||||
se.pages.SwitchToPage("schema-editor")
|
||||
se.pages.RemovePage("table-editor")
|
||||
return nil
|
||||
}
|
||||
if event.Key() == tcell.KeyTab {
|
||||
se.app.SetFocus(btnNewCol)
|
||||
return nil
|
||||
}
|
||||
if event.Key() == tcell.KeyEnter {
|
||||
row, _ := colTable.GetSelection()
|
||||
if row > 0 { // Skip header row
|
||||
colName := columnNames[row-1]
|
||||
column := table.Columns[colName]
|
||||
se.showColumnEditor(schemaIndex, tableIndex, row-1, column)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
if event.Rune() == 'c' {
|
||||
row, _ := colTable.GetSelection()
|
||||
if row > 0 && row <= len(columnNames) { // Skip header row
|
||||
colName := columnNames[row-1]
|
||||
column := table.Columns[colName]
|
||||
se.showColumnEditor(schemaIndex, tableIndex, row-1, column)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if event.Rune() == 'b' {
|
||||
se.pages.RemovePage("table-editor")
|
||||
se.pages.SwitchToPage("schema-editor")
|
||||
return nil
|
||||
}
|
||||
return event
|
||||
})
|
||||
|
||||
flex.AddItem(title, 1, 0, false).
|
||||
AddItem(info, 2, 0, false).
|
||||
AddItem(colTable, 0, 1, true).
|
||||
AddItem(btnFlex, 1, 0, false)
|
||||
|
||||
se.pages.AddPage("table-editor", flex, true, true)
|
||||
}
|
||||
|
||||
// showNewTableDialog shows dialog to create a new table
|
||||
func (se *SchemaEditor) showNewTableDialog(schemaIndex int) {
|
||||
form := tview.NewForm()
|
||||
|
||||
tableName := ""
|
||||
description := ""
|
||||
|
||||
form.AddInputField("Table Name", "", 40, nil, func(value string) {
|
||||
tableName = value
|
||||
})
|
||||
|
||||
form.AddInputField("Description", "", 40, nil, func(value string) {
|
||||
description = value
|
||||
})
|
||||
|
||||
form.AddButton("Save", func() {
|
||||
if tableName == "" {
|
||||
return
|
||||
}
|
||||
|
||||
se.CreateTable(schemaIndex, tableName, description)
|
||||
|
||||
schema := se.db.Schemas[schemaIndex]
|
||||
se.pages.RemovePage("new-table")
|
||||
se.pages.RemovePage("schema-editor")
|
||||
se.showSchemaEditor(schemaIndex, schema)
|
||||
})
|
||||
|
||||
form.AddButton("Back", func() {
|
||||
schema := se.db.Schemas[schemaIndex]
|
||||
se.pages.RemovePage("new-table")
|
||||
se.pages.RemovePage("schema-editor")
|
||||
se.showSchemaEditor(schemaIndex, schema)
|
||||
})
|
||||
|
||||
form.SetBorder(true).SetTitle(" New Table ").SetTitleAlign(tview.AlignLeft)
|
||||
form.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
|
||||
if event.Key() == tcell.KeyEscape {
|
||||
se.showExitConfirmation("new-table", "schema-editor")
|
||||
return nil
|
||||
}
|
||||
return event
|
||||
})
|
||||
|
||||
se.pages.AddPage("new-table", form, true, true)
|
||||
}
|
||||
|
||||
// showNewTableDialogFromList shows dialog to create a new table with schema selection
|
||||
func (se *SchemaEditor) showNewTableDialogFromList() {
|
||||
form := tview.NewForm()
|
||||
|
||||
tableName := ""
|
||||
description := ""
|
||||
selectedSchemaIdx := 0
|
||||
|
||||
// Create schema dropdown options
|
||||
schemaOptions := make([]string, len(se.db.Schemas))
|
||||
for i, schema := range se.db.Schemas {
|
||||
schemaOptions[i] = schema.Name
|
||||
}
|
||||
|
||||
form.AddInputField("Table Name", "", 40, nil, func(value string) {
|
||||
tableName = value
|
||||
})
|
||||
|
||||
form.AddDropDown("Schema", schemaOptions, 0, func(option string, optionIndex int) {
|
||||
selectedSchemaIdx = optionIndex
|
||||
})
|
||||
|
||||
form.AddInputField("Description", "", 40, nil, func(value string) {
|
||||
description = value
|
||||
})
|
||||
|
||||
form.AddButton("Save", func() {
|
||||
if tableName == "" {
|
||||
return
|
||||
}
|
||||
|
||||
se.CreateTable(selectedSchemaIdx, tableName, description)
|
||||
|
||||
se.pages.RemovePage("new-table-from-list")
|
||||
se.pages.RemovePage("tables")
|
||||
se.showTableList()
|
||||
})
|
||||
|
||||
form.AddButton("Back", func() {
|
||||
se.pages.RemovePage("new-table-from-list")
|
||||
se.pages.RemovePage("tables")
|
||||
se.showTableList()
|
||||
})
|
||||
|
||||
form.SetBorder(true).SetTitle(" New Table ").SetTitleAlign(tview.AlignLeft)
|
||||
form.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
|
||||
if event.Key() == tcell.KeyEscape {
|
||||
se.showExitConfirmation("new-table-from-list", "tables")
|
||||
return nil
|
||||
}
|
||||
return event
|
||||
})
|
||||
|
||||
se.pages.AddPage("new-table-from-list", form, true, true)
|
||||
}
|
||||
|
||||
// showEditTableDialog shows dialog to edit table properties
|
||||
func (se *SchemaEditor) showEditTableDialog(schemaIndex, tableIndex int) {
|
||||
table := se.db.Schemas[schemaIndex].Tables[tableIndex]
|
||||
form := tview.NewForm()
|
||||
|
||||
// Local variables to collect changes
|
||||
newName := table.Name
|
||||
newDescription := table.Description
|
||||
newGUID := table.GUID
|
||||
|
||||
form.AddInputField("Table Name", table.Name, 40, nil, func(value string) {
|
||||
newName = value
|
||||
})
|
||||
|
||||
form.AddTextArea("Description", table.Description, 40, 5, 0, func(value string) {
|
||||
newDescription = value
|
||||
})
|
||||
|
||||
form.AddInputField("GUID", table.GUID, 40, nil, func(value string) {
|
||||
newGUID = value
|
||||
})
|
||||
|
||||
form.AddButton("Save", func() {
|
||||
// Apply changes using dataops
|
||||
se.UpdateTable(schemaIndex, tableIndex, newName, newDescription)
|
||||
se.db.Schemas[schemaIndex].Tables[tableIndex].GUID = newGUID
|
||||
|
||||
table := se.db.Schemas[schemaIndex].Tables[tableIndex]
|
||||
se.pages.RemovePage("edit-table")
|
||||
se.pages.RemovePage("table-editor")
|
||||
se.showTableEditor(schemaIndex, tableIndex, table)
|
||||
})
|
||||
|
||||
form.AddButton("Back", func() {
|
||||
// Discard changes - don't apply them
|
||||
table := se.db.Schemas[schemaIndex].Tables[tableIndex]
|
||||
se.pages.RemovePage("edit-table")
|
||||
se.pages.RemovePage("table-editor")
|
||||
se.showTableEditor(schemaIndex, tableIndex, table)
|
||||
})
|
||||
|
||||
form.SetBorder(true).SetTitle(" Edit Table ").SetTitleAlign(tview.AlignLeft)
|
||||
form.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
|
||||
if event.Key() == tcell.KeyEscape {
|
||||
se.showExitConfirmation("edit-table", "table-editor")
|
||||
return nil
|
||||
}
|
||||
return event
|
||||
})
|
||||
|
||||
se.pages.AddPage("edit-table", form, true, true)
|
||||
}
|
||||
299
pkg/ui/ui_rules.md
Normal file
299
pkg/ui/ui_rules.md
Normal file
@@ -0,0 +1,299 @@
|
||||
# UI Rules and Guidelines
|
||||
|
||||
## Layout Requirements
|
||||
|
||||
All layouts / forms must be in seperate files regarding their domain or entity.
|
||||
|
||||
### Screen Layout Structure
|
||||
|
||||
All screens must follow this consistent layout:
|
||||
|
||||
1. **Title at the Top** (1 row, fixed height)
|
||||
- Centered bold text: `[::b]Title Text`
|
||||
- Use `tview.NewTextView()` with `SetTextAlign(tview.AlignCenter)`
|
||||
- Enable dynamic colors: `SetDynamicColors(true)`
|
||||
|
||||
2. **Content in the Middle** (flexible height)
|
||||
- Tables, lists, forms, or info displays
|
||||
- Uses flex weight of 1 for dynamic sizing
|
||||
|
||||
3. **Action Buttons at the Bottom** (1 row, fixed height)
|
||||
- Must be in a horizontal flex container
|
||||
- Action buttons before Back button
|
||||
- Back button is always last
|
||||
|
||||
### Form Layout Structure
|
||||
|
||||
All forms must follow this button order:
|
||||
|
||||
1. **Save Button** (always first)
|
||||
- Label: "Save"
|
||||
- Primary action that commits changes
|
||||
|
||||
2. **Delete Button** (optional, only for edit forms)
|
||||
- Label: "Delete"
|
||||
- Only shown when editing existing items (not for new items)
|
||||
- Must show confirmation dialog before deletion
|
||||
|
||||
3. **Back Button** (always last)
|
||||
- Label: "Back"
|
||||
- Returns to previous screen without saving
|
||||
|
||||
**Button Order Examples:**
|
||||
- **New Item Forms:** Save, Back
|
||||
- **Edit Item Forms:** Save, Delete, Back
|
||||
|
||||
## Tab Navigation
|
||||
|
||||
All screens must implement circular tab navigation:
|
||||
|
||||
1. **Tab Key** - Moves focus to the next focusable element
|
||||
2. **Shift+Tab (BackTab)** - Moves focus to the previous focusable element
|
||||
3. **At the End** - Tab cycles back to the first element
|
||||
4. **At the Start** - Shift+Tab cycles back to the last element
|
||||
|
||||
**Navigation Flow Pattern:**
|
||||
- Each widget must handle both Tab and BackTab
|
||||
- First widget: BackTab → Last widget, Tab → Second widget
|
||||
- Middle widgets: BackTab → Previous widget, Tab → Next widget
|
||||
- Last widget: BackTab → Previous widget, Tab → First widget
|
||||
|
||||
## Keyboard Shortcuts
|
||||
|
||||
### Standard Keys
|
||||
|
||||
- **ESC** - Cancel current operation or go back to previous screen
|
||||
- **Tab** - Move focus forward (circular)
|
||||
- **Shift+Tab** - Move focus backward (circular)
|
||||
- **Enter** - Activate/select current item in tables and lists
|
||||
|
||||
### Letter Key Shortcuts
|
||||
|
||||
- **'n'** - New (create new item)
|
||||
- **'b'** - Back (return to previous screen)
|
||||
- **'e'** - Edit (edit current item)
|
||||
- **'d'** - Delete (delete current item)
|
||||
- **'c'** - Edit Column (in table editor)
|
||||
- **'s'** - Manage Schemas (in main menu)
|
||||
- **'t'** - Manage Tables (in main menu)
|
||||
- **'q'** - Quit/Exit (in main menu)
|
||||
|
||||
## Consistency Requirements
|
||||
|
||||
1. **Layout Structure** - All screens: Title (top) → Content (middle) → Buttons (bottom)
|
||||
2. **Title Format** - Bold (`[::b]`), centered, dynamic colors enabled
|
||||
3. **Tables** - Fixed headers (row 0), borders enabled, selectable rows
|
||||
4. **Buttons** - Include keyboard shortcuts in labels (e.g., "Back [b]")
|
||||
5. **Forms** - Button order: Save, Delete (if edit), Back
|
||||
6. **Destructive Actions** - Always show confirmation dialogs
|
||||
7. **ESC Key** - All screens support ESC to go back
|
||||
8. **Action Buttons** - Positioned before Back button, in logical order
|
||||
9. **Data Refresh** - Always refresh the previous screen when returning from a form or dialog
|
||||
|
||||
## Widget Naming Conventions
|
||||
|
||||
- **Tables:** `schemaTable`, `tableTable`, `colTable`
|
||||
- **Buttons:** Prefix with `btn` (e.g., `btnBack`, `btnDelete`, `btnNewSchema`)
|
||||
- **Flex containers:** `btnFlex` for button containers, `flex` for main layout
|
||||
- **Forms:** `form`
|
||||
- **Lists:** `list`, `tableList`
|
||||
- **Text views:** `title`, `info`
|
||||
- Use camelCase for all variable names
|
||||
|
||||
## Page Naming Conventions
|
||||
|
||||
Use descriptive kebab-case names:
|
||||
|
||||
- **Main screens:** `main`, `schemas`, `tables`, `schema-editor`, `table-editor`, `column-editor`
|
||||
- **Load/Save screens:** `load-database`, `save-database`
|
||||
- **Creation dialogs:** `new-schema`, `new-table`, `new-column`, `new-table-from-list`
|
||||
- **Edit dialogs:** `edit-schema`, `edit-table`
|
||||
- **Confirmations:** `confirm-delete-schema`, `confirm-delete-table`, `confirm-delete-column`
|
||||
- **Exit confirmations:** `exit-confirm`, `exit-editor-confirm`
|
||||
- **Status dialogs:** `error-dialog`, `success-dialog`
|
||||
|
||||
## Dialog and Confirmation Rules
|
||||
|
||||
### Confirmation Dialogs
|
||||
|
||||
1. **Delete Confirmations** - Required for all destructive actions
|
||||
- Show item name in confirmation text
|
||||
- Buttons: "Cancel", "Delete"
|
||||
- ESC key dismisses dialog
|
||||
|
||||
2. **Exit Confirmations** - Required when exiting forms with potential unsaved changes
|
||||
- Text: "Exit without saving changes?"
|
||||
- Buttons: "Cancel", "No, exit without saving"
|
||||
- ESC key confirms exit
|
||||
|
||||
3. **Save Confirmations** - Optional, based on context
|
||||
- Use for critical data changes
|
||||
- Clear description of what will be saved
|
||||
|
||||
### Dialog Behavior
|
||||
|
||||
- All dialogs must capture ESC key for dismissal
|
||||
- Modal dialogs overlay current screen
|
||||
- Confirmation dialogs use `tview.NewModal()`
|
||||
- Remove dialog page after action completes
|
||||
|
||||
## Data Refresh Rules
|
||||
|
||||
When returning from any form or dialog, the previous screen must be refreshed to show updated data. If Tables exists in the screen, their data must be updated:
|
||||
|
||||
1. **After Save** - Remove and recreate the previous screen to display updated data
|
||||
2. **After Delete** - Remove and recreate the previous screen to display remaining data
|
||||
3. **After Cancel/Back** - Remove and recreate the previous screen (data may have changed)
|
||||
4. **Implementation Pattern** - Remove the current page, remove the previous page, then recreate the previous page with fresh data
|
||||
|
||||
**Why This Matters:**
|
||||
- Ensures users see their changes immediately
|
||||
- Prevents stale data from being displayed
|
||||
- Maintains data consistency across the UI
|
||||
- Avoids confusion from seeing outdated information
|
||||
|
||||
**Example Flow:**
|
||||
```
|
||||
User on Schema List → Opens Edit Schema Form → Saves Changes →
|
||||
Returns to Schema List (refreshed with updated schema data)
|
||||
```
|
||||
|
||||
## Big Loading/Saving Operations
|
||||
|
||||
When loading big changes, files or data, always give a load completed or load error dialog.
|
||||
Do the same with saving.
|
||||
This informs the user what happens.
|
||||
When data is dirty, always ask the user to save when trying to exit.
|
||||
|
||||
### Load/Save Screens
|
||||
|
||||
- **Load Screen** (`load-database`) - Shown when no source is specified via command line
|
||||
- Format dropdown (dbml, dctx, drawdb, graphql, json, yaml, gorm, bun, drizzle, prisma, typeorm, pgsql)
|
||||
- File path input (for file-based formats)
|
||||
- Connection string input (for database formats like pgsql)
|
||||
- Load button [l] - Loads the database
|
||||
- Create New button [n] - Creates a new empty database
|
||||
- Exit button [q] - Exits the application
|
||||
- ESC key exits the application
|
||||
|
||||
- **Save Screen** (`save-database`) - Accessible from main menu with 'w' key
|
||||
- Format dropdown (same as load screen)
|
||||
- File path input
|
||||
- Help text explaining format requirements
|
||||
- Save button [s] - Saves the database
|
||||
- Back button [b] - Returns to main menu
|
||||
- ESC key returns to main menu
|
||||
- Pre-populated with existing save configuration if available
|
||||
|
||||
### Status Dialogs
|
||||
|
||||
- **Error Dialog** (`error-dialog`) - Shows error messages with OK button and ESC to dismiss
|
||||
- **Success Dialog** (`success-dialog`) - Shows success messages with OK button, ESC to dismiss, and optional callback on close
|
||||
|
||||
## Screen Organization
|
||||
|
||||
Organize UI code into these files:
|
||||
|
||||
### UI Files (Screens and Dialogs)
|
||||
|
||||
- **editor.go** - Core `SchemaEditor` struct, constructor, `Run()` method, helper functions
|
||||
- **main_menu.go** - Main menu screen
|
||||
- **load_save_screens.go** - Database load and save screens
|
||||
- **database_screens.go** - Database edit form
|
||||
- **schema_screens.go** - Schema list, schema editor, new/edit schema dialogs
|
||||
- **table_screens.go** - Tables list, table editor, new/edit table dialogs
|
||||
- **column_screens.go** - Column editor, new column dialog
|
||||
- **domain_screens.go** - Domain list, domain editor, new/edit domain dialogs
|
||||
- **dialogs.go** - Confirmation dialogs (exit, delete)
|
||||
|
||||
### Data Operations Files (Business Logic)
|
||||
|
||||
- **schema_dataops.go** - Schema CRUD operations (Create, Read, Update, Delete)
|
||||
- **table_dataops.go** - Table CRUD operations
|
||||
- **column_dataops.go** - Column CRUD operations
|
||||
|
||||
## Code Separation Rules
|
||||
|
||||
### UI vs Business Logic
|
||||
|
||||
1. **UI Files** - Handle only presentation and user interaction
|
||||
- Display data in tables, lists, and forms
|
||||
- Capture user input
|
||||
- Navigate between screens
|
||||
- Show/hide dialogs
|
||||
- Call dataops methods for actual data changes
|
||||
|
||||
2. **Dataops Files** - Handle only business logic and data manipulation
|
||||
- Create, read, update, delete operations
|
||||
- Data validation
|
||||
- Data structure manipulation
|
||||
- Return created/updated objects or success/failure status
|
||||
- No UI code or tview references
|
||||
|
||||
### Implementation Pattern
|
||||
|
||||
#### Creating New Items
|
||||
|
||||
**Bad (Direct Data Manipulation in UI):**
|
||||
```go
|
||||
form.AddButton("Save", func() {
|
||||
schema := &models.Schema{Name: name, Description: desc, ...}
|
||||
se.db.Schemas = append(se.db.Schemas, schema)
|
||||
})
|
||||
```
|
||||
|
||||
**Good (Using Dataops Methods):**
|
||||
```go
|
||||
form.AddButton("Save", func() {
|
||||
se.CreateSchema(name, description)
|
||||
})
|
||||
```
|
||||
|
||||
#### Editing Existing Items
|
||||
|
||||
**Bad (Modifying Data in onChange Callbacks):**
|
||||
```go
|
||||
form.AddInputField("Name", column.Name, 40, nil, func(value string) {
|
||||
column.Name = value // Changes immediately as user types!
|
||||
})
|
||||
form.AddButton("Save", func() {
|
||||
// Data already changed, just refresh screen
|
||||
})
|
||||
```
|
||||
|
||||
**Good (Local Variables + Dataops on Save):**
|
||||
```go
|
||||
// Store original values
|
||||
originalName := column.Name
|
||||
newName := column.Name
|
||||
|
||||
form.AddInputField("Name", column.Name, 40, nil, func(value string) {
|
||||
newName = value // Store in local variable
|
||||
})
|
||||
|
||||
form.AddButton("Save", func() {
|
||||
// Apply changes only when Save is clicked
|
||||
se.UpdateColumn(schemaIndex, tableIndex, originalName, newName, ...)
|
||||
// Then refresh screen
|
||||
})
|
||||
|
||||
form.AddButton("Back", func() {
|
||||
// Discard changes - don't apply local variables
|
||||
// Just refresh screen
|
||||
})
|
||||
```
|
||||
|
||||
### Why This Matters
|
||||
|
||||
**Edit Forms Must Use Local Variables:**
|
||||
1. **Deferred Changes** - Changes only apply when Save is clicked
|
||||
2. **Cancellable** - Back button discards changes without saving
|
||||
3. **Handles Renames** - Original name preserved to update map keys correctly
|
||||
4. **User Expectations** - Save means "commit changes", Back means "cancel"
|
||||
|
||||
This separation ensures:
|
||||
- Cleaner, more maintainable code
|
||||
- Reusable business logic
|
||||
- Easier testing
|
||||
- Clear separation of concerns
|
||||
- Proper change management (save vs cancel)
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/writers"
|
||||
)
|
||||
|
||||
// TemplateData represents the data passed to the template for code generation
|
||||
@@ -105,19 +106,20 @@ func (td *TemplateData) FinalizeImports() {
|
||||
}
|
||||
|
||||
// NewModelData creates a new ModelData from a models.Table
|
||||
func NewModelData(table *models.Table, schema string, typeMapper *TypeMapper) *ModelData {
|
||||
tableName := table.Name
|
||||
if schema != "" {
|
||||
tableName = schema + "." + table.Name
|
||||
}
|
||||
func NewModelData(table *models.Table, schema string, typeMapper *TypeMapper, flattenSchema bool) *ModelData {
|
||||
tableName := writers.QualifiedTableName(schema, table.Name, flattenSchema)
|
||||
|
||||
// Generate model name: singularize and convert to PascalCase
|
||||
// Generate model name: Model + Schema + Table (all PascalCase)
|
||||
singularTable := Singularize(table.Name)
|
||||
modelName := SnakeCaseToPascalCase(singularTable)
|
||||
tablePart := SnakeCaseToPascalCase(singularTable)
|
||||
|
||||
// Add "Model" prefix if not already present
|
||||
if !hasModelPrefix(modelName) {
|
||||
modelName = "Model" + modelName
|
||||
// Include schema name in model name
|
||||
var modelName string
|
||||
if schema != "" {
|
||||
schemaPart := SnakeCaseToPascalCase(schema)
|
||||
modelName = "Model" + schemaPart + tablePart
|
||||
} else {
|
||||
modelName = "Model" + tablePart
|
||||
}
|
||||
|
||||
model := &ModelData{
|
||||
@@ -133,8 +135,10 @@ func NewModelData(table *models.Table, schema string, typeMapper *TypeMapper) *M
|
||||
// Find primary key
|
||||
for _, col := range table.Columns {
|
||||
if col.IsPrimaryKey {
|
||||
model.PrimaryKeyField = SnakeCaseToPascalCase(col.Name)
|
||||
model.IDColumnName = col.Name
|
||||
// Sanitize column name to remove backticks
|
||||
safeName := writers.SanitizeStructTagValue(col.Name)
|
||||
model.PrimaryKeyField = SnakeCaseToPascalCase(safeName)
|
||||
model.IDColumnName = safeName
|
||||
// Check if PK type is a SQL type (contains resolvespec_common or sql_types)
|
||||
goType := typeMapper.SQLTypeToGoType(col.Type, col.NotNull)
|
||||
model.PrimaryKeyIsSQL = strings.Contains(goType, "resolvespec_common") || strings.Contains(goType, "sql_types")
|
||||
@@ -146,6 +150,8 @@ func NewModelData(table *models.Table, schema string, typeMapper *TypeMapper) *M
|
||||
columns := sortColumns(table.Columns)
|
||||
for _, col := range columns {
|
||||
field := columnToField(col, table, typeMapper)
|
||||
// Check for name collision with generated methods and rename if needed
|
||||
field.Name = resolveFieldNameCollision(field.Name)
|
||||
model.Fields = append(model.Fields, field)
|
||||
}
|
||||
|
||||
@@ -154,10 +160,13 @@ func NewModelData(table *models.Table, schema string, typeMapper *TypeMapper) *M
|
||||
|
||||
// columnToField converts a models.Column to FieldData
|
||||
func columnToField(col *models.Column, table *models.Table, typeMapper *TypeMapper) *FieldData {
|
||||
fieldName := SnakeCaseToPascalCase(col.Name)
|
||||
// Sanitize column name first to remove backticks before generating field name
|
||||
safeName := writers.SanitizeStructTagValue(col.Name)
|
||||
fieldName := SnakeCaseToPascalCase(safeName)
|
||||
goType := typeMapper.SQLTypeToGoType(col.Type, col.NotNull)
|
||||
bunTag := typeMapper.BuildBunTag(col, table)
|
||||
jsonTag := col.Name // Use column name for JSON tag
|
||||
// Use same sanitized name for JSON tag
|
||||
jsonTag := safeName
|
||||
|
||||
return &FieldData{
|
||||
Name: fieldName,
|
||||
@@ -184,9 +193,28 @@ func formatComment(description, comment string) string {
|
||||
return comment
|
||||
}
|
||||
|
||||
// hasModelPrefix checks if a name already has "Model" prefix
|
||||
func hasModelPrefix(name string) bool {
|
||||
return len(name) >= 5 && name[:5] == "Model"
|
||||
// resolveFieldNameCollision checks if a field name conflicts with generated method names
|
||||
// and adds an underscore suffix if there's a collision
|
||||
func resolveFieldNameCollision(fieldName string) string {
|
||||
// List of method names that are generated by the template
|
||||
reservedNames := map[string]bool{
|
||||
"TableName": true,
|
||||
"TableNameOnly": true,
|
||||
"SchemaName": true,
|
||||
"GetID": true,
|
||||
"GetIDStr": true,
|
||||
"SetID": true,
|
||||
"UpdateID": true,
|
||||
"GetIDName": true,
|
||||
"GetPrefix": true,
|
||||
}
|
||||
|
||||
// Check if field name conflicts with a reserved method name
|
||||
if reservedNames[fieldName] {
|
||||
return fieldName + "_"
|
||||
}
|
||||
|
||||
return fieldName
|
||||
}
|
||||
|
||||
// sortColumns sorts columns by sequence, then by name
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/writers"
|
||||
)
|
||||
|
||||
// TypeMapper handles type conversions between SQL and Go types for Bun
|
||||
@@ -164,11 +165,14 @@ func (tm *TypeMapper) BuildBunTag(column *models.Column, table *models.Table) st
|
||||
var parts []string
|
||||
|
||||
// Column name comes first (no prefix)
|
||||
parts = append(parts, column.Name)
|
||||
// Sanitize to remove backticks which would break struct tag syntax
|
||||
safeName := writers.SanitizeStructTagValue(column.Name)
|
||||
parts = append(parts, safeName)
|
||||
|
||||
// Add type if specified
|
||||
if column.Type != "" {
|
||||
typeStr := column.Type
|
||||
// Sanitize type to remove backticks
|
||||
typeStr := writers.SanitizeStructTagValue(column.Type)
|
||||
if column.Length > 0 {
|
||||
typeStr = fmt.Sprintf("%s(%d)", typeStr, column.Length)
|
||||
} else if column.Precision > 0 {
|
||||
@@ -188,12 +192,17 @@ func (tm *TypeMapper) BuildBunTag(column *models.Column, table *models.Table) st
|
||||
|
||||
// Default value
|
||||
if column.Default != nil {
|
||||
parts = append(parts, fmt.Sprintf("default:%v", column.Default))
|
||||
// Sanitize default value to remove backticks
|
||||
safeDefault := writers.SanitizeStructTagValue(fmt.Sprintf("%v", column.Default))
|
||||
parts = append(parts, fmt.Sprintf("default:%s", safeDefault))
|
||||
}
|
||||
|
||||
// Nullable (Bun uses nullzero for nullable fields)
|
||||
// and notnull tag for explicitly non-nullable fields
|
||||
if !column.NotNull && !column.IsPrimaryKey {
|
||||
parts = append(parts, "nullzero")
|
||||
} else if column.NotNull && !column.IsPrimaryKey {
|
||||
parts = append(parts, "notnull")
|
||||
}
|
||||
|
||||
// Check for indexes (unique indexes should be added to tag)
|
||||
@@ -260,7 +269,7 @@ func (tm *TypeMapper) NeedsFmtImport(generateGetIDStr bool) bool {
|
||||
|
||||
// GetSQLTypesImport returns the import path for sql_types (ResolveSpec common)
|
||||
func (tm *TypeMapper) GetSQLTypesImport() string {
|
||||
return "github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
return "github.com/bitechdev/ResolveSpec/pkg/spectypes"
|
||||
}
|
||||
|
||||
// GetBunImport returns the import path for Bun
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"fmt"
|
||||
"go/format"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
@@ -85,7 +86,7 @@ func (w *Writer) writeSingleFile(db *models.Database) error {
|
||||
// Collect all models
|
||||
for _, schema := range db.Schemas {
|
||||
for _, table := range schema.Tables {
|
||||
modelData := NewModelData(table, schema.Name, w.typeMapper)
|
||||
modelData := NewModelData(table, schema.Name, w.typeMapper, w.options.FlattenSchema)
|
||||
|
||||
// Add relationship fields
|
||||
w.addRelationshipFields(modelData, table, schema, db)
|
||||
@@ -124,7 +125,16 @@ func (w *Writer) writeSingleFile(db *models.Database) error {
|
||||
}
|
||||
|
||||
// Write output
|
||||
return w.writeOutput(formatted)
|
||||
if err := w.writeOutput(formatted); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Run go fmt on the output file
|
||||
if w.options.OutputPath != "" {
|
||||
w.runGoFmt(w.options.OutputPath)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// writeMultiFile writes each table to a separate file
|
||||
@@ -171,7 +181,7 @@ func (w *Writer) writeMultiFile(db *models.Database) error {
|
||||
templateData.AddImport(fmt.Sprintf("resolvespec_common \"%s\"", w.typeMapper.GetSQLTypesImport()))
|
||||
|
||||
// Create model data
|
||||
modelData := NewModelData(table, schema.Name, w.typeMapper)
|
||||
modelData := NewModelData(table, schema.Name, w.typeMapper, w.options.FlattenSchema)
|
||||
|
||||
// Add relationship fields
|
||||
w.addRelationshipFields(modelData, table, schema, db)
|
||||
@@ -207,13 +217,19 @@ func (w *Writer) writeMultiFile(db *models.Database) error {
|
||||
}
|
||||
|
||||
// Generate filename: sql_{schema}_{table}.go
|
||||
filename := fmt.Sprintf("sql_%s_%s.go", schema.Name, table.Name)
|
||||
// Sanitize schema and table names to remove quotes, comments, and invalid characters
|
||||
safeSchemaName := writers.SanitizeFilename(schema.Name)
|
||||
safeTableName := writers.SanitizeFilename(table.Name)
|
||||
filename := fmt.Sprintf("sql_%s_%s.go", safeSchemaName, safeTableName)
|
||||
filepath := filepath.Join(w.options.OutputPath, filename)
|
||||
|
||||
// Write file
|
||||
if err := os.WriteFile(filepath, []byte(formatted), 0644); err != nil {
|
||||
return fmt.Errorf("failed to write file %s: %w", filename, err)
|
||||
}
|
||||
|
||||
// Run go fmt on the generated file
|
||||
w.runGoFmt(filepath)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -222,6 +238,9 @@ func (w *Writer) writeMultiFile(db *models.Database) error {
|
||||
|
||||
// addRelationshipFields adds relationship fields to the model based on foreign keys
|
||||
func (w *Writer) addRelationshipFields(modelData *ModelData, table *models.Table, schema *models.Schema, db *models.Database) {
|
||||
// Track used field names to detect duplicates
|
||||
usedFieldNames := make(map[string]int)
|
||||
|
||||
// For each foreign key in this table, add a belongs-to/has-one relationship
|
||||
for _, constraint := range table.Constraints {
|
||||
if constraint.Type != models.ForeignKeyConstraint {
|
||||
@@ -235,8 +254,9 @@ func (w *Writer) addRelationshipFields(modelData *ModelData, table *models.Table
|
||||
}
|
||||
|
||||
// Create relationship field (has-one in Bun, similar to belongs-to in GORM)
|
||||
refModelName := w.getModelName(constraint.ReferencedTable)
|
||||
fieldName := w.generateRelationshipFieldName(constraint.ReferencedTable)
|
||||
refModelName := w.getModelName(constraint.ReferencedSchema, constraint.ReferencedTable)
|
||||
fieldName := w.generateHasOneFieldName(constraint)
|
||||
fieldName = w.ensureUniqueFieldName(fieldName, usedFieldNames)
|
||||
relationTag := w.typeMapper.BuildRelationshipTag(constraint, "has-one")
|
||||
|
||||
modelData.AddRelationshipField(&FieldData{
|
||||
@@ -263,8 +283,9 @@ func (w *Writer) addRelationshipFields(modelData *ModelData, table *models.Table
|
||||
// Check if this constraint references our table
|
||||
if constraint.ReferencedTable == table.Name && constraint.ReferencedSchema == schema.Name {
|
||||
// Add has-many relationship
|
||||
otherModelName := w.getModelName(otherTable.Name)
|
||||
fieldName := w.generateRelationshipFieldName(otherTable.Name) + "s" // Pluralize
|
||||
otherModelName := w.getModelName(otherSchema.Name, otherTable.Name)
|
||||
fieldName := w.generateHasManyFieldName(constraint, otherSchema.Name, otherTable.Name)
|
||||
fieldName = w.ensureUniqueFieldName(fieldName, usedFieldNames)
|
||||
relationTag := w.typeMapper.BuildRelationshipTag(constraint, "has-many")
|
||||
|
||||
modelData.AddRelationshipField(&FieldData{
|
||||
@@ -295,22 +316,77 @@ func (w *Writer) findTable(schemaName, tableName string, db *models.Database) *m
|
||||
return nil
|
||||
}
|
||||
|
||||
// getModelName generates the model name from a table name
|
||||
func (w *Writer) getModelName(tableName string) string {
|
||||
// getModelName generates the model name from schema and table name
|
||||
func (w *Writer) getModelName(schemaName, tableName string) string {
|
||||
singular := Singularize(tableName)
|
||||
modelName := SnakeCaseToPascalCase(singular)
|
||||
tablePart := SnakeCaseToPascalCase(singular)
|
||||
|
||||
if !hasModelPrefix(modelName) {
|
||||
modelName = "Model" + modelName
|
||||
// Include schema name in model name
|
||||
var modelName string
|
||||
if schemaName != "" {
|
||||
schemaPart := SnakeCaseToPascalCase(schemaName)
|
||||
modelName = "Model" + schemaPart + tablePart
|
||||
} else {
|
||||
modelName = "Model" + tablePart
|
||||
}
|
||||
|
||||
return modelName
|
||||
}
|
||||
|
||||
// generateRelationshipFieldName generates a field name for a relationship
|
||||
func (w *Writer) generateRelationshipFieldName(tableName string) string {
|
||||
// Use just the prefix (3 letters) for relationship fields
|
||||
return GeneratePrefix(tableName)
|
||||
// generateHasOneFieldName generates a field name for has-one relationships
|
||||
// Uses the foreign key column name for uniqueness
|
||||
func (w *Writer) generateHasOneFieldName(constraint *models.Constraint) string {
|
||||
// Use the foreign key column name to ensure uniqueness
|
||||
// If there are multiple columns, use the first one
|
||||
if len(constraint.Columns) > 0 {
|
||||
columnName := constraint.Columns[0]
|
||||
// Convert to PascalCase for proper Go field naming
|
||||
// e.g., "rid_filepointer_request" -> "RelRIDFilepointerRequest"
|
||||
return "Rel" + SnakeCaseToPascalCase(columnName)
|
||||
}
|
||||
|
||||
// Fallback to table-based prefix if no columns defined
|
||||
return "Rel" + GeneratePrefix(constraint.ReferencedTable)
|
||||
}
|
||||
|
||||
// generateHasManyFieldName generates a field name for has-many relationships
|
||||
// Uses the foreign key column name + source table name to avoid duplicates
|
||||
func (w *Writer) generateHasManyFieldName(constraint *models.Constraint, sourceSchemaName, sourceTableName string) string {
|
||||
// For has-many, we need to include the source table name to avoid duplicates
|
||||
// e.g., multiple tables referencing the same column on this table
|
||||
if len(constraint.Columns) > 0 {
|
||||
columnName := constraint.Columns[0]
|
||||
// Get the model name for the source table (pluralized)
|
||||
sourceModelName := w.getModelName(sourceSchemaName, sourceTableName)
|
||||
// Remove "Model" prefix if present
|
||||
sourceModelName = strings.TrimPrefix(sourceModelName, "Model")
|
||||
|
||||
// Convert column to PascalCase and combine with source table
|
||||
// e.g., "rid_api_provider" + "Login" -> "RelRIDAPIProviderLogins"
|
||||
columnPart := SnakeCaseToPascalCase(columnName)
|
||||
return "Rel" + columnPart + Pluralize(sourceModelName)
|
||||
}
|
||||
|
||||
// Fallback to table-based naming
|
||||
sourceModelName := w.getModelName(sourceSchemaName, sourceTableName)
|
||||
sourceModelName = strings.TrimPrefix(sourceModelName, "Model")
|
||||
return "Rel" + Pluralize(sourceModelName)
|
||||
}
|
||||
|
||||
// ensureUniqueFieldName ensures a field name is unique by adding numeric suffixes if needed
|
||||
func (w *Writer) ensureUniqueFieldName(fieldName string, usedNames map[string]int) string {
|
||||
originalName := fieldName
|
||||
count := usedNames[originalName]
|
||||
|
||||
if count > 0 {
|
||||
// Name is already used, add numeric suffix
|
||||
fieldName = fmt.Sprintf("%s%d", originalName, count+1)
|
||||
}
|
||||
|
||||
// Increment the counter for this base name
|
||||
usedNames[originalName]++
|
||||
|
||||
return fieldName
|
||||
}
|
||||
|
||||
// getPackageName returns the package name from options or defaults to "models"
|
||||
@@ -341,6 +417,15 @@ func (w *Writer) writeOutput(content string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// runGoFmt runs go fmt on the specified file
|
||||
func (w *Writer) runGoFmt(filepath string) {
|
||||
cmd := exec.Command("gofmt", "-w", filepath)
|
||||
if err := cmd.Run(); err != nil {
|
||||
// Don't fail the whole operation if gofmt fails, just warn
|
||||
fmt.Fprintf(os.Stderr, "Warning: failed to run gofmt on %s: %v\n", filepath, err)
|
||||
}
|
||||
}
|
||||
|
||||
// shouldUseMultiFile determines whether to use multi-file mode based on metadata or output path
|
||||
func (w *Writer) shouldUseMultiFile() bool {
|
||||
// Check if multi_file is explicitly set in metadata
|
||||
@@ -386,6 +471,7 @@ func (w *Writer) createDatabaseRef(db *models.Database) *models.Database {
|
||||
DatabaseVersion: db.DatabaseVersion,
|
||||
SourceFormat: db.SourceFormat,
|
||||
Schemas: nil, // Don't include schemas to avoid circular reference
|
||||
GUID: db.GUID,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -402,5 +488,6 @@ func (w *Writer) createSchemaRef(schema *models.Schema, db *models.Database) *mo
|
||||
Sequence: schema.Sequence,
|
||||
RefDatabase: w.createDatabaseRef(db), // Include database ref
|
||||
Tables: nil, // Don't include tables to avoid circular reference
|
||||
GUID: schema.GUID,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -66,7 +66,7 @@ func TestWriter_WriteTable(t *testing.T) {
|
||||
// Verify key elements are present
|
||||
expectations := []string{
|
||||
"package models",
|
||||
"type ModelUser struct",
|
||||
"type ModelPublicUser struct",
|
||||
"bun.BaseModel",
|
||||
"table:public.users",
|
||||
"alias:users",
|
||||
@@ -78,9 +78,9 @@ func TestWriter_WriteTable(t *testing.T) {
|
||||
"resolvespec_common.SqlTime",
|
||||
"bun:\"id",
|
||||
"bun:\"email",
|
||||
"func (m ModelUser) TableName() string",
|
||||
"func (m ModelPublicUser) TableName() string",
|
||||
"return \"public.users\"",
|
||||
"func (m ModelUser) GetID() int64",
|
||||
"func (m ModelPublicUser) GetID() int64",
|
||||
}
|
||||
|
||||
for _, expected := range expectations {
|
||||
@@ -175,12 +175,378 @@ func TestWriter_WriteDatabase_MultiFile(t *testing.T) {
|
||||
postsStr := string(postsContent)
|
||||
|
||||
// Verify relationship is present with Bun format
|
||||
if !strings.Contains(postsStr, "USE") {
|
||||
t.Errorf("Missing relationship field USE")
|
||||
// Should now be RelUserID (has-one) instead of USE
|
||||
if !strings.Contains(postsStr, "RelUserID") {
|
||||
t.Errorf("Missing relationship field RelUserID (new naming convention)")
|
||||
}
|
||||
if !strings.Contains(postsStr, "rel:has-one") {
|
||||
t.Errorf("Missing Bun relationship tag: %s", postsStr)
|
||||
}
|
||||
|
||||
// Check users file contains has-many relationship
|
||||
usersContent, err := os.ReadFile(filepath.Join(tmpDir, "sql_public_users.go"))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read users file: %v", err)
|
||||
}
|
||||
|
||||
usersStr := string(usersContent)
|
||||
|
||||
// Should have RelUserIDPublicPosts (has-many) field - includes schema prefix
|
||||
if !strings.Contains(usersStr, "RelUserIDPublicPosts") {
|
||||
t.Errorf("Missing has-many relationship field RelUserIDPublicPosts")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriter_MultipleReferencesToSameTable(t *testing.T) {
|
||||
// Test scenario: api_event table with multiple foreign keys to filepointer table
|
||||
db := models.InitDatabase("testdb")
|
||||
schema := models.InitSchema("org")
|
||||
|
||||
// Filepointer table
|
||||
filepointer := models.InitTable("filepointer", "org")
|
||||
filepointer.Columns["id_filepointer"] = &models.Column{
|
||||
Name: "id_filepointer",
|
||||
Type: "bigserial",
|
||||
NotNull: true,
|
||||
IsPrimaryKey: true,
|
||||
}
|
||||
schema.Tables = append(schema.Tables, filepointer)
|
||||
|
||||
// API event table with two foreign keys to filepointer
|
||||
apiEvent := models.InitTable("api_event", "org")
|
||||
apiEvent.Columns["id_api_event"] = &models.Column{
|
||||
Name: "id_api_event",
|
||||
Type: "bigserial",
|
||||
NotNull: true,
|
||||
IsPrimaryKey: true,
|
||||
}
|
||||
apiEvent.Columns["rid_filepointer_request"] = &models.Column{
|
||||
Name: "rid_filepointer_request",
|
||||
Type: "bigint",
|
||||
NotNull: false,
|
||||
}
|
||||
apiEvent.Columns["rid_filepointer_response"] = &models.Column{
|
||||
Name: "rid_filepointer_response",
|
||||
Type: "bigint",
|
||||
NotNull: false,
|
||||
}
|
||||
|
||||
// Add constraints
|
||||
apiEvent.Constraints["fk_request"] = &models.Constraint{
|
||||
Name: "fk_request",
|
||||
Type: models.ForeignKeyConstraint,
|
||||
Columns: []string{"rid_filepointer_request"},
|
||||
ReferencedTable: "filepointer",
|
||||
ReferencedSchema: "org",
|
||||
ReferencedColumns: []string{"id_filepointer"},
|
||||
}
|
||||
apiEvent.Constraints["fk_response"] = &models.Constraint{
|
||||
Name: "fk_response",
|
||||
Type: models.ForeignKeyConstraint,
|
||||
Columns: []string{"rid_filepointer_response"},
|
||||
ReferencedTable: "filepointer",
|
||||
ReferencedSchema: "org",
|
||||
ReferencedColumns: []string{"id_filepointer"},
|
||||
}
|
||||
|
||||
schema.Tables = append(schema.Tables, apiEvent)
|
||||
db.Schemas = append(db.Schemas, schema)
|
||||
|
||||
// Create writer
|
||||
tmpDir := t.TempDir()
|
||||
opts := &writers.WriterOptions{
|
||||
PackageName: "models",
|
||||
OutputPath: tmpDir,
|
||||
Metadata: map[string]interface{}{
|
||||
"multi_file": true,
|
||||
},
|
||||
}
|
||||
|
||||
writer := NewWriter(opts)
|
||||
err := writer.WriteDatabase(db)
|
||||
if err != nil {
|
||||
t.Fatalf("WriteDatabase failed: %v", err)
|
||||
}
|
||||
|
||||
// Read the api_event file
|
||||
apiEventContent, err := os.ReadFile(filepath.Join(tmpDir, "sql_org_api_event.go"))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read api_event file: %v", err)
|
||||
}
|
||||
|
||||
contentStr := string(apiEventContent)
|
||||
|
||||
// Verify both relationships have unique names based on column names
|
||||
expectations := []struct {
|
||||
fieldName string
|
||||
tag string
|
||||
}{
|
||||
{"RelRIDFilepointerRequest", "join:rid_filepointer_request=id_filepointer"},
|
||||
{"RelRIDFilepointerResponse", "join:rid_filepointer_response=id_filepointer"},
|
||||
}
|
||||
|
||||
for _, exp := range expectations {
|
||||
if !strings.Contains(contentStr, exp.fieldName) {
|
||||
t.Errorf("Missing relationship field: %s\nGenerated:\n%s", exp.fieldName, contentStr)
|
||||
}
|
||||
if !strings.Contains(contentStr, exp.tag) {
|
||||
t.Errorf("Missing relationship tag: %s\nGenerated:\n%s", exp.tag, contentStr)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify NO duplicate field names (old behavior would create duplicate "FIL" fields)
|
||||
if strings.Contains(contentStr, "FIL *ModelFilepointer") {
|
||||
t.Errorf("Found old prefix-based naming (FIL), should use column-based naming")
|
||||
}
|
||||
|
||||
// Also verify has-many relationships on filepointer table
|
||||
filepointerContent, err := os.ReadFile(filepath.Join(tmpDir, "sql_org_filepointer.go"))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read filepointer file: %v", err)
|
||||
}
|
||||
|
||||
filepointerStr := string(filepointerContent)
|
||||
|
||||
// Should have two different has-many relationships with unique names
|
||||
hasManyExpectations := []string{
|
||||
"RelRIDFilepointerRequestOrgAPIEvents", // Has many via rid_filepointer_request
|
||||
"RelRIDFilepointerResponseOrgAPIEvents", // Has many via rid_filepointer_response
|
||||
}
|
||||
|
||||
for _, exp := range hasManyExpectations {
|
||||
if !strings.Contains(filepointerStr, exp) {
|
||||
t.Errorf("Missing has-many relationship field: %s\nGenerated:\n%s", exp, filepointerStr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriter_MultipleHasManyRelationships(t *testing.T) {
|
||||
// Test scenario: api_provider table referenced by multiple tables via rid_api_provider
|
||||
db := models.InitDatabase("testdb")
|
||||
schema := models.InitSchema("org")
|
||||
|
||||
// Owner table
|
||||
owner := models.InitTable("owner", "org")
|
||||
owner.Columns["id_owner"] = &models.Column{
|
||||
Name: "id_owner",
|
||||
Type: "bigserial",
|
||||
NotNull: true,
|
||||
IsPrimaryKey: true,
|
||||
}
|
||||
schema.Tables = append(schema.Tables, owner)
|
||||
|
||||
// API Provider table
|
||||
apiProvider := models.InitTable("api_provider", "org")
|
||||
apiProvider.Columns["id_api_provider"] = &models.Column{
|
||||
Name: "id_api_provider",
|
||||
Type: "bigserial",
|
||||
NotNull: true,
|
||||
IsPrimaryKey: true,
|
||||
}
|
||||
apiProvider.Columns["rid_owner"] = &models.Column{
|
||||
Name: "rid_owner",
|
||||
Type: "bigint",
|
||||
NotNull: true,
|
||||
}
|
||||
apiProvider.Constraints["fk_owner"] = &models.Constraint{
|
||||
Name: "fk_owner",
|
||||
Type: models.ForeignKeyConstraint,
|
||||
Columns: []string{"rid_owner"},
|
||||
ReferencedTable: "owner",
|
||||
ReferencedSchema: "org",
|
||||
ReferencedColumns: []string{"id_owner"},
|
||||
}
|
||||
schema.Tables = append(schema.Tables, apiProvider)
|
||||
|
||||
// Login table
|
||||
login := models.InitTable("login", "org")
|
||||
login.Columns["id_login"] = &models.Column{
|
||||
Name: "id_login",
|
||||
Type: "bigserial",
|
||||
NotNull: true,
|
||||
IsPrimaryKey: true,
|
||||
}
|
||||
login.Columns["rid_api_provider"] = &models.Column{
|
||||
Name: "rid_api_provider",
|
||||
Type: "bigint",
|
||||
NotNull: true,
|
||||
}
|
||||
login.Constraints["fk_api_provider"] = &models.Constraint{
|
||||
Name: "fk_api_provider",
|
||||
Type: models.ForeignKeyConstraint,
|
||||
Columns: []string{"rid_api_provider"},
|
||||
ReferencedTable: "api_provider",
|
||||
ReferencedSchema: "org",
|
||||
ReferencedColumns: []string{"id_api_provider"},
|
||||
}
|
||||
schema.Tables = append(schema.Tables, login)
|
||||
|
||||
// Filepointer table
|
||||
filepointer := models.InitTable("filepointer", "org")
|
||||
filepointer.Columns["id_filepointer"] = &models.Column{
|
||||
Name: "id_filepointer",
|
||||
Type: "bigserial",
|
||||
NotNull: true,
|
||||
IsPrimaryKey: true,
|
||||
}
|
||||
filepointer.Columns["rid_api_provider"] = &models.Column{
|
||||
Name: "rid_api_provider",
|
||||
Type: "bigint",
|
||||
NotNull: true,
|
||||
}
|
||||
filepointer.Constraints["fk_api_provider"] = &models.Constraint{
|
||||
Name: "fk_api_provider",
|
||||
Type: models.ForeignKeyConstraint,
|
||||
Columns: []string{"rid_api_provider"},
|
||||
ReferencedTable: "api_provider",
|
||||
ReferencedSchema: "org",
|
||||
ReferencedColumns: []string{"id_api_provider"},
|
||||
}
|
||||
schema.Tables = append(schema.Tables, filepointer)
|
||||
|
||||
// API Event table
|
||||
apiEvent := models.InitTable("api_event", "org")
|
||||
apiEvent.Columns["id_api_event"] = &models.Column{
|
||||
Name: "id_api_event",
|
||||
Type: "bigserial",
|
||||
NotNull: true,
|
||||
IsPrimaryKey: true,
|
||||
}
|
||||
apiEvent.Columns["rid_api_provider"] = &models.Column{
|
||||
Name: "rid_api_provider",
|
||||
Type: "bigint",
|
||||
NotNull: true,
|
||||
}
|
||||
apiEvent.Constraints["fk_api_provider"] = &models.Constraint{
|
||||
Name: "fk_api_provider",
|
||||
Type: models.ForeignKeyConstraint,
|
||||
Columns: []string{"rid_api_provider"},
|
||||
ReferencedTable: "api_provider",
|
||||
ReferencedSchema: "org",
|
||||
ReferencedColumns: []string{"id_api_provider"},
|
||||
}
|
||||
schema.Tables = append(schema.Tables, apiEvent)
|
||||
|
||||
db.Schemas = append(db.Schemas, schema)
|
||||
|
||||
// Create writer
|
||||
tmpDir := t.TempDir()
|
||||
opts := &writers.WriterOptions{
|
||||
PackageName: "models",
|
||||
OutputPath: tmpDir,
|
||||
Metadata: map[string]interface{}{
|
||||
"multi_file": true,
|
||||
},
|
||||
}
|
||||
|
||||
writer := NewWriter(opts)
|
||||
err := writer.WriteDatabase(db)
|
||||
if err != nil {
|
||||
t.Fatalf("WriteDatabase failed: %v", err)
|
||||
}
|
||||
|
||||
// Read the api_provider file
|
||||
apiProviderContent, err := os.ReadFile(filepath.Join(tmpDir, "sql_org_api_provider.go"))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read api_provider file: %v", err)
|
||||
}
|
||||
|
||||
contentStr := string(apiProviderContent)
|
||||
|
||||
// Verify all has-many relationships have unique names
|
||||
hasManyExpectations := []string{
|
||||
"RelRIDAPIProviderOrgLogins", // Has many via Login
|
||||
"RelRIDAPIProviderOrgFilepointers", // Has many via Filepointer
|
||||
"RelRIDAPIProviderOrgAPIEvents", // Has many via APIEvent
|
||||
"RelRIDOwner", // Has one via rid_owner
|
||||
}
|
||||
|
||||
for _, exp := range hasManyExpectations {
|
||||
if !strings.Contains(contentStr, exp) {
|
||||
t.Errorf("Missing relationship field: %s\nGenerated:\n%s", exp, contentStr)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify NO duplicate field names
|
||||
// Count occurrences of "RelRIDAPIProvider" fields - should have 3 unique ones
|
||||
count := strings.Count(contentStr, "RelRIDAPIProvider")
|
||||
if count != 3 {
|
||||
t.Errorf("Expected 3 RelRIDAPIProvider* fields, found %d\nGenerated:\n%s", count, contentStr)
|
||||
}
|
||||
|
||||
// Verify no duplicate declarations (would cause compilation error)
|
||||
duplicatePattern := "RelRIDAPIProviders []*Model"
|
||||
if strings.Contains(contentStr, duplicatePattern) {
|
||||
t.Errorf("Found duplicate field declaration pattern, fields should be unique")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriter_FieldNameCollision(t *testing.T) {
|
||||
// Test scenario: table with columns that would conflict with generated method names
|
||||
table := models.InitTable("audit_table", "audit")
|
||||
table.Columns["id_audit_table"] = &models.Column{
|
||||
Name: "id_audit_table",
|
||||
Type: "smallint",
|
||||
NotNull: true,
|
||||
IsPrimaryKey: true,
|
||||
Sequence: 1,
|
||||
}
|
||||
table.Columns["table_name"] = &models.Column{
|
||||
Name: "table_name",
|
||||
Type: "varchar",
|
||||
Length: 100,
|
||||
NotNull: true,
|
||||
Sequence: 2,
|
||||
}
|
||||
table.Columns["table_schema"] = &models.Column{
|
||||
Name: "table_schema",
|
||||
Type: "varchar",
|
||||
Length: 100,
|
||||
NotNull: true,
|
||||
Sequence: 3,
|
||||
}
|
||||
|
||||
// Create writer
|
||||
tmpDir := t.TempDir()
|
||||
opts := &writers.WriterOptions{
|
||||
PackageName: "models",
|
||||
OutputPath: filepath.Join(tmpDir, "test.go"),
|
||||
}
|
||||
|
||||
writer := NewWriter(opts)
|
||||
|
||||
err := writer.WriteTable(table)
|
||||
if err != nil {
|
||||
t.Fatalf("WriteTable failed: %v", err)
|
||||
}
|
||||
|
||||
// Read the generated file
|
||||
content, err := os.ReadFile(opts.OutputPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read generated file: %v", err)
|
||||
}
|
||||
|
||||
generated := string(content)
|
||||
|
||||
// Verify that TableName field was renamed to TableName_ to avoid collision
|
||||
if !strings.Contains(generated, "TableName_") {
|
||||
t.Errorf("Expected field 'TableName_' (with underscore) but not found\nGenerated:\n%s", generated)
|
||||
}
|
||||
|
||||
// Verify the struct tag still references the correct database column
|
||||
if !strings.Contains(generated, `bun:"table_name,`) {
|
||||
t.Errorf("Expected bun tag to reference 'table_name' column\nGenerated:\n%s", generated)
|
||||
}
|
||||
|
||||
// Verify the TableName() method still exists and doesn't conflict
|
||||
if !strings.Contains(generated, "func (m ModelAuditAuditTable) TableName() string") {
|
||||
t.Errorf("TableName() method should still be generated\nGenerated:\n%s", generated)
|
||||
}
|
||||
|
||||
// Verify NO field named just "TableName" (without underscore)
|
||||
if strings.Contains(generated, "TableName resolvespec_common") || strings.Contains(generated, "TableName string") {
|
||||
t.Errorf("Field 'TableName' without underscore should not exist (would conflict with method)\nGenerated:\n%s", generated)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTypeMapper_SQLTypeToGoType_Bun(t *testing.T) {
|
||||
|
||||
@@ -126,7 +126,15 @@ func (w *Writer) tableToDBML(t *models.Table) string {
|
||||
attrs = append(attrs, "increment")
|
||||
}
|
||||
if column.Default != nil {
|
||||
attrs = append(attrs, fmt.Sprintf("default: `%v`", column.Default))
|
||||
// Check if default value contains backticks (DBML expressions like `now()`)
|
||||
defaultStr := fmt.Sprintf("%v", column.Default)
|
||||
if strings.HasPrefix(defaultStr, "`") && strings.HasSuffix(defaultStr, "`") {
|
||||
// Already an expression with backticks, use as-is
|
||||
attrs = append(attrs, fmt.Sprintf("default: %s", defaultStr))
|
||||
} else {
|
||||
// Regular value, wrap in single quotes
|
||||
attrs = append(attrs, fmt.Sprintf("default: '%v'", column.Default))
|
||||
}
|
||||
}
|
||||
|
||||
if len(attrs) > 0 {
|
||||
|
||||
@@ -133,7 +133,11 @@ func (w *Writer) mapTableFields(table *models.Table) models.DCTXTable {
|
||||
prefix = table.Name[:3]
|
||||
}
|
||||
|
||||
tableGuid := w.newGUID()
|
||||
// Use GUID from model if available, otherwise generate a new one
|
||||
tableGuid := table.GUID
|
||||
if tableGuid == "" {
|
||||
tableGuid = w.newGUID()
|
||||
}
|
||||
w.tableGuidMap[table.Name] = tableGuid
|
||||
|
||||
dctxTable := models.DCTXTable{
|
||||
@@ -171,7 +175,11 @@ func (w *Writer) mapTableKeys(table *models.Table) []models.DCTXKey {
|
||||
}
|
||||
|
||||
func (w *Writer) mapField(column *models.Column) models.DCTXField {
|
||||
guid := w.newGUID()
|
||||
// Use GUID from model if available, otherwise generate a new one
|
||||
guid := column.GUID
|
||||
if guid == "" {
|
||||
guid = w.newGUID()
|
||||
}
|
||||
fieldKey := fmt.Sprintf("%s.%s", column.Table, column.Name)
|
||||
w.fieldGuidMap[fieldKey] = guid
|
||||
|
||||
@@ -209,7 +217,11 @@ func (w *Writer) mapDataType(dataType string) string {
|
||||
}
|
||||
|
||||
func (w *Writer) mapKey(index *models.Index, table *models.Table) models.DCTXKey {
|
||||
guid := w.newGUID()
|
||||
// Use GUID from model if available, otherwise generate a new one
|
||||
guid := index.GUID
|
||||
if guid == "" {
|
||||
guid = w.newGUID()
|
||||
}
|
||||
keyKey := fmt.Sprintf("%s.%s", table.Name, index.Name)
|
||||
w.keyGuidMap[keyKey] = guid
|
||||
|
||||
@@ -344,7 +356,7 @@ func (w *Writer) mapRelation(rel *models.Relationship, schema *models.Schema) mo
|
||||
}
|
||||
|
||||
return models.DCTXRelation{
|
||||
Guid: w.newGUID(),
|
||||
Guid: rel.GUID, // Use GUID from relationship model
|
||||
PrimaryTable: w.tableGuidMap[rel.ToTable], // GUID of the 'to' table (e.g., users)
|
||||
ForeignTable: w.tableGuidMap[rel.FromTable], // GUID of the 'from' table (e.g., posts)
|
||||
PrimaryKey: primaryKeyGUID,
|
||||
|
||||
@@ -127,6 +127,51 @@ func (w *Writer) databaseToDrawDB(d *models.Database) *DrawDBSchema {
|
||||
}
|
||||
}
|
||||
|
||||
// Create subject areas for domains
|
||||
for domainIdx, domainModel := range d.Domains {
|
||||
// Calculate bounds for all tables in this domain
|
||||
minX, minY := 999999, 999999
|
||||
maxX, maxY := 0, 0
|
||||
|
||||
domainTableCount := 0
|
||||
for _, domainTable := range domainModel.Tables {
|
||||
// Find the table in the schema to get its position
|
||||
for _, t := range schema.Tables {
|
||||
if t.Name == domainTable.TableName {
|
||||
if t.X < minX {
|
||||
minX = t.X
|
||||
}
|
||||
if t.Y < minY {
|
||||
minY = t.Y
|
||||
}
|
||||
if t.X+colWidth > maxX {
|
||||
maxX = t.X + colWidth
|
||||
}
|
||||
if t.Y+rowHeight > maxY {
|
||||
maxY = t.Y + rowHeight
|
||||
}
|
||||
domainTableCount++
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Only create area if domain has tables in this schema
|
||||
if domainTableCount > 0 {
|
||||
area := &DrawDBArea{
|
||||
ID: areaID,
|
||||
Name: domainModel.Name,
|
||||
Color: getColorForIndex(len(d.Schemas) + domainIdx), // Use different colors than schemas
|
||||
X: minX - 20,
|
||||
Y: minY - 20,
|
||||
Width: maxX - minX + 40,
|
||||
Height: maxY - minY + 40,
|
||||
}
|
||||
schema.SubjectAreas = append(schema.SubjectAreas, area)
|
||||
areaID++
|
||||
}
|
||||
}
|
||||
|
||||
// Add relationships
|
||||
for _, schemaModel := range d.Schemas {
|
||||
for _, table := range schemaModel.Tables {
|
||||
|
||||
@@ -196,7 +196,9 @@ func (w *Writer) writeTableFile(table *models.Table, schema *models.Schema, db *
|
||||
}
|
||||
|
||||
// Generate filename: {tableName}.ts
|
||||
filename := filepath.Join(w.options.OutputPath, table.Name+".ts")
|
||||
// Sanitize table name to remove quotes, comments, and invalid characters
|
||||
safeTableName := writers.SanitizeFilename(table.Name)
|
||||
filename := filepath.Join(w.options.OutputPath, safeTableName+".ts")
|
||||
return os.WriteFile(filename, []byte(code), 0644)
|
||||
}
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"sort"
|
||||
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/writers"
|
||||
)
|
||||
|
||||
// TemplateData represents the data passed to the template for code generation
|
||||
@@ -24,6 +25,7 @@ type ModelData struct {
|
||||
Fields []*FieldData
|
||||
Config *MethodConfig
|
||||
PrimaryKeyField string // Name of the primary key field
|
||||
PrimaryKeyType string // Go type of the primary key field
|
||||
IDColumnName string // Name of the ID column in database
|
||||
Prefix string // 3-letter prefix
|
||||
}
|
||||
@@ -103,19 +105,20 @@ func (td *TemplateData) FinalizeImports() {
|
||||
}
|
||||
|
||||
// NewModelData creates a new ModelData from a models.Table
|
||||
func NewModelData(table *models.Table, schema string, typeMapper *TypeMapper) *ModelData {
|
||||
tableName := table.Name
|
||||
if schema != "" {
|
||||
tableName = schema + "." + table.Name
|
||||
}
|
||||
func NewModelData(table *models.Table, schema string, typeMapper *TypeMapper, flattenSchema bool) *ModelData {
|
||||
tableName := writers.QualifiedTableName(schema, table.Name, flattenSchema)
|
||||
|
||||
// Generate model name: singularize and convert to PascalCase
|
||||
// Generate model name: Model + Schema + Table (all PascalCase)
|
||||
singularTable := Singularize(table.Name)
|
||||
modelName := SnakeCaseToPascalCase(singularTable)
|
||||
tablePart := SnakeCaseToPascalCase(singularTable)
|
||||
|
||||
// Add "Model" prefix if not already present
|
||||
if !hasModelPrefix(modelName) {
|
||||
modelName = "Model" + modelName
|
||||
// Include schema name in model name
|
||||
var modelName string
|
||||
if schema != "" {
|
||||
schemaPart := SnakeCaseToPascalCase(schema)
|
||||
modelName = "Model" + schemaPart + tablePart
|
||||
} else {
|
||||
modelName = "Model" + tablePart
|
||||
}
|
||||
|
||||
model := &ModelData{
|
||||
@@ -131,8 +134,11 @@ func NewModelData(table *models.Table, schema string, typeMapper *TypeMapper) *M
|
||||
// Find primary key
|
||||
for _, col := range table.Columns {
|
||||
if col.IsPrimaryKey {
|
||||
model.PrimaryKeyField = SnakeCaseToPascalCase(col.Name)
|
||||
model.IDColumnName = col.Name
|
||||
// Sanitize column name to remove backticks
|
||||
safeName := writers.SanitizeStructTagValue(col.Name)
|
||||
model.PrimaryKeyField = SnakeCaseToPascalCase(safeName)
|
||||
model.PrimaryKeyType = typeMapper.SQLTypeToGoType(col.Type, col.NotNull)
|
||||
model.IDColumnName = safeName
|
||||
break
|
||||
}
|
||||
}
|
||||
@@ -141,6 +147,8 @@ func NewModelData(table *models.Table, schema string, typeMapper *TypeMapper) *M
|
||||
columns := sortColumns(table.Columns)
|
||||
for _, col := range columns {
|
||||
field := columnToField(col, table, typeMapper)
|
||||
// Check for name collision with generated methods and rename if needed
|
||||
field.Name = resolveFieldNameCollision(field.Name)
|
||||
model.Fields = append(model.Fields, field)
|
||||
}
|
||||
|
||||
@@ -149,10 +157,13 @@ func NewModelData(table *models.Table, schema string, typeMapper *TypeMapper) *M
|
||||
|
||||
// columnToField converts a models.Column to FieldData
|
||||
func columnToField(col *models.Column, table *models.Table, typeMapper *TypeMapper) *FieldData {
|
||||
fieldName := SnakeCaseToPascalCase(col.Name)
|
||||
// Sanitize column name first to remove backticks before generating field name
|
||||
safeName := writers.SanitizeStructTagValue(col.Name)
|
||||
fieldName := SnakeCaseToPascalCase(safeName)
|
||||
goType := typeMapper.SQLTypeToGoType(col.Type, col.NotNull)
|
||||
gormTag := typeMapper.BuildGormTag(col, table)
|
||||
jsonTag := col.Name // Use column name for JSON tag
|
||||
// Use same sanitized name for JSON tag
|
||||
jsonTag := safeName
|
||||
|
||||
return &FieldData{
|
||||
Name: fieldName,
|
||||
@@ -179,9 +190,28 @@ func formatComment(description, comment string) string {
|
||||
return comment
|
||||
}
|
||||
|
||||
// hasModelPrefix checks if a name already has "Model" prefix
|
||||
func hasModelPrefix(name string) bool {
|
||||
return len(name) >= 5 && name[:5] == "Model"
|
||||
// resolveFieldNameCollision checks if a field name conflicts with generated method names
|
||||
// and adds an underscore suffix if there's a collision
|
||||
func resolveFieldNameCollision(fieldName string) string {
|
||||
// List of method names that are generated by the template
|
||||
reservedNames := map[string]bool{
|
||||
"TableName": true,
|
||||
"TableNameOnly": true,
|
||||
"SchemaName": true,
|
||||
"GetID": true,
|
||||
"GetIDStr": true,
|
||||
"SetID": true,
|
||||
"UpdateID": true,
|
||||
"GetIDName": true,
|
||||
"GetPrefix": true,
|
||||
}
|
||||
|
||||
// Check if field name conflicts with a reserved method name
|
||||
if reservedNames[fieldName] {
|
||||
return fieldName + "_"
|
||||
}
|
||||
|
||||
return fieldName
|
||||
}
|
||||
|
||||
// sortColumns sorts columns by sequence, then by name
|
||||
|
||||
@@ -62,7 +62,7 @@ func (m {{.Name}}) SetID(newid int64) {
|
||||
{{if and .Config.GenerateUpdateID .PrimaryKeyField}}
|
||||
// UpdateID updates the primary key value
|
||||
func (m *{{.Name}}) UpdateID(newid int64) {
|
||||
m.{{.PrimaryKeyField}} = int32(newid)
|
||||
m.{{.PrimaryKeyField}} = {{.PrimaryKeyType}}(newid)
|
||||
}
|
||||
{{end}}
|
||||
{{if and .Config.GenerateGetIDName .IDColumnName}}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/writers"
|
||||
)
|
||||
|
||||
// TypeMapper handles type conversions between SQL and Go types
|
||||
@@ -199,12 +200,15 @@ func (tm *TypeMapper) BuildGormTag(column *models.Column, table *models.Table) s
|
||||
var parts []string
|
||||
|
||||
// Always include column name (lowercase as per user requirement)
|
||||
parts = append(parts, fmt.Sprintf("column:%s", column.Name))
|
||||
// Sanitize to remove backticks which would break struct tag syntax
|
||||
safeName := writers.SanitizeStructTagValue(column.Name)
|
||||
parts = append(parts, fmt.Sprintf("column:%s", safeName))
|
||||
|
||||
// Add type if specified
|
||||
if column.Type != "" {
|
||||
// Include length, precision, scale if present
|
||||
typeStr := column.Type
|
||||
// Sanitize type to remove backticks
|
||||
typeStr := writers.SanitizeStructTagValue(column.Type)
|
||||
if column.Length > 0 {
|
||||
typeStr = fmt.Sprintf("%s(%d)", typeStr, column.Length)
|
||||
} else if column.Precision > 0 {
|
||||
@@ -234,7 +238,9 @@ func (tm *TypeMapper) BuildGormTag(column *models.Column, table *models.Table) s
|
||||
|
||||
// Default value
|
||||
if column.Default != nil {
|
||||
parts = append(parts, fmt.Sprintf("default:%v", column.Default))
|
||||
// Sanitize default value to remove backticks
|
||||
safeDefault := writers.SanitizeStructTagValue(fmt.Sprintf("%v", column.Default))
|
||||
parts = append(parts, fmt.Sprintf("default:%s", safeDefault))
|
||||
}
|
||||
|
||||
// Check for unique constraint
|
||||
@@ -331,5 +337,5 @@ func (tm *TypeMapper) NeedsFmtImport(generateGetIDStr bool) bool {
|
||||
|
||||
// GetSQLTypesImport returns the import path for sql_types
|
||||
func (tm *TypeMapper) GetSQLTypesImport() string {
|
||||
return "github.com/bitechdev/ResolveSpec/pkg/common/sql_types"
|
||||
return "github.com/bitechdev/ResolveSpec/pkg/spectypes"
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"fmt"
|
||||
"go/format"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
@@ -82,7 +83,7 @@ func (w *Writer) writeSingleFile(db *models.Database) error {
|
||||
// Collect all models
|
||||
for _, schema := range db.Schemas {
|
||||
for _, table := range schema.Tables {
|
||||
modelData := NewModelData(table, schema.Name, w.typeMapper)
|
||||
modelData := NewModelData(table, schema.Name, w.typeMapper, w.options.FlattenSchema)
|
||||
|
||||
// Add relationship fields
|
||||
w.addRelationshipFields(modelData, table, schema, db)
|
||||
@@ -121,7 +122,16 @@ func (w *Writer) writeSingleFile(db *models.Database) error {
|
||||
}
|
||||
|
||||
// Write output
|
||||
return w.writeOutput(formatted)
|
||||
if err := w.writeOutput(formatted); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Run go fmt on the output file
|
||||
if w.options.OutputPath != "" {
|
||||
w.runGoFmt(w.options.OutputPath)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// writeMultiFile writes each table to a separate file
|
||||
@@ -165,7 +175,7 @@ func (w *Writer) writeMultiFile(db *models.Database) error {
|
||||
templateData.AddImport(fmt.Sprintf("sql_types \"%s\"", w.typeMapper.GetSQLTypesImport()))
|
||||
|
||||
// Create model data
|
||||
modelData := NewModelData(table, schema.Name, w.typeMapper)
|
||||
modelData := NewModelData(table, schema.Name, w.typeMapper, w.options.FlattenSchema)
|
||||
|
||||
// Add relationship fields
|
||||
w.addRelationshipFields(modelData, table, schema, db)
|
||||
@@ -201,13 +211,19 @@ func (w *Writer) writeMultiFile(db *models.Database) error {
|
||||
}
|
||||
|
||||
// Generate filename: sql_{schema}_{table}.go
|
||||
filename := fmt.Sprintf("sql_%s_%s.go", schema.Name, table.Name)
|
||||
// Sanitize schema and table names to remove quotes, comments, and invalid characters
|
||||
safeSchemaName := writers.SanitizeFilename(schema.Name)
|
||||
safeTableName := writers.SanitizeFilename(table.Name)
|
||||
filename := fmt.Sprintf("sql_%s_%s.go", safeSchemaName, safeTableName)
|
||||
filepath := filepath.Join(w.options.OutputPath, filename)
|
||||
|
||||
// Write file
|
||||
if err := os.WriteFile(filepath, []byte(formatted), 0644); err != nil {
|
||||
return fmt.Errorf("failed to write file %s: %w", filename, err)
|
||||
}
|
||||
|
||||
// Run go fmt on the generated file
|
||||
w.runGoFmt(filepath)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -216,6 +232,9 @@ func (w *Writer) writeMultiFile(db *models.Database) error {
|
||||
|
||||
// addRelationshipFields adds relationship fields to the model based on foreign keys
|
||||
func (w *Writer) addRelationshipFields(modelData *ModelData, table *models.Table, schema *models.Schema, db *models.Database) {
|
||||
// Track used field names to detect duplicates
|
||||
usedFieldNames := make(map[string]int)
|
||||
|
||||
// For each foreign key in this table, add a belongs-to relationship
|
||||
for _, constraint := range table.Constraints {
|
||||
if constraint.Type != models.ForeignKeyConstraint {
|
||||
@@ -229,8 +248,9 @@ func (w *Writer) addRelationshipFields(modelData *ModelData, table *models.Table
|
||||
}
|
||||
|
||||
// Create relationship field (belongs-to)
|
||||
refModelName := w.getModelName(constraint.ReferencedTable)
|
||||
fieldName := w.generateRelationshipFieldName(constraint.ReferencedTable)
|
||||
refModelName := w.getModelName(constraint.ReferencedSchema, constraint.ReferencedTable)
|
||||
fieldName := w.generateBelongsToFieldName(constraint)
|
||||
fieldName = w.ensureUniqueFieldName(fieldName, usedFieldNames)
|
||||
relationTag := w.typeMapper.BuildRelationshipTag(constraint, false)
|
||||
|
||||
modelData.AddRelationshipField(&FieldData{
|
||||
@@ -257,8 +277,9 @@ func (w *Writer) addRelationshipFields(modelData *ModelData, table *models.Table
|
||||
// Check if this constraint references our table
|
||||
if constraint.ReferencedTable == table.Name && constraint.ReferencedSchema == schema.Name {
|
||||
// Add has-many relationship
|
||||
otherModelName := w.getModelName(otherTable.Name)
|
||||
fieldName := w.generateRelationshipFieldName(otherTable.Name) + "s" // Pluralize
|
||||
otherModelName := w.getModelName(otherSchema.Name, otherTable.Name)
|
||||
fieldName := w.generateHasManyFieldName(constraint, otherSchema.Name, otherTable.Name)
|
||||
fieldName = w.ensureUniqueFieldName(fieldName, usedFieldNames)
|
||||
relationTag := w.typeMapper.BuildRelationshipTag(constraint, true)
|
||||
|
||||
modelData.AddRelationshipField(&FieldData{
|
||||
@@ -289,22 +310,77 @@ func (w *Writer) findTable(schemaName, tableName string, db *models.Database) *m
|
||||
return nil
|
||||
}
|
||||
|
||||
// getModelName generates the model name from a table name
|
||||
func (w *Writer) getModelName(tableName string) string {
|
||||
// getModelName generates the model name from schema and table name
|
||||
func (w *Writer) getModelName(schemaName, tableName string) string {
|
||||
singular := Singularize(tableName)
|
||||
modelName := SnakeCaseToPascalCase(singular)
|
||||
tablePart := SnakeCaseToPascalCase(singular)
|
||||
|
||||
if !hasModelPrefix(modelName) {
|
||||
modelName = "Model" + modelName
|
||||
// Include schema name in model name
|
||||
var modelName string
|
||||
if schemaName != "" {
|
||||
schemaPart := SnakeCaseToPascalCase(schemaName)
|
||||
modelName = "Model" + schemaPart + tablePart
|
||||
} else {
|
||||
modelName = "Model" + tablePart
|
||||
}
|
||||
|
||||
return modelName
|
||||
}
|
||||
|
||||
// generateRelationshipFieldName generates a field name for a relationship
|
||||
func (w *Writer) generateRelationshipFieldName(tableName string) string {
|
||||
// Use just the prefix (3 letters) for relationship fields
|
||||
return GeneratePrefix(tableName)
|
||||
// generateBelongsToFieldName generates a field name for belongs-to relationships
|
||||
// Uses the foreign key column name for uniqueness
|
||||
func (w *Writer) generateBelongsToFieldName(constraint *models.Constraint) string {
|
||||
// Use the foreign key column name to ensure uniqueness
|
||||
// If there are multiple columns, use the first one
|
||||
if len(constraint.Columns) > 0 {
|
||||
columnName := constraint.Columns[0]
|
||||
// Convert to PascalCase for proper Go field naming
|
||||
// e.g., "rid_filepointer_request" -> "RelRIDFilepointerRequest"
|
||||
return "Rel" + SnakeCaseToPascalCase(columnName)
|
||||
}
|
||||
|
||||
// Fallback to table-based prefix if no columns defined
|
||||
return "Rel" + GeneratePrefix(constraint.ReferencedTable)
|
||||
}
|
||||
|
||||
// generateHasManyFieldName generates a field name for has-many relationships
|
||||
// Uses the foreign key column name + source table name to avoid duplicates
|
||||
func (w *Writer) generateHasManyFieldName(constraint *models.Constraint, sourceSchemaName, sourceTableName string) string {
|
||||
// For has-many, we need to include the source table name to avoid duplicates
|
||||
// e.g., multiple tables referencing the same column on this table
|
||||
if len(constraint.Columns) > 0 {
|
||||
columnName := constraint.Columns[0]
|
||||
// Get the model name for the source table (pluralized)
|
||||
sourceModelName := w.getModelName(sourceSchemaName, sourceTableName)
|
||||
// Remove "Model" prefix if present
|
||||
sourceModelName = strings.TrimPrefix(sourceModelName, "Model")
|
||||
|
||||
// Convert column to PascalCase and combine with source table
|
||||
// e.g., "rid_api_provider" + "Login" -> "RelRIDAPIProviderLogins"
|
||||
columnPart := SnakeCaseToPascalCase(columnName)
|
||||
return "Rel" + columnPart + Pluralize(sourceModelName)
|
||||
}
|
||||
|
||||
// Fallback to table-based naming
|
||||
sourceModelName := w.getModelName(sourceSchemaName, sourceTableName)
|
||||
sourceModelName = strings.TrimPrefix(sourceModelName, "Model")
|
||||
return "Rel" + Pluralize(sourceModelName)
|
||||
}
|
||||
|
||||
// ensureUniqueFieldName ensures a field name is unique by adding numeric suffixes if needed
|
||||
func (w *Writer) ensureUniqueFieldName(fieldName string, usedNames map[string]int) string {
|
||||
originalName := fieldName
|
||||
count := usedNames[originalName]
|
||||
|
||||
if count > 0 {
|
||||
// Name is already used, add numeric suffix
|
||||
fieldName = fmt.Sprintf("%s%d", originalName, count+1)
|
||||
}
|
||||
|
||||
// Increment the counter for this base name
|
||||
usedNames[originalName]++
|
||||
|
||||
return fieldName
|
||||
}
|
||||
|
||||
// getPackageName returns the package name from options or defaults to "models"
|
||||
@@ -335,6 +411,15 @@ func (w *Writer) writeOutput(content string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// runGoFmt runs go fmt on the specified file
|
||||
func (w *Writer) runGoFmt(filepath string) {
|
||||
cmd := exec.Command("gofmt", "-w", filepath)
|
||||
if err := cmd.Run(); err != nil {
|
||||
// Don't fail the whole operation if gofmt fails, just warn
|
||||
fmt.Fprintf(os.Stderr, "Warning: failed to run gofmt on %s: %v\n", filepath, err)
|
||||
}
|
||||
}
|
||||
|
||||
// shouldUseMultiFile determines whether to use multi-file mode based on metadata or output path
|
||||
func (w *Writer) shouldUseMultiFile() bool {
|
||||
// Check if multi_file is explicitly set in metadata
|
||||
@@ -380,6 +465,7 @@ func (w *Writer) createDatabaseRef(db *models.Database) *models.Database {
|
||||
DatabaseVersion: db.DatabaseVersion,
|
||||
SourceFormat: db.SourceFormat,
|
||||
Schemas: nil, // Don't include schemas to avoid circular reference
|
||||
GUID: db.GUID,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -396,5 +482,6 @@ func (w *Writer) createSchemaRef(schema *models.Schema, db *models.Database) *mo
|
||||
Sequence: schema.Sequence,
|
||||
RefDatabase: w.createDatabaseRef(db), // Include database ref
|
||||
Tables: nil, // Don't include tables to avoid circular reference
|
||||
GUID: schema.GUID,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -66,7 +66,7 @@ func TestWriter_WriteTable(t *testing.T) {
|
||||
// Verify key elements are present
|
||||
expectations := []string{
|
||||
"package models",
|
||||
"type ModelUser struct",
|
||||
"type ModelPublicUser struct",
|
||||
"ID",
|
||||
"int64",
|
||||
"Email",
|
||||
@@ -75,9 +75,9 @@ func TestWriter_WriteTable(t *testing.T) {
|
||||
"time.Time",
|
||||
"gorm:\"column:id",
|
||||
"gorm:\"column:email",
|
||||
"func (m ModelUser) TableName() string",
|
||||
"func (m ModelPublicUser) TableName() string",
|
||||
"return \"public.users\"",
|
||||
"func (m ModelUser) GetID() int64",
|
||||
"func (m ModelPublicUser) GetID() int64",
|
||||
}
|
||||
|
||||
for _, expected := range expectations {
|
||||
@@ -164,9 +164,437 @@ func TestWriter_WriteDatabase_MultiFile(t *testing.T) {
|
||||
t.Fatalf("Failed to read posts file: %v", err)
|
||||
}
|
||||
|
||||
if !strings.Contains(string(postsContent), "USE *ModelUser") {
|
||||
// Relationship field should be present
|
||||
t.Logf("Posts content:\n%s", string(postsContent))
|
||||
postsStr := string(postsContent)
|
||||
|
||||
// Verify relationship is present with new naming convention
|
||||
// Should now be RelUserID (belongs-to) instead of USE
|
||||
if !strings.Contains(postsStr, "RelUserID") {
|
||||
t.Errorf("Missing relationship field RelUserID (new naming convention)")
|
||||
}
|
||||
|
||||
// Check users file contains has-many relationship
|
||||
usersContent, err := os.ReadFile(filepath.Join(tmpDir, "sql_public_users.go"))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read users file: %v", err)
|
||||
}
|
||||
|
||||
usersStr := string(usersContent)
|
||||
|
||||
// Should have RelUserIDPublicPosts (has-many) field - includes schema prefix
|
||||
if !strings.Contains(usersStr, "RelUserIDPublicPosts") {
|
||||
t.Errorf("Missing has-many relationship field RelUserIDPublicPosts")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriter_MultipleReferencesToSameTable(t *testing.T) {
|
||||
// Test scenario: api_event table with multiple foreign keys to filepointer table
|
||||
db := models.InitDatabase("testdb")
|
||||
schema := models.InitSchema("org")
|
||||
|
||||
// Filepointer table
|
||||
filepointer := models.InitTable("filepointer", "org")
|
||||
filepointer.Columns["id_filepointer"] = &models.Column{
|
||||
Name: "id_filepointer",
|
||||
Type: "bigserial",
|
||||
NotNull: true,
|
||||
IsPrimaryKey: true,
|
||||
}
|
||||
schema.Tables = append(schema.Tables, filepointer)
|
||||
|
||||
// API event table with two foreign keys to filepointer
|
||||
apiEvent := models.InitTable("api_event", "org")
|
||||
apiEvent.Columns["id_api_event"] = &models.Column{
|
||||
Name: "id_api_event",
|
||||
Type: "bigserial",
|
||||
NotNull: true,
|
||||
IsPrimaryKey: true,
|
||||
}
|
||||
apiEvent.Columns["rid_filepointer_request"] = &models.Column{
|
||||
Name: "rid_filepointer_request",
|
||||
Type: "bigint",
|
||||
NotNull: false,
|
||||
}
|
||||
apiEvent.Columns["rid_filepointer_response"] = &models.Column{
|
||||
Name: "rid_filepointer_response",
|
||||
Type: "bigint",
|
||||
NotNull: false,
|
||||
}
|
||||
|
||||
// Add constraints
|
||||
apiEvent.Constraints["fk_request"] = &models.Constraint{
|
||||
Name: "fk_request",
|
||||
Type: models.ForeignKeyConstraint,
|
||||
Columns: []string{"rid_filepointer_request"},
|
||||
ReferencedTable: "filepointer",
|
||||
ReferencedSchema: "org",
|
||||
ReferencedColumns: []string{"id_filepointer"},
|
||||
}
|
||||
apiEvent.Constraints["fk_response"] = &models.Constraint{
|
||||
Name: "fk_response",
|
||||
Type: models.ForeignKeyConstraint,
|
||||
Columns: []string{"rid_filepointer_response"},
|
||||
ReferencedTable: "filepointer",
|
||||
ReferencedSchema: "org",
|
||||
ReferencedColumns: []string{"id_filepointer"},
|
||||
}
|
||||
|
||||
schema.Tables = append(schema.Tables, apiEvent)
|
||||
db.Schemas = append(db.Schemas, schema)
|
||||
|
||||
// Create writer
|
||||
tmpDir := t.TempDir()
|
||||
opts := &writers.WriterOptions{
|
||||
PackageName: "models",
|
||||
OutputPath: tmpDir,
|
||||
Metadata: map[string]interface{}{
|
||||
"multi_file": true,
|
||||
},
|
||||
}
|
||||
|
||||
writer := NewWriter(opts)
|
||||
err := writer.WriteDatabase(db)
|
||||
if err != nil {
|
||||
t.Fatalf("WriteDatabase failed: %v", err)
|
||||
}
|
||||
|
||||
// Read the api_event file
|
||||
apiEventContent, err := os.ReadFile(filepath.Join(tmpDir, "sql_org_api_event.go"))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read api_event file: %v", err)
|
||||
}
|
||||
|
||||
contentStr := string(apiEventContent)
|
||||
|
||||
// Verify both relationships have unique names based on column names
|
||||
expectations := []struct {
|
||||
fieldName string
|
||||
tag string
|
||||
}{
|
||||
{"RelRIDFilepointerRequest", "foreignKey:RIDFilepointerRequest"},
|
||||
{"RelRIDFilepointerResponse", "foreignKey:RIDFilepointerResponse"},
|
||||
}
|
||||
|
||||
for _, exp := range expectations {
|
||||
if !strings.Contains(contentStr, exp.fieldName) {
|
||||
t.Errorf("Missing relationship field: %s\nGenerated:\n%s", exp.fieldName, contentStr)
|
||||
}
|
||||
if !strings.Contains(contentStr, exp.tag) {
|
||||
t.Errorf("Missing relationship tag: %s\nGenerated:\n%s", exp.tag, contentStr)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify NO duplicate field names (old behavior would create duplicate "FIL" fields)
|
||||
if strings.Contains(contentStr, "FIL *ModelFilepointer") {
|
||||
t.Errorf("Found old prefix-based naming (FIL), should use column-based naming")
|
||||
}
|
||||
|
||||
// Also verify has-many relationships on filepointer table
|
||||
filepointerContent, err := os.ReadFile(filepath.Join(tmpDir, "sql_org_filepointer.go"))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read filepointer file: %v", err)
|
||||
}
|
||||
|
||||
filepointerStr := string(filepointerContent)
|
||||
|
||||
// Should have two different has-many relationships with unique names
|
||||
hasManyExpectations := []string{
|
||||
"RelRIDFilepointerRequestOrgAPIEvents", // Has many via rid_filepointer_request
|
||||
"RelRIDFilepointerResponseOrgAPIEvents", // Has many via rid_filepointer_response
|
||||
}
|
||||
|
||||
for _, exp := range hasManyExpectations {
|
||||
if !strings.Contains(filepointerStr, exp) {
|
||||
t.Errorf("Missing has-many relationship field: %s\nGenerated:\n%s", exp, filepointerStr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriter_MultipleHasManyRelationships(t *testing.T) {
|
||||
// Test scenario: api_provider table referenced by multiple tables via rid_api_provider
|
||||
db := models.InitDatabase("testdb")
|
||||
schema := models.InitSchema("org")
|
||||
|
||||
// Owner table
|
||||
owner := models.InitTable("owner", "org")
|
||||
owner.Columns["id_owner"] = &models.Column{
|
||||
Name: "id_owner",
|
||||
Type: "bigserial",
|
||||
NotNull: true,
|
||||
IsPrimaryKey: true,
|
||||
}
|
||||
schema.Tables = append(schema.Tables, owner)
|
||||
|
||||
// API Provider table
|
||||
apiProvider := models.InitTable("api_provider", "org")
|
||||
apiProvider.Columns["id_api_provider"] = &models.Column{
|
||||
Name: "id_api_provider",
|
||||
Type: "bigserial",
|
||||
NotNull: true,
|
||||
IsPrimaryKey: true,
|
||||
}
|
||||
apiProvider.Columns["rid_owner"] = &models.Column{
|
||||
Name: "rid_owner",
|
||||
Type: "bigint",
|
||||
NotNull: true,
|
||||
}
|
||||
apiProvider.Constraints["fk_owner"] = &models.Constraint{
|
||||
Name: "fk_owner",
|
||||
Type: models.ForeignKeyConstraint,
|
||||
Columns: []string{"rid_owner"},
|
||||
ReferencedTable: "owner",
|
||||
ReferencedSchema: "org",
|
||||
ReferencedColumns: []string{"id_owner"},
|
||||
}
|
||||
schema.Tables = append(schema.Tables, apiProvider)
|
||||
|
||||
// Login table
|
||||
login := models.InitTable("login", "org")
|
||||
login.Columns["id_login"] = &models.Column{
|
||||
Name: "id_login",
|
||||
Type: "bigserial",
|
||||
NotNull: true,
|
||||
IsPrimaryKey: true,
|
||||
}
|
||||
login.Columns["rid_api_provider"] = &models.Column{
|
||||
Name: "rid_api_provider",
|
||||
Type: "bigint",
|
||||
NotNull: true,
|
||||
}
|
||||
login.Constraints["fk_api_provider"] = &models.Constraint{
|
||||
Name: "fk_api_provider",
|
||||
Type: models.ForeignKeyConstraint,
|
||||
Columns: []string{"rid_api_provider"},
|
||||
ReferencedTable: "api_provider",
|
||||
ReferencedSchema: "org",
|
||||
ReferencedColumns: []string{"id_api_provider"},
|
||||
}
|
||||
schema.Tables = append(schema.Tables, login)
|
||||
|
||||
// Filepointer table
|
||||
filepointer := models.InitTable("filepointer", "org")
|
||||
filepointer.Columns["id_filepointer"] = &models.Column{
|
||||
Name: "id_filepointer",
|
||||
Type: "bigserial",
|
||||
NotNull: true,
|
||||
IsPrimaryKey: true,
|
||||
}
|
||||
filepointer.Columns["rid_api_provider"] = &models.Column{
|
||||
Name: "rid_api_provider",
|
||||
Type: "bigint",
|
||||
NotNull: true,
|
||||
}
|
||||
filepointer.Constraints["fk_api_provider"] = &models.Constraint{
|
||||
Name: "fk_api_provider",
|
||||
Type: models.ForeignKeyConstraint,
|
||||
Columns: []string{"rid_api_provider"},
|
||||
ReferencedTable: "api_provider",
|
||||
ReferencedSchema: "org",
|
||||
ReferencedColumns: []string{"id_api_provider"},
|
||||
}
|
||||
schema.Tables = append(schema.Tables, filepointer)
|
||||
|
||||
// API Event table
|
||||
apiEvent := models.InitTable("api_event", "org")
|
||||
apiEvent.Columns["id_api_event"] = &models.Column{
|
||||
Name: "id_api_event",
|
||||
Type: "bigserial",
|
||||
NotNull: true,
|
||||
IsPrimaryKey: true,
|
||||
}
|
||||
apiEvent.Columns["rid_api_provider"] = &models.Column{
|
||||
Name: "rid_api_provider",
|
||||
Type: "bigint",
|
||||
NotNull: true,
|
||||
}
|
||||
apiEvent.Constraints["fk_api_provider"] = &models.Constraint{
|
||||
Name: "fk_api_provider",
|
||||
Type: models.ForeignKeyConstraint,
|
||||
Columns: []string{"rid_api_provider"},
|
||||
ReferencedTable: "api_provider",
|
||||
ReferencedSchema: "org",
|
||||
ReferencedColumns: []string{"id_api_provider"},
|
||||
}
|
||||
schema.Tables = append(schema.Tables, apiEvent)
|
||||
|
||||
db.Schemas = append(db.Schemas, schema)
|
||||
|
||||
// Create writer
|
||||
tmpDir := t.TempDir()
|
||||
opts := &writers.WriterOptions{
|
||||
PackageName: "models",
|
||||
OutputPath: tmpDir,
|
||||
Metadata: map[string]interface{}{
|
||||
"multi_file": true,
|
||||
},
|
||||
}
|
||||
|
||||
writer := NewWriter(opts)
|
||||
err := writer.WriteDatabase(db)
|
||||
if err != nil {
|
||||
t.Fatalf("WriteDatabase failed: %v", err)
|
||||
}
|
||||
|
||||
// Read the api_provider file
|
||||
apiProviderContent, err := os.ReadFile(filepath.Join(tmpDir, "sql_org_api_provider.go"))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read api_provider file: %v", err)
|
||||
}
|
||||
|
||||
contentStr := string(apiProviderContent)
|
||||
|
||||
// Verify all has-many relationships have unique names
|
||||
hasManyExpectations := []string{
|
||||
"RelRIDAPIProviderOrgLogins", // Has many via Login
|
||||
"RelRIDAPIProviderOrgFilepointers", // Has many via Filepointer
|
||||
"RelRIDAPIProviderOrgAPIEvents", // Has many via APIEvent
|
||||
"RelRIDOwner", // Belongs to via rid_owner
|
||||
}
|
||||
|
||||
for _, exp := range hasManyExpectations {
|
||||
if !strings.Contains(contentStr, exp) {
|
||||
t.Errorf("Missing relationship field: %s\nGenerated:\n%s", exp, contentStr)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify NO duplicate field names
|
||||
// Count occurrences of "RelRIDAPIProvider" fields - should have 3 unique ones
|
||||
count := strings.Count(contentStr, "RelRIDAPIProvider")
|
||||
if count != 3 {
|
||||
t.Errorf("Expected 3 RelRIDAPIProvider* fields, found %d\nGenerated:\n%s", count, contentStr)
|
||||
}
|
||||
|
||||
// Verify no duplicate declarations (would cause compilation error)
|
||||
duplicatePattern := "RelRIDAPIProviders []*Model"
|
||||
if strings.Contains(contentStr, duplicatePattern) {
|
||||
t.Errorf("Found duplicate field declaration pattern, fields should be unique")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriter_FieldNameCollision(t *testing.T) {
|
||||
// Test scenario: table with columns that would conflict with generated method names
|
||||
table := models.InitTable("audit_table", "audit")
|
||||
table.Columns["id_audit_table"] = &models.Column{
|
||||
Name: "id_audit_table",
|
||||
Type: "smallint",
|
||||
NotNull: true,
|
||||
IsPrimaryKey: true,
|
||||
Sequence: 1,
|
||||
}
|
||||
table.Columns["table_name"] = &models.Column{
|
||||
Name: "table_name",
|
||||
Type: "varchar",
|
||||
Length: 100,
|
||||
NotNull: true,
|
||||
Sequence: 2,
|
||||
}
|
||||
table.Columns["table_schema"] = &models.Column{
|
||||
Name: "table_schema",
|
||||
Type: "varchar",
|
||||
Length: 100,
|
||||
NotNull: true,
|
||||
Sequence: 3,
|
||||
}
|
||||
|
||||
// Create writer
|
||||
tmpDir := t.TempDir()
|
||||
opts := &writers.WriterOptions{
|
||||
PackageName: "models",
|
||||
OutputPath: filepath.Join(tmpDir, "test.go"),
|
||||
}
|
||||
|
||||
writer := NewWriter(opts)
|
||||
|
||||
err := writer.WriteTable(table)
|
||||
if err != nil {
|
||||
t.Fatalf("WriteTable failed: %v", err)
|
||||
}
|
||||
|
||||
// Read the generated file
|
||||
content, err := os.ReadFile(opts.OutputPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read generated file: %v", err)
|
||||
}
|
||||
|
||||
generated := string(content)
|
||||
|
||||
// Verify that TableName field was renamed to TableName_ to avoid collision
|
||||
if !strings.Contains(generated, "TableName_") {
|
||||
t.Errorf("Expected field 'TableName_' (with underscore) but not found\nGenerated:\n%s", generated)
|
||||
}
|
||||
|
||||
// Verify the struct tag still references the correct database column
|
||||
if !strings.Contains(generated, `gorm:"column:table_name;`) {
|
||||
t.Errorf("Expected gorm tag to reference 'table_name' column\nGenerated:\n%s", generated)
|
||||
}
|
||||
|
||||
// Verify the TableName() method still exists and doesn't conflict
|
||||
if !strings.Contains(generated, "func (m ModelAuditAuditTable) TableName() string") {
|
||||
t.Errorf("TableName() method should still be generated\nGenerated:\n%s", generated)
|
||||
}
|
||||
|
||||
// Verify NO field named just "TableName" (without underscore)
|
||||
if strings.Contains(generated, "TableName sql_types") || strings.Contains(generated, "TableName string") {
|
||||
t.Errorf("Field 'TableName' without underscore should not exist (would conflict with method)\nGenerated:\n%s", generated)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriter_UpdateIDTypeSafety(t *testing.T) {
|
||||
// Test scenario: tables with different primary key types
|
||||
tests := []struct {
|
||||
name string
|
||||
pkType string
|
||||
expectedPK string
|
||||
castType string
|
||||
}{
|
||||
{"int32_pk", "int", "int32", "int32(newid)"},
|
||||
{"int16_pk", "smallint", "int16", "int16(newid)"},
|
||||
{"int64_pk", "bigint", "int64", "int64(newid)"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
table := models.InitTable("test_table", "public")
|
||||
table.Columns["id"] = &models.Column{
|
||||
Name: "id",
|
||||
Type: tt.pkType,
|
||||
NotNull: true,
|
||||
IsPrimaryKey: true,
|
||||
}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
opts := &writers.WriterOptions{
|
||||
PackageName: "models",
|
||||
OutputPath: filepath.Join(tmpDir, "test.go"),
|
||||
}
|
||||
|
||||
writer := NewWriter(opts)
|
||||
err := writer.WriteTable(table)
|
||||
if err != nil {
|
||||
t.Fatalf("WriteTable failed: %v", err)
|
||||
}
|
||||
|
||||
content, err := os.ReadFile(opts.OutputPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read generated file: %v", err)
|
||||
}
|
||||
|
||||
generated := string(content)
|
||||
|
||||
// Verify UpdateID method has correct type cast
|
||||
if !strings.Contains(generated, tt.castType) {
|
||||
t.Errorf("Expected UpdateID to cast to %s\nGenerated:\n%s", tt.castType, generated)
|
||||
}
|
||||
|
||||
// Verify no invalid int32(newid) for non-int32 types
|
||||
if tt.expectedPK != "int32" && strings.Contains(generated, "int32(newid)") {
|
||||
t.Errorf("UpdateID should not cast to int32 for %s type\nGenerated:\n%s", tt.pkType, generated)
|
||||
}
|
||||
|
||||
// Verify UpdateID parameter is int64 (for consistency)
|
||||
if !strings.Contains(generated, "UpdateID(newid int64)") {
|
||||
t.Errorf("UpdateID should accept int64 parameter\nGenerated:\n%s", generated)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
217
pkg/writers/pgsql/NAMING_CONVENTIONS.md
Normal file
217
pkg/writers/pgsql/NAMING_CONVENTIONS.md
Normal file
@@ -0,0 +1,217 @@
|
||||
# PostgreSQL Naming Conventions
|
||||
|
||||
Standardized naming rules for all database objects in RelSpec PostgreSQL output.
|
||||
|
||||
## Quick Reference
|
||||
|
||||
| Object Type | Prefix | Format | Example |
|
||||
| ----------------- | ----------- | ---------------------------------- | ------------------------ |
|
||||
| Primary Key | `pk_` | `pk_<schema>_<table>` | `pk_public_users` |
|
||||
| Foreign Key | `fk_` | `fk_<table>_<referenced_table>` | `fk_posts_users` |
|
||||
| Unique Constraint | `uk_` | `uk_<table>_<column>` | `uk_users_email` |
|
||||
| Unique Index | `uidx_` | `uidx_<table>_<column>` | `uidx_users_email` |
|
||||
| Regular Index | `idx_` | `idx_<table>_<column>` | `idx_posts_user_id` |
|
||||
| Check Constraint | `chk_` | `chk_<table>_<constraint_purpose>` | `chk_users_age_positive` |
|
||||
| Sequence | `identity_` | `identity_<table>_<column>` | `identity_users_id` |
|
||||
| Trigger | `t_` | `t_<purpose>_<table>` | `t_audit_users` |
|
||||
| Trigger Function | `tf_` | `tf_<purpose>_<table>` | `tf_audit_users` |
|
||||
|
||||
## Naming Rules by Object Type
|
||||
|
||||
### Primary Keys
|
||||
|
||||
**Pattern:** `pk_<schema>_<table>`
|
||||
|
||||
- Include schema name to avoid collisions across schemas
|
||||
- Use lowercase, snake_case format
|
||||
- Examples:
|
||||
- `pk_public_users`
|
||||
- `pk_audit_audit_log`
|
||||
- `pk_staging_temp_data`
|
||||
|
||||
### Foreign Keys
|
||||
|
||||
**Pattern:** `fk_<table>_<referenced_table>`
|
||||
|
||||
- Reference the table containing the FK followed by the referenced table
|
||||
- Use lowercase, snake_case format
|
||||
- Do NOT include column names in standard FK constraints
|
||||
- Examples:
|
||||
- `fk_posts_users` (posts.user_id → users.id)
|
||||
- `fk_comments_posts` (comments.post_id → posts.id)
|
||||
- `fk_order_items_orders` (order_items.order_id → orders.id)
|
||||
|
||||
### Unique Constraints
|
||||
|
||||
**Pattern:** `uk_<table>_<column>`
|
||||
|
||||
- Use `uk_` prefix strictly for database constraints (CONSTRAINT type)
|
||||
- Include column name for clarity
|
||||
- Examples:
|
||||
- `uk_users_email`
|
||||
- `uk_users_username`
|
||||
- `uk_products_sku`
|
||||
|
||||
### Unique Indexes
|
||||
|
||||
**Pattern:** `uidx_<table>_<column>`
|
||||
|
||||
- Use `uidx_` prefix strictly for index type objects
|
||||
- Distinguished from constraints for clarity and implementation flexibility
|
||||
- Examples:
|
||||
- `uidx_users_email`
|
||||
- `uidx_sessions_token`
|
||||
- `uidx_api_keys_key`
|
||||
|
||||
### Regular Indexes
|
||||
|
||||
**Pattern:** `idx_<table>_<column>`
|
||||
|
||||
- Standard indexes for query optimization
|
||||
- Single column: `idx_<table>_<column>`
|
||||
- Examples:
|
||||
- `idx_posts_user_id`
|
||||
- `idx_orders_created_at`
|
||||
- `idx_users_status`
|
||||
|
||||
### Check Constraints
|
||||
|
||||
**Pattern:** `chk_<table>_<constraint_purpose>`
|
||||
|
||||
- Describe the constraint validation purpose
|
||||
- Use lowercase, snake_case for the purpose
|
||||
- Examples:
|
||||
- `chk_users_age_positive` (CHECK (age > 0))
|
||||
- `chk_orders_quantity_positive` (CHECK (quantity > 0))
|
||||
- `chk_products_price_valid` (CHECK (price >= 0))
|
||||
- `chk_users_status_enum` (CHECK (status IN ('active', 'inactive')))
|
||||
|
||||
### Sequences
|
||||
|
||||
**Pattern:** `identity_<table>_<column>`
|
||||
|
||||
- Used for SERIAL/IDENTITY columns
|
||||
- Explicitly named for clarity and management
|
||||
- Examples:
|
||||
- `identity_users_id`
|
||||
- `identity_posts_id`
|
||||
- `identity_transactions_id`
|
||||
|
||||
### Triggers
|
||||
|
||||
**Pattern:** `t_<purpose>_<table>`
|
||||
|
||||
- Include purpose before table name
|
||||
- Lowercase, snake_case format
|
||||
- Examples:
|
||||
- `t_audit_users` (audit trigger on users table)
|
||||
- `t_update_timestamp_posts` (timestamp update trigger on posts)
|
||||
- `t_validate_orders` (validation trigger on orders)
|
||||
|
||||
### Trigger Functions
|
||||
|
||||
**Pattern:** `tf_<purpose>_<table>`
|
||||
|
||||
- Pair with trigger naming convention
|
||||
- Use `tf_` prefix to distinguish from triggers themselves
|
||||
- Examples:
|
||||
- `tf_audit_users` (function for t_audit_users)
|
||||
- `tf_update_timestamp_posts` (function for t_update_timestamp_posts)
|
||||
- `tf_validate_orders` (function for t_validate_orders)
|
||||
|
||||
## Multi-Column Objects
|
||||
|
||||
### Composite Primary Keys
|
||||
|
||||
**Pattern:** `pk_<schema>_<table>`
|
||||
|
||||
- Same as single-column PKs
|
||||
- Example: `pk_public_order_items` (composite key on order_id + item_id)
|
||||
|
||||
### Composite Unique Constraints
|
||||
|
||||
**Pattern:** `uk_<table>_<column1>_<column2>_[...]`
|
||||
|
||||
- Append all column names in order
|
||||
- Examples:
|
||||
- `uk_users_email_domain` (UNIQUE(email, domain))
|
||||
- `uk_inventory_warehouse_sku` (UNIQUE(warehouse_id, sku))
|
||||
|
||||
### Composite Unique Indexes
|
||||
|
||||
**Pattern:** `uidx_<table>_<column1>_<column2>_[...]`
|
||||
|
||||
- Append all column names in order
|
||||
- Examples:
|
||||
- `uidx_users_first_name_last_name` (UNIQUE INDEX on first_name, last_name)
|
||||
- `uidx_sessions_user_id_device_id` (UNIQUE INDEX on user_id, device_id)
|
||||
|
||||
### Composite Regular Indexes
|
||||
|
||||
**Pattern:** `idx_<table>_<column1>_<column2>_[...]`
|
||||
|
||||
- Append all column names in order
|
||||
- List columns in typical query filter order
|
||||
- Examples:
|
||||
- `idx_orders_user_id_created_at` (filter by user, then sort by created_at)
|
||||
- `idx_logs_level_timestamp` (filter by level, then by timestamp)
|
||||
|
||||
## Special Cases & Conventions
|
||||
|
||||
### Audit Trail Tables
|
||||
|
||||
- Audit table naming: `<original_table>_audit` or `audit_<original_table>`
|
||||
- Audit indexes follow standard pattern: `idx_<audit_table>_<column>`
|
||||
- Examples:
|
||||
- Users table audit: `users_audit` with `idx_users_audit_tablename`, `idx_users_audit_changedate`
|
||||
- Posts table audit: `posts_audit` with `idx_posts_audit_tablename`, `idx_posts_audit_changedate`
|
||||
|
||||
### Temporal/Versioning Tables
|
||||
|
||||
- Use suffix `_history` or `_versions` if needed
|
||||
- Apply standard naming rules with the full table name
|
||||
- Examples:
|
||||
- `idx_users_history_user_id`
|
||||
- `uk_posts_versions_version_number`
|
||||
|
||||
### Schema-Specific Objects
|
||||
|
||||
- Always qualify with schema when needed: `pk_<schema>_<table>`
|
||||
- Multiple schemas allowed: `pk_public_users`, `pk_staging_users`
|
||||
|
||||
### Reserved Words & Special Names
|
||||
|
||||
- Avoid PostgreSQL reserved keywords in object names
|
||||
- If column/table names conflict, use quoted identifiers in DDL
|
||||
- Naming convention rules still apply to the logical name
|
||||
|
||||
### Generated/Anonymous Indexes
|
||||
|
||||
- If an index lacks explicit naming, default to: `idx_<schema>_<table>`
|
||||
- Should be replaced with explicit names following standards
|
||||
- Examples (to be renamed):
|
||||
- `idx_public_users` → should be `idx_users_<column>`
|
||||
|
||||
## Implementation Notes
|
||||
|
||||
### Code Generation
|
||||
|
||||
- Names are always lowercase in generated SQL
|
||||
- Underscore separators are required
|
||||
|
||||
### Migration Safety
|
||||
|
||||
- Do NOT rename objects after creation without explicit migration
|
||||
- Names should be consistent across all schema versions
|
||||
- Test generated DDL against PostgreSQL before deployment
|
||||
|
||||
### Testing
|
||||
|
||||
- Ensure consistency across all table and constraint generation
|
||||
- Test with reserved words to verify escaping
|
||||
|
||||
## Related Documentation
|
||||
|
||||
- PostgreSQL Identifier Rules: https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-IDENTIFIERS
|
||||
- Constraint Documentation: https://www.postgresql.org/docs/current/ddl-constraints.html
|
||||
- Index Documentation: https://www.postgresql.org/docs/current/indexes.html
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/pgsql"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/writers"
|
||||
)
|
||||
|
||||
@@ -30,7 +31,7 @@ type MigrationWriter struct {
|
||||
|
||||
// NewMigrationWriter creates a new templated migration writer
|
||||
func NewMigrationWriter(options *writers.WriterOptions) (*MigrationWriter, error) {
|
||||
executor, err := NewTemplateExecutor()
|
||||
executor, err := NewTemplateExecutor(options.FlattenSchema)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create template executor: %w", err)
|
||||
}
|
||||
@@ -335,7 +336,7 @@ func (w *MigrationWriter) generateAlterTableScripts(schema *models.Schema, model
|
||||
SchemaName: schema.Name,
|
||||
TableName: modelTable.Name,
|
||||
ColumnName: modelCol.Name,
|
||||
ColumnType: modelCol.Type,
|
||||
ColumnType: pgsql.ConvertSQLType(modelCol.Type),
|
||||
Default: defaultVal,
|
||||
NotNull: modelCol.NotNull,
|
||||
})
|
||||
@@ -359,7 +360,7 @@ func (w *MigrationWriter) generateAlterTableScripts(schema *models.Schema, model
|
||||
SchemaName: schema.Name,
|
||||
TableName: modelTable.Name,
|
||||
ColumnName: modelCol.Name,
|
||||
NewType: modelCol.Type,
|
||||
NewType: pgsql.ConvertSQLType(modelCol.Type),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -427,9 +428,11 @@ func (w *MigrationWriter) generateIndexScripts(model *models.Schema, current *mo
|
||||
for _, modelTable := range model.Tables {
|
||||
currentTable := currentTables[strings.ToLower(modelTable.Name)]
|
||||
|
||||
// Process primary keys first
|
||||
// Process primary keys first - check explicit constraints
|
||||
foundExplicitPK := false
|
||||
for constraintName, constraint := range modelTable.Constraints {
|
||||
if constraint.Type == models.PrimaryKeyConstraint {
|
||||
foundExplicitPK = true
|
||||
shouldCreate := true
|
||||
|
||||
if currentTable != nil {
|
||||
@@ -464,6 +467,53 @@ func (w *MigrationWriter) generateIndexScripts(model *models.Schema, current *mo
|
||||
}
|
||||
}
|
||||
|
||||
// If no explicit PK constraint, check for columns with IsPrimaryKey = true
|
||||
if !foundExplicitPK {
|
||||
pkColumns := []string{}
|
||||
for _, col := range modelTable.Columns {
|
||||
if col.IsPrimaryKey {
|
||||
pkColumns = append(pkColumns, col.SQLName())
|
||||
}
|
||||
}
|
||||
if len(pkColumns) > 0 {
|
||||
sort.Strings(pkColumns)
|
||||
constraintName := fmt.Sprintf("pk_%s_%s", model.SQLName(), modelTable.SQLName())
|
||||
shouldCreate := true
|
||||
|
||||
if currentTable != nil {
|
||||
// Check if a PK constraint already exists (by any name)
|
||||
for _, constraint := range currentTable.Constraints {
|
||||
if constraint.Type == models.PrimaryKeyConstraint {
|
||||
shouldCreate = false
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if shouldCreate {
|
||||
sql, err := w.executor.ExecuteCreatePrimaryKey(CreatePrimaryKeyData{
|
||||
SchemaName: model.Name,
|
||||
TableName: modelTable.Name,
|
||||
ConstraintName: constraintName,
|
||||
Columns: strings.Join(pkColumns, ", "),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
script := MigrationScript{
|
||||
ObjectName: fmt.Sprintf("%s.%s.%s", model.Name, modelTable.Name, constraintName),
|
||||
ObjectType: "create primary key",
|
||||
Schema: model.Name,
|
||||
Priority: 160,
|
||||
Sequence: len(scripts),
|
||||
Body: sql,
|
||||
}
|
||||
scripts = append(scripts, script)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Process indexes
|
||||
for indexName, modelIndex := range modelTable.Indexes {
|
||||
// Skip primary key indexes
|
||||
@@ -703,7 +753,7 @@ func (w *MigrationWriter) generateAuditScripts(schema *models.Schema, auditConfi
|
||||
}
|
||||
|
||||
// Generate audit function
|
||||
funcName := fmt.Sprintf("ft_audit_%s", table.Name)
|
||||
funcName := fmt.Sprintf("tf_audit_%s", table.Name)
|
||||
funcData := BuildAuditFunctionData(schema.Name, table, pk, config, auditSchema, auditConfig.UserFunction)
|
||||
|
||||
funcSQL, err := w.executor.ExecuteAuditFunction(funcData)
|
||||
|
||||
@@ -121,7 +121,7 @@ func TestWriteMigration_WithAudit(t *testing.T) {
|
||||
}
|
||||
|
||||
// Verify audit function
|
||||
if !strings.Contains(output, "CREATE OR REPLACE FUNCTION public.ft_audit_users()") {
|
||||
if !strings.Contains(output, "CREATE OR REPLACE FUNCTION public.tf_audit_users()") {
|
||||
t.Error("Migration missing audit function")
|
||||
}
|
||||
|
||||
@@ -137,7 +137,7 @@ func TestWriteMigration_WithAudit(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestTemplateExecutor_CreateTable(t *testing.T) {
|
||||
executor, err := NewTemplateExecutor()
|
||||
executor, err := NewTemplateExecutor(false)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create executor: %v", err)
|
||||
}
|
||||
@@ -170,14 +170,14 @@ func TestTemplateExecutor_CreateTable(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestTemplateExecutor_AuditFunction(t *testing.T) {
|
||||
executor, err := NewTemplateExecutor()
|
||||
executor, err := NewTemplateExecutor(false)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create executor: %v", err)
|
||||
}
|
||||
|
||||
data := AuditFunctionData{
|
||||
SchemaName: "public",
|
||||
FunctionName: "ft_audit_users",
|
||||
FunctionName: "tf_audit_users",
|
||||
TableName: "users",
|
||||
TablePrefix: "NULL",
|
||||
PrimaryKey: "id",
|
||||
@@ -202,7 +202,7 @@ func TestTemplateExecutor_AuditFunction(t *testing.T) {
|
||||
|
||||
t.Logf("Generated SQL:\n%s", sql)
|
||||
|
||||
if !strings.Contains(sql, "CREATE OR REPLACE FUNCTION public.ft_audit_users()") {
|
||||
if !strings.Contains(sql, "CREATE OR REPLACE FUNCTION public.tf_audit_users()") {
|
||||
t.Error("SQL missing function definition")
|
||||
}
|
||||
if !strings.Contains(sql, "IF TG_OP = 'INSERT'") {
|
||||
@@ -215,3 +215,70 @@ func TestTemplateExecutor_AuditFunction(t *testing.T) {
|
||||
t.Error("SQL missing DELETE handling")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteMigration_NumericConstraintNames(t *testing.T) {
|
||||
// Current database (empty)
|
||||
current := models.InitDatabase("testdb")
|
||||
currentSchema := models.InitSchema("entity")
|
||||
current.Schemas = append(current.Schemas, currentSchema)
|
||||
|
||||
// Model database (with constraint starting with number)
|
||||
model := models.InitDatabase("testdb")
|
||||
modelSchema := models.InitSchema("entity")
|
||||
|
||||
// Create individual_actor_relationship table
|
||||
table := models.InitTable("individual_actor_relationship", "entity")
|
||||
idCol := models.InitColumn("id", "individual_actor_relationship", "entity")
|
||||
idCol.Type = "integer"
|
||||
idCol.IsPrimaryKey = true
|
||||
table.Columns["id"] = idCol
|
||||
|
||||
actorIDCol := models.InitColumn("actor_id", "individual_actor_relationship", "entity")
|
||||
actorIDCol.Type = "integer"
|
||||
table.Columns["actor_id"] = actorIDCol
|
||||
|
||||
// Add constraint with name starting with number
|
||||
constraint := &models.Constraint{
|
||||
Name: "215162_fk_actor",
|
||||
Type: models.ForeignKeyConstraint,
|
||||
Columns: []string{"actor_id"},
|
||||
ReferencedSchema: "entity",
|
||||
ReferencedTable: "actor",
|
||||
ReferencedColumns: []string{"id"},
|
||||
OnDelete: "CASCADE",
|
||||
OnUpdate: "NO ACTION",
|
||||
}
|
||||
table.Constraints["215162_fk_actor"] = constraint
|
||||
|
||||
modelSchema.Tables = append(modelSchema.Tables, table)
|
||||
model.Schemas = append(model.Schemas, modelSchema)
|
||||
|
||||
// Generate migration
|
||||
var buf bytes.Buffer
|
||||
writer, err := NewMigrationWriter(&writers.WriterOptions{})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create writer: %v", err)
|
||||
}
|
||||
writer.writer = &buf
|
||||
|
||||
err = writer.WriteMigration(model, current)
|
||||
if err != nil {
|
||||
t.Fatalf("WriteMigration failed: %v", err)
|
||||
}
|
||||
|
||||
output := buf.String()
|
||||
t.Logf("Generated migration:\n%s", output)
|
||||
|
||||
// Verify constraint name is properly quoted
|
||||
if !strings.Contains(output, `"215162_fk_actor"`) {
|
||||
t.Error("Constraint name starting with number should be quoted")
|
||||
}
|
||||
|
||||
// Verify the SQL is syntactically correct (contains required keywords)
|
||||
if !strings.Contains(output, "ADD CONSTRAINT") {
|
||||
t.Error("Migration missing ADD CONSTRAINT")
|
||||
}
|
||||
if !strings.Contains(output, "FOREIGN KEY") {
|
||||
t.Error("Migration missing FOREIGN KEY")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -21,6 +21,7 @@ func TemplateFunctions() map[string]interface{} {
|
||||
"quote": quote,
|
||||
"escape": escape,
|
||||
"safe_identifier": safeIdentifier,
|
||||
"quote_ident": quoteIdent,
|
||||
|
||||
// Type conversion
|
||||
"goTypeToSQL": goTypeToSQL,
|
||||
@@ -122,6 +123,43 @@ func safeIdentifier(s string) string {
|
||||
return strings.ToLower(safe)
|
||||
}
|
||||
|
||||
// quoteIdent quotes a PostgreSQL identifier if necessary
|
||||
// Identifiers need quoting if they:
|
||||
// - Start with a digit
|
||||
// - Contain special characters
|
||||
// - Are reserved keywords
|
||||
// - Contain uppercase letters (to preserve case)
|
||||
func quoteIdent(s string) string {
|
||||
if s == "" {
|
||||
return `""`
|
||||
}
|
||||
|
||||
// Check if quoting is needed
|
||||
needsQuoting := unicode.IsDigit(rune(s[0]))
|
||||
|
||||
// Starts with digit
|
||||
|
||||
// Contains uppercase letters or special characters
|
||||
for _, r := range s {
|
||||
if unicode.IsUpper(r) {
|
||||
needsQuoting = true
|
||||
break
|
||||
}
|
||||
if !unicode.IsLetter(r) && !unicode.IsDigit(r) && r != '_' {
|
||||
needsQuoting = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if needsQuoting {
|
||||
// Escape double quotes by doubling them
|
||||
escaped := strings.ReplaceAll(s, `"`, `""`)
|
||||
return `"` + escaped + `"`
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// Type conversion functions
|
||||
|
||||
// goTypeToSQL converts Go type to PostgreSQL type
|
||||
|
||||
@@ -101,6 +101,31 @@ func TestSafeIdentifier(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestQuoteIdent(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{"valid_name", "valid_name"},
|
||||
{"ValidName", `"ValidName"`},
|
||||
{"123column", `"123column"`},
|
||||
{"215162_fk_constraint", `"215162_fk_constraint"`},
|
||||
{"user-id", `"user-id"`},
|
||||
{"user@domain", `"user@domain"`},
|
||||
{`"quoted"`, `"""quoted"""`},
|
||||
{"", `""`},
|
||||
{"lowercase", "lowercase"},
|
||||
{"with_underscore", "with_underscore"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
result := quoteIdent(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("quoteIdent(%q) = %q, want %q", tt.input, result, tt.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGoTypeToSQL(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
@@ -243,7 +268,7 @@ func TestTemplateFunctions(t *testing.T) {
|
||||
// Check that all expected functions are registered
|
||||
expectedFuncs := []string{
|
||||
"upper", "lower", "snake_case", "camelCase",
|
||||
"indent", "quote", "escape", "safe_identifier",
|
||||
"indent", "quote", "escape", "safe_identifier", "quote_ident",
|
||||
"goTypeToSQL", "sqlTypeToGo", "isNumeric", "isText",
|
||||
"first", "last", "filter", "mapFunc", "join_with",
|
||||
"join",
|
||||
@@ -289,7 +314,7 @@ func TestFormatType(t *testing.T) {
|
||||
|
||||
// Test that template functions work in actual templates
|
||||
func TestTemplateFunctionsInTemplate(t *testing.T) {
|
||||
executor, err := NewTemplateExecutor()
|
||||
executor, err := NewTemplateExecutor(false)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create executor: %v", err)
|
||||
}
|
||||
|
||||
@@ -18,14 +18,39 @@ type TemplateExecutor struct {
|
||||
templates *template.Template
|
||||
}
|
||||
|
||||
// NewTemplateExecutor creates a new template executor
|
||||
func NewTemplateExecutor() (*TemplateExecutor, error) {
|
||||
// NewTemplateExecutor creates a new template executor.
|
||||
// flattenSchema controls whether schema.table identifiers use dot or underscore separation.
|
||||
func NewTemplateExecutor(flattenSchema bool) (*TemplateExecutor, error) {
|
||||
// Create template with custom functions
|
||||
funcMap := make(template.FuncMap)
|
||||
for k, v := range TemplateFunctions() {
|
||||
funcMap[k] = v
|
||||
}
|
||||
|
||||
// qual_table returns a quoted, schema-qualified identifier.
|
||||
// With flatten=false: "schema"."table" (or unquoted equivalents).
|
||||
// With flatten=true: "schema_table".
|
||||
funcMap["qual_table"] = func(schema, name string) string {
|
||||
if schema == "" {
|
||||
return quoteIdent(name)
|
||||
}
|
||||
if flattenSchema {
|
||||
return quoteIdent(schema + "_" + name)
|
||||
}
|
||||
return quoteIdent(schema) + "." + quoteIdent(name)
|
||||
}
|
||||
|
||||
// qual_table_raw is the same as qual_table but without identifier quoting.
|
||||
funcMap["qual_table_raw"] = func(schema, name string) string {
|
||||
if schema == "" {
|
||||
return name
|
||||
}
|
||||
if flattenSchema {
|
||||
return schema + "_" + name
|
||||
}
|
||||
return schema + "." + name
|
||||
}
|
||||
|
||||
tmpl, err := template.New("").Funcs(funcMap).ParseFS(templateFS, "templates/*.tmpl")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse templates: %w", err)
|
||||
@@ -177,6 +202,72 @@ type AuditTriggerData struct {
|
||||
Events string
|
||||
}
|
||||
|
||||
// CreateUniqueConstraintData contains data for create unique constraint template
|
||||
type CreateUniqueConstraintData struct {
|
||||
SchemaName string
|
||||
TableName string
|
||||
ConstraintName string
|
||||
Columns string
|
||||
}
|
||||
|
||||
// CreateCheckConstraintData contains data for create check constraint template
|
||||
type CreateCheckConstraintData struct {
|
||||
SchemaName string
|
||||
TableName string
|
||||
ConstraintName string
|
||||
Expression string
|
||||
}
|
||||
|
||||
// CreateForeignKeyWithCheckData contains data for create foreign key with existence check template
|
||||
type CreateForeignKeyWithCheckData struct {
|
||||
SchemaName string
|
||||
TableName string
|
||||
ConstraintName string
|
||||
SourceColumns string
|
||||
TargetSchema string
|
||||
TargetTable string
|
||||
TargetColumns string
|
||||
OnDelete string
|
||||
OnUpdate string
|
||||
Deferrable bool
|
||||
}
|
||||
|
||||
// SetSequenceValueData contains data for set sequence value template
|
||||
type SetSequenceValueData struct {
|
||||
SchemaName string
|
||||
TableName string
|
||||
SequenceName string
|
||||
ColumnName string
|
||||
}
|
||||
|
||||
// CreateSequenceData contains data for create sequence template
|
||||
type CreateSequenceData struct {
|
||||
SchemaName string
|
||||
SequenceName string
|
||||
Increment int
|
||||
MinValue int64
|
||||
MaxValue int64
|
||||
StartValue int64
|
||||
CacheSize int
|
||||
}
|
||||
|
||||
// AddColumnWithCheckData contains data for add column with existence check template
|
||||
type AddColumnWithCheckData struct {
|
||||
SchemaName string
|
||||
TableName string
|
||||
ColumnName string
|
||||
ColumnDefinition string
|
||||
}
|
||||
|
||||
// CreatePrimaryKeyWithAutoGenCheckData contains data for primary key with auto-generated key check template
|
||||
type CreatePrimaryKeyWithAutoGenCheckData struct {
|
||||
SchemaName string
|
||||
TableName string
|
||||
ConstraintName string
|
||||
AutoGenNames string // Comma-separated list of names like "'name1', 'name2'"
|
||||
Columns string
|
||||
}
|
||||
|
||||
// Execute methods for each template
|
||||
|
||||
// ExecuteCreateTable executes the create table template
|
||||
@@ -319,6 +410,76 @@ func (te *TemplateExecutor) ExecuteAuditTrigger(data AuditTriggerData) (string,
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
// ExecuteCreateUniqueConstraint executes the create unique constraint template
|
||||
func (te *TemplateExecutor) ExecuteCreateUniqueConstraint(data CreateUniqueConstraintData) (string, error) {
|
||||
var buf bytes.Buffer
|
||||
err := te.templates.ExecuteTemplate(&buf, "create_unique_constraint.tmpl", data)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to execute create_unique_constraint template: %w", err)
|
||||
}
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
// ExecuteCreateCheckConstraint executes the create check constraint template
|
||||
func (te *TemplateExecutor) ExecuteCreateCheckConstraint(data CreateCheckConstraintData) (string, error) {
|
||||
var buf bytes.Buffer
|
||||
err := te.templates.ExecuteTemplate(&buf, "create_check_constraint.tmpl", data)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to execute create_check_constraint template: %w", err)
|
||||
}
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
// ExecuteCreateForeignKeyWithCheck executes the create foreign key with check template
|
||||
func (te *TemplateExecutor) ExecuteCreateForeignKeyWithCheck(data CreateForeignKeyWithCheckData) (string, error) {
|
||||
var buf bytes.Buffer
|
||||
err := te.templates.ExecuteTemplate(&buf, "create_foreign_key_with_check.tmpl", data)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to execute create_foreign_key_with_check template: %w", err)
|
||||
}
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
// ExecuteSetSequenceValue executes the set sequence value template
|
||||
func (te *TemplateExecutor) ExecuteSetSequenceValue(data SetSequenceValueData) (string, error) {
|
||||
var buf bytes.Buffer
|
||||
err := te.templates.ExecuteTemplate(&buf, "set_sequence_value.tmpl", data)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to execute set_sequence_value template: %w", err)
|
||||
}
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
// ExecuteCreateSequence executes the create sequence template
|
||||
func (te *TemplateExecutor) ExecuteCreateSequence(data CreateSequenceData) (string, error) {
|
||||
var buf bytes.Buffer
|
||||
err := te.templates.ExecuteTemplate(&buf, "create_sequence.tmpl", data)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to execute create_sequence template: %w", err)
|
||||
}
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
// ExecuteAddColumnWithCheck executes the add column with check template
|
||||
func (te *TemplateExecutor) ExecuteAddColumnWithCheck(data AddColumnWithCheckData) (string, error) {
|
||||
var buf bytes.Buffer
|
||||
err := te.templates.ExecuteTemplate(&buf, "add_column_with_check.tmpl", data)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to execute add_column_with_check template: %w", err)
|
||||
}
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
// ExecuteCreatePrimaryKeyWithAutoGenCheck executes the create primary key with auto-generated key check template
|
||||
func (te *TemplateExecutor) ExecuteCreatePrimaryKeyWithAutoGenCheck(data CreatePrimaryKeyWithAutoGenCheckData) (string, error) {
|
||||
var buf bytes.Buffer
|
||||
err := te.templates.ExecuteTemplate(&buf, "create_primary_key_with_autogen_check.tmpl", data)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to execute create_primary_key_with_autogen_check template: %w", err)
|
||||
}
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
// Helper functions to build template data from models
|
||||
|
||||
// BuildCreateTableData builds CreateTableData from a models.Table
|
||||
@@ -355,7 +516,7 @@ func BuildAuditFunctionData(
|
||||
auditSchema string,
|
||||
userFunction string,
|
||||
) AuditFunctionData {
|
||||
funcName := fmt.Sprintf("ft_audit_%s", table.Name)
|
||||
funcName := fmt.Sprintf("tf_audit_%s", table.Name)
|
||||
|
||||
// Build list of audited columns
|
||||
auditedColumns := make([]*models.Column, 0)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
ALTER TABLE {{.SchemaName}}.{{.TableName}}
|
||||
ADD COLUMN IF NOT EXISTS {{.ColumnName}} {{.ColumnType}}
|
||||
ALTER TABLE {{qual_table .SchemaName .TableName}}
|
||||
ADD COLUMN IF NOT EXISTS {{quote_ident .ColumnName}} {{.ColumnType}}
|
||||
{{- if .Default}} DEFAULT {{.Default}}{{end}}
|
||||
{{- if .NotNull}} NOT NULL{{end}};
|
||||
12
pkg/writers/pgsql/templates/add_column_with_check.tmpl
Normal file
12
pkg/writers/pgsql/templates/add_column_with_check.tmpl
Normal file
@@ -0,0 +1,12 @@
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (
|
||||
SELECT 1 FROM information_schema.columns
|
||||
WHERE table_schema = '{{.SchemaName}}'
|
||||
AND table_name = '{{.TableName}}'
|
||||
AND column_name = '{{.ColumnName}}'
|
||||
) THEN
|
||||
ALTER TABLE {{qual_table .SchemaName .TableName}} ADD COLUMN {{.ColumnDefinition}};
|
||||
END IF;
|
||||
END;
|
||||
$$;
|
||||
@@ -1,7 +1,7 @@
|
||||
{{- if .SetDefault -}}
|
||||
ALTER TABLE {{.SchemaName}}.{{.TableName}}
|
||||
ALTER COLUMN {{.ColumnName}} SET DEFAULT {{.DefaultValue}};
|
||||
ALTER TABLE {{qual_table .SchemaName .TableName}}
|
||||
ALTER COLUMN {{quote_ident .ColumnName}} SET DEFAULT {{.DefaultValue}};
|
||||
{{- else -}}
|
||||
ALTER TABLE {{.SchemaName}}.{{.TableName}}
|
||||
ALTER COLUMN {{.ColumnName}} DROP DEFAULT;
|
||||
ALTER TABLE {{qual_table .SchemaName .TableName}}
|
||||
ALTER COLUMN {{quote_ident .ColumnName}} DROP DEFAULT;
|
||||
{{- end -}}
|
||||
@@ -1,2 +1,2 @@
|
||||
ALTER TABLE {{.SchemaName}}.{{.TableName}}
|
||||
ALTER COLUMN {{.ColumnName}} TYPE {{.NewType}};
|
||||
ALTER TABLE {{qual_table .SchemaName .TableName}}
|
||||
ALTER COLUMN {{quote_ident .ColumnName}} TYPE {{.NewType}};
|
||||
@@ -1,4 +1,4 @@
|
||||
CREATE OR REPLACE FUNCTION {{.SchemaName}}.{{.FunctionName}}()
|
||||
CREATE OR REPLACE FUNCTION {{qual_table_raw .SchemaName .FunctionName}}()
|
||||
RETURNS trigger AS
|
||||
$body$
|
||||
DECLARE
|
||||
@@ -81,4 +81,4 @@ LANGUAGE plpgsql
|
||||
VOLATILE
|
||||
SECURITY DEFINER;
|
||||
|
||||
COMMENT ON FUNCTION {{.SchemaName}}.{{.FunctionName}}() IS 'Audit trigger function for table {{.SchemaName}}.{{.TableName}}';
|
||||
COMMENT ON FUNCTION {{qual_table_raw .SchemaName .FunctionName}}() IS 'Audit trigger function for table {{qual_table_raw .SchemaName .TableName}}';
|
||||
@@ -4,13 +4,13 @@ BEGIN
|
||||
SELECT 1
|
||||
FROM pg_trigger
|
||||
WHERE tgname = '{{.TriggerName}}'
|
||||
AND tgrelid = '{{.SchemaName}}.{{.TableName}}'::regclass
|
||||
AND tgrelid = '{{qual_table_raw .SchemaName .TableName}}'::regclass
|
||||
) THEN
|
||||
CREATE TRIGGER {{.TriggerName}}
|
||||
AFTER {{.Events}}
|
||||
ON {{.SchemaName}}.{{.TableName}}
|
||||
ON {{qual_table_raw .SchemaName .TableName}}
|
||||
FOR EACH ROW
|
||||
EXECUTE FUNCTION {{.SchemaName}}.{{.FunctionName}}();
|
||||
EXECUTE FUNCTION {{qual_table_raw .SchemaName .FunctionName}}();
|
||||
END IF;
|
||||
END;
|
||||
$$;
|
||||
@@ -1,6 +1,6 @@
|
||||
{{/* Base constraint template */}}
|
||||
{{- define "constraint_base" -}}
|
||||
ALTER TABLE {{.SchemaName}}.{{.TableName}}
|
||||
ALTER TABLE {{qual_table_raw .SchemaName .TableName}}
|
||||
ADD CONSTRAINT {{.ConstraintName}}
|
||||
{{block "constraint_definition" .}}{{end}};
|
||||
{{- end -}}
|
||||
@@ -15,7 +15,7 @@ BEGIN
|
||||
AND table_name = '{{.TableName}}'
|
||||
AND constraint_name = '{{.ConstraintName}}'
|
||||
) THEN
|
||||
ALTER TABLE {{.SchemaName}}.{{.TableName}}
|
||||
ALTER TABLE {{qual_table_raw .SchemaName .TableName}}
|
||||
DROP CONSTRAINT {{.ConstraintName}};
|
||||
END IF;
|
||||
END;
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
|
||||
{{/* Base ALTER TABLE structure */}}
|
||||
{{- define "alter_table_base" -}}
|
||||
ALTER TABLE {{.SchemaName}}.{{.TableName}}
|
||||
ALTER TABLE {{qual_table_raw .SchemaName .TableName}}
|
||||
{{block "alter_operation" .}}{{end}};
|
||||
{{- end -}}
|
||||
|
||||
@@ -30,5 +30,5 @@ $$;
|
||||
|
||||
{{/* Common drop pattern */}}
|
||||
{{- define "drop_if_exists" -}}
|
||||
{{block "drop_type" .}}{{end}} IF EXISTS {{.SchemaName}}.{{.ObjectName}};
|
||||
{{block "drop_type" .}}{{end}} IF EXISTS {{qual_table_raw .SchemaName .ObjectName}};
|
||||
{{- end -}}
|
||||
@@ -1 +1 @@
|
||||
COMMENT ON COLUMN {{.SchemaName}}.{{.TableName}}.{{.ColumnName}} IS '{{.Comment}}';
|
||||
COMMENT ON COLUMN {{qual_table .SchemaName .TableName}}.{{quote_ident .ColumnName}} IS '{{.Comment}}';
|
||||
@@ -1 +1 @@
|
||||
COMMENT ON TABLE {{.SchemaName}}.{{.TableName}} IS '{{.Comment}}';
|
||||
COMMENT ON TABLE {{qual_table .SchemaName .TableName}} IS '{{.Comment}}';
|
||||
12
pkg/writers/pgsql/templates/create_check_constraint.tmpl
Normal file
12
pkg/writers/pgsql/templates/create_check_constraint.tmpl
Normal file
@@ -0,0 +1,12 @@
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (
|
||||
SELECT 1 FROM information_schema.table_constraints
|
||||
WHERE table_schema = '{{.SchemaName}}'
|
||||
AND table_name = '{{.TableName}}'
|
||||
AND constraint_name = '{{.ConstraintName}}'
|
||||
) THEN
|
||||
ALTER TABLE {{qual_table .SchemaName .TableName}} ADD CONSTRAINT {{quote_ident .ConstraintName}} CHECK ({{.Expression}});
|
||||
END IF;
|
||||
END;
|
||||
$$;
|
||||
@@ -1,10 +1,10 @@
|
||||
ALTER TABLE {{.SchemaName}}.{{.TableName}}
|
||||
DROP CONSTRAINT IF EXISTS {{.ConstraintName}};
|
||||
ALTER TABLE {{qual_table .SchemaName .TableName}}
|
||||
DROP CONSTRAINT IF EXISTS {{quote_ident .ConstraintName}};
|
||||
|
||||
ALTER TABLE {{.SchemaName}}.{{.TableName}}
|
||||
ADD CONSTRAINT {{.ConstraintName}}
|
||||
ALTER TABLE {{qual_table .SchemaName .TableName}}
|
||||
ADD CONSTRAINT {{quote_ident .ConstraintName}}
|
||||
FOREIGN KEY ({{.SourceColumns}})
|
||||
REFERENCES {{.TargetSchema}}.{{.TargetTable}} ({{.TargetColumns}})
|
||||
REFERENCES {{qual_table .TargetSchema .TargetTable}} ({{.TargetColumns}})
|
||||
ON DELETE {{.OnDelete}}
|
||||
ON UPDATE {{.OnUpdate}}
|
||||
DEFERRABLE;
|
||||
@@ -0,0 +1,18 @@
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (
|
||||
SELECT 1 FROM information_schema.table_constraints
|
||||
WHERE table_schema = '{{.SchemaName}}'
|
||||
AND table_name = '{{.TableName}}'
|
||||
AND constraint_name = '{{.ConstraintName}}'
|
||||
) THEN
|
||||
ALTER TABLE {{qual_table .SchemaName .TableName}}
|
||||
ADD CONSTRAINT {{quote_ident .ConstraintName}}
|
||||
FOREIGN KEY ({{.SourceColumns}})
|
||||
REFERENCES {{qual_table .TargetSchema .TargetTable}} ({{.TargetColumns}})
|
||||
ON DELETE {{.OnDelete}}
|
||||
ON UPDATE {{.OnUpdate}}{{if .Deferrable}}
|
||||
DEFERRABLE{{end}};
|
||||
END IF;
|
||||
END;
|
||||
$$;
|
||||
@@ -1,2 +1,2 @@
|
||||
CREATE {{if .Unique}}UNIQUE {{end}}INDEX IF NOT EXISTS {{.IndexName}}
|
||||
ON {{.SchemaName}}.{{.TableName}} USING {{.IndexType}} ({{.Columns}});
|
||||
CREATE {{if .Unique}}UNIQUE {{end}}INDEX IF NOT EXISTS {{quote_ident .IndexName}}
|
||||
ON {{qual_table .SchemaName .TableName}} USING {{.IndexType}} ({{.Columns}});
|
||||
@@ -6,8 +6,8 @@ BEGIN
|
||||
AND table_name = '{{.TableName}}'
|
||||
AND constraint_name = '{{.ConstraintName}}'
|
||||
) THEN
|
||||
ALTER TABLE {{.SchemaName}}.{{.TableName}}
|
||||
ADD CONSTRAINT {{.ConstraintName}} PRIMARY KEY ({{.Columns}});
|
||||
ALTER TABLE {{qual_table .SchemaName .TableName}}
|
||||
ADD CONSTRAINT {{quote_ident .ConstraintName}} PRIMARY KEY ({{.Columns}});
|
||||
END IF;
|
||||
END;
|
||||
$$;
|
||||
@@ -0,0 +1,27 @@
|
||||
DO $$
|
||||
DECLARE
|
||||
auto_pk_name text;
|
||||
BEGIN
|
||||
-- Drop auto-generated primary key if it exists
|
||||
SELECT constraint_name INTO auto_pk_name
|
||||
FROM information_schema.table_constraints
|
||||
WHERE table_schema = '{{.SchemaName}}'
|
||||
AND table_name = '{{.TableName}}'
|
||||
AND constraint_type = 'PRIMARY KEY'
|
||||
AND constraint_name IN ({{.AutoGenNames}});
|
||||
|
||||
IF auto_pk_name IS NOT NULL THEN
|
||||
EXECUTE 'ALTER TABLE {{qual_table .SchemaName .TableName}} DROP CONSTRAINT ' || quote_ident(auto_pk_name);
|
||||
END IF;
|
||||
|
||||
-- Add named primary key if it doesn't exist
|
||||
IF NOT EXISTS (
|
||||
SELECT 1 FROM information_schema.table_constraints
|
||||
WHERE table_schema = '{{.SchemaName}}'
|
||||
AND table_name = '{{.TableName}}'
|
||||
AND constraint_name = '{{.ConstraintName}}'
|
||||
) THEN
|
||||
ALTER TABLE {{qual_table .SchemaName .TableName}} ADD CONSTRAINT {{quote_ident .ConstraintName}} PRIMARY KEY ({{.Columns}});
|
||||
END IF;
|
||||
END;
|
||||
$$;
|
||||
6
pkg/writers/pgsql/templates/create_sequence.tmpl
Normal file
6
pkg/writers/pgsql/templates/create_sequence.tmpl
Normal file
@@ -0,0 +1,6 @@
|
||||
CREATE SEQUENCE IF NOT EXISTS {{qual_table .SchemaName .SequenceName}}
|
||||
INCREMENT {{.Increment}}
|
||||
MINVALUE {{.MinValue}}
|
||||
MAXVALUE {{.MaxValue}}
|
||||
START {{.StartValue}}
|
||||
CACHE {{.CacheSize}};
|
||||
@@ -1,7 +1,7 @@
|
||||
CREATE TABLE IF NOT EXISTS {{.SchemaName}}.{{.TableName}} (
|
||||
CREATE TABLE IF NOT EXISTS {{qual_table .SchemaName .TableName}} (
|
||||
{{- range $i, $col := .Columns}}
|
||||
{{- if $i}},{{end}}
|
||||
{{$col.Name}} {{$col.Type}}
|
||||
{{quote_ident $col.Name}} {{$col.Type}}
|
||||
{{- if $col.Default}} DEFAULT {{$col.Default}}{{end}}
|
||||
{{- if $col.NotNull}} NOT NULL{{end}}
|
||||
{{- end}}
|
||||
|
||||
12
pkg/writers/pgsql/templates/create_unique_constraint.tmpl
Normal file
12
pkg/writers/pgsql/templates/create_unique_constraint.tmpl
Normal file
@@ -0,0 +1,12 @@
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (
|
||||
SELECT 1 FROM information_schema.table_constraints
|
||||
WHERE table_schema = '{{.SchemaName}}'
|
||||
AND table_name = '{{.TableName}}'
|
||||
AND constraint_name = '{{.ConstraintName}}'
|
||||
) THEN
|
||||
ALTER TABLE {{qual_table .SchemaName .TableName}} ADD CONSTRAINT {{quote_ident .ConstraintName}} UNIQUE ({{.Columns}});
|
||||
END IF;
|
||||
END;
|
||||
$$;
|
||||
@@ -1 +1 @@
|
||||
ALTER TABLE {{.SchemaName}}.{{.TableName}} DROP CONSTRAINT IF EXISTS {{.ConstraintName}};
|
||||
ALTER TABLE {{qual_table .SchemaName .TableName}} DROP CONSTRAINT IF EXISTS {{quote_ident .ConstraintName}};
|
||||
@@ -1 +1 @@
|
||||
DROP INDEX IF EXISTS {{.SchemaName}}.{{.IndexName}} CASCADE;
|
||||
DROP INDEX IF EXISTS {{qual_table .SchemaName .IndexName}} CASCADE;
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user